Preface

Codeforces Round 958 (Div. 2) E 题时,需要求得一个数组任意元素左右第 11 和第 22 个比它小的元素。

研究了 Jiangly 的写法,得到一个通用的解法,记录如下。

问题引入

我们都知道,如果我们想求得一个数组任意元素左边或右边的第一个比他大或小的元素,可以通过单调栈实现,时间和空间复杂度都是 O(n)O(n) 。但是如果我们想求得一个数组任意元素左边或右边的第 kk 个比他大或小的元素,该怎么办呢?

解决思路

线段树 + 二分(线段树上二分优化)

我们考虑使用线段树来解决这个问题。我们可以用线段树来维护区间最小值,然后通过二分查找来找到第 ii 个比他小或大的元素的位置,共 kk 次二分查找,时间复杂度为 O(nklog2n)O(nk\log^2 n)

考虑线段树上二分优化,可以消去一层二分,时间复杂度为 O(nklogn)O(nk\log n)

这样,我们就得到了时间复杂度为 O(nklogn)O(nk\log n) 的解法。

但是这种解法显然不够优秀。不论从实现复杂性上看,还是从复杂度上看都不够好。

时间上,我们带有一个 logn\log n ,而空间上,我们需要四倍的空间来维护线段树。

再加上线段树的常数较大,这显然不是一个好的解法。

单调栈 + 动态规划

回顾

考虑再次回到我们一开始的思路,我们可以通过单调栈 + 动态规划来解决这个问题。

让我们先考虑如何求得一个数组任意元素右边的第一个比他小的元素。

我们可以通过单调栈来实现这个功能,我们维护一个单调递增的栈,栈中存放的是元素的下标。

我们从左往右遍历数组,如果当前元素比栈顶元素小,我们将这个元素出栈,直到栈为空或者栈顶元素比当前元素小。

我们就得到出栈元素的右边第一个比他小的元素的位置,就是我们当前元素的位置。

简单实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
std::vector<int> next_smaller(std::vector<int> const& nums) {
std::vector<int> res(nums.size(), -1);
std::stack<int> stk;
for (int i = 0; i < nums.size(); i++) {
while (!stk.empty() && nums[stk.top()] > nums[i]) {
res[stk.top()] = i;
stk.pop();
}
stk.push(i);
}
return res;
}

推广

我们发现,解决问题的本质是记录每个元素是被谁弹出的。

于是,对于该元素右侧第 kk 个比他小的元素,我们只需要维护 kk 层单调栈,让其被弹出 kk 次,每次向后转移即可。

容易发现,kk 层栈依然具有单调性。

实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
auto next_kth_smaller(std::vector<int> const& nums, int k) {
int n = nums.size();
std::vector<std::vector<int>> res(k, std::vector<int>(n, -1));
std::vector<std::vector<int>> stk(k);
for (int i = 0; i < n; i++) {
for (int j = k - 1; j >= 0; j--) {
std::vector<int> tmp;
while (!stk[j].empty() && nums[stk[j].back()] > nums[i]) {
if (j + 1 < k) {
tmp.push_back(stk[j].back());
}
res[j][stk[j].back()] = i;
stk[j].pop_back();
}
if (j + 1 < k) {
stk[j + 1].insert(stk[j + 1].end(), tmp.rbegin(), tmp.rend());
}
}
stk[0].push_back(i);
}
return res;
}

进行简单的封装:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
template <class T, class Cmp = std::less<T>>
auto next_kth_smaller(std::vector<T> const& nums, int k, Cmp&& cmp = Cmp{}) {
int n = nums.size();
std::vector<std::vector<T>> res(k, std::vector<T>(n, -1));
std::vector<std::vector<T>> stk(k);
for (int i = 0; i < n; i++) {
for (int j = k - 1; j >= 0; j--) {
std::vector<int> tmp;
while (!stk[j].empty() && cmp(nums[i], nums[stk[j].back()])) {
if (j + 1 < k) {
tmp.push_back(stk[j].back());
}
res[j][stk[j].back()] = i;
stk[j].pop_back();
}
if (j + 1 < k) {
stk[j + 1].insert(stk[j + 1].end(), tmp.rbegin(), tmp.rend());
}
}
stk[0].push_back(i);
}
return res;
}

template <class T, class Cmp = std::less<T>>
auto prev_kth_smaller(std::vector<T> const& nums, int k, Cmp&& cmp = Cmp{}) {
int n = nums.size();
std::vector<std::vector<T>> res(k, std::vector<T>(n, -1));
std::vector<std::vector<T>> stk(k);
for (int i = n - 1; i >= 0; i--) {
for (int j = k - 1; j >= 0; j--) {
std::vector<int> tmp;
while (!stk[j].empty() && cmp(nums[i], nums[stk[j].back()])) {
if (j + 1 < k) {
tmp.push_back(stk[j].back());
}
res[j][stk[j].back()] = i;
stk[j].pop_back();
}
if (j + 1 < k) {
stk[j + 1].insert(stk[j + 1].end(), tmp.rbegin(), tmp.rend());
}
}
stk[0].push_back(i);
}
return res;
}

这样,我们就得到了一个时间复杂度为 O(nk)O(nk) , 空间复杂度为 O(nk)O(nk) 的解法。