yukicoder No.2164 Equal Balls

yukicoder.me

解説というよりかは、ACするまでに考えたこととそれからの高速化について考えたことのメモです。
具体的には、私の初AC時は${\rm O}(NM + M ^ {3} \log M)$ だったのですが、そこから ${\rm O} ( NM + M ^ {2} \log^{2} M ) $への高速化を行いました。
また、計算量を考えるうえで $ M = {\rm max} A = {\rm max} B $ を暗黙のうちに用いています。

ACするまでに考えたこと

まず、Aの和、Bの和をそれぞれ持ったDPを考えたくなりますが、$N = 10 ^ 5 $ のため計算量的に無理そうだと悟ります。そこで、(Aの和) $-$ (Bの和)といった変換を行うことで、和について持つkeyは1つで良くなりそうだと気づきますが、これでもまだ厳しそうです(実際は直前 $ M $ 個の情報全部持たなければいけないためおそらくkeyは1つでは無理)。以降、 $s = $ (Aの和) $-$ (Bの和)とします。
DPの遷移について考えてみると、$M - 1$ 項目までは$s$ はなんでもよいですが、 $ M $ 項目以降は直前の $ M $ 個の和が完全に一致している($ s = 0 $ である)必要があることに気づきます。そして、1項目が $ s = 1 $ だった場合、$ M + 1 $に遷移する際は直前 $ M $ 個の $ s $ が1減ることになります。すなわち、 $ M $ 個の和について常に $ s = 0 $ が成り立っていなければならないので、この均衡を保つためには、 $ M + 1 $ でも $ s = 1 $ とならなければいけないことが分かります。

したがって、要素番号 $ i $ の $ i \bmod M $ が等しいものについてはすべてにおいて $ s $ が等しいものを選択し、すべての余剰 $ 0 \leq x \leq M - 1 $についての $ s $ の和が0となっているものを数え上げればいいことになります。これならば、マイナスを考慮しても状態が$ 300 \times 600 $程度に抑えられそうなため、遷移に $ {\rm O} (M) $ 程度かけても間に合いそうだと考えました。

方針がたったため、各要素 $ i $ について$\bmod M $ でまとめます。すなわち、余剰が同じものについて $ s $ が同じものについての積をとります。しかし、ここで少し困りました。 $ A[i] $の中から $ a $ 個選ぶ、 $ B[i] $ の中から $ b $ 個選ぶ場合、$ s = a - b $ となり、それぞれ $ A[i] $ と $ B[i] $ についての二項係数の積で求めることができますが、同じ$s $ の中でも($s = 1$を例に挙げると、$(a, b) = (1, 0), (2, 1), (3, 2)$などのように) 複数のものが同じ$s$に集められる場合があることに気がつきます。愚直にすべての積をとってしまうと、$ {\rm O} (M ^ 2)$かかってしまい、状態数を考えると$ {\rm O} (NM ^ 2) $まで膨れ上がるため、高速化を考える必要がありました。まず、思いついたものとしてそれぞれを$ {\rm O} (M \log M) $で畳み込むことですが、それでも $ M = 600 $程度の $ {\rm O} (N M \log M) $ が間に合うとは思えません。そこで、DEGwerさんの数え上げpdfに二項係数の畳み込みみたいなの載ってないかなと思い、見に行くと30ページの公式集に見つかります。

drive.google.com

$$ \sum_{i = 0}^{k} \binom{n}{i} \binom{m}{k -i } = \binom{n + m}{k} $$

らしいです。$ (1 + x) ^ {N}$の係数が二項係数になるので、

$$ (1 + x) ^ {A[i]} (1 + x) ^ {B[i]} = (1 + x) ^ {A[i] + B[i]} $$

から自明だった...。ただ、$a + b$ではなく,$a - b$で使用しているためこの公式ができるか確信が持てなかったのですが,パスカルの三角形は左右対称的で$B[i]$から$b $ 個選ぶ場合の数は、$B[i]$から $b$ 個選ばないものを選ぶ場合の数に等しいことからこの公式が使えそうだと思いいたります。 このようにして、各余剰について$s$に対する二項係数の積を求めるパートは$ {\rm O} (NM )$でできました。

続いて、最後の和の数え上げパートですが、ここではMの和を情報に持つ必要があるため、状態数が $ {\rm O} (M^{2}) $ 、遷移が $ {\rm O} (M)$ 、遷移回数が $ {\rm O} ( M ) $ 、全体が$ {\rm O} (M ^ {4})$の計算量になり、ここでも高速化が必要になりました。ここの高速化は畳み込みしか思いつかなかったのもありますが、今回は長さ $ M ^ 2 $ 程度のものを $ M $ 回畳み込めばよいため、間に合いそうだと思いました。よって、計算量 $ {\rm O} ( M^{3} \log M ) $ 程度で間に合いました。

高速化について

無事ACできたのですが、初提出の実行時間は4041 msで本当に想定解なのか疑わしいです...。解説を見ようにもこの日はACとっても見れない仕様になっているようでした。
最終的には、以下の点を改良することで1000 msを切る程度まで実行時間を改良することができました。

  • 二項係数の計算を階乗前計算から和の漸化式による計算を切り替える(パスカルの三角形からの計算)
  • $ \bmod M $ が等しいものをまとめるパートで毎回余剰を求めないようにする
  • $ \bmod M $ が異なるものについて全体の畳み込みを行う際にサイズが近いもので畳み込みされるようにする

前2つはなるべく $ \bmod $ の計算回数を減らそうというものなのでいいとして、最後のものについて軽く触れます。

これは、端的に言えば、先頭から畳み込み演算を行っていたものを、並列的な畳み込み演算に切り替えたことになります。配列のサイズを $ N $ とすると、どちらも $N - 1 $ 回の計算であることには変わりないですが、畳み込みのサイズが後者の方が抑えられそうな気がしました。

この発想に至った経緯について触れます。畳み込みの計算量は配列 $ A $ のサイズを $ | A | $ 、配列 $ B $ のサイズを $ | B | $ とすると、 $ {\rm O} ( ( |A| + |B| ) \log ( |A| + |B| ) ) $になります。$ \log $の計算量評価は難しいため、ひとまず $\log$ なしで考えてみると畳み込みの累積演算は以下の問題に帰着されるかと思います。

$ N $ 個の要素からなる配列 $ A_i $ が与えられます。要素$i$、$j$について合体させるには、コストが $ A_i + A_j $ かかり、合体後のサイズは$ A_i + A_j $ となります。最終的に配列の要素が1つになるまで合体させるのにかかる最小コストはいくつですか?

これは、AtCoderのEducational DP Contest / DP まとめコンテスト 「N - Slimes 」の隣接選択の条件が外れただけのものになります。

atcoder.jp

隣接選択の条件が外れたより近い問題としてはAtCoder Beginner Contest 252の「F - Bread」やプログラミングコンテストチャレンジブック第2版49ページの問題などがります。いずれもなるべく小さいサイズのもの2つ選んでマージするという貪欲法によって最適解が得られました。以上の経験から、先頭から累積的に計算するよりも並列的に計算した方が計算量が良くなるのではないかと予想しました。

atcoder.jp

実際に評価してみると、8個の要素からなる配列を先頭からマージしていく場合、

  • $ a_1 a_2 a_3 a_4 a_5 a_6 a_7 a_8 $  
    合計コスト $0$
  • $ ( a_1 + a_2 ) a_3 a_4 a_5 a_6 a_7 a_8 $  
    合計コスト $a_1+a_2$
  • $ ( ( a_1+a_2)+a_3) a_4 a_5 a_6 a_7 a_8 $  
    合計コスト $2(a_1+a_2) + a_3$
  • $ ( ( ( a_1+a_2)+a_3)+a_4) a_5 a_6 a_7 a_8 $  
    合計コスト $3(a_1+a_2) + 2a_3 + a_4$
  • $ ( ( ( ( a_1+a_2)+a_3)+a_4)+a_5) a_6 a_7 a_8 $  
    合計コスト $4(a_1+a_2) + 3a_3 + 2a_4 + a_5$
  • $ ( ( ( ( ( a_1+a_2)+a_3)+a_4)+a_5)+a_6) a_7 a_8 $  
    合計コスト $5(a_1+a_2) + 4a_3 + 3a_4 + 2a_5 + a_6$
  • $ ( ( ( ( ( ( a_1+a_2)+a_3)+a_4)+a_5)+a_6)+a_7) a_8 $  
    合計コスト $6(a_1+a_2) + 5a_3 + 4a_4 + 3a_5 + 2a_6 + a_7$
  • $ ( ( ( ( ( ( ( a_1+a_2)+a_3)+a_4)+a_5)+a_6)+a_7)+a_8) $  
    合計コスト $7(a_1+a_2) + 6a_3 + 5a_4 + 4a_5 + 3a_6 + 2 a_7 + a_8$

のようになります。配列 $ A $ のサイズを $ M $ 、 $ A $ の要素をすべて $ 1 $ とすると、合計コストは $ M ^ 2 ( M + 1 ) / 2 $ 程度になります。($ A $ の要素が$ 1 $の場合は、$ M (M + 1) / 2$ 程度)

8個の要素からなる配列を並列的にマージしていく場合、

  • $ a_1 a_2 a_3 a_4 a_5 a_6 a_7 a_8 $  
    合計コスト $0$
  • $ (a_1+a_2) (a_3+a_4) (a_5+a_6) (a_7+a_8) $  
    合計コスト $a_1 + a_2 +a_3 + a_4 + a_5 + a_6 + a_7 + a_8$
  • $ ( ( a_1+a_2)+(a_3+a_4)) ( ( a_5+a_6)+(a_7+a_8)) $  
    合計コスト $2(a_1 + a_2 +a_3 + a_4 + a_5 + a_6 + a_7 + a_8)$
  • $ ( ( ( a_1+a_2)+(a_3+a_4))+ ( ( a_5+a_6)+(a_7+a_8))) $  
    合計コスト $3(a_1 + a_2 +a_3 + a_4 + a_5 + a_6 + a_7 + a_8)$

となります。配列 $ A $ のサイズを $ M $ 、 $ A $ の要素をすべて $ M $ とすると、合計コストは $ M^ 2 \log M $ 程度になります。($ A $ の要素が$ 1 $の場合は、$ M \log M $ 程度)

若干疑わしかったため、念のためEDPC - TのACコードに

8
1 1 1 1 1 1 1 1

を投げてみます。 すると無事 $ 8 \times \log_2 8 = 8 \times 3 = 24 $ が返ってきました。

畳み込みの計算量を $ {\rm O} (M_1 \log M_2 ) $ とすると、ここでは $ M_1 $ の部分を考えたことになります。 $ M_2 $ の部分は $ M ^ {2} $ になっても定数倍を考慮すると、$ \log M $ に落ちます。したがって、先頭から累積的にマージさせると$ {\rm O} (NM + M ^ {3} \log M) $ であったのに対し、並列的にマージすることによって計算量を $ {\rm O}( NM + M^{2} \log ^{2} M ) $ まで落とすことができました。 $ NM $ は二項係数の積を求めるパートに依存する計算量です。
C++の標準ライブラリでは、std::accumulateが先頭から累積的にマージするもので、C++17以降で使えるstd::reduceが並列的に計算を行うものになります。C++に馴染みのない方へ簡単に説明すると、accumulate的なマージがデータ構造でいうと累積和にあたり、reduce的なマージがデータ構造でいうとセグメント木にあたるかと思います。

実際に、accumulateとreduceを使って実装してみるとaccumulateからreduceに切り替えることによって、2948 msから996 msへとなり2秒近く早くなりました。こんなことによってオーダーレベルで改善されてるとは思わなかった...。writer様やtester様は添え字をstd::queueに持ってマージしているようです。速い解法の人はみんなstd::queue、std::deque、std::priority_queueなどを使ったいたので、強い人には常識なのかな...?

std::accmulateを用いた実装 (2948 ms)

#include <atcoder/all>
using namespace std;
using namespace atcoder;
using mint = modint998244353;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    int N, M, r = 600;
    cin >> N >> M;
    vector<int> A(N), B(N);
    for(auto &&v:A)cin >> v;
    for(auto &&v:B)cin >> v;

    vector<vector<mint>> comb(r + 1), tb(M, vector<mint>(r + 1, 1));
    for(int i = 1; i <= r; i++){
        comb[i].resize(i + 1);
        comb[i][0] = comb[i][i] = 1;
        for(int j = 1; j < i; j++) comb[i][j] = comb[i - 1][j - 1] + comb[i - 1][j];
    }

    for(int i = 0, rem = 0; i < N; i++){
        for(int j = -B[i]; j <= A[i]; j++) tb[rem][j + 300] *= comb[A[i] + B[i]][B[i] + j];
        for(int j = A[i] + 1; j <= 300; j++) tb[rem][j + 300] = 0;
        for(int j = B[i] + 1; j <= 300; j++) tb[rem][-j + 300] = 0;
        if(++rem >= M)rem -= M;
    }

    cout << accumulate(tb.begin(), tb.end(), vector<mint>(1, 1), [](vector<mint> lhs, vector<mint> rhs){
        return convolution(lhs, rhs);
    })[300 * M].val() << '\n';
}

std::reduceを用いた実装 (996 ms)

#include <atcoder/all>
using namespace std;
using namespace atcoder;
using mint = modint998244353;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    int N, M, r = 600;
    cin >> N >> M;
    vector<int> A(N), B(N);
    for(auto &&v:A)cin >> v;
    for(auto &&v:B)cin >> v;

    vector<vector<mint>> comb(r + 1), tb(M, vector<mint>(r + 1, 1));
    for(int i = 1; i <= r; i++){
        comb[i].resize(i + 1);
        comb[i][0] = comb[i][i] = 1;
        for(int j = 1; j < i; j++) comb[i][j] = comb[i - 1][j - 1] + comb[i - 1][j];
    }

    for(int i = 0, rem = 0; i < N; i++){
        for(int j = -B[i]; j <= A[i]; j++) tb[rem][j + 300] *= comb[A[i] + B[i]][B[i] + j];
        for(int j = A[i] + 1; j <= 300; j++) tb[rem][j + 300] = 0;
        for(int j = B[i] + 1; j <= 300; j++) tb[rem][-j + 300] = 0;
        if(++rem >= M)rem -= M;
    }

    cout << reduce(tb.begin(), tb.end(), vector<mint>(1, 1), [](vector<mint> lhs, vector<mint> rhs){
        return convolution(lhs, rhs);
    })[300 * M].val() << '\n';
}