--- comments: true difficulty: 困难 edit_url: https://github.com/doocs/leetcode/edit/main/solution/2100-2199/2179.Count%20Good%20Triplets%20in%20an%20Array/README.md rating: 2272 source: 第 72 场双周赛 Q4 tags: - 树状数组 - 线段树 - 数组 - 二分查找 - 分治 - 有序集合 - 归并排序 --- # [2179. 统计数组中好三元组数目](https://leetcode.cn/problems/count-good-triplets-in-an-array) [English Version](/solution/2100-2199/2179.Count%20Good%20Triplets%20in%20an%20Array/README_EN.md) ## 题目描述

给你两个下标从 0 开始且长度为 n 的整数数组 nums1 和 nums2 ,两者都是 [0, 1, ..., n - 1] 的 排列 。

好三元组 指的是 3 个 互不相同 的值,且它们在数组 nums1 和 nums2 中出现顺序保持一致。换句话说,如果我们将 pos1v 记为值 v 在 nums1 中出现的位置,pos2v 为值 v 在 nums2 中的位置,那么一个好三元组定义为 0 <= x, y, z <= n - 1 ,且 pos1x < pos1y < pos1z 和 pos2x < pos2y < pos2z 都成立的 (x, y, z) 。

请你返回好三元组的 总数目 。

 

示例 1:

输入:nums1 = [2,0,1,3], nums2 = [0,1,2,3]
输出:1
解释:
总共有 4 个三元组 (x,y,z) 满足 pos1x < pos1y < pos1,分别是 (2,0,1) ,(2,0,3) ,(2,1,3) 和 (0,1,3) 。
这些三元组中,只有 (0,1,3) 满足 pos2x < pos2y < pos2z 。所以只有 1 个好三元组。

示例 2:

输入:nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]
输出:4
解释:总共有 4 个好三元组 (4,0,3) ,(4,0,2) ,(4,1,3) 和 (4,1,2) 。

 

提示:

## 解法 ### 方法一:树状数组 对于本题,我们先用 pos 记录每个数在 nums2 中的位置,然后依次对 nums1 中的每个元素进行处理。 考虑**以当前数字作为三元组中间数字**的好三元组的数目。第一个数字需要是之前已经遍历过的,并且在 nums2 中的位置比当前数字更靠前的;第三个数字需要是当前还没有遍历过的,并且在 nums2 中的位置比当前数字更靠后的。 以 `nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]`为例,考虑我们的遍历过程: 1. 首先处理 4,此时 nums2 中出现情况为 `[4,X,X,X,X]`,4 之前有值的个数是 0,4 之后没有值的个数有 4 个。因此以 4 为中间数字能形成 0 个好三元组。 1. 接下来是 0,此时 nums2 中出现情况为 `[4,X,0,X,X]`,0 之前有值的个数是 1,0 之后没有值的个数有 2 个。因此以 0 为中间数字能形成 2 个好三元组。 1. 接下来是 1,此时 nums2 中出现情况为 `[4,1,0,X,X]`,1 之前有值的个数是 1,0 之后没有值的个数有 2 个。因此以 1 为中间数字能形成 2 个好三元组。 1. ... 1. 最后是 2,此时 nums2 中出现情况为 `[4,1,0,2,3]`,2 之前有值的个数是 4,2 之后没有值的个数是 0。因此以 2 为中间数字能形成 0 个好三元组。 我们可以用**树状数组**来更新 nums2 中各个位置数字的出现情况,快速算出每个数字左侧 1 的个数,以及右侧 0 的个数。 树状数组,也称作“二叉索引树”(Binary Indexed Tree)或 Fenwick 树。 它可以高效地实现如下两个操作: 1. **单点更新** `update(x, delta)`: 把序列 x 位置的数加上一个值 delta; 1. **前缀和查询** `query(x)`:查询序列 `[1,...x]` 区间的区间和,即位置 x 的前缀和。 这两个操作的时间复杂度均为 $O(\log n)$。因此,整体的时间复杂度为 $O(n \log n)$,其中 $n$ 为数组 $\textit{nums1}$ 的长度。空间复杂度 $O(n)$。 #### Python3 ```python class BinaryIndexedTree: def __init__(self, n): self.n = n self.c = [0] * (n + 1) @staticmethod def lowbit(x): return x & -x def update(self, x, delta): while x <= self.n: self.c[x] += delta x += BinaryIndexedTree.lowbit(x) def query(self, x): s = 0 while x > 0: s += self.c[x] x -= BinaryIndexedTree.lowbit(x) return s class Solution: def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int: pos = {v: i for i, v in enumerate(nums2, 1)} ans = 0 n = len(nums1) tree = BinaryIndexedTree(n) for num in nums1: p = pos[num] left = tree.query(p) right = n - p - (tree.query(n) - tree.query(p)) ans += left * right tree.update(p, 1) return ans ``` #### Java ```java class Solution { public long goodTriplets(int[] nums1, int[] nums2) { int n = nums1.length; int[] pos = new int[n]; BinaryIndexedTree tree = new BinaryIndexedTree(n); for (int i = 0; i < n; ++i) { pos[nums2[i]] = i + 1; } long ans = 0; for (int num : nums1) { int p = pos[num]; long left = tree.query(p); long right = n - p - (tree.query(n) - tree.query(p)); ans += left * right; tree.update(p, 1); } return ans; } } class BinaryIndexedTree { private int n; private int[] c; public BinaryIndexedTree(int n) { this.n = n; c = new int[n + 1]; } public void update(int x, int delta) { while (x <= n) { c[x] += delta; x += lowbit(x); } } public int query(int x) { int s = 0; while (x > 0) { s += c[x]; x -= lowbit(x); } return s; } public static int lowbit(int x) { return x & -x; } } ``` #### C++ ```cpp class BinaryIndexedTree { public: int n; vector c; BinaryIndexedTree(int _n) : n(_n) , c(_n + 1) {} void update(int x, int delta) { while (x <= n) { c[x] += delta; x += lowbit(x); } } int query(int x) { int s = 0; while (x > 0) { s += c[x]; x -= lowbit(x); } return s; } int lowbit(int x) { return x & -x; } }; class Solution { public: long long goodTriplets(vector& nums1, vector& nums2) { int n = nums1.size(); vector pos(n); for (int i = 0; i < n; ++i) pos[nums2[i]] = i + 1; BinaryIndexedTree* tree = new BinaryIndexedTree(n); long long ans = 0; for (int& num : nums1) { int p = pos[num]; int left = tree->query(p); int right = n - p - (tree->query(n) - tree->query(p)); ans += 1ll * left * right; tree->update(p, 1); } return ans; } }; ``` #### Go ```go type BinaryIndexedTree struct { n int c []int } func newBinaryIndexedTree(n int) *BinaryIndexedTree { c := make([]int, n+1) return &BinaryIndexedTree{n, c} } func (this *BinaryIndexedTree) lowbit(x int) int { return x & -x } func (this *BinaryIndexedTree) update(x, delta int) { for x <= this.n { this.c[x] += delta x += this.lowbit(x) } } func (this *BinaryIndexedTree) query(x int) int { s := 0 for x > 0 { s += this.c[x] x -= this.lowbit(x) } return s } func goodTriplets(nums1 []int, nums2 []int) int64 { n := len(nums1) pos := make([]int, n) for i, v := range nums2 { pos[v] = i + 1 } tree := newBinaryIndexedTree(n) var ans int64 for _, num := range nums1 { p := pos[num] left := tree.query(p) right := n - p - (tree.query(n) - tree.query(p)) ans += int64(left) * int64(right) tree.update(p, 1) } return ans } ``` #### TypeScript ```ts class BinaryIndexedTree { private c: number[]; private n: number; constructor(n: number) { this.n = n; this.c = Array(n + 1).fill(0); } private static lowbit(x: number): number { return x & -x; } update(x: number, delta: number): void { while (x <= this.n) { this.c[x] += delta; x += BinaryIndexedTree.lowbit(x); } } query(x: number): number { let s = 0; while (x > 0) { s += this.c[x]; x -= BinaryIndexedTree.lowbit(x); } return s; } } function goodTriplets(nums1: number[], nums2: number[]): number { const n = nums1.length; const pos = new Map(); nums2.forEach((v, i) => pos.set(v, i + 1)); const tree = new BinaryIndexedTree(n); let ans = 0; for (const num of nums1) { const p = pos.get(num)!; const left = tree.query(p); const total = tree.query(n); const right = n - p - (total - left); ans += left * right; tree.update(p, 1); } return ans; } ``` ### 方法二:线段树 我们也可以用线段树来实现。线段树是一种数据结构,能够高效地进行区间查询和更新操作。它的基本思想是将一个区间划分为多个子区间,并且每个子区间都可以用一个节点来表示。 线段树将整个区间分割为多个不连续的子区间,子区间的数量不超过 `log(width)`。更新某个元素的值,只需要更新 `log(width)` 个区间,并且这些区间都包含在一个包含该元素的大区间内。 - 线段树的每个节点代表一个区间; - 线段树具有唯一的根节点,代表的区间是整个统计范围,如 `[1, N]`; - 线段树的每个叶子节点代表一个长度为 1 的元区间 `[x, x]`; - 对于每个内部节点 `[l, r]`,它的左儿子是 `[l, mid]`,右儿子是 `[mid + 1, r]`, 其中 `mid = ⌊(l + r) / 2⌋` (即向下取整)。 时间复杂度 $O(n \log n)$,其中 $n$ 为数组 $\textit{nums1}$ 的长度。空间复杂度 $O(n)$。 #### Python3 ```python class Node: __slots__ = ("l", "r", "v") def __init__(self): self.l = 0 self.r = 0 self.v = 0 class SegmentTree: def __init__(self, n): self.tr = [Node() for _ in range(4 * n)] self.build(1, 1, n) def build(self, u, l, r): self.tr[u].l = l self.tr[u].r = r if l == r: return mid = (l + r) >> 1 self.build(u << 1, l, mid) self.build(u << 1 | 1, mid + 1, r) def modify(self, u, x, v): if self.tr[u].l == x and self.tr[u].r == x: self.tr[u].v += v return mid = (self.tr[u].l + self.tr[u].r) >> 1 if x <= mid: self.modify(u << 1, x, v) else: self.modify(u << 1 | 1, x, v) self.pushup(u) def pushup(self, u): self.tr[u].v = self.tr[u << 1].v + self.tr[u << 1 | 1].v def query(self, u, l, r): if self.tr[u].l >= l and self.tr[u].r <= r: return self.tr[u].v mid = (self.tr[u].l + self.tr[u].r) >> 1 v = 0 if l <= mid: v += self.query(u << 1, l, r) if r > mid: v += self.query(u << 1 | 1, l, r) return v class Solution: def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int: pos = {v: i for i, v in enumerate(nums2, 1)} ans = 0 n = len(nums1) tree = SegmentTree(n) for num in nums1: p = pos[num] left = tree.query(1, 1, p) right = n - p - (tree.query(1, 1, n) - tree.query(1, 1, p)) ans += left * right tree.modify(1, p, 1) return ans ``` #### Java ```java class Solution { public long goodTriplets(int[] nums1, int[] nums2) { int n = nums1.length; int[] pos = new int[n]; SegmentTree tree = new SegmentTree(n); for (int i = 0; i < n; ++i) { pos[nums2[i]] = i + 1; } long ans = 0; for (int num : nums1) { int p = pos[num]; long left = tree.query(1, 1, p); long right = n - p - (tree.query(1, 1, n) - tree.query(1, 1, p)); ans += left * right; tree.modify(1, p, 1); } return ans; } } class Node { int l; int r; int v; } class SegmentTree { private Node[] tr; public SegmentTree(int n) { tr = new Node[4 * n]; for (int i = 0; i < tr.length; ++i) { tr[i] = new Node(); } build(1, 1, n); } public void build(int u, int l, int r) { tr[u].l = l; tr[u].r = r; if (l == r) { return; } int mid = (l + r) >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); } public void modify(int u, int x, int v) { if (tr[u].l == x && tr[u].r == x) { tr[u].v += v; return; } int mid = (tr[u].l + tr[u].r) >> 1; if (x <= mid) { modify(u << 1, x, v); } else { modify(u << 1 | 1, x, v); } pushup(u); } public void pushup(int u) { tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v; } public int query(int u, int l, int r) { if (tr[u].l >= l && tr[u].r <= r) { return tr[u].v; } int mid = (tr[u].l + tr[u].r) >> 1; int v = 0; if (l <= mid) { v += query(u << 1, l, r); } if (r > mid) { v += query(u << 1 | 1, l, r); } return v; } } ``` #### C++ ```cpp class Node { public: int l; int r; int v; }; class SegmentTree { public: vector tr; SegmentTree(int n) { tr.resize(4 * n); for (int i = 0; i < tr.size(); ++i) tr[i] = new Node(); build(1, 1, n); } void build(int u, int l, int r) { tr[u]->l = l; tr[u]->r = r; if (l == r) return; int mid = (l + r) >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); } void modify(int u, int x, int v) { if (tr[u]->l == x && tr[u]->r == x) { tr[u]->v += v; return; } int mid = (tr[u]->l + tr[u]->r) >> 1; if (x <= mid) modify(u << 1, x, v); else modify(u << 1 | 1, x, v); pushup(u); } void pushup(int u) { tr[u]->v = tr[u << 1]->v + tr[u << 1 | 1]->v; } int query(int u, int l, int r) { if (tr[u]->l >= l && tr[u]->r <= r) return tr[u]->v; int mid = (tr[u]->l + tr[u]->r) >> 1; int v = 0; if (l <= mid) v += query(u << 1, l, r); if (r > mid) v += query(u << 1 | 1, l, r); return v; } }; class Solution { public: long long goodTriplets(vector& nums1, vector& nums2) { int n = nums1.size(); vector pos(n); for (int i = 0; i < n; ++i) pos[nums2[i]] = i + 1; SegmentTree* tree = new SegmentTree(n); long long ans = 0; for (int& num : nums1) { int p = pos[num]; int left = tree->query(1, 1, p); int right = n - p - (tree->query(1, 1, n) - tree->query(1, 1, p)); ans += 1ll * left * right; tree->modify(1, p, 1); } return ans; } }; ``` #### Go ```go type Node struct { l, r, v int } type SegmentTree struct { tr []Node } func NewSegmentTree(n int) *SegmentTree { tr := make([]Node, 4*n) st := &SegmentTree{tr: tr} st.build(1, 1, n) return st } func (st *SegmentTree) build(u, l, r int) { st.tr[u].l = l st.tr[u].r = r if l == r { return } mid := (l + r) >> 1 st.build(u<<1, l, mid) st.build(u<<1|1, mid+1, r) } func (st *SegmentTree) modify(u, x, v int) { if st.tr[u].l == x && st.tr[u].r == x { st.tr[u].v += v return } mid := (st.tr[u].l + st.tr[u].r) >> 1 if x <= mid { st.modify(u<<1, x, v) } else { st.modify(u<<1|1, x, v) } st.pushup(u) } func (st *SegmentTree) pushup(u int) { st.tr[u].v = st.tr[u<<1].v + st.tr[u<<1|1].v } func (st *SegmentTree) query(u, l, r int) int { if st.tr[u].l >= l && st.tr[u].r <= r { return st.tr[u].v } mid := (st.tr[u].l + st.tr[u].r) >> 1 res := 0 if l <= mid { res += st.query(u<<1, l, r) } if r > mid { res += st.query(u<<1|1, l, r) } return res } func goodTriplets(nums1 []int, nums2 []int) int64 { n := len(nums1) pos := make(map[int]int) for i, v := range nums2 { pos[v] = i + 1 } tree := NewSegmentTree(n) var ans int64 for _, num := range nums1 { p := pos[num] left := tree.query(1, 1, p) right := n - p - (tree.query(1, 1, n) - tree.query(1, 1, p)) ans += int64(left * right) tree.modify(1, p, 1) } return ans } ``` #### TypeScript ```ts class Node { l: number = 0; r: number = 0; v: number = 0; } class SegmentTree { private tr: Node[]; constructor(n: number) { this.tr = Array(4 * n); for (let i = 0; i < 4 * n; i++) { this.tr[i] = new Node(); } this.build(1, 1, n); } private build(u: number, l: number, r: number): void { this.tr[u].l = l; this.tr[u].r = r; if (l === r) return; const mid = (l + r) >> 1; this.build(u << 1, l, mid); this.build((u << 1) | 1, mid + 1, r); } modify(u: number, x: number, v: number): void { if (this.tr[u].l === x && this.tr[u].r === x) { this.tr[u].v += v; return; } const mid = (this.tr[u].l + this.tr[u].r) >> 1; if (x <= mid) { this.modify(u << 1, x, v); } else { this.modify((u << 1) | 1, x, v); } this.pushup(u); } private pushup(u: number): void { this.tr[u].v = this.tr[u << 1].v + this.tr[(u << 1) | 1].v; } query(u: number, l: number, r: number): number { if (this.tr[u].l >= l && this.tr[u].r <= r) { return this.tr[u].v; } const mid = (this.tr[u].l + this.tr[u].r) >> 1; let res = 0; if (l <= mid) { res += this.query(u << 1, l, r); } if (r > mid) { res += this.query((u << 1) | 1, l, r); } return res; } } function goodTriplets(nums1: number[], nums2: number[]): number { const n = nums1.length; const pos = new Map(); nums2.forEach((v, i) => pos.set(v, i + 1)); const tree = new SegmentTree(n); let ans = 0; for (const num of nums1) { const p = pos.get(num)!; const left = tree.query(1, 1, p); const total = tree.query(1, 1, n); const right = n - p - (total - left); ans += left * right; tree.modify(1, p, 1); } return ans; } ```