Zer0e's Blog

TopK问题与快速选择算法

字数统计: 1.9k阅读时长: 8 min
2020/05/13 Share

前言

TopK问题通常指的是寻找数组中第k大(小)的数或前k大(小)的数组。最简单的方法当然是使用排序,本文将从TopK问题入手,讲讲常见的几种解决方法并详细讲解快速选择算法。

实践1

解决TopK问题的解决方案有以下几种:

  • 排序
  • 快速选择算法

我们使用剑指offer中给出的一道题目作为例子。
输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。
接下来我将详细讲解这几种解决方案来解决上述问题。

排序

排序是最容易想到也是比较简单的方法,许多语言中都内置了排序方法,当然自己实现排序也是可以的。

1
2
3
4
class Solution:
def getLeastNumbers(self, arr, k)
arr.sort()
return arr[:k]

由于python中的排序使用的是快速排序,所以平均时间复杂度为O(nlogn)。

我们使用大根堆来解决上述问题。由于上述题目是寻找前k小的数,所以我们使用大根堆,poll出n-k个数,留下的就是前k小的数。详细思路为:将k个数插入大根堆中,从第k+1个数开始,如果当前数小于堆顶的数,把堆顶数弹出,再插入当前数。最后留在堆中的数即为前k小的数。
在java当中,可以使用PriorityQueue并重写比较器来实现一个大根堆,而python中因为heapq模块只支持小根堆,我们需要将数组中的数取反,才能使用小根堆来获得前k个最小值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
public int[] getLeastNumbers(int[] arr, int k) {
if (k == 0 || arr.length == 0) {
return new int[0];
}
Queue<Integer> pq = new PriorityQueue<>((v1, v2) -> v2 - v1);
for (int num: arr) {
if (pq.size() < k) {
pq.offer(num);
} else if (num < pq.peek()) {
pq.poll();
pq.offer(num);
}
}

int[] res = new int[pq.size()];
int idx = 0;
for(int num: pq) {
res[idx++] = num;
}
return res;
}
}

使用小根堆

1
2
3
4
5
6
7
8
9
10
11
12
13
class Solution:
def getLeastNumbers(self, arr, k)
if k == 0:
return list()

pq = [-x for x in arr[:k]]
heapq.heapify(hp)
for i in range(k, len(arr)):
if -pq[0] > arr[i]:
heapq.heappop(pq)
heapq.heappush(pq, -arr[i])
ans = [-x for x in pq]
return ans

使用堆的平均时间复杂度为O(nlogk),空间复杂度为O(k)。

快速选择

快速选择算法其实是快速排序的思想,我们可以先回忆下快排的思想。使用快排思想可以将数组分隔为左右两边,数组下标为[0,a)与[a,n),如果a刚好等于k-1的话,那么[0,a)就是我们要的前k小的数,如果a小于k-1则在右区间继续寻找a,如果a大于k-1的话则在左区间寻找。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution:
def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
def quickSelect(arr,left,right,k):
if k < 0:
return []

if left <= right:
i = left
j = right
key = arr[left]
while i < j:
while i < j and arr[j] > key:
j -= 1
while i < j and arr[i] <= key:
i += 1
if i < j:
arr[i],arr[j] = arr[j],arr[i]
arr[left],arr[j] = arr[j],arr[left]

if j == k:
return arr[:j+1]

if j > k:
return quickSelect(arr,left,j-1,k)
else:
return quickSelect(arr,j+1,right,k)

return quickSelect(arr,0,len(arr)-1,k-1)

这个算法的改进之处与快排的改进之处一致,在于每次对于key的选取,如果数组本身有序,并且key总是取左边一个数作为对比,或者说key的选取总是最大值或最小值,那么可能导致时间复杂度退化为O(n^2),并且由于快速选择相较于快速排序,只需要对左区间或者右区间进行partition,而不是左右区间都要partition,因此时间复杂度为N + N/2 + N/4 + … + N/N = 2N,即O(N)时间复杂度。

实践2

上面一道题我们解决了前k小的数,而TopK其实说的是top,即第k个大的数。我们使用leetcode第215题。
在未排序的数组中找到第 k 个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

我们照样使用堆和快速选择来解决这个问题。

这里我们使用小根堆,将数全部入堆,如果堆大小超过k,则poll出堆顶元素,最后在堆顶的就是第k大的数。

1
2
3
4
5
6
7
8
9
10
11
12
class Solution {
public int findKthLargest(int[] nums, int k) {
PriorityQueue<Integer> hp =
new PriorityQueue<Integer>();
for (int n: nums) {
hp.add(n);
if (hp.size() > k)
hp.poll();
}
return hp.poll();
}
}

在python的heapq模块中,我们可以使用nlargest方法来获取前k个大的数,并返回最后一个

1
2
3
class Solution:
def findKthLargest(self, nums, k):
return heapq.nlargest(k, nums)[-1]

快速选择

我们可以完全复制上一道题的代码,只需改动些许地方。1.当j==k时,返回的是一个数。2.由于上一道题代码是找第k个小的数,所以刚好是下标与k-1相等时返回,也就是说寻找第k大相当于寻找第n-k+1小的数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
def quickSelect(arr,left,right,k):
if k < 0:
return []

if left <= right:
i = left
j = right
key = arr[left]
while i < j:
while i < j and arr[j] > key:
j -= 1
while i < j and arr[i] <= key:
i += 1
if i < j:
arr[i],arr[j] = arr[j],arr[i]
arr[left],arr[j] = arr[j],arr[left]

if j == k:
return arr[j]

if j > k:
return quickSelect(arr,left,j-1,k)
else:
return quickSelect(arr,j+1,right,k)
# len(nums)-k 是数组下标
return quickSelect(nums,0,len(nums)-1,len(nums)-k)

老生常谈的优化,对于key的选择很关键,在LeetCode中,如果key总是为左边那个数,则时间耗时1100ms,而如果使用随机下标与left进行交换,则时间降至50ms以内。并且减少了递归所需要的内存消耗。具体部分代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
....
if left <= right:
i = left
j = right
a = random.randint(left, right)
arr[left],arr[a] = arr[a],arr[left]
key = arr[left]
while i < j:
while i < j and arr[j] > key:
j -= 1
while i < j and arr[i] <= key:
i += 1
if i < j:
arr[i],arr[j] = arr[j],arr[i]
arr[left],arr[j] = arr[j],arr[left]
...

partition思路

在官方答案中是将右边作为起始点,其思想大同小异,这里简单讲讲与我思路的不同。
首先随机选取一个pivot,并将这个数与最右边那个数进行一次交换。
第二步,定义i,j指针,初始化为left,循环退出条件为j指针等于最右边数的下标。查看nums[j]是否小于等于pivot,如果不是,则j向右移动。如果是,交换i,j位置的元素,并且i,j都向右移动。
第三步,重复第二步,直到j==right,此时交换i与j的元素,此时,i左边元素都小于它,右边元素都大于它。
以上就是另一种partition的思路。
这篇文章中的partition思路与那篇快速排序的文章相同。

总结

快速选择算法与快速排序思想一致,通过对数组进行partition来获取前k小的数,通过写这篇文章,再一次复习了快速排序算法,并对两种算法有了自己的认识与理解。

CATALOG
  1. 1. 前言
  2. 2. 实践1
    1. 2.1. 排序
    2. 2.2.
    3. 2.3. 快速选择
  3. 3. 实践2
    1. 3.1.
    2. 3.2. 快速选择
    3. 3.3. partition思路
  4. 4. 总结