--- comments: true difficulty: 困难 edit_url: https://github.com/doocs/leetcode/edit/main/solution/0300-0399/0315.Count%20of%20Smaller%20Numbers%20After%20Self/README.md tags: - 树状数组 - 线段树 - 数组 - 二分查找 - 分治 - 有序集合 - 归并排序 --- # [315. 计算右侧小于当前元素的个数](https://leetcode.cn/problems/count-of-smaller-numbers-after-self) [English Version](/solution/0300-0399/0315.Count%20of%20Smaller%20Numbers%20After%20Self/README_EN.md) ## 题目描述

给你一个整数数组 nums ,按要求返回一个新数组 counts 。数组 counts 有该性质: counts[i] 的值是  nums[i] 右侧小于 nums[i] 的元素的数量。

 

示例 1:

输入:nums = [5,2,6,1]
输出:[2,1,1,0] 
解释:
5 的右侧有 2 个更小的元素 (2 和 1)
2 的右侧仅有 1 个更小的元素 (1)
6 的右侧有 1 个更小的元素 (1)
1 的右侧有 0 个更小的元素

示例 2:

输入:nums = [-1]
输出:[0]

示例 3:

输入:nums = [-1,-1]
输出:[0,0]

 

提示:

## 解法 ### 方法一:树状数组 树状数组,也称作“二叉索引树”(Binary Indexed Tree)或 Fenwick 树。 它可以高效地实现如下两个操作: 1. **单点更新** `update(x, delta)`: 把序列 x 位置的数加上一个值 delta; 1. **前缀和查询** `query(x)`:查询序列 `[1,...x]` 区间的区间和,即位置 x 的前缀和。 这两个操作的时间复杂度均为 $O(\log n)$。 树状数组最基本的功能就是求比某点 x 小的点的个数(这里的比较是抽象的概念,可以是数的大小、坐标的大小、质量的大小等等)。 比如给定数组 `a[5] = {2, 5, 3, 4, 1}`,求 `b[i] = 位置 i 左边小于等于 a[i] 的数的个数`。对于此例,`b[5] = {0, 1, 1, 2, 0}`。 解决方案是直接遍历数组,每个位置先求出 `query(a[i])`,然后再修改树状数组 `update(a[i], 1)` 即可。当数的范围比较大时,需要进行离散化,即先进行去重并排序,然后对每个数字进行编号。 #### Python3 ```python class BinaryIndexedTree: def __init__(self, n): self.n = n self.c = [0] * (n + 1) @staticmethod def lowbit(x): return x & -x def update(self, x, delta): while x <= self.n: self.c[x] += delta x += BinaryIndexedTree.lowbit(x) def query(self, x): s = 0 while x > 0: s += self.c[x] x -= BinaryIndexedTree.lowbit(x) return s class Solution: def countSmaller(self, nums: List[int]) -> List[int]: alls = sorted(set(nums)) m = {v: i for i, v in enumerate(alls, 1)} tree = BinaryIndexedTree(len(m)) ans = [] for v in nums[::-1]: x = m[v] tree.update(x, 1) ans.append(tree.query(x - 1)) return ans[::-1] ``` #### Java ```java class Solution { public List countSmaller(int[] nums) { Set s = new HashSet<>(); for (int v : nums) { s.add(v); } List alls = new ArrayList<>(s); alls.sort(Comparator.comparingInt(a -> a)); int n = alls.size(); Map m = new HashMap<>(n); for (int i = 0; i < n; ++i) { m.put(alls.get(i), i + 1); } BinaryIndexedTree tree = new BinaryIndexedTree(n); LinkedList ans = new LinkedList<>(); for (int i = nums.length - 1; i >= 0; --i) { int x = m.get(nums[i]); tree.update(x, 1); ans.addFirst(tree.query(x - 1)); } return ans; } } class BinaryIndexedTree { private int n; private int[] c; public BinaryIndexedTree(int n) { this.n = n; c = new int[n + 1]; } public void update(int x, int delta) { while (x <= n) { c[x] += delta; x += lowbit(x); } } public int query(int x) { int s = 0; while (x > 0) { s += c[x]; x -= lowbit(x); } return s; } public static int lowbit(int x) { return x & -x; } } ``` #### C++ ```cpp class BinaryIndexedTree { public: int n; vector c; BinaryIndexedTree(int _n) : n(_n) , c(_n + 1) {} void update(int x, int delta) { while (x <= n) { c[x] += delta; x += lowbit(x); } } int query(int x) { int s = 0; while (x > 0) { s += c[x]; x -= lowbit(x); } return s; } int lowbit(int x) { return x & -x; } }; class Solution { public: vector countSmaller(vector& nums) { unordered_set s(nums.begin(), nums.end()); vector alls(s.begin(), s.end()); sort(alls.begin(), alls.end()); unordered_map m; int n = alls.size(); for (int i = 0; i < n; ++i) m[alls[i]] = i + 1; BinaryIndexedTree* tree = new BinaryIndexedTree(n); vector ans(nums.size()); for (int i = nums.size() - 1; i >= 0; --i) { int x = m[nums[i]]; tree->update(x, 1); ans[i] = tree->query(x - 1); } return ans; } }; ``` #### Go ```go type BinaryIndexedTree struct { n int c []int } func newBinaryIndexedTree(n int) *BinaryIndexedTree { c := make([]int, n+1) return &BinaryIndexedTree{n, c} } func (this *BinaryIndexedTree) lowbit(x int) int { return x & -x } func (this *BinaryIndexedTree) update(x, delta int) { for x <= this.n { this.c[x] += delta x += this.lowbit(x) } } func (this *BinaryIndexedTree) query(x int) int { s := 0 for x > 0 { s += this.c[x] x -= this.lowbit(x) } return s } func countSmaller(nums []int) []int { s := make(map[int]bool) for _, v := range nums { s[v] = true } var alls []int for v := range s { alls = append(alls, v) } sort.Ints(alls) m := make(map[int]int) for i, v := range alls { m[v] = i + 1 } ans := make([]int, len(nums)) tree := newBinaryIndexedTree(len(alls)) for i := len(nums) - 1; i >= 0; i-- { x := m[nums[i]] tree.update(x, 1) ans[i] = tree.query(x - 1) } return ans } ``` ### 方法二:线段树 线段树将整个区间分割为多个不连续的子区间,子区间的数量不超过 `log(width)`。更新某个元素的值,只需要更新 `log(width)` 个区间,并且这些区间都包含在一个包含该元素的大区间内。 - 线段树的每个节点代表一个区间; - 线段树具有唯一的根节点,代表的区间是整个统计范围,如 `[1, N]`; - 线段树的每个叶子节点代表一个长度为 1 的元区间 `[x, x]`; - 对于每个内部节点 `[l, r]`,它的左儿子是 `[l, mid]`,右儿子是 `[mid + 1, r]`, 其中 `mid = ⌊(l + r) / 2⌋` (即向下取整)。 #### Python3 ```python class Node: def __init__(self): self.l = 0 self.r = 0 self.v = 0 class SegmentTree: def __init__(self, n): self.tr = [Node() for _ in range(n << 2)] self.build(1, 1, n) def build(self, u, l, r): self.tr[u].l = l self.tr[u].r = r if l == r: return mid = (l + r) >> 1 self.build(u << 1, l, mid) self.build(u << 1 | 1, mid + 1, r) def modify(self, u, x, v): if self.tr[u].l == x and self.tr[u].r == x: self.tr[u].v += v return mid = (self.tr[u].l + self.tr[u].r) >> 1 if x <= mid: self.modify(u << 1, x, v) else: self.modify(u << 1 | 1, x, v) self.pushup(u) def query(self, u, l, r): if self.tr[u].l >= l and self.tr[u].r <= r: return self.tr[u].v mid = (self.tr[u].l + self.tr[u].r) >> 1 v = 0 if l <= mid: v += self.query(u << 1, l, r) if r > mid: v += self.query(u << 1 | 1, l, r) return v def pushup(self, u): self.tr[u].v = self.tr[u << 1].v + self.tr[u << 1 | 1].v class Solution: def countSmaller(self, nums: List[int]) -> List[int]: s = sorted(set(nums)) m = {v: i for i, v in enumerate(s, 1)} tree = SegmentTree(len(s)) ans = [] for v in nums[::-1]: x = m[v] ans.append(tree.query(1, 1, x - 1)) tree.modify(1, x, 1) return ans[::-1] ``` #### Java ```java class Solution { public List countSmaller(int[] nums) { Set s = new HashSet<>(); for (int v : nums) { s.add(v); } List alls = new ArrayList<>(s); alls.sort(Comparator.comparingInt(a -> a)); int n = alls.size(); Map m = new HashMap<>(n); for (int i = 0; i < n; ++i) { m.put(alls.get(i), i + 1); } SegmentTree tree = new SegmentTree(n); LinkedList ans = new LinkedList<>(); for (int i = nums.length - 1; i >= 0; --i) { int x = m.get(nums[i]); tree.modify(1, x, 1); ans.addFirst(tree.query(1, 1, x - 1)); } return ans; } } class Node { int l; int r; int v; } class SegmentTree { private Node[] tr; public SegmentTree(int n) { tr = new Node[4 * n]; for (int i = 0; i < tr.length; ++i) { tr[i] = new Node(); } build(1, 1, n); } public void build(int u, int l, int r) { tr[u].l = l; tr[u].r = r; if (l == r) { return; } int mid = (l + r) >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); } public void modify(int u, int x, int v) { if (tr[u].l == x && tr[u].r == x) { tr[u].v += v; return; } int mid = (tr[u].l + tr[u].r) >> 1; if (x <= mid) { modify(u << 1, x, v); } else { modify(u << 1 | 1, x, v); } pushup(u); } public void pushup(int u) { tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v; } public int query(int u, int l, int r) { if (tr[u].l >= l && tr[u].r <= r) { return tr[u].v; } int mid = (tr[u].l + tr[u].r) >> 1; int v = 0; if (l <= mid) { v += query(u << 1, l, r); } if (r > mid) { v += query(u << 1 | 1, l, r); } return v; } } ``` #### C++ ```cpp class Node { public: int l; int r; int v; }; class SegmentTree { public: vector tr; SegmentTree(int n) { tr.resize(4 * n); for (int i = 0; i < tr.size(); ++i) tr[i] = new Node(); build(1, 1, n); } void build(int u, int l, int r) { tr[u]->l = l; tr[u]->r = r; if (l == r) return; int mid = (l + r) >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); } void modify(int u, int x, int v) { if (tr[u]->l == x && tr[u]->r == x) { tr[u]->v += v; return; } int mid = (tr[u]->l + tr[u]->r) >> 1; if (x <= mid) modify(u << 1, x, v); else modify(u << 1 | 1, x, v); pushup(u); } void pushup(int u) { tr[u]->v = tr[u << 1]->v + tr[u << 1 | 1]->v; } int query(int u, int l, int r) { if (tr[u]->l >= l && tr[u]->r <= r) return tr[u]->v; int mid = (tr[u]->l + tr[u]->r) >> 1; int v = 0; if (l <= mid) v += query(u << 1, l, r); if (r > mid) v += query(u << 1 | 1, l, r); return v; } }; class Solution { public: vector countSmaller(vector& nums) { unordered_set s(nums.begin(), nums.end()); vector alls(s.begin(), s.end()); sort(alls.begin(), alls.end()); unordered_map m; int n = alls.size(); for (int i = 0; i < n; ++i) m[alls[i]] = i + 1; SegmentTree* tree = new SegmentTree(n); vector ans(nums.size()); for (int i = nums.size() - 1; i >= 0; --i) { int x = m[nums[i]]; tree->modify(1, x, 1); ans[i] = tree->query(1, 1, x - 1); } return ans; } }; ``` #### Go ```go type Pair struct { val int index int } var ( tmp []Pair count []int ) func countSmaller(nums []int) []int { tmp, count = make([]Pair, len(nums)), make([]int, len(nums)) array := make([]Pair, len(nums)) for i, v := range nums { array[i] = Pair{val: v, index: i} } sorted(array, 0, len(array)-1) return count } func sorted(arr []Pair, low, high int) { if low >= high { return } mid := low + (high-low)/2 sorted(arr, low, mid) sorted(arr, mid+1, high) merge(arr, low, mid, high) } func merge(arr []Pair, low, mid, high int) { left, right := low, mid+1 idx := low for left <= mid && right <= high { if arr[left].val <= arr[right].val { count[arr[left].index] += right - mid - 1 tmp[idx], left = arr[left], left+1 } else { tmp[idx], right = arr[right], right+1 } idx++ } for left <= mid { count[arr[left].index] += right - mid - 1 tmp[idx] = arr[left] idx, left = idx+1, left+1 } for right <= high { tmp[idx] = arr[right] idx, right = idx+1, right+1 } // 排序 for i := low; i <= high; i++ { arr[i] = tmp[i] } } ``` ### 方法三:归并排序