PyTorch中register

PyTorch中register,第1张

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

PyTorch中register_hook函数学习
  • 一、backward函数
  • 二、register_hook函数


一、backward函数

当输出o不是标量时,不能直接o.backward(),需要向backward传入与输入x具有相同维度的tensor w,o.backward(w) 求的不是 o 对 x 的导数,而是 l = torch.sum(o*w)对 x 的导数,相当于多加了一步按权重线性求和,使得 o 变成了标量。

需要注意:当中间有变量时,如o=f(y),y=g(x),则该w同样作用于求y的梯度上。

import torch

def y_grad(grad):
    print('y的梯度(z对y)为:', grad)

x = torch.tensor([1.,2.,3.], requires_grad=True)
y = torch.pow(x, 2)
z = x + y

y.register_hook(y_grad)
z.backward(torch.tensor([1,1,1]))

输出为:

y的梯度(z对y)为: tensor([1., 1., 1.])

y.register_hook(y_grad)
z.backward(torch.tensor([1,2,1]))

输出为:

y的梯度(z对y)为: tensor([1., 2., 1.])

ps:requires_grad=False的变量可以输入进PyTorch的model,且修改变量requires_grad=True

二、register_hook函数

由于反向传播时,不会保留中间变量的梯度,因此该函数的目的主要是对中间变量的梯度进行需要的 *** 作

  1. register_hook(),该函数的参数必须为函数,调用方式为x.register_hook(func),将x的梯度作为参数传入func,func即可对x的梯度进行所需 *** 作
  2. func对中间变量进行 *** 作后,会改变该中间变量的梯度值,将改变的梯度值向后传播,影响叶子变量梯度
  3. 具体计算过程如下
import torch

def y_grad(grad):
    print('y的梯度(z对y)为:', grad)
    return grad**2

x = torch.tensor([1.,2.,3.], requires_grad=True)
y = torch.pow(x, 2)
z = x + y

y.register_hook(y_grad)
z.backward(torch.tensor([1,2,1]))
print(x.grad)

输出为:

y的梯度(z对y)为: tensor([1., 2., 1.])
tensor([ 3., 10., 7.])

计算推导:
z = y + x = x 2 + x z = y + x = x^2 + x z=y+x=x2+x
此时 x x x y y y z z z w w w 都是vector,将 z z z 乘以 w w w 得到 z z z 为标量, z z z x x x的导数为:
∂ z ∂ x = ( w ∂ z ∂ y ) ⋅ ∂ y ∂ x + w ⋅ 1 \frac{\partial z}{\partial x} = (w\frac{\partial z}{\partial y}) \cdot \frac{\partial y}{\partial x} + w \cdot 1 xz=(wyz)xy+w1
括号里的是 新的y的梯度,因此函数对y梯度的平方 *** 作要包含w,即
( w ∂ z ∂ y ) 2 (w\frac{\partial z}{\partial y})^2 (wyz)2 因此对于 x [ 1 ] = 2 x[1]=2 x[1]=2,对应的 w = 2 w=2 w=2 ∂ z ∂ y = 1 \frac{\partial z}{\partial y}=1 yz=1 ∂ y ∂ x = 2 x \frac{\partial y}{\partial x}=2x xy=2x,新的 z 对 y z对y zy的梯度为 ( 2 ⋅ 1 ) = 2 (2\cdot1)=2 (21)=2,经过平方后等于4,传到 x [ 1 ] x[1] x[1]处时 ∂ z ∂ x [ 1 ] = ( w ∂ z ∂ y ) 2 ⋅ ∂ y ∂ x + w ⋅ 1 = ( 2 ⋅ 1 ) 2 ⋅ 2 ⋅ 2 + 2 = 18 \frac{\partial z}{\partial x[1]} = (w\frac{\partial z}{\partial y})^2 \cdot \frac{\partial y}{\partial x} + w \cdot 1=(2\cdot1)^2\cdot2\cdot2+2=18 x[1]z=(wyz)2xy+w1=(21)222+2=18

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/869961.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-13
下一篇2022-05-13

发表评论

登录后才能评论

评论列表(0条)

    保存