0%

Python中的堆排序与优先队列

对数据进行排序是一个很常见的需求,但有时候我们并不需要对完整的数据进行排序,只需要排前几的数据,也就是经典的 Top-K 问题。

Top-K 问题的经典解法有两种:一种是脱胎于快速排序(Quick Sort)的快速选择(Quick Select)算法,核心思路是在每一次Partion操作后下一次递归只操作前K项数据。另一种是基于堆排序的方法。

Python 中有两个标准库可以原生的支持堆排序(优先队列),分别是heapqPriorityQueue(queue)

heapq

heapq标准库提供了一些工具函数用来对list对象实现二叉堆的各种操作(就地修改list对象)。简单的用法如下:

建堆

1
2
3
4
5
6
import heapq

# 可以用过random.shuffle函数创造乱序数组
arr = [4, 0, 3, 1, 6, 5, 9, 7, 8, 2]
heapq.heapify(arr)
assert arr == [0, 1, 3, 4, 2, 5, 9, 7, 8, 6]

获取堆顶元素

1
2
assert heapq.heappop(arr) == 0
assert arr == [1, 2, 3, 4, 6, 5, 9, 7, 8]

插入新元素

1
2
heapq.heappush(arr, 11)
assert arr == [1, 2, 3, 4, 6, 5, 9, 7, 8, 11]

heapq也提供了直接获取nlargestnsmallest函数,并且这两个函数并不会就地修改原数据。

1
2
3
4
arr = [4, 0, 3, 1, 6, 5, 9, 7, 8, 2]
assert heapq.nlargest(5, arr) == [9, 8, 7, 6, 5]
assert arr == [4, 0, 3, 1, 6, 5, 9, 7, 8, 2]
assert heapq.nsmallest(5, arr) == [0, 1, 2, 3, 4]

queue.PriorityQueue

queue标准库为 Python 代码提供了原生线程安全的队列实现。queue.PriorityQueue则是 Python 原生的优先队列实现,相比heapq有着更直观易用的接口。

创建优先队列

1
2
3
4
5
6
7
8
from queue import PriorityQueue

pq = PriorityQueue()

arr = [4, 0, 3, 1, 6, 5, 9, 7, 8, 2]

for num in arr:
pq.put(num)

获取队首元素

1
2
while not pq.empty():
assert pq.get() == 0

对比

heapq标准库是专门用来做堆排序相关操作的,而PriorityQueue类毕竟继承于queue.Queue,适用于多线程通信场景。两者的效率还是有着不小差距的。

我们以 LeetCode 973(最接近原点的 K 个点)为例,分别用heapqPriorityQueue实现,比较一下二者的运行效率。

题目描述

973. 最接近原点的 K 个点

我们有一个由平面上的点组成的列表 points。需要从中找出 K 个距离原点 (0, 0) 最近的点。

(这里,平面上两点之间的距离是欧几里德距离。)

你可以按任何顺序返回答案。除了点坐标的顺序之外,答案确保是唯一的。

示例 1

输入:points = [[1,3],[-2,2]], K = 1
输出:[[-2,2]]
解释:
(1, 3) 和原点之间的距离为 sqrt(10),
(-2, 2) 和原点之间的距离为 sqrt(8),
由于 sqrt(8) < sqrt(10),(-2, 2) 离原点更近。
我们只需要距离原点最近的 K = 1 个点,所以答案就是 [[-2,2]]。

示例 2

输入:points = [[3,3],[5,-1],[-2,4]], K = 2
输出:[[3,3],[-2,4]]
(答案 [[-2,4],[3,3]] 也会被接受。)

提示:

1 <= K <= points.length <= 10000
-10000 < points[i][0] < 10000
-10000 < points[i][1] < 10000

生成测试数据

1
2
3
4
5
from random import randint
def genPoints(n:int = 100):
return [(randint(0, 100), randint(0, 100)) for _ in range(n)]
points = genPoints(1_0000)
less_points = genPoints(100)

heapq实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import heapq

from typing import List


def distance(point: List[int]):
return point[0] ** 2 + point[1] ** 2


class Solution:
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
distances = [(distance(point), point) for point in points]
return [e[1] for e in heapq.nsmallest(K, distances)]


solution = Solution()
%timeit solution.kClosest(points, 100)
# 6.79 ms ± 181 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

PriorityQueue实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from queue import PriorityQueue
from typing import List


def distance(point: List[int]):
return point[0] ** 2 + point[1] ** 2


class Solution:
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
pq = PriorityQueue()
for point in points:
pq.put((distance(point), point), block=False)
ret = []
while not pq.empty():
ret.append(pq.get()[1])
return ret


solution = Solution()
%timeit solution.kClosest(points,100)
# 52.2 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

我们可以看到heapq版本比PriorityQueue版本快接近一个数量级,并且代码也更精简。

这也说明了我们要在合适的地方使用合适的工具。

扫码加入技术交流群🖱️
QR code