我们在训练模型的时候经常会出现各种问题导致训练中断,比方说断电、系统中断、内存溢出、断连、硬件故障、地震火灾等之类的导致电脑系统关闭,从而将模型训练中断。
所以在实际运行当中,我们经常需要每100轮epoch或者每50轮epoch要保存训练好的参数,以防不测,这样下次可以直接加载该轮epoch的参数接着训练,就不用重头开始。下面我们来介绍Pytorch断点续训原理以及DFGAN20版本和22版本断点续训实操。
文末评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。
一、Pytorch断点续训
1.1、保存模型
pytorch保存模型等相关参数,需要利用torch.save(),torch.save()是PyTorch框架中用于保存Python对象到磁盘上的函数,一般为
torch.save(checkpoint, checkpoint_path)
其中
checkpoint
为保存模型的所有参数和缓存的键值对,
checkpoint_path
表示最终保存的模型,通常以.pth格式保存。
torch.save()
函数会将obj序列化为字节流,并将字节流写入f指定的文件中。在读取数据时,可以使用
torch.load()
函数来将文件中的字节流反序列化成Python对象。使用这两个函数可以轻松地将PyTorch模型保存到磁盘上,并在需要的时候重新加载使用。
一般在实际操作中,我们写为:
torch.save(netG.state_dict(),'%s/netG_epoch_%d.pth'%(self.model_dir, epoch))
它接受两个参数:要保存的对象(即状态字典)和文件路径。在这里,状态字典是通过调用
netG.state_dict()
方法获得的,而文件路径是使用字符串格式化操作构建的。字符串
'%s/netG_epoch_%d.pth' % (self.model_dir, epoch)
中,%s表示第一个字符串占位符将被替换为
self.model_dir
(即保存.pth文件的目录路径),%d表示第二个字符串占位符将被替换为epoch(即当前训练的轮数)。这样就可以在每一轮训练结束后将当前的网络模型参数保存到一个新的.pth文件中,文件名中包含轮数以便于后续的查看和比较。
1.2、读取模型
对应的,
torch.load()
函数是PyTorch框架中用于从磁盘上加载Python对象的函数。一般为:
checkpoint = torch.load(log_dir)
model.load_state_dict(checkpoint['model'])
torch.load()
函数会从文件中读取字节流,并将其反序列化成Python对象。对于PyTorch模型,可以直接将其反序列化成模型对象。
一般实际操作中,我们常常写为:
model.load_state_dict(torch.load(path))
首先使用
torch.load()
函数从指定的路径中加载模型参数,得到一个字典对象,即
state_dict
。其中,字典的键是各个层次结构的名称,而键所对应的值则是该层次结构中各个参数的值。
然后,使用
model.load_state_dict()
函数将
state_dict
中的参数加载到已经定义好的模型中。这个函数的作用是将
state_dict
中每个键所对应的参数加载到模型中对应的键所指定的层次结构上。
需要注意的是,由于模型的结构和保存的参数的结构必须匹配,因此在加载参数之前,需要先定义好模型的结构,使其与保存的参数的结构相同。如果结构不匹配,会导致加载参数失败,甚至会引发错误。
二、DFGAN20版本
在DFGAN20版本当中,模型保存在
DFGAN/code/models
当中,其中
netG_300.pth
就是代表生成器第300轮的模型
netD_300.pth
也就是代表鉴别器第300轮的模型。
我们可以将需要的模型的路径记下来,然后打开
main.py
文件,其中在270行左右的
# # validation data #
下面
可以在下面这段代码的后面
netG =NetG(cfg.TRAIN.NF,100, sentencelstm, wordlstm).to(device)
netD =NetD(cfg.TRAIN.NF).to(device)
增加两句:
netG.load_state_dict(torch.load('models/%s/netG_300.pth'%(cfg.CONFIG_NAME)))
netD.load_state_dict(torch.load('models/%s/netD_300.pth'%(cfg.CONFIG_NAME)))
这样,就成功读取了所选文件夹目录下的
netG_300.pth
和
netD_300.pth
,如果要在这个epoch下进行采样,只需要把
code/cfg/bird.yml
下
B_VALIDATION
改为
True
,如果需要在这个epoch下进行断点续训则
B_VALIDATION
改为
False
就可以了。
三、DFGAN22版本
DFGAN22版本与DFGAN20版本代码结构有所不同,但是在断点续训的原理上是一样的。
DFGAN22版本在保存模型时并没有单独保存netG, netD, netC, optG, optD等模型,而且将他们的模型都保存为一个.pth文件,如名为
state_epoch_940.pth
代表的就是第940轮的所有断点文件。这些断点文件保存在
code/saved_models/bird或cooc
下,如:
如果要进行断点续训,我们可以把这个文件路径记下来或者将文件挪到需要的位置,我一般将需要断点续训或者采样的模型放在pretrained文件夹下。
然后下一步,打开
code/cfg/bird.yml
文件,如果是coco数据集则打开
coco.yml
:
修改state_epoch为自己选定的第几轮模型(想读取
state_epoch_940.pth
,则state_epoch改为940,这样后面打印结果、保存模型就是从941开始了),然后修改checkpoint为相应模型的路径如:
./saved_models/bird/pretrained/state_epoch_940.pth
,最终如下所示:
state_epoch:940
checkpoint:./saved_models/bird/pretrained/state_epoch_940.pth
如果你想更深层次了解其原理,即DFGAN22 版是如何保存模型和读取模型的,可以打开
code/lib/utils.py
文件,在第140行附近写了保存模型的函数,与我们之前讲的原理是一样的,只不过他将netG, netD, netC, optG, optD等又做了一层,然后将其统一保存到state_epoch_中:
def save_models(netG, netD, netC, optG, optD, epoch, multi_gpus, save_path):if(multi_gpus==True)and(get_rank()!=0):
None
else:
state ={'model':{'netG': netG.state_dict(),'netD': netD.state_dict(),'netC': netC.state_dict()}, \
'optimizers':{'optimizer_G': optG.state_dict(),'optimizer_D': optD.state_dict()},\
'epoch': epoch}
torch.save(state,'%s/state_epoch_%03d.pth'%(save_path, epoch))
在第90行到140行附近,也写了读取模型的方法,也就是读相应checkpoint的
checkpoint['model']['netG']
,看完你会发觉,原理很简单,代码也不算很难,遇到问题建议大家多多阅读源码。
def load_opt_weights(optimizer, weights):
optimizer.load_state_dict(weights)return optimizer
def load_model_opt(netG, netD, netC, optim_G, optim_D, path, multi_gpus=False):
checkpoint = torch.load(path, map_location=torch.device('cpu'))
netG =load_model_weights(netG, checkpoint['model']['netG'], multi_gpus)
netD =load_model_weights(netD, checkpoint['model']['netD'], multi_gpus)
netC =load_model_weights(netC, checkpoint['model']['netC'], multi_gpus)
optim_G =load_opt_weights(optim_G, checkpoint['optimizers']['optimizer_G'])
optim_D =load_opt_weights(optim_D, checkpoint['optimizers']['optimizer_D'])return netG, netD, netC, optim_G, optim_D
def load_models(netG, netD, netC, path):
checkpoint = torch.load(path, map_location=torch.device('cpu'))
netG =load_model_weights(netG, checkpoint['model']['netG'])
netD =load_model_weights(netD, checkpoint['model']['netD'])
netC =load_model_weights(netC, checkpoint['model']['netC'])return netG, netD, netC
def load_netG(netG, path, multi_gpus, train):
checkpoint = torch.load(path, map_location="cpu")
netG =load_model_weights(netG, checkpoint['model']['netG'], multi_gpus, train)return netG
def load_model_weights(model, weights, multi_gpus=False, train=True):iflist(weights.keys())[0].find('module')==-1:
pretrained_with_multi_gpu = False
else:
pretrained_with_multi_gpu = True
if(multi_gpus==False)or(train==False):if pretrained_with_multi_gpu:
state_dict ={
key[7:]: value
for key, value in weights.items()}else:
state_dict = weights
else:
state_dict = weights
model.load_state_dict(state_dict)return model
三、可能遇见的问题
问题1:模型中断后继续训练出错
在有些时候我们需要保存训练好的参数为path文件,以防不测,下次可以直接加载该轮epoch的参数接着训练,但是在重新加载时发现类似报错:
size mismatch for block0.affine0.linear1.linear2.weight: copying a param with shape torch.Size([512, 256])from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for block0.affine0.linear1.linear2.bias: copying a param with shape torch.Size([512])from checkpoint, the shape in current model is torch.Size([256]).
问题原因:这是说明某个超参数出现了问题,可能你之前训练时候用的是64,现在准备在另外的机器上面续训的时候某个超参数设置的是32,导致了size mismatch,也有可能是你动过了模型的代码,导致现在代码和训练的模型匹配不上了。
解决方案:查看size mismatch的模型部分,将超参数改回来,并将代码和原本训练的代码保持一致。
问题2:模型中断后继续训练 效果直降
加载该轮epoch的参数接着训练,继续训练的过程是能够运行的,但是发现继续训练时效果大打折扣,完全没有中断前的最后几轮好。
问题原因:暂时未知,推测是续训时模型加载的问题,也有可能是保存和加载的方式问题
解决方案:统一保存和加载的方式,当我采用以下方式时,貌似避免了这个问题:
模型的保存:
torch.save(netG.state_dict(),'models/%s/netG_%03d.pth'%(cfg.CONFIG_NAME, epoch))
模型的重新加载:
netD.load_state_dict(torch.load('models/%s/netD_300.pth'%(cfg.CONFIG_NAME), map_location='cuda:0'))
四、好书推荐(评论送书)
4.1、好书推荐
《PyTorch教程:21个项目玩转PyTorch实战》
阅读这本书,可以通过经典项目入门 PyTorch,通过前沿项目提升 PyTorch,基于PyTorch玩转深度学习,本书适合人工智能、机器学习、深度学习方面的人员阅读,也适合其他 IT 方面从业者,另外,还可以作为相关专业的教材。
京东自营购买链接:https://item.jd.com/13522327.html
评论区评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。
目前买还有优惠:北京大学出版社4月“423世界读书日”促销活动
当当活动日期:4.6-4.11,4.18-4.23
京东活动日期: 4.6 一天, 4.17-4.23
活动期间满100减50或者半价5折销售
希望大家关注参与423读书日北大社促销活动
京东自营购买链接:https://item.jd.com/13522327.html
4.2、内容简介
PyTorch 是基于 Torch 库的开源机器学习库,它主要由 Meta(原 Facebook)的人工智能研究实验室开发,在自然语言处理和计算机视觉领域都具有广泛的应用。本书介绍了简单且经典的入门项目,方便快速上手,如 MNIST数字识别,读者在完成项目的过程中可以了解数据集、模型和训练等基础概念。本书还介绍了一些实用且经典的模型,如 R-CNN 模型,通过这个模型的学习,读者可以对目标检测任务有一个基本的认识,对于基本的网络结构原理有一定的了解。另外,本书对于当前比较热门的生成对抗网络和强化学习也有一定的介绍,方便读者拓宽视野,掌握前沿方向。
4.3、作者简介
王飞,2019年翻译了PyTorch官方文档,读研期间研究方向为自然语言处理,主要是中文分词、文本分类和数据挖掘。目前在教育行业工作,探索人工智能技术在教育中的应用。
何健伟,曾任香港大学助理研究员,研究方向为自然语言处理,目前从事大规模推荐算法架构研究工作。
林宏彬,硕士期间研究方向为自然语言处理,现任阿里巴巴算法工程师,目前从事广告推荐领域的算法研究工作。
史周安,软件工程硕士,人工智能技术爱好者、实践者与探索者。目前从事弱监督学习、迁移学习与医学图像相关工作。
💡 最后
我们已经建立了🏤T2I研学社群,如果你还有其他疑问或者对🎓文本生成图像很感兴趣,可以私信我加入社群。
📝 加入社群 抱团学习:中杯可乐多加冰-深度学习T2I研习群
🔥 限时免费订阅:文本生成图像T2I专栏
🎉 支持我:点赞👍+收藏⭐️+留言📝
评论区评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。
版权归原作者 中杯可乐多加冰 所有, 如有侵权,请联系我们删除。