P3241开店 分块题解

· · 题解

前言

现有题解只有动态点分治和树剖加主席树两种做法,都需要高难度的前置知识和代码能力。而本题解只需要分块、ST表和换根DP。码量主要靠多个简单部分堆砌,实现优良完全没有调试难度,对代码能力要求也不高

方法

假如把 n 个妖怪按年龄排序,查询就是一个连续的区间。考虑分块,设块长为 B。预处理每个块对每个点的贡献({a_i}_j 表示每个对于第 j 块,若询问的点为 i 的距离总和),这个我们若一个一个求显然会超时,若是对每个整块做一遍换根复杂度是 O(\frac{n}{B} \times n),然后查询是 O(q \times \frac{n}{B} \times k)。对于散块直接一个一个点加入需要求 lca 来得到距离,设求 lca 复杂度为 O(k),复杂度为 O(qBk)。对于第一个部分显然 B\ge O(\sqrt{n}) 复杂度才正确,那么考虑第二部分,显然我们需要 O(1) 求 lca,即 kO(1)。那么我们需要一个科技。

科技

首先 O(n+m) 的离线做法不行,而且也难度超标,常数也不优秀。所以我们使用在提高组难度范围内的算法:ST表。因为可以有足够的预处理时间,ST表此时成为了最优的选择。那么我们如何做呢?考虑查询 xy 的 lca。对于一个节点 uxy 是其两个不同的子树的节点,lca 一定是 u 或其祖先。那么我们考虑在每个节点前插入其父亲,此时节点 u 一定在节点 xy 之间。所以 xy 之间深度最小的点就是它们的 lca。具体实现只需要把节点按 dfn 序排序,并在每个节点前插入它的父亲,每次查询 [dfn_x,dfn_y] 区间的 dep 最小的节点即为它们的 lca。

实现

换根dp部分考虑对于每一个块,把块内所有点设为关键点,换根求出每个点与最近关键的距离。 分块部分整块直接使用预处理的部分,散块点 x 与点 u 的距离即为 dis_x+dis_u-2 \times dis_{lca_{u,v}}dis 为点到根节点的距离)。 空间可能会被卡,注意到空间为 nB,把块长开到 1.5\sqrt{n} 即可通过,时间和空间均可接受,微调块长或许更好。

代码

#include<bits/stdc++.h>
using namespace std;
#define pb push_back 
typedef pair<int,int> pr;
#define fi first
#define se second
typedef long long ll;
const int N=1.5e5+5;
const int M=395;
const ll inf=1e15+7ll;
const int N3=N<<1;
vector<pr> g[N];
struct LCA{ 
    int mx,cnt,rk[N3],a[N3],dep[N3],dfn[N3],rnk[N3];
    ll dis[N3];
    int lg[N3],mi[21],st[21][N<<1],id[21][N3];
    void dfs(int u,int fa){
        dep[u]=dep[fa]+1,dfn[u]=++cnt;
        a[cnt]=dep[u],rnk[cnt]=u;
        for(pr it:g[u]){
            int v=it.fi,w=it.se;
            if(v==fa) continue;
            dis[v]=dis[u]+w;
            dfs(v,u);
            a[++cnt]=dep[u],rnk[cnt]=u;
        }
    }
    void init_ST(){
        lg[1]=0;
        for(int i=2;i<=cnt;i++) lg[i]=lg[i>>1]+1;
        mx=lg[cnt];
        mi[0]=1;
        for(int i=1;i<=mx;i++) mi[i]=mi[i-1]<<1;
        for(int i=1;i<=cnt;i++) st[0][i]=a[i],id[0][i]=rnk[i];
        for(int i=1;i<=mx;i++){
            for(int j=1;j+mi[i]-1<=cnt;j++){
                int nx=j+mi[i-1];
                if(st[i-1][j]<st[i-1][nx]){
                    st[i][j]=st[i-1][j],id[i][j]=id[i-1][j];
                }else{
                    st[i][j]=st[i-1][nx],id[i][j]=id[i-1][nx];
                }
            }
        }
    }
    inline int lca(int u,int v){
        int l=dfn[u],r=dfn[v];
        if(l>r) swap(l,r);
        int k=lg[r-l+1],res;
        int nx=r-mi[k]+1;
        if(st[k][l]<st[k][nx]) res=id[k][l];
        else res=id[k][nx];
        return res;
    }
    inline ll d(int u,int v){
        return dis[u]+dis[v]-(dis[lca(u,v)]<<1);
    }
    void init(int rt){
        cnt=0;
        dfs(rt,0);init_ST();
    }
}tr;
struct dfsdp{
    ll f[N];int sz[N];
    void dfs1(int u,int fa){
        for(pr it:g[u]){
            int v=it.fi,w=it.se;
            if(v==fa) continue;
            dfs1(v,u);
            sz[u]+=sz[v];
            f[u]+=f[v]+1ll*w*sz[v];
        }
    }
    void dfs2(int u,int fa){
        for(pr it:g[u]){
            int v=it.fi,w=it.se;
            if(v==fa) continue;
            ll k=f[u]-f[v]-1ll*w*sz[v];
            f[v]+=k,f[v]+=1ll*(sz[1]-sz[v])*w;
            dfs2(v,u);
        }
    }
    inline void getdis(){
        dfs1(1,0),dfs2(1,0);
    }
}xx;
int n,m,ans,cnt,l[M],r[M],bel[N];ll ax[(int)(M/1.5)][N];
struct node{ int x,id; }a[N];
inline bool cmp(node ax,node ay){ return ax.x<ay.x;}
inline void bld(int x){
    for(int i=1;i<=n;i++) xx.f[i]=xx.sz[i]=0;
    for(int i=l[x];i<=r[x];i++) xx.sz[a[i].id]=1;
    xx.getdis();
    for(int i=1;i<=n;i++) ax[x][i]=xx.f[i];
}
inline ll get(int li,int ri,int x){
    ll res=0;
    for(int i=li;i<=ri;i++) res+=tr.d(x,a[i].id);
    return res;
}
inline ll qiu(int li,int ri,int x){
    if(li>ri) return 0;
    int bl=bel[li],br=bel[ri];
    if(bl==br) return get(li,ri,x);
    ll res=get(li,r[bl],x)+get(l[br],ri,x);
    for(int i=bl+1;i<br;i++) res+=ax[i][x];
    return res;
}
void init(){
    int k=min(n,(int)(sqrt(n)*1.5)),nw=1,la=0;
    while(nw<n){
        nw=min(nw+k,n),cnt++;
        l[cnt]=la,r[cnt]=nw;
        for(int i=la;i<=nw;i++) bel[i]=cnt;
        la=nw+1;
    }
    for(int i=1;i<=cnt;i++) bld(i);
}
struct bs{
    inline int findl(int x){
        int l=1,r=n,pos=n+1;
        while(l<=r){
            int mid=(l+r)>>1;
            if(a[mid].x>=x) r=mid-1,pos=mid;
            else l=mid+1;
        }
        return pos;
    }
    inline int findr(int x){
        int l=1,r=n,pos=0;
        while(l<=r){
            int mid=(l+r)>>1;
            if(a[mid].x<=x) l=mid+1,pos=mid;
            else r=mid-1;
        }
        return pos;
    }
}ef;
int mod;
int main(){
    scanf("%d%d%d",&n,&m,&mod);
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i].x);a[i].id=i;
    }
    sort(a+1,a+n+1,cmp);
    int u,v,w;
    for(int i=1;i<n;i++){
        scanf("%d%d%d",&u,&v,&w);
        g[u].pb({v,w}),g[v].pb({u,w});
    }
    tr.init(1);
    init();
    int li,ri,aa,bb;
    ll las=0;
    for(int i=1;i<=m;i++){
        scanf("%d%d%d",&u,&aa,&bb);
        int li=(las+aa)%mod,ri=(las+bb)%mod;
        if(li>ri) swap(li,ri);
        int nl=ef.findl(li),nr=ef.findr(ri);
//      printf("111 l:%d r:%d\n",li,ri);
//      printf("000 L:%d R:%d\n",nl,nr);
        las=qiu(nl,nr,u);
        printf("%lld\n",las);
    }
    return 0;
}