Pytorch学习笔记-07 模型finetune

Pytorch学习笔记-07 模型finetune,第1张

Pytorch学习笔记-07 模型finetune Pytorch学习笔记-07 模型finetune

文章目录
  • Pytorch学习笔记-07 模型finetune
    • 模型保存与加载
      • 序列化与反序列化
      • torch.save
      • torch.load
      • e.g.
    • 模型微调( Finetune)
    • GPU 的使用

模型保存与加载 序列化与反序列化

torch.save

主要参数:

  • obj :对象
  • f :输出路径
torch.load

主要参数

  • f :文件路径
  • map_location :指定存放位置 cpu or gpu

法1: 保存整个 Module
torch.save(net,path)

法2: 保存模型参数
state_dict = net.state_dict()
torch.save(state_dict , path)

e.g.

保存

net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

加载

# ================================== load net ===========================
# flag = 1
flag = 0
if flag:

    path_model = "./model.pkl"
    net_load = torch.load(path_model)

    print(net_load)

# ================================== load state_dict ===========================

flag = 1
# flag = 0
if flag:

    path_state_dict = "./model_state_dict.pkl"
    state_dict_load = torch.load(path_state_dict)

    print(state_dict_load.keys())

# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:

    net_new = LeNet2(classes=2019)

    print("加载前: ", net_new.features[0].weight[0, ...])
    net_new.load_state_dict(state_dict_load)
    print("加载后: ", net_new.features[0].weight[0, ...])
模型微调( Finetune)

Transfer Learning :机器学习分支,研究 源域 (source 的知识如何应用到 目标域 (target domain)

模型微调步骤:

  1. 获取预训练模型参数
  2. 加载模型( load_state_dict
  3. 修改输出层

模型微调训练方法:

  • 固定预训练的参数 requires_grad =False lr =
  • Features Extractor 较小学习率( params_group

e.g.

# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载参数
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join(baseDIR, "..", "..", "data", "finetune_resnet18-5c106cde.pth")
    if not os.path.exists(path_pretrained_model):
        raise Exception("n{} 不存在,请下载 07-02-数据-模型finetune.zipn放到 {}下,并解压即可".format(
            path_pretrained_model, os.path.dirname(path_pretrained_model)))
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

# 法1 : 冻结卷积层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
    for param in resnet18_ft.parameters():
        param.requires_grad = False
    print("conv1.weights[0, 0, ...]:n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))


# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features
resnet18_ft.fc = nn.Linear(num_ftrs, classes)


resnet18_ft.to(device)
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
# 法2 : conv 小学习率
# flag = 0
flag = 1
if flag:
    fc_params_id = list(map(id, resnet18_ft.fc.parameters()))     # 返回的是parameters的 内存地址
    base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())
    optimizer = optim.SGD([
        {'params': base_params, 'lr': LR*0},   # 0
        {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)

else:
    optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)               # 选择优化器

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)     # 设置学习率下降策略


# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    resnet18_ft.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = resnet18_ft(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().cpu().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

            # if flag_m1:
            print("epoch:{} conv1.weights[0, 0, ...] :n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        resnet18_ft.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = resnet18_ft(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().cpu().sum().numpy()

                loss_val += loss.item()

            loss_val_mean = loss_val/len(valid_loader)
            valid_curve.append(loss_val_mean)
            print("Valid:t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))

train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
GPU 的使用

CPU Central Processing Unit, 中央处理器):主要包括控制器和运算器
GPU(Graphics Processing Unit, 图形处理器 ):处理统一的,无依赖的大规模数据运算

to函数:转换数据 类型 设备

  1. tensor.to args , kwargs
  2. module.to args , kwargs

区别:张量不执行 inplace ,模型 执行 inplace

torch.cuda常用方法

  • torch.cuda.device_count ()():计算当前可见可用 gpu 数
  • torch.cuda.get_device_name ()():获取 gpu 名称
  • torch.cuda.manual_seed ()():为当前 gpu 设置随机种子
  • torch.cuda.manual_seed_all ()():为所有可见可用 gpu 设置随机种子
  • torch.cuda.set_device ()():设置主 gpu 为哪一个物理 gpu (不推荐
    • 推荐:os.environ.setdefault (“CUDA_VISIBLE_DEVICES”, “2,3”)

多gpu 运算的 分发并行 机制

分发 并行运算 结果回收

torch.nn.DataParallel
功能:包装模型,实现分发并行机制
主要参数:

  • module 需要包装分发的模型
  • device_ids 可分发的 gpu ,默认分发到所有 可见可用 gpu
  • output_device 结果输出设备

查询当前 gpu 内存剩余

def get_gpu_memory():
    import os
    os.system('nvidia-smi -q -d Memory | grep -A4 GPU | grep Free > tmp.txt')
    memory_gpu = [int(x.split()[2]) for x in open('tmp.txt', 'r').readlines()]
    os.system('rm tmp.txt')
    return memory_gpu

gpu模型加载

报错1:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU -only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

解决: torch.load(path_state_dict, map_location=“cpu”)

报错2:

RuntimeError: Error(s) in loading state_dict for FooNet:
Missing key(s) in state_dict: "linears.0.weight", "linears.1.weight", "linears.2.weight".
Unexpected key(s) in state_dict: "module.linears.0.weight",
"module.linears.1.weight", "module.linears.2.weight".

解决:

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict_load.items():
    namekey = k[7:] if k.startswith('module.') else k
    new_state_dict[namekey] = v

s) in state_dict: “module.linears.0.weight”,
“module.linears.1.weight”, “module.linears.2.weight”.

解决:

```python
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict_load.items():
    namekey = k[7:] if k.startswith('module.') else k
    new_state_dict[namekey] = v

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/zaji/5572382.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-12-14
下一篇2022-12-14

发表评论

登录后才能评论

评论列表(0条)

    保存