hogecoder

tsutaj 競技プログラミングの記録

Codeforces Round #576 Div.1 D: Rectangle Painting 1

本番は嘘解法で通してしまったのでちゃんと書きます。

問題概要

原文 → Problem - D - Codeforces

 N \times N のグリッド状のマスが与えられ、それぞれは白または黒に塗られている。

このマスの任意の  h \times w 長方形領域に対して、その領域内にあるマスを全て白にするという操作を何度でもすることができる。この操作のコストは  \max(h, w) である。

全てのマスを白くするために必要なコストの合計の最小値を求めよ。

解説

まず、操作を行う長方形領域が重ならないし、接しもしないことが重要である。そうなることを簡単に証明する。

ふたつの長方形領域があって、その大きさがそれぞれ  (h_1, w_1), (h_2, w_2) であり、これらが重なっているまたは接しているものとする。この長方形領域をどちらも含むような長方形領域であって幅と高さそれぞれ最小のものをとり、そのサイズを  (H, W) とおくと、重なっているまたは接しているという仮定から  H \leq h_1 + h_2 および  W \leq w_1 + w_2 が成り立つ。

 (H, W) の長方形を取ることで損をしないことを示す (これが示せれば重なるケースや接するケースを考慮しなくて良い)。 (h_1, w_1), (h_2, w_2) について  \max をどこでとるかで場合分けして考える。

  •  \max(h_1, w_1) = h_1, \max(h_2, w_2) = h_2 のとき
    • 上の不等式から明らかに成立
  •  \max(h_1, w_1) = w_1, \max(h_2, w_2) = w_2 のとき
    • 上の不等式から明らかに成立
  •  \max(h_1, w_1) = h_1, \max(h_2, w_2) = w_2 のとき
    • 別々に取ると  h_1 + w_2 だけコストがかかるが、 w_1 \leq h_1 より  W \leq w_1 + w_2 \leq h_1 + w_2 であることがわかり、 h_2 \leq w_2 より  H \leq h_1 + h_2 \leq h_1 + w_2 であることもわかるので、  (H, W) の長方形を取って損しないことが分かる
  •  \max(h_1, w_1) = w_1, \max(h_2, w_2) = h_2 のとき
    • 上と同様の理由で成立

よって長方形領域が重なる・接することはない。

あとはどのようにして解くかだが、ある長方形領域に着目したときの最適な値をメモ化再帰で求めればよい。ある行または列に対して、はじめから全て白であればそこで長方形領域が分かれるとみなしてよく、より小さい問題にして解くことができる。

ソースコード

int dp[51][51][51][51];
int sum[51][51];

int get_sum(int lx, int ly, int rx, int ry) {
    return sum[rx][ry] + sum[lx][ly] - sum[lx][ry] - sum[rx][ly];
}

int solve(int lx, int ly, int rx, int ry) {
    if(dp[lx][ly][rx][ry] >= 0) return dp[lx][ly][rx][ry];
    int bx = (rx - lx), by = (ry - ly);
    if(bx == 0 or by == 0) return 0;
    
    int res = max(bx, by);
    for(int i=lx; i<rx; i++) {
        if(get_sum(i, ly, i+1, ry) == 0) {
            res = min(res, solve(lx, ly, i, ry) + solve(i+1, ly, rx, ry));
        }
    }
    for(int i=ly; i<ry; i++) {
        if(get_sum(lx, i, rx, i+1) == 0) {
            res = min(res, solve(lx, ly, rx, i) + solve(lx, i+1, rx, ry));
        }
    }
    return dp[lx][ly][rx][ry] = res;
}

int main() {
    int N; scanf("%d", &N);
    for(int i=1; i<=N; i++) {
        for(int j=1; j<=N; j++) {
            char c; scanf(" %c", &c);
            if(c == '#') sum[i][j]++;
        }
    }

    for(int i=0; i<=N; i++) {
        for(int j=1; j<=N; j++) {
            sum[i][j] += sum[i][j-1];
        }
    }
    for(int j=0; j<=N; j++) {
        for(int i=1; i<=N; i++) {
            sum[i][j] += sum[i-1][j];
        }
    }

    fill(dp[0][0][0], dp[N+1][0][0], -1);
    cout << solve(0, 0, N, N) << endl;
    return 0;
}