Pytorch保存checkpoint(检查点):通常在训练模型的过程中,每隔一段时间就将训练模型信息保存一次【包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复】

Pytorch保存checkpoint(检查点):通常在训练模型的过程中,每隔一段时间就将训练模型信息保存一次【包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复】,第1张

通常在训练模型的过程中,可能会遭遇断电、断网的尴尬,一旦出现这种情况,先前训练的模型就白费了,又得重头开始训练。因此每隔一段时间就将训练模型信息保存一次很有必要。而这些信息不光包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复训练。

state = {
    'epoch' : epoch + 1,  #保存当前的迭代次数
    'state_dict' : model.state_dict(), #保存模型参数
    'optimizer' : optimizer.state_dict(), #保存优化器参数
    ...,      #其余一些想保持的参数都可以添加进来
    ...,
}

torch.save(state, 'checkpoint.pth.tar')     #将state中的信息保存到checkpoint.pth.tar
#Pytorch 约定使用.tar格式来保存这些检查点


#当想恢复训练时
checkpoint = torch.load('checkpoint.pth.tar')
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])   #加载模型的参数
optimizer.load_state_dict(checkpoint['optimizer']) #加载优化器的参数



参考资料:
pytorch如何保存模型?

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/885514.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-14
下一篇2022-05-14

发表评论

登录后才能评论

评论列表(0条)

    保存