【Pytorch】模型参数初始化

【Pytorch】模型参数初始化,第1张

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

【Pytorch】模型参数初始化
  • 一、在模型外使用apply方法
  • 二、在模型__init__()中进行初始化


一、在模型外使用apply方法

apply方法是在nn.Module中实现的, 递归地调用self.children()去处理自己以及子模块。Pytorch中的任何网络,都是torch.nn.Module的子类,也就是模块。
model.apply(weight_init)会递归地将函数weight_init应用到父模块的每个子模块,也包括model这个父模块自身,经常用于初始化模型权重。
注意此种初始化方式采用递归,而在python中,对递归层数是有限制的,所以当网络结构很深时,可能会递归层数过深的错误。

class Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1), 
            nn.ReLU(True),
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(True),
            nn.Linear(n_hidden_2, out_dim)
             )  
    def forward(self, x):
        x = self.layer(x)
        return x

# 根据网络层的不同定义不同的初始化方式
def weight_init(m):
	# classname = m.__class__.__name__
	# if classname.find('Conv'):
	# if type(m) == nn.Conv2d:
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
     # 是否为批归一化层
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)
    # 也可以判断是否为conv2d,使用相应的初始化方式 
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
# 将weight_init应用在模块上
model.apply(weight_init)
二、在模型__init__()中进行初始化

在模型中以迭代的方式定义一个初始化函数weights_init并在__init__中调用它。

class Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1), 
            nn.ReLU(True),
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(True),
            nn.Linear(n_hidden_2, out_dim)
             )
         self.weights_init()
    
    def forward(self, x):
        x = self.layer(x)
        return x
    
    def weights_init(self):
        # 迭代循环初始化参数
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, -100)
            # 也可以判断是否为conv2d,使用相应的初始化方式 
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight.item(), 1)
                nn.init.constant_(m.bias.item(), 0)

model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/727262.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-04-26
下一篇2022-04-26

发表评论

登录后才能评论

评论列表(0条)

    保存