torch中的 inplace operation *** 作错误记录

torch中的 inplace operation *** 作错误记录,第1张

import torch
w = torch.rand(4, requires_grad=True)
w += 1
loss = w.sum()
loss.backward()

执行 loss 对参数 w 进行求导,会出现报错:RuntimeError: a leaf Variable that requires grad is being used in an in-place operation。是第 3 行代码 w += 1,如果把这句改成 w = w + 1,再执行就不会报错了。

import torch
x = torch.zeros(4)
w = torch.rand(4, requires_grad=True)
x[0] = torch.rand(1) * w[0]
for i in range(3):
    x[i+1] = torch.sin(x[i]) * w[i]
loss = x.sum()
loss.backward()
可使用 with torch.autograd.set_detect_anomaly(True) 定位具体的出错位置。

with torch.autograd.set_detect_anomaly(True):
    x = torch.zeros(4)
    w = torch.rand(4, requires_grad=True)
    x[0] = torch.rand(1) * w[0]
    for i in range(3):
        x[i+1] = torch.sin(x[i]) * w[i]
    loss = x.sum()
    loss.backward()

Error detected in SinBackward. 大概是 torch.sin() 函数出现了问题。将第 6 行代码 x[i+1] = torch.sin(x[i]) * w[i] 改成 x[i+1] = torch.sin(x[i].clone()) * w[i]

import torch
x = torch.zeros(4)
w = torch.rand(4, requires_grad=True)
x[0] = torch.rand(1) * w[0]
for i in range(3):
    x[i+1] = torch.sin(x[i].clone()) * w[i]
loss = x.sum()
loss.backward()

inplace operation 的报错:

x += 1 改成 x = x + 1;
x[:, :, 0:3] = x[:, :, 0:3] + 1 改成 x[:, :, 0:3] = x[:, :, 0:3].clone() + 1;
x[i+1] = torch.sin(x[i]) * w[i] 改成 x[i+1] = torch.sin(x[i].clone()) * w[i];
可使用 with torch.autograd.set_detect_anomaly(True) 帮助定位出错位置,一般会运行较长时间。

x = x + 1 is not in-place, because it takes the objects pointed to by x, creates a new Variable, adds 1 to x putting the result in the new Variable, and overwrites the object referenced by x to point to the new var. There are no in-place modifications, you only change Python references (you can check that id(x) is different before and after that line).

On the other hand, doing x += 1 or x[0] = 1 will modify the data of the Variable in-place, so that no copy is done. However some functions (in your case *) require the inputs to never change after they compute the output, or they wouldn’t be able to compute the gradient. That’s why an error is raised.


参考:https://blog.csdn.net/weixin_39679367/article/details/122754199

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/917238.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-16
下一篇2022-05-16

发表评论

登录后才能评论

评论列表(0条)

    保存