P13233 题解

· · 题解

一个简单的 DP 是设 dp_{l,r} 表示在 [l,r] 内进行查找最坏情况下所需要的最小代价,转移显然是 dp_{l,r}=\displaystyle\min_{i\in[l,r]}(a_i+\max(dp_{l,i-1},dp_{i + 1 ,r})),边界条件 dp_{i,i}=a_i

这个式子没有前途,考虑从值域很小下手,发现代价的一个显然上界是 9\log n,你只需要做正常的二分查找即可。

所以考虑转换 DP 的值和下标,等价于对每个代价看看能扩展多长,具体的,设 lft_{val,i} 表示在区间最劣情况下花费为 val 时,i 能扩展到的最左边的左端点是什么,同理记录一个 rt_{val,i} 表示能扩展到的最右边的右端点。

对于转移考虑像朴素 DP 时一样根据中转点进行转移,对于一个中转点 i 和一个代价 val,明显左右区间代价不超过 val-a_i 就 OK,所以转移显然:

lft_{val,R}=L,rt_{val,L}=R

其中 L=lft_{val-a_i,i-1},R = rt_{val-a_i,i+1}

然而这个转移你枚举的区间是极大的,这很不好,所以一会需要在扫一遍将 lft 更新为后缀最小值以及 rt 更新为前缀最大值。

当然两个数组初始时都要继承 val-1 时的取值。

当第一次满足 rt_{val,1}=n 时就找到了最小代价。

发现实际我需要用到的值域区间大小是 10,所以可以滚动,空间就开下了。

单次复杂度 O(n\log n)

#include "bits/stdc++.h"
#define f(i ,m ,n ,x) for (int i = (m) ,i##END = (n) ; i <= i##END ; i += (x))

const int N = 1e6 + 25 ,mod = 10 ;
int n ,a[N] ,lft[10][N] ,rt[10][N] ;
std :: vector < int > pos[10] ;

int main (void) {
    std :: ios :: sync_with_stdio (false) ,
    std :: cin.tie (nullptr) ,std :: cout.tie (nullptr) ;

    int T ; std :: cin >> T ; f (t ,1 ,T ,1) {
        std :: string str ; std :: cin >> str ;
        n = str.length () ; f (i ,1 ,n ,1) a[i] = str[i - 1] ^ 48 ;
        f (i ,1 ,n ,1) pos[a[i]].emplace_back (i) ;

        f (i ,0 ,9 ,1) 
            __builtin_memset (lft[i] ,0x3f ,sizeof (int) * (n + 1)) ,
            __builtin_memset (rt[i] ,0 ,sizeof (int) * (n + 1)) ;

        int lim = std :: ceil (std :: log2 (n + 1)) * 9 ;
        f (s ,1 ,lim ,1) {
            int u = (s - 1) % mod ,v = (u - 1 + mod) % mod ;
            __builtin_memset (lft[u] ,0x3f ,sizeof (int) * (n + 1)) ,
            __builtin_memset (rt[u] ,0 ,sizeof (int) * (n + 1)) ;

            f (i ,1 ,n ,1)
                lft[u][i] = lft[v][i] ,rt[u][i] = rt[v][i] ;
            if (s <= 9) for (auto p : pos[s]) lft[u][p] = rt[u][p] = p ;

            f (i ,1 ,n ,1) {
                if (a[i] > s) continue ;
                int q = (u - a[i] + mod) % mod ;
                int l = lft[q][i - 1] ,r = rt[q][i + 1] ;
                if (i == 1) l = 1 ;
                if (i == n) r = n ;
                if (l == 0x3f3f3f3f || !r) continue ;
                lft[u][r] = std :: min (lft[u][r] ,l) ,
                rt[u][l] = std :: max (rt[u][l] ,r) ;
            } 

            int min = 0x3f3f3f3f ; 
            for (int i = n ; i ; i--)
                min = lft[u][i] = std :: min (lft[u][i] ,min) ;

            int max = 0 ;
            f (i ,1 ,n ,1)
                max = rt[u][i] = std :: max (rt[u][i] ,max) ;

            if (rt[u][1] == n) 
            { std :: cout << "Case #" << t << ": " << s << '\n' ; break ;}
        } f (i ,1 ,n ,1) pos[a[i]].clear () ;
    }
    return 0 ;
}