Pytorch训练过程中GPU显存不断增加的解决方案

Pytorch训练过程中GPU显存不断增加的解决方案,第1张

Pytorch训练过程中显存不断增加原因之一

在使用pytorch利用测试集进行网络预测时,给网络输入数据,默认会构建计算图,构建计算图是为了方便后续的反向传播进行梯度计算,如果只是为了利用网络进行预测,则不需要构建完整的计算图。构建完整计算图会增加计算和累积内存消耗,导致所占GPU显存越来越大。

解决方案
在测试代码处于如下命令下:

with torch.no_grad():

例如:

with torch.no_grad():
	prediction = net(images)
	loss = loss_func(prediction , label) / batch_size

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/942275.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-18
下一篇2022-05-18

发表评论

登录后才能评论

评论列表(0条)

    保存