5.7 数据预处理与缓存#
随着任务场景和深度学习模型的复杂化,使得模型在训练过程中每次调试时都需要花费较长的时间来等待数据集预处理结果。一个简单直接的办法就是在模型每次载入数据集时都预先判断本地是否有对应的缓存文件,如果有则直接载入,没有则重新处理并进行缓存。同时,为了方便这段处理逻辑能够方便地迁移到其它类似情况,因此我们需要将其定义成一个Python修饰器来进行使用。
下面,我们先来简单介绍一个Python中修饰器的功能及用法。
5.7.1 修饰器介绍#
关于什么是修饰器或装饰器(Decorator)我们这里就不从Python语法上来做详细的解释了。简单一句话,修饰器的作用的就是在正式执行某个功能函数之前,预先执行你想要执行的某些逻辑。例如在进行数据预处理之前先判断是否有对应的缓存文件。下面,我们直接从用法的层面来逐步了解Python中的修饰器。
首先来看这样一个场景,假如你之前已经定义了多个功能函数,但此时需要在日志文件中同时也输出每个函数的实际运行时间和其它相关信息。例如:
打印出当前主程序正在调用哪个功能函数的信息。例如:
1 def add(a=1, b=2):
2 time.sleep(2)
3 r = a + b
4 return r
5
6 def subtract(a=1, b=2):
7 time.sleep(3)
8 r = a - b
9 return r在上述代码中,time.sleep(2)是为了模拟运行所花费的时间。
进一步,对于上述两个函数,如果需要实现打印运行时间等相关信息,可以通过如下类似方式实现:
1 def add(a=1, b=2):
2 print(f"正在执行函数 add() !")
3 start_time = time.time()
4 time.sleep(2)
5 r = a + b
6 end_time = time.time()
7 print(f"一共耗时{(end_time - start_time):.3f}s")
8 return r在上述代码中,第2、3、6和7行便是需要打印输出的相关信息。虽然通过这样的方式也能解决,但是如果有大量的函数都需要添加这么一段逻辑,那这种做法显然不可取。另外一种高效的方法则是使用Python中的修饰器。
假如现在已经定义好了一个名为get_info的修饰器,那么只需要通过如下方式便可以打印上述相关信息,示例代码如下所示:
1 @get_info
2 def subtract(a=1, b=2):
3 time.sleep(3)
4 r = a - b
5 return r
6
7 if __name__ == '__main__':
8 subtract(3, 4)
9
10 正在执行函数 subtract() !
11 一共耗时3.002s在上述代码中,第1行便是调用了get_info修饰器。第2~5行是subtract函数原有的计算逻辑,并没有进行任何修改。所以,此时我们只需要在所有函数定义的地方使用get_info修饰器便可以实现运行时间计算的功能。
5.7.2 修饰器定义#
在Python语法中,修饰器可以简单的分为包含参数和不包含参数两种。例如上面在使用@get_info时便没有传入相关参数,如果包含有参数则使用方式类似@get_info(book_name="《跟我一起学深度学习》")。下面我们分别就这两种情况进行介绍。
1. 不含参数的修饰器
在使用修饰器之前,需要先定义一个完成目标功能的函数。对于第5.7.1节中的例子来说,示例代码如下:
1 def get_info(func):
2 def wrapper(*args, **kwargs):
3 print(f"正在执行函数 {func.__name__}() !")
4 start_time = time.time()
5 result = func(*args, **kwargs)
6 end_time = time.time()
7 print(f"一共耗时{(end_time - start_time):.3f}s")
8 return result
9 return wrapper在上述代码中,第3~4行和第6~7行便是为了实现目标功能所加入的逻辑。第5行则是原有功能函数的执行逻辑,例如第5.7.1节中的add和subtract函数。
此时可以看出,get_info本质上就是定义了一个多层嵌套的函数,因此也可以通过函数调用的方式来进行使用,示例代码如下所示:
1 def subtract(a=1, b=2):
2 time.sleep(3)
3 r = a - b
4 return r
5
6 if __name__ == '__main__':
7 get_info(subtract)(7, 8)虽然这样的方式也能实现同样的逻辑,但是使用起来不如修饰器简洁。
通过上述介绍可以发现,定义修饰器函数的大致格式如下所示:
1 def decorator(func):
2 def wrapper(*args, **kwargs):
3 ## 在这里添加需要预先执行的代码语句
4 result = func(*args, **kwargs)
5 ## 在这里添加需要事后执行的代码语句
6 return result
7 return wrapper在上述代码中,第1行decorator为修饰器的名称,func为使用该修饰器的函数。第2行*args, **kwargs则为使用该修饰器函数的相关参数。第3行则是需要预先执行的计算逻辑。第4行则是执行原有函数的计算逻辑。第5行是事后需要执行的计算逻辑。
同时,由于通过@符号来将decorator作为修饰器调用本质上只是一种快速简洁的方式,所以@decorator还等价于decorator(func)(*args, **kwargs)这样的调用方式。因此,通过后者我们还能够更加清晰地认识到整个修饰器的工作流程。
2. 包含参数的修饰器
所谓包含参数的修饰器指的是在调用修饰器时同时也传入相关参数。例如在后续介绍数据预处理结果缓存时,为了能够区分缓存结果的唯一性,我们就需要传入预处理时的相关参数来构造一个缓存文件名,如top_k、max_len或者cut_words这样的参数。因为对于不同的参数,构造得到的数据集并不一样。
对于需要传入用户参数的修饰器,其定义代码如下所示:
1 def get_info_with_para(name=None):
2 print(f"name = {name}")
3 def decorating_function(func):
4 def wrapper(*args, **kwargs):
5 print(f"正在执行函数 {func.__name__}() !")
6 start_time = time.time()
7 result = func(*args, **kwargs)
8 end_time = time.time()
9 print(f"一共耗时{(end_time - start_time):.3f}s")
10 return result
11 return wrapper
12 return decorating_function在上述代码中,为了实现传入自定义参数,我们在已有的两层函数之上又嵌套了一个函数。
进一步,可以通过如下方式来进行使用:
1 @get_info_with_para(name='power function')
2 def power(num):
3 time.sleep(3)
4 r = num ** 2
5 return r
6
7 name = power function
8 正在执行函数 power() !
9 一共耗时3.005s上述完整示例代码可参见Code/Chapter05/C08_DataCache/decorator.py文件。
5.7.3 定义数据集构造类#
在介绍完修饰器的基本原理及用法之后我们再来看如何实现数据预处理结果缓存。整理逻辑依旧是本节内容伊始所提,载入数据集之前首先判断本地是否存在缓存,如果存在则直接载入缓存,如果不存在则再调用函数进行数据预处理并进行缓存。
通常来说,在构造训练集时可以通过定义一个类来完成,并且这个类至少会包含3个方法:__init__、data_process和load_train_val_test_data,其中__init__用来初始化类中的相关参数(如batch_size、max_len、file_ptah等等);data_process用来对数据集进行预处理并返回预处理后的结果;load_train_test_data用来构造最后模型训练时的DataLoader迭代器。
进一步,其定义代码如下所示:
1 class LoadData(object):
2 FILE_PATH = './text_train.txt'
3
4 def __init__(self):
5 self.max_len = 5
6 self.batch_size = 2
7
8 def data_process(self, file_path=None):
9 time.sleep(10)
10 logging.info("正在进行预处理数据……")
11 x = torch.randn((10, 5))
12 y = torch.randint(2, [10])
13 data = {"x": x, "y": y}
14 return data
15
16 def load_train_val_test_data(self):
17 data = self.data_process(file_path=self.FILE_PATH)
18 x, y = data['x'], data['y']
19 data_iter = TensorDataset(x, y)
20 data_iter = DataLoader(data_iter, batch_size=self.batch_size)
21 return data_iter在上述代码中,第4~6行是初始化数据预处理的相关参数。第8~14行则是模拟数据集的处理过程,这里直接随机进行了生成,其中第9行是用来模拟消耗的时间。第16~21行则是用来构造最后的迭代器。
5.7.4 定义缓存修饰器#
在完成数据集构造类之后,只需要再按照5.7.2节中的语法完成缓存修饰器的实现即可,具体示例代码如下所示:
1 def process_cache(unique_key=None):
2 if unique_key is None:
3 raise ValueError("unique_key 不能为空, 请指定相关数据集构造类的成员变量")
4 def decorating_function(func):
5 def wrapper(*args, **kwargs):
6 obj = args[0]
7 file_path = kwargs['file_path']
8 file_dir = f"{os.sep}".join(file_path.split(os.sep)[:-1])
9 file_name = "".join(file_path.split(os.sep)[-1].split('.')[:-1])
10 paras = f"cache_{file_name}_"
11 for k in unique_key:
12 paras += f"{k}{obj.__dict__[k]}_" # 遍历对象中的所有参数
13 cache_path = os.path.join(file_dir, paras[:-1] + '.pt')
14 start_time = time.time()
15 if not os.path.exists(cache_path):
16 logging.info(f"缓存文件 {cache_path} 不存在,重新处理并缓存!")
17 data = func(*args, **kwargs)
18 with open(cache_path, 'wb') as f:
19 torch.save(data, f)
20 else:
21 logging.info(f"缓存文件 {cache_path} 存在,直接载入缓存文件!")
22 with open(cache_path, 'rb') as f:
23 data = torch.load(f)
24 end_time = time.time()
25 logging.info(f"数据预处理一共耗时{(end_time - start_time):.3f}s")
26 return data
27 return wrapper
28 return decorating_function在上述代码中,第1行unique_key指定用于区分根据同一原始数据但不同超参数所生成的缓存文件,如['top_k', 'cut_words', 'max_sen_len']等。第6行用于获取类对象,因为data_process(self, file_path=None)中的第1个参数为self。第7行是获取方法data_process中file_path的取值。第8~13行是根据文件名和传入的unique_key构造一个唯一的缓存文件名。第15~19行则是当本地不存在缓存文件时,根据第17行来对原始数据预处理并根据第18~19行将处理好的结果存放到本地。第20~23行则是直接从本地载入缓存文件。
在函数process_cache实现完成后,只需要以修饰器@process_cache(['max_len'])的形式将其作用于data_process方法上即可,此时指定了用于区分不同缓存文件的参数名max_len。
最后,在第1次使用上述数据集构造类,将会得到如下所示输出信息:
1 ## 索引预处理缓存文件的参数为:['max_len']
2 缓存文件 ./cache_text_train_max_len5.pt 不存在,重新处理并缓存!
3 正在进行预处理数据……
4 数据预处理一共耗时10.006s从上述结果可以看出,数据集处理完毕后将会生成一个名为cache_text_train_max_len5.pt的缓存文件,且一共耗费了10秒钟的时间。
当第2次再次载入同样的缓存文件时,则会得到如下所示输出信息:
1 ## 索引预处理缓存文件的参数为:['max_len']
2 缓存文件 ./cache_text_train_max_len5.pt 存在,直接载入缓存文件!
3 数据预处理一共耗时0.002s从上述结果可以看出,由于此时本地缓存文件存在,所以便直接从本地载入了一共耗时不到1秒。
到此,对于如何利用Python修饰器来便捷缓存数据预处理结果的内容就介绍完了,上述完整示例代码可参见Code/utils/tools.py文件。
5.7.5 小结#
在本节内容中,我们首先从使用示例的角度来介绍了Python修饰器的用法及工作原理,即其本质上只是Python中所支持的一种快速简洁的函数调用方式;然后介绍了不含参数和含有参数两种修饰器的实现方法;最后通过一个实际的使用示例来详细介绍了如何从零实现一个可通用的数据预处理缓存修饰器。