実例から学ぶ Python競技プログラミングの定数倍高速化シリーズ1:徒競走

競技プログラミングAtCoderの問題をPythonを使って解き、定数倍高速化した結果をまとめる。

定数倍高速化とは何か

定数倍高速化 - MonoBook より

定数倍高速化とは、アルゴリズムの改善による高速化とは違い、計算処理の方法を改善することにより計算量のオーダーを変えずに処理を高速化することである。

そしてこの記事では、一般論ではなく、具体的な問題に対する最適化を取り上げる。つまり、競技プログラミングの特定の問題を色々なコードで解いてみて、その結果を整理してそこから教訓を得ようというものだ。

(初回なのにタイトルに「シリーズ」とか入れてるけど、大丈夫なんだろうか。私が記事を書くやる気が続くだろうか……?)

今回取り上げるのは、AtCoder Beginner Contest 041 D問題 徒競走である。

注意事項

本検証はPython 3.8.2に言語がバージョンアップされた後に実施している。
そのため、処理時間はPython 3.8.2, PyPy3 (7.3.0)である。

処理時間は1回提出した結果であり、複数回の平均ではない。

解法

解法の細かい説明はしない。そこはこの記事の主眼ではないので。
というか、解説を読んで解法を理解してからACしたので、解法は公式解説PDFと全く同じである。

処理時間まとめ

以下の表は、Python3 / PyPy3の処理時間をまとめたものである。最初のコードから始めて、高速化を順次実行したときの、実行時間である。*1

コード Python (3.8.2) PyPy3 (7.3.0)
最初 TLE 750ms
高速化(1) 早期break TLE 324ms
高速化(2) 計算量の削減 TLE 243ms
高速化(3) sys.stdin.readline TLE 230ms
高速化(4) 真偽値判定を簡略化 TLE(3262ms) 230ms
高速化(5) 2**nを1<<nに変更 1626ms 169ms
高速化(6) indexをsetで管理 1032ms 270ms

最初

最初の提出コード:GitHub
最初の提出結果:AtCoder

正直、提出したときはPyPyでも通らないだろうと思っていた。 「これは計算量を落とせる(後述の高速化(2)のこと)のに、わざわざ計算量が大きい解法で解いているので、おそらくTLEになるだろう。でも一応提出してみるか」と思っていたら通ってしまった。
計算量は外側のループから順に、2n * n * mである。最大値を代入すると、 2^16 * 16 * (16*15)/ 2 = 125829120
あれ1億2500万だぞ……!? しかもこのコードは途中でbreakしないから、ループが全部回るぞ……何で通ったの???

一応ACは取れたので良いのだが、Pythonでも通るようにできないかと思って、色々と定数倍高速化を試みた。
なおこの問題の制限時間は3秒である。

高速化(1) 早期break

コード差分:GitHub

とりあえずパッと思いつくものを入れた。 フラグを最初にTrueにしておいて、どれか一つでも条件を満たすならFalseになるというパターンのコードである。だから、Falseとわかった時点でその後の計算は不要であり、ループをbreakできる。

計算量は……えーっとどうなるんだ。 最悪のケースになる入力は、多分以下の場合だろう。

16 120
1 2
1 3
2 3
1 4
2 4
3 4
1 5
……(後略)

そうすると、

  • 頂点2が1位になりえないのは、1個目の辺を見た段階でわかる(のでここでbreak)
  • 頂点3が1位になりえないのは、2個目の辺を見た段階でわかる(のでここでbreak)
  • 頂点4が1位になりえないのは、4個目の辺を見た段階でわかる(のでここでbreak)

……というふうに、各頂点に対する辺のループの数は三角数+1になる。毎回120個すべてを見ていた場合と比べて、計算量は大雑把に3分の1かな。多分。
三角数はxの2次関数なので、x2区間[0,1]で積分するイメージである)

PyPyの実行時間では半分以下と著しく速くなっているが、Pythonでは依然としてTLEである。

高速化(2) 計算量の削減 /(1)を上書き

コード差分:GitHub

定数倍高速化と散々言ってるが、ここだけは計算量の式が変わる高速化です(=定数倍高速化ではない)。
may_be_first(1位になる可能性があるか否か)を最初にまとめて計算しておけば、繰り返し計算する必要はない。
計算量が2n * n * m → 2n * (n + m) に削減された。
数字の上では 1億2500万→891万 で10分の1以下になってるけど、PyPyの時間はそこまで短くなってはいない。なんでやろ。

高速化(3) 入力 input = sys.stdin.readline

コード差分:GitHub

Python競技プログラミングをするならまずやろうと言われるinput = sys.stdin.readline
簡単にできるから入れてみたものの、あまり変化なし。
(これのおかげでTLEが消えて無事に通ったという経験がない。個人的にはあまり効果を感じていないので、普段はinput = sys.stdin.readlineを書いていない。まぁ、今回は入力が最大でも121行だから、入力を高速化しても効果が殆ど無いのだろう。)

高速化(4) 真偽値判定を簡略化

コード差分:GitHub

bit AND を取った結果が0か非0かを判定するためにx & y > 0と書いていたが、0より大きいか小さいか判定する必要はないので単にx & yでよい。

TLEのとき、表示されている時間は本当にかかった時間とは限らない。制限時間を少し過ぎたところでコードを強制終了している。
1つ前までは、3.3秒あたりで強制終了しているようにみえる(本当のところは不明です)。 しかし今回は3262msである。
つまり今までは3.3秒を過ぎて強制終了されたが、 今回は3.262秒で計算が終わったが、それが3秒より大きいのでTLEであると推測される(本当のところは不明です)。
あと一息だ。

高速化(5) 2**nを1<<nに変更

コード差分:GitHub

ついにPythonでも

俺の中では「2**nと書こうが1<<nと書こうが同じじゃん」と思ってたけど、 3262ms→1626msだ! 時間がなんと半分以下になった!!
コンパイラの最適化で勝手に書き換えられるかなと思っていたけど、最適化は走らないんだね……一般のa**nの場合はこの書き換えは使えない。だから、冷静に考えたら、一部の数のときだけ適用できる書き換えをするのは難しいんだろうな……)

マジかよ。累乗計算をビット演算に変えるだけでこんなに変わるの!?

特定桁のbit ANDを取るとき、つい癖でif x & 2**digit > 0と書いていたけど、これがそれほどまでに動作の遅い書き方だとは知らなかった。

高速化(6) indexをsetで管理する /(4)(5)を上書き

コード差分:GitHub

idxと特定桁のbit ANDを取る処理を何度かやっている。
「これ……どの桁にビットが立っているかを最初に集合(set)に入れて、あとはdigitが集合に含まれるか判定するほうが速いかもしれない?」と考えた。
ここの集合の作り方はリスト内包表記で簡潔に書くと動作が速い。

Pythonだと1626→1032msと速くなった。
一方で、PyPyだと169ms→270msと逆に遅くなった。PyPyだと時間がかかる処理なんだろうか。

まとめ

2**n は非常に遅い書き方なので 1<<n にしましょう。

ただし演算子の優先順位には注意!!
今回も2**n1<<n に置換した後、確認せずにsubmitして、Runtime Errorを1回出した。
idx - (1<<digit)とすべきところでidx - 1<<digitと書いて、減算を最初に計算してしまったからである。
以前も同じミスにハマって、そのときに書いた記事が以下である。

linus-mk.hatenablog.com

それでは。

*1:余談になるが、markdownで表を作るの面倒だなぁと探したら以下のWebツールが見つかった。「MarkdownテーブルをExcelライクな操作で簡単に作成できるツールです」と書いてある。便利。 https://notepm.jp/markdown-table-tool