5.3 模型的保存与复用#

在深度学习中通常训练一个可用的模型都需要耗费极大的成本,因此在模型训练过程中就需要对满足某些条件下的网络权重参数进行保存,然后在实际推理过程中直接载入这些权重参数来完成模型的推理过程。同时,另外一种场景便是模型已经在一批数据上训练完成且完成了本地持久化保存,但可能过了一段时间后又收集到了一批新的数据,因此这时候就需要将之前的模型载入进行在新数据上进行增量训练或者是在整个数据上进行全量训练。

在PyTorch中可以通过torch.save()torch.load()来完成上述场景中的主要步骤。下面,我们将以之前介绍的LeNet5网络模型为例来分别进行介绍。不过在这之前,我们先来看看PyTorch中模型参数的保存形式。

5.3.1 查看模型参数#

这里,我们依旧以4.4节内容中介绍的LeNet5网络模型为例进行讲解分析。在定义完LeNet5网络模型并完成实例化操作后,那么网络中对应的权重参数也都完成了初始化的工作,即有了一个初始值。同时,可以通过如下方式来访问:

1 import sys
2 sys.path.append("../../")
3 from Chapter04.C03_LeNet5.LeNet5 import LeNet5
4 if __name__ == '__main__':
5     model = LeNet5()
6     print("Model's state_dict:")
7     for (name, param) in model.state_dict().items():
8         print(name, param.size())

在上述代码中,第1~2行是将Chapter04这个搜索路径加入到系统路径中,否则第3行会提示” No module named ‘Chapter04’ “。第5行是实例化模型LeNet5,即初始化整个模型。第7~8行则是遍历模型中的每个参数。同时,需要注意的是通过model.state_dict()函数返回得到的是一个Python中的有序字段(OrderedDict),即是遍历输出的顺序结果就是元素插入字典时的顺序,例如这里插入的网络层。

上述代码运行结束后,其输出的结果为:

 1 Model's state_dict:
 2 conv.0.weight torch.Size([6, 1, 5, 5])
 3 conv.0.bias torch.Size([6])
 4 conv.3.weight torch.Size([16, 6, 5, 5])
 5 conv.3.bias torch.Size([16])
 6 fc.1.weight torch.Size([120, 400])
 7 fc.1.bias torch.Size([120])
 8 fc.3.weight torch.Size([84, 120])
 9 fc.3.bias torch.Size([84])
10 fc.5.weight torch.Size([10, 84])
11 fc.5.bias torch.Size([10])

在上述输出结果中,每一行的前半部分表示参数的名称,如conv.0.weight,后面部分表示该权重参数对应的形状。同时从输出结果可以看出,模型一共有5层权重参数,即conv.0conv.3fc.1fc.3fc.5

5.3.2 自定义参数前缀#

在上面的输出结果中有两个地方值得注意:①参数名中的fcconv前缀是根据定义LeNet5模型时nn.Sequential()时的名字所确定,即在第4.4.3节中定模型时使用了两个Sequential()实例对象,名称分别为convfc;②参数名中的数字表示每个Sequential()中网络层所在的位置。例如,如果将LeNet5网络结构定义成如下形式:

 1 class LeNet5(nn.Module):
 2     def __init__(self, ):
 3         super(LeNet5, self).__init__()
 4         self.lenet5 = nn.Sequential(  
 5             nn.Conv2d(1, 6, 5, padding=2),  
 6             nn.ReLU(), 
 7             nn.MaxPool2d(2, 2),  
 8             nn.Conv2d(6, 16, 5), 
 9             nn.ReLU(),
10             nn.MaxPool2d(2, 2),
11             nn.Flatten(),
12             nn.Linear(16 * 5 * 5, 120),
13             nn.ReLU(),
14             nn.Linear(120, 84),
15             nn.ReLU(),
16             nn.Linear(84, 10))

那么其参数名则为:

1 print(model.state_dict().keys())
2 odict_keys(['lenet5.0.weight', 'lenet5.0.bias', 'lenet5.3.weight', 'lenet5.3.bias', 'lenet5.7.weight', 'lenet5.7.bias', 'lenet5.9.weight', 'lenet5.9.bias', 'lenet5.11.weight', 'lenet5.11.bias'])

可以看出,参数名最前面的部分就是Sequential()对象的名字,理解了这一点对于后续我们去解析和载入一些预训练模型很有帮助。

除此之外,对于PyTorch中的优化器等,其同样有对应的state_dict()方法来获取相关参数信息,例如:

1 optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
2 print(optimizer.state_dict())
3 {'state': {}, 'param_groups': [{'initial_lr': 0.01, 'lr': 0.0, 
4 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False,
5 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 
6 'fused': False,'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]}

在介绍完模型参数的查看方法后,便可以进入到模型复用阶段的内容了。上述完整示例代码可参见Code/Chapter05/C04_ModelSaving/E01_CheckParams.py文件。

5.3.3 保存训练模型#

在PyTorch中对于模型的保存来说非常容易,通常来说通过如下两行代码便可以实现:

1 model_save_path = os.path.join(model_save_dir, 'model.pt')
2 torch.save(model.state_dict(), model_save_path)

在指定保存的模型名称时PyTorch官方建议的后缀为.pt或者.pth(当然也不强制)。最后,只需要在合适的地方加入第2行代码即可完成模型的保存。

同时,如果想要在训练过程中保存某个条件下的最优模型,那么应该通过如下方式:

1 from copy import deepcopy 
2 best_model_state = deepcopy(model.state_dict()) 
3 torch.save(best_model_state, model_save_path)

而不是:

1 best_model_state = model.state_dict() 
2 torch.save(best_model_state, model_save_path)

因为后者best_model_state得到只是model.state_dict()的引用,它依旧会随着训练过程而发生改变。

5.3.4 复用模型推理#

在推理复用模型的过程中,首先需要完成网络的初始化工作,然后再载入已有的模型参数来覆盖网络中的权重参数即可,示例代码如下所示:

 1 def inference(config, test_iter):
 2     test_data = test_iter.dataset
 3     model = LeNet5()
 4 	   model.eval()
 5     if os.path.exists(config.model_save_path):
 6         checkpoint = torch.load(config.model_save_path)
 7         model.load_state_dict(checkpoint)
 8     else:
 9         raise ValueError(f"模型{config.model_save_path}不存在!")
10     y_true = test_data.targets[:5]
11     imgs = test_data.data[:5].unsqueeze(1).to(torch.float32)
12     with torch.no_grad():
13         logits = model(imgs)
14     y_pred = logits.argmax(1)
15     print(f"真实标签为:{y_true}")
16     print(f"预测标签为:{y_pred}")

在上述代码中,第1行传入的是模型配置参数和测试文件,即并没有像之前一样将模型也作为参数传递进来。第3行用于实例化得到一个模型,此时模型中的权重参数都是随机初始化的。第4行是将模型的状态切换至推理状态。第5~7行则是分别先校验本地指定路径中是否已经存在模型文件,如果存在则载入并用其重新初始化网络模型。第10~16行介绍见4.4.3节内容。

5.3.5 复用模型训练#

在介绍完模型的保存与复用之后,模型的追加训练过程就很简单了。在网络训练之前,只需要按照5.3.4节中的方法重新初始化网络权重参数,然后按照正常的步骤训练模型即可。

52