Pynote

Python、機械学習、画像処理について

scikit-learn - データセットを学習データとテストデータに分割する。

train_test_split()

sklearn.model_selection.train_test_split(*arrays, **options)[source]

引数

  • arrays: 分割対象の配列 (複数個指定可)。
    • 配列の型は リスト、numpy 配列、scipy sparse 行列、pandas dataframes に対応
  • test_size: テストに利用する割合。float, int, None の指定が可能。
    • デフォルト: train_size を指定した場合は None、そうでない場合は 0.25。
    • float: 分割の割合 (0~1) を指定する。
    • int: テストに使用するデータ数を指定する。
    • None: train_size の指定に従い、決める。
  • train_size: 学習に利用する割合。float, int, None の指定が可能。
    • デフォルト: test_size を指定した場合は None。
    • float: 分割の割合 (0~1) を指定する。
    • int: テストに使用するデータ数を指定する。
    • None: test_size の指定に従い、決める。
  • random_state: 分割に使用する乱数に関する設定。
    • デフォルト: None
    • int: 乱数のシードに利用
    • None: np.random.RandomState オブジェクトを利用する。
  • shuffle: 分割前にシャッフルするかどうかを指定する。
    • デフォルト: True
  • stratify: 各クラスごとに同じ割合で分割したい場合は、クラス一覧を指定する。
    • デフォルト: None

返り値

  • 分割した配列

基本的な使い方

分割する割合または数を指定する。

test_size (train_size) でテストデータ (学習データ) に使用する割合または数を指定する。一方を指定すれば、もう一方も決まる。
例えば、学習データを 0.75 と指定した場合、残りの 0.25 がテストデータになる。

  • テストデータ25%、学習データ75%と分割したい場合: test_size=0.25
  • 全データ数が100であるとき、テストデータ10、学習データ90と分割したい場合: test_size=10

分割する配列を指定する。

引数に2つの配列 array1, array2 を渡した場合、各配列が学習データとテストデータに分割されるので、
返り値は4個の配列のリストになる。

array1_train, array1_test, array2_train, array2_test = \
    train_test_split([array1, array2])

サンプルコード

基本的な使い方

import numpy as np
from sklearn.model_selection import train_test_split

# 分割対象の配列を作成する。
data = np.arange(20)
labels = np.arange(20)

# array1, array2 をそれぞれ学習データ75%、テストデータ25%の割合で分割する。
data_train, data_test, labels_train, labels_test = \
    train_test_split(data, labels, train_size=0.75)

print('data train', data_train)
print('data test', data_test)
print('labels train', labels_train)
print('labels test', labels_test)
data train [10  4  9 17 14  6  8  7 16  3 18 19  1 15  0]
data test [13  2 12  5 11]
labels train [10  4  9 17 14  6  8  7 16  3 18 19  1 15  0]
labels test [13  2 12  5 11]

シャッフルするが実行するたびに分割結果を同じにしたい場合

# 乱数のシードを10で固定することで、シャッフルする結果が毎回同じになる。
data_train, data_test, labels_train, labels_test = train_test_split(
    data, labels, train_size=0.75, random_state=10)

print('data train', data_train)
print('data test', data_test)
print('labels train', labels_train)
print('labels test', labels_test)