题解 P1005 【矩阵取数游戏】

· · 题解

唉。心好累。

刘汝佳骗了我。《算法竞赛入门经典》骗了我。

vector实现的高精度不好写,而且慢的要死。疯狂地压到了15位才勉强通过倒数第二个测试点。

痛定思痛。改用数组。

贼快。

不过里面保留了我之前用的优化。

我都忘了这是一道动态规划的题目。

那就忘了吧。

#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
#define A 1000000000000000
using namespace std;

struct bint{
    long long s[10];

    bint(long long num = 0){
        *this=num;
    }

    bint operator = (long long num){
        memset(s, 0, sizeof(s));
        s[0]=num;
        return *this;
    }

    bint operator + (const bint& b) const {
        bint c;
        unsigned long long g = 0;
        for (int i=0; i<9 ; i++){
            unsigned long long x = g;
            x += (unsigned long long)s[i]+b.s[i];
            c.s[i]=x%A;
            g=x/A;
        }
        return c;
    }

    bint operator* (const bint& b) const {
        bint c;
        unsigned long long g=0;
        for (int i=0; i<9; i++){
            unsigned long long x=g;
            for (int j=0; j<=i; j++){
                int k=i-j;
                x+=(unsigned long long)s[k]*b.s[j];
            }
            c.s[i]=x%A;
            g=x/A;
        }
        return c;

    }

    bool operator < (const bint& b) const {
        for (int i=9; i>=0; i--){
            if (s[i]<b.s[i]) return 1;
            if (s[i]>b.s[i]) return 0;
        }
        return 0;   
    }

    void print(){
        char buf[200];
        for (int i=9; i>=0; i--){
            sprintf(buf+(9-i)*15, "%015lld", s[i]);
        }
        bool flag=0;
        for (int i=0; i<150; i++){
            if (buf[i]>'0') flag=1;
            if (flag) printf("%c", buf[i]);
        }

        if (!flag) printf("0");

    }

}; 

long long a[100]; bint dp[100][100];
bint ans;
bint two[82];

inline void ini(){
    two[0]=1;
    for (int i=1; i<=80; i++){
        two[i]=two[i-1]*2;
    }
}

bint at[100][100];
bool used[100][100];

inline bint multi(int i, int p){
    if (used[i][p]) return at[i][p];
    used[i][p]=true;
    return at[i][p]=(bint)a[i]*two[p];
}

int main(){

    freopen("1005.cpp", "r", stdin);
    ini();
    int n, m;
    cin>>n>>m;
    ans=0;
    for (int w=0; w<n; w++){
        memset(dp, 0, sizeof(dp));
        memset(used, false, sizeof(used));
        for (int i=1; i<=m; i++) scanf("%d", a+i);
        bint anst=0;

        for (int t=0; t<m; t++){    
            for (int i=1; i+t<=m; i++){
                int j=i+t;
                int p=m-t;
                bint s = dp[i+1][j]+multi(i,p);
                bint t = dp[i][j-1]+multi(j,p);
                if (s<t) dp[i][j]=t;
                else dp[i][j]=s;

            }
        }

        ans=ans+dp[1][m];
    }
    ans.print();
    cout<<endl;
}