sklearn の StratifiedKFold のサンプル数に対する挙動の変化
はじめに
scikit-learn の StratifiedKFold の引数 n_splits と、データのサンプル数による挙動と、分割不能になる数を確認します。
Stratified KFold とは
データセットのクラスを保持しながらデータを分割する方法です。
たとえば、データセット内の、クラス 1 とクラス 2 のデータの比率が 9:1 であった場合、テストデータにおいてもクラスの比率が 9:1 であると、実運用段階に近く、よい正確な精度の見積もりができる可能性が高いと考えられます。
このような比率を保持した分割を行うのが StratifiedKFold です。
scikit-learn
疑問点
scikit-learn の StratifiedKFold は可能な限りデータのクラスの比率を保持しますが、そのままでは比率の保持ができないデータを渡した場合はどのような挙動をするのでしょうか?
例えばクラス 1 に属するデータは 5 個しかないにも関わらず、学習・試験データを 6 セット求められた場合です。
以降、順に挙動を確認します。
1. n_splits <= 最小クラスのデータ数
確認のための、目的変数(クラス)は、クラス 3 のデータが 3 個、クラス 4 のデータが 4 個、クラス 5 のデータが 5 個とします。
ここから、3 つの学習・試験データの分割を得ることを考えましょう。この場合データは十分にあるので、素直な挙動になります。
各データが 1 度ずつ登場し、3 つの分割方法が得られます。
2. 最小クラスのデータ数 < n_splits <= 最大クラスのデータ数
先の場合と同じデータを用いて、5 つの学習・試験データの分割を得ることを考えます。
最小クラスは 3 データしかないので、データの分割はできなさそうですが、分割できてしまいます。
内容を見ると、試験データにおいては、データ数が少ないクラスのデータが含まれないことがあるようです。
また、この際、「UserWarning: The least populated class in y has only 3 members, which is less than n_splits=5.」 と警告が出ます。
警告は出ますが、実行はできるので注意が必要です。
3. n_splits > 最大クラスのデータ数
先の場合と同じデータを用いて、6 セット得ようとしてみます。
この場合は、最大クラスのデータ数以上にデータを分割することを求められているので当然動作しません。
"ValueError: n_splits=6 cannot be greater than the number of members in each class."
とエラーが出ます。
私は今まで、この each class というのは、「任意のクラスのデータよりも~」という意味だと勘違いしていました。
ただ、n_splits の数は最初クラスのデータ数まで、と説明している記事もあったので、scikit-learn のバージョンにより挙動が異なるのかもしれません。
一応 scikit-learn の最新版のコードを読んで、最大クラスのデータ数で制限をかけていることを確認しました。
まとめ:n_splits の数に対する挙動
- 最もサンプル数の少ないクラスのサンプル数以下である場合
- そのまま動作する
- 最もサンプル数の少ないクラスのサンプル数より多く、最もサンプル数の多いクラスのサンプル数以下である場合
- UserWarning を出しながら動作する
- 最もサンプル数の多いクラスのサンプル数より多い場合
- Error
動作環境
- python: 3.8.11
- scikit-learn: 0.24.2