更新于 2026年6月29日

10.7 从零实现BERT#

经过「第10.6节 BERT原理:双向编码器预训练模型解析」内容的介绍,我们对于BERT模型的整体结构已经有了一定的了解。根据图10-26可知,从本质上来说BERT就是由Transformer中的编码器构建而来,同时在输入层部分额外加入了一个句子编码来区分输入的不同部分。在本节内容中,我们将以图10-26中黑色加粗字体所示的部分为一个类来分别构建实现整个BERT模型。

10.7.1 工程结构#

由于整个项目涉及到的代码模块较多,所以我们在这里先进行简单说明, 这样便于各位读者在阅读后续内容时能够快速地定位到相应的代码部分。同时, 这里也建议各位读者在阅读内容的能够同时结合代码一起阅读并动手实践。

在整个工程项目中一共包含有6个主要的文件目录,cache、data、model、Tasks、test和uils。其中cache目录用来存放训练过程中所保存下来的模型;data用来存放各类数据集,包括后续将要用到的文本分类、问题回答和训练预料等;model目录中存放的是整个BERT模型的实现代码,以及相关下游任务的构造模型;Tasks目录中存放的是model中各个任务对应的模型训练代码;test目录中存放的是各个模块的测试案例,用于在实现过程中的验证,后续内容中的相关使用示例都能在其中找到;utils目录中存放的是一些辅助工具模块,包括数据集构建和日志模块等。

10.7.2 Input Embedding实现#

首先,我们先来看Input Embedding的实现过程。由于在「第10.4节 Transformer实现:从模块到代码的搭建过程」内容中我们已经介绍过了字符嵌入的实现,所以在复用这部分代码之后只需要再实现位置编码和句子编码即可。本节内容及后续多个下游任务的完整示例代码可参见Code/Chapter10/C04_BERT文件。

1. Positional Embedding实现

不同于Transformer中位置编码的实现方式,在BERT中位置编码并没有采用固定的变换公式来计算每个位置上的值,而是采用了普通嵌入层的方式来为每个位置生成一个向量然后随着模型一起训练。因此,这也就限制了在使用预训练的中文BERT模型时最大的序列长度只能是512,因为在训练时只初始化了512个位 置向量。示例代码如下所示:

1 class PositionalEmbedding(nn.Module):
2     def __init__(self, hidden_size, max_position_embeddings=512, 
3                  initializer_range=0.02):
4         super(PositionalEmbedding, self).__init__()
5         self.embedding = nn.Embedding(max_position_embeddings, hidden_size)
6 
7     def forward(self, position_ids):
8         return self.embedding(position_ids).transpose(0, 1)

在上述代码中,第7行position_ids的形状为[1,position_ids_len]。第8行返回结果的形状为[position_ids_len, 1, hidden_size]

2. Segment Embedding 实现

句子编码是对输入的两个序列分别赋予一个位置向量用以区分各自所在的位置,这一点可以和上面的位置编码进行类比。具体地,示例代码如下所示:

1 class SegmentEmbedding(nn.Module):
2     def __init__(self, type_vocab_size, hidden_size, initializer_range=0.02):
3         super(SegmentEmbedding, self).__init__()
4         self.embedding = nn.Embedding(type_vocab_size, hidden_size)
5 
6     def forward(self, token_type_ids):
7         return self.embedding(token_type_ids)

在上述代码中,第2行type_vocab_size的默认值为 2,即只用于区分两个序列的不同位置。第6行token_type_ids的形状为[token_type_ids_len, batch_size]。第7行返回结果的形状为[token_type_ids_len, batch_size, hidden_size]

3. Bert Embedding实现

在完成3个部分的代码实现之后,只需要将每个部分的结果相加便可以得到最终的嵌入层表示作为BERT模型的输入,示例代码如如下所示:

 1 class BertEmbeddings(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.word_embeddings = TokenEmbedding(config.vocab_size,
 5                                       config.hidden_size,config.pad_token_id)
 6         self.position_embeddings = PositionalEmbedding(
 7                            config.max_position_embeddings,config.hidden_size)
 8         self.token_type_embeddings = SegmentEmbedding(config.type_vocab_size,
 9                                  config.hidden_size,config.initializer_range)
10         self.LayerNorm = nn.LayerNorm(config.hidden_size)
11         self.dropout = nn.Dropout(config.hidden_dropout_prob)
12         self.register_buffer("position_ids",
13                 torch.arange(config.max_position_embeddings).expand((1, -1)))

在上述代码中,第2行config是传入的一个配置类,里面各个类成员就是BERT模型中对应的模型参数。第 4~8行是分别用来定义图10-27中的3个编码部分。第12~13行是用来生成一个默认的位置编号,即[0,1,....,511]

进一步,其前向传播过程代码为:

 1     def forward(self,input_ids=None,position_ids=None,token_type_ids=None):
 2         src_len = input_ids.size(0)
 3         token_emb = self.word_embeddings(input_ids)
 4         if position_ids is None: 
 5             position_ids = self.position_ids[:, :src_len]
 6         positional_emb = self.position_embeddings(position_ids)
 7         if token_type_ids is None: 
 8             token_type_ids = torch.zeros_like(input_ids)
 9         segment_emb = self.token_type_embeddings(token_type_ids)
10         embeddings = token_emb + positional_emb + segment_emb
11         embeddings = self.LayerNorm(embeddings)
12         embeddings = self.dropout(embeddings)
13         return embeddings

在上述代码中,第1行input_ids表示输入序列的原始索引编号,即根据词表映射后的索引形状为[src_len, batch_size]。第4~6行position_ids是位置序列, 本质是[0,1,2,3,...,src_len-1]形状为[1,src_len],在实际建模时这个参数可以不用传值,因为当其为空时会自动从self.position_ids截取一段。第7~9行token_type_ids用于不同序列之间的分割,例如[0,0,0,0,1,1,1,1]用于区分前后不同的两个句子形状为[src_len,batch_size]。如果输入模型的只有一个序列,那么这个参数也不用传值。第 10~12行代码则是用来将3部分的编码结果进行相加。

4. 使用示例

在实现完上述代码之后,便可以通过如下方式进行使用:

1 if __name__ == '__main__':
2     json_file = '../bert_base_chinese/config.json'
3     config = BertConfig.from_json_file(json_file)
4     src = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], dtype=torch.long)
5     token_type_ids = torch.LongTensor([[0, 0, 0, 1, 1], [0, 0, 1, 1, 1]])
6     src, token_type_ids = src.transpose(0, 1), token_type_ids.transpose(0, 1)  
7     bert_embedding = BertEmbeddings(config)
8     bert_embedding_result = bert_embedding(src, token_type_ids=token_type_ids)
9     print(bert_embedding_result.shape) #  torch.Size([5, 2, 768])

在上述代码中,第2~3行是载入原始的BERT模型配置文件,里面包含了hidden_sizemax_position_embeddings等默认参数的取值。第4~6行是生成输入层对应的输入部分。第7~9行是实例化BERT嵌入层并计算前向传播的输出结果,形状为[src_len, batch_size, hidden_size]

10.7.3 BERT网络实现#

在实现完Input Embedding部分的代码后,下面可以着手来实现构成BERT模型的第2个重要组成部分BertEncoder。如图10-26所示,整个 BertEncoder由多个BertLayer堆叠形成;而BertLayer又是由BertOutputBertIntermediateBertAttention这3个部分组成;而BertAttention是由BertSelfAttentionBertSelfOutput所构成。之所以需要将整个模型拆分成各个模块进行实现,主要是为了降低功能模块之间的耦合性,以便按需进行调整。

1. BertaAttention实现

对于BertAttention来说其核心是Transformer中所提出的self-attention机制,即图10-26中的BertSelfAttention模块;其次是一个残差连接和标准化操作。对于BertSelfAttention的实现,示例代码如下所示:

 1 class BertSelfAttention(nn.Module):
 2     def __init__(self, config):
 3         super(BertSelfAttention, self).__init__()
 4         if 'use_torch_multi_head' in config.__dict__ and config.use_torch_multi_head:
 5             MultiHeadAttention = nn.MultiheadAttention
 6         else:
 7             MultiHeadAttention = MyMultiheadAttention
 8         self.multi_head_attention = MultiHeadAttention(config.hidden_size,
 9                 config.num_attention_heads,config.attention_probs_dropout_prob)
10 
11     def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
12         return self.multi_head_attention(query, key, value, 
13             attn_mask=attn_mask, key_padding_mask=key_padding_mask)

在上述代码中,第4~10行是实例化一个多头注意力机制对象,并且这里我们提供了两种多头实现,一种「第10.3节 Transformer结构:编码器解码器整体架构解析」内容中介绍的都头注意力实现,另一种是直接使用PyTorch框架中的默认实现,可以通过设置参数use_torch_multi_head = True进行切换。第12~15行则是多头注意力的前向传播过程,其返回包含两个部分,多头注意力的线性组合以及多头注意力权重的均值,形状分别为[tgt_len, batch_size, hidden_size][batch_size, tgt_len, src_len]

进一步,对于BertSelfOutput的实现包括层 Dropout、标准化和残差连接3个操作,示例代码如下:

 1 class BertSelfOutput(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
 5         self.dropout = nn.Dropout(config.hidden_dropout_prob)
 6 
 7     def forward(self, hidden_states, input_tensor):
 8         hidden_states = self.dropout(hidden_states)
 9         hidden_states = self.LayerNorm(hidden_states + input_tensor)
10         return hidden_states

上述代码便是BertSelfOutput的实现,其过程也十分简单这里就不再赘述,最后第10行返回结果的形状为[src_len, batch_size, hidden_size]

接下来就是对BertAttention部分进行实现,其由BertSelfAttentionBertSelfOutput这两个类构成,示例代码如下所示:

 1 class BertAttention(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.self = BertSelfAttention(config)
 5         self.output = BertSelfOutput(config)
 6 
 7     def forward(self,hidden_states,attention_mask=None):
 8         self_outputs = self.self(hidden_states,hidden_states,hidden_states,
 9                             attn_mask=None,key_padding_mask=attention_mask)
10         attention_output = self.output(self_outputs[0], hidden_states)
11         return attention_output

在上述代码中,第7行hidden_states是输入层处理后的结果,形状为[src_len, batch_size, hidden_size]attention_mask是同一个小批量样本中不同长度序列的掩码填充信息,即在第10.4节内容中所介绍的key_padding_mask,形状为[batch_size, src_len],这里只是为了和PyTorch中的命名方式保持一致。第8~9行是自注意力机制的输出结果。第10~11行便是执行BertSelfOutput中的3个操作,最后返回结果的形状为[src_len, batch_size, hidden_size]

2. BertLayer实现

根据图10-26可知,BertLayer里还有BertOutputBertIntermediate这两个模块,因此下面先来实现这两个部分。对于BertIntermediate来说也是一个普通的全连接层,示例代码如下所示:

 1 class BertIntermediate(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
 5         if isinstance(config.hidden_act, str):
 6             self.intermediate_act_fn = get_activation(config.hidden_act)
 7         else:
 8             self.intermediate_act_fn = config.hidden_act
 9 
10     def forward(self, hidden_states):
11         hidden_states = self.dense(hidden_states)
12         if self.intermediate_act_fn is None:
13             hidden_states = hidden_states
14         else:
15             hidden_states = self.intermediate_act_fn(hidden_states)
16         return hidden_states

在上述代码中,第6行用来根据指定参数获取激活函数。第11~15行则是根据对应的激活函数对输入进行非线性变化。第16行是最后返回的结果,形状为[src_len, batch_size, intermediate_size]

进一步,对于BertOutput来说,其包含有一个全连接层和残差连接, 实现代码如下所示:

 1 class BertOutput(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
 5         self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
 6         self.dropout = nn.Dropout(config.hidden_dropout_prob)
 7 
 8     def forward(self, hidden_states, input_tensor):
 9         hidden_states = self.dense(hidden_states)
10         hidden_states = self.dropout(hidden_states)
11         hidden_states = self.LayerNorm(hidden_states + input_tensor)
12         return hidden_states

在上述代码中,第8行里hidden_states指的就是BertIntermediate模块的输出,而input_tensor则是BertAttention部分的输出。

在实现完这两个部分的代码后,便可以通过BertAttentionBertIntermediateBertOutput这3个部分来实现组合的BertLayer部分,示例代码如下所示:

 1 class BertLayer(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.bert_attention = BertAttention(config)
 5         self.bert_intermediate = BertIntermediate(config)
 6         self.bert_output = BertOutput(config)
 7 
 8     def forward(self,hidden_states,attention_mask=None):
 9         attention_output = self.bert_attention(hidden_states, attention_mask)
10         intermediate_output = self.bert_intermediate(attention_output)
11         layer_output = self.bert_output(intermediate_output, attention_output)
12         return layer_output

到此,对于BertLayer部分的实现就介绍完了,下面继续来看如何对BertEncoder进行实现。

3. BertEncoder实现

根据图10-26可知,BERT主要由Input EmbeddingBertEncoder这两部分构成,而BertEncoder是由多个BertLayer堆叠所形成,因此需要先实现BertEncoder,示例代码如下:

 1 class BertEncoder(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.config = config
 5         self.bert_layers = nn.ModuleList([BertLayer(config)
 6                           for _ in range(config.num_hidden_layers)])
 7 
 8     def forward(self,hidden_states,attention_mask=None):
 9         all_encoder_layers = []
10         layer_output = hidden_states
11         for i, layer_module in enumerate(self.bert_layers):
12             layer_output = layer_module(layer_output,attention_mask)
13             all_encoder_layers.append(layer_output)
14         return all_encoder_layers

在上述代码中,第5~6行是用来实例化多个BertLayer层。第11~13行用来循环计算多层BertLayer堆叠后的输出结果,其中每一层的输出结果形状为[src_len, batch_size, hidden_size]。最后,只需要按需将 BertEncoder部分的输出结果输入到下游任务即可。

进一步,在将BertEncoder部分的输出结果输入到下游任务前,需要将其进行略微的处理,示例代码如下所示:

 1 class BertPooler(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 5         self.activation = nn.Tanh()
 6         self.config = config
 7 
 8     def forward(self, hidden_states):
 9         if self.config.pooler_type == "first_token_transform":
10             token_tensor = hidden_states[0, :].reshape(-1, self.config.hidden_size)
11         elif self.config.pooler_type == "all_token_average":
12             token_tensor = torch.mean(hidden_states, dim=0)
13         pooled_output = self.dense(token_tensor)
14         pooled_output = self.activation(pooled_output)
15         return pooled_output 

在上述代码中,第9~10行代码用来取BertEncoder输出的第一个位置,即[CLS]位置对应的编码向量。例如在进行文本分类时可以取该位置上的结果进行下一步的分类处理。第11~12行是我们自己额外加入的一个选项,表示取所有位置的平均值,当然还也可以根据自己的需要在添加下面添加其它的方式。注意,此时需要在config.json这个配置文件中加入pooler_type这个字段。第13~15行是一个普通 的全连接层,最后输出结果的形状[batch_size, hidden_size]

4. BertModel实现

在对BERT模型中的各个基础模块实现完成后,根据图10-26所示只需要再将各个部分的代码组合到一起便完成了BERT模型的实现,示例代码如下所示:

 1 class BertModel(nn.Module):
 2     def __init__(self, config):
 3         super().__init__()
 4         self.bert_embed = BertEmbeddings(config)
 5         self.bert_encoder = BertEncoder(config)
 6         self.bert_pooler = BertPooler(config)
 7         self.config = config
 8 
 9     def forward(self,input_ids=None,attention_mask=None,
10                 token_type_ids=None,position_ids=None):
11         embed_output = self.bert_embed(input_ids,position_ids,token_type_ids)
12         all_encoder_outputs = self.bert_encoder(embed_output,attention_mask)
13         sequence_output = all_encoder_outputs[-1]
14         pooled_output = self.bert_pooler(sequence_output)
15         return pooled_output, all_encoder_outputs

在上述代码中,第4~6行是分别实例化BERT模型中的各个模块。第9~10行是BERT模型的输入,其中input_ids的形状为[src_len, batch_size]attention_mask的形状为[batch_size, src_len]token_type_ids的形状为[src_len, batch_size]。第11行便是输入层的输出结 果,形状为[src_len, batch_size, hidden_size]。第12行是整个BERT编码部分的输出,其中all_encoder_outputs为一个包含有num_hidden_layers个层的输出。第13行是处理得到整个BERT 网络的输出,这里取了最后一层的输出,形状为[src_len, batch_size, hidden_size]。第14行默认是最后一层的第1个向量,即[CLS]位置经BertPooler层后的结果,其形状为[batch_size, hidden_size]

5. 使用示例

在完成上述整个BERT模型的代码实现后可以通过如下方式来进行使用:

1 if __name__ == '__main__':
2     src = torch.tensor([[1, 3, 5, 7, 9, 2, 3], [2, 4, 6, 8, 10, 0, 0]], dtype=torch.long)
3     token_type_ids = torch.LongTensor([[0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 0, 0]])
4     attention_mask = torch.tensor([[False, False, False, False, False, True, True],
5                                    [False, False, False, False, False, False, True]])
6     src, token_type_ids = src.transpose(0, 1)token_type_ids.transpose(0, 1)
7     bert_model = BertModel(config)  
8     bert_model_output = bert_model(src,attention_mask,token_type_ids)[0]
9     print(bert_model_output.shape) # torch.Size([2, 768])

在上述代码中,第2~6行是定义模型相关输入,其中attention_mask向量中True表示该位置为填充值。第7行是实例化一个BERT模型。第8~9行是模型的前向传播计算结果,形状为[batch_size, hidden_size]

10.7.4 小结#

在本节内容中,我们首先介绍了整个工程的目录结构;然后分别依次介绍了BERT模型中Input Embedding层各个部分的实现;最后分模块详细介绍了BERT模型中BertaAttentionBertLayerBertEncoderBertModel的实现过程,并进行了使用示例。在下节内容中,我们将介绍第1个基于BERT预训练模型的下游文本任务的构建过程。

您当前阅读的内容现已出版,点击右侧了解

10章教学课件,400余幅示意插图、40个示例源代码,助力读者轻松迈入深度学习的大门!

查看详情
阅读 --

10.4 Transformer实现

在前面两节内容中我们分别详细介绍了Transformer模型的原理与多头注意力机制的实现过程,接下来,我们将会一步一步地来详细介绍如何通过 PyTorch框架实现Transformer的整体网络结构, 包括嵌入层、编码器和解码器等等。下面, …

10.3 Transformer结构

在10.2节内容中我们详细介绍了自注意力机制的动机和原理,在介绍下来的这节内容中我们将继续介绍Transformer的整个网络结构,以及多头注意力机制的实现。