Codeforces 上的教程:link
基本思想
问题描述
我们考虑一个问题:
给定某个序列,要求维护的操作有:单点修改,求前缀和,搜索某个前缀和(类似于在前缀和数组上求 lower_bound
)。
这个事实上容易用线段树完成,但是由于树状数组具有空间更小、常数小、代码简单等优点,我们想用树状数组完成这个操作。
$O(\lg^2n)$ 的实现
单点修改和求前缀和都可以用树状数组完成,问题主要在于如何搜索某个前缀和。
由于前缀和这玩意儿具有单调性,一个简单的想法就是二分查找。代码如下:
>folded1 2 3 4 5 6 7 8 9
| int search(int val){ int l = 1, r = n; while(l < r){ int mid = (l + r) >> 1; if(sum(mid) < val) l = mid + 1; else r = mid; } return l; }
|
二分是 $O(\lg n)$ 的,在树状数组上求前缀和是 $O(\lg n)$ 的,所以总复杂度是 $O(\lg^2n)$ 的。
$O(\lg n)$ 的实现——倍增思想
倍增思想有很多重要的应用,例如 $ST$ 表、倍增求 lca 等,这里可以帮助我们在树状数组上完成 lower_bound
操作。
假设我们想要搜索前缀和为 $val$ 的地方,设定一个 pos
指针,它初始为 $0$,最终将指向最大的前缀和小于 $val$ 的位置;再设置一个变量 sum
,存储 pos
处的前缀和;设置倍增的长度 i
,最初为 $\lg n$(为了代码方便,一般取 $20$ 即可),在倍增的过程中不断减小至 $0$。每一个状态(pos
,sum
,i
)表示我们现在考虑的是位置 pos+(1<<i)
的前缀和,这个前缀和的值是 sum+c[pos+(1<<i)]
,如果它大于等于了 $val$,那么我们减小倍增的长度 i
;否则,我们把 pos
提到 pos+(1<<i)
处。
我们用例子来更直观地说明【以下例子和图片均来源于 Codeforces 的教程】:
给定数组 a[]
:
它的树状数组 c[]
长这样:
我们想搜索 $val=27$ 的位置,那么算法过程如下:
最后 pos
值为 $13$,是最大的前缀和小于 $27$ 的位置。所以我们的目标位置就是 pos+1
.
代码如下:
>folded1 2 3 4 5 6 7
| int search(int val){ int pos = 0, sum = 0; for(int i = 20; i >= 0; i--) if(pos + (1<<i) <= n && sum + c[pos+(1<<i)] < val) pos += (1<<i), sum += c[pos]; return pos + 1; }
|
进一步
容易发现,只要我们维护的信息具有单调性,就可以用这个方法。
练习
CF1354D Multiset
题目链接
其实这道题是我学树状数组倍增的原因。
开一个值域树状数组,维护前缀个数,这玩意儿是单调增加的,所以查询第 $k$ 个数可以用上树状数组倍增。
>folded1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
| #include<bits/stdc++.h> using namespace std; template<typename T>void read(T&x){x=0;int fl=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-') fl=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}x*=fl;} template<typename T,typename...Args>inline void read(T&t,Args&...args){read(t);read(args...);} typedef long long LL; typedef vector<int> vi; typedef pair<int, int> pii; #define mp(x, y) make_pair(x, y) #define pb(x) emplace_back(x) const int N = 1000005;
int n, q;
int c[N]; inline int lowbit(int x){ return x & -x; } inline void add(int x, int val){ while(x <= n){ c[x] += val; x += lowbit(x); } } inline int search(int val){ int pos = 0, sum = 0; for(int i = 20; i >= 0; i--) if(pos + (1<<i) <= n && sum + c[pos+(1<<i)] < val) pos += (1<<i), sum += c[pos]; return pos + 1; }
int main(){ read(n, q); for(int i = 1; i <= n; i++){ int x; read(x); add(x, 1); } while(q--){ int x; read(x); if(x > 0) add(x, 1); else add(search(-x), -1); } int ans = search(1); if(ans == n + 1) puts("0"); else printf("%d\n", ans); return 0; }
|
CF992E Nastya and King-Shamans
题目链接
由 $a_i=s_{i-1}$ 可以推出 $s_i=2s_{i-1}$,那么我们用树状数组维护前缀和,每次询问时从 $sum=0$ 开始查找第一个前缀和大于等于 $2sum$ 的位置,由于每次乘 $2$,所以最多查找 $\lg 10^{14}$ 次。
>folded1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| #include<bits/stdc++.h> using namespace std; template<typename T>void read(T&x){x=0;int fl=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-') fl=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}x*=fl;} template<typename T,typename...Args>inline void read(T&t,Args&...args){read(t);read(args...);} typedef long long LL; typedef vector<int> vi; typedef pair<int, int> pii; #define mp(x, y) make_pair(x, y) #define pb(x) emplace_back(x) const int N = 200005;
int n, q, a[N];
LL c[N]; inline int lowbit(int x){ return x & -x; } inline void add(int x, LL val){ while(x <= n){ c[x] += val; x += lowbit(x); } } inline LL sum(int x){ LL res = 0; while(x){ res += c[x]; x -= lowbit(x); } return res; } inline int search(LL val){ int pos = 0; LL sum = 0; for(int i = 20; i >= 0; i--) if(pos + (1<<i) <= n && sum + c[pos + (1<<i)] < val) pos += (1<<i), sum += c[pos]; return pos + 1; }
int main(){ read(n, q); for(int i = 1; i <= n; i++){ read(a[i]); add(i, a[i]); } while(q--){ int p, x; read(p, x); add(p, x - a[p]); a[p] = x; LL s = 0; while(1){ int pos = search(s << 1); if(pos == n + 1){ puts("-1"); break; } if(sum(pos) == sum(pos-1) << 1){ printf("%d\n", pos); break; } s = sum(pos); } } return 0; }
|