10.12 BERT从零训练#

在前面几节内容中我们已经介绍了几种常见的基于BERT预训练模型的下游任务,在接下来的这节内容中我们将会介绍如何从零实现整个NSP和MLM任务并从头训练BERT模型。通常来说,我们既可以通过MLM和NSP这两个任务来从头训练一个BERT模型,当然也可以在开源预训练模型的基础上再次通过MLM和NSP任务来在特定语料中对模型进行微调,以使得整个模型参数更加符合这一场景,并且一般来说更加倾向于第2种做法。

10.6节内容中我们已经就MLM和NSP这两个任务的原理做了详细介绍,所以这里就不再赘述。一言以蔽之,MLM就是随机掩盖掉部分字符让模型来预测,而NSP则是同时输入模型两句话让模型判断后一句话是否真的为前一句话的下一句话,最终通过这两个任务来训练BERT模型中的权重参数。

10.12.1 构建流程与格式化#

1. 数据集构建流程

在正式介绍数据预处理之前,我们依旧先通过一张流程图来了解一下整个数据集的构建流程。

图 10-38 MLM和NSP数据集构建流程图

如图10-38所示便是整个NSP和MLM任务数据集的构建流程。第①②步是根据原始语料来构造NSP任务所需要的输入和标签。第③步则是随机掩盖掉部分字符来构造MLM任务的输入,并同时进行填充处理。第④步则是根据第③步处理后的结果来构造MLM任务的标签值,其中[P]表示填充的含 义,这样做的目的是为了方面在计算损失时直接忽略那些不需要进行预测的位置。在大致清楚了整个数据集的构建流程后,我们下面就可以一步一步来完成数据集的构建。

同时,为了能够使得整个数据预处理代码具有通用性,同时支持构造不同场景语料下的训练数据集,因此我们需要为每一类不同的数据源定义一个格式化函数来完成标准化的输入。这样即使是换了不同的语料只需要重写一个针对该数据集的格式化函数即可,其余部分的代码都不需要进行改动。

2. 英文数据格式化

这里首先以英文维基百科数据wiki2 [1]为例来介绍如何得到格式化后的标准数据。如下所示便是wiki2中的原始文本数据的存储形式:

1 The development of [UNK] powder , based on [UNK] or [UNK] , by the French inventor Paul [UNK] in 1884 was a     further step allowing smaller charges of propellant with longer barrels . The guns of the pre @-@ [UNK]         battleships of the 1890s tended to be smaller in calibre...
2 The nature of the projectiles also changed during the ironclad period . Initially , the best armor @-@           piercing [UNK] was a solid cast @-@ iron shot . Later , shot of [UNK] iron , a harder iron alloy , gave better   armor @-@ piercing qualities . Eventually the armor @-@ piercing shell was developed .

在上述示例数据中,每一行都表示一个段落,其由一句话或多句话组成。此时我们需要在目录utils下新建create_pretraining_data.py模块,然后定义一个函数来对其进行预处理:

 1 def read_wiki2(filepath=None, seps='.'):
 2     with open(filepath, 'r') as f:
 3         lines = f.readlines() 
 4     paragraphs = []
 5     for line in tqdm(lines, ncols=80, desc=" ## 正在读取原始数据"):
 6         if len(line.split(' . ')) < 2:
 7             continue
 8         line = line.strip()
 9         paragraphs.append([line[0]])
10         for w in line[1:]:
11             if paragraphs[-1][-1][-1] in seps:
12                 paragraphs[-1].append(w)
13             else:
14                 paragraphs[-1][-1] += w
15     random.shuffle(paragraphs)  # 将所有段落打乱
16     return paragraphs

在上述代码中,第1行seps用于指定句子与句子之间的分隔符。第2~3行用于一次读取所有原始数据,每一行为一个段落。第5~14行用于遍历每一个段落, 并进行相应的处理。第6~7行用于过滤掉段落中只有一个句子的情况,因为后续我们要构造NSP任务所以一个段落至少要有两句话。第8行用于去掉整个段落两端的空格或换行符。第9~14行开始遍历段落中的每一句话并进行分割,同时保留了分隔符在句子中。第15行则是将所有的段落给打乱,注意不是句子。

最终,经过read_wiki2()函数处理后,我们便能得到一个标准的二维列表,格式形如:

1 [ [sentence_a1, sentence_a2, ...], [sentence_b1, sentence_b2,...],...,[] ]

上述格式就是后续代码处理时所接受的标准格式,如果需要引入自己的数据那么只需要处理成这样的格式即可。

3. 中文数据格式化

在介绍完英文数据集的格式化过程后我们再来看一个中文原始数据的格式化过程。如下所示便是我们后续所需要用到的中文宋词数据集:

1 红酥手黄縢酒满城春色宫墙柳东风恶欢情薄一怀愁绪几年离索错错错 春如旧人空瘦泪痕红鲛绡透桃花落闲池阁山盟虽     在锦书难托莫莫莫
2 十年生死两茫茫不思量自难忘千里孤坟无处话凄凉纵使相逢应不识尘满面鬓如霜夜来幽梦忽还乡小轩窗正梳妆相顾无言惟有   泪千行料得年年断 肠处明月夜短松冈

在上述示例中,每一行表示一首词,句与句之间通过句号进行分割。下面我们同样需要定义一个函数来对其进行预处理并返回指定的标准格式:

 1 def read_songci(filepath=None, seps='。'):
 2     with open(filepath, 'r', encoding='utf-8') as f:
 3         lines = f.readlines()  
 4     paragraphs = []
 5     for line in tqdm(lines, ncols=80, desc=" ## 正在读取原始数据"):
 6         if "□" in line or "……" in line or len(line.split('。')) < 2:
 7             continue
 8         paragraphs.append([line[0]])
 9         line = line.strip()
10         for w in line[1:]:
11             if paragraphs[-1][-1][-1] in seps:
12                 paragraphs[-1].append(w)
13             else:
14                 paragraphs[-1][-1] += w
15     random.shuffle(paragraphs) 
16     return paragraphs

上述代码中整体上与read_wiki2()函数一致,所以就不再赘述。

10.12.2 数据预处理#

在正式构造NSP任务数据之前,我们需要在create_pretraining_data.py模块中先定义一个类并定义相关的类成员变量以方便在其它成员方法中使用,核心示例代码如下所示:

 1 class LoadBertPretrainingDataset(object):
 2     def __init__(self, vocab_path='./vocab.txt', tokenizer=None,
 3         batch_size=32, max_sen_len=None, max_position_embeddings=512, 
 4         pad_index=0, is_sample_shuffle=True, random_state=2021,
 5         data_name='wiki2', masked_rate=0.15,seps="。",
 6         masked_token_rate=0.8, masked_token_unchanged_rate=0.5):
 7         self.vocab = build_vocab(vocab_path)
 8         self.max_sen_len = max_sen_len
 9         self.pad_index = pad_index
10         self.data_name = data_name
11         self.masked_rate = masked_rate
12         self.masked_token_rate = masked_token_rate
13         random.seed(random_state)

52