
在使用pytorch利用测试集进行网络预测时,给网络输入数据,默认会构建计算图,构建计算图是为了方便后续的反向传播进行梯度计算,如果只是为了利用网络进行预测,则不需要构建完整的计算图。构建完整计算图会增加计算和累积内存消耗,导致所占GPU显存越来越大。
解决方案
在测试代码处于如下命令下:
with torch.no_grad():
例如:
with torch.no_grad():
prediction = net(images)
loss = loss_func(prediction , label) / batch_size
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)