2. Pytorch:模型保存与读取
2.1. 简单
import torch
## save
torch.save(model, 'model.pkl')
## load
model = torch.load('model.pkl')
这种方法存储的模型包括了模型框架及模型参数等,存取的 pkl 文件较大。
2.2. 详细
模型除了本身的框架、参数信息,还应包括训练的信息,比如训练迭代次数、优化器参数等。
1import torch
2import shutil
3
4## save
5def save_checkpoint(state, is_best, save_path, filename):
6 filename = os.path.join(save_path, filename)
7 torch.save(state, filename)
8 if is_best:
9 bestname = os.path.join(save_path, 'model_best.pth.tar')
10 shutil.copyfile(filename, bestname)
11
12save_checkpoint({
13 'epoch': cur_epoch,
14 'state_dict': model.state_dict(),
15 'best_prec': best_prec,
16 'loss_train': loss_train,
17 'optimizer': optimizer.state_dict(),
18 }, is_best, save_path, 'epoch-{}_checkpoint.pth.tar'.format(cur_epoch))
19
20## load
21def load_checkpoint(checkpoint, model, optimizer):
22 """ loads state into model and optimizer and returns:
23 epoch, best_precision, loss_train[]
24 e.g., model = alexnet(pretrained=False)
25 """
26 if os.path.isfile(load_path):
27 print("=> loading checkpoint '{}'".format(load_path))
28 checkpoint = torch.load(load_path)
29 epoch = checkpoint['epoch']
30 best_prec = checkpoint['best_prec']
31 loss_train = checkpoint['loss_train']
32 model.load_state_dict(checkpoint['state_dict'])
33 optimizer.load_state_dict(checkpoint['optimizer'])
34 print("=> loaded checkpoint '{}' (epoch {})"
35 .format(epoch, checkpoint['epoch']))
36 return epoch, best_prec, loss_train
37 else:
38 print("=> no checkpoint found at '{}'".format(load_path))
39 # epoch, best_precision, loss_train
40 return 1, 0, []
2.3. 导入部分参数
当我们只需要从 state_dict()
导入部分模型参数时,可以采用如下方法:
1# args has the model name, num classes and other irrelevant stuff
2>>> pretrained_state = model_zoo.load_url(model_names[args.arch])
3>>> model_state = my_model.state_dict()
4>>> pretrained_state = { k:v for k,v in pretrained_state.iteritems() if k in model_state and v.size() == model_state[k].size() }
5>>> model_state.update(pretrained_state)
6>>> my_model.load_state_dict(model_state)
Note
state_dict()
的参数是包含设备信息的,如果 torch.save
保存的是 GPU 上的模型的状态,则其参数是在 GPU 上的;相应地, torch.load
会默认地将这些参数加载到 GPU 上。为了避免 GPU 显存耗尽,可以使用 torch.load(checkpoint, map_location='cpu')
先将这些参数加载到 CPU 上,然后再进行 load_state_dict
。
2.4. 参考资料
Saving and loading a model in Pytorch?
How to load part of pre trained model?
Serialization