P1118 数字三角形

· · 个人记录

这是一道不错的题目,首先暴力是想都不要去想的,不现实,因为生成最多12位数的全排列需要O(N!)的时间复杂度,那么我们就要想办法进行剪枝优化。

那怎么剪枝呢?我们是否可以在生成全排列的同时判断这组全排列相加到最后能够不大于最后剩下的那个数呢?

事实上这是可行的,我们假设n为一个比较小的数(比如,按样例,4),设第一行的n个数分别为a,b,c,...(我这里是a,b,c,d),然后模拟加一下,就会发现sum是……

如果n为4,那么sum是a+3b+3c+d。

如果n为5,那么sum是a+4b+6c+4d+e。

如果n为6,那么sum是a+5b+10c+10d+5e+f。

不知大家发现没有这与杨辉三角形有关。

杨辉三角形:

        1
       1 1
      1 2 1
     1 3 3 1
    1 4 6 4 1
   1 5 10 10 5 1
  1 6 15 20 15 6 1
 1 7 21 35 35 21 7 1
1 8 28 56 70 56 28 8 1

那么就简单了,我们只需要先根据所给的n来构造一个杨辉三角形的系数即可,而计算杨辉三角形每一行的系数可以使用组合数来求。 第n行的第k个数字为组合数C(n, k)的值。

代码:

#include <bits/stdc++.h>

using namespace std;
typedef long long int lli;
/**
 *  Created with IntelliJ Clion.
 *  @author  wanyu
 *  @Date: 2018-04-27
 *  @Time: 19:23
 *  To change this template use File | Settings | File Templates.
 * 
 */

#define mset(t, x) memset(t,x,sizeof(t))
#define lson index<<1
#define rson (index<<1) +1
#define loop(a, b, c) for(int a=b;a<=c;a++)
#define loop2(a, b, c) for(int a=b;a>=c;a--)
#define loop3(a, b, c) for(int a=b;a<c;a++)
#define loop4(a, b, c) for(int a=b;a>c;a--)
#define maxn 30
#define maxm 10005
int n, sum;
int st[maxn];
bool vis[maxn];
bool flag;
int c[maxn];

inline lli fac(int x) {//阶乘计算
    lli ans = 1;
    if (x == 0) {
        return ans;
    }
    loop(i, 2, x) {
        ans *= i;
    }
    return ans;

}

void caculate(int m) {//组合数计算
    //系数计算
    //C(n,m)=n!/[m!(n-m)!]
    lli a = fac(m);
    loop(i, 0, m) {
        c[i] = a / (fac(i) * fac(m - i));
    }
}

bool judge(int x) {//判断当前生成的数列相加后是否小于等于最后剩下的那个
    int temp = 0;
    loop(i, 0, x) {
        temp += (st[i] * c[i]);
        if (temp > sum) {//大于
            return false;
        }
    }
    if (x == n - 1) {
        return temp == sum;//如果生成了最后一个数
    }
    return true;
}

void dfs(int index) {
    if (index == n) {
        flag = true;
        loop3(i, 0, n) {//输出答案
            cout << st[i];
            if (i != n) {
                cout << " ";
            }
        }
        return;
    }
    loop(i, 1, n) {
        if (!vis[i]) {
            vis[i] = 1;
            st[index] = i;
            if (judge(index)) {//判断生成的这个数是否合法
                dfs(index + 1);
            } else {
                vis[i] = 0;//回溯
                return;
            }
            if (flag) {//找到结果了
                return;
            }
            vis[i] = 0;
        }
    }

}

int main() {
    cin >> n >> sum;
    mset(vis, 0);
    mset(st, 0);
    mset(c, 0);
    flag = 0;
    caculate(n - 1);
    dfs(0);
    return 0;
}