基于Densenet&Xception融合的102种鲜花识别
开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 15 天,PaddleAPI
# 经过step-1处理后的的预训练模型
pretrained_model_path = 'models/step-8_model/'
# 加载经过处理的模型
fluid.io.load_params(executor=exe, dirname=pretrained_model_path)
# 定义输入数据维度
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
# 迭代一次,测试程序是否跑通。
for pass_id in range(1):
# 进行训练
for batch_id, data in enumerate(train_reader()):
train_cost, train_acc = exe.run(program=fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, acc])
# 每100个batch打印一次信息
if batch_id % 10 == 0:
print('Pass:%d, Batch:%d, Cost:%0.5f, Accuracy:%0.5f' %
(pass_id, batch_id, train_cost[0], train_acc[0]))
# 保存参数模型
save_pretrain_model_path = 'models/step-6_model/'
# 删除旧的模型文件
shutil.rmtree(save_pretrain_model_path, ignore_errors=True)
# 创建保持模型文件目录
os.makedirs(save_pretrain_model_path)
# 保存推断模型
fluid.io.save_inference_model(dirname=save_pretrain_model_path, feeded_var_names=['image'],
target_vars=[model], executor=exe)
六、预测
# 加载推断模型
use_gpu = True
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
save_freeze_dir = 'models/step-6_model/'
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(dirname=save_freeze_dir, executor=exe)
# 读取测试数据
from PIL import Image
import numpy as np
def reader(img_path):
img = Image.open(img_path)
if img.mode != 'RGB':
img = img.convert('RGB')
img = img.resize((640, 640), Image.ANTIALIAS)
img = np.array(img).astype('float32')
img -= [127.5, 127.5, 127.5]
img = img.transpose((2, 0, 1)) # HWC to CHW
img *= 0.007843
img = img[np.newaxis,:]
return img
# 单例模式,预测数据
# 此处直接生成比赛提交用的CSV文件,大家可以去平台上提交,测试自己的得分哦。
img_list = os.listdir('data/data30606/54_data/')
img_list.sort()
img_list.sort(key=lambda x: int(x[:-4])) ##文件名按数字排序
img_nums = len(img_list)
# print(img_list)
test_path = 'data/data30606/54_data/test/'
# img_path = test_path + img_list[i]
labels = []
for i in range(img_nums):
img_path = test_path + img_list[i]
tensor_img = reader(img_path)
label = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets)
lab = np.argmax(label)
# print(lab)
labels.append(lab)
submit = pd.DataFrame()
submit[1] = labels
submit.to_csv('submit123.csv', header=False)
总结
总体来说,这种融合方法不是很优雅,相对于计算量的提升所带来的精度提升收益不是很大,比赛中有人这么干,但是有AIstudio,
显卡算力足够,大家可以尽情的堆,精度越高,比赛排名越高。
下面是我的得分,大家可以调整迭代次数、学习率等超参,或者增加全连接层,添加DropOut,来调整网络,大家加油哦。