
【Pytorch】模型参数初始化提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
- 一、在模型外使用apply方法
- 二、在模型__init__()中进行初始化
一、在模型外使用apply方法
apply方法是在nn.Module中实现的, 递归地调用self.children()去处理自己以及子模块。Pytorch中的任何网络,都是torch.nn.Module的子类,也就是模块。
model.apply(weight_init)会递归地将函数weight_init应用到父模块的每个子模块,也包括model这个父模块自身,经常用于初始化模型权重。
注意此种初始化方式采用递归,而在python中,对递归层数是有限制的,所以当网络结构很深时,可能会递归层数过深的错误。
class Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(True),
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(True),
nn.Linear(n_hidden_2, out_dim)
)
def forward(self, x):
x = self.layer(x)
return x
# 根据网络层的不同定义不同的初始化方式
def weight_init(m):
# classname = m.__class__.__name__
# if classname.find('Conv'):
# if type(m) == nn.Conv2d:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# 是否为批归一化层
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
# 也可以判断是否为conv2d,使用相应的初始化方式
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
# 将weight_init应用在模块上
model.apply(weight_init)
二、在模型__init__()中进行初始化
在模型中以迭代的方式定义一个初始化函数weights_init并在__init__中调用它。
class Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(True),
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(True),
nn.Linear(n_hidden_2, out_dim)
)
self.weights_init()
def forward(self, x):
x = self.layer(x)
return x
def weights_init(self):
# 迭代循环初始化参数
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, -100)
# 也可以判断是否为conv2d,使用相应的初始化方式
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight.item(), 1)
nn.init.constant_(m.bias.item(), 0)
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
欢迎分享,转载请注明来源:内存溢出
微信扫一扫
支付宝扫一扫
评论列表(0条)