5.7 数据预处理与缓存#

随着任务场景和深度学习模型的复杂化,使得模型在训练过程中每次调试时都需要花费较长的时间来等待数据集预处理结果。一个简单直接的办法就是在模型每次载入数据集时都预先判断本地是否有对应的缓存文件,如果有则直接载入,没有则重新处理并进行缓存。同时,为了方便这段处理逻辑能够方便地迁移到其它类似情况,因此我们需要将其定义成一个Python修饰器来进行使用。

下面,我们先来简单介绍一个Python中修饰器的功能及用法。

5.7.1 修饰器介绍#

关于什么是修饰器或装饰器(Decorator)我们这里就不从Python语法上来做详细的解释了。简单一句话,修饰器的作用的就是在正式执行某个功能函数之前,预先执行你想要执行的某些逻辑。例如在进行数据预处理之前先判断是否有对应的缓存文件。下面,我们直接从用法的层面来逐步了解Python中的修饰器。

首先来看这样一个场景,假如你之前已经定义了多个功能函数,但此时需要在日志文件中同时也输出每个函数的实际运行时间和其它相关信息。例如:

打印出当前主程序正在调用哪个功能函数的信息。例如:

1 def add(a=1, b=2):    
2     time.sleep(2)
3     r = a + b
4     return r
5 
6 def subtract(a=1, b=2):
7     time.sleep(3)
8     r = a - b
9     return r

在上述代码中,time.sleep(2)是为了模拟运行所花费的时间。

进一步,对于上述两个函数,如果需要实现打印运行时间等相关信息,可以通过如下类似方式实现:

1 def add(a=1, b=2):
2     print(f"正在执行函数 add() !")
3     start_time = time.time()
4     time.sleep(2)
5     r = a + b
6     end_time = time.time()
7     print(f"一共耗时{(end_time - start_time):.3f}s")
8     return r

在上述代码中,第2、3、6和7行便是需要打印输出的相关信息。虽然通过这样的方式也能解决,但是如果有大量的函数都需要添加这么一段逻辑,那这种做法显然不可取。另外一种高效的方法则是使用Python中的修饰器。

假如现在已经定义好了一个名为get_info的修饰器,那么只需要通过如下方式便可以打印上述相关信息,示例代码如下所示:

 1 @get_info
 2 def subtract(a=1, b=2):
 3     time.sleep(3)
 4     r = a - b
 5     return r
 6 
 7 if __name__ == '__main__':
 8     subtract(3, 4)
 9 
10 正在执行函数 subtract() 
11 一共耗时3.002s

在上述代码中,第1行便是调用了get_info修饰器。第2~5行是subtract函数原有的计算逻辑,并没有进行任何修改。所以,此时我们只需要在所有函数定义的地方使用get_info修饰器便可以实现运行时间计算的功能。

5.7.2 修饰器定义#

在Python语法中,修饰器可以简单的分为包含参数和不包含参数两种。例如上面在使用@get_info时便没有传入相关参数,如果包含有参数则使用方式类似@get_info(book_name="《跟我一起学深度学习》")。下面我们分别就这两种情况进行介绍。

1. 不含参数的修饰器

在使用修饰器之前,需要先定义一个完成目标功能的函数。对于5.7.1节中的例子来说,示例代码如下:

1 def get_info(func):
2     def wrapper(*args, **kwargs):
3         print(f"正在执行函数 {func.__name__}() !")
4         start_time = time.time()
5         result = func(*args, **kwargs)
6         end_time = time.time()
7         print(f"一共耗时{(end_time - start_time):.3f}s")
8         return result
9     return wrapper

在上述代码中,第3~4行和第6~7行便是为了实现目标功能所加入的逻辑。第5行则是原有功能函数的执行逻辑,例如5.7.1节中的addsubtract函数。

此时可以看出,get_info本质上就是定义了一个多层嵌套的函数,因此也可以通过函数调用的方式来进行使用,示例代码如下所示:

1 def subtract(a=1, b=2):
2     time.sleep(3)
3     r = a - b
4     return r
5     
6 if __name__ == '__main__':
7     get_info(subtract)(7, 8)

虽然这样的方式也能实现同样的逻辑,但是使用起来不如修饰器简洁。

通过上述介绍可以发现,定义修饰器函数的大致格式如下所示:

1 def decorator(func):
2     def wrapper(*args, **kwargs):
3         ## 在这里添加需要预先执行的代码语句
4         result = func(*args, **kwargs)
5         ## 在这里添加需要事后执行的代码语句
6         return result
7     return wrapper

在上述代码中,第1行decorator为修饰器的名称,func为使用该修饰器的函数。第2行*args, **kwargs则为使用该修饰器函数的相关参数。第3行则是需要预先执行的计算逻辑。第4行则是执行原有函数的计算逻辑。第5行是事后需要执行的计算逻辑。

同时,由于通过@符号来将decorator作为修饰器调用本质上只是一种快速简洁的方式,所以@decorator还等价于decorator(func)(*args, **kwargs)这样的调用方式。因此,通过后者我们还能够更加清晰地认识到整个修饰器的工作流程。

2. 包含参数的修饰器

所谓包含参数的修饰器指的是在调用修饰器时同时也传入相关参数。例如在后续介绍数据预处理结果缓存时,为了能够区分缓存结果的唯一性,我们就需要传入预处理时的相关参数来构造一个缓存文件名,如top_kmax_len或者cut_words这样的参数。因为对于不同的参数,构造得到的数据集并不一样。

对于需要传入用户参数的修饰器,其定义代码如下所示:

 1 def get_info_with_para(name=None):
 2     print(f"name = {name}")
 3     def decorating_function(func):
 4         def wrapper(*args, **kwargs):
 5             print(f"正在执行函数 {func.__name__}() !")
 6             start_time = time.time()
 7             result = func(*args, **kwargs)
 8             end_time = time.time()
 9             print(f"一共耗时{(end_time - start_time):.3f}s")
10             return result
11         return wrapper
12     return decorating_function

在上述代码中,为了实现传入自定义参数,我们在已有的两层函数之上又嵌套了一个函数。

进一步,可以通过如下方式来进行使用:

1 @get_info_with_para(name='power function')
2 def power(num):
3     time.sleep(3)
4     r = num ** 2
5     return r
6 
7 name = power function
8 正在执行函数 power() 
9 一共耗时3.005s

上述完整示例代码可参见Code/Chapter05/C08_DataCache/decorator.py文件。

5.7.3 定义数据集构造类#

在介绍完修饰器的基本原理及用法之后我们再来看如何实现数据预处理结果缓存。整理逻辑依旧是本节内容伊始所提,载入数据集之前首先判断本地是否存在缓存,如果存在则直接载入缓存,如果不存在则再调用函数进行数据预处理并进行缓存。

通常来说,在构造训练集时可以通过定义一个类来完成,并且这个类至少会包含3个方法:__init__data_processload_train_val_test_data,其中__init__用来初始化类中的相关参数(如batch_sizemax_lenfile_ptah等等);data_process用来对数据集进行预处理并返回预处理后的结果;load_train_test_data用来构造最后模型训练时的DataLoader迭代器。

进一步,其定义代码如下所示:

 1 class LoadData(object):
 2     FILE_PATH = './text_train.txt'
 3 
 4     def __init__(self):
 5         self.max_len = 5
 6         self.batch_size = 2
 7 
 8     def data_process(self, file_path=None):
 9         time.sleep(10)
10         logging.info("正在进行预处理数据……")
11         x = torch.randn((10, 5))
12         y = torch.randint(2, [10])
13         data = {"x": x, "y": y}
14         return data
15 
16     def load_train_val_test_data(self):
17         data = self.data_process(file_path=self.FILE_PATH)
18         x, y = data['x'], data['y']
19         data_iter = TensorDataset(x, y)
20         data_iter = DataLoader(data_iter, batch_size=self.batch_size)
21         return data_iter

在上述代码中,第4~6行是初始化数据预处理的相关参数。第8~14行则是模拟数据集的处理过程,这里直接随机进行了生成,其中第9行是用来模拟消耗的时间。第16~21行则是用来构造最后的迭代器。

50