华为云AI开发平台ModelArts断点续训练和增量训练_云淘科技

什么是断点续训练和增量训练

断点续训练是指因为某些原因(例如容错重启、资源抢占、作业卡死等)导致训练作业还未完成就被中断,下一次训练可以在上一次的训练基础上继续进行。这种方式对于需要长时间训练的模型而言比较友好。

增量训练是指增加新的训练数据到当前训练流程中,扩展当前模型的知识和能力。

断点续训练和增量训练均是通过checkpoint机制实现。

checkpoint的机制是:在模型训练的过程中,不断地保存训练结果(包括但不限于EPOCH、模型权重、优化器状态、调度器状态)。即便模型训练中断,也可以基于checkpoint接续训练。

当需要从训练中断的位置接续训练,只需要加载checkpoint,并用checkpoint信息初始化训练状态即可。用户需要在代码里加上reload ckpt的代码,使能读取前一次训练保存的预训练模型。

ModelArts中如何实现断点续训练和增量训练

在ModelArts训练中实现断点续训练或增量训练,建议使用“训练输出”功能。

在创建训练作业时,设置训练输出参数train_url,在指定的训练输出的数据存储位置中保存checkpoint,并打开“预下载至本地目录”开关。选择预下载至本地目录时,系统在训练作业启动前,自动将数据存储位置中的checkpoint文件下载到训练容器的本地目录。

图1 训练输出设置

断点续训练建议和训练容错检查(即故障自动重启)功能同时使用。在创建训练作业页面,开启故障自动重启开关。训练环境预检测失败、或者训练容器硬件检测故障、或者训练作业失败时会自动重新下发并运行训练作业。

图2 故障自动重启设置

Pytorch版reload ckpt

简单介绍一下pytorch模型保存的两种方式。

仅保存模型参数

state_dict = model.state_dict()
torch.save(state_dict, path)

保存整个Model(不推荐)

torch.save(model, path)

保存模型的训练过程的产物

将模型训练过程中的网络权重、优化器权重、以及epoch进行保存,便于中断后继续训练恢复。

   checkpoint = {
           "net": model.state_dict(),
           "optimizer": optimizer.state_dict(),
           "epoch": epoch   
   }
   if not os.path.isdir('model_save_dir'):
       os.makedirs('model_save_dir')
   torch.save(checkpoint,'model_save_dir/ckpt_{}.pth'.format(str(epoch)))

详细代码

import os
import argparse
parser.add_argument("--train_url", type=str)
args = parser.parse_known_args()
# train_url 将被赋值为"/home/ma-user/modelarts/outputs/train_url_0" 
train_url = args.train_url

# 判断输出路径中是否有模型文件。若无文件则默认从头训练,若有模型文件,则加载epoch值最大的ckpt文件,当做预训练模型
if os.listdir(train_url):
    print('> load last ckpt and continue training!!')
    last_ckpt = sorted([file for file in os.listdir(train_url) if file.endswith(".pth")])[-1]
    local_ckpt_file = os.path.join(train_url, last_ckpt)
    print('last_ckpt:', last_ckpt)
    # 加载断点
    checkpoint = torch.load(local_ckpt_file)  
    # 加载模型可学习参数
    model.load_state_dict(checkpoint['net'])  
    # 加载优化器参数
    optimizer.load_state_dict(checkpoint['optimizer'])  
    # 获取保存的epoch,模型会在此epoch的基础上继续训练
    start_epoch = checkpoint['epoch']  
start = datetime.now()
total_step = len(train_loader)
for epoch in range(start_epoch + 1, args.epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ...

    # 保存模型训练过程中的网络权重、优化器权重、以及epoch
    checkpoint = {
          "net": model.state_dict(),
          "optimizer": optimizer.state_dict(),
          "epoch": epoch
        }
    if not os.path.isdir(train_url):
        os.makedirs(train_url)
        torch.save(checkpoint, os.path.join(train_url, 'ckpt_best_{}.pth'.format(epoch)))

MindSpore版reload ckpt

import os
import argparse
parser.add_argument("--train_url", type=str)
args = parser.parse_known_args()
# train_url 将被赋值为"/home/ma-user/modelarts/outputs/train_url_0" 
train_url = args.train_url

# 初始定义的网络、损失函数及优化器
net = resnet50(args_opt.batch_size, args_opt.num_classes)
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
# 首次训练的epoch初始值,mindspore1.3及以后版本会支持定义epoch_size初始值
# cur_epoch_num = 0
# 判断输出obs路径中是否有模型文件。若无文件则默认从头训练,若有模型文件,则加载epoch值最大的ckpt文件,当做预训练模型
if os.listdir(train_url):
    last_ckpt = sorted([file for file in os.listdir(train_url) if file.endswith(".ckpt")])[-1]
    print('last_ckpt:', last_ckpt)
    last_ckpt_file = os.path.join(train_url, last_ckpt)
     # 加载断点
    param_dict = load_checkpoint(last_ckpt_file)
    print('> load last ckpt and continue training!!')
    # 加载模型参数到net
    load_param_into_net(net, param_dict)
    # 加载模型参数到opt
    load_param_into_net(opt, param_dict)

    # 获取保存的epoch值,模型会在此epoch的基础上继续训练,此参数在mindspore1.3及以后版本会支持
    # if param_dict.get("epoch_num"):
    #     cur_epoch_num = int(param_dict["epoch_num"].data.asnumpy())
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
# as for train, users could use model.train
if args_opt.do_train:
    dataset = create_dataset()
    batch_num = dataset.get_dataset_size()
    config_ck = CheckpointConfig(save_checkpoint_steps=batch_num,
                                     keep_checkpoint_max=35)
    # append_info=[{"epoch_num": cur_epoch_num}],mindspore1.3及以后版本会支持append_info参数,保存当前时刻的epoch值
    ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10",
                                     directory=args_opt.train_url,
                                     config=config_ck)
    loss_cb = LossMonitor()
    model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
    # model.train(epoch_size-cur_epoch_num, dataset, callbacks=[ckpoint_cb, loss_cb]),mindspore1.3及以后版本支持从断点恢复训练

父主题: 训练进阶

同意关联代理商云淘科技,购买华为云产品更优惠(QQ 78315851)

内容没看懂? 不太想学习?想快速解决? 有偿解决: 联系专家