pix2pix、pix2pixHD 通过损失日志进行训练可视化

pix2pix、pix2pixHD 通过损失日志进行训练可视化,第1张

目录

背景

代码

结果

总结


背景

pix2pix(HD)代码在训练时会自动保存一个损失变化的txt文件,通过该文件能够对训练过程进行一个简单的可视化,代码如下。

训练的损失文件如图,对其进行可视化。

代码

#coding:utf-8
##
#author: QQ:1913434222 WeChat:Alocus
##
import matplotlib.pyplot as plt
import re
import numpy as np
import os,sys

file = "loss_logA - 副本.txt"
savepath = r"J:\混合样本去雾\调试代码&脚本&可视化\visualiz\"
#dirs = os.listdir( filepath )
#for file in dirs:
#    if file.endswith('txt'):
#        f  = open(filepath+file,'r')
f  = open(file,'r')
lines = f.readlines()
G_GAN = []
G_GAN_Feat = []
G_VGG = []
G_KL= []
D_real = []
D_fake = []
G_featD = []
featD_real = []
featD_fake = []
total_data =[]

for line in lines:
    if "(epoch" in line:

        if "G_GAN" in line :
            G_GAN_list = line.split()
            G_GAN.append(float(G_GAN_list[9]))
        if "G_GAN_Feat" in line :
            G_GAN_Feat_list = line.split()
            G_GAN_Feat.append(float(G_GAN_Feat_list[11]))
        if "G_VGG" in line :
            G_VGG_list = line.split()
            G_VGG.append(float(G_VGG_list[13]))
        if "G_KL" in line :
            G_KL_list = line.split()
            G_KL.append(float(G_KL_list[15]))
        if "D_real" in line:
            D_real_list = line.split()
            D_real.append(float(D_real_list[17]))
        if "D_fake" in line:
            D_fake_list = line.split()
            D_fake.append(float(D_fake_list[19]))
        if "G_featD" in line:
            G_featD_list = line.split()
            G_featD.append(float(G_featD_list[21]))
        if "featD_real" in line:
            featD_real_list = line.split()
            featD_real.append(float(featD_real_list[23]))
        if "featD_fake" in line:
            featD_fake_list = line.split()
            featD_fake.append(float(featD_fake_list[25]))


    total_data = [(G_GAN [i] + G_GAN_Feat[i]+G_VGG[i]+
    G_KL[i]+
    D_real[i]+
    D_fake[i]+
    G_featD[i]+
    featD_real[i]+
    featD_fake[i])  for i in range(0, len(featD_fake))]

    fig = plt.figure()#(figsize=(50,6))
    ax = np.linspace(0,len(featD_fake),len(featD_fake))
    plt.plot(ax, total_data, label="total")
    plt.plot(ax,G_GAN,label="G_GAN")
    plt.plot(ax,G_GAN_Feat,label="G_GAN_Feat")
    plt.plot(ax,G_VGG,label="G_VGG")
    plt.plot(ax,D_real,label="D_real")
    plt.plot(ax,D_real,label="D_fake")
    plt.plot(ax,G_KL,label="G_KL")
    plt.plot(ax,G_featD,label="G_featD")
    plt.plot(ax,featD_real,label="featD_real")
    plt.plot(ax,featD_fake,label="featD_fake")

plt.grid(color='gray', linestyle='--', linewidth=1, alpha=0.3)
plt.legend()
plt.title('VAE$_1$')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.savefig(savepath+file.split('.txt')[0]+'.png')


#G_GAN: 0.973 G_GAN_Feat: 7.709 G_VGG: 8.092 G_KL: 1.172 D_real: 0.908 D_fake: 0.870
# G_Feat_L2: 46.512 G_GAN: 3.917 G_GAN_Feat: 14.005 G_VGG: 8.741 D_real: 3.828 D_fake: 3.098
结果

总结

如果代码的损失函数等进行了修改,或者损失函数有变化,则需要对代码进行对应的修改,修改很简单,看下前面如何写的,照着改就ok啦!

有一些细节,如他们代码保存损失时不是每个epoch都一直保存的,有的epoch不同iter会保存两次所以会有些小问题,不过还好,问题不大。

祝好!

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/800078.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-06
下一篇2022-05-06

发表评论

登录后才能评论

评论列表(0条)

    保存