CSP2023-S T4题解

· · 题解

思路

题目中说最少需要多少天,不难发现天数具有单调性,可以想到二分。对于二分出的天数 mid ,对于第 i 块地可以二分计算出 t_i 表示它最迟哪一天开始种。然后我们考虑如何安排扩散顺序最优。对于一个时间 d ,所有 t_i \leq d 的点都应该扩散完成,所以可以按照 t_i 排序。对于排序过后的数组,每一个点暴力一直向父亲节点跳,直到跳到一个被扩散过的节点。显然,整个向上跳的过程复杂度是均摊 O(1) 的。综上所述,记 w = 10^9, 本题总复杂度是 O(n\log^2w)注意:在计算 t_i 时会爆 long long

Code

#include<bits/stdc++.h>
#define int long long
#define M 100005
using namespace std;
int a[M], b[M], c[M], lt[M], n, f[M], vis[M];
vector<int> v[M];
struct node{
    int id, x;
}seq[M];
bool cmp(node x, node y)
{
    return x.x < y.x;
}
//58156328888996813 791817301 56
void dfs(int x, int fa)
{
    f[x] = fa;
    for(auto &t : v[x]) if(t != fa) dfs(t, x);
} 
bool check1(int d, int id, int ed)
{
    int pt = ed - d + 1;
    if(c[id] > 0)
    {
        __int128 t1 = c[id];
        __int128 t2 = b[id];
        __int128 t = t2 * pt + t1 * (d + ed) * pt / 2;
        //if(id == 14) print(t);
        return (t >= a[id]);
    }
    if(c[id] == 0)
    {
        return (b[id] * pt >= a[id]);
    }
    if(c[id] < 0)
    {
        int ng = max(b[id] + d * c[id], 1ll);
        if(ng == 1) return (pt >= a[id]);
        int dd = d + (ng - 1) / (-c[id]) + 1;
        //cout << ng << endl;
        if(dd > ed)
        {
            __int128 t1 = c[id];
            __int128 t2 = b[id];
            __int128 t = t2 * pt + t1 * (d + ed) * pt / 2;
            return (t >= a[id]);
        }
        else
        {
            int pt1 = dd - d;
            __int128 t1 = c[id];
            __int128 t2 = b[id];
            __int128 t = t2 * pt1 + t1 * (d + dd - 1) * pt1 / 2 + (ed - dd + 1);
            return (t >= a[id]);
        }
    }
}
bool check(int x)
{
    for(int i = 1; i <= n; i++)
    {
        int l = 0, r = x, mid, ans = -1;
        while(l <= r)
        {
            mid = (l + r) >> 1;
            if(check1(mid, i, x)) l = mid + 1, ans = mid;
            else r = mid - 1;
        }
        //cout << x << " " << i << " " << a[i] << endl;
        if(ans == -1) return 0;
        lt[i] = ans;
        seq[i].id = i, seq[i].x = lt[i];
        //c
    }
    dfs(1, 0);
    memset(vis, 0, sizeof(vis));
    sort(seq + 1, seq + n + 1, cmp);
    int dy = 0;
    for(int i = 1; i <= n; i++)
    {
        if(vis[seq[i].id]) continue;
        int x = seq[i].id;
        //cout << seq[i].id << " " <<
        while(1)
        {
            dy++;
            vis[x] = 1;
            x = f[x];
            if(vis[x] || (x == 0)) break;
        }
        if(lt[seq[i].id] < dy) return 0;
    }
    return 1;
}
signed main()
{
//    freopen("tree.in", "r", stdin);
//    freopen("tree.out", "w", stdout);
    scanf("%lld", &n);
    for(int i = 1; i <= n; i++) scanf("%lld%lld%lld", &a[i], &b[i], &c[i]);
    for(int i = 1, x, y; i < n; i++) scanf("%lld%lld", &x, &y), v[x].push_back(y), v[y].push_back(x);
    //cout << check1(1, 4, ) << endl;
    int l = 0, r = 1e9, mid, ans = -1;
    while(l <= r)
    {
        mid = (l + r) >> 1;
        //cout << l << " " << r << " " << mid << endl;
        if(check(mid)) r = mid - 1, ans = mid;
        else l = mid + 1;
    }
    cout << ans << endl;
    return 0;
}