由于 Python 没有大顶堆。因此我这里使用了小顶堆进行模拟实现。即将原有的数全部取相反数,比如原数字是 5,就将 -5 入堆。经过这样的处理,小顶堆就可以当成大顶堆用了。不过需要注意的是,当你 pop 出来的时候, 记得也要取反,将其还原回来哦。
代码示例:
h = []
A = [1,2,3,4,5]
for a in A:
heapq.heappush(h, -a)
-1 * heapq.heappop(h) # 5
-1 * heapq.heappop(h) # 4
-1 * heapq.heappop(h) # 3
-1 * heapq.heappop(h) # 2
-1 * heapq.heappop(h) # 1
用图来表示就是下面这样:
铺垫就到这里,接下来进入正题。
三个技巧
技巧一 - 固定堆
这个技巧指的是固定堆的大小 k 不变,代码上可通过每 pop 出去一个就 push 进来一个来实现。而由于初始堆可能是 0,我们刚开始需要一个一个 push 进堆以达到堆的大小为 k,因此严格来说应该是维持堆的大小不大于 k。
固定堆一个典型的应用就是求第 k 小的数。其实求第 k 小的数最简单的思路是建立小顶堆,将所有的数先全部入堆,然后逐个出堆,一共出堆 k 次。最后一次出堆的就是第 k 小的数。
然而,我们也可不先全部入堆,而是建立大顶堆(注意不是上面的小顶堆),并维持堆的大小为 k 个。如果新的数入堆之后堆的大小大于 k,则需要将堆顶的数和新的数进行比较,并将较大的移除。这样可以保证堆中的数是全体数字中最小的 k 个,而这最小的 k 个中最大的(即堆顶)不就是第 k 小的么?这也就是选择建立大顶堆,而不是小顶堆的原因。
简单一句话总结就是固定一个大小为 k 的大顶堆可以快速求第 k 小的数,反之固定一个大小为 k 的小顶堆可以快速求第 k 大的数。比如力扣 2020-02-24 的周赛第三题5663. 找出第 K 大的异或坐标值就可以用固定小顶堆技巧来实现(这道题让你求第 k 大的数)。
又因为 sum_of_q 一定的时候, w/q 越小,总工资越小。因此我们可以从小到大枚举 w/q,然后在其中选 k 个 最小的q,使得总工资最小。
因此思路就是:
枚举最大的 w/q, 然后用堆存储 k 个 q。当堆中元素大于 k 个时,将最大的 q 移除。
由于移除的时候我们希望移除“最大的”q,因此用大根堆
于是我们可以写出下面的代码:
class Solution:
def mincostToHireWorkers(self, quality: List[int], wage: List[int], K: int) -> float:
eff = [(w/q, q, w) for q, w in zip(quality, wage)]
eff.sort(key=lambda a: a[0])
ans = float('inf')
for i in range(K-1, len(eff)):
h = []
k = K - 1
rate, _, total = eff[i]
# 找出工作效率比它高的 k 个人,这 k 个人的工资尽可能低。
# 由于已经工作效率倒序排了,因此前面的都是比它高的,然后使用堆就可得到 k 个工资最低的。
for j in range(i):
heapq.heappush(h, eff[j][1] * rate)
while k > 0:
total += heapq.heappop(h)
k -= 1
ans = min(ans, total)
return ans
(代码 1.3.2)
这种做法每次都 push 很多数,并 pop k 次,并没有很好地利用堆的动态特性,而只利用了其求极值的特性。
一个更好的做法是使用固定堆技巧。
这道题可以换个角度思考。其实这道题不就是让我们选 k 个人,工作效率比取他们中最低的,并按照这个最低的工作效率计算总工资,找出最低的总工资么? 因此这道题可以固定一个大小为 k 的大顶堆,通过一定操作保证堆顶的就是第 k 小的(操作和前面的题类似)。
并且前面的解法中堆使用了三元组 (q / w, q, w),实际上这也没有必要。因为已知其中两个,可推导出另外一个,因此存储两个就行了,而又由于我们需要根据工作效率比做堆的键,因此任意选一个 q 或者 w 即可,这里我选择了 q,即存 (q/2, q) 二元组。
具体来说就是:以 rate 为最低工作效率比的 k 个人的总工资 = $\displaystyle \sum_{n=1}^{k}{q}_{n}/rate$,这里的 rate 就是当前的 q / w,同时也是 k 个人的 q / w 的最小值。
代码
class Solution:
def mincostToHireWorkers(self, quality: List[int], wage: List[int], K: int) -> float:
# 如果最大的 w/q 确定,那么总工资就是确定的。就是 sum_of_q * w/q, 也就是说 sum_of_q 越小,总工资越小
# 枚举最大的 w/q, 然后用堆在其中选 k 个 q 即可。由于移除的时候我们希望移除“最大的”q,因此用大根堆
A = [(w/q, q) for w, q in zip(wage, quality)]
A.sort()
ans = inf
sum_of_q = 0
h = []
for rate, q in A:
heapq.heappush(h, -q)
sum_of_q += q
if len(h) == K:
ans = min(ans, sum_of_q * rate)
sum_of_q += heapq.heappop(h)
return ans
(代码 1.3.3)
技巧二 - 多路归并
这个技巧其实在前面讲超级丑数的时候已经提到了,只是没有给这种类型的题目一个名字。
其实这个技巧,叫做多指针优化可能会更合适,只不过这个名字实在太过朴素且容易和双指针什么的混淆,因此我给 ta 起了个别致的名字 - 多路归并。
多路体现在:有多条候选路线。代码上,我们可使用多指针来表示。
归并体现在:结果可能是多个候选路线中最长的或者最短,也可能是第 k 个 等。因此我们需要对多条路线的结果进行比较,并根据题目描述舍弃或者选取某一个或多个路线。
这样描述比较抽象,接下来通过几个例子来加深一下大家的理解。
这里我给大家精心准备了四道难度为 hard 的题目。 掌握了这个套路就可以去快乐地 AC 这四道题啦。
1439. 有序矩阵中的第 k 个最小数组和
题目描述
给你一个 m * n 的矩阵 mat,以及一个整数 k ,矩阵中的每一行都以非递减的顺序排列。
你可以从每一行中选出 1 个元素形成一个数组。返回所有可能数组中的第 k 个 最小 数组和。
示例 1:
输入:mat = [[1,3,11],[2,4,6]], k = 5
输出:7
解释:从每一行中选出一个元素,前 k 个和最小的数组分别是:
[1,2], [1,4], [3,2], [3,4], [1,6]。其中第 5 个的和是 7 。
示例 2:
输入:mat = [[1,3,11],[2,4,6]], k = 9
输出:17
示例 3:
输入:mat = [[1,10,10],[1,4,5],[2,3,6]], k = 7
输出:9
解释:从每一行中选出一个元素,前 k 个和最小的数组分别是:
[1,1,2], [1,1,3], [1,4,2], [1,4,3], [1,1,6], [1,5,2], [1,5,3]。其中第 7 个的和是 9 。
示例 4:
输入:mat = [[1,1,10],[2,2,9]], k = 7
输出:12
提示:
m == mat.length
n == mat.length[i]
1 <= m, n <= 40
1 <= k <= min(200, n ^ m)
1 <= mat[i][j] <= 5000
mat[i] 是一个非递减数组
思路
其实这道题就是给你 m 个长度均相同的一维数组,让我们从这 m 个数组中分别选出一个数,即一共选取 m 个数,求这 m 个数的和是所有选取可能性中和第 k 小的。
一个朴素的想法是使用多指针来解。对于这道题来说就是使用 m 个指针,分别指向 m 个一维数组,指针的位置表示当前选取的是该一维数组中第几个。
上面提到了题目需要求的其实是第 k 小的和,而最小的我们是容易知道的,即所有的一维数组首项和。我们又发现,根据最小的,我们可以推导出第 2 小,推导的方式就是移动其中一个指针,这就一共分裂出了 n 种情况了,其中 n 为一维数组长度,第 2 小的就在这分裂中的 n 种情况中,而筛选的方式是这 n 种情况和最小的,后面的情况也是类似。不难看出每次分裂之后极值也发生了变化,因此这是一个明显的求动态求极值的信号,使用堆是一个不错的选择。
那代码该如何书写呢?
上面说了,我们先要初始化 m 个指针,并赋值为 0。对应伪代码:
# 初始化堆
h = []
# sum(vec[0] for vec in mat) 是 m 个一维数组的首项和
# [0] * m 就是初始化了一个长度为 m 且全部填充为 0 的数组。
# 我们将上面的两个信息组装成元祖 cur 方便使用
cur = (sum(vec[0] for vec in mat), [0] * m)
# 将其入堆
heapq.heappush(h, cur)
接下来,我们每次都移动一个指针,从而形成分叉出一条新的分支。每次从堆中弹出一个最小的,弹出 k 次就是第 k 小的了。伪代码:
for 1 to K:
# acc 当前的和, pointers 是指针情况。
acc, pointers = heapq.heappop(h)
# 每次都粗暴地移动指针数组中的一个指针。每移动一个指针就分叉一次, 一共可能移动的情况是 n,其中 n 为一维数组的长度。
for i, pointer in enumerate(pointers):
# 如果 pointer == len(mat[0]) - 1 说明到头了,不能移动了
if pointer != len(mat[0]) - 1:
# 下面两句话的含义是修改 pointers[i] 的指针 为 pointers[i] + 1
new_pointers = pointers.copy()
new_pointers[i] += 1
# 将更新后的 acc 和指针数组重新入堆
heapq.heappush(h, (acc + mat[i][pointer + 1] - mat[i][pointer], new_pointers))
class Solution:
def kthSmallest(self, mat, k: int) -> int:
h = []
cur = (sum(vec[0] for vec in mat), tuple([0] * len(mat)))
heapq.heappush(h, cur)
seen = set(cur)
for _ in range(k):
acc, pointers = heapq.heappop(h)
for i, pointer in enumerate(pointers):
if pointer != len(mat[0]) - 1:
t = list(pointers)
t[i] = pointer + 1
tt = tuple(t)
if tt not in seen:
seen.add(tt)
heapq.heappush(h, (acc + mat[i][pointer + 1] - mat[i][pointer], tt))
return acc
(代码 1.3.4)
719. 找出第 k 小的距离对
题目描述
给定一个整数数组,返回所有数对之间的第 k 个最小距离。一对 (A, B) 的距离被定义为 A 和 B 之间的绝对差值。
示例 1:
输入:
nums = [1,3,1]
k = 1
输出:0
解释:
所有数对如下:
(1,3) -> 2
(1,1) -> 0
(3,1) -> 2
因此第 1 个最小距离的数对是 (1,1),它们之间的距离为 0。
提示:
2 <= len(nums) <= 10000.
0 <= nums[i] < 1000000.
1 <= k <= len(nums) * (len(nums) - 1) / 2.
思路
不难看出所有的数对可能共 $C_n^2$ 个,也就是 $n\times(n-1)\div2$。
因此我们可以使用两次循环找出所有的数对,并升序排序,之后取第 k 个。
实际上,我们可使用固定堆技巧,维护一个大小为 k 的大顶堆,这样堆顶的元素就是第 k 小的,这在前面的固定堆中已经讲过,不再赘述。
class Solution:
def smallestDistancePair(self, nums: List[int], k: int) -> int:
h = []
for i in range(len(nums)):
for j in range(i + 1, len(nums)):
a, b = nums[i], nums[j]
# 维持堆大小不超过 k
if len(h) == k and -abs(a - b) > h[0]:
heapq.heappop(h)
if len(h) < k:
heapq.heappush(h, -abs(a - b))
return -h[0]
(代码 1.3.5)
不过这种优化意义不大,因为算法的瓶颈在于 $N^2$ 部分的枚举,我们应当设法优化这一点。
如果我们将数对进行排序,那么最小的数对距离一定在 nums[i] - nums[i - 1] 中,其中 i 为从 1 到 n 的整数,究竟是哪个取决于谁更小。接下来就可以使用上面多路归并的思路来解决了。
如果 nums[i] - nums[i - 1] 的差是最小的,那么第 2 小的一定是剩下的 n - 1 种情况和 nums[i] - nums[i - 1] 分裂的新情况。关于如何分裂,和上面类似,我们只需要移动其中 i 的指针为 i + 1 即可。这里的指针数组长度固定为 2,而不是上面题目中的 m。这里我将两个指针分别命名为 fr 和 to,分别代表 from 和 to。
代码
class Solution(object):
def smallestDistancePair(self, nums, k):
nums.sort()
# n 种候选答案
h = [(nums[i+1] - nums[i], i, i+1) for i in range(len(nums) - 1)]
heapq.heapify(h)
for _ in range(k):
diff, fr, to = heapq.heappop(h)
if to + 1 < len(nums):
heapq.heappush((nums[to + 1] - nums[fr], fr, to + 1))
return diff
(代码 1.3.6)
由于时间复杂度和 k 有关,而 k 最多可能达到 $N^2$ 的量级,因此此方法实际上也会超时。不过这证明了这种思路的正确性,如果题目稍加改变说不定就能用上。
这道题可通过二分法来解决,由于和堆主题有偏差,因此这里简单讲一下。
求第 k 小的数比较容易想到的就是堆和二分法。二分的原因在于求第 k 小,本质就是求不大于其本身的有 k - 1 个的那个数。而这个问题很多时候满足单调性,因此就可使用二分来解决。
而我们知道,发问的答案也是不严格递减的,因此使用二分就应该被想到。我们不断发问直到问到小于 x 的有 k - 1 个即可。然而这样的发问也有问题。原因有两个:
小于 x 的有 k - 1 个的数可能不止一个
我们无法确定小于 x 的有 k - 1 个的数一定存在。 比如数对差分别为 [1,1,1,1,2],让你求第 3 大的,那么小于 x 有两个的数根本就不存在。
我们的思路可调整为求小于等于 x 有 k 个的,接下来我们使用二分法的最左模板即可解决。关于最左模板可参考我的二分查找专题
代码:
class Solution:
def smallestDistancePair(self, A: List[int], K: int) -> int:
A.sort()
l, r = 0, A[-1] - A[0]
def count_ngt(mid):
slow = 0
ans = 0
for fast in range(len(A)):
while A[fast] - A[slow] > mid:
slow += 1
ans += fast - slow
return ans
while l <= r:
mid = (l + r) // 2
if count_ngt(mid) >= K:
r = mid - 1
else:
l = mid + 1
return l
class Solution:
def smallestRange(self, martrix: List[List[int]]) -> List[int]:
l, r = -10**9, 10**9
# 将每一行最小的都放到堆中,同时记录其所在的行号和列号,一共 n 个齐头并进
h = [(row[0], i, 0) for i, row in enumerate(martrix)]
heapq.heapify(h)
# 维护最大值
max_v = max(row[0] for row in martrix)
while True:
min_v, row, col = heapq.heappop(h)
# max_v - min_v 是当前的最大最小差值, r - l 为全局的最大最小差值。因为如果当前的更小,我们就更新全局结果
if max_v - min_v < r - l:
l, r = min_v, max_v
if col == len(martrix[row]) - 1: return [l, r]
# 更新指针,继续往后移动一位
heapq.heappush(h, (martrix[row][col + 1], row, col + 1))
max_v = max(max_v, martrix[row][col + 1])
class Solution:
def smallestRange(self, martrix: List[List[int]]) -> List[int]:
l, r = -10**9, 10**9
# 将每一行最小的都放到堆中,同时记录其所在的行号和列号,一共 n 个齐头并进
h = [(row[0], i, 0) for i, row in enumerate(martrix)]
heapq.heapify(h)
# 维护最大值
max_v = max(row[0] for row in martrix)
while True:
min_v, row, col = heapq.heappop(h)
# max_v - min_v 是当前的最大最小差值, r - l 为全局的最大最小差值。因为如果当前的更小,我们就更新全局结果
if max_v - min_v < r - l:
l, r = min_v, max_v
if col == len(martrix[row]) - 1: return [l, r]
# 更新指针,继续往后移动一位
heapq.heappush(h, (martrix[row][col + 1], row, col + 1))
max_v = max(max_v, martrix[row][col + 1])
def minimumDeviation(self, nums: List[int]) -> int:
matrix = [[] for _ in range(len(nums))]
for i, num in enumerate(nums):
if num & 1 == 1:
matrix[i] += [num, num * 2]
else:
temp = []
while num and num & 1 == 0:
temp += [num]
num //= 2
temp += [num]
matrix[i] += temp[::-1]
a, b = self.smallestRange(matrix)
return b - a
cur = startFuel # 刚开始有 startFuel 升汽油
last = 0 # 上一次的位置
for i, fuel in stations:
cur -= i - last # 走过两个 staton 的耗油为两个 station 的距离,也就是 i - last
if cur < 0:
# 我们必须在前面就加油,否则到不了这里
# 但是在前面的哪个 station 加油呢?
# 直觉告诉我们应该贪心地选择可以加汽油最多的站 i,如果加上 i 的汽油还是 cur < 0,继续加次大的站 j,直到没有更多汽油可加或者 cur > 0
所以这个事后诸葛亮本质上解决的是,基于当前信息无法获取最优解,我们必须掌握全部信息之后回溯。以这道题来说,我们可以先遍历一边 station,然后将每个 station 的油量记录到一个数组中,每次我们“预见“到无法到达下个站的时候,就从这个数组中取最大的。。。。 基于此,我们可以考虑使用堆优化取极值的过程,而不是使用数组的方式。
代码
class Solution:
def minRefuelStops(self, target: int, startFuel: int, stations: List[List[int]]) -> int:
stations += [(target, 0)]
cur = startFuel
ans = 0
h = []
last = 0
for i, fuel in stations:
cur -= i - last
while cur < 0 and h:
cur -= heapq.heappop(h)
ans += 1
if cur < 0:
return -1
heappush(h, -fuel)
last = i
return ans
class Solution:
def furthestBuilding(self, heights: List[int], bricks: int, ladders: int) -> int:
h = []
for i in range(1, len(heights)):
diff = heights[i] - heights[i - 1]
if diff <= 0:
continue
if bricks < diff and ladders > 0:
ladders -= 1
if h and -h[0] > diff:
bricks -= heapq.heappop(h)
else:
continue
bricks -= diff
if bricks < 0:
return i - 1
heapq.heappush(h, -diff)
return len(heights) - 1
(代码 1.3.12)
四大应用
接下来是本文的最后一个部分《四大应用》,目的是通过这几个例子来帮助大家巩固前面的知识。
1. topK
求解 topK 是堆的一个很重要的功能。这个其实已经在前面的固定堆部分给大家介绍过了。
这里直接引用前面的话:
“其实求第 k 小的数最简单的思路是建立小顶堆,将所有的数先全部入堆,然后逐个出堆,一共出堆 k 次。最后一次出堆的就是第 k 小的数。然而,我们也可不先全部入堆,而是建立大顶堆(注意不是上面的小顶堆),并维持堆的大小为 k 个。如果新的数入堆之后堆的大小大于 k,则需要将堆顶的数和新的数进行比较,并将较大的移除。这样可以保证堆中的数是全体数字中最小的 k 个,而这最小的 k 个中最大的(即堆顶)不就是第 k 小的么?这也就是选择建立大顶堆,而不是小顶堆的原因。”
其实除了第 k 小的数,我们也可以将中间的数全部收集起来,这就可以求出最小的 k 个数。和上面第 k 小的数唯一不同的点在于需要收集 popp 出来的所有的数。
算法的基本思想是贪心,每次都遍历所有邻居,并从中找到距离最小的,本质上是一种广度优先遍历。这里我们借助堆这种数据结构,使得可以在 $logN$ 的时间内找到 cost 最小的点,其中 N 为 堆的大小。
代码模板:
def dijkstra(graph, start, end):
# 堆里的数据都是 (cost, i) 的二元祖,其含义是“从 start 走到 i 的距离是 cost”。
heap = [(0, start)]
visited = set()
while heap:
(cost, u) = heapq.heappop(heap)
if u in visited:
continue
visited.add(u)
if u == end:
return cost
for v, c in graph[u]:
if v in visited:
continue
next = cost + c
heapq.heappush(heap, (next, v))
return -1
class Solution:
def dijkstra(self, graph, start, end):
heap = [(0, start)]
visited = set()
while heap:
(cost, u) = heapq.heappop(heap)
if u in visited:
continue
visited.add(u)
if u == end:
return cost
for v, c in graph[u]:
if v in visited:
continue
next = cost + c
heapq.heappush(heap, (next, v))
return -1
def networkDelayTime(self, times: List[List[int]], N: int, K: int) -> int:
graph = collections.defaultdict(list)
for fr, to, w in times:
graph[fr - 1].append((to - 1, w))
ans = -1
for to in range(N):
# 调用封装好的 dijkstra 方法
dist = self.dijkstra(graph, K - 1, to)
if dist == -1: return -1
ans = max(ans, dist)
return ans
(代码 1.4.2)
你学会了么?
上面的算法并不是最优解,我只是为了体现将 dijkstra 封装为 api 调用 的思想。一个更好的做法是一次遍历记录所有的距离信息,而不是每次都重复计算。时间复杂度会大大降低。这在计算一个点到图中所有点的距离时有很大的意义。 为了实现这个目的,我们的算法会有什么样的调整?
class Solution:
def dijkstra(self, graph, start, end):
heap = [(0, start)] # cost from start node,end node
dist = {}
while heap:
(cost, u) = heapq.heappop(heap)
if u in dist:
continue
dist[u] = cost
for v, c in graph[u]:
if v in dist:
continue
next = cost + c
heapq.heappush(heap, (next, v))
return dist
def networkDelayTime(self, times: List[List[int]], N: int, K: int) -> int:
graph = collections.defaultdict(list)
for fr, to, w in times:
graph[fr - 1].append((to - 1, w))
ans = -1
dist = self.dijkstra(graph, K - 1, to)
return -1 if len(dist) != N else max(dist.values())