python競技プログラミングで、二項係数の計算でTLEしたので高速化した話

競技プログラミングAtCoderなど)によくある、二項係数(コンビネーション)を109+7で割った余りを求める話です。

前提条件

Pythonで実装するときに、気をつけるべきポイントを書いておきます。
C++Javaを使ってる人は、この記事を読む意味はありません。細かい実装を気にしなくても余裕で間に合うので。
しかしPythonを使う場合は、コードの細かい違いによってTLEになる危険性が多くあります。

よくやる二項係数 (nCk mod. p)、逆元 (a^-1 mod. p) の求め方 - けんちょんの競プロ精進記録
に記載してある内容を理解していることを前提とします。

また、これは上記記事の2番(nが大きく、kが105程度)のアルゴリズムの実装です。 よく使う方である、上記記事の1番(階乗テーブルを作る方の)解法は書いていません(が、間接的に参考になると思います)。

いろいろ試してみましたが、使い勝手と速度を考えると、ライブラリには「教訓2」の形で入れておくのが良いと思います。

f:id:soratokimitonoaidani:20200223234352p:plain

問題設定

n <= 109, k <= 2*105とする。 nCkを素数pで割った余りを求めたい。

n = 10 ** 9
k = 2 * 10 ** 5
mod = 10**9 + 7

# x ** a をmodで割った余りを、O(log(a))時間で求める。
def power(x, a):
    if a == 0:
        return 1
    elif a == 1:
        return x
    elif a % 2 == 0:
        return power(x, a//2) **2 % mod
    else:
        return power(x, a//2) **2 * x % mod

# xの逆元を求める。フェルマーの小定理より、 x の逆元は x ^ (mod - 2) に等しい。計算時間はO(log(mod))程度。
# https://qiita.com/Yaruki00/items/fd1fc269ff7fe40d09a6
def modinv(x):
    return power(x, mod-2)

modinv_table = [-1] * (k+1)
for i in range(1, k+1):
    modinv_table[i] = modinv(i)

def binomial_coefficients(n, k):
    ans = 1
    for i in range(k):
        ans *= n-i
        ans *= modinv_table[i + 1]
        ans %= mod
    return ans

print(binomial_coefficients(n, k))

こういうコードを書いた。
結論からいうと、アルゴリズムが間違っているわけではないが、このコードはTLEになる。

TLEを確認する

AtCoderのコードテスト上で以下を確認した。
(コード提出と異なり、2秒を超えてもTLEにはならないんだな。一部の問題は実行時間制限が4秒や6秒になっているから、)

n= 10**9, k = 10**5print(binomial_coefficients(n, k))を計算する
→1740ms程度かかった。

n= 10**9, k = 2 * 10**5print(binomial_coefficients(n, k))を計算する
→3480ms程度かかった。

というわけで、制限時間が2秒の場合は、105でギリギリ間に合うが、2*105だと間に合わない。

教訓1 python標準のpow()関数は「xyをpで割った余りを求める」こともできる

まさかそんなことはできないだろうと思って自分で繰り返し二乗法を書いていたのだが、「xyをpで割った余りを求める」こともできる。 これを使えば、上のコードは以下のように書ける。

n = 10 ** 9
k = 2 * 10 ** 5
mod = 10**9 + 7

def modinv(x):
    return pow(x, mod-2, mod)

modinv_table = [-1] * (k+1)
for i in range(1, k+1):
    modinv_table[i] = modinv(i)

def binomial_coefficients(n, k):
    ans = 1
    for i in range(k):
        ans *= n-i
        ans *= modinv_table[i + 1]
        ans %= mod
    return ans

print(binomial_coefficients(n, k))
n= 10**9, k = 2 * 10**5print(binomial_coefficients(n, k))を計算する
→870ms程度かかった。

よ、余裕で間に合うじゃん!!

組み込みのpow()のほうがもちろん最適化されているので、同じ計算をするとしても高速に動作する。

教訓2 高速に逆元を求める方法

フェルマーの小定理を使わずに逆元を求めるという、スゴイ方法がある。この方法を最初に考えた人はどういう頭の構造をしてるんだろうか? 累乗計算をする必要がないので、フェルマーの小定理に比べて高速になる。
証明は、Python で二項係数 nCr を高速に計算したい | Satoooh Blogなどに詳しい。

n = 10 ** 9
k = 2 * 10 ** 5
mod = 10**9 + 7

# def modinv(x):
# xの逆元を求める際に[mod % x]の逆元が必要なので、関数の形でxの逆元を直接求めることは難しい。再帰を使えば行けそうだけど。

modinv_table = [-1] * (k+1)
modinv_table[1] = 1
for i in range(2, k+1):
    modinv_table[i] = (-modinv_table[mod % i] * (mod // i)) % mod

def binomial_coefficients(n, k):
    ans = 1
    for i in range(k):
        ans *= n-i
        ans *= modinv_table[i + 1]
        ans %= mod
    return ans

print(binomial_coefficients(n, k))
n= 10**9, k = 2 * 10**5print(binomial_coefficients(n, k))を計算する
→190ms程度かかった。

よ、余裕で間に合うじゃん!! (2回目)

フェルマーの小定理から逆元を求める方法の計算量は、log(109+7)程度である。しかし、pythonは遅いので、そのlogのせいでTLEになりかねない。

教訓3 逆元を求める回数を極力減らす

逆元を求めるというのは計算量の重い操作である。どうしても剰余演算をしなければいけないからだ。
したがって、同じ計算をするのでも、逆元を求める操作の回数が少ないほうが計算時間は少なくて済む。
……という点に注意して、元のコードを眺めると、何度も逆元(mod上の割り算)を計算していることに気づく。それを一回で済むように書き換えればよい。

n = 10 ** 9
k = 2 * 10 ** 5
mod = 10**9 + 7

# x ** a をmodで割った余りを、O(log(a))時間で求める。
def power(x, a):
    if a == 0:
        return 1
    elif a == 1:
        return x
    elif a % 2 == 0:
        return power(x, a//2) **2 % mod
    else:
        return power(x, a//2) **2 * x % mod

# xの逆元を求める。フェルマーの小定理より、 x の逆元は x ^ (mod - 2) に等しい。計算時間はO(log(mod))程度。
# https://qiita.com/Yaruki00/items/fd1fc269ff7fe40d09a6
def modinv(x):
    return power(x, mod-2)

def binomial_coefficients(n, k):
    numera = 1  # 分子
    denomi = 1  # 分母

    for i in range(k):
        numera *= n-i
        numera %= mod
        denomi *= i+1
        denomi %= mod
    return numera * modinv(denomi) % mod

print(binomial_coefficients(n, k))
n= 10**9, k = 2 * 10**5print(binomial_coefficients(n, k))を計算する
→85ms程度かかった。

よ、余裕で間に合うじゃん!! (3回目)

注意すべきは、テーブルにしていないので毎回の計算に時間がかかることだ。
色々なkに対して二項係数nCkを何度も求める場合は不向き。

余談:python3.8では組み込みpow()関数で逆元が計算できるらしい……

python3.8では、modを取ったときの-1乗や-2乗などもできるらしい。以下は公式ドキュメントに記載されている例だ。
組み込み関数 — Python 3.8.2rc2 ドキュメント

>>> pow(38, -1, mod=97)
23
>>> 23 * 38 % 97 == 1
True

mod 97の上で38の逆元を求めている。
これが使えるようになったら、この記事自体が無用になってしまうんじゃないか?

ただし上記は3.8以降であり、3.7以前では使えない。
手元のPython 3.7.6 ではこのようになった。

>>> pow(38, -1, 97)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: pow() 2nd argument cannot be negative when 3rd argument specified
>>> 

きっかけ

2020年2月22日のABC156 D問題でTLEから抜け出せなくて、結局C++で提出したが、レートが大きく下がってしまった。そこで今回、振り返って分析をした。
ちなみにC++だと、上記の教訓に挙げたことを考えなくても、余裕で間に合う。
Submission #10287127 - AtCoder Beginner Contest 156