VGG Loss的Pytorch实现

VGG Loss的Pytorch实现,第1张

VGG Loss的Pytorch实现

代码引自:https://github.com/bryandlee/stylegan2-encoder-pytorch/blob/master/train_encoder.py

class VGGLoss(nn.Module):
    def __init__(self, device, n_layers=5):
        super().__init__()
        
        feature_layers = (2, 7, 12, 21, 30)
        self.weights = (1.0, 1.0, 1.0, 1.0, 1.0)  

        vgg = torchvision.models.vgg19(pretrained=True).features
        
        self.layers = nn.ModuleList()
        prev_layer = 0
        for next_layer in feature_layers[:n_layers]:
            layers = nn.Sequential()
            for layer in range(prev_layer, next_layer):
                layers.add_module(str(layer), vgg[layer])
            self.layers.append(layers.to(device))
            prev_layer = next_layer
        
        for param in self.parameters():
            param.requires_grad = False

        self.criterion = nn.L1Loss().to(device)
        
    def forward(self, source, target):
        loss = 0 
        for layer, weight in zip(self.layers, self.weights):
            source = layer(source)
            with torch.no_grad():
                target = layer(target)
            loss += weight*self.criterion(source, target)
            
        return loss 

VGG Loss在GAN里面用的比较多,做风格迁移之类的,可以用来比较两张图像感官上的差距。具体来说的话其实就是将两张图像送入预训练的VGG19网络中提取各层特征,然后对比各层特征之间的差异(使用L1Loss)。所涉及到的一个关键参数为各层特征的权重,这个会对训练效果会有比较大的影响。

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/zaji/5652090.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-12-16
下一篇2022-12-16

发表评论

登录后才能评论

评论列表(0条)

    保存