题意:把一个长度为$n$的序列分成$m$段,每一段的代价是区间中相同的数的对数,最小化代价和
------------
贴一下[FlashHu大佬的博客讲解](https://www.cnblogs.com/flashhu/p/9480669.html)
**决策单调性优化**$dp$
首先有一个暴力的$dp$
设$f[i][j]$表示前$i$个数分成$j$段的最小代价
那么有方程:$f[i][j]=min\left\{f[k][j-1]+cost(k,i)\right\}$
用滚动数组去掉第二维,得到:$f[i]=min\left\{g[k]+cost(k,i)\right\}$
------------
设$f[i]$的最优决策点为$p[i]$,可以发现$p[i]<=p[i+1]$
所以这个转移是具有决策单调性的
那么能不能和[[JSOI2011]柠檬](https://www.lydsy.com/JudgeOnline/problem.php?id=4709)一样用单调栈做呢
因为这里的$cost(k,i)$不能$O(1)$算出
所以二分单调栈的复杂度会很高
------------
因为转移具有单调性并且离线,考虑用分治优化
设当前求解区间为$[l,r]$,决策区间为$[L,R]$
用一个桶来记录每个值在区间中出现的次数
暴力扫$[L,min(mid,R)]$区间
通过计算$f[i]+cost(i,mid)$找到这一段中的最优决策点$k$
然后递归到左子问题$(l,mid-1,L,k)$
注意这里的子问题$[l,mid-1],[L,k]$和当前问题$[L,l-1]$其实是一样的
所以把这一层修改过的桶和代价还原之后再递归进去
再看右子问题$(mid+1,r,k,R)$
它需要先处理出$[k,mid]$这个区间
当前是$[L,l-1]$,把左右端点分别移动,处理一下桶和代价
回溯的时候再还原,因为上一层还要用
```
#include<cstdio>
#include<cstring>
#include<cctype>
#include<algorithm>
#define reg register
using namespace std;
typedef long long ll;
const int N=1e5+5;
int n,m,a[N],cnt[N];
ll f[N],g[N];
inline int read()
{
int x=0,w=1;
char c=getchar();
while (!isdigit(c)&&c!='-') c=getchar();
if (c=='-') c=getchar(),w=-1;
while (isdigit(c))
{
x=(x<<1)+(x<<3)+c-'0';
c=getchar();
}
return x*w;
}
void solve(int l,int r,int L,int R,ll sum)//求解区间[l,r],决策区间[L,R],花费sum
{
if (l>r) return;
int mid=(l+r)>>1,pos=min(mid,R),k=0;
for (reg int i=l;i<=mid;i++) sum+=(cnt[a[i]]++);
for (reg int i=L;i<=pos;i++) {sum-=(--cnt[a[i]]); if (g[mid]>f[i]+sum) g[mid]=f[i]+sum,k=i;}
for (reg int i=L;i<=pos;i++) sum+=(cnt[a[i]]++);
for (reg int i=l;i<=mid;i++) sum-=(--cnt[a[i]]);//还原
solve(l,mid-1,L,k,sum);
for (reg int i=l;i<=mid;i++) sum+=(cnt[a[i]]++);
for (reg int i=L;i<k;i++) sum-=(--cnt[a[i]]);
solve(mid+1,r,k,R,sum);
for (reg int i=L;i<k;i++) sum+=(cnt[a[i]]++);
for (reg int i=l;i<=mid;i++) sum-=(--cnt[a[i]]);
}
int main()
{
n=read(),m=read();
for (reg int i=1;i<=n;a[i++]=read());
for (reg int i=1;i<=n;i++) f[i]=f[i-1]+(cnt[a[i]]++);//只分成一段的代价可以直接算出
memset(cnt,0,sizeof(cnt));
for (reg int i=2;i<=m;i++)
{
memset(g,127/3,sizeof(g));
solve(1,n,1,n,0);
memcpy(f,g,sizeof(g));
}
printf("%lld\n",f[n]);
return 0;
}
```