题解:P15546 「Stoi2037」七里香

· · 题解

五年级 xxs 赛时被卡溢出了呜呜呜。

本文同步发表于博客园。

推式子

这个推式子推了我半小时,感觉推复杂了,不过也在赛后补题的时候过了。

什么文字的太多了,直接看形式化题意,让我们求

\sum_{1 \le l<r \le n} \sum_{l \le i<j \le r} [(j-1)k+a_j'-(i-1)k+a_i']

的最大值。

我们定义第 x 列的行号

f(x)=(x-1)k+a_x'

则原式变为

\sum_{1 \le l<r \le n} \sum_{1 \le i<j \le r} [f(j)-f(i)]。

我们令原式为 F

f(j)-f(i)=(j-i)k+(a_j'-a_i')

\begin{align*} F&=\sum_{1 \le l<r \le n} \sum_{l \le i<j \le r} [f(j)-f(i)]\\ &=\sum_{1 \le l<r \le n} \sum_{l \le i<j \le r} [(j-i)k+(a_j'-a_i')]\\ &=k\sum_{1 \le l<r \le n} \sum_{l \le i<j \le r}(j-i)+\sum_{1 \le l<r \le n} \sum_{l \le i<j \le r}(a_j'-a_i')\\ \end{align*}

方便计算,记

S=\sum_{1 \le l<r \le n} \sum_{l \le i<j \le r} (j-i) T=\sum_{1 \le l<r \le n} \sum_{l \le i<j \le r} (a_j'-a_i')

F=k \cdot S+T。

下面分别计算 ST

化简 S

注意到 S 为恒定常数。

对于固定的 i<j,其在区间 [l,r] 中出现当且仅当 l \le ij \le r,共有

i \cdot (n-j+1)

个区间,则

S=\sum_{1 \le i<j \le n} i(j-i)(n-j+1)。

d=j-i,则 j=d+i,从而

S=\sum_{d=1}^{n-1}\sum_{i=1}^{n-d} id(n-i-d+1)。

d 从第二个 \sum 中提出来:

S=\sum_{d=1}^{n-1} d \cdot \sum_{i=1}^{n-d}i(n-i-d+1)。

m=n-d,则第二个 \sum

\sum_{i=1}^m i(m-i+1)=\frac16 m(m+1)(m+2)

于是

S=\frac16\sum_{d=1}^{n-1}d(n-d)(n-d+1)(n-d+2)。

化简 T

同上一部分,可交换求和顺序并化简

T=\sum_{1 \le i<j \le n} i(a_j'-a_i')(n-j+1)。

拆开

T=\sum_{1 \le i<j \le n} i(n-j+1) \cdot a_j'-\sum_{1 \le i<j \le n} i(n-j+1) \cdot a_i'。

对于第一个 \sum,固定 j 对于 i<j 求和

\sum_{i=1}^{j-1} i(n-j+1) \cdot a_j'=\frac12 j \cdot a_j'(j-1)(n-j+1)。

对于第二个 \sum,固定 i 对于 j>i 求和

\begin{align*} \sum_{j=i+1}^n i(n-j+1) \cdot a_i'&=i \cdot a_i'\sum_{j=i+1}^n (n-j+1)\\ &=\frac12 i \cdot a_i'(n-i)(n-i+1)。 \end{align*}

于是

\begin{align*} T&= \sum_{i=1}^n a_i'[\frac12 i(i+1)(n-i+1)-\frac12 i(n-i)(n-i+1)]\\ &=\sum_{i=1}^n a_i'[\frac12i(n-i+1)(i+1-n+i)]\\ &=\frac12\sum_{i=1}^ni \cdot a_i'(n-i+1)(2i-n-1)。 \end{align*}

w_i=\frac12 i(n-i+1)(2i-n-1)

T=\sum_{i=1}^n a_i'w_i。

排序不等式

排序不等式(知道了之后你甚至可以切蓝题)

\sum_{i=1}^n a_ib_i \ge \sum_{i=1}^n a_ib_{j_i} \ge \sum_{i=1}^n a_ib_{n-i+1}

其中

\begin{cases} a_1 \le a_2 \le \dots \le a_n\\ b_1 \le b_2 \le \dots \le b_n\\ \end{cases}

\{j_i\}1,2,\dots,n 的一个排列。

排序不等式证明

P=\sum_{i=1}^n a_ib_{j_i}

如果 j_n \ne n,则设此时 b_n 所在的项是 a_{j_m}b_n,则由

(b_n-b_{j_n})(a_n-a_{j_m}) \ge 0

a_nb_n+a_{j_m}b_{j_n} \ge a_{j_m}b_n+a_nb_{j_n}

也就说明 j_n \ne n 时,调换 Pb_nb_{j_n} 的位置,可以得到 a_nb_n 项,使得 P 变成 P_1,使得 P_1 \ge P

同理,可以得到 a_{n-1}b_{n-1} 项,使得 P_1 变为 P_2,使得 P_2 \ge P_1

重复这个过程,经过最多 n-1 次变换,可以得到 \sum_{i=1}^n a_ib_i,故

P \le \sum_{i=1}^n a_ib_i

同理可知

P \ge \sum_{i=1}^n a_1b_{n-i+1}

\{a_n\}\{b_n\} 全相等时,显然等号成立。

\{a_n\}\{b_n\} 不全相等时,必然

a_1 \ne a_n,b_1 \ne b_n

于是

a_1b_1+a_nb_n>a_1b_n+a_nb_1

\sum_{i=2}^{n-1} a_ib_i \ge \sum_{i=2}^{n-1}a_ib_{n-i+1}

从而

\sum_{i=1}^n a_ib_i>\sum_{i=1}^{n} a_ib_{n-i+1}

故这两个等式中必有一个不成立。

所以当且仅当 \{a_n\} 全相等,\{b_n\} 全相等时取等。

最大化 T

已知 \{a_i'\}\{a_i\} 的一个重排,即每个数值 v\in[1,k]a' 中出现的次数等于它在 a 中出现的次数 c_v

由排序不等式:要使 \sum_{i=1}^n w_i a_i' 最大,应将 a_i' 按与 w_i 相同的顺序排列,即 w_i 越大,分配的 a_i' 也应越大。

因此我们按 w_i 降序排序索引,同时将可用的数值(从大到小,每个值 vc_v 份)依次赋给这些位置。

代码

::::success[code]

#include <bits/stdc++.h>
#define pub public:
#define pri private:
#define fri friend:
#define Ofile(s) freopen(s".in", "r", stdin), freopen (s".out", "w", stdout)
#define Cfile(s) fclose(stdin), fclose(stdout)
#define fast ios::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
using namespace std;

using ll = long long;
using ull = unsigned long long;
using i128 = __int128;
using ui128 = unsigned __int128;
using lb = long double;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
using pil = pair<int, ll>;
using pli = pair<ll, int>;

constexpr int mod = 998244353;
constexpr int maxn = 1e5 + 5;
constexpr int maxk = 1e6 + 5;

ll n, k, cur_val, r, out, total;
i128 S, T, tmp, ans;

ll sum[maxk], idx[maxn];
i128 w[maxn];

void print_unsigned(ui128 t){
    if (t >= 10)
        print_unsigned(t / 10);
    putchar('0' + t % 10);
}

void print(i128 t){
    if (t < 0){
        putchar('-');
        ui128 tmpp = -(ui128)t;
        print_unsigned(tmpp);
    }
    else if (!t)
        putchar('0');
    else 
        print_unsigned((ui128) t);
}

int main() {
    freopen("contest.in", "r", stdin);
    freopen("contest.out", "w", stdout);
    cin >> n >> k;
    for (int i = 1, x; i <= n; i++)
        cin >> x, sum[x]++;
    for (ll d = 1; d <= n - 1; d++){
        tmp = (i128)d * (n - d) * (n - d + 1) * (n - d + 2) / 6;
        S += tmp;
    }
    S *= k;
    for (ll i = 1; i <= n; i++)
        w[i] = (i128)i * (n - i + 1) * (2 * i - n - 1) / 2;
    for (int i = 1; i <= n; i++) 
        idx[i] = i;
    sort(idx + 1, idx + n + 1, [&](int x, int y) {
        return w[x] > w[y];
    });
    cur_val = k;
    while (cur_val >= 1 && sum[cur_val] == 0) 
        cur_val--;
    r = sum[cur_val];
    for (ll i = 1; i <= n; i++){
        if (total >= n) break;
        ll pos = idx[i];
        T += w[pos] * cur_val;
        total++;
        r--;
        while (!r && cur_val > 1) 
            cur_val--, r = sum[cur_val];
    }
    ans = S + T;
    print(ans);
    return 0;
}

::::