使用Pytorch提供的标准网络结构训练Cifar10模型

释放双眼,带上耳机,听听看~!
本文介绍了如何使用Pytorch提供的标准网络结构ResNet18来训练Cifar10模型,通过调用已定义好的模型可以节省自定义模型结构的工作量。

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第28天,点击查看活动详情

前言

在之前的文章中,我们介绍了如何去自定义去完成关于ResNet这样的网络结构,VGGNet这样的网络结构,MobileNet这样的网络结构,以及Inception这样不同的四大类结构。实际上,在Pytorch中提供了非常多的已经定义好的模型,这些模型也是目前来说比较标准的网络结构,我们经常会利用这些标准的网络结构去作为我们的预训练的模型,这样就可以节省很多的工作,就不需要自己去自定义模型结构。

今天,我们通过调用Pytorch提供的标准网络ResNet18来完成Cifar10模型的训练。

  • 1.1 调用Pytorch提供的标准网络

相比于之前自定义的网络结构,使用Pytorch提供的标准网络的代码量是比较少的,如果不需要对网络结构进行自己定义或者进行模型压缩裁剪等操作的时候,推荐大家使用Pytorch提供的标准网络结构

import torch.nn as nn
from torchvision import models

class resnet18(nn.Module):
    def __init__(self):
        super(resnet18, self).__init__()
        self.model = models.resnet18(pretrained=True)
        # 这里主要用来解决cifar10,需要修改类别数
        self.num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(self.num_features, 10)

    def forward(self, x):
        out = self.model(x)

        return out


def pytorch_resnet18():
    return resnet18()

注:cifar10数据训练的代码参考我之前的文章Pytorch——Cifar10图像分类中的训练模型的代码,只需要修改一下net即可。

在进行cifar10的数据训练,可以看到在第一个epoch之后,准确率到了26%,并且整个网络是处于收敛过程中的,如果需要使用其它的网络结构的时候,也可以利用这个模板来调用其他的模型。

使用Pytorch提供的标准网络结构训练Cifar10模型

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

PyTorch落地Linux基金会,AI社区讨论十年后的技术发展

2023-12-21 18:54:14

AI教程

清华大学开源多模态对话模型VisualGLM-6B能解读表情包

2023-12-21 19:05:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索