更新于 2026年6月29日

8.6 STResNet网络#

在前面几节内容中, 我们陆续介绍了多种通过结合CNN和RNN的模型来对时序数据进行建模,并且从「第8.4节 ConvLSTM原理:时空序列预测与卷积循环模型」内容开始还首次引入了基于时空数据相关任务。不管是第8.4节中介绍的ConvLSTM模型还是「第8.5节 3DCNN原理:三维卷积网络处理视频与时序数据」中引入的3DCNN模型,为了能同时提取时空数据在时间和空间两个维度的特征信息均从模型本身进行了改进。在本节内容中,我们将会介绍一种通过改进任务建模方式而仅依靠2DCNN来进行时空数据特征提取的模型,并同时介绍如何对多个子模块的输出结果进行融合。

8.6.1 STResNet动机#

尽管已有的时序模型例如3DCNN或者是ConvLSTM也能够用于提取时空数据在时间和空间两个维度上的特征信息,但是在流量预测这个场景中却不能有效得捕捉到在时间维度上的周期或者是趋势信息。例如在交通流量这一场景中,传统的时序模型只能提取到时空数据在紧邻若干时间内容的时序信息,但显然在这个场景中交通流量还具有周期规律或者趋势性。

通常来说,交通流量数据在相邻时刻间、连续多天的同一时间段以及在长期趋势上都具有一定的规律性。同时,在这一场景中当天是否为节假日也会对交通流量的预测结果产生极大的影响。

图 8-13 交通流量数据趋势图
图 8-13 交通流量数据趋势图

如图8-13所示,左图展示了以周为单位节假日和非节假日的流量变化图;右图展示了以周为单位不同时间段的流量变化趋势。从图8-13中的结果可以看出,不管在哪种情况下流量数据都具有一定的周期或者规律性。

基于这样的动机,郑宇[1]等人在2017年提出了一种能够同时提取时空数据在邻近性(Closeness )、周期性(Period)和趋势性(Trend)上的时序特征,并同时考虑到节假日及天气状况等额外信息的时空模型(Spatio Temporal ResNet, STResNet)。该模型以残差模块为基础,通过构建不同的子模块来对不同的采样数据进行特征提取,然后再将各个模块的结果进行融合以实现上述目的。

8.6.2 任务背景#

1. 任务介绍

交通流量(Traffic Flow)预测任务是指使用历史数据和其他相关信息来预测特定区域或道路上未来的交通流量情况。它的目标是估计未来某一时间段内的车辆数量、速度和拥堵状况等交通指标,并且通常可以通过交通传感器、视频监控、GPS设备、移动应用程序等方式获取交通流量原始数据。在论文[1]中,作者使用了GPS原始数据来构建对应的数据集并用于模型训练。

图 8-14 网格划分及流量图
图 8-14 网格划分及流量图

如图8-14(左)所示,将整个城市以经纬度划分为大小为$32\times32$的方格,然后再根据GPS轨迹以半小时为间隔统计得到每个格子中的流入流出量并看做是两个通道,这样便得到了整个城市各个区域在每个时刻的流量分布情况。进一步,图8-14(右)便是某一时刻的流入交通流量分布情况。经过这样的处理,我们便得到了一系列具有时序关系的图片数据,而任务的目的便是将前$T$个时间片(时刻)的流量作为输入,然后预测第$T+1$时刻的流入流出流量。

2. 样本采样

为了使得STResNet模型具备提取时空数据在时间维度上的近邻性、周期性和趋势性,作者以不同的时间间隔对原始数据进行了采样,然后再将这3部分采样得到的数据输入到3个由残差网络构建的子模型中进行空间上的特征提取,并将三部分的输出融合作为整体的特征表示。具体采样方式如图8-15所示。

图 8-15 数据采样图
图 8-15 数据采样图

如图8-15所示便是原始构建完成的流量数据,每个时间片的流量情况用一个形状为[2,32,32]的矩阵来进行表示。例如以预测时刻片$t_i$的流量为例,可以取$t_i$的前3个时间片$t_{i-1},t_{i-2},t_{i-3}$来模拟邻近性、取$t_{i}$前一天同时刻的时间片$t_j$来模拟周期性、取$t_i$前一周同时刻的时间片$t_k$来模拟趋势性,并将这3部分作为一个样本同时输入到3个子模块中。最后,再将窗口向右滑动一个时间片来构建下一个样,即$t_i,t_{i-1},t_{i-2}$作为近邻性的输入、$t_{j+1}$作为周期性的输入、$t_{k+1}$作为趋势性的输入,然后预测第$t_{i+1}$时刻的交通流量。当然,这3部分的采样方法以及连续时间片的长度都可以作为超参数进行调整,详见第8.6.4节代码实现。

同时,为了考虑到其它额外因素对最终预测结果的影响,因此作者:①使用了一个8维向量来表示当天是否为工作日,其中前7个维度为独热编码表示当天是星期几,第8个维度表示当天是否为工作日;②使用了一个维度来表示当天是否为节假日;③使用了一个19维向量来表示天气状况,其中前17个维度同样也为独热编码用于表示其中一种天气情况,最后两个维度则表示风速和温度。最终,用这个28维向量来表示除了流量数据之外的额外因素。

8.6.3 STResNet结构#

1. 整体结构

在介绍完任务背景之后接下来开始介绍STResNet模型的原理。STResNet模型整体上可以分为4个部分,其中前3个部分以不同间隔采样的流量数据作为输入通过3个深度残差网络子模块来提取时空数据在时间维度和空间维度上的特征信息,第4个部分则是一个浅层的全连接网络用于考虑节假日和天气等因素对预测结果的影响。整体结构如图8-16所示。

图 8-16 STResNet网络结构图
图 8-16 STResNet网络结构图

如图8-16所示,右侧部分3个相同的结构便是用于提取时刻数据特征的深度残差网络模块,从右至左依次用于对邻近性、周期性和趋势性采样数据在空间上特征提取,然后再将这3部分的结果融合得到流量数据在时间上的特征信息,从而得到$X_{\text{Res}}$来表征整体流量数据的时空特征信息。左侧部分则是用于处理天气等额外信息对预测结果的影响。最终,将两部分的结果相加经过$\tanh$非线性变换后便得到了整个模型的预测输出。

2. 残差模块

深度残差模块主要由多个残差单元所构成,并且首尾各插入了一个单独的卷积层来调整特征图的通道数。对于每一个残差单元来说,其由两个卷积层所构成并且还加入了批归一化层,如图8-17所示。

图 8-17 残差单元结构图
图 8-17 残差单元结构图

如图8-17所示便是STResNet模型中的残差单元。在整个残差模块中,卷积核的大小均是$3\times3$,且除了Conv2之外所有卷积层的卷积核个数均为64,Conv2的卷积核个数则为2,因为预测输出包含两个通道。

3. 模型融合

对于每个残差模块来说,其输出的结果均为一个形状为[2,32,32]的特征图,因此作者采用了式(8-4)所示的方式进行融合

$$ X_{\text{Res}}=W_c\odot X_c+W_p\odot X_p+W_q\odot X_q\tag{8-5} $$

其中$X_c$、$X_p$和$X_q$分别为3个残差模块的输出结果,$W_c$、$W_p$和$W_q$为3个可学习的权重矩阵形状均为[2,32,32],$\odot$表示按位乘。

在对外部因素进行处理时使用了两个全连接层,由于这部分原始输入整体上是由独热编码构成,所以第1个全连接层还可以近似看成一个嵌入层。为了匹配最后的输出形状,在第2个全连接层中权重参数的维度则必须为$2\times32\times32$,然后再将输出$X_{\text{Ext}}$变形为[2,32,32]并直接与$X_{\text{Res}}$按位相加得到融合后的结果。最后,为了使得模型能够快速收敛作者还使用了$\tanh$非线性变换,这也就意味着输入模型的交通流量数据需要先做$[-1,1]$的标准化,然后在实际推理过程中再将预测结果还原,计算过程如式(8-6)所示。

$$ \begin{aligned} y=2\cdot\frac{x-\min}{\max-\min}-1 \end{aligned}\tag{8-6} $$

对于不同模块间输出结果的融合一般常见的有:①在某个维度进行拼接;②直接以按位加的方式进行处理;③先以按位乘的方式作用一个可学习的权重参数再进行按位加处理;④先以矩阵乘法的方式将不同模块的输出变换到同一个维度再按位加处理。通常来说,先作用一个可学习参数再进行按位加是一个比较好的选择。

最后,由于STResNet模型完成的是一个回归任务,因此选择了均方误差作为整体的目标函数,如式(8-7)所示。

$$ \begin{aligned} \mathcal{L}(\theta) &= \|X_t-\hat{X}_t\|^2_2,\\[2ex]\hat{X}_t&=\tanh(X_{\text{Res}}+X_{\text{Ext}}) \end{aligned}\tag{8-7} $$

其中$X_t$和$\hat{X}_t$分别表示第$t$时刻的真实值和预测值。

8.6.4 数据集构建#

在原始数据集中一共包含有6个文件,分别是BJ_Holiday.txtBJ_Meteorology.h5BJ13_M32x32_T30_InOut.h5BJ14_M32x32_T30_InOut.h5BJ15_M32x32_T30_InOut.h5BJ16_M32x32_T30_InOut.h5,其中第1个是节假日信息,第2个是气象信息,最后4个则是交通流量数据。下面我们开始简单介绍一下整个数据集的构建流程,以下完整示例代码及注释可以参见[Code/utils/data_helper.py`文件。

1. 读取原始数据

在清楚数据集的相关信息后进一步便可以编码读取并进行相关的预处理工作。首先我们需要先定义一个类并完成其初始化函数的构造,示例代码如下所示:

 1 class TaxiBJ(object):
 2     DATA_DIR = os.path.join(DATA_HOME, 'TaxiBJ')
 3     FILE_PATH_FLOW = [os.path.join(DATA_DIR, 'BJ13_M32x32_T30_InOut.h5'),
 4                       os.path.join(DATA_DIR, 'BJ14_M32x32_T30_InOut.h5'),
 5                       os.path.join(DATA_DIR, 'BJ15_M32x32_T30_InOut.h5'),
 6                       os.path.join(DATA_DIR, 'BJ16_M32x32_T30_InOut.h5')]
 7     FILE_PATH_HOLIDAY = os.path.join(DATA_DIR, 'BJ_Holiday.txt')
 8     FILE_PATH_METEORO = os.path.join(DATA_DIR, 'BJ_Meteorology.h5')
 9     CATH_FILE_PATH = os.path.join(DATA_DIR, 'TaxiBJ.pt')
10 
11     def __init__(self, T=48, nb_flow=2, len_test=None, len_closeness=None,
12          len_period=None, len_trend=None, meta_data=True,meteorol_data=True,
13          holiday_data=True, batch_size=4, is_sample_shuffle=True):
14         self.T = T
15         self.nb_flow = nb_flow
16         self.len_test = len_test
17         self.len_closeness = len_closeness
18         self.len_period = len_period
19         self.len_trend = len_trend
20         self.meta_data = meta_data
21         self.meteorology_data = meteorol_data
22         self.holiday_data = holiday_data

在上述代码中,第2~9行用于定义数据集的目录和文件名以及缓存文件的文件名。第11~22行则是定义相关构造数据集时的超参数。

接着,分别定义3个类方法来载入节假日数据、气象数据和流量数据。对于节假日数据来说,其载入方法如下所示:

 1     def load_holiday(self, timeslots=None):
 2         filepath = self.FILE_PATH_HOLIDAY
 3         with open(filepath, 'r') as f:
 4             holidays = f.readlines()
 5             holidays = set([h.strip() for h in holidays])
 6             H = np.zeros(len(timeslots))
 7             for i, slot in enumerate(timeslots):
 8                 if slot[:8] in holidays: 
 9                     H[i] = 1
10         return H[:, None] 

在上述代码中,第1行timeslots是表示日期的时间戳,如['2014120106','2014120206']。第3~5行是读取原始节假日信息文件并进行去重,最终将得到一个假期列表,形如:['20130101','20130102','20130103','20130209', ...]。第6~9行则是遍历timeslots中的每个时间戳并取前8位为日期,判断其是否为节假日。最后,第10行将会返回得到一个形状为[n,1]且仅包含0和1取值的列向量。

例如对于时间戳来说:

1 load_holiday(timeslots=['2014120106','2014120206','2014010106','2014120706'])

返回结果为:

1 [[0.]
2  [0.]
3  [1.] # 这一天元旦为节假日
4  [0.]]

进一步,对于气象数据来说其存储格式为h5因此需要先通过pip安装h5py模块,其载入方法如下所示:

 1     def load_meteorology(self, timeslots=None):
 2         file_path = self.FILE_PATH_METEORO
 3         with h5py.File(file_path, 'r') as f:
 4             Timeslot, WindSpeed = f['date'][:], f['WindSpeed'][:]
 5             Weather, Temperature = f['Weather'][:], f['Temperature'][:]
 6         M = dict() 
 7         for i, slot in enumerate(Timeslot):
 8             M[slot] = i
 9         WS, WR, TE = [], [], [] 
10         for slot in timeslots:
11             predicted_id, cur_id = M[slot], predicted_id - 1 
12             WS.append(WindSpeed[cur_id])
13             WR.append(Weather[cur_id])
14             TE.append(Temperature[cur_id])
15         WS, WR, TE = np.asarray(WS), np.asarray(WR), np.asarray(TE)
16         WS = 1. * (WS - WS.min()) / (WS.max() - WS.min())
17         TE = 1. * (TE - TE.min()) / (TE.max() - TE.min())
18         merge_data = np.hstack([WR, WS[:, None], TE[:, None]])
19         return merge_data

在上述代码中,第3~5行为载入原始气象数据集和时间戳。第6~8行是给每个时间戳赋一个索引,形容{...,b'2016061335': 59003, b'2016061336': 59004, b'2016061337': 59005}。第10~14行表示取上一个索引,因为一般来说预测第$t$时刻时只能取其$t-1$时刻的天气信息。第16~17行是一次对所有的温度和风速进行标准化。第18行则是将天气、风速和温度这3列特征拼接起来得到一个形状为[n,19]的输入特征。

最后则是实现原始数据的载入,示例代码如下所示:

1     def load_stdata(fname):
2         f = h5py.File(fname, 'r')
3         data = f['data'][:]
4         timestamps = f['date'][:]
5         f.close()
6         return data, timestamps

在上述代码中,第2行是载入原始h5文件。第3~4行则是分别取每个时间片对应的双通道流量数据和相应的时间戳信息。进一步还可以通过show_example()方法来进行可视化,即图8-14所示。

2. 样本构建

在实现各部分数据的载入之后,下面开始进行采用并构建样本。由于这部分代码较长所以分为两个部分进行介绍。首先是载入流量数据并进行采样处理,示例代码如下所示:

 1 		 @process_cache(unique_key=["T", "nb_flow", "len_test", "len_closeness","len_period"])
 2     def data_process(self, file_path=None):
 3         data_all, timestamps_all= [], []
 4         for fname in self.FILE_PATH_FLOW:
 5             self.stat(fname)
 6             data, timestamps = self.load_stdata(fname)
 7             data, timestamps = self.remove_incomplete_days(data, timestamps, self.T)
 8             data = data[:, :self.nb_flow]
 9             data[data < 0] = 0.  
10             data_all.append(data)
11             timestamps_all.append(timestamps)
12         data_train = np.vstack(copy(data_all))[:-self.len_test]
13         mmn = MinMaxNormalization()
14         mmn.fit(data_train)
15         data_all_mmn = [mmn.transform(d) for d in data_all]
16         XC, XP, XT, Y, timestamps_Y = [], [], [], [], []
17         for data, timestamps in zip(data_all_mmn, timestamps_all):
18             st = STMatrix(data, timestamps, self.T, CheckComplete=False)  # 采样构造流量数据
19             _XC, _XP, _XT, _Y, _timestamps_Y = st.create_dataset(
20                 len_closeness=self.len_closeness, len_period=self.len_period, len_trend=self.len_trend)
21             XC.append(_XC)
22             XP.append(_XP)
23             XT.append(_XT)
24             Y.append(_Y)
25             timestamps_Y += _timestamps_Y 

在上述代码中,第1行是根据参数取值构建文件名并缓存data_process()处理完成的结果。第4~11行是逐一读取原始流量数据文件并进行相关预处理,其中第5行是查看数据集信息包括时间跨度和数量等,第6行是载入原始流量数据,第7行检查是否有数据不完整的某一天(如以30分钟为时间片则1天便有48条数据),第9行是处理异常把小于0的数据替换为0,第10~11行则是保存每个文件载入的结果,其中data的形状为[num,2,32,32]。第12行是先将4部分的数据堆叠到一起然后再换分出训练集部分。第13~15行是实例化一个标准化对象,然后再拟合相应的参数并对所有数据按式(8-5)中的计算方式进行标准化。第17~25行则是遍历每个部分的流量数据构造样本,其中第18行是实例化一个用于样本采样的类对象,第19行则是根据传入参数进行样本构造,len_closeness表示邻近性的时间片长度默认为3,len_period表示周期性的时间片长度默认为1,第25行是保存所有预测时间的时间戳。

进一步,载入气象数据并划分整个数据集,示例代码如下所示:

 1         meta_feature = []
 2         if self.meta_data:
 3             time_feature = timestamp2vec(timestamps_Y)  
 4             meta_feature.append(time_feature)
 5         if self.holiday_data:
 6             holiday_feature = self.load_holiday(timestamps_Y) 
 7             meta_feature.append(holiday_feature)
 8         if self.meteorology_data:
 9             meteorol_feature = self.load_meteorology(timestamps_Y) 
10             meta_feature.append(meteorol_feature)
11         meta_feature = np.hstack(meta_feature) if len(
12                         meta_feature) > 0 else np.asarray(meta_feature)
13         XC_train, XP_train, XT_train, Y_train = XC[:-self.len_test], \
14                 XP[:-self.len_test], XT[:-self.len_test], Y[:-self.len_test]
15         XC_test, XP_test, XT_test, Y_test = XC[-self.len_test:], \
16                 XP[-self.len_test:], XT[-self.len_test:], Y[-self.len_test:]
17         times_train, times_test = timestamps_Y[:-self.len_test], timestamps_Y[-self.len_test:]
18         meta_train, meta_test = meta_feature[:-self.len_test], meta_feature[-self.len_test:]
19         train_data = [item for item in zip(XC_train, XP_train, XT_train, Y_train, meta_train, times_train)]
20         test_data = [item for item in zip(XC_test, XP_test, XT_test, Y_test, meta_test, times_test)]
21         data = {"train_data": train_data, "test_data": test_data, "mmn": mmn}
22         return data

在上述代码中,第2~4行是将时间戳转换为一个8维向量表示当前是否为工作日。第5~7行是载入节假日数据并用1个维度来进行表示。第8~10行是载入气象数据,一共19个维度。第11~12行是就上述特征进行拼接,最终得到一个[n,28]的特征矩阵。第13~18行是各部分数据的训练集和测试集划分,默认为使用最后4周作为测试集。第19~20行是将所有输入数据以样本为单位进行组合,以便后续构造迭代器。第21~22行则是返回最后处理完成的样本,其中mmn用于在后续推理过程时将预测结果还原为真实值。

3. 构建迭代器

在完成原始样本构建后进一步可以构造得到迭代器,示例代码如下所示:

 1     def load_train_test_data(self, is_train=False):
 2         data = self.data_process(file_path=self.CATH_FILE_PATH)
 3         mmn = data['mmn']
 4         if not is_train:
 5             test_data = data['test_data']
 6             test_iter = DataLoader(test_data, self.batch_size, True)
 7             return test_iter, mmn
 8         train_data = data['train_data']
 9         train_iter = DataLoader(train_data, self.batch_size, self.is_sample_shuffle)
10         return train_iter, mmn

在上述代码中,第2行是返回处理完成的数据样本。第3行是取标准化实例化对象。第4~7行则是构造测试集对应的迭代器。第8~10行则是构造训练集对应的迭代器。

8.6.5 STResNet实现#

在完成整个数据集的构建之后便可以根据图8-16中的网络结构来实现STResNet模型。以下完整示例代码可以参见Code/Chapter08/C07_STResNet/STResNet.py文件。

1. 前向传播

根据图8-16所示,首先需要实现网络结构中的残差单元和残差模块。对于残差单元来说,其实现代码如下所示:

 1 class ResUnit(nn.Module):
 2     def __init__(self, res_in_chs=16, res_out_chs=32):
 3         super().__init__()
 4         self.block = nn.Sequential(
 5             nn.BatchNorm2d(res_in_chs),
 6             nn.ReLU(inplace=True),
 7             nn.Conv2d(res_in_chs, res_out_chs, 3, stride=1, padding=1),
 8             nn.BatchNorm2d(res_out_chs),
 9             nn.ReLU(inplace=True),
10             nn.Conv2d(res_out_chs, res_in_chs, 3, stride=1, padding=1))
11     def forward(self, x):
12         return x + self.block(x)

在上述代码中,第2行res_in_chsres_out_chs分别表示卷积单元中两个卷积操作对应的通道数。第7、10行则是对应的卷积操作,其中卷积核的大小固定为$3\times3$且进行了填充处理,即卷积操作后不改变特征图的大小。第11~12行则是残差连接。

进一步,对于残差模块来说,其由多个残差单元所构成,实现代码如下所示:

 1 class ResComponent(nn.Module):
 2     def __init__(self, conv1_in_chs=8, conv1_out_chs=16, num_res_unit=3, 
 3                  res_out_chs=32, nb_flow=2):
 4         super().__init__()
 5         self.conv1 = nn.Conv2d(conv1_in_chs, conv1_out_chs, 3, stride=1, padding=1)
 6         res_units = []
 7         for i in range(num_res_unit):
 8             res_units.append(ResUnit(conv1_out_chs, res_out_chs))
 9         self.res_units = nn.ModuleList(res_units)
10         self.conv2 = nn.Conv2d(conv1_out_chs, nb_flow, 3, stride=1, padding=1)
11     def forward(self, x):
12         x = self.conv1(x)
13         for res_unit in self.res_units:
14             x = res_unit(x)
15         x = self.conv2(x)
16         return x  

在上述代码中,第5行和第10行分别是残差模块中首尾的两个卷积操作。第7~8行便是连续的多个残差单元。第9行是res_units转换为ModuleList对象,否则在GPU上运行可能会出现模型参数不在同一个设备上的错误。第12~15行便是整个残差模块的前向传播过程。

接着,根据图8-16中的网络结构,气象等额外信息的特征提取模块实现如下所示:

 1 class FeatureExt(nn.Module):
 2     def __init__(self, ext_dim=20, nb_flow=2, map_height=32, map_width=32):
 3         super().__init__()
 4         self.nb_flow = nb_flow
 5         self.map_height = map_height
 6         self.map_width = map_width
 7         self.feature = nn.Sequential(
 8             nn.Linear(ext_dim, 10),
 9             nn.ReLU(inplace=True),
10             nn.Linear(10, nb_flow * map_height * map_width),
11             nn.ReLU(inplace=True))
12     def forward(self, x):
13         x = self.feature(x.to(torch.float32))
14         x = torch.reshape(x, [-1, self.nb_flow, self.map_height, self.map_width])
15         return x

在上述代码中,第7~11行是对应的两个全连接层,其中第2个全连接层的输出维度为$2\times32\times32$。第13~14行便是整个前向传播过程,其中第14行是把输出特征变形为特征图的形式以便后续进行特征融合。

基于上述已实现的各个子模块便可以实现完整的STResNet模型,示例代码如下所示:

 1 class STResNet(nn.Module):
 2     def __init__(self, config=None):
 3         super().__init__()
 4         self.close = ResComponent(config.nb_flow * config.len_closeness,
 5                     	config.conv1_out_chs, config.num_res_unit, config.res_out_chs)
 6         self.period = ResComponent(config.nb_flow * config.len_period,
 7                       config.conv1_out_chs, config.num_res_unit, config.res_out_chs)
 8         self.trend = ResComponent(config.nb_flow * config.len_trend,
 9                      config.conv1_out_chs, config.num_res_unit, config.res_out_chs)
10         self.ext_feature = FeatureExt(config.ext_dim, config.nb_flow, 
11                                       config.map_height, config.map_width)
12         self.w_c = nn.Parameter(torch.randn([1, config.nb_flow, 
13                                             config.map_height, config.map_width]))
14         self.w_p = nn.Parameter(torch.randn([1, config.nb_flow, 
15                                              config.map_height, config.map_width]))
16         self.w_t = nn.Parameter(torch.randn([1, config.nb_flow, 
17                                              config.map_height, config.map_width]))

在上述代码中,第4~10行分别用于构建图8-16中邻近性、周期性、趋势性和额外因素这4个模块。第12~17行则是随机初始化3个权重矩阵用于后续的特征融合。

整个前向传播计算过程的示例代码如下所示:

 1     def forward(self, x, y=None):
 2         x0 = self.close(x[0])
 3         x1 = self.period(x[1])
 4         x2 = self.trend(x[2])
 5         x3 = self.ext_feature(x[3])
 6         y1 = x0 * self.w_c + x1 * self.w_p + x2 * self.w_t
 7         logits = torch.tanh(y1 + x3)
 8         if y is not None:
 9             loss_fct = nn.MSELoss()
10             loss = loss_fct(logits, y)
11             return loss, logits
12         else:
13             return logits

在上述代码中,第1行中x是一个列表包含有图8-16中4个子模块的输入内容。第2~5行是4个子模块各自特征提取的前向传播计算过程。第6行是对邻近性、周期性和趋势性3个模块输出结果的融合。第7行是对整个输出进行非线性变换将结果压缩至$[-1,1]$。第8~13行是根据输入值返回不同的计算结果,其中nn.MSELoss()用于回归任务中计算均方误差作为损失,详见「第3.8.1节 深度学习回归模型评估指标:MSE、RMSE 与 R2」内容。

最后,可以通过如下方式来进行使用:

 1 if __name__ == '__main__':
 2     x0 = torch.randn([16, 6, 32, 32])
 3     x1 = torch.randn([16, 2, 32, 32])
 4     x2 = torch.randn([16, 2, 32, 32])
 5     x3 = torch.randint(0, 2, size=[16, 28])
 6     y = torch.randn([16, 2, 32, 32])
 7     x = [x0, x1, x2, x3]
 8     st = STResNet(config)
 9     loss, logits = st(x, y)
10     print(logits.shape) # torch.Size([16, 2, 32, 32])

2. 模型训练

在完成整个网络结构的实现后便可以进行模型训练。对于模型训练的整个实现过程同之前类似这里就不再赘述,下面介绍一下模型评估部分的实现过程,示例代码如下所示:

 1 def evaluate(data_iter, model, device, mmn):
 2     model.eval()
 3     all_logits, all_labels = [], []
 4     with torch.no_grad():
 5         for i, (XC, XP, XT, Y, meta_test, _) in enumerate(data_iter):
 6             XC, XP, XT, Y = XC.to(device), XP.to(device), XT.to(device), Y.to(device)
 7             meta_test = meta_test.to(device)
 8             loss, logits = model([XC, XP, XT, meta_test], Y)
 9             all_logits.append(logits) 
10             all_labels.append(Y) 
11         model.train()
12         rmse = compute_rmse(all_logits, all_labels, mmn)
13     return rmse

在上述代码中,第5~8行是模型的前向传播计算过程。第9~10行则是将模型的预测结果和真实结果以此存放至列表中。第12行则是计算均方根误差来作为模型的评估指标,计算公式详见3.8.1节内容,compute_rmse()函数的实现如下:

1 def compute_rmse(all_logits=None, all_labels=None, mmn=None):
2     all_logits = torch.cat(all_logits, dim=0)  
3     all_labels = torch.cat(all_labels, dim=0) 
4     y_pred = all_logits.detach().cpu().numpy()
5     y_true = all_labels.detach().cpu().numpy()
6     y_pred = mmn.inverse_transform(y_pred)
7     y_true = mmn.inverse_transform(y_true)
8     rmse = np.sqrt(np.mean((y_pred - y_true) ** 2))
9     return rmse

在上述代码中,第1行all_logitsall_labels分别为上面预测值和真实值的保存结果,均为一个列表。第2~3行则是将两者转换为4维的张量,形状为[n,2,32,32]。第4~5行则是将两者存放至CPU上并转换为ndarray类型。第6~7行则是将预测结果从范围$[-1,1]$还原至真实值。第8行是计算预测值与真实值之间的均方根误差。

3. 模型推理

在完成模型训练后则可以进一步实现模型的推理过程,示例代码如下所示:

 1 def inference(config):
 2     data_loader = TaxiBJ(config.T, config.nb_flow, config.len_test, config.len_closeness,
 3                          config.len_period, config.len_trend, batch_size=config.batch_size)
 4     test_iter, mmn = data_loader.load_train_test_data(is_train=False)
 5     model = STResNet(config)
 6     if os.path.exists(config.model_save_path):
 7         checkpoint = torch.load(config.model_save_path)
 8         model.load_state_dict(checkpoint)
 9     else:
10         raise ValueError(f" # 模型{config.model_save_path}不存在!")
11     rmse = evaluate(test_iter, model, config.device, mmn)
12     logging.info(f" # RMSE on test: {rmse}")

在上述代码中,第2~4行是根据配置信息返回测试集对应的迭代器。第6~10行用于判断本地是否存在模型文件,存在则返回。第11行是计算对应的均方根误差评估值。

最后,在对网络模型进行训练时将会得到类似如下的输出结果:

 1 Epochs[1/50]--batch[0/215]--loss: 1.1327
 2 Epochs[1/50]--batch[50/215]--loss: 0.0412
 3 Epochs[1/50]--batch[100/201]--loss: 0.0098
 4 Epochs[1/50]--batch[150/201]--loss: 0.0102
 5 Epochs[1/50]--batch[200/201]--loss: 0.0071
 6 Epochs[1/50]--Total loss: 22.2378
 7 RMSE on train: 54.104
 8 RMSE on test: 55.241
 9 Epochs[13/50]--Total loss: 0.2029
10 RMSE on test: 19.942

8.6.6 小结#

在本节内容中,我们首先介绍了STResNet模型所提出的动机及其需要解决的问题;然后详细了整个任务的背景和模型样本采样的整体逻辑;接着进一步详细介绍了STResNet模型各部分的原理和北京出租车数据集的构建过程;最后介绍了如何从零开始实现整个STResNet模型并在北京出租车数据集上进行了测试。

引用#

[1] Zhang J, Zheng Y, Qi D. Deep spatio-temporal residual networks for citywide crowd flows prediction[C]. Proceedings of the AAAI conference on artificial intelligence. 2017, 31(1).

您当前阅读的内容现已出版,点击右侧了解

10章教学课件,400余幅示意插图、40个示例源代码,助力读者轻松迈入深度学习的大门!

查看详情
阅读 --

8.4 ConvLSTM网络

在8.3节内容中,我们介绍了几种将CNN和RNN进行结合的时序模型,包括串行的方式将CNN和RNN进行结合、以并行的方式将CNN和RNN进行结合。同时,在这些任务场景中序列样本所拥有的一个共同特点便是对于每个序列中的每个时刻来说,其特征表示 …

3.8 回归模型评估指标

在本节中,我们首先通过一个示例介绍了为什么我们需要引入评估指标,即如何评价一个回归模型的优与劣;然后详细地逐一介绍了5种常用的评估指标和实现方法;最后,我们还逐一展示了评价指标的示例用法。