機械学習モデルを解釈する方法 Permutation Importance / Partial Dependence Plot

Machine Learning for Insights Challengeで勉強した「Permutation Importance」と「Partial Dependence Plot」についてまとめる。

Machine Learning for Insights Challengeとは

9月18~21日に、kaggleの「Machine Learning for Insights Challenge」という講座が開催された。
1日に1通メールが来て、機械学習関連の話を学べる。
この記事では、4日のうち前半2日間の部分をまとめる。

なお、教材は公開されているので、今からでも同じように学習できる。
それぞれの日の説明には演習問題がついていて、kaggle kernel上で実行できる。

さてInsightsを日本語に直すと「洞察、物事の本質を見抜くこと」となるが、具体的には何なんだろうか。
「コースの始めに」に相当するUse Cases for Model Insights | Kaggleを見ると、このコースでは以下について学べると書いてある。

データのどの機能が最も重要であるとモデルは判断したか?
モデルからの単一の予測について、データ内のそれぞれの特徴量がその特定の予測にどのように影響したか?
全体的に見て、それぞれの特徴量はモデルの予測にどのような影響を与えるのか(多数の起こりうる予測を考慮した場合、典型的な効果は何か?)

以下の記事に出てくるような「解釈性(Interpretaility)」の話題と近いと考えておけば良さそうだ。
【記事更新】私のブックマーク「機械学習における解釈性(Interpretability in Machine Learning)」 - 人工知能学会 (The Japanese Society for Artificial Intelligence)

Permutation Importance

資料はこちら。下の方のリンクから演習問題に行ける。
Permutation Importance | Kaggle

機械学習のモデルを構築したときに「どの特徴量が重要なのか(feature importance)」を知る方法の一つが「Permutation Importance」である。

  1. モデルの訓練をする
  2. 1つの列の値をシャッフルして、その結果のデータセットを使用して予測を行う。 これらの予測値と真の目標値を使用して、シャッフルすることで損失関数がどれだけ悪くなったかを計算する。 そのパフォーマンス低下は、あなたが今シャッフルした変数の重要性を測定している。
  3. データを元の順序に戻す(手順2のシャッフルを元に戻す)。データセットのそれぞれの列で手順2を繰り返して、各列の重要度を計算する。

演習問題では「タクシーの料金」の予測をしていた。
乗車地点の緯度・経度、降車地点の緯度・経度、乗客の人数というデータを元に機械学習モデルを作成する。
学習用のデータの一部は以下の通り。

fare_amount pickup_longitude pickup_latitude dropoff_longitude dropoff_latitude passenger_count
2 5.7 -73.982738 40.761270 -73.991242 40.750562 2
3 7.7 -73.987130 40.733143 -73.991567 40.758092 1
4 5.3 -73.968095 40.768008 -73.956655 40.783762 1
6 7.5 -73.980002 40.751662 -73.973802 40.764842 1
7 16.5 -73.951300 40.774138 -73.990095 40.751048 1

バリデーション用データのうち、ある特徴量を互いに入れ替えてから、予測を実行して、損失関数の値がどう変わるかを観察する。
例えば、乗車地点の緯度を入れ替えてから予測をすると、予測結果は多く変わってくるから、損失関数の値も大きくなる、
一方で、乗客の人数は料金に影響しないので、乗客の人数を入れ替えても結果は殆ど変わらないだろうと予想される。

すなわち、精度が下がったらモデル構築の上で重要な特徴量、精度が下がらなかったらモデル構築の上で重要でない特徴量と判断できる。

Partial Dependence Plot (PDP)

2日目のテーマはPartial Dependence Plotである。
Partial Dependence Plotは、それぞれの特徴量が予測にどのような影響を与えるかを知るのに役に立つ。

資料はこちら。下の方のリンクから演習問題に行ける。
Partial Plots | Kaggle

モデルの訓練をしたあとに、
検証データのある点を元にして、影響を見たい1つの特徴量だけを変化させて、学習済みのモデルで予測を実施する。すると「1つの特徴量が変化すると予測はどう変わるのか?」が分かる。
検証データの中から多数の点について同様の操作を実施し、その平均を描画する。

演習問題の最後の方は、1日目のPermutation Importanceと2日目のPartial Dependence Plotを組み合わせた複合問題だった。これによって両者の挙動の違いが分かった。

問題:予測可能な特徴量が特徴量AとBの2つしかない場合を考える。どちらも最小値は-1、最大値は1である。
特徴量AのPartial Dependence Plotを描画すると、全ての範囲にわたって急激に増加する。一方、特徴量BのPartial Dependence Plotは、全ての範囲にわたってよりゆっくり(急激ではなく)増加する。
このとき、特徴量AのPermutation Importanceは必ず特徴量BのPermutation Importanceより必ず大きいと言えるか?

ちょっと考えると正しそうかなーっと思ったけど、実はこれは必ずしも正しくない。
たとえば、特徴量Aが変化する場合には大きな影響を与えるが、ほとんどのデータでは全く同じ値を取る、という場合が考えられる。この場合、Permutation Importanceを計算してもほとんどの値は変わらないままであり、特徴量AのPermutation Importanceの値はそれほど大きくならない。