hogecoder

tsutaj 競技プログラミングの記録

FaceBook Hacker Cup 2021 Qualification Round C: Gold Mine

問題概要

問題原文 → Problem C2: Gold Mine - Chapter 2 | Facebook Hacker Cup - 2021 - Qualification Round

 N 頂点の木があり、それぞれの頂点  i には重み  C_i がついている。木から高々  K 個の辺素パスを取るとき、パスに含まれる重みの和の最大値を求めよ (ただしそれぞれの頂点の重みは一度のみ加算するものとする)。

※元の問題は C1, C2 に分かれていますが、C1 では  K = 1 で固定です。この記事では C2 について解説しています。

解法

頂点  1 を根として、根付き木として考察します。動的計画法 (DP) で解くことができます。

 \mathrm{dp}[i][j][s] := 頂点  i を根とした部分木の中で、すでに完成した辺素パスが  j 個あり、以下で表す状態が  s であるときの答え

  •  s = 0: 頂点  i はパスに含まれていない
  •  s = 1: 頂点  i i を端点とするパスに含まれており、それは今後親方向に伸びる可能性がある
  •  s = 2: 頂点  i はパスに含まれており、親方向に伸びる可能性はない

DP の定義より、答えは  \max_{k=0}^{K} \mathrm{dp}[1][k][2] に等しいです。

 s = 0 の場合は比較的簡単で、頂点  i が含まれない状態を漏れのないように列挙して遷移させればよいです。

 s = 1 の場合は、主に以下の 3 パターンがあります。

  •  i ですでに  s = 1 であるパターン (子要素が来てもその状態を保持しておく)
  •  i の子  c について  s = 1 であるものが伸びてくるパターン
  •  i からパスを始めるパターン

 s = 2 の場合は、主に以下の 3 パターンがあります。

  •  i ですでに  s = 2 であるパターン (子要素が来てもその状態を保持しておく)
  •  i の子  c_1, c_2 について  s = 1 であるものが伸びてくるパターン
    • この場合、頂点  i は端点となることはなく、 i を通過するようなパスが出来ます
  •  i からパスを始め、そこで即終了するパターン
    •  1 頂点のみを取りに行く場合に相当します

遷移がややこしいですが、これらを網羅すると解くことができます。

ソースコード

merge() 関数は、ある頂点 cur とその子 nxt に関して、DP の情報をマージするための関数です。この関数では子が存在しない場合 (つまり、その頂点が葉である場合) に対応できないため、別途その処理を書く必要があります。

また、 s = 1, 2 において  C_{\mathrm{cur}} の値を DP の中で加算してしまおうとすると、パスは点素ではないため何度加算されるかわかりません。すでに加算したかどうかのフラグを DP に持たせても良いですが、次元が増えてしまうので、DP 計算後にループで足してしまうほうが楽でしょう。

template<typename T> void chmax(T &a, T b) {a = max(a, b);}

// dp[頂点][ドリル回数][パス伸びてない, パス伸びている, その頂点を通るパスがある]
ll dp[55][55][3], buf[55][3];

void solve(int case_num) {
    int N; scanf("%d", &N);

    int K; scanf("%d", &K);
    // const int K = 1;
    
    vector<ll> C(N);
    for(auto &e : C) cin >> e;

    vector< vector<int> > G(N);
    for(int i=0; i+1<N; i++) {
        int a, b; scanf("%d%d", &a, &b);
        a--; b--;
        G[a].emplace_back(b);
        G[b].emplace_back(a);
    }
    
    fill(dp[0][0], dp[N+1][0], -LONGINF);
    
    auto merge = [&](int cur, int nxt) {
                     for(int i=K; i>=0; i--) {
                         for(int j=0; i+j<=K; j++) {
                             ll ma_dp_nxt = max(dp[nxt][j][0], dp[nxt][j][2]);
                             ll ma_dp_cur = max(dp[cur][i][0], dp[cur][i][2]);

                             // << 両方伸びている場合 >>
                             // 伸びているものを合体させる
                             // それ以前に、スタンバイさせてるやつが存在
                             chmax(buf[i+j+1][2], dp[cur][i][1] + dp[nxt][j][1]);

                             // << nxt が伸びている場合 >>
                             // cur が空いている場合に nxt を伸ばす (cur で切る場合もある)
                             chmax(buf[i+j][1], ma_dp_cur + dp[nxt][j][1]);
                             chmax(buf[i+j+1][2], ma_dp_cur + dp[nxt][j][1]);

                             // 伸びているものをスタンバイさせる
                             chmax(buf[i+j][1], ma_dp_cur + dp[nxt][j][1]);                  
                             
                             // << 両方伸びていない場合 >>
                             // cur が空いている場合に cur から伸ばす場合あり (cur で即切る場合もあり)
                             chmax(buf[i+j][1], ma_dp_cur + ma_dp_nxt);
                             chmax(buf[i+j+1][2], ma_dp_cur + ma_dp_nxt);

                             // nxt について、cur を使わない
                             chmax(buf[i+j][0], dp[cur][i][0] + ma_dp_nxt);
                             chmax(buf[i+j][1], dp[cur][i][1] + ma_dp_nxt);
                             chmax(buf[i+j][2], dp[cur][i][2] + ma_dp_nxt);                             
                         }
                     }

                     for(int i=K; i>=0; i--) {
                         dp[cur][i][0] = buf[i][0];
                         dp[cur][i][1] = buf[i][1];
                         dp[cur][i][2] = buf[i][2];
                     }
                 };

    ll ans = C[0];
    auto dfs = [&](auto &&self, int cur, int par) -> void {
                   for(auto to : G[cur]) {
                       if(to == par) continue;
                       self(self, to, cur);
                   }
                   fill(buf[0], buf[K+1], -LONGINF);
                   buf[0][0] = 0;
                   dp[cur][0][0] = 0;
                   for(auto to : G[cur]) {
                       if(to == par) continue;
                       merge(cur, to);
                   }

                   // 葉から始める場合
                   chmax(dp[cur][0][1], 0LL);
                   chmax(dp[cur][1][0], C[cur]);
                   
                   for(int i=0; i<=K; i++) {
                       for(int j=0; j<3; j++) {
                           ll &c = dp[cur][i][j];
                           c += (j >= 1) * C[cur];
                           if(cur == 0 and j == 2) chmax(ans, c);
                       }
                   }
               };
    dfs(dfs, 0, -1);
    printf("Case #%d: %lld\n", case_num, ans);
}

int main() {
    int T; scanf("%d", &T);
    for(int i=1; i<=T; i++) solve(i);
    return 0;
}

自分で解いているときは上記のような問題設定とは微妙に異なるもので考えてしまっていて、パスであって頂点 1 を通るものに関してはすべて頂点 1 を端点として考えることにしてしまっていました。こうすると、許容されるパスの数が  K 本だったり  K + 1 本だったりするので迷走していました。いい感じの問題に落とし込む力も重要ですね。