题解:P15385 回文回文回 IV / paliniv

· · 题解

题意

给定一个长度为 n 的整数数组 a,求有多少种排列 b 满足其前缀和数组是回文串。答案对 998244353 取模。

思路

设前缀和数组为 p,回文条件为 p_i = p_{n+1-i}

S = \sum a_i

i=1,有 p_1=p_n,所以 b_1 只能是所有数加起来的和。也就是 b_1=S

不难发现,所有满足 l+r-1=n 的数对 (l,r),有 b_l=-b_r

n 为偶数,中间位置 b_{\frac{n}{2}+1} 必须满足 2 \times p_{\frac{n}{2}} = S。结合回文性质可推出该位置值为 0

因此问题等价于:将数组元素分配到 n 个位置上,使得:

  1. 位置 1 的值恰好为 S
  2. 关于 \frac{n}{2} 对称位置的元素之和为 0

从数组中取出一个值为 S 的元素放在第一位,有 tot_S 种选法,n 为偶数时再取一个 0 放中间(有 mid 种选法)。剩余元素中,每个正值 v 的个数必须等于 -v 的个数,0 的个数必须为偶数,否则答案为零。

设剩余元素中值 v > 0c_v 对,值 02c_0 个。总配度数 k=\sum_{v>0} c_v+c_0

考虑带标号排列的计数:

  1. k 组元素分配到 k 个对称位置对上,有 k! 种可能。
  2. 对于每个组内,把正放前面还是负放前面,有 2^{c_v} 种。
  3. 对于每个相同的 v 放到不同组有 c_v! 种。
  4. 对于零的对,把 2c_00 放入 c_0 对个位置,由于 0=-0 无方向性,所以贡献为 \frac{(2c_0)!}{c_0!}(即从 2c_0 个零中选配对的方案数)。

最终式子

ans = tot_S \times mid \times k! \times \prod_{v>0}(2^{c_v}\times c_v!) \times \frac{(2c_0)!}{c_0!}

code

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod = 998244353;
const int N = 200005;
int fac[N],ifac[N],poww[N];
int qpow(int a,int b){
    int res = 1;
    while(b){
        if(b & 1) res = 1LL * res * a % mod;
        a = 1LL * a * a % mod;
        b >>= 1;
    }
    return res;
}
void pre(){
    fac[0] = ifac[0] = poww[0] = 1;
    for(int i = 1;i < N;i++){
        fac[i] = 1LL * fac[i - 1] * i % mod;
        poww[i] = 1LL * poww[i - 1] * 2 % mod;
    }
    ifac[N - 1] = qpow(fac[N - 1],mod - 2);
    for(int i = N - 2;i >= 1;i--){
        ifac[i] = 1LL * ifac[i + 1] *(i + 1) % mod;
    }
}
int t;
int a[N];
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(nullptr);
    pre();
    cin >> t;
    while(t--){
        int n;
        cin >> n;
        map<int,int> cnt;
        long long sum = 0;
        for(int i = 1;i <= n;i++){
            cin >> a[i];
            cnt[a[i]]++;
            sum += a[i];
        }
        if(cnt.find((int)sum) == cnt.end() || cnt[(int)sum] == 0){
            cout << 0 << "\n";
            continue;
        }
        int cntt = cnt[(int)sum];
        cnt[(int)sum]--;
        if(cnt[(int)sum] == 0) cnt.erase((int)sum);
        int mid = 1;
        if(n % 2 == 0){
            if(cnt.find(0) == cnt.end() || cnt[0] == 0){
                cout << 0 << "\n";
                continue;
            }
            mid = cnt[0];
            cnt[0]--;
            if(cnt[0] == 0) cnt.erase(0);
        }
        bool ok = true;
        int k = 0;
        set<int>tot;
        for(auto it = cnt.begin();it != cnt.end();++it){
            tot.insert(abs(it->first));
        }
        long long ans = 1LL * cntt % mod * mid % mod;
        int c0 = 0;
        for(auto it = tot.begin();it != tot.end();++it){
            int v = *it;
            if(v == 0){
                int z = cnt[0];
                if(z % 2 != 0){
                    ok = false;break;
                }
                c0 = z / 2;
                k += c0;
            } else{
                int cp = cnt.count(v) ? cnt[v] : 0;
                int cn = cnt.count(-v) ? cnt[-v] : 0;
                if(cp != cn){
                    ok = false;
                    break;
                }
                k += cp;
                ans = 1LL * ans * poww[cp] % mod * fac[cp] % mod;
            }
        }
        if(!ok){
            cout << 0 << "\n";
            continue;
        }
        ans = 1LL * ans * fac[k] % mod;
        if(c0 > 0){
            ans = 1LL * ans * fac[2 * c0] % mod * ifac[c0] % mod;
        }
        cout << ans << "\n";
    }
    return 0;
}