P16829

· · 题解

官方题解使用了震撼的根号分治,但是只需要仔细观察就可以发现,这个题可以质因数分解,所以根号分治毫无意义。

考虑对每个 a_i 质因数分解。对于 a_i 的每一个形如 p^k 的因子,我们都容易维护其被操作一删除的时间。例如我们对 120 = 2^3 \times 3^1 \times 5^1 进行一次操作之后,假如其变为了 10 = 2 \times 5,那么我们认为 2^2, 2^33^1 被删除了。一共显然只会有 O(n \log a) 次删除。

对于剩下的部分,由于本题询问的特殊性,实际上是不需要像官方题解一样上 cdq 大力数颜色的。

对于每一次删除,假如我们在 u 号点删除了 p^k,考虑询问哪些位置时,答案会变小。显然就是子树内除了 u 号点以外不包含 p^k 的所有点的答案会除掉一个 p。不难发现,所有这样的点构成一条链。相当于要找 u 的某个深度最大的祖先满足其子树内包含其他的 p^k。我们可以找到离 u 号点前后分别最近(在 DFS 序意义下)的包含 p^k 的点 v_l,v_r,显然 \operatorname{LCA}(u,v) 及其祖先的答案均不会变,因此这条链的端点也是好找的。

至于如何找包含 p^k 的点,直接全扔进 set 里面即可。

链除,单点查询,树状数组即可。时间复杂度 O(n \log n \log a)

代码不是特别难写,但是比较长。

#include<bits/stdc++.h>
using namespace std;
#define mod 998244353
#define N 100005
int n,Q;
int qp(int x,int y){
    int ans=1;
    while(y){
        if(y&1)ans=1LL*ans*x%mod;
        x=1LL*x*x%mod;
        y>>=1;
    }
    return ans;
}
int a[N],m=100000;
vector<int>T[N];
int ft[N],ift[N],inv[N],pa[N];
int vis[N],p[N],cnt=0,lpf[N];
vector<pair<int,int> >vd[N];
struct Query{
    int op,x,y;
}q[N];
int dep[N],st[20][N],dfn[N],tot=0;
int dfr[N],nfd[N];
void dfs0(int x,int fa){
    dfn[x]=dfr[x]=++tot;
    nfd[tot]=x;
    pa[x]=fa;
    st[0][dfn[x]]=dfn[fa];
    for(int v:T[x])if(v!=fa){
        dfs0(v,x);
        dfr[x]=dfr[v];
    }
}
int lca(int x,int y){
    if(x==y)return x;
    if(x>y)swap(x,y);
    int d=__lg(y-x);
    return min(st[d][x+1],st[d][y-(1<<d)+1]);
}
int tag[N*4];
void update(int l,int r,int x,int L,int R,int v,int t){
    if(L<=l && r<=R){
        int res=__gcd(tag[x],v);
        if(res==tag[x])return;
        tag[x]=res;
    }
    if(l==r){
        int id=nfd[l];
        int tmp=__gcd(a[id],v);
        if(a[id]!=tmp)vd[t].push_back({id,a[id]});
        a[id]=tmp;
        return;
    }
    int mid=(l+r)/2;
    if(L<=mid)update(l,mid,x*2,L,R,v,t);
    if(R>mid)update(mid+1,r,x*2+1,L,R,v,t);
}
set<int>s[N];
struct fwt{
    int c[N];
    void init(){
        for(int i=0;i<=n;i++)c[i]=1;
    }
    void upd(int x,int v){
        for(;x<=n;x+=x&-x)c[x]=1LL*c[x]*v%mod;
    }
    int qry(int x){
        int res=1;
        if(!x)return res;
        for(;x;x-=x&-x)res=1LL*res*c[x]%mod;
        return res;
    }
    int qry(int l,int r){
        return 1LL*qry(r)*qp(qry(l-1),mod-2)%mod;
    }
}ds;
void upd(int x,int v){
    x=dfn[x];
    auto it=s[v].lower_bound(x);
    int u=0;
    if(it!=s[v].end())u=max(u,lca(x,*it));
    if(it!=s[v].begin()){
        --it;
        u=max(u,lca(x,*it));
    }
    s[v].insert(x);
    ds.upd(x,lpf[v]);
    if(u)ds.upd(u,inv[lpf[v]]);
}
int ANS[N];
int main(){
    vis[1]=1;
    for(int i=2;i<=m;i++){
        if(!vis[i])p[++cnt]=i,lpf[i]=i;
        for(int j=1;j<=cnt&&p[j]*i<=m;j++){
            vis[i*p[j]]=1;
            lpf[i*p[j]]=p[j];
            if(i%p[j]==0)break;
        }
    }
    ft[0]=1;
    for(int i=1;i<=m;i++)ft[i]=1LL*ft[i-1]*i%mod;
    ift[m]=qp(ft[m],mod-2);
    for(int i=m;i>=1;i--)ift[i-1]=1LL*ift[i]*i%mod;
    for(int i=1;i<=m;i++)inv[i]=1LL*ift[i]*ft[i-1]%mod;
    scanf("%d%d",&n,&Q);
    ds.init();
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
    for(int i=1;i<n;i++){
        int u,v;
        scanf("%d%d",&u,&v);
        T[u].push_back(v);
        T[v].push_back(u);
    }
    dfs0(1,0);
    for(int i=1;i<=Q;i++){
        scanf("%d",&q[i].op);
        if(q[i].op==1)scanf("%d%d",&q[i].x,&q[i].y);
        else scanf("%d",&q[i].x);
    }
    for(int i=1;i<=Q;i++){
        if(q[i].op==1)update(1,n,1,dfn[q[i].x],dfr[q[i].x],q[i].y,i);
    }
    for(int j=1;j<20;j++)for(int i=1;i+(1<<j)-1<=n;i++){
        st[j][i]=min(st[j-1][i],st[j-1][i+(1<<j-1)]);
    }
    q[Q+1].y=1;
    update(1,n,1,1,n,1,Q+1);
    for(int i=Q+1;i>=0;i--){
        for(auto [x,v]:vd[i]){
            int t1=v,t2=a[x],p1=1,p2=1;
            while(t1>1){
                int f=lpf[t1];
                t1/=f,p1*=f;
                if(t2%f==0)t2/=f;
                else upd(x,p1);
                if(f!=lpf[t1]){
                    p1=1;
                }
            }
            a[x]=v;
        }
        if(q[i].op==2){
            ANS[i]=ds.qry(dfn[q[i].x],dfr[q[i].x]);
        }
    }
    for(int i=1;i<=Q;i++)if(q[i].op==2)printf("%d\n",ANS[i]);
    return 0;
}