第 5 章 模型训练与复用#
经过前面几个章节内容的介绍我们算是已经逐步迈入了深度学习的大门。所谓工欲善其事必先利其器,因此在接下来的这章内容中,我们将会逐一对深度学习模型训练过程中将会用到的一些辅助技能进行介绍,包括:如何有效对模型参数进行管理、怎么从本地文件中载入参数、如何保证模型训练过程的可追溯、模型的持久化与迁移方法以及模型的多GPU训练和预处理结果缓存等内容。
5.1 参数及日志管理#
在深度学习模型的实现过程中由于我们会频繁调整整个模型的超参数,例如需要突然新增一个丢弃率参数或者是模型的控制参数等等,而且这样的操作经常是跨多个函数或模块,如果依旧采用显示的参数名来传递参数就会变得十分复杂。如果模型参数数量较多,通过参数名来传递参数也会显得代码十分臃肿。
同时,由于在深度学习模型中通常会有较多的超参数,模型在训练过程中也会输出相应的评估结果、损失值甚至是部分权重参数结果等,为了使得整个模型的训练过程可追溯,因此就需要有效地将这些信息给保存下来以便不时之需。
5.1.1 参数传递#
例如对于某个深度学习模型来说,其训练部分的函数实现过程如下所示:
1 def train(train_file_path=os.path.join('data', 'train.txt'),
2 val_file_path=os.path.join('data', 'val.txt'),
3 test_file_path=os.path.join('data', 'test.txt'),
4 split_sep='_!_',
5 is_sample_shuffle=True,
6 batch_size=16,
7 learning_rate=3.5e-5,
8 max_sen_len=None,
9 num_labels=3,
10 epochs=5):
11 dataset = get_dataset(train_file_path, val_file_path,max_sen_len,
12 test_file_path, split_sep, is_sample_shuffle)
13 model = get_model(max_sen_len, num_labels)从上述代码可以看出,第1~10行定义很多需要用到的参数,且在第11~13行中分别将这些参数传入到了对应的函数中。这样看起来似乎没有问题,但是此时你需要再添加一个参数到get_model()这个函数中,例如加入丢弃率来提高模型泛化能力,并且get_model()函数里面也是在不同模块都要用到丢弃率这个参数,如果直接采用新加参数的方式那难免会涉及到诸多地方的修改。
因此,对于模型参数有效管理的一种高效做法就是在所有地方均传入一个实例化的类对象,通过类对象访问类成员变量的方式来获取相应的参数值,这样在增删模型参数时只需要在原始类对象实例化的地方修改一次就能实现。首先,需要定义一个配置类,实现代码如下所示:
1 class ModelConfig(object):
2 def __init__(self,train_file_path=os.path.join('data', 'train.txt'),
3 val_file_path=os.path.join('data', 'val.txt'),
4 test_file_path=os.path.join('data', 'test.txt'),
5 split_sep='_!_',is_sample_shuffle=True,
6 batch_size=16,learning_rate=3.5e-5,
7 max_sen_len=None,num_labels=3,epochs=5):
8 self.train_file_path = train_file_path
9 self.val_file_path = val_file_path
10 self.test_file_path = test_file_path
11 self.split_sep = split_sep
12 self.is_sample_shuffle = is_sample_shuffle
13 self.batch_size = batch_size
14 self.learning_rate = learning_rate
15 self.max_sen_len = max_sen_len
16 self.num_labels = num_labels
17 self.epochs = epochs在上述代码中便定义了模型所需要用到的参数,并且可以通过如下方式来进行访问,示例代码如下所示:
1 if __name__ == '__main__':
2 config = ModelConfig(epochs=10)
3 print(f"epochs = {config.epochs}")
4 ## epochs = 10进一步,对于上面train()函数中的示例可以改写为如下形式:
1 def train(config):
2 dataset = get_dataset(config)
3 model = get_mode(config)通过这样的管理方式,即使后续需要在模型中新增参数则只需要在类ModelConfig中新增一个成员变量即可,然后在需要的地方以config.para_name的方式来进行获取。
5.1.2 参数载入#
在上述示例中,我们介绍到了如何通过定义一个ModelConfig来管理模型参数,但是在一些场景中还需要从本地载入一个模型参数文件进行使用。例如在后面介绍BERT模型时就需要载入本地一个名为config.json的参数文件,形式如下:
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
} 对于使用存放在本地文件中的参数一种最直观的方式就是直接将这些参数手动添加到ModelConfig类的成员变量中。当然,通常来说一种更常见的做法是在ModelConfig类中实现一个方法来加载这些本地参数,实现代码如下所示:
1 @classmethod
2 def from_json_file(cls, json_file):
3 with open(json_file, 'r') as reader:
4 text = reader.read()
5 model_config = cls()
6 for (key, value) in dict(json.loads(text)).items():
7 model_config.__dict__[key] = value
8 return model_config在上述代码中,第1行@classmethod表示声明from_json_file()方法为类ModelConfig的一个类方法,作用是在不实例化一个ModelConfig类对象之前一样可以调用类ModelConfig中的方法,即后续可以通过ModelConfig.from_json_file()的形式来进行调用,这一点在后续载入BERT预训练模型时也会遇到。第3~4行是打开配置文件。第6~7行是遍历文件中的每个参数并加入到类ModelConfig的成员变量里,其中 dict(json.loads(text))表示将文本内容转换为dict对象。
最后,通过如下方式便可进行参数加载和访问相关参数:
1 if __name__ == '__main__':
2 config = ModelConfig.from_json_file("./config.json")
3 print(config.hidden_dropout_prob)
4 print(config.hidden_size)
5 # 0.1
6 # 768以上完整示例代码可以参见Code/Chapter05/C01_ConfigManage/E03_LoadConfig.py文件。
5.1.3 定义日志函数#
在模型开发中,可以借助logging这个Python模块来完成上述功能(如果没有的话通过pip install logging命令安装)。同时,为了满足日志信息也能在控制端输出等功能需要基于logging再改进一下,实现代码如下所示:
1 import logging
2 import os,sys
3 def logger_init(log_file_name='monitor',log_level=logging.DEBUG,
4 log_dir='./logs/',only_file=False):
5 if not os.path.exists(log_dir):
6 os.makedirs(log_dir)
7 log_path = os.path.join(log_dir, log_file_name + '_' + str(datetime.now())[:10] + '.txt')
8 formatter = '[%(asctime)s] - %(levelname)s: [%(filename)s][%(lineno)s] %(message)s'
9 datefmt = "%Y-%d-%m %H:%M:%S'"
10 if only_file:
11 logging.basicConfig(filename=log_path, level=log_level,
12 format=formatter, datefmt=datefmt)
13 else:
14 logging.basicConfig(level=log_level, format=formatter, datefmt=datefmt,
15 handlers=[logging.FileHandler(log_path),
16 logging.StreamHandler(sys.stdout)])在上述代码中,第3行中log_file_name用于指定日志文件名的前缀;log_level用于指定日志的输出等级,一般常见的有WARNING、INFO和DEBUG3种,其重要性降序排列(重要性越高输出内容越少);log_dir用于指定日志的保存目录;only_file用于指定是否输出到日志文件。第5~6行用于判断日志目录是否存在,如果不存在则创建。第7行是构建最终日志保存的路径,且同时在文件名后面加上了当天日期。第8~9行是定义日志信息的输出格式,其中lineno表示打印语句所在的行号。第10~16行是根据条件判断日志输出方式。最后,logs文件中将会生成一个类似名为monitor_2023-03-03.txt的日志文件。
5.1.4 日志输出示例#
在完成上述工作后便可以在任意模块或者文件中使用logging来进行日志记录,下面进行一个具体的示例。首先在classA.py中新建了一个名为ClassA的类,代码如下:
1 import logging
2 class ClassA(object):
3 def __init__(self):
4 logging.info(f"我在{__name__}中!")
5 logging.debug(f"我在文件{__file__}中,这是一条debug信息!")
6 logging.warning(f"我在文件{__file__}中,这是一条warning信息!")在上述代码中,第4行__name__表示取当前模块的名称,即classA。第5行__file__表示所在文件的绝对路径。
接着在classB.py中新建了一个名为ClassB的类,代码如下:
1 class ClassB(object):
2 def __init__(self):
3 logging.info(f"我在{__name__}中!")
4 logging.debug(f"我在文件{__file__}中,这是一条debug信息!")最后在main.py中调用这两个类,并输出相应的日志信息,代码如下:
1 from classA import ClassA
2 from classB import ClassB
3 from log_manage import logger_init
4 import logging
5
6 def log_test():
7 a = ClassA()
8 b = ClassB()
9 logging.info(f"我在{__name__}中!")
10
11 if __name__ == '__main__':
12 logger_init(log_file_name='monitor', log_level=logging.INFO,
13 log_dir='./logs',only_file=False)
14 log_test()在运行完上述代码后,日志文件monitor_2023_03_04.txt和终端里就会输出如下所示的日志信息:
1 我在classA中!
2 我在文件DeepLearningWithMe/Code/Chapter05/C02_LogManage/classA.py中,这是一条warning信息!
3 我在classB中!
4 我在__main__中!可以发现,classA和classB这两个模块中的日志信息都有被打印出来,而且也都满足了跨模块日志打印的需求。但是可以发现,logging.debug这样的信息并没有打印出来,其原因就在于通过logger_init()函数初始化时指定的日志输出等级为logging.INFO,这就意味着不会输出调试信息。当然,只需要将log_level指定为logging.DEBUG即可输出所有信息。
5.1.5 打印模型参数#
在介绍完日志的打印输出方法后,进一步只需要在上面ModelConfig类的定义中加入如下几行代码便可以在模型训练时打印相关的模型信息:
1 class ModelConfig(object):
2 def __init__(self, ):
3 .....
4 logging.info("#### <----------------------->")
5 for key, value in self.__dict__.items():
6 logging.info(f"## {key} = {value}")
7
8 if __name__ == '__main__':
9 logger_init(log_file_name='monitor', log_level=logging.DEBUG,
10 log_dir='./logs', only_file=False)
11 config = ModelConfig()在上述代码中,第4~6行用于遍历类中所有的成员变量(即模型参数)并打印输出。最后,在控制台和日志文件中便会输出类似如下的信息:
1 #### <----------------------->
2 ## batch_size = 16
3 ## learning_rate = 3.5e-05
4 ## num_labels = 3
5 ## epochs = 5以上完整示例代码可以参见Code/Chapter05/C02_LogManage文件夹。
5.1.6 小结#
在本节内容中,我们首先介绍了在编写代码模型的过程中参数管理的重要性和必要性,并介绍了如何定义一个类配置类并通过类成员的方式来管理和获取参数;然后详细介绍了如何载入本地文件中的参数值并添加到配置类中进行使用;接着进一步介绍了如何基于logging模块来定义一个初始化函数;最后详细展示了如何来使用logging在各个模块中将相关信息打印到同一个日志文件中。在实际使用过程中只需要在需要输出日志信息的地方通过函数logging.info()进行打印,然后在主函数运行的地方调用logger_init()函数来初始化即可完成日志信息的输出或打印。