Torchvision
机器学习与深度学习,数据集组织、加载处理、按需按批装载、送入模型训练,不论是图片、文字还是音视频,流程基本上一致。
具体图片处理的大部分实现transform包上,实际使用时需要加入业务场景才能丰满起来。
当我们静下心来,花时间去接触AI相关的知识与工具,我们会深刻的感觉到技术真的只是一个工具,是场景将它丰富了起来。—— 笔者个人观点
简单介绍
Torchvision 是 PyTorch 的一个独立子库,它服务于PyTorch
深度学习框架的,主要用于计算机视觉任务,包括图像处理、数据加载、数据增强、预训练模型等。
核心包如下:
- torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
- torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
- torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
- torchvision.utils: 其他的一些有用的方法。
官网文档入口: pytorch.org/docs/stable…
读取数据集
可以从网上得到数据集,再用torchvision加载数据并处理,也可以从自建的数据集上加载并。
本文用torchvision数据集来演示读取的过程,内部会使用transform对数据进行变型。
数据集准备 – CIFAR10
此次代码中要用到的数据集,见附件有介绍与中文的参数。
用代码下载数据集-CIFAR10
通过py代码
# 使用CIFAR10数据集
# 训练集
# 如果下载比较慢,可以将控制台打印的下载链接放到专门的下载工具中下载
# 首先下载的是一个压缩包,会自动解压
train_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=True, download=True)
# 测试集
test_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=False, download=True)
运行代码,控制台显示如下信息
50000 — 说明有5w张训练数据
10000 –说明有1w张测试数据会自动下载数据集到
torchvision_dataset
文件夹
已下载就不会继续下载,控制台会出输Files already downloaded and verified
字样
操作数据
torchvision.datasets和transform的联合使用
- 下载数据集
- 装载图片
- 图片处理
- 图片展示
import torchvision.datasets
from torch.utils.tensorboard import SummaryWriter
# 将图片数据都转为tensor类型
# 可以对数据集做任何transforms范围内的操作,该例子只针对数据做toTensor
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor
])
# 使用CIFAR10数据集
train_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=True, transform=dataset_transform, download=True)
# 测试集
test_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=False, transform=dataset_transform, download=True)
# 用tensorboard显示前10张图片
# 运行tensorboard --logdir=p10
writer = SummaryWriter('p10')
for i in range(0):
img, target = test_set[i]
writer.add_image("test_set", img,i)
writer.close()
Torchvision.transform
Transforms是torchvision
模块下面的一个子模块,在Dataset中很常用可以方便地对图像进行各种变换操作。该模块中包含大量用户数据类型转化的类型和方法,比如统一size,每一个图像数据进行类的转化等。
“””
transforms.ToTensor
转化:PIL Image或numpy.ndarray(H * W * C) 转到 tensor 的数据类型
主要方法
call(self, pic)
参数:pic – Image或numpy的图像对象
返回值 : 返回tensor类型的图片
“””
Tensor图像结构:
Tensor是PyTorch中最基本的数据结构,你可以将其视为多维数组或者矩阵。PyTorch tensor和NumPy array非常相似,但是tensor可以在GPU上运算,而NumPy array则只能在CPU上运算。
可对图像直接操作,代码如下:
导入
import torch
创建一个未初始化的5×3矩阵
# 创建一个未初始化的5x3矩阵
x = torch.empty(5, 3)
print(x)
用.backward() 计算梯度
# 因为out包含一个标量,out.backward()等价于out.backward(torch.tensor(1.))
out.backward()
# 打印梯度 d(out)/dx
print(x.grad)
常用方法
ToTensor
- 作用: PIL Image或numpy.ndarray(H * W * C) 转到 tensor 的数据类型
- 输入: PIL
Image.open()
- 输出: 类型
ToTensor
Normalize
- 作用: 根据均值与标准差归一化tensor类图片
- 输入: tensor类型图片的均值与标准差
- 输出: 归一化后的图片数据
- 计算公式: (Input[channel] – mean[channel]) / std[channel]
举例
Input[channel] - mean[channel]) / std[channel= (input - 0.5)/0.5= 2 * input - 1结论 input像素值[0-1] --> result[-1,1]
Resize
- 作用: 将(PIL Image or Tensor)调整为给定的大小。
- 输入:
- size (sequence or int):如果size是(h, w)这样的序列,则输出size将与此匹配。如果size为int,图像的较小边缘将匹配此数字。即,如果高度>宽度,那么图像将被重新缩放为(size*高度/宽度,size)
- …
- 输出: 变型后的PIL Image or Tensor
Compose
- 作用:把几个tranforms组合在一起使用,相当于一个组合器,可以对输入图片一次进行多个transforms的操作。比如 compose负责把ToTensor和resize组合起来,一步到位实现PIL图形到resize后的tensor图形的转换
RandomCrop
- 作用:随机裁剪
- 输入:
- size: size可以是tuple也可以是Integer
- …
- 输出:裁后图片
代码
# 导入
"""
Torchvision 是 PyTorch 的一个独立子库,主要用于计算机视觉任务,包括图像处理、数据加载、数据增强、预训练模型等。
Torchvision 提供了各种经典的计算机视觉数据集的加载器,如CIFAR-10、ImageNet,以及用于数据预处理和数据增强的工具,可以帮助用户更轻松地进行图像分类、目标检测、图像分割等任务。
"""
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
"""
用ToTensor将PIL图片转为Tensor图片
"""
# 绝对路径 D:workspacepythonlearn_torchdatatrainants013035.jpg
img_path = "data/train/bees/16838648_415acd9e3f.jpg"
img = Image.open(img_path)
trans_toTensor = transforms.ToTensor()
img_tensor = trans_toTensor(img)
writer = SummaryWriter("logs")
writer.add_image("tensor_img", img_tensor)
"""
2. 用Normalize实现Tensor图片归一化
"""
trans_norm_0 = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5, ])
img_norm_0 = trans_norm_0(img_tensor)
writer.add_image("Normalize", img_norm_0, 1)
trans_norm = transforms.Normalize([6, 3, 2], [9, 3, 5])
img_norm = trans_norm(img_tensor)
writer.add_image("Normalize", img_norm, 2)
# 运行,测蔗
# tensorboard --logdir=logs
"""
2. Resize-等比例缩放
"""
trans_resize = transforms.Resize((512, 512))
# img_PIl --> img_resize PIL
img_resize = trans_resize(img)
# image PIl ---> toTensor --> 转为 tensor
img_resize = trans_toTensor(img_resize)
# print(img_resize)
writer.add_image("Resize", img_resize, 1)
"""
transforms.Compose
# trans_toTensor: 输入
# trans_resize_2: 输出
"""
trans_resize_2 = transforms.Resize(100)
trans_compose = transforms.Compose([trans_resize_2, trans_toTensor])
img_resize_2 = trans_compose(img)
writer.add_image("Resize", img_resize_2, 0)
"""
transforms.RandomCrop:随机裁剪
"""
trans_random = transforms.RandomCrop((150, 500))
trans_compose_2 = transforms.Compose([trans_random, trans_toTensor])
for i in range(10):
img_crop = trans_compose_2(img)
writer.add_image("RandomCrop", img_crop, i)
writer.close()
启动tensorboard查看结果
相关知识
Torchvision
Torchvision.dataset
文档入口
数据集
CIFAR10
CIFAR10由10个不同标签的图像组成。其中包括卡车、青蛙、船、汽车、鹿等常见图像。还有一个CIFAR100版本,有 100 个不同的类别
CIFAR10/CIFAR100一般用于物价识别,其广泛用于机器学习领域的计算机视觉算法基准测试。详情 官网地址
包名 torchvision.datasets.FashionMNIST()
包名 torchvision.datasets.CIFAR10()
参数说明:
- root: 数据集根路径,可以是相对路径
- train: = ture 训练集,否则为测试集
- transform: 对数据集进行的transform操作
- target_transform: 训练后的目标数据集执行指定的transform操作
- download:=true 自动下载数据集,false不会下载
COOC
目前有超过 100,000 个日常物品,如人、瓶子、文具、书籍等。 广泛用于目标检测,语义分割和图像描述
MNIST
MNIST 常用的入门级数据集,手写文字数据集
包名 torchvision.datasets.MNIST()
Fashion MNIST
该数据集与 MNIST 类似,但该数据集不是手写数字,而是 T 恤、裤子、包等服装项目。
包名 torchvision.datasets.FashionMNIST()
torchvision.models
提供神经网络常见的神经网络,有一些神经网络已经预训练好了。
torchvision.transform
图像处理与变形等
transforms.CenterCrop 对图片中心进行裁剪
transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
transforms.Grayscale 对图像进行灰度变换
transforms.Pad 使用固定值进行像素填充
transforms.RandomAffine 随机仿射变换
transforms.RandomCrop 随机区域裁剪
transforms.RandomHorizontalFlip 随机水平翻转
transforms.RandomRotation 随机旋转
transforms.RandomVerticalFlip 随机垂直翻转
文档入口
torchvision.utils
提供一些常用的工具,比如tensorboard
等
from torch.utils.tensorboard import SummaryWriter
文档入口
记录于:2012/11/10 _山海
[参考]