[DS记录]P5608 [Ynoi2013] 文化课

command_block

2021-07-03 08:46:01

Personal

**题意** : 维护一个长度为 $n$ 的,只有 整数,加号和乘号 的算式。 支持下列操作 : - 将区间内的数值全部修改为 $x$。 - 将区间内的运算符全部修改为 $opt$。 - 查询区间 $[l,r]$ 取出后,按照运算符的优先级计算出的结果。答案对 $10^9+7$ 取模。 $n,m\leq 10^5$ ,时限$\texttt{1.5s}$,空限$\texttt{64M}$。 ------------ - 先不考虑修改。 使用线段树维护。 对于每个线段树节点,维护区间算式的结果。 合并两个区间时,若中间的运算符是加法,则简单相加。 若中间的运算符是乘法,找出连乘前缀后缀,相乘即可。这需要维护最长连乘前后缀的积。 - 接着考虑修改运算符。 若区间改为加法,直接将算式结果替换为区间和。 若区间改为乘法,直接将算式结果替换为区间积。 注意要更新最长连乘前后缀。 - 最后考虑区间修改数值。 记 $s_i$ 为区间内长为 $i$ 的极长连乘段的个数,则将区间数值修改为 $x$ 后算式结果为 : $$\sum\limits_{i}s_ix^i$$ 注意到有值的 $s_i$ 只有 $O(\sqrt{len})$ 个(自然根号),复杂度为 $O(\sqrt{n}+\sqrt{n/2}+\sqrt{n/4}+...)=O(\sqrt{n})$。 维护 $s$ ,`pushup` 时用归并合并,时间复杂度同上。 空间复杂度为 $T(n)=O(\sqrt{n})+2T(n/2)=O(n)$。 还有个小问题,我们计算上式时需要计算 $x$ 的幂,直接使用快速幂会使得复杂度带个 $\log$。 注意到我们维护了有序的 $s_i$ ,对于相邻两个有值的 $s_i,s_j$ ,快速幂计算 $x^{j-i}$ 以从 $x^i$ 得到 $x^j$。可以证明这样优化后复杂度变为 $O(\sqrt{n})$。 实际维护的东西有亿点点多…… 交两次就过了,爽! ```cpp #include<algorithm> #include<cstdio> #include<vector> #define uint unsigned int #define ll long long #define pb emplace_back #define Pr pair<int,int> #define fir first #define sec second #define mp make_pair #define Itor vector<Pr>::iterator #define MaxN 100500 using namespace std; const int mod=1000000007; ll powM(ll a,int t){ ll ret=1; while(t){ if (t&1)ret=ret*a%mod; a=a*a%mod;t>>=1; }return ret; } int sx[MaxN],sop[MaxN]; struct Node{ //算式结果,左侧值,右侧值,和,积,区间长度 //左连乘长度,右连乘长度,左连乘值,右连乘值,符号标记,值标记 //区间左侧符号 int s,xl,xr,s0,s1,len ,cpl,cpr,pl,pr,ft,fv; bool tl; vector<Pr> o; inline void laddt(int t){ tl=ft=t;o.clear(); if (t==0){cpl=cpr=1;pl=xl;pr=xr;s=s0;o.pb(mp(1,len));} else {cpl=cpr=len;pl=pr=s=s1;o.pb(mp(len,1));} } void laddv(int v){ xl=xr=fv=v; s0=1ll*len*v%mod;s1=powM(v,len); pl=powM(v,cpl);pr=powM(v,cpr); s=0; for (int i=0,buf=1,las=0;i<o.size();i++){ buf=1ll*buf*powM(v,o[i].fir-las)%mod;las=o[i].fir; s=(s+1ll*o[i].sec*buf)%mod; } } }a[MaxN<<2]; inline void up(int u) { int l=u<<1,r=u<<1|1; a[u].tl=a[l].tl; a[u].s0=(a[l].s0+a[r].s0)%mod; a[u].s1=1ll*a[l].s1*a[r].s1%mod; a[u].xl=a[l].xl;a[u].xr=a[r].xr; if (a[r].tl==0){ a[u].s=(a[l].s+a[r].s)%mod; a[u].cpl=a[l].cpl;a[u].cpr=a[r].cpr; a[u].pl=a[l].pl;a[u].pr=a[r].pr; }else { a[u].s=((ll)a[l].s+a[r].s-a[l].pr-a[r].pl+1ll*a[l].pr*a[r].pl)%mod; a[u].cpl=(a[l].cpl==a[l].len) ? a[l].len+a[r].cpl : a[l].cpl; a[u].pl=(a[l].cpl==a[l].len) ? 1ll*a[l].s1*a[r].pl%mod : a[l].pl; a[u].cpr=(a[r].cpr==a[r].len) ? a[r].len+a[l].cpr : a[r].cpr; a[u].pr=(a[r].cpr==a[r].len) ? 1ll*a[r].s1*a[l].pr%mod : a[r].pr; } vector<Pr> &lo=a[l].o,&ro=a[r].o,&o=a[u].o; o.clear(); int p=0; for (int i=0;i<lo.size();i++){ while(p<ro.size()&&lo[i].fir>ro[p].fir)a[u].o.pb(ro[p++]); if (p<ro.size()&&lo[i].fir==ro[p].fir){ o.pb(mp(lo[i].fir,lo[i].sec+ro[p].sec)); p++; }else o.pb(lo[i]); }while(p<ro.size())o.pb(ro[p++]); if (a[r].tl){ Itor it=lower_bound(o.begin(),o.end(),mp(a[l].cpr,0)); it->sec--;if (!it->sec)o.erase(it); it=lower_bound(o.begin(),o.end(),mp(a[r].cpl,0)); it->sec--;if (!it->sec)o.erase(it); int len=a[l].cpr+a[r].cpl; it=lower_bound(o.begin(),o.end(),mp(len,0)); if (it==o.end()||it->fir!=len)o.insert(it,mp(len,1)); else it->sec++; } } void build(int l,int r,int u) { a[u].ft=a[u].fv=-1; a[u].len=r-l+1; if (l==r){ a[u].tl=sop[l-1]; a[u].cpl=a[u].cpr=1; a[u].s=a[u].s0=a[u].s1=a[u].xl=a[u].xr= a[u].pl=a[u].pr=sx[l]; a[u].o.pb(mp(1,1)); return ; }int mid=(l+r)>>1; build(l,mid,u<<1); build(mid+1,r,u<<1|1); up(u); } inline void ladd(int u) { if (a[u].ft!=-1){ a[u<<1].laddt(a[u].ft); a[u<<1|1].laddt(a[u].ft); a[u].ft=-1; } if (a[u].fv!=-1){ a[u<<1].laddv(a[u].fv); a[u<<1|1].laddv(a[u].fv); a[u].fv=-1; } } int wfl,wfr,wfc; void chgt(int l,int r,int u) { if (wfl<=l&&r<=wfr){a[u].laddt(wfc);return ;} int mid=(l+r)>>1;ladd(u); if (wfl<=mid)chgt(l,mid,u<<1); if (mid<wfr)chgt(mid+1,r,u<<1|1); up(u); } void chgv(int l,int r,int u) { if (wfl<=l&&r<=wfr){a[u].laddv(wfc);return ;} int mid=(l+r)>>1;ladd(u); if (wfl<=mid)chgv(l,mid,u<<1); if (mid<wfr)chgv(mid+1,r,u<<1|1); up(u); } int ret,pr; void qry(int l,int r,int u) { if (wfl<=l&&r<=wfr){ if (wfl==l){ret=a[u].s;pr=a[u].pr;} else { if (a[u].tl){ ret=((ll)ret+a[u].s-pr-a[u].pl+1ll*pr*a[u].pl)%mod; pr=(a[u].len==a[u].cpl) ? 1ll*a[u].pr*pr%mod : a[u].pr; }else {ret=(ret+a[u].s)%mod;pr=a[u].pr;} }return ; }int mid=(l+r)>>1;ladd(u); if (wfl<=mid)qry(l,mid,u<<1); if (mid<wfr)qry(mid+1,r,u<<1|1); } int n,m; int main() { scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) {uint sav;scanf("%u",&sav);sx[i]=sav%mod;} for (int i=1;i<n;i++)scanf("%d",&sop[i]); build(1,n,1); for (int i=1,op;i<=m;i++){ scanf("%d%d%d",&op,&wfl,&wfr); if (op==1){uint sav;scanf("%d",&sav);wfc=sav%mod;chgv(1,n,1);} if (op==2){scanf("%d",&wfc);wfl++;wfr++;chgt(1,n,1);} if (op==3){qry(1,n,1);printf("%d\n",ret);} }return 0; } ```