11.12 T2

· · 个人记录

没时间了,随便口胡一下吧。没听评讲。

我们对每一个点计算贡献,答案就是所有点贡献的和。比如 x 的贡献就是 分母为 i = root \to x_{fa}|sum(i) 乘起来,分子为 i = nxt(root) \to x | w(i) 乘起来,再对着这个分数乘 a_i。

我们先改 w_u,再改 a_u。

对于 w(u) 的修改,会直接修改 sum_{fa_u},还会修改 w(u)。我们看这会对哪些点的贡献产生影响,发现会影响 fa 子树,不包括 fa,的分母。会影响 u 子树,包括 u,的分子。如果我们把这棵树变成 dfs 序,就是区间乘。

然后对于 a_u 的修改,相当于对于点 u 的值先区间除再区间乘,或者是对于差值区间加,不过这应该不太重要吧。

于是得到了一个需要维护区间乘法,区间加法,区间求和,或者只是区间乘法,区间求和的线段树。

参考线段树 2 吧。

#include<bits/stdc++.h>
using namespace std;
#define int long long
constexpr int N = 1e5+10, mod = 998244353;

int mi(int x,int k) {
    return k?1ll*mi(1ll*x*x%mod,k/2)*(k%2?x:1)%mod:1;
}

int inv(int x) {
    return mi(x,mod-2);
}
int n;
vector<int>e[N]; int fa[N];
int w[N],a[N],sum[N],val[N];

int L[N],R[N],d[N],dn;
void dfs_val(int u) {
    L[u] = ++dn;
    d[dn] = u;
    for (int v : e[u]) {
        // cout << u << ' '<< v << ' ' << val[u] << ' ' << w[v] << ' ' << sum[u] << endl;
        val[v] = 1ll*val[u] * w[v] % mod * inv(sum[u]) % mod;
        dfs_val(v);
    }
    R[u] = dn;
}

#define mid ((l+r)>>1)
#define ls (u*2)
#define rs (u*2+1)
struct Node{
    int val,tag;
}c[N<<2];

void pushup(int u) {
    c[u].val = (c[ls].val + c[rs].val) % mod;
}
void build(int u,int l,int r) {
    c[u].tag=1;
    if (l == r) {
        c[u].val = val[d[l]] * a[d[l]] % mod;
        return ;
    }
    build(ls,l,mid),build(rs,mid+1,r);
    pushup(u);
}
void addtag(int u,int x) {
    (c[u].val *= x) %= mod;
    (c[u].tag *= x) %= mod;
}

void pushdw(int u) {
    if (c[u].tag != 1) {
        addtag(ls,c[u].tag), addtag(rs,c[u].tag);
        c[u].tag = 1;
    }
}
void upd(int u,int l,int r,int L,int R,int x) {
    if (L <= l && r<= R) {
        addtag(u,x);
        return ;
    }
    pushdw(u);
    if (L <= mid) upd(ls,l,mid,L,R,x);
    if (R > mid) upd(rs,mid+1,r,L,R,x);
    pushup(u);
}
signed main() {
    //freopen("b.in","r",stdin);
    //freopen("b.out","w",stdout);
    cin.tie(nullptr) -> sync_with_stdio(false);
    cin>>n;
    // cout << "DOWN";
    for (int i=2;i<=n;i++) {
        int x;
        cin>>x;
        fa[i] = x;
        e[x].push_back(i);
    }

    for (int i=1;i<=n;i++) {cin>>w[i];}
    for (int i=1;i<=n;i++) {cin>>a[i];}
    for (int i=1;i<=n;i++) {
        for (int v : e[i]) {
            sum[i] += w[v];
            sum[i] %= mod;
        }
    }
    val[1] = 1;dfs_val(1);
    build(1,1,n);

    cout << c[1].val % mod<< endl;
    int Q; cin>>Q;
    while(Q--) {
        int u,neww,newa;
        cin>>u>>neww>>newa;
        int fu = fa[u],newsum = sum[fu] - w[u] + neww;
        int invw = inv(w[u]), invsum = inv(newsum), inva = inv(a[u]);
        // cout << L[fu] << ' ' << R[fu] << endl;
        if (u != 1) {
            upd(1,1,n,L[fu]+1,R[fu],sum[fu]);
            upd(1,1,n,L[fu]+1,R[fu],invsum);
            upd(1,1,n,L[u],R[u],invw);
            upd(1,1,n,L[u],R[u],neww);
        }
        upd(1,1,n,L[u],L[u],inva);
        upd(1,1,n,L[u],L[u],newa);

        cout << c[1].val % mod << '\n';

        sum[fu] = newsum;
        w[u] = neww;
        a[u] = newa;
    }
    return 0;
}