5.5 从零实现K近邻#
在前面几节内容中,我们已经详细地介绍了KNN的基本思想与原理,以及kd树的构建过程和搜索原理等。但是对于KNN和kd树具体的实现细节并没有做过多的介绍。下面我们就开始正式介绍如何从零实现kd树以及完成整个KNN的代码实现。以下完整示例代码可参见AllBooKCode/Chapter05/C05_knn_imp_from_scratch.py 文件。
5.5.1 kd树节点定义#
根据第5.4.1节内容介绍,kd树本质上也就等同于二叉搜索树,因此,首先我们需要定义kd树中的节点信息,以及kd树的构建与查询等。同时,由于在KNN的预测结果中需要根据训练样本给出每个预测样本的标签值,因此就需要知道每个训练样本的原始标签值,故需要在节点中保存每个样本索引。最终,kd树的节点信息定义如下:
1 class Node(object):
2 def __init__(self, data=None, index=-1):
3 self.data = data
4 self.left_child = None
5 self.right_child = None
6 self.index = index
7
8 def __str__(self):
9 return f"data({self.data}),index({int(self.index)})"在上述代码中,第2~6行定义了节点Node中保存的具体信息,包括样本点、左右子树以及在原始样本中的索引;第8~9行定义了__str__()方法,其作用是在使用print()函数时可以直接打印出节点的信息,而不必用node.data这样的形式来访问节点中的样本。
5.5.2 kd树构建#
在完成kd树节点的定义之后,下一步就可以开始定义构建kd树的整个过程。首先,我们需要定义类的初始化函数,示例代码如下:
1 class MyKDTree(object):
2 def __init__(self, points):
3 self.root = None
4 self.dim = points.shape[1]
5 points = np.hstack(([points, np.arange(0, len(points)).reshape(-1, 1)]))
6 self.insert(points, order=0) # 递归构建KD树
7
8 def is_empty(self):
9 return not self.root在上述代码中,第3行定义了kd树的根节点。第4行定义了原始样本的维度。第5行用于在样本点的最后一列附加上每个样本点的索引值。第6行则是调用insert()方法递归完成kd树的构建;第8~9行定义了一个方法来判断当前kd是否为空。
接下来便是完成insert()方法的实现过程,示例代码如下:
1 def insert(self, data, order=0):
2 if len(data) < 1:
3 return
4 data = sorted(data, key=lambda x: x[order % self.dim]) # 按某个维度进行排序
5 idx = len(data) // 2
6 node = Node(data[idx][:-1], data[idx][-1])
7 left_data = data[:idx]
8 right_data = data[idx + 1:]
9 if self.is_empty():
10 self.root = node # 整个kd树的根节点
11 node.left_child = self.insert(left_data, order + 1) # 递归构建左子树
12 node.right_child = self.insert(right_data, order + 1) # 递归构建右子树
13 return node在上述代码中,第2~3行用来判断当前传入的样本点是否为空,如果为空则结束当前递归。第4行用于将当前样本按照某个维度的大小顺序进行排序,其中样本点维度的比较顺序为从左到右依次轮询,order在每次进行递归时都会累加;第5~6行用来获取并保存当前样本点排序后中间位置的样本并保存到一个新初始化的节点中;第7~8行则是分别取当前排序后样本的左边部分和右边部分,以此来分别作为当前节点的左右子树。第11~12行则是分别递归构建左右子树。
5.5.3 kd构建示例#
在实现kd树的构建代码后,便可以通过如下方式来进行使用:
1 def test_kd_tree_build(points):
2 tree = MyKDTree(points)
3 tree.level_order()
4
5 if __name__ == '__main__':
6 points = np.array([[5, 7], [3, 8], [6, 3], [8, 5], [15, 6.], [10, 4], [12, 13], [9, 10], [11, 14]])
7 test_kd_tree_build_in_book(points)在上述代码中,第2行便是根据传入的样本点来递归的构建kd树。第3行则是将构建完成的kd树以层次遍历的方式打印出来。
以上代码运行结束后便会有类似如下信息输出:
1 当前待划分样本点:[[3., 8., 1.], [5., 7., 0.], [6., 3., 2.], [8., 5., 3.],
2 [9., 10., 7.], [10., 4., 5.], [11., 14., 8.], [12., 13., 6.], [15., 6., 4.]]
3 父节点:[ 9. 10. 7.]
4 左子树: [[3., 8., 1.], [5., 7., 0.], [6., 3., 2.], [8., 5., 3.]]
5 右子树: [[10., 4., 5.], [11., 14., 8.], [12., 13., 6.], [15., 6., 4.]]
6 ============
7 当前待划分样本点:[[6., 3., 2.], [8., 5., 3.], [5., 7., 0.], [3., 8., 1.]]
8 父节点:[5. 7. 0.]
9 左子树: [[6., 3., 2.], [8., 5., 3.]]
10 右子树: [[3., 8., 1.]]
11 ============
12 ......
13 层次遍历结果为:
14 第1层的节点为:<p([ 9. 10.]), idx(7)>
15 第2层的节点为:<p([5. 7.]), idx(0)> <p([12. 13.]), idx(6)>
16 第3层的节点为:<p([8. 5.]), idx(3)> <p([3. 8.]), idx(1)>
17 <p([15. 6.]), idx(4)> <p([11. 14.]), idx(8)>
18 第4层的节点为:<p([6. 3.]), idx(2)> <p([10. 4.]), idx(5)>在上述输出结果中,第1~12行便是kd树在构建过程中所输出的信息,需要再次提醒的是样本点的最后一个维度为当前样本点在原始样本中的索引值。第13~18行则是构建完成后kd树的层次遍历结果。最后,根据层次遍历以及构建过程的输出结果,也可以还原得到图5-5中所示的kd树。
5.5.4 kd树最近邻搜索#
在实现最近邻的搜索过程之前首先需要根据式(5-1)中的定义来实现两个点之间距离的计算,实现代码如下所示:
1 def distance(p1, p2, p=2):
2 return np.sum((p1 - p2) ** p) ** (1 / p)在上述代码中,当$p=2$时就是我们熟悉的欧式距离。
进一步,根据5.4.2节中kd树最近邻搜索的伪代码实现过程,我们在类MyKDTree实现一个方法来完成最近邻的搜索过程,示例代码如下:
1 def nearest_search(self, point):
2 best_node,best_dist = None, np.inf
3 visited,point = [],point.reshape(-1)
4 def nearest_node_search(point, curr_node, order=0):
5 nonlocal best_node, best_dist, visited
6 if curr_node is None:
7 return None
8 visited.append(curr_node)
9 dist = self.distance(curr_node.data, point, self.p)
10 if dist < best_dist:
11 best_dist,best_node = dist,curr_node
12 cmp_dim = order % self.dim
13 if point[cmp_dim] < curr_node.data[cmp_dim]:
14 nearest_node_search(point, curr_node.left_child, order + 1)
15 else:
16 nearest_node_search(point, curr_node.right_child, order + 1)
17 if np.abs(curr_node.data[cmp_dim] - point[cmp_dim]) < best_dist:
18 child = curr_node.left_child if curr_node.left_child
19 not in visited else curr_node.right_child
20 nearest_node_search(point, child, order + 1)
21 nearest_node_search(point, self.root, 0)
22 return best_node, best_dist在上述代码中,第2~3行定义了相关的全局记录变量。第5行则是用来声明这3个变量不是局部变量而是上面定义的全局变量。第4行是定义一个函数来完成后续的递归搜索。第8行则是用来记录当前哪些节点已经被访问过。第9行是计算当前节点到被搜索点的距离。第10~11行则是判断是否要更新当前最佳节点。第12行是计算得到进入左右子树时的判断维度。第13~16行是根据维度比较信息递归遍历相应的左子树或右子树。第17~20行则是根据第5.4.2节中的子空间排除原理来判断当前节点左右子树中未访问过的节点是否存在最佳节点,并进行递归遍历。第21行则是开始进入到递归搜索中。第22行是返回最后的最佳节点和最短距离。