P11186 三目运算 题解

· · 题解

题目传送门

4pts

题目中第一个测试点没有三目运算符,因此直接输出 S 即可。

#include<bits/stdc++.h>
using namespace std;
int main(){
    int m,q;
    cin>>m>>q;
    int p;
    cin>>p;
    while(q--){
        int x;
        cin>>x;
        cout<<p<<"\n";
    }
    return 0;
}

48pts

考虑暴力递归进行模拟。

我们每次递归一个区间 [l,r] ,表示当前答案存在于区间 [l,r] 中。暴力用 x 与当前三目运算符进行运算,如果三目运算成立,往左区间递归,不成立就往右区间递归。

我们以 ? 为左括号, : 为右括号,进行括号匹配暴力找出区间端点即可。

#include<bits/stdc++.h>
#define int long long
using namespace std;
char s[5000005];
int n,q;
bool isdight(char x){
    return '0'<=x&&x<='9';
}//判断数字
int to_num(int u){
    int ans=0;
    for(int i=u;i<=n&&isdight(s[i]);i++)ans=ans*10+s[i]-'0';
    return ans;
}//char转化为数字
bool examine(int u,int x){
    char op=s[u+1];
    int t=to_num(u+2);
    if(op=='<')return x<t;
    else return x>t;
}//三目运算
int turn_l(int u){
    int i;
    for(i=u;i<=n&&s[i]!='?';i++);
    return i;
}//找出对应的 ? 
int turn_r(int u){
    int stk=0;
    for(int i=u;i<=n;i++){
        if(s[i]=='?')stk++;
        if(s[i]==':')stk--;
        if(s[i]==':'&&stk==0)return i;
    }
    return n;
}//找出对应的 : 
int solve(int l,int r,int x){
    if(l>n||r>n||l>r)return 0;
    if(isdight(s[l]))return to_num(l);
    if(examine(l,x))return solve(turn_l(l)+1,turn_r(l)-1,x);
    else return solve(turn_r(l)+1,r,x);
}//查找函数
signed main(){
    int m;
    cin>>m>>q>>(s+1);
    n=strlen(s+1);
    if(isdight(s[1])){
        while(q--)cout<<to_num(1)<<"\n";
        return 0;
    }//特判没有三目运算的情况
    while(q--){
        int x;
        cin>>x;
        cout<<solve(1,n,x)<<"\n";//查询
    }
}

72pts

对于 48 分的代码,我们的 turn_lturn_r 的函数复杂度为 O(n) ,显然可以通过预处理优化为 O(1)

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5000005;
char s[N];
int nxt1[N],nxt2[N];
int n,q;
bool isdight(char x){
    return '0'<=x&&x<='9';
}
int to_num(int u){
    int ans=0;
    for(int i=u;i<=n&&isdight(s[i]);i++)ans=ans*10+s[i]-'0';
    return ans;
}
bool examine(int u,int x){
    char op=s[u+1];
    int t=to_num(u+2);
    if(op=='<')return x<t;
    else return x>t;
}
int turn_l(int u){
    return nxt1[u];
}
int turn_r(int u){
    return nxt2[u];
}
int solve(int l,int r,int x){
    if(l>n||r>n||l>r)return 0;
    if(isdight(s[l]))return to_num(l);
    if(examine(l,x))return solve(turn_l(l)+1,turn_r(l)-1,x);
    else return solve(turn_r(l)+1,r,x);
}//同上
stack<int>Q;
signed main(){
    int m;
    cin>>m>>q>>(s+1);
    n=strlen(s+1);
    nxt1[n+1]=n;
    for(int i=n;i>=1;i--){
        nxt1[i]=nxt1[i+1];
        if(s[i]=='?')nxt1[i]=i;
    }//预处理 ? 的所在位置
    for(int i=1;i<=n;i++){
        if(s[i]==':'){
            if(Q.empty()){
                break;
            }
            int t=Q.top();
            Q.pop();
            nxt2[t]=i;
        }
        if(s[i]=='?')Q.push(i);
    }//括号匹配得到每个 ? 对应的 : 
    nxt2[n+1]=n;
    for(int i=n;i>=1;i--)if(!nxt2[i])nxt2[i]=nxt2[i+1];
    for(int i=1;i<=n;i++)
    if(isdight(s[1])){
        while(q--)cout<<to_num(1)<<"\n";
        return 0;
    }
    while(q--){
        int x;
        cin>>x;
        cout<<solve(1,n,x)<<"\n";
    }
}

100pts

由三目运算的本质可得,答案为 x 的查询数一定在一个区间里,又因为 q \le 10^5 ,我们可以直接在 solve 函数中预处理出所有的 ans ,这样查询的复杂度降为 O(1) ,预处理复杂度为 O(n) ,足以通过本题。

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5000005;
char s[N];
int nxt1[N],nxt2[N];
int n,qry;
bool isdight(char x){
    return '0'<=x&&x<='9';
}
int to_num(int u){
    int ans=0;
    for(int i=u;i<=n&&isdight(s[i]);i++)ans=ans*10+s[i]-'0';
    return ans;
}
int turn_l(int u){
    return nxt1[u];
}
int turn_r(int u){
    return nxt2[u];
}
int ans[N];
void solve(int ql,int qr,int l,int r){
//ql,qr:目前答案区间
    if(isdight(s[l])){
        int num=to_num(l);
        for(int i=ql;i<=qr;i++)ans[i]=num;//统计答案
        return ;
    }
    int num=to_num(l+2);
    char op=s[l+1];
    int l1,r1,l2,r2;
    if(op=='<'){
        l1=ql,r1=num-1;
        l2=num,r2=qr;
    }
    if(op=='>'){
        l1=num+1,r1=qr;
        l2=ql,r2=num;
    }
    solve(l1,r1,turn_l(l)+1,turn_r(l)-1);//左区间
    solve(l2,r2,turn_r(l)+1,r);//右区间
}
stack<int>stk;
signed main(){
    int m;
    cin>>m>>qry>>(s+1);
    n=strlen(s+1);
    nxt1[n+1]=n;
    for(int i=n;i>=1;i--){
        nxt1[i]=nxt1[i+1];
        if(s[i]=='?')nxt1[i]=i;
    }
    for(int i=1;i<=n;i++){
        if(s[i]==':'){
            if(stk.empty()){
                break;
            }
            int t=stk.top();
            stk.pop();
            nxt2[t]=i;
        }
        if(s[i]=='?')stk.push(i);
    }
    nxt2[n+1]=n;
    for(int i=n;i>=1;i--)if(!nxt2[i])nxt2[i]=nxt2[i+1];
    solve(1,m,1,n); 
    while(qry--){
        int x;
        cin>>x;
        x=min(x,m);//当 x>m 时,与 x=m 的结果一致
        cout<<ans[x]<<"\n";//直接查询
    }
}