StratifiedKFold のサンプル数に対する挙動

はじめに

scikit-learn の StratifiedKFold の引数 n_splits と、データのサンプル数による挙動と、分割不能になる数の確認です。

Stratified KFold とは

データセットの型よりを保持しながらデータを分割する方法です。
たとえば、クラス1とクラス2のデータの比率が9:1 であった場合、テストデータにおいても9:1の比率であると、実運用段階に近い、よい精度の見積もりができる可能性が高いと考えられます。
このような比率を保持した分割を行うのが StratifiedKFold です。
https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html#sklearn.model_selection.StratifiedKFold

疑問点

scikit-learn の StratifiedKFold は可能な限りデータのクラスの比率を保持しますが、そのままでは比率の保持ができないデータを渡した場合はどのような挙動をするのでしょうか?

例えばクラス1に属するデータは5個しかないにも関わらず、学習・試験データを6セット求められた場合です。
以降、順に挙動を確認します。

1. n_splits <= 最小クラスのデータ数

確認のための、目的変数(クラス)は、クラス3のデータが3個、クラス4のデータが4個、クラス5のデータが5個とします。

ここから、3つの学習・試験データの分割を得ることを考えましょう。この場合データは十分にあるので、素直な挙動になります。
各データが1度ずつ登場し、3つの分割方法が得られます。

2. 最小クラスのデータ数 < n_splits <= 最大クラスのデータ数

先の場合と同じデータを用いて、5つの学習・試験データの分割を得ることを考えます。

最小クラスは3データしかないので、データの分割はできなさそうですが、分割できてしまいます。
内容を見ると、試験データにおいては、データ数が少ないクラスのデータが含まれないことがあるようです。
また、この際、「UserWarning: The least populated class in y has only 3 members, which is less than n_splits=5.」 と警告が出ます。
警告は出ますが、実行はできるので注意が必要です。

3. n_splits > 最大クラスのデータ数

先の場合と同じデータを用いて、6セット得ようとしてみます。

この場合は、最大クラスのデータ数以上にデータを分割することを求められているので当然動作しません。「ValueError: n_splits=6 cannot be greater than the number of members in each class.

とエラーが出ます。」
私は今まで、この each class というのは、「任意のクラスのデータよりも~」という意味だと勘違いしていました。
ただ、n_splits の数は最初クラスのデータ数まで、と説明している記事もあったので、scikit-learnのバージョンにより挙動が異なるのかもしれません。
一応最新版のソースを読んで、最大クラスのデータ数で制限をかけていることを確認しました。

まとめ:n_splits の数に対する挙動

  • 最もサンプル数の少ないクラスのサンプル数以下である場合
    => そのまま動作する
  • 最もサンプル数の少ないクラスのサンプル数より多く、最もサンプル数の多いクラスのサンプル数以下である場合
    => UserWarning を出しながら動作する
  • 最もサンプル数の多いクラスのサンプル数より多い場合
    => Error を出す

動作環境

python: 3.8.11

sklearn: 0.24.2

コメント

タイトルとURLをコピーしました