日照夏令营 day2 regress 题解

· · 个人记录

做法

对于一个点的 k 级祖先,可以用倍增或长链剖分求出,十分容易。

对于一个点的所有 k 级子孙,因为一个点 ik 级子孙的深度是 dep_i - k,所以要查询的是,以 i 为根的子树中,有多少个深度为 dep_i + k 的。
所以可以对于每个点 i,建一棵权值线段树,维护以 i 为根的子树中深度为 j(代表任意深度) 的节点有多少个。
这样通过线段树合并就能处理出所有信息了。

Code

#include <bits/stdc++.h>
using namespace std;
namespace Main
{
    typedef long long ll;
    const int maxn=3e5+5;
    int n,m;
    int log[maxn];
    int head[maxn];
    struct EDGE
    {
        int to,nxt;
    }edge[maxn<<1];
    int cnt=0;
    inline void add(int u,int to)
    {
        edge[++cnt].to=to;
        edge[cnt].nxt=head[u];
        head[u]=cnt;
    }
    int h[maxn],hs[maxn],fa[20][maxn];
    int dep[maxn];
    int ans[maxn];
    //这里的h[i]是以i为根的子树的深度
    int count[maxn];
    //count[i]是深度为i的点的个数
    void dfs1(int u,int _fa)
    {
        h[u]=dep[u]=dep[_fa]+1;
        fa[0][u]=_fa;
        count[dep[u]]++;
        for(int i=1;i<=19;i++)
        {
            fa[i][u]=fa[i-1][fa[i-1][u]];
        }
        for(int i=head[u];i;i=edge[i].nxt)
        {
            int to=edge[i].to;
            if(to==_fa)continue;
            dfs1(to,u);
            h[u]=max(h[u],h[to]);
            if(h[to]>h[hs[u]])hs[u]=to;
        }
        h[u]++;
    }
    int top[maxn];
    vector<int> up[maxn],down[maxn];
    void dfs2(int u,int _top)
    {
        top[u]=_top;
        if(u==top[u])
        {
            for(int i=0,now=u;i<=h[u]-dep[u];i++)
            {

                up[u].emplace_back(now);
                now=fa[0][now];
            }
            for(int i=0,now=u;i<=h[u]-dep[u];i++)
            {
                down[u].emplace_back(now);
                now=hs[now];
            }
        }
        if(hs[u])
        {
            dfs2(hs[u],_top);
        }
        for(int i=head[u];i;i=edge[i].nxt)
        {
            int to=edge[i].to;
            if(to==fa[0][u]||to==hs[u])continue;
            dfs2(to,to);
        }
    }
    int ask(int x,int k)
    {
        if(k==0)
        {
            return x;
        }
        int mbsd=dep[x]-k;//目标深度
        if(mbsd<=0)return -1;
        int dqd=top[fa[log[k]][x]];
        if(dep[dqd]<mbsd)
        {
            dqd=down[dqd][mbsd-dep[dqd]];
        }
        if(dep[dqd]>mbsd)
        {
            dqd=up[dqd][dep[dqd]-mbsd];
        }
        return dqd;
    }
    struct Tree
    {
        int ls,rs,val;
    }tree[maxn*100];
    int nodecnt;
    inline void push_up(int node)
    {
        tree[node].val=tree[tree[node].ls].val+tree[tree[node].rs].val;
    }
    int rt[maxn];
    int modify(int node,int l,int r,int pos,int val)
    {
        if(!node)node=++nodecnt;
        if(l==r)
        {
            tree[node].val+=val;
            return node;
        }
        int mid=l+r>>1;
        if(mid>=pos)
        {
            tree[node].ls=modify(tree[node].ls,l,mid,pos,val);
        }
        else tree[node].rs=modify(tree[node].rs,mid+1,r,pos,val);
        push_up(node);
        return node;
    }
    int merge(int a,int b,int l,int r)
    {
        if(!a||!b)return a|b;
        int root=++nodecnt;
        if(l==r)
        {
            tree[root].val=tree[a].val+tree[b].val;
            return root;
        }
        int mid=l+r>>1;
        tree[root].ls=merge(tree[a].ls,tree[b].ls,l,mid);
        tree[root].rs=merge(tree[a].rs,tree[b].rs,mid+1,r);
        push_up(root);
        return root;
    }
    int query(int node,int l,int r,int pos)
    {
        if(l==r)
        {
            return tree[node].val;
        }
        int mid=l+r>>1;
        int ans=0;
        if(mid>=pos)ans=query(tree[node].ls,l,mid,pos);
        else ans=query(tree[node].rs,mid+1,r,pos);
        return ans;
    }
    struct Question
    {
        int k,id;
        Question(int k2,int id2)
        {
            k=k2;
            id=id2;
        }
    };
    vector<Question> qs[maxn];
    void solve(int u,int _fa)
    {
        for(int i=head[u];i;i=edge[i].nxt)
        {
            int to=edge[i].to;
            if(to==_fa)continue;
            solve(to,u);
            rt[u]=merge(rt[u],rt[to],1,n+1);
        }
        rt[u]=modify(rt[u],1,n+1,dep[u],1);
        for(int i=0;i<qs[u].size();i++)
        {
            ans[qs[u][i].id]=query(rt[u],1,n+1,dep[u]+qs[u][i].k)-1;
            if(ans[qs[u][i].id]<0)ans[qs[u][i].id]=0;
        }
    }
    void main()
    {
        scanf("%d",&n);
        for(int i=2;i<=n;i++)
        {
            log[i]=log[i>>1]+1;
        }
        int __fa;
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&__fa);
            add(i,__fa);
            add(__fa,i);
        }
        dfs1(0,-1);
        dfs2(0,0);
        scanf("%d",&m);
        for(int i=1;i<=m;i++)
        {
            int xi,ki;
            scanf("%d%d",&xi,&ki);
            int zx=ask(xi,ki);
            if(zx!=-1)
            {
                qs[zx].emplace_back(ki,i);
            }

        }
        solve(0,-1);
        for(int i=1;i<=m;i++)
        {
            printf("%d ",ans[i]);
        }
    }
}
int main()
{
    Main::main();
    return 0;
}