Pytorch断点续训原理及DFGAN2版本实操

释放双眼,带上耳机,听听看~!
了解Pytorch断点续训的原理和实际操作,以及DFGAN2版本的断点续训实操。学习如何保存和加载模型参数,以确保训练中断后能够继续训练模型。

我们在训练模型的时候经常会出现各种问题导致训练中断,比方说断电、系统中断、内存溢出、断连、硬件故障、地震火灾等之类的导致电脑系统关闭,从而将模型训练中断。

所以在实际运行当中,我们经常需要每100轮epoch或者每50轮epoch要保存训练好的参数,以防不测,这样下次可以直接加载该轮epoch的参数接着训练,就不用重头开始。下面我们来介绍Pytorch断点续训原理及DFGAN20版本和22版本断点续训实操

一、Pytorch断点续训

1.1、保存模型

pytorch保存模型等相关参数,需要利用torch.save(),torch.save()是PyTorch框架中用于保存Python对象到磁盘上的函数,一般为

torch.save(checkpoint, checkpoint_path)

其中checkpoint为保存模型的所有参数和缓存的键值对,checkpoint_path表示最终保存的模型,通常以.pth格式保存。

torch.save()函数会将obj序列化为字节流,并将字节流写入f指定的文件中。在读取数据时,可以使用torch.load()函数来将文件中的字节流反序列化成Python对象。使用这两个函数可以轻松地将PyTorch模型保存到磁盘上,并在需要的时候重新加载使用。

一般在实际操作中,我们写为:

torch.save(netG.state_dict(),'%s/netG_epoch_%d.pth' % (self.model_dir, epoch))

它接受两个参数:要保存的对象(即状态字典)和文件路径。在这里,状态字典是通过调用netG.state_dict()方法获得的,而文件路径是使用字符串格式化操作构建的。字符串'%s/netG_epoch_%d.pth' % (self.model_dir, epoch) 中,%s表示第一个字符串占位符将被替换为self.model_dir(即保存.pth文件的目录路径),%d表示第二个字符串占位符将被替换为epoch(即当前训练的轮数)。这样就可以在每一轮训练结束后将当前的网络模型参数保存到一个新的.pth文件中,文件名中包含轮数以便于后续的查看和比较。

1.2、读取模型

对应的,torch.load()函数是PyTorch框架中用于从磁盘上加载Python对象的函数。一般为:

 checkpoint = torch.load(log_dir)
 model.load_state_dict(checkpoint['model'])

torch.load()函数会从文件中读取字节流,并将其反序列化成Python对象。对于PyTorch模型,可以直接将其反序列化成模型对象。

一般实际操作中,我们常常写为:

model.load_state_dict(torch.load(path))

首先使用torch.load()函数从指定的路径中加载模型参数,得到一个字典对象,即state_dict。其中,字典的键是各个层次结构的名称,而键所对应的值则是该层次结构中各个参数的值。

然后,使用model.load_state_dict()函数将state_dict中的参数加载到已经定义好的模型中。这个函数的作用是将state_dict中每个键所对应的参数加载到模型中对应的键所指定的层次结构上。

需要注意的是,由于模型的结构和保存的参数的结构必须匹配,因此在加载参数之前,需要先定义好模型的结构,使其与保存的参数的结构相同。如果结构不匹配,会导致加载参数失败,甚至会引发错误。

二、DFGAN20版本

在DFGAN20版本当中,模型保存在DFGAN/code/models当中,其中netG_300.pth就是代表生成器第300轮的模型netD_300.pth也就是代表鉴别器第300轮的模型。
Pytorch断点续训原理及DFGAN2版本实操

我们可以将需要的模型的路径记下来,然后打开main.py文件,其中在270行左右的# # validation data #下面
Pytorch断点续训原理及DFGAN2版本实操
可以在下面这段代码的后面

netG = NetG(cfg.TRAIN.NF, 100, sentencelstm, wordlstm).to(device)
netD = NetD(cfg.TRAIN.NF).to(device)

增加两句:

netG.load_state_dict(torch.load('models/%s/netG_300.pth' % (cfg.CONFIG_NAME)))
netD.load_state_dict(torch.load('models/%s/netD_300.pth' % (cfg.CONFIG_NAME)))

这样,就成功读取了所选文件夹目录下的netG_300.pthnetD_300.pth,如果要在这个epoch下进行采样,只需要把code/cfg/bird.ymlB_VALIDATION改为 True,如果需要在这个epoch下进行断点续训则B_VALIDATION改为False就可以了。

三、DFGAN22版本

DFGAN22版本与DFGAN20版本代码结构有所不同,但是在断点续训的原理上是一样的。

DFGAN22版本在保存模型时并没有单独保存netG, netD, netC, optG, optD等模型,而且将他们的模型都保存为一个.pth文件,如名为state_epoch_940.pth代表的就是第940轮的所有断点文件。这些断点文件保存在code/saved_models/bird或cooc下,如:
Pytorch断点续训原理及DFGAN2版本实操
如果要进行断点续训,我们可以把这个文件路径记下来或者将文件挪到需要的位置,我一般将需要断点续训或者采样的模型放在pretrained文件夹下。

然后下一步,打开code/cfg/bird.yml文件,如果是coco数据集则打开coco.yml
Pytorch断点续训原理及DFGAN2版本实操
修改state_epoch为自己选定的第几轮模型(想读取state_epoch_940.pth,则state_epoch改为940,这样后面打印结果、保存模型就是从941开始了),然后修改checkpoint为相应模型的路径如:./saved_models/bird/pretrained/state_epoch_940.pth,最终如下所示:

state_epoch: 940
checkpoint: ./saved_models/bird/pretrained/state_epoch_940.pth

如果你想更深层次了解其原理,即DFGAN22 版是如何保存模型和读取模型的,可以打开code/lib/utils.py文件,在第140行附近写了保存模型的函数,与我们之前讲的原理是一样的,只不过他将netG, netD, netC, optG, optD等又做了一层,然后将其统一保存到state_epoch_中:

def save_models(netG, netD, netC, optG, optD, epoch, multi_gpus, save_path):
    if (multi_gpus==True) and (get_rank() != 0):
        None
    else:
        state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}, 
                'optimizers': {'optimizer_G': optG.state_dict(), 'optimizer_D': optD.state_dict()},
                'epoch': epoch}
        torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch))

在第90行到140行附近,也写了读取模型的方法,也就是读相应checkpoint的checkpoint['model']['netG'],看完你会发觉,原理很简单,代码也不算很难,遇到问题建议大家多多阅读源码。

def load_opt_weights(optimizer, weights):
    optimizer.load_state_dict(weights)
    return optimizer


def load_model_opt(netG, netD, netC, optim_G, optim_D, path, multi_gpus=False):
    checkpoint = torch.load(path, map_location=torch.device('cpu'))
    netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus)
    netD = load_model_weights(netD, checkpoint['model']['netD'], multi_gpus)
    netC = load_model_weights(netC, checkpoint['model']['netC'], multi_gpus)
    optim_G = load_opt_weights(optim_G, checkpoint['optimizers']['optimizer_G'])
    optim_D = load_opt_weights(optim_D, checkpoint['optimizers']['optimizer_D'])
    return netG, netD, netC, optim_G, optim_D


def load_models(netG, netD, netC, path):
    checkpoint = torch.load(path, map_location=torch.device('cpu'))
    netG = load_model_weights(netG, checkpoint['model']['netG'])
    netD = load_model_weights(netD, checkpoint['model']['netD'])
    netC = load_model_weights(netC, checkpoint['model']['netC'])
    return netG, netD, netC


def load_netG(netG, path, multi_gpus, train):
    checkpoint = torch.load(path, map_location="cpu")
    netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus, train)
    return netG


def load_model_weights(model, weights, multi_gpus=False, train=True):
    if list(weights.keys())[0].find('module')==-1:
        pretrained_with_multi_gpu = False
    else:
        pretrained_with_multi_gpu = True
    if (multi_gpus==False) or (train==False):
        if pretrained_with_multi_gpu:
            state_dict = {
                key[7:]: value
                for key, value in weights.items()
            }
        else:
            state_dict = weights
    else:
        state_dict = weights
    model.load_state_dict(state_dict)
    return model

三、可能遇见的问题

问题1:模型中断后继续训练出错

在有些时候我们需要保存训练好的参数为path文件,以防不测,下次可以直接加载该轮epoch的参数接着训练,但是在重新加载时发现类似报错:

size mismatch for block0.affine0.linear1.linear2.weight: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for block0.affine0.linear1.linear2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).

问题原因:这是说明某个超参数出现了问题,可能你之前训练时候用的是64,现在准备在另外的机器上面续训的时候某个超参数设置的是32,导致了size mismatch,也有可能是你动过了模型的代码,导致现在代码和训练的模型匹配不上了。

解决方案:查看size mismatch的模型部分,将超参数改回来,并将代码和原本训练的代码保持一致。

问题2:模型中断后继续训练 效果直降

加载该轮epoch的参数接着训练,继续训练的过程是能够运行的,但是发现继续训练时效果大打折扣,完全没有中断前的最后几轮好。
问题原因:暂时未知,推测是续训时模型加载的问题,也有可能是保存和加载的方式问题
解决方案:统一保存和加载的方式,当我采用以下方式时,貌似避免了这个问题:
模型的保存:

torch.save(netG.state_dict(), 'models/%s/netG_%03d.pth' % (cfg.CONFIG_NAME, epoch))

模型的重新加载:

netD.load_state_dict(torch.load('models/%s/netD_300.pth' % (cfg.CONFIG_NAME), map_loca

💡 最后

我们已经建立了🏤T2I研学社群,如果你还有其他疑问或者对🎓文本生成图像很感兴趣,可以私信我加入社群

🎉 支持我:点赞👍+收藏⭐️+留言📝

本文正在参加「金石计划」

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

MTCNN人脸检测算法原理与应用

2023-12-4 12:12:14

AI教程

城市场景渲染新方法:基于MLP的NeRF与特征网格的集成

2023-12-4 12:18:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索