--- comments: true difficulty: 困难 edit_url: https://github.com/doocs/leetcode/edit/main/solution/2100-2199/2179.Count%20Good%20Triplets%20in%20an%20Array/README.md rating: 2272 source: 第 72 场双周赛 Q4 tags: - 树状数组 - 线段树 - 数组 - 二分查找 - 分治 - 有序集合 - 归并排序 --- # [2179. 统计数组中好三元组数目](https://leetcode.cn/problems/count-good-triplets-in-an-array) [English Version](/solution/2100-2199/2179.Count%20Good%20Triplets%20in%20an%20Array/README_EN.md) ## 题目描述

给你两个下标从 0 开始且长度为 n 的整数数组 nums1 和 nums2 ,两者都是 [0, 1, ..., n - 1] 的 排列 。

好三元组 指的是 3 个 互不相同 的值,且它们在数组 nums1 和 nums2 中出现顺序保持一致。换句话说,如果我们将 pos1v 记为值 v 在 nums1 中出现的位置,pos2v 为值 v 在 nums2 中的位置,那么一个好三元组定义为 0 <= x, y, z <= n - 1 ,且 pos1x < pos1y < pos1z 和 pos2x < pos2y < pos2z 都成立的 (x, y, z) 。

请你返回好三元组的 总数目 。

 

示例 1:

输入:nums1 = [2,0,1,3], nums2 = [0,1,2,3]
输出:1
解释:
总共有 4 个三元组 (x,y,z) 满足 pos1x < pos1y < pos1,分别是 (2,0,1) ,(2,0,3) ,(2,1,3) 和 (0,1,3) 。
这些三元组中,只有 (0,1,3) 满足 pos2x < pos2y < pos2z 。所以只有 1 个好三元组。

示例 2:

输入:nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]
输出:4
解释:总共有 4 个好三元组 (4,0,3) ,(4,0,2) ,(4,1,3) 和 (4,1,2) 。

 

提示:

## 解法 ### 方法一:树状数组或线段树 对于本题,我们先用 pos 记录每个数在 nums2 中的位置,然后依次对 nums1 中的每个元素进行处理。 考虑**以当前数字作为三元组中间数字**的好三元组的数目。第一个数字需要是之前已经遍历过的,并且在 nums2 中的位置比当前数字更靠前的;第三个数字需要是当前还没有遍历过的,并且在 nums2 中的位置比当前数字更靠后的。 以 `nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]`为例,考虑我们的遍历过程: 1. 首先处理 4,此时 nums2 中出现情况为 `[4,X,X,X,X]`,4 之前有值的个数是 0,4 之后没有值的个数有 4 个。因此以 4 为中间数字能形成 0 个好三元组。 1. 接下来是 0,此时 nums2 中出现情况为 `[4,X,0,X,X]`,0 之前有值的个数是 1,0 之后没有值的个数有 2 个。因此以 0 为中间数字能形成 2 个好三元组。 1. 接下来是 1,此时 nums2 中出现情况为 `[4,1,0,X,X]`,1 之前有值的个数是 1,0 之后没有值的个数有 2 个。因此以 1 为中间数字能形成 2 个好三元组。 1. ... 1. 最后是 2,此时 nums2 中出现情况为 `[4,1,0,2,3]`,2 之前有值的个数是 4,2 之后没有值的个数是 0。因此以 2 为中间数字能形成 0 个好三元组。 我们可以用**树状数组**或**线段树**这两种数据结构来更新 nums2 中各个位置数字的出现情况,快速算出每个数字左侧 1 的个数,以及右侧 0 的个数。 **数据结构 1:树状数组** 树状数组,也称作“二叉索引树”(Binary Indexed Tree)或 Fenwick 树。 它可以高效地实现如下两个操作: 1. **单点更新** `update(x, delta)`: 把序列 x 位置的数加上一个值 delta; 1. **前缀和查询** `query(x)`:查询序列 `[1,...x]` 区间的区间和,即位置 x 的前缀和。 这两个操作的时间复杂度均为 $O(\log n)$。 **数据结构 2:线段树** 线段树将整个区间分割为多个不连续的子区间,子区间的数量不超过 `log(width)`。更新某个元素的值,只需要更新 `log(width)` 个区间,并且这些区间都包含在一个包含该元素的大区间内。 - 线段树的每个节点代表一个区间; - 线段树具有唯一的根节点,代表的区间是整个统计范围,如 `[1, N]`; - 线段树的每个叶子节点代表一个长度为 1 的元区间 `[x, x]`; - 对于每个内部节点 `[l, r]`,它的左儿子是 `[l, mid]`,右儿子是 `[mid + 1, r]`, 其中 `mid = ⌊(l + r) / 2⌋` (即向下取整)。 > 本题 Python3 线段树代码 TLE。 #### Python3 ```python class BinaryIndexedTree: def __init__(self, n): self.n = n self.c = [0] * (n + 1) @staticmethod def lowbit(x): return x & -x def update(self, x, delta): while x <= self.n: self.c[x] += delta x += BinaryIndexedTree.lowbit(x) def query(self, x): s = 0 while x > 0: s += self.c[x] x -= BinaryIndexedTree.lowbit(x) return s class Solution: def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int: pos = {v: i for i, v in enumerate(nums2, 1)} ans = 0 n = len(nums1) tree = BinaryIndexedTree(n) for num in nums1: p = pos[num] left = tree.query(p) right = n - p - (tree.query(n) - tree.query(p)) ans += left * right tree.update(p, 1) return ans ``` #### Java ```java class Solution { public long goodTriplets(int[] nums1, int[] nums2) { int n = nums1.length; int[] pos = new int[n]; BinaryIndexedTree tree = new BinaryIndexedTree(n); for (int i = 0; i < n; ++i) { pos[nums2[i]] = i + 1; } long ans = 0; for (int num : nums1) { int p = pos[num]; long left = tree.query(p); long right = n - p - (tree.query(n) - tree.query(p)); ans += left * right; tree.update(p, 1); } return ans; } } class BinaryIndexedTree { private int n; private int[] c; public BinaryIndexedTree(int n) { this.n = n; c = new int[n + 1]; } public void update(int x, int delta) { while (x <= n) { c[x] += delta; x += lowbit(x); } } public int query(int x) { int s = 0; while (x > 0) { s += c[x]; x -= lowbit(x); } return s; } public static int lowbit(int x) { return x & -x; } } ``` #### C++ ```cpp class BinaryIndexedTree { public: int n; vector c; BinaryIndexedTree(int _n) : n(_n) , c(_n + 1) {} void update(int x, int delta) { while (x <= n) { c[x] += delta; x += lowbit(x); } } int query(int x) { int s = 0; while (x > 0) { s += c[x]; x -= lowbit(x); } return s; } int lowbit(int x) { return x & -x; } }; class Solution { public: long long goodTriplets(vector& nums1, vector& nums2) { int n = nums1.size(); vector pos(n); for (int i = 0; i < n; ++i) pos[nums2[i]] = i + 1; BinaryIndexedTree* tree = new BinaryIndexedTree(n); long long ans = 0; for (int& num : nums1) { int p = pos[num]; int left = tree->query(p); int right = n - p - (tree->query(n) - tree->query(p)); ans += 1ll * left * right; tree->update(p, 1); } return ans; } }; ``` #### Go ```go type BinaryIndexedTree struct { n int c []int } func newBinaryIndexedTree(n int) *BinaryIndexedTree { c := make([]int, n+1) return &BinaryIndexedTree{n, c} } func (this *BinaryIndexedTree) lowbit(x int) int { return x & -x } func (this *BinaryIndexedTree) update(x, delta int) { for x <= this.n { this.c[x] += delta x += this.lowbit(x) } } func (this *BinaryIndexedTree) query(x int) int { s := 0 for x > 0 { s += this.c[x] x -= this.lowbit(x) } return s } func goodTriplets(nums1 []int, nums2 []int) int64 { n := len(nums1) pos := make([]int, n) for i, v := range nums2 { pos[v] = i + 1 } tree := newBinaryIndexedTree(n) var ans int64 for _, num := range nums1 { p := pos[num] left := tree.query(p) right := n - p - (tree.query(n) - tree.query(p)) ans += int64(left) * int64(right) tree.update(p, 1) } return ans } ``` ### 方法二 #### Python3 ```python class Node: def __init__(self): self.l = 0 self.r = 0 self.v = 0 class SegmentTree: def __init__(self, n): self.tr = [Node() for _ in range(4 * n)] self.build(1, 1, n) def build(self, u, l, r): self.tr[u].l = l self.tr[u].r = r if l == r: return mid = (l + r) >> 1 self.build(u << 1, l, mid) self.build(u << 1 | 1, mid + 1, r) def modify(self, u, x, v): if self.tr[u].l == x and self.tr[u].r == x: self.tr[u].v += v return mid = (self.tr[u].l + self.tr[u].r) >> 1 if x <= mid: self.modify(u << 1, x, v) else: self.modify(u << 1 | 1, x, v) self.pushup(u) def pushup(self, u): self.tr[u].v = self.tr[u << 1].v + self.tr[u << 1 | 1].v def query(self, u, l, r): if self.tr[u].l >= l and self.tr[u].r <= r: return self.tr[u].v mid = (self.tr[u].l + self.tr[u].r) >> 1 v = 0 if l <= mid: v += self.query(u << 1, l, r) if r > mid: v += self.query(u << 1 | 1, l, r) return v class Solution: def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int: pos = {v: i for i, v in enumerate(nums2, 1)} ans = 0 n = len(nums1) tree = SegmentTree(n) for num in nums1: p = pos[num] left = tree.query(1, 1, p) right = n - p - (tree.query(1, 1, n) - tree.query(1, 1, p)) ans += left * right tree.modify(1, p, 1) return ans ``` #### Java ```java class Solution { public long goodTriplets(int[] nums1, int[] nums2) { int n = nums1.length; int[] pos = new int[n]; SegmentTree tree = new SegmentTree(n); for (int i = 0; i < n; ++i) { pos[nums2[i]] = i + 1; } long ans = 0; for (int num : nums1) { int p = pos[num]; long left = tree.query(1, 1, p); long right = n - p - (tree.query(1, 1, n) - tree.query(1, 1, p)); ans += left * right; tree.modify(1, p, 1); } return ans; } } class Node { int l; int r; int v; } class SegmentTree { private Node[] tr; public SegmentTree(int n) { tr = new Node[4 * n]; for (int i = 0; i < tr.length; ++i) { tr[i] = new Node(); } build(1, 1, n); } public void build(int u, int l, int r) { tr[u].l = l; tr[u].r = r; if (l == r) { return; } int mid = (l + r) >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); } public void modify(int u, int x, int v) { if (tr[u].l == x && tr[u].r == x) { tr[u].v += v; return; } int mid = (tr[u].l + tr[u].r) >> 1; if (x <= mid) { modify(u << 1, x, v); } else { modify(u << 1 | 1, x, v); } pushup(u); } public void pushup(int u) { tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v; } public int query(int u, int l, int r) { if (tr[u].l >= l && tr[u].r <= r) { return tr[u].v; } int mid = (tr[u].l + tr[u].r) >> 1; int v = 0; if (l <= mid) { v += query(u << 1, l, r); } if (r > mid) { v += query(u << 1 | 1, l, r); } return v; } } ``` #### C++ ```cpp class Node { public: int l; int r; int v; }; class SegmentTree { public: vector tr; SegmentTree(int n) { tr.resize(4 * n); for (int i = 0; i < tr.size(); ++i) tr[i] = new Node(); build(1, 1, n); } void build(int u, int l, int r) { tr[u]->l = l; tr[u]->r = r; if (l == r) return; int mid = (l + r) >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); } void modify(int u, int x, int v) { if (tr[u]->l == x && tr[u]->r == x) { tr[u]->v += v; return; } int mid = (tr[u]->l + tr[u]->r) >> 1; if (x <= mid) modify(u << 1, x, v); else modify(u << 1 | 1, x, v); pushup(u); } void pushup(int u) { tr[u]->v = tr[u << 1]->v + tr[u << 1 | 1]->v; } int query(int u, int l, int r) { if (tr[u]->l >= l && tr[u]->r <= r) return tr[u]->v; int mid = (tr[u]->l + tr[u]->r) >> 1; int v = 0; if (l <= mid) v += query(u << 1, l, r); if (r > mid) v += query(u << 1 | 1, l, r); return v; } }; class Solution { public: long long goodTriplets(vector& nums1, vector& nums2) { int n = nums1.size(); vector pos(n); for (int i = 0; i < n; ++i) pos[nums2[i]] = i + 1; SegmentTree* tree = new SegmentTree(n); long long ans = 0; for (int& num : nums1) { int p = pos[num]; int left = tree->query(1, 1, p); int right = n - p - (tree->query(1, 1, n) - tree->query(1, 1, p)); ans += 1ll * left * right; tree->modify(1, p, 1); } return ans; } }; ```