ABC397F 题解

· · 题解

可以看看这个 CF833B。

题意

给你一个序列,你要把这个序列分成连续的三段,每一段的权值为每一段中不同的数的个数,问你这个序列的最大权值是多少。

思路

首先发现可以预处理来第一段和最后一段的权值,设为 ans1ans2ans1_i 即为从 1i 这一段的权值,ans2_i 即为从 in 这一段的权值,于是我们就有了 O(n^2) 的 dp 做法。

f_i 表示以 i 为第二段结尾的最大权值,答案即为 \max f_i

f_i=\max_{j=1}^{i-1} ans1_j+val(j+1,i)+ans2_j 考虑优化。 我们发现,难点在于如何计算 $val(j+1,i)$。 考虑说我们可以把每一个 $ans1$ 扔到线段树上,一个点 $j$ 就表示说以 $j$ 为第一段的结尾最大权值,然后我们枚举第二段的结尾 $i$,每次 dp。 我们记 $a_i$ 上一次出现的位置为 $la_i$。 考虑说我们每个数只对于起点在 $[la_i+1,i]$ 这个闭区间的所有区间有 $1$ 的贡献,也就是线段树上 $[la_i,i-1]$ 这个区间。 最后一段的直接加上即可。 然后每次转移找最大值。 复杂度 $O(n\log n)$。 不理解可以结合代码食用。 ## Code ```c++ #include <bits/stdc++.h> #define endl '\n' #define int long long #define fi first #define se second using namespace std; const int N=3e5+10; const int inf=0x3f3f3f3f3f3f3f3f; int n; int a[N]; int ans1[N],ans2[N]; unordered_map<int,int> mp; int ans; int pre[N],nxt[N]; struct Node { int l,r,w; int lt; }tr[N<<2]; void build(int rt,int l,int r) { tr[rt].l=l,tr[rt].r=r; if(l==r) { tr[rt].w=ans1[l]; return ; } int mid=(tr[rt].l+tr[rt].r)>>1; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); tr[rt].w=max(tr[rt<<1].w,tr[rt<<1|1].w); } void pushdown(int rt) { int &tag=tr[rt].lt; tr[rt<<1].w+=tag; tr[rt<<1|1].w+=tag; tr[rt<<1].lt+=tag; tr[rt<<1|1].lt+=tag; tag=0; } void add(int rt,int l,int r,int k) { if(tr[rt].l>=l&&tr[rt].r<=r) { tr[rt].w+=k; tr[rt].lt+=k; return ; } pushdown(rt); int mid=(tr[rt].r+tr[rt].l)>>1; if(l<=mid) add(rt<<1,l,r,k); if(r>mid) add(rt<<1|1,l,r,k); tr[rt].w=max(tr[rt<<1].w,tr[rt<<1|1].w); } int check(int rt,int l,int r) { if(tr[rt].l>=l&&tr[rt].r<=r) return tr[rt].w; pushdown(rt); int mid=(tr[rt].l+tr[rt].r)>>1; int res=0; if(l<=mid) res=max(res,check(rt<<1,l,r)); if(r>mid) res=max(res,check(rt<<1|1,l,r)); return res; } signed main() { //freopen(".in","r",stdin); //freopen(".out","w",stdout); cin.tie(0); cout.tie(0); ios::sync_with_stdio(false); cin>>n; for(int i=1;i<=n;i++) cin>>a[i]; for(int i=1;i<=n;i++) { pre[i]=mp[a[i]]; mp[a[i]]=i; ans1[i]=mp.size(); } mp.clear(); for(int i=n;i>=1;i--) { nxt[i]=mp[a[i]]; mp[a[i]]=i; ans2[i]=mp.size(); } build(1,1,n); for(int i=2;i<n;i++) { add(1,(pre[i]?pre[i]:1),i-1,1); ans=max(check(1,1,i-1)+ans2[i+1],ans); } cout<<ans; return 0; } ```