本文 已参与「新人创作礼」活动,一起开启掘金创作之路。
想必大家做深度学习分类的时候经常会遇到一个疑问,分类的模型最后需要怎么用,今天就来简单的介绍下这个问题
1、先预览
其实也就是分析每一帧照片,来判断最大可能属于哪个类别,跟目标检测很像,目标检测判断一张照片里有哪几种物品,而分类仅仅是判断这个照片属于哪个类别
2、官网地址
我们可以从官网的这里找到pytroch部署到移动端的例子,当然,现在这个还不是很成熟,感觉自己训练的数据集和网络,可能会多少有点问题,建议使用官方提供的网络模型来训练
3、流程
4、加载库
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
5、代码分析
1、加载图片
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
2、加载模型
module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
3、准备Tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
4、运行model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
5、得到分数集合
final float[] scores = outputTensor.getDataAsFloatArray();
6、获取物体标签
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
7、注意事项
这里的Demo加载的是需要pt模型数据,而我们Pytroch模型得到的数据都是pth的,如果直接使用,肯定是不行的,因此需要先转换后再使用,转换的方法如下所示
import torch
import torch.utils.data.distributed
# pytorch环境中
from models.base_model import BaseModel
model_pth = 'ghostnet.pth' #模型的参数文件
mobile_pt ='model.pt' # 将模型保存为Android可以调用的文件
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = BaseModel(name='resnet', num_classes=5).to(device)
model.load_state_dict(torch.load('ghostnet.pth', map_location=device), strict=False)
model.eval() # 模型设为评估模式
# 1张3通道224*224的图片
input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式
mobile = torch.jit.trace(model, input_tensor) # 模型转化
mobile.save(mobile_pt) # 保存文件
以上的例子,在官网中均有介绍,建议读者去官网下载好代码例子后跑通,在使用自己的数据集,目前感觉这个方法还是不太成熟,如果读者想做目标检测,比如YOLO的相关算法部署,小编在这里建议读者使用腾讯的NCNN这个推理框架,还是很好用的,我屡试不爽。