pytorch的梯度传递

pytorch的梯度传递,第1张

pytorch的梯度传递
  • 1.requires_grad的传递
    • 1.1三种情况下的梯度传递
    • 1.2利用requires_grad=False冻结骨干网络
    • 1.3网络中的数据是记录梯度的

1.requires_grad的传递

requires_gard 是tensor的一个属性,requires_gard=False表示不记录梯度,requires_gard=True表示记录张量的梯度。

每次的计算抽象为张量 A 与 B 做数学运算得到张量 C,C 是否记录梯度取决于 A 和 B的情况。

1.1三种情况下的梯度传递
  • A.requires_gard=False B.requires_gard=True ⇒ C.requires_gard=True
    A = torch.tensor([1., 2., 3.], requires_grad=True)
    B = torch.tensor([4., 5., 6.], requires_grad=False)
    C = A + B
    C.requires_grad
    ---------------------------------------------------------------------------------
    True
    
  • A.requires_gard=True B.requires_gard=False ⇒ C.requires_gard=True
    A = torch.tensor([1., 2., 3.], requires_grad=False)
    B = torch.tensor([4., 5., 6.], requires_grad=True)
    C = A + B
    C.requires_grad
    ----------------------------------------------------------------------------------
    True
    
  • A.requires_gard=False B.requires_gard=False ⇒ C.requires_gard=False
    A = torch.tensor([1., 2., 3.], requires_grad=False)
    B = torch.tensor([4., 5., 6.], requires_grad=False)
    C = A + B
    C.requires_grad
    -----------------------------------------------------------------------------------
    False
    

由此可见,只有当输入都不需要记录梯度时,后续计算的张量才不记录梯度,只要有一个输入张量计算梯度,后续的张量均需要记录梯度

1.2利用requires_grad=False冻结骨干网络
# 获得pytorch的预训练模型
model = torchvision.models.resnet18(pretrained=True)
# 冻结model的梯度计算
for p in model.parameters():
    p.requires_grad = False
# 替换最上层的fc
model.fc = torch.nn.Linear(512, 100)
# 新创建的liner层默认requires_grad=True
optmizer = torch.optim.SGD(model.fc.parameters(), lr=0.001)
1.3网络中的数据是记录梯度的
model = torchvision.models.resnet18(pretrained=True)
inputs = torch.randn(1, 3, 128, 128)
inputs.requires_grad
model(inputs).requires_grad
--------------------------------------------------------------------
False
True

虽然输入网络的tensor inputs 是不记录梯度的(requires_grad=False),但是网络的参数记录梯度,导致中间层的输出数据和最终的输出数据的requires_grad=True。

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/919118.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-16
下一篇2022-05-16

发表评论

登录后才能评论

评论列表(0条)

    保存