5.5 开源模型复用#

在前面两节内容中我们陆续介绍了在PyTorch框架中模型保存和迁移的基本原理,在接下来的这节内容中我们将以ResNet18在ImageNet上训练得到的1000分类预训练模型为例,将其迁移到CIFAR10数据集上进行微调。总体上来讲,我们首先需要实例化一个ResNet模型,然后再用预训练模型对其初始化;然后再将原始ResNet中最后一个1000分类的分类层改为CIFAR10数据对应10分类层;最后在CIFAR10数据集上完成整个模型的微调。以下完整示例代码可以参见Code/Chapter05/C06_PretrainedModel/文件。

5.5.1 ResNet结构介绍#

为了方便使用PyTorch官方开源的预训练模型,下面我们直接使用PyTorch框架中ResNet模型。同时,为了便于后续模型迁移理解,这里先简单介绍一下PyTorch中ResNet实现部分的代码。在PyTorch框架中,可以通过如下2代码来实例化一个残差网络,以ResNet18为例,示例代码如下:

1 from torchvision.models import resnet18
2 model = resnet18()

其中函数resnet18的实现过程为:

1 def resnet18(*, weights = None, ):
2     weights = ResNet18_Weights.verify(weights)
3     return _resnet(BasicBlock, [2, 2, 2, 2], weights,...)

在上述代码中,第2行是验证传入的预训练模型是否合法。第3行则是根据残差结构的数量返回ResNet18模型。

进一步,_restnet函数中的核心部分为:

1 model = ResNet(block, layers, **kwargs)

在上述代码中返回的便是一个残差网络的实例化对象,而类ResNet中的网络结构定义过程为:

 1 class ResNet(nn.Module):
 2     def __init__(self,......):
 3         super().__init__()
 4         ......
 5         self.layer1 = self._make_layer(block, 64, layers[0])
 6         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
 7         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
 8         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
 9         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
10         self.fc = nn.Linear(512 * block.expansion, num_classes)

在上述代码中,第5~9行便是相应的残差结构和全局平均池化层。第10行则是对应最后的分类层,而有序将ResNet18迁移到CIFAR10数据集上需要修改的便是最后一个分类层。

5.5.2 迁移模型构造#

在清楚PyTorch中ResNet模型的基本实现结构之后,我们便可以对其进行修改以适应CIFAR10数据集。示例代码如下所示:

 1 from torchvision.models import resnet18
 2 from torchvision.models import ResNet18_Weights
 3 
 4 class ResNet18(nn.Module):
 5     def __init__(self, num_classes=10, frozen=False):
 6         super(ResNet18, self).__init__()
 7         self.resnet18 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
 8         if frozen:
 9             for (name, param) in self.resnet18.named_parameters():
10                 param.requires_grad = False
11                 logging.info(f"冻结参数: {name}, {param.shape}")
12         self.resnet18.fc = nn.Linear(512, num_classes)

在上述代码中,第7行是返回一个实例化的18层残差网络,同时指定了需要通过预训练模型来对其进行初始化。第8~11行则是用来判断是否需要对预训练部分的参数进行冻结,即不参与后续模型的训练过程,当然也可根据需要修改为对其中一部分参数进行冻结。第12行则是将原始残差网络的最后一层替换为符合新数据集的分类层。

进一步,其对应的前向传播实现过程为:

1     def forward(self, x, labels=None):
2         logits = self.resnet18(x)
3         if labels is not None:
4             loss_fct = nn.CrossEntropyLoss(reduction='mean')
5             loss = loss_fct(logits, labels)
6             return loss, logits
7         else:
8             return logits

然后,可通过如下方式打印网络结构信息:

1 if __name__ == '__main__':
2     model = ResNet18(frozen=True)
3     x = torch.rand(1, 3, 96, 96)
4     out = model(x)
5     print(out)
6     for (name, param) in model.named_parameters():
7         print(f"name = {name,param.shape} requires_grad = {param.requires_grad}")

在上述代码中,第2行用于实例化一个残差网络并且冻结相关的预训练参数。第5行则是输出前向传播最后的结果。第6~7行是查看模型中的权重参数是否被冻结。

 1 冻结参数: conv1.weight, torch.Size([64, 3, 7, 7])
 2 冻结参数: bn1.weight, torch.Size([64])
 3 冻结参数: bn1.bias, torch.Size([64])
 4 ......
 5 冻结参数: layer4.1.bn2.weight, torch.Size([512])
 6 冻结参数: layer4.1.bn2.bias, torch.Size([512])
 7 冻结参数: fc.weight, torch.Size([1000, 512])
 8 冻结参数: fc.bias, torch.Size([1000])
 9 tensor([[-1.3807, -0.2270,  0.4926,  0.6058,  1.0789,  0.0495,  1.0578,  0.4514,
10           0.2397, -0.2712]], grad_fn=<AddmmBackward0>)
11 ......
12 name = ('resnet18.layer4.1.bn2.weight', torch.Size([512])) requires_grad = False
13 name = ('resnet18.layer4.1.bn2.bias', torch.Size([512])) requires_grad = False
14 name = ('resnet18.fc.weight', torch.Size([10, 512])) requires_grad = True
15 name = ('resnet18.fc.bias', torch.Size([10])) requires_grad = True

在上述输出结果中,第1~8行为原始ResNet18的参数信息,且均已经被冻结。第9~15行为迁移后残差网络的相关输出信息,其中第10~11行是前向传播输出结果,第11~15行则是各层权重的名称、形状及是否被冻结等信息,从这里可以看出除了最后两层之外其余层的参数均不参与训练,且最后一个分类层也已经变成了10分类。

到此,对于迁移模型的网络结构实现就介绍完了,整个网络训练代码与 4.9节中相同这里就不再赘述,各位读者直接阅读代码即可。

5.5.3 结果对比#

在完成模型的训练过程后,我们可以将原始ResNet18模型、迁移冻结后的ResNet18模型以及进行微调的ResNet18模型这三者在CIFAR10上的结果进行一个简单的对比,如表5-1所示。

表 5-1 模型分类准确率对比

从表5-1中的结果可以看出,如果整个网络模型的权重都随机初始化,那么虽然第1轮迭代结束后它在测试集上的准确率最差,但是随后却超越了冻结整个预训练参数只有分类层参与训练的模型。同时,在这3种情况中,将预训练模型一同进行微调时的效果最好,经过50轮迭代之后在测试集上的准确率达到了$90\%$以上。

5.5.4 小结#

在本节内容中,我们首先介绍了PyTorch框架中ResNet残差网络的基本实现逻辑;然后详细介绍了如何基于预训练模型来完成ResNet18的迁移任务并对相关输出结果进行了分析;最后,对比了3种不同初始化方法或训练策略的残差模型在CIFAR10数据集上的分类准确率。

引用#

[1] Paszke A, Gross S, Massa F, et al. Pytorch: An imperative style, high-performance deep learning library[J]. Advances in neural information processing systems, 2019, 32.