5.5 开源模型复用#
在前面两节内容中我们陆续介绍了在PyTorch框架中模型保存和迁移的基本原理,在接下来的这节内容中我们将以ResNet18在ImageNet上训练得到的1000分类预训练模型为例,将其迁移到CIFAR10数据集上进行微调。总体上来讲,我们首先需要实例化一个ResNet模型,然后再用预训练模型对其初始化;然后再将原始ResNet中最后一个1000分类的分类层改为CIFAR10数据对应10分类层;最后在CIFAR10数据集上完成整个模型的微调。以下完整示例代码可以参见Code/Chapter05/C06_PretrainedModel/文件。
5.5.1 ResNet结构介绍#
为了方便使用PyTorch官方开源的预训练模型,下面我们直接使用PyTorch框架中ResNet模型。同时,为了便于后续模型迁移理解,这里先简单介绍一下PyTorch中ResNet实现部分的代码。在PyTorch框架中,可以通过如下2代码来实例化一个残差网络,以ResNet18为例,示例代码如下:
1 from torchvision.models import resnet18
2 model = resnet18()其中函数resnet18的实现过程为:
1 def resnet18(*, weights = None, ):
2 weights = ResNet18_Weights.verify(weights)
3 return _resnet(BasicBlock, [2, 2, 2, 2], weights,...)在上述代码中,第2行是验证传入的预训练模型是否合法。第3行则是根据残差结构的数量返回ResNet18模型。
进一步,_restnet函数中的核心部分为:
1 model = ResNet(block, layers, **kwargs)在上述代码中返回的便是一个残差网络的实例化对象,而类ResNet中的网络结构定义过程为:
1 class ResNet(nn.Module):
2 def __init__(self,......):
3 super().__init__()
4 ......
5 self.layer1 = self._make_layer(block, 64, layers[0])
6 self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
7 self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
8 self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
9 self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
10 self.fc = nn.Linear(512 * block.expansion, num_classes)在上述代码中,第5~9行便是相应的残差结构和全局平均池化层。第10行则是对应最后的分类层,而有序将ResNet18迁移到CIFAR10数据集上需要修改的便是最后一个分类层。
5.5.2 迁移模型构造#
在清楚PyTorch中ResNet模型的基本实现结构之后,我们便可以对其进行修改以适应CIFAR10数据集。示例代码如下所示:
1 from torchvision.models import resnet18
2 from torchvision.models import ResNet18_Weights
3
4 class ResNet18(nn.Module):
5 def __init__(self, num_classes=10, frozen=False):
6 super(ResNet18, self).__init__()
7 self.resnet18 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
8 if frozen:
9 for (name, param) in self.resnet18.named_parameters():
10 param.requires_grad = False
11 logging.info(f"冻结参数: {name}, {param.shape}")
12 self.resnet18.fc = nn.Linear(512, num_classes)在上述代码中,第7行是返回一个实例化的18层残差网络,同时指定了需要通过预训练模型来对其进行初始化。第8~11行则是用来判断是否需要对预训练部分的参数进行冻结,即不参与后续模型的训练过程,当然也可根据需要修改为对其中一部分参数进行冻结。第12行则是将原始残差网络的最后一层替换为符合新数据集的分类层。
进一步,其对应的前向传播实现过程为:
1 def forward(self, x, labels=None):
2 logits = self.resnet18(x)
3 if labels is not None:
4 loss_fct = nn.CrossEntropyLoss(reduction='mean')
5 loss = loss_fct(logits, labels)
6 return loss, logits
7 else:
8 return logits然后,可通过如下方式打印网络结构信息:
1 if __name__ == '__main__':
2 model = ResNet18(frozen=True)
3 x = torch.rand(1, 3, 96, 96)
4 out = model(x)
5 print(out)
6 for (name, param) in model.named_parameters():
7 print(f"name = {name,param.shape} requires_grad = {param.requires_grad}")在上述代码中,第2行用于实例化一个残差网络并且冻结相关的预训练参数。第5行则是输出前向传播最后的结果。第6~7行是查看模型中的权重参数是否被冻结。
1 冻结参数: conv1.weight, torch.Size([64, 3, 7, 7])
2 冻结参数: bn1.weight, torch.Size([64])
3 冻结参数: bn1.bias, torch.Size([64])
4 ......
5 冻结参数: layer4.1.bn2.weight, torch.Size([512])
6 冻结参数: layer4.1.bn2.bias, torch.Size([512])
7 冻结参数: fc.weight, torch.Size([1000, 512])
8 冻结参数: fc.bias, torch.Size([1000])
9 tensor([[-1.3807, -0.2270, 0.4926, 0.6058, 1.0789, 0.0495, 1.0578, 0.4514,
10 0.2397, -0.2712]], grad_fn=<AddmmBackward0>)
11 ......
12 name = ('resnet18.layer4.1.bn2.weight', torch.Size([512])) requires_grad = False
13 name = ('resnet18.layer4.1.bn2.bias', torch.Size([512])) requires_grad = False
14 name = ('resnet18.fc.weight', torch.Size([10, 512])) requires_grad = True
15 name = ('resnet18.fc.bias', torch.Size([10])) requires_grad = True在上述输出结果中,第1~8行为原始ResNet18的参数信息,且均已经被冻结。第9~15行为迁移后残差网络的相关输出信息,其中第10~11行是前向传播输出结果,第11~15行则是各层权重的名称、形状及是否被冻结等信息,从这里可以看出除了最后两层之外其余层的参数均不参与训练,且最后一个分类层也已经变成了10分类。
到此,对于迁移模型的网络结构实现就介绍完了,整个网络训练代码与 4.9节中相同这里就不再赘述,各位读者直接阅读代码即可。
5.5.3 结果对比#
在完成模型的训练过程后,我们可以将原始ResNet18模型、迁移冻结后的ResNet18模型以及进行微调的ResNet18模型这三者在CIFAR10上的结果进行一个简单的对比,如表5-1所示。

从表5-1中的结果可以看出,如果整个网络模型的权重都随机初始化,那么虽然第1轮迭代结束后它在测试集上的准确率最差,但是随后却超越了冻结整个预训练参数只有分类层参与训练的模型。同时,在这3种情况中,将预训练模型一同进行微调时的效果最好,经过50轮迭代之后在测试集上的准确率达到了$90\%$以上。
5.5.4 小结#
在本节内容中,我们首先介绍了PyTorch框架中ResNet残差网络的基本实现逻辑;然后详细介绍了如何基于预训练模型来完成ResNet18的迁移任务并对相关输出结果进行了分析;最后,对比了3种不同初始化方法或训练策略的残差模型在CIFAR10数据集上的分类准确率。
引用#
[1] Paszke A, Gross S, Massa F, et al. Pytorch: An imperative style, high-performance deep learning library[J]. Advances in neural information processing systems, 2019, 32.