第六章 - 高频考题(困难)
5999. 统计数组中好三元组数目

题目地址(5999. 统计数组中好三元组数目)

https://leetcode-cn.com/problems/count-good-triplets-in-an-array/

题目描述

给你两个下标从 0 开始且长度为 n 的整数数组 nums1 和 nums2 ,两者都是 [0, 1, ..., n - 1] 的 排列 。
好三元组 指的是 3 个 互不相同 的值,且它们在数组 nums1 和 nums2 中出现顺序保持一致。换句话说,如果我们将 pos1v 记为值 v 在 nums1 中出现的位置,pos2v 为值 v 在 nums2 中的位置,那么一个好三元组定义为 0 <= x, y, z <= n - 1 ,且 pos1x < pos1y < pos1z 和 pos2x < pos2y < pos2z 都成立的 (x, y, z) 。
请你返回好三元组的 总数目 。
示例 1:
输入:nums1 = [2,0,1,3], nums2 = [0,1,2,3]
输出:1
解释:
总共有 4 个三元组 (x,y,z) 满足 pos1x < pos1y < pos1z ,分别是 (2,0,1) ,(2,0,3) ,(2,1,3) 和 (0,1,3) 。
这些三元组中,只有 (0,1,3) 满足 pos2x < pos2y < pos2z 。所以只有 1 个好三元组。
示例 2:
输入:nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]
输出:4
解释:总共有 4 个好三元组 (4,0,3) ,(4,0,2) ,(4,1,3) 和 (4,1,2) 。
提示:
n == nums1.length == nums2.length
3 <= n <= 10^5
0 <= nums1[i], nums2[i] <= n - 1
nums1 和 nums2 是 [0, 1, ..., n - 1] 的排列。

前置知识

  • 平衡二叉树
  • 枚举

公司

  • 暂无

思路

本题的第一个关键点是:根据数组 A 的索引对应关系置换数组 B,得到新的数组 C,问题转化为堆 C 求递增三元组的个数
比如对于题目给的:nums1 = [2,0,1,3], nums2 = [0,1,2,3]
我们可以获取到 nums1 的索引对应关系,即 2->0, 0->1, 1->2, 3->3。
n = len(nums1)
for i in range(n):
d[nums1[i]] = i
用这个对应关系更新 nums2,最终得到的数组为 [1,2,0,3],我们只需要求 [1,2,0,3] 的递增三元组的个数即可。
for i in range(n):
nums.append(d[nums2[i]])
第二个关键单是枚举中间值 x,这样以 x 为中间值的递增三元组的个数就是 ycnt * zcnt(笛卡尔积),其中 ycnt 为 x 前面的比 x 小的,zcnt 为后面的比 x 大的。
比较容易想到的方法是使用两个平衡二叉树,代码:
sl1 = SortedList()
sl2 = SortedList(nums)
for num in nums:
sl1.add(num)
sl2.remove(num)
ans += sl1.bisect_left(num) * (len(sl2) - sl2.bisect_left(num + 1))
return ans
实际上使用 sl1 就足够了。这是因为 zcnt 其实也等价于所有比 x 大的数的总数 - sl1 中比 x 大的数的个数,而这个信息通过 sl1 就足以求得。具体代码见下方代码区。我们可以省去一个 SortedList 的开销,因此不管是空间复杂度还是时间复杂度都可以获得常数级别的优化。

关键点

  • 根据数组 A 的索引对应关系置换数组 B,得到新的数组 C,问题转化为堆 C 求递增三元组的个数
  • 枚举三元组中中间的数

代码

  • 语言支持:Python3
Python3 Code:
from sortedcontainers import SortedList
class Solution:
def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
d = {}
nums = []
ans = 0
n = len(nums1)
for i in range(n):
d[nums1[i]] = i
for i in range(n):
nums.append(d[nums2[i]])
sl1 = SortedList()
for num in nums:
sl1.add(num)
ans += sl1.bisect_left(num) * ((n - num - (len(sl1) - sl1.bisect_left(num))))
return ans
复杂度分析
令 n 为数组长度。
  • 时间复杂度:$O(nlogn)$
  • 空间复杂度:$O(n)$
此题解由 力扣刷题插件 自动生成。
力扣的小伙伴可以关注我,这样就会第一时间收到我的动态啦~
以上就是本文的全部内容了。大家对此有何看法,欢迎给我留言,我有时间都会一一查看回答。更多算法套路可以访问我的 LeetCode 题解仓库:https://github.com/azl397985856/leetcode 。 目前已经 40K star 啦。大家也可以关注我的公众号《力扣加加》带你啃下算法这块硬骨头。
关注公众号力扣加加,努力用清晰直白的语言还原解题思路,并且有大量图解,手把手教你识别套路,高效刷题。