P1949 [NOI 2001] 聪明的打字员 解

· · 个人记录

写一种有点乱搞的 IDA* 做法.

“有点乱搞”是因为它的正确性论述有个部分太过琐碎,没想到严格的数学证明.

由于此题的可行状态数太多,搜索树非常深,我考虑用 IDA*.

如果你不知道 IDA* ,这指的是使用估值函数剪枝的迭代加深搜索算法.

如果你不知道迭代加深搜索,这是一种设置了递归深度上限的回溯搜索,采用了逐个试验的逻辑寻找最优解(最小的可能深度). 比如假设最小可能深度为 15,那么递归时任意分支只要深度超过 15 就剪掉;如果这样没有找到合法解,说明 15 不对,再假定 16 为答案并重新搜,直到找到解.

如果你不知道估值函数,这是指在每个状态都估算一下当前状态离目标状态最少还需要几层递归. 很明显,如果 当前深度+估算的值 超过了设定的上限,那么当前的分支也可以直接剪掉.

所以我们要做两件事:设计搜索算法的逻辑、设计估值函数.

先看看怎么搜索.

这很显然,此题的一个状态就是当前的 6 位密码以及光标位置,状态间的转移就是一次操作.

但这中间可以剪很多枝:除了光标在最左边时左移、值为 9 时增加这些明显的浪费以外,一个位置的值如果和目标值一样也不必做改值/交换操作.

这道题最神奇的是估值函数. 这里估值函数的定义是从一个 6 位密码 code 变化到另一个密码 target 需要的操作数的一个下界,并且这个估值效果得很紧,不然估不出 000000999999 这种 59 层递归就直接爆炸.

我们把操作数叫做成本,然后把光标移动耗费的叫做移动成本,其它的叫修改成本. 很明显如果 codetargetk 位不匹配,那么 k-1 是移动成本的下界. 关键在于修改成本,我想了三个方式来衡量.

第一种是算 code 的数字和跟 target 的数字和之差的绝对值 C_1,这是显然的:一次改值操作最多让这个差减小 1,而交换操作不影响这个差,最终差应为 0,所以 C_1 是正确的下界.

第一种显然不够用,999111555555 的数字和之差为 0 但修改成本显然比 0 大得多,所以引入第二种估值:计算两个密码的直径,直径定义为最大三位之和减去最小三位之和,然后求直径之差的绝对值 C_2. 这个方法完美解决 999111/555555 这种情况,而且可以类似第一种证明是正确的.

但验证发现第二种还不够用,例如 900000444565 这种情况下它会炸:修改成本应该在 29 左右,但是最好的第一种方案估出来也是 19,并不理想. 于是我想了第三种怪招:先把两个密码升序排序,然后对于 code 里每一位,如果它target 里出现,就求它和排序后 target 对应项的差的绝对值,结果相加得 C_3.

我的动机是,900000/444565 之所以炸就是因为 9 没有被正确处理,它应该跟 45 求差的绝对值而不是像第一种方案直接减掉 9 本身,即这种没出现在 target 里的数应该拿来求差的绝对值而不是差,而出现在 target 里的数由于可能被交换故不考虑;另一方面也不能直接算 |9-4|,因为此题的交换操作非常神奇,4 如果能和 6 交换就会出现更优的 |9-6|. 于是我想了一种折衷方案,让比较大的数和比较大的数求差的绝对值,即对排序后的两个密码里的对应项求差的绝对值. 这样得到的 C_3 是不是正确的下界呢?不敢说是,但这个估计已经很精了,我有把握说几乎是. 就算存在 Hack 近似度也会很高.

于是,如果 C_1,C_2,C_3 都是修改成本的下界,我们直接取 max\{C_1,C_2,C_3\} 再加上移动成本的下界 k-1 就完成了估值函数的计算.

IDA* 算法设计完了.

俺过了搞 OI 的时候了,因此码风比较重表义而不是编写速度. 代码如下.

#include <bits/stdc++.h>

#define IOS_SPEED std::ios::sync_with_stdio(false)

using std::cin, std::cout;
using std::array, std::unordered_set;
using std::swap, std::sort, std::abs, std::lower_bound, std::max;

constexpr int len = 6, upper = 9;
array<int, len> code, target, target_o; // xxx_o := 排序后的 xxx
int pos = 0;
int sum_target = 0, diff_target = 0; // sum_xxx := xxx 的数位和, diff_xxx := xxx 的直径

inline int rest_lowerbound(){ // 估值函数
    int answer = 0;
    for(int i=0; i<len; ++i){
        if(code[i]!=target[i]) ++ answer;
    }
    int sum_code = 0, diff_code = 0, dist_sum = 0;
    auto code_o = code;
    sort(code_o.begin(), code_o.end());
    for(int i=0; i<len; ++i){
        sum_code += code_o[i]; diff_code += (i<len/2)? (-code_o[i]): code_o[i];
        auto loc = lower_bound(target_o.begin(), target_o.end(), code_o[i]);
        if(loc==target_o.end()||*loc!=code_o[i])
            dist_sum += abs(code_o[i]-target_o[i]);
    }
    answer += max(max(abs(diff_target-diff_code), abs(sum_code-sum_target)), dist_sum);
    // abs(sum_target-sum_code) 即 C_1, abs(diff_target-diff_code) 即 C_2, dist_sum 即 C_3
    -- answer; if(answer<0) answer = 0;
    return answer;
}

bool enumerate(int depth, int stop){
    if(depth+rest_lowerbound()>stop) return false;
    if(code==target) return true;
    if(pos>0){
        -- pos; if(enumerate(depth+1, stop)) return true; ++ pos;
    }
    if(pos<len-1){
        ++ pos; if(enumerate(depth+1, stop)) return true; -- pos;
    }
    if(code[pos]>0&&code[pos]!=target[pos]){
        -- code[pos]; if(enumerate(depth+1, stop)) return true; ++ code[pos];
    }
    if(code[pos]<upper&&code[pos]!=target[pos]){
        ++ code[pos]; if(enumerate(depth+1, stop)) return true; -- code[pos];
    }
    if(pos>0&&code[pos]!=target[pos]){
        swap(code[pos], code[0]); if(enumerate(depth+1, stop)) return true; swap(code[pos], code[0]);
    }
    if(pos<len-1&&code[pos]!=target[pos]){
        swap(code[pos], code[len-1]); if(enumerate(depth+1, stop)) return true; swap(code[pos], code[len-1]);
    }
    return false;
}

int min_presses(){
    target_o = target;
    sort(target_o.begin(), target_o.end());
    for(int i=0; i<len; ++i){
        sum_target += target_o[i]; diff_target += (i<len/2)? (-target_o[i]): target_o[i];
    }
    int attempt = rest_lowerbound();
    while(!enumerate(0, attempt)) ++ attempt;
    return attempt;
}

void interface(){
    IOS_SPEED;
    char new_bit;
    for(int i=0; i<len; ++i){cin >> new_bit; code[i] = new_bit-'0';}
    for(int i=0; i<len; ++i){cin >> new_bit; target[i] = new_bit-'0';}
    cout << min_presses() << "\n"; 
}

int main()
{
    interface();
    return 0;
}

神奇地以 61ms 成为最快解.