[CF833B] The Bakery

题目链接

Solution

考虑最简单的 $dp$:设 $dp[i][j]$ 表示将前 $i$ 个数字分成 $j$ 段能得到的最大价值,则:

其中,$c(k+1,i)$ 表示 $a[k+1],\,a[k+2],\cdots,a[i]$ 中不同数字的个数。

为了快速转移,我们考虑用线段树维护 $dp[k][j-1]+c(k+1,i)$,那么就能够做到 $O(nk\log n)$。注意 $c(k+1,i)$ 是与 $i$ 有关的,当我们遍历 $i$ 时,我们需要随时更新线段树中的信息以保证正确性。考虑 $c(k+1,i)$ 什么时候会改变——当区间 $[k+1,i]$ 中没有 $a[i]$ 这个数时,$c(k+1,i)$ 就相比 $c(k+1,i-1)$ 加了 $1$,所以更新信息的操作就是找到与 $a[i]$ 相同的数上一次出现的位置(设为 $p$),对 $[p,i-1]$ 进行区间加 $1$。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<bits/stdc++.h>

using namespace std;

template<typename T>void read(T&x){x=0;int fl=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')
fl=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}x*=fl;}
template<typename T,typename...Args>inline void read(T&t,Args&...args){read(t);read(args...);}

typedef long long LL;
typedef vector<int> vi;
typedef pair<int, int> pii;
#define mp(x, y) make_pair(x, y)
#define pb(x) emplace_back(x)

const int N = 35005;
const int K = 55;

struct segTree{
int l, r, mx, lazy;
}tr[K][N<<2];
#define lid (id<<1)
#define rid (id<<1|1)
#define mid ((tr[k][id].l + tr[k][id].r) >> 1)
inline void pushup(int k, int id){
tr[k][id].mx = max(tr[k][lid].mx, tr[k][rid].mx);
}
inline void pushdown(int k, int id){
if(tr[k][id].l == tr[k][id].r) return;
if(tr[k][id].lazy){
tr[k][lid].lazy += tr[k][id].lazy;
tr[k][lid].mx += tr[k][id].lazy;
tr[k][rid].lazy += tr[k][id].lazy;
tr[k][rid].mx += tr[k][id].lazy;
tr[k][id].lazy = 0;
}
}
void build(int k, int id, int l, int r){
tr[k][id].l = l, tr[k][id].r = r;
tr[k][id].mx = tr[k][id].lazy = 0;
if(l == r) return;
build(k, lid, l, mid), build(k, rid, mid+1, r);
pushup(k, id);
}
void st(int k, int id, int pos, int val){
pushdown(k, id);
if(tr[k][id].l == tr[k][id].r){
tr[k][id].mx = val;
tr[k][id].lazy = 0;
return;
}
if(pos <= mid) st(k, lid, pos, val);
else st(k, rid, pos, val);
pushup(k, id);
}
void add(int k, int id, int l, int r, int val){
pushdown(k, id);
if(tr[k][id].l == l && tr[k][id].r == r){
tr[k][id].lazy += val;
tr[k][id].mx += val;
return;
}
if(r <= mid) add(k, lid, l, r, val);
else if(l > mid) add(k, rid, l, r, val);
else add(k, lid, l, mid, val), add(k, rid, mid+1, r, val);
pushup(k, id);
}
int queryMax(int k, int id, int l, int r){
pushdown(k, id);
if(tr[k][id].l == l && tr[k][id].r == r) return tr[k][id].mx;
if(r <= mid) return queryMax(k, lid, l, r);
else if(l > mid) return queryMax(k, rid, l, r);
else return max(queryMax(k, lid, l, mid), queryMax(k, rid, mid+1, r));
}

int n, k, a[N], lst[N], pre[N], dp[N][K];

int main(){
read(n, k);
for(int i = 1; i <= n; i++){
read(a[i]);
dp[i][1] = dp[i-1][1] + (lst[a[i]] == 0);
pre[i] = max(1, lst[a[i]]);
lst[a[i]] = i;
}
for(int j = 1; j <= k; j++) build(j, 1, 1, n);
for(int i = 1; i <= n; i++) st(1, 1, i, dp[i][1]);
for(int j = 2; j <= k; j++){
for(int i = j; i <= n; i++){
add(j-1, 1, pre[i], i-1, 1);
dp[i][j] = queryMax(j-1, 1, 1, i-1);
st(j, 1, i, dp[i][j]);
}
}
printf("%d\n", dp[n][k]);
return 0;
}
作者

xyfJASON

发布于

2021-05-19

更新于

2021-05-19

许可协议

评论