更新于 2026年6月28日

5.5 从零实现K近邻#

在前面几节内容中,我们已经详细地介绍了KNN的基本思想与原理,以及kd树的构建过程和搜索原理等。但是对于KNN和kd树具体的实现细节并没有做过多的介绍。下面我们就开始正式介绍如何从零实现kd树以及完成整个KNN的代码实现。以下完整示例代码可参见AllBooKCode/Chapter05/C05_knn_imp_from_scratch.py 文件。

5.5.1 kd树节点定义#

根据第5.4.1节内容介绍,kd树本质上也就等同于二叉搜索树,因此,首先我们需要定义kd树中的节点信息,以及kd树的构建与查询等。同时,由于在KNN的预测结果中需要根据训练样本给出每个预测样本的标签值,因此就需要知道每个训练样本的原始标签值,故需要在节点中保存每个样本索引。最终,kd树的节点信息定义如下:

1 class Node(object):
2     def __init__(self, data=None, index=-1):
3         self.data = data
4         self.left_child = None
5         self.right_child = None
6         self.index = index
7 
8     def __str__(self):
9         return f"data({self.data}),index({int(self.index)})"

在上述代码中,第2~6行定义了节点Node中保存的具体信息,包括样本点、左右子树以及在原始样本中的索引;第8~9行定义了__str__()方法,其作用是在使用print()函数时可以直接打印出节点的信息,而不必用node.data这样的形式来访问节点中的样本。

5.5.2 kd树构建#

在完成kd树节点的定义之后,下一步就可以开始定义构建kd树的整个过程。首先,我们需要定义类的初始化函数,示例代码如下:

1 class MyKDTree(object):
2     def __init__(self, points):
3         self.root = None
4         self.dim = points.shape[1]
5         points = np.hstack(([points, np.arange(0, len(points)).reshape(-1, 1)]))
6         self.insert(points, order=0)  # 递归构建KD树
7 
8     def is_empty(self):
9         return not self.root

在上述代码中,第3行定义了kd树的根节点。第4行定义了原始样本的维度。第5行用于在样本点的最后一列附加上每个样本点的索引值。第6行则是调用insert()方法递归完成kd树的构建;第8~9行定义了一个方法来判断当前kd是否为空。

接下来便是完成insert()方法的实现过程,示例代码如下:

 1     def insert(self, data, order=0):
 2         if len(data) < 1:
 3             return
 4         data = sorted(data, key=lambda x: x[order % self.dim])  # 按某个维度进行排序
 5         idx = len(data) // 2
 6         node = Node(data[idx][:-1], data[idx][-1])
 7         left_data = data[:idx]
 8         right_data = data[idx + 1:]
 9         if self.is_empty():
10             self.root = node  # 整个kd树的根节点
11         node.left_child = self.insert(left_data, order + 1)  # 递归构建左子树
12         node.right_child = self.insert(right_data, order + 1)  # 递归构建右子树
13         return node

在上述代码中,第2~3行用来判断当前传入的样本点是否为空,如果为空则结束当前递归。第4行用于将当前样本按照某个维度的大小顺序进行排序,其中样本点维度的比较顺序为从左到右依次轮询,order在每次进行递归时都会累加;第5~6行用来获取并保存当前样本点排序后中间位置的样本并保存到一个新初始化的节点中;第7~8行则是分别取当前排序后样本的左边部分和右边部分,以此来分别作为当前节点的左右子树。第11~12行则是分别递归构建左右子树。

5.5.3 kd构建示例#

在实现kd树的构建代码后,便可以通过如下方式来进行使用:

1 def test_kd_tree_build(points):
2     tree = MyKDTree(points)
3     tree.level_order()
4     
5 if __name__ == '__main__':
6     points = np.array([[5, 7], [3, 8], [6, 3], [8, 5], [15, 6.], [10, 4], [12, 13], [9, 10], [11, 14]])
7     test_kd_tree_build_in_book(points)

在上述代码中,第2行便是根据传入的样本点来递归的构建kd树。第3行则是将构建完成的kd树以层次遍历的方式打印出来。

以上代码运行结束后便会有类似如下信息输出:

 1 当前待划分样本点[[3., 8., 1.], [5., 7., 0.], [6., 3., 2.], [8., 5., 3.], 
 2 [9., 10., 7.], [10., 4., 5.], [11., 14., 8.], [12., 13., 6.], [15., 6., 4.]]
 3 父节点[ 9. 10.  7.]
 4 左子树: [[3., 8., 1.], [5., 7., 0.], [6., 3., 2.], [8., 5., 3.]]
 5 右子树: [[10., 4., 5.], [11., 14., 8.], [12., 13., 6.], [15., 6., 4.]]
 6 ============
 7 当前待划分样本点[[6., 3., 2.], [8., 5., 3.], [5., 7., 0.], [3., 8., 1.]]
 8 父节点[5. 7. 0.]
 9 左子树: [[6., 3., 2.], [8., 5., 3.]]
10 右子树: [[3., 8., 1.]]
11 ============
12 ......
13 层次遍历结果为:
14 第1层的节点为:<p([ 9. 10.]), idx(7)>
15 第2层的节点为:<p([5. 7.]), idx(0)> <p([12. 13.]), idx(6)>
16 第3层的节点为:<p([8. 5.]), idx(3)> <p([3. 8.]), idx(1)>
17 			<p([15.  6.]), idx(4)> <p([11. 14.]), idx(8)>
18 第4层的节点为:<p([6. 3.]), idx(2)> <p([10.  4.]), idx(5)>

在上述输出结果中,第1~12行便是kd树在构建过程中所输出的信息,需要再次提醒的是样本点的最后一个维度为当前样本点在原始样本中的索引值。第13~18行则是构建完成后kd树的层次遍历结果。最后,根据层次遍历以及构建过程的输出结果,也可以还原得到图5-5中所示的kd树。

5.5.4 kd树最近邻搜索#

在实现最近邻的搜索过程之前首先需要根据式(5-1)中的定义来实现两个点之间距离的计算,实现代码如下所示:

1 def distance(p1, p2, p=2):
2     return np.sum((p1 - p2) ** p) ** (1 / p)

在上述代码中,当$p=2$时就是我们熟悉的欧式距离。

进一步,根据5.4.2节中kd树最近邻搜索的伪代码实现过程,我们在类MyKDTree实现一个方法来完成最近邻的搜索过程,示例代码如下:

 1     def nearest_search(self, point):
 2         best_node,best_dist = None, np.inf
 3         visited,point = [],point.reshape(-1)
 4         def nearest_node_search(point, curr_node, order=0):
 5             nonlocal best_node, best_dist, visited
 6             if curr_node is None:
 7                 return None
 8             visited.append(curr_node)
 9             dist = self.distance(curr_node.data, point, self.p)
10             if dist < best_dist:
11                 best_dist,best_node = dist,curr_node
12             cmp_dim = order % self.dim
13             if point[cmp_dim] < curr_node.data[cmp_dim]:
14                 nearest_node_search(point, curr_node.left_child, order + 1)
15             else:
16                 nearest_node_search(point, curr_node.right_child, order + 1)
17             if np.abs(curr_node.data[cmp_dim] - point[cmp_dim]) < best_dist:
18                 child = curr_node.left_child if curr_node.left_child 
19                 				   not in visited else curr_node.right_child
20                 nearest_node_search(point, child, order + 1)
21         nearest_node_search(point, self.root, 0)
22         return best_node, best_dist

在上述代码中,第2~3行定义了相关的全局记录变量。第5行则是用来声明这3个变量不是局部变量而是上面定义的全局变量。第4行是定义一个函数来完成后续的递归搜索。第8行则是用来记录当前哪些节点已经被访问过。第9行是计算当前节点到被搜索点的距离。第10~11行则是判断是否要更新当前最佳节点。第12行是计算得到进入左右子树时的判断维度。第13~16行是根据维度比较信息递归遍历相应的左子树或右子树。第17~20行则是根据第5.4.2节中的子空间排除原理来判断当前节点左右子树中未访问过的节点是否存在最佳节点,并进行递归遍历。第21行则是开始进入到递归搜索中。第22行是返回最后的最佳节点和最短距离。

最后,可以通过如下方式来进行kd树中最近邻样本点的搜索:

 1 def test_kd_nearest_search(points, q=None):
 2     tree = MyKDTree(points)
 3     best_node, best_dist = tree.nearest_search(q)
 4     logging.info("MyKDTree 运行结果:")
 5     logging.info(f"离样本点{q}最近的节点是:{best_node},距离为:{round(best_dist, 3)}")
 6 
 7     kd_tree = KDTree(points)
 8     dist, ind = kd_tree.query(q, k=1)
 9     logging.info("sklearn KDTree 运行结果:")
10     logging.info(f"离样本点{q}最近的节点是:{points[ind]},距离为:{dist}")
11 if __name__ == '__main__':
12 	points = np.array([[5, 7], [3, 8], [6, 3], [8, 5], [15, 6.], 
13 					   [10, 4], [12, 13], [9, 10], [11, 14]])
14     test_kd_nearest_search(points, q=np.array([[8.9, 4]]))

上述代码运行结束后便会输出类似如下的结果:

1 MyKDTree 运行结果
2 离样本点[[8.9 4. ]]最近的节点是data([10. 4.]),index(5),距离为1.1
3 sklearn KDTree 运行结果
4 离样本点[[8.9 4. ]]最近的节点是[[[10. 4.]]],距离为[[1.1]]

根据上述输出结果,可以得知距离样本点[8.9,4]最近的样本点是[10,4],并且整个kd树的搜索过程将如图5-9所示。

5.5.5 kd树K近邻搜索#

在第5.4节内容中我们已经详细介绍了kd树K近邻的搜索原理,因此在这里就不再赘述直接按照之前给出的伪代码来进行实现即可。在实现K近邻搜索之前,需要先在类MyKDTree实现一个方法用于对节点进行插入并排序,示例代码如下:

1     def append(self, k_nearest_nodes, curr_node, point):
2         k_nearest_nodes.append(curr_node)
3         k_nearest_nodes = sorted(k_nearest_nodes,
4              key=lambda x: self.distance(x.data, point, self.p))
5         return k_nearest_nodes

在上述代码中,第2行是将当前节点加入到k_nearest_nodes中。第3-4行是根据k_nearest_nodes中每个样本点到被搜索点的距离来进行升序排序。

进一步,便可以实现得到kd树的K近邻搜索逻辑,示例代码如下:

 1    def _k_nearest_search(self, point, k):
 2         k_nearest_nodes,visited,n = [],[],0
 3         def k_nearest_node_search(point, curr_node, order=0):
 4             nonlocal k_nearest_nodes, n
 5             if curr_node is None:
 6                 return None
 7             visited.append(curr_node)
 8             if n < k: 
 9                 n += 1
10                 k_nearest_nodes = self.append(k_nearest_nodes, curr_node, point)
11             else: 
12                 d1 = self.distance(curr_node.data, point, self.p)
13                 d2 = self.distance(point, k_nearest_nodes[-1].data, self.p)
14                 if d1 < d2:
15                     k_nearest_nodes.pop()
16                     k_nearest_nodes = self.append(k_nearest_nodes, curr_node, point)
17             cmp_dim = order % self.dim
18             if point[cmp_dim] < curr_node.data[cmp_dim]:
19                 k_nearest_node_search(point, curr_node.left_child, order + 1)
20             else:
21                 k_nearest_node_search(point, curr_node.right_child, order + 1)
22             if n < k or np.abs(curr_node.data[cmp_dim] - point[cmp_dim]) < \
23                     self.distance(point, k_nearest_nodes[-1].data, self.p):
24                 child = curr_node.left_child if curr_node.left_child \
25                 		not in visited else curr_node.right_child
26                 k_nearest_node_search(point, child, order + 1)
27         k_nearest_node_search(point, self.root, 0)
28         return k_nearest_nodes

在上述代码中,第2行定义了相关的全局记录变量。第3行是定义一个函数来完成后续的递归搜索过程。第4行则是用来声明这3个遍历不是局部变量而是上面定义的全局变量。第7行用来记录访问过的节点。第8~10行是判断如果当前还没找到K个点,则直接进行保存。第11~16行是判断如果已经找到K个局部最优点,则开始按距离进行筛选。第17~21行则是根据当前比较维度来判断进入左子树还是右子树进行递归遍历。第22~26行是根据子空间排除原理来判断当前节点左右子树中未访问过的节点是否存在最佳节点,并进行递归遍历。

到此,对于任一点的K近邻查找就已经实现完毕了,下面只需要再定义一个方法即可循环完成多个样本点的K近邻搜索过程,示例代码如下:

 1 	def k_nearest_search(self, points, k):
 2         result_points, result_ind = [], []
 3         for point in points:
 4             k_nodes = self._k_nearest_search(point, k)
 5             tmp_points,tmp_ind = [], []
 6             for node in k_nodes:
 7                 tmp_points.append(node.data)
 8                 tmp_ind.append(int(node.index))
 9             result_points.append(tmp_points)
10             result_ind.append(tmp_ind)
11         return np.array(result_points), np.array(result_ind)

在上述代码中,第3~4行开始便是遍历被一个被搜索的样本,并根据上面介绍的_k_nearest_search()方法来完成每个样本的最近K个样本的搜索。第6~10行是将每个被搜索样本的的结果进行格式化处理。

最后,可以通过如下方式来进行kd树中K近邻样本点的搜索:

 1 def test_kd_k_nearest_search(points):
 2     tree = MyKDTree(points)
 3     p = np.array([[8.9, 4]])
 4     k_best_nodes, ind = tree.k_nearest_search(p, k=3)
 5     logging.info("MyKDTree 运行结果:")
 6     logging.info(f"\n{k_best_nodes}")
 7     logging.info(f"样本索引编号为:{ind}")
 8 
 9     logging.info("sklearn KDTree 运行结果:")
10     kd_tree = KDTree(points)
11     dist, ind = kd_tree.query(p, k=3)
12     logging.info(f"\n{points[ind]}", )
13     logging.info(f"样本索引编号为:{ind}")
14 
15 if __name__ == '__main__':
16     points = np.array([[5, 7], [3, 8], [6, 3], [8, 5], 
17     	[15, 6.], [10, 4], [12, 13], [9, 10], [11, 14]])
18     test_kd_k_nearest_search(points)

上述代码运行结束后便会输出类似如下的结果:

 1 # =========== 正在查找离样本点[8.9 4. ]最近的3个样本点!==========
 2 K近邻搜索当前访问节点为data([ 9. 10.]),index(7)
 3 有序列表中的节点数目为0 < 3直接加入新节点并排序
 4 当前K近邻有序列表中的节点为已按距离升序排序):
 5 data([ 9. 10.]),index(7)
 6 访问当前节点data([ 9. 10.]),index(7)的左孩子
 7 K近邻搜索当前访问节点为data([5. 7.]),index(0)
 8 有序列表中的节点数目为1 < 3直接加入新节点并排序
 9 当前K近邻有序列表中的节点为已按距离升序排序):
10 data([5. 7.]),index(0), data([ 9. 10.]),index(7)
11 ...
12 回到上一层递归当前访问节点为 data([3. 8.]),index(1)开始判断步骤(6)
13 回到上一层递归当前访问节点为 data([ 9. 10.]),index(7)开始判断步骤(6)
14 被搜索点到当前节点划分维度的距离小于列表最后一个元素到被搜索点的距离|9.0 - 8.9|<4.92
15 访问当前节点data([ 9. 10.]),index(7)的右孩子
16 ...
17 MyKDTree 运行结果[[[10.  4.] [ 8.  5.] [ 6.  3.]]]
18 样本索引编号为:[[5 3 2]]
19 sklearn KDTree 运行结果[[[10.  4.] [ 8.  5.][ 6.  3.]]]
20 样本索引编号为:[[5 3 2]]

到此,对于kd树的整体实现就介绍完了,接下来就是KNN分类模型的实现。

5.5.6 KNN实现#

在完成kd的相关功能实现后,我们只需要在此基础上稍加封装根据投票原则便可以实现KNN分类模型。

1. 模型拟合

对于KNN分类模型来说,所谓的模型拟合其实就是根据给定的训练样本来构造完成对应的kd树,示例代码如下:

1 class MyKNN():
2     def __init__(self, n_neighbors, p=2):
3         self.n_neighbors = n_neighbors
4         self.p = p
5 
6     def fit(self, x, y):
7         self._y = y
8         self.kd_tree = MyKDTree(x, self.p)
9         return self

在上述代码中,第2~4行为类MyKNN的初始化方法,用来保存K值和距离计算方法p值。第6~9行则是模型的拟合过程,即建立kd树。

2. 模型预测

在模型拟合完成后,便可以对新输入的样本进行预测。不过根据k_nearest_search()方法的返回结果可知,K近邻搜索返回的是训练集中K个节点对应的索引位置,因此需要先根据索引取到对应的标签值;然后再根据取到的标签通过投票法来确定新样本的类别。因此此处需要先定义一个辅助方法来完成类别的确定,示例代码如下:

1     def get_pred_labels(query_label):
2         y_pred = [0] * len(query_label)
3         for i, label in enumerate(query_label):
4             max_fre, count_dict = 0, {}
5             for l in label:
6                 count_dict[l] = count_dict.setdefault(l, 0) + 1
7                 if count_dict[l] > max_freq:
8                     max_freq, y_pred[i] = count_dict[l], l
9         return np.array(y_pred)

在上述代码中,第1行query_label是一个二维数组,query_label[i] 表示离第i个样本最近的K个样本点对应的正确标签。第3~8行则是分别开始遍历每个样本的预测结果,其中第5~8行是根据投票规则来确定当前样本的所属类别。

进一步,可以通过如下方式来完成新样本的预测:

1     def predict(self, x):
2         k_best_nodes, ind = self.kd_tree.k_nearest_search(x, k=self.n_neighbors)
3         query_label = self._y[ind]
4         y_pred = self.get_pred_labels(query_label)
5         return y_pred

最后,以iris数据集为例来进行实验,并同时与sklearn中KNeighborsClassifier的分类结果进行对比。

 1 if __name__ == '__main__':
 2     x_train, x_test, y_train, y_test = load_data()
 3     k = 5
 4     model = KNeighborsClassifier(n_neighbors=k)
 5     model.fit(x_train, y_train)
 6     y_pred = model.predict(x_test)
 7     logging.info(f"impl_by_sklearn 准确率:{accuracy_score(y_test, y_pred)}")
 8 
 9     my_model = MyKNN(n_neighbors=k)
10     my_model.fit(x_train, y_train)
11     y_pred = my_model.predict(x_test)
12     logging.info(f"impl_by_ours 准确率:{accuracy_score(y_test, y_pred)}")

上述代码运行结束后的结果如下所示:

1 impl_by_sklearn 准确率0.9556
2 impl_by_ours 准确率0.9556

从上述结果可以看出,两者的分类准确率并没有任何差异。

5.5.7 总结#

在本节内容中,我们首先介绍了kd树节点的定义、kd树的从零实现过程以及其使用示例;然后详细介绍了如何实现kd树的最近邻和K近邻搜索过程;最后介绍了如何基于kd树搜索实现整个K近邻算法并将其同sklearn中的KNN算法进行了对比。

总结一下,在本章中我们首先介绍了K近邻算法的主要思想及其原理,包括K值选择和距离的度量方式等;接着简单地总结了sklearn框架接口的设计风格及通用的建模步骤;然后介绍了如何通过sklearn建立完整的K近邻分类器,包括模型训练、模型选择、并行搜索和交叉验证等;最后详细介绍了如何通过kd树来实现K近邻算法,并与sklearn中的对应模块进行了对比。

引用#

[1] 李航.统计学习方法[M].2版.北京: 清华大学出版社,2019.

[2] PEDREGOSA.scikitlearn: Machine Learning in Python[J].JMLR 12,2011: 28252830.

[3] https://en.wikipedia.org/wiki/Kd_tree

[4] http://web.stanford.edu/class/cs106l/

阅读 --