树链剖分get!

· · 个人记录

模板:

传送门

了解了线段树和树剖的思路后

唯一的难点就是

如何将二者结合

利用 dfsx 数组给节点标dfs序

在用 pos 和 end 数组记录每个子树的起始位置

利用dfs序建线段树

便有如下代码

#define ll long long
#define maxn 100010

ll p,bui=1;
int a[maxn],n,m,root,x,y,z,flag;
int head[maxn],cnt;
int siz[maxn],son[maxn],top[maxn],fa[maxn],dep[maxn];
int end[maxn],dfsx[maxn],pos[maxn];

struct node
{
    int to,next;
}map[maxn*2];
struct Node
{
    ll w,l,r,tag;
}tree[4*maxn];

void add(int u,int v)
{
    map[++cnt].to = v;
    map[cnt].next = head[u];
    head[u] = cnt;
}

void build(int k,int L,int R)
{
    tree[k].l = L;
    tree[k].r = R;
    if(L == R)
    {
        tree[k].w = a[dfsx[bui++]];
        return;
    }
int mid = (L+R)/2;
    build(k*2,L,mid);
    build(k*2+1,mid+1,R);
    tree[k].w = tree[k*2].w+tree[k*2+1].w;
}

void down(int k)
{
int L = k*2;
int R = k*2+1;
    tree[L].tag += tree[k].tag;
    tree[R].tag += tree[k].tag;
    tree[L].w   += tree[k].tag*(tree[L].r-tree[L].l+1);
    tree[R].w   += tree[k].tag*(tree[R].r-tree[R].l+1);
    tree[k].tag = 0;
}

void add_line(int k,int L,int R,ll W)
{
int l = tree[k].l;
int r = tree[k].r;
    if(l>=L && R>=r)
    {
        tree[k].tag += W;
        tree[k].w += W*(r-l+1);
        return;
    }
    down(k);
int mid = (l+r)/2;
    if(L <= mid)    add_line(k*2,L,R,W);
    if(R >  mid)    add_line(k*2+1,L,R,W);
    tree[k].w = tree[k*2].w+tree[k*2+1].w;
}

ll query_line(int k,int L,int R)
{
ll a;
int l = tree[k].l;
int r = tree[k].r;
    if(l>=L && r<=R)    return tree[k].w;
    if(l>R || L>r)      return 0;
    down(k);
int mid = (l+r)/2;
    a =  query_line(k*2,L,R);
    a += query_line(k*2+1,L,R);
    return a;
}

void dfs1(int u, int FA, int deep)
{
    fa[u] = FA;
    dep[u] = deep;
    siz[u] = 1;
    for(int k=head[u];k;k=map[k].next)
    {
int v = map[k].to;
        if(v != fa[u])
        {
            dfs1(v,u,deep+1);
            siz[u] += siz[v];
            if(siz[v] > siz[son[u]])    son[u] = v;
        }
    }
}

void dfs2(int u,int tp)
{
    top[u] = tp;
    dfsx[cnt++] = u;
    if(son[u])  dfs2(son[u],tp);
    for(int k=head[u];k;k=map[k].next)
    {
int v = map[k].to;
        if(v!=fa[u] && v!=son[u])   dfs2(v,v);
    }
    end[u] = cnt-1;
}

void add_path(int x,int y,int z)
{
int fx = top[x];
int fy = top[y];
    while(fx != fy)
    {
        if(dep[fx] >= dep[fy])
        {
            add_line(1,pos[fx],pos[x],z);
            x = fa[fx]; fx = top[x];
        }
        else 
        {
            add_line(1,pos[fy],pos[y],z);
            y = fa[fy]; fy = top[y];
        }
    }
    if(pos[x] < pos[y]) add_line(1,pos[x],pos[y],z);
    else                add_line(1,pos[y],pos[x],z);
}

ll query_path(int x,int y)
{
ll ans = 0;
int fx = top[x];
int fy = top[y];
    while(fx != fy)
    {
        if(dep[fx] >= dep[fy])
        {
            ans += query_line(1,pos[fx],pos[x]);
            x = fa[fx]; fx = top[x];
        }
        else 
        {
            ans += query_line(1,pos[fy],pos[y]);
            y = fa[fy]; fy = top[y];
        }
    }
    if(pos[x] < pos[y]) ans += query_line(1,pos[x],pos[y]);
    else                ans += query_line(1,pos[y],pos[x]);
    return ans;
}

int main()
{
    cin >> n >> m >> root >> p;
    for(int i=1;i<=n;i++)   cin >> a[i];
    for(int i=1;i<n;i++)
    { 
        cin >> x >> y;
        add(x,y);
        add(y,x);
    }
    cnt = 1;
    dfs1(root,root,1);
    dfs2(root,root);
    for(int i=1;i<cnt;i++)  pos[dfsx[i]] = i;
    build(1,1,n);
    while(m--)
    {
        cin >> flag;
        switch(flag)
        {
            case 1:
                cin >> x >> y >> z;
                add_path(x,y,z);
                break;
            case 2:
                cin >> x >> y;
                cout << query_path(x,y) % p << endl;
                break;
            case 3:
                cin >> x >> z;
                add_line(1,pos[x],end[x],z);
                break;
            case 4:
                cin >> x;
                cout << query_line(1,pos[x],end[x]) % p << endl;
                break;
        }
    }
    return 0;
}

思考:

传送门

(前排提示:结合样例食用效果更佳)

一遍切紫啊啊啊啊啊!!!!

代码没什么难度,重在思考。

首先就是,为什么要用树剖,

这个题,直接想是没有思路的,所以要换个角度,

对于一棵树,什么时候会用到额外的路径?

当然是当前节点所在的链与整棵树断开的时候。

所以,当额外路径两端的点所在的链被断开时,则要用到某一额外路径,

这就很显然了,非树剖莫属,此时额外路径长度就相当于加在两点间链上的权值

不过值得注意的是,每次添加是区间修改,向下传递时要取最小值(详见 down() 函数)

#define PII pair<int,int>
#define MAXN 50005
#define _L k<<1
#define _R k<<1|1

int n,m,cnt,tot;
int head[MAXN];
int siz[MAXN],son[MAXN],dep[MAXN],FA[MAXN];
int DFN[MAXN],Top[MAXN];
queue<PII> que;

struct node{
    int to,next;
}map[MAXN*2];

struct Node{
    int L,R,w;
}tree[MAXN*4];

void add(int u, int v)
{
    map[++cnt] = (node){v,head[u]};
    head[u] = cnt;
}

void dfs1(int u, int fa)
{
    int maxn = 0;
    siz[u] = 1;
    son[u] = 0;
    FA[u] = fa;
    dep[u] = dep[fa]+1;
    for(int k=head[u];k;k=map[k].next)
    {
        int v = map[k].to;
        if(v == fa) continue;
        dfs1(v,u);
        siz[u] += siz[v];
        if(siz[v] > maxn)
        {
            maxn = siz[v];
            son[u] = v;
        }
    }
}

void dfs2(int u, int fa, int top)
{
    Top[u] = top;
    DFN[u] = ++tot;
    if(son[u])  dfs2(son[u],u,top);
    for(int k=head[u];k;k=map[k].next)
    {
        int v = map[k].to;
        if(v==fa || v==son[u])  continue;
        dfs2(v,u,v);
    }
}

void build(int L, int R, int k)
{
    tree[k].w = 0x3f3f3f3f;
    tree[k].L = L;
    tree[k].R = R;
    if(L == R)  return;
    int mid = (L+R)>>1;
    build(L,mid  ,_L);
    build(mid+1,R,_R);
}

void down(int k)
{
    tree[_L].w = min(tree[k].w,tree[_L].w);
    tree[_R].w = min(tree[k].w,tree[_R].w);
}

void update(int L, int R, int k, int w)
{
    int l = tree[k].L;
    int r = tree[k].R;
    if(l>=L && r<=R)
    {
        tree[k].w = min(tree[k].w,w);
        return;
    }
    down(k);
    int mid = (l+r)>>1;
    if(L <= mid)    update(L,R,_L,w);
    if(R >  mid)    update(L,R,_R,w);
}

void LCA(int x, int y, int w)
{
    int nx = Top[x];
    int ny = Top[y];
    while(nx != ny)
    {
        if(dep[nx] < dep[ny])
        {
            swap( x, y);
            swap(nx,ny);
        }
        update(DFN[nx],DFN[x],1,w);
        x  = FA[nx];
        nx = Top[x];
    }
    if(x == y)  return;
    if(dep[x] > dep[y]) swap(x,y);
    update(DFN[x]+1,DFN[y],1,w);//这里DFN[x]要加一 ,因为这种情况下两个点在一条链上 
}

int query(int k, int x)
{
    int L = tree[k].L;
    int R = tree[k].R;
    if(L == R)  return tree[k].w;
    down(k);
    int mid = (L+R)/2;
    if(x <= mid)    return query(_L,x); else
    if(x >  mid)    return query(_R,x);
}

int main(void)
{
    cin >> n >> m;
    for(int i=1,u,v;i<n;i++)
    {
        cin >> u >> v;
        add(u,v);
        add(v,u);
        que.push(make_pair(u,v));
    }
    dfs1(1,0);
    dfs2(1,0,1);
    build(1,n,1);
    for(int i=1,u,v,w;i<=m;i++)
    {
        cin >> u >> v >> w;
        LCA(u,v,w);
    }
    for(int i=1;i<n;i++)
    {
        int x = que.front().first ;
        int y = que.front().second;
        que.pop();
        if(dep[x] < dep[y]) swap(x,y);
        int ans = query(1,DFN[x]);
        if(ans != 0x3f3f3f3f)   cout << ans  << endl;
        else                    cout << "-1" << endl;
    }
    return 0;
}

重构:

传送门

以前的码风太糟糕了,而且也没有完全掌握树链剖分 + 线段树,所以决定再找一道题做一次。

感觉现在的码风好看多了,嗝儿~

#define _L k<<1
#define _R k<<1|1

int n,q,cnt;
int son[MAXN],FA[MAXN],size[MAXN],dep[MAXN];
int Top[MAXN],DFN[MAXN],End[MAXN];
char c[100];

struct Node{
    int L,R,w; //w为已需要安装的个数
    int tag;
}tree[4*MAXN];

void down(int k)
{
    if(tree[k].tag == -1) return;
    tree[_L].tag = tree[_R].tag = tree[k].tag;
    tree[_L].w = tree[k].tag*(tree[_L].R-tree[_L].L+1);
    tree[_R].w = tree[k].tag*(tree[_R].R-tree[_R].L+1);
    tree[ k].tag = -1;
}

void dfs1(int u, int fa)
{
    FA[u] = fa;
    size[u] = 1;
    dep[u] = dep[fa]+1;
    for(int k=head[u];k;k=map[k].next)
    {
        int v = map[k].to;
        if(v == fa) continue;
        dfs1(v,u);
        size[u] += size[v];
        if(size[v] > size[son[u]]) son[u] = v;
    }
}

void dfs2(int u, int top)
{
    End[u] = u;
    Top[u] = top;
    DFN[u] = ++cnt;
    if(!son[u]) return;
    dfs2(son[u],top);
    if(DFN[u] < DFN[son[u]]) End[u] = End[son[u]];
    for(int k=head[u];k;k=map[k].next)
    {
        int v = map[k].to;
        if(v==FA[u] || v==son[u]) continue;
        dfs2(v,v);
        if(DFN[End[v]] > DFN[End[u]]) End[u] = End[v];
    }
}

void build(int L ,int R, int k)
{
    tree[k].tag = -1;
    tree[k].L = L;
    tree[k].R = R;
    if(L == R)
    {
        tree[k].w = 1;
        return;
    }
    int mid = (L+R)>>1;
    build(L,mid,_L);
    build(mid+1,R,_R);
    tree[k].w = tree[_L].w+tree[_R].w;
}

void update(int l, int r, int k, int F)
{
    int L = tree[k].L;
    int R = tree[k].R;
    if(L>=l && R<=r)
    {
        tree[k].w = F*(R-L+1);
        tree[k].tag = F;
        return;
    }
    down(k);
    int mid = (L+R)>>1;
    if(l <= mid) update(l,r,_L,F);
    if(r >  mid) update(l,r,_R,F);
    tree[k].w = tree[_L].w+tree[_R].w;
}

int query(int l, int r, int k) //查询区间内 1 的个数
{
    int ans = 0;
    int L = tree[k].L;
    int R = tree[k].R;
    if(L>=l && R<=r) return tree[k].w;
    down(k);
    int mid = (L+R)>>1;
    if(l <= mid) ans += query(l,r,_L);
    if(r >  mid) ans += query(l,r,_R);
    return ans;
}

int install(int x, int y) //LCA
{
    int ans = 0;
    while(Top[x] != Top[y])
    {
        if(dep[x] < dep[y]) swap(x,y);
        ans += query(DFN[Top[x]],DFN[x],1);
        update(DFN[Top[x]],DFN[x],1,0);
        x = FA[Top[x]];
    }
    if(DFN[x] < DFN[y]) swap(x,y);
    ans += query(DFN[y],DFN[x],1);
    update(DFN[y],DFN[x],1,0);
    return ans;
}

int uninstall(int x, int y)
{
    int ans = DFN[y]-DFN[x]+1-query(DFN[x],DFN[y],1);
    update(DFN[x],DFN[y],1,1);
    return ans;
}

int main(void)
{
    cin >> n;
    for(int i=2;i<=n;i++)
    {
        int u;
        cin >> u;
        add(u+1,i);
        add(i,u+1);
    }
    cnt = 0;
    dfs1(1,0);  dfs2(1,1);
    build(1,n,1);
    cin >> q;
    for(int i=1,u;i<=q;i++)
    {
        scanf("%s",c);
        if(c[0] == 'i'){
            cin >> u;
            cout << install(++u,1) << endl;
        }
        else{
            cin >> u;   u++;
            cout << uninstall(u,End[u]) << endl;
        }
    }
    return 0;
}

复习:

传送门

感觉每次打的代码码风都不一样呢。。。

看了蓝皮书之后又有了些新感受,决定再复习一遍

#define _L k<<1
#define _R k<<1|1
#define L(a) tree[a].L
#define R(b) tree[b].R
#define W(a) tree[a].w
#define tag(a)tree[a].tag

int n,m,cnt,ans;
int val[MAXN],head[MAXN];
int dep[MAXN],size[MAXN],FA[MAXN],DFN[MAXN],son[MAXN];
int END[MAXN],RNK[MAXN],Top[MAXN];

struct node{
    int to,next;
}map[2*MAXN];

struct Node{
    int L,R;
    int w,tag;
}tree[4*MAXN];

void add(int u, int v)
{
    map[++cnt] = (node){v,head[u]};
    head[u] = cnt;
}
void dfs1(int u, int fa)
{
    size[u] = 1;
    dep[u] = dep[fa]+1;
    FA[u] = fa;
    for(int k=head[u];k;k=map[k].next)
    {
        int v = map[k].to;
        if(v == fa) continue;
        dfs1(v,u);
        size[u] += size[v];
        if(size[son[u]] < size[v]) son[u] = v;
    }
}
void dfs2(int u, int top)
{
    Top[u] = top;
    DFN[u] = ++cnt;
    END[u] = cnt;
    RNK[cnt] = u;
    if(!son[u]) return;
    dfs2(son[u],top);
    END[u] = END[son[u]];
    for(int k=head[u];k;k=map[k].next)
    {
        int v = map[k].to;
        if(v==FA[u] || v==son[u]) continue;
        dfs2(v,v);
        END[u] = max(END[u],END[v]);
    }
}
void down(int k)
{
    int T = tag(k); tag(k) = 0;
    tag(_L) += T;   W(_L) += T*(R(_L)-L(_L)+1);
    tag(_R) += T;   W(_R) += T*(R(_R)-L(_R)+1);
}
void build(int L, int R, int k)
{
    L(k) = L;   R(k) = R;
    if(L == R)  {W(k) = val[RNK[L]]; return;}
    int mid = (L+R)>>1;
    build(L,mid,_L);
    build(mid+1,R,_R);
    W(k) = W(_L)+W(_R);
}
void update_point(int k, int x, int w)
{
    if(L(k) == R(k)){W(k) += w; return;}
    down(k);
    int mid = (L(k)+R(k))>>1;
    if(x <= mid) update_point(_L,x,w);
    else         update_point(_R,x,w);
    W(k) = W(_L)+W(_R);
}
void update_section(int k, int l, int r, int w)
{
    if(L(k)>=l && R(k)<=r)
    {
        W(k) += w*(R(k)-L(k)+1);
        tag(k) += w;
        return;
    }
    down(k);
    int mid = (L(k)+R(k))>>1;
    if(l <= mid) update_section(_L,l,r,w);
    if(r >  mid) update_section(_R,l,r,w);
    W(k) = W(_L)+W(_R);
}
void query_section(int k, int l ,int r)
{
    if(L(k)>=l && R(k)<=r) {ans += tree[k].w; return;}
    down(k);
    int mid = (L(k)+R(k))>>1;
    if(l <= mid) query_section(_L,l,r);
    if(r >  mid) query_section(_R,l,r);
}
void LCA(int x, int y)
{
    int nx = Top[x];
    int ny = Top[y];
    while(nx != ny)
    {
        if(dep[x] < dep[y])
        {
            swap( x, y);
            swap(nx,ny);
        }
        query_section(1,DFN[nx],DFN[x]);
        x = FA[nx];
        nx = Top[x];
    }
    if(dep[x] < dep[y]) swap(x,y);
    query_section(1,DFN[y],DFN[x]);
}
signed main(void)
{
    cin >> n >> m;
    for(int i=1;i<=n;i++) cin >> val[i];
    for(int i=1;i< n;i++)
    {
        int u,v;
        cin >> u >> v;
        add(u,v);   add(v,u);
    }
    cnt = 0;
    dfs1(1,0);  dfs2(1,1);
    build(1,n,1);
    for(int i=1;i<=m;i++)
    {
        int x,a,b;
        cin >> x >> a;
        switch(x)
        {
        case 1:
            cin >> b;
            update_point(1,DFN[a],b);
            break;
        case 2:
            cin >> b;
            update_section(1,DFN[a],END[a],b);
            break;
        case 3:
            ans = 0;
            LCA(1,a);
            cout << ans << "\n";
            break;
        }
    }
    return 0;
}