P5024
[NOIP2018 提高组] 保卫王国
调了一个下午一个小时加一节晚自习终于写完了。。。
自找的麻烦,自己给自己创造了超级多需要注意的细节。
先是一个整体的思路。
我们首先往暴力的
暴力之后就能够拿到 44pts 的好成绩。
然后是正解。一眼动态DP裸题倍增优化。
然后是一些比较麻烦的事情,我们这个待修改的倍增 DP 要注意的细节超级多。
首先定义
然后我们就可以瞎搞(因为这样想过之后方法就是会很麻烦)。
转移随便搞搞罢,注意时间复杂度,,我一开始居然写退化了。。。(丢脸)然后就是注意细节,虽然我们定义两个节点的状态
首先分类讨论,如果链的一端就是 LCA 的话,我们从较低的点往上跳,然后跳到 LCA 下面再转移给 LCA,然后再从 LCA 跳到根节点。
如果是一般情况,我们从两端开始跳到 LCA 下面,LCA 单独转移(注意还要加上其他子节点),然后从 LCA 再转移到根节点。
实现的细节很多。。。其实主要都是跳的过程中有点重合的存在,需要单独拎出来。然后就是重新修改
对拍+查错弄了几乎一整个晚自习,自闭了。代码弄到 6.83KB,自闭了。至今写过的最长的代码。。。
时间复杂度
代码:
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
const ll inf=(1ll<<40),N=1e5;
ll n,m,u,v,tot,cnt,a,b,x,y,ans;
char type[5];
ll lg[N+5],dt[N+5],fa[N+5][20],g[N+5][20][2][2],pos[N+5],f[N+5][2],sum[N+5][2];
bool vis[N+5][2];
ll cst[N+5],ver[N*2+5],nxt[N*2+5],head[N+5];
ll getlca(ll a,ll b) {
if(dt[a]<dt[b]) swap(a,b);
while(dt[a]>dt[b]) a=fa[a][lg[dt[a]-dt[b]]-1];
if(a==b) return a;
for(ll k=lg[dt[a]];k>=0;k--) {
if(fa[a][k]!=fa[b][k]) {
a=fa[a][k];b=fa[b][k];
}
}
return fa[a][0];
}
void dfs(ll p,ll fath) {
dt[p]=dt[fath]+1;fa[p][0]=fath;
if(fath!=0) {
g[p][0][1][1]=f[fath][1]-min(f[p][0],f[p][1]);//注意这里 -min 的还原
g[p][0][1][0]=f[fath][0]-f[p][1];//我们转移 f 的时候也特判了,这里同理
g[p][0][0][0]=inf;
g[p][0][0][1]=f[fath][1]-min(f[p][0],f[p][1]);//同理
}
for(ll i=1;i<=lg[dt[p]];i++) {
fa[p][i]=fa[fa[p][i-1]][i-1];
g[p][i][0][0]=min(g[p][i-1][0][0]+g[fa[p][i-1]][i-1][0][0],g[p][i-1][0][1]+g[fa[p][i-1]][i-1][1][0]);
g[p][i][0][1]=min(g[p][i-1][0][0]+g[fa[p][i-1]][i-1][0][1],g[p][i-1][0][1]+g[fa[p][i-1]][i-1][1][1]);
g[p][i][1][0]=min(g[p][i-1][1][0]+g[fa[p][i-1]][i-1][0][0],g[p][i-1][1][1]+g[fa[p][i-1]][i-1][1][0]);
g[p][i][1][1]=min(g[p][i-1][1][0]+g[fa[p][i-1]][i-1][0][1],g[p][i-1][1][1]+g[fa[p][i-1]][i-1][1][1]);
}
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
dfs(ver[i],p);
}
}
ll dp(ll p,ll stat,ll fath) {
if(vis[p][stat]) return f[p][stat];vis[p][stat]=1;
f[p][stat]=stat*cst[p];
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
if(stat) f[p][stat]+=min(dp(ver[i],0,p),dp(ver[i],1,p));
else f[p][stat]+=dp(ver[i],1,p);
}
return f[p][stat];
}
void add(ll u,ll v) {
ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}
void restitute() {
// printf("cnt=%lld\n",cnt);
for(ll i=1;i<=cnt;i++) {
sum[pos[i]][0]=sum[pos[i]][1]=inf;
}
cnt=0;
}
void init() {
for(ll i=1;i<=n;i++) {
sum[i][0]=sum[i][1]=inf;
}
}
inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
return ret*f;
}
void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {
do{buf[++len]=x%10+48;x/=10;}while(x);
}
else {
putchar('-');
do{buf[++len]=-(x%10)+48;x/=10;}while(x);
}
while(len>=0) putchar(buf[len--]);
}
void writeln(ll x) {
write(x);putchar('\n');
}
int main() {
n=read();m=read();cin>>type;
for(ll i=1;i<=n;i++) {
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
}
for(ll i=1;i<=n;i++) {
cst[i]=read();
}
for(ll i=1;i<n;i++) {
u=read();v=read();
add(u,v);add(v,u);
}
dp(1,0,0);dp(1,1,0);
dfs(1,0);dfs(1,0);
// for(ll i=1;i<=n;i++) {
// for(ll j=0;j<=lg[dt[i]];j++) {
// printf("When p=%lld fa=%lld j=%lld stat=0 stat_fa=0 g=%lld\n",i,fa[i][j],j,g[i][j][0][0]);
// printf("When p=%lld fa=%lld j=%lld stat=0 stat_fa=1 g=%lld\n",i,fa[i][j],j,g[i][j][0][1]);
// printf("When p=%lld fa=%lld j=%lld stat=1 stat_fa=0 g=%lld\n",i,fa[i][j],j,g[i][j][1][0]);
// printf("When p=%lld fa=%lld j=%lld stat=1 stat_fa=1 g=%lld\n",i,fa[i][j],j,g[i][j][1][1]);
// }
// }
//
init();
for(ll i=1;i<=m;i++) {
a=read();x=read();b=read();y=read();
if(dt[a]<dt[b]) {swap(a,b);swap(x,y);}
ll lca=getlca(a,b);
// printf("a=%lld , b=%lld , lca=%lld\n",a,b,lca);
if(b==lca) {
// printf("It is in the case 1:\n");
sum[a][x]=f[a][x];pos[++cnt]=a;
// printf("sum[%lld][%lld]=%lld\n",a,x,sum[a][x]);
for(ll j=lg[dt[a]-dt[b]];j>=0;j--) {
if(dt[fa[a][j]]>dt[b]) {
sum[fa[a][j]][0]=min(g[a][j][0][0]+sum[a][0],g[a][j][1][0]+sum[a][1]);
sum[fa[a][j]][1]=min(g[a][j][0][1]+sum[a][0],g[a][j][1][1]+sum[a][1]);
a=fa[a][j];pos[++cnt]=a;
}
}
// printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",a,sum[a][0],a,sum[a][1]);
if(y==0) sum[b][y]=sum[a][1]+g[a][0][1][y];
if(y==1) sum[b][y]=min(sum[a][0]+g[a][0][0][y],sum[a][1]+g[a][0][1][y]);
pos[++cnt]=b;
// printf("sum[%lld][%lld]=%lld ? sum[%lld][0]=%lld , sum[%lld][1]=%lld , g[%lld][0][0][%lld]=%lld , g[%lld][0][1][%lld]=%lld\n",b,y,sum[b][y],a,sum[a][0],a,sum[a][1],a,y,g[a][0][0][y],a,y,g[a][0][1][y]);
for(ll j=lg[dt[b]];j>=0;j--) {
if(dt[fa[b][j]]>dt[1]) {
sum[fa[b][j]][0]=min(g[b][j][0][0]+sum[b][0],g[b][j][1][0]+sum[b][1]);
sum[fa[b][j]][1]=min(g[b][j][0][1]+sum[b][0],g[b][j][1][1]+sum[b][1]);
b=fa[b][j];pos[++cnt]=b;
}
}
if(b!=1) {
sum[1][0]=min(sum[b][0]+g[b][0][0][0],sum[b][1]+g[b][0][1][0]);
sum[1][1]=min(sum[b][0]+g[b][0][0][1],sum[b][1]+g[b][0][1][1]);
pos[++cnt]=1;
}
ans=min(sum[1][0],sum[1][1]);
}
else {
// printf("It is in the case 2:\n");
sum[a][x]=f[a][x];pos[++cnt]=a;
// printf("sum[%lld][%lld]=%lld\n",a,x,sum[a][x]);
for(ll j=lg[dt[a]-dt[lca]];j>=0;j--) {
if(dt[fa[a][j]]>dt[lca]) {
sum[fa[a][j]][0]=min(g[a][j][0][0]+sum[a][0],g[a][j][1][0]+sum[a][1]);
sum[fa[a][j]][1]=min(g[a][j][0][1]+sum[a][0],g[a][j][1][1]+sum[a][1]);
a=fa[a][j];pos[++cnt]=a;
}
}
// printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",a,sum[a][0],a,sum[a][1]);
sum[b][y]=f[b][y];pos[++cnt]=b;
// printf("sum[%lld][%lld]=%lld\n",b,y,sum[b][y]);
for(ll j=lg[dt[b]-dt[lca]];j>=0;j--) {
if(dt[fa[b][j]]>dt[lca]) {
sum[fa[b][j]][0]=min(g[b][j][0][0]+sum[b][0],g[b][j][1][0]+sum[b][1]);
sum[fa[b][j]][1]=min(g[b][j][0][1]+sum[b][0],g[b][j][1][1]+sum[b][1]);
b=fa[b][j];pos[++cnt]=b;
}
}
// printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",b,sum[b][0],b,sum[b][1]);
sum[lca][1]=cst[lca];sum[lca][0]=0;
// printf("lca=%lld\n",lca);
for(ll j=head[lca];j;j=nxt[j]) {
if(ver[j]==fa[lca][0]) continue;
if(ver[j]==a||ver[j]==b) {
sum[lca][0]+=sum[ver[j]][1];
sum[lca][1]+=min(sum[ver[j]][0],sum[ver[j]][1]);
// printf("ver=%lld , sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",ver[j],ver[j],sum[ver[j]][0],ver[j],sum[ver[j]][1]);
}
else {
sum[lca][0]+=f[ver[j]][1];
sum[lca][1]+=min(f[ver[j]][0],f[ver[j]][1]);
}
}
pos[++cnt]=lca;
// printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",lca,sum[lca][0],lca,sum[lca][1]);
for(ll j=lg[dt[lca]];j>=0;j--) {
if(dt[fa[lca][j]]>dt[1]) {
sum[fa[lca][j]][0]=min(g[lca][j][0][0]+sum[lca][0],g[lca][j][1][0]+sum[lca][1]);
sum[fa[lca][j]][1]=min(g[lca][j][0][1]+sum[lca][0],g[lca][j][1][1]+sum[lca][1]);
lca=fa[lca][j];pos[++cnt]=lca;
}
}
if(lca!=1) {
sum[1][0]=sum[lca][1]+g[lca][0][1][0];
sum[1][1]=min(sum[lca][0]+g[lca][0][0][1],sum[lca][1]+g[lca][0][1][1]);
pos[++cnt]=1;
}
ans=min(sum[1][0],sum[1][1]);
}
// for(ll i=1;i<=n;i++) {
// printf("sum[%lld][0]=%lld , sum[%lld][1]=%lld\n",i,sum[i][0],i,sum[i][1]);
// }
if(ans>=inf) writeln(-1);
else writeln(ans);
restitute();
}
return 0;
}
代码(暴力):
#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
const ll inf=(1ll<<40),N=1e5;
ll n,m,u,v,tot,cnt,a,b,x,y,ans,lca;
char type[5];
ll lg[N+5],dt[N+5],fa[N+5][20],g[N+5][2],pos[N+5],f[N+5][2],sum[N+5][2];
ll cst[N+5],ver[N*2+5],nxt[N*2+5],head[N+5];
void dfs(ll p,ll fath) {
f[p][1]=cst[p];dt[p]=dt[fath]+1;fa[p][0]=fath;
for(ll i=1;i<=lg[dt[p]];i++) {
fa[p][i]=fa[fa[p][i-1]][i-1];
}
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
dfs(ver[i],p);
f[p][1]+=min(f[ver[i]][0],f[ver[i]][1]);
f[p][0]+=f[ver[i]][1];
}
}
ll getlca(ll a,ll b) {
if(dt[a]<dt[b]) swap(a,b);
while(dt[a]>dt[b]) a=fa[a][lg[dt[a]-dt[b]]-1];
if(a==b) return a;
for(ll k=lg[dt[a]];k>=0;k--) {
if(fa[a][k]!=fa[b][k]) {
a=fa[a][k];b=fa[b][k];
}
}
return fa[a][0];
}
void dp(ll p,ll fath) {
if(p==a) return;
g[p][1]=cst[p];g[p][0]=0;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
dp(ver[i],p);
if(p!=b) {
g[p][1]+=min(g[ver[i]][0],g[ver[i]][1]);
g[p][0]+=g[ver[i]][1];
}
else {
if(y==1) {g[p][1]+=min(g[ver[i]][0],g[ver[i]][1]);g[p][0]=inf;}
if(y==0) {g[p][0]+=g[ver[i]][1];g[p][1]=inf;}
}
}
}
void _dp(ll p,ll fath) {
if(p==lca) return;
g[p][1]=cst[p];g[p][0]=0;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
_dp(ver[i],p);
g[p][1]+=min(g[ver[i]][0],g[ver[i]][1]);
g[p][0]+=g[ver[i]][1];
}
}
void __dp(ll p,ll fath) {
if(p==a||p==b) return;
g[p][1]=cst[p];g[p][0]=0;
for(ll i=head[p];i;i=nxt[i]) {
if(ver[i]==fath) continue;
__dp(ver[i],p);
g[p][1]+=min(g[ver[i]][0],g[ver[i]][1]);
g[p][0]+=g[ver[i]][1];
}
}
void init() {
for(ll i=1;i<=n;i++) {
g[i][0]=g[i][1]=inf;
}
}
void add(ll u,ll v) {
ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;
}
inline ll read() {
ll ret=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') f=-f;ch=getchar();}
while(ch>='0'&&ch<='9') {ret=(ret<<3)+(ret<<1)+ch-'0';ch=getchar();}
return ret*f;
}
void write(ll x) {
static char buf[22];static ll len=-1;
if(x>=0) {
do{buf[++len]=x%10+48;x/=10;}while(x);
}
else {
putchar('-');
do{buf[++len]=-(x%10)+48;x/=10;}while(x);
}
while(len>=0) putchar(buf[len--]);
}
void writeln(ll x) {
write(x);putchar('\n');
}
int main() {
n=read();m=read();cin>>type;
for(ll i=1;i<=n;i++) {
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
}
for(ll i=1;i<=n;i++) {
cst[i]=read();
}
for(ll i=1;i<n;i++) {
u=read();v=read();
add(u,v);add(v,u);
}
dfs(1,0);
// for(ll i=1;i<=n;i++) {
// printf("f[%lld][0]=%lld , f[%lld][1]=%lld\n",i,f[i][0],i,f[i][1]);
// }
for(ll i=1;i<=m;i++) {
init();
a=read();x=read();b=read();y=read();
if(dt[a]<dt[b]) {swap(a,b);swap(x,y);}
lca=getlca(a,b);
if(lca==b) {
g[a][x]=f[a][x];dp(lca,fa[lca][0]);
}
else {
g[a][x]=f[a][x];g[b][y]=f[b][y];
__dp(lca,fa[lca][0]);
}
_dp(1,0);
if(b!=1) ans=min(g[1][0],g[1][1]);
else ans=g[b][y];
// printf("ans=%lld\n",ans);
// for(ll j=1;j<=n;j++) {
// printf("g[%lld][0]=%lld , g[%lld][1]=%lld\n",j,g[j][0],j,g[j][1]);
// }
if(ans>=inf) writeln(-1);
else writeln(ans);
}
return 0;
}
代码(数据生成):
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<ctime>
#define ll long long
using namespace std;
ll n,m,fa,a,b,x,y;
ll random(ll x) {
return rand()*rand()%x;
}
int main() {
freopen("r.in","w",stdout);
srand((unsigned)time(0));
n=random(2000)+5;m=random(2000)+3;
printf("%lld %lld C3\n",n,m);
for(ll i=1;i<=n;i++) {
printf("%lld ",random(5)+1);
}
printf("\n");
for(ll i=1;i<n;i++) {
fa=random(i)+1;printf("%lld %lld\n",fa,i+1);
}
for(ll i=1;i<=m;i++) {
a=random(n)+1;b=a;
while(b==a) b=random(n)+1;
x=random(2);y=random(2);
printf("%lld %lld %lld %lld\n",a,x,b,y);
}
fclose(stdout);
return 0;
}