[DS记录]P5608 [Ynoi2013] 文化课
command_block
2021-07-03 08:46:01
**题意** : 维护一个长度为 $n$ 的,只有 整数,加号和乘号 的算式。
支持下列操作 :
- 将区间内的数值全部修改为 $x$。
- 将区间内的运算符全部修改为 $opt$。
- 查询区间 $[l,r]$ 取出后,按照运算符的优先级计算出的结果。答案对 $10^9+7$ 取模。
$n,m\leq 10^5$ ,时限$\texttt{1.5s}$,空限$\texttt{64M}$。
------------
- 先不考虑修改。
使用线段树维护。
对于每个线段树节点,维护区间算式的结果。
合并两个区间时,若中间的运算符是加法,则简单相加。
若中间的运算符是乘法,找出连乘前缀后缀,相乘即可。这需要维护最长连乘前后缀的积。
- 接着考虑修改运算符。
若区间改为加法,直接将算式结果替换为区间和。
若区间改为乘法,直接将算式结果替换为区间积。
注意要更新最长连乘前后缀。
- 最后考虑区间修改数值。
记 $s_i$ 为区间内长为 $i$ 的极长连乘段的个数,则将区间数值修改为 $x$ 后算式结果为 :
$$\sum\limits_{i}s_ix^i$$
注意到有值的 $s_i$ 只有 $O(\sqrt{len})$ 个(自然根号),复杂度为 $O(\sqrt{n}+\sqrt{n/2}+\sqrt{n/4}+...)=O(\sqrt{n})$。
维护 $s$ ,`pushup` 时用归并合并,时间复杂度同上。
空间复杂度为 $T(n)=O(\sqrt{n})+2T(n/2)=O(n)$。
还有个小问题,我们计算上式时需要计算 $x$ 的幂,直接使用快速幂会使得复杂度带个 $\log$。
注意到我们维护了有序的 $s_i$ ,对于相邻两个有值的 $s_i,s_j$ ,快速幂计算 $x^{j-i}$ 以从 $x^i$ 得到 $x^j$。可以证明这样优化后复杂度变为 $O(\sqrt{n})$。
实际维护的东西有亿点点多……
交两次就过了,爽!
```cpp
#include<algorithm>
#include<cstdio>
#include<vector>
#define uint unsigned int
#define ll long long
#define pb emplace_back
#define Pr pair<int,int>
#define fir first
#define sec second
#define mp make_pair
#define Itor vector<Pr>::iterator
#define MaxN 100500
using namespace std;
const int mod=1000000007;
ll powM(ll a,int t){
ll ret=1;
while(t){
if (t&1)ret=ret*a%mod;
a=a*a%mod;t>>=1;
}return ret;
}
int sx[MaxN],sop[MaxN];
struct Node{
//算式结果,左侧值,右侧值,和,积,区间长度
//左连乘长度,右连乘长度,左连乘值,右连乘值,符号标记,值标记
//区间左侧符号
int s,xl,xr,s0,s1,len
,cpl,cpr,pl,pr,ft,fv;
bool tl;
vector<Pr> o;
inline void laddt(int t){
tl=ft=t;o.clear();
if (t==0){cpl=cpr=1;pl=xl;pr=xr;s=s0;o.pb(mp(1,len));}
else {cpl=cpr=len;pl=pr=s=s1;o.pb(mp(len,1));}
}
void laddv(int v){
xl=xr=fv=v;
s0=1ll*len*v%mod;s1=powM(v,len);
pl=powM(v,cpl);pr=powM(v,cpr);
s=0;
for (int i=0,buf=1,las=0;i<o.size();i++){
buf=1ll*buf*powM(v,o[i].fir-las)%mod;las=o[i].fir;
s=(s+1ll*o[i].sec*buf)%mod;
}
}
}a[MaxN<<2];
inline void up(int u)
{
int l=u<<1,r=u<<1|1;
a[u].tl=a[l].tl;
a[u].s0=(a[l].s0+a[r].s0)%mod;
a[u].s1=1ll*a[l].s1*a[r].s1%mod;
a[u].xl=a[l].xl;a[u].xr=a[r].xr;
if (a[r].tl==0){
a[u].s=(a[l].s+a[r].s)%mod;
a[u].cpl=a[l].cpl;a[u].cpr=a[r].cpr;
a[u].pl=a[l].pl;a[u].pr=a[r].pr;
}else {
a[u].s=((ll)a[l].s+a[r].s-a[l].pr-a[r].pl+1ll*a[l].pr*a[r].pl)%mod;
a[u].cpl=(a[l].cpl==a[l].len) ? a[l].len+a[r].cpl : a[l].cpl;
a[u].pl=(a[l].cpl==a[l].len) ? 1ll*a[l].s1*a[r].pl%mod : a[l].pl;
a[u].cpr=(a[r].cpr==a[r].len) ? a[r].len+a[l].cpr : a[r].cpr;
a[u].pr=(a[r].cpr==a[r].len) ? 1ll*a[r].s1*a[l].pr%mod : a[r].pr;
}
vector<Pr> &lo=a[l].o,&ro=a[r].o,&o=a[u].o;
o.clear();
int p=0;
for (int i=0;i<lo.size();i++){
while(p<ro.size()&&lo[i].fir>ro[p].fir)a[u].o.pb(ro[p++]);
if (p<ro.size()&&lo[i].fir==ro[p].fir){
o.pb(mp(lo[i].fir,lo[i].sec+ro[p].sec));
p++;
}else o.pb(lo[i]);
}while(p<ro.size())o.pb(ro[p++]);
if (a[r].tl){
Itor it=lower_bound(o.begin(),o.end(),mp(a[l].cpr,0));
it->sec--;if (!it->sec)o.erase(it);
it=lower_bound(o.begin(),o.end(),mp(a[r].cpl,0));
it->sec--;if (!it->sec)o.erase(it);
int len=a[l].cpr+a[r].cpl;
it=lower_bound(o.begin(),o.end(),mp(len,0));
if (it==o.end()||it->fir!=len)o.insert(it,mp(len,1));
else it->sec++;
}
}
void build(int l,int r,int u)
{
a[u].ft=a[u].fv=-1;
a[u].len=r-l+1;
if (l==r){
a[u].tl=sop[l-1];
a[u].cpl=a[u].cpr=1;
a[u].s=a[u].s0=a[u].s1=a[u].xl=a[u].xr=
a[u].pl=a[u].pr=sx[l];
a[u].o.pb(mp(1,1));
return ;
}int mid=(l+r)>>1;
build(l,mid,u<<1);
build(mid+1,r,u<<1|1);
up(u);
}
inline void ladd(int u)
{
if (a[u].ft!=-1){
a[u<<1].laddt(a[u].ft);
a[u<<1|1].laddt(a[u].ft);
a[u].ft=-1;
}
if (a[u].fv!=-1){
a[u<<1].laddv(a[u].fv);
a[u<<1|1].laddv(a[u].fv);
a[u].fv=-1;
}
}
int wfl,wfr,wfc;
void chgt(int l,int r,int u)
{
if (wfl<=l&&r<=wfr){a[u].laddt(wfc);return ;}
int mid=(l+r)>>1;ladd(u);
if (wfl<=mid)chgt(l,mid,u<<1);
if (mid<wfr)chgt(mid+1,r,u<<1|1);
up(u);
}
void chgv(int l,int r,int u)
{
if (wfl<=l&&r<=wfr){a[u].laddv(wfc);return ;}
int mid=(l+r)>>1;ladd(u);
if (wfl<=mid)chgv(l,mid,u<<1);
if (mid<wfr)chgv(mid+1,r,u<<1|1);
up(u);
}
int ret,pr;
void qry(int l,int r,int u)
{
if (wfl<=l&&r<=wfr){
if (wfl==l){ret=a[u].s;pr=a[u].pr;}
else {
if (a[u].tl){
ret=((ll)ret+a[u].s-pr-a[u].pl+1ll*pr*a[u].pl)%mod;
pr=(a[u].len==a[u].cpl) ? 1ll*a[u].pr*pr%mod : a[u].pr;
}else {ret=(ret+a[u].s)%mod;pr=a[u].pr;}
}return ;
}int mid=(l+r)>>1;ladd(u);
if (wfl<=mid)qry(l,mid,u<<1);
if (mid<wfr)qry(mid+1,r,u<<1|1);
}
int n,m;
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
{uint sav;scanf("%u",&sav);sx[i]=sav%mod;}
for (int i=1;i<n;i++)scanf("%d",&sop[i]);
build(1,n,1);
for (int i=1,op;i<=m;i++){
scanf("%d%d%d",&op,&wfl,&wfr);
if (op==1){uint sav;scanf("%d",&sav);wfc=sav%mod;chgv(1,n,1);}
if (op==2){scanf("%d",&wfc);wfl++;wfr++;chgt(1,n,1);}
if (op==3){qry(1,n,1);printf("%d\n",ret);}
}return 0;
}
```