Pynote

Python、機械学習、画像処理について

scikit-learn - matplotlib を使って分類問題の決定境界を描画する

概要

matplotlib で scikit-learn の学習したモデルの決定境界を可視化する方法について

学習する。

iris データセットを用いる。
特徴量としては、Sepal Length、Sepal Width、Petal Length、Petal Width の4つのうち、Sepal Length、Petal Length の2変数を使用する。

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

# データを取得
iris = datasets.load_iris()
data = iris.data[:, [0, 2]]
label = iris.target

学習データとテストデータに 8:2 の割合で分割する。

# 学習データとテストデータに分割する。
X_train, X_test, Y_train, Y_test = train_test_split(
    data, label, test_size=0.2, stratify=label, random_state=42)

今回はロジスティック回帰モデル LogisticRegression を使用して、学習する。

# ロジスティック回帰モデルで学習する。
model = LogisticRegression(solver='lbfgs', multi_class='multinomial')
model.fit(X_train, Y_train)

# テストデータを推論し、精度を出力する。
Y_pred = model.score(X_test, Y_test)
print('test accuracy: {:.2%}'.format(Y_pred))

決定境界を描画する。

サンプルを描画する。

fig, ax = plt.subplots(figsize=(8, 6))

# タイトル、x 軸、y 軸のラベルを設定する。
ax.set_title('classification data using LogisticRegression')
ax.set_xlabel('Sepal length')
ax.set_ylabel('Petal length')

# サンプルを描画する。
ax.scatter(data[:, 0], data[:, 1], c=label, s=7, cmap='tab10')

次にグラフの x 軸、y 軸の範囲はそれぞれ Axes.get_xlim(), Axes.get_ylim() で得られるので、
この範囲に格子状に点を numpy.meshgrid() で作成する。

numpy.meshgrid() の使い方は以下を参照。

pynote.hatenablog.com

print('xlim', ax.get_xlim())  # xlim (4.116740225759217, 8.083259774240783)
print('ylim', ax.get_ylim())  # ylim (0.7005384988315995, 7.199461501168401)
X, Y = np.meshgrid(np.linspace(*ax.get_xlim(), 1000),
                   np.linspace(*ax.get_ylim(), 1000))

そして、作成した各点をモデルで推論し、その点のラベルを得る。
predict() 関数の引数は (サンプル数, 特徴量) という2次元配列を想定しているため、形状を変更する。
推論結果は、X, Y と同じ形状に戻す。

# 推論する。
XY = np.column_stack([X.ravel(), Y.ravel()])
Z = model.predict(XY).reshape(X.shape)

以上で、サンプル (x, y) とその推論ラベル z のデータを作成することができた。
これは次のような関数と考えることができる。

from mpl_toolkits.mplot3d import Axes3D

classes = np.unique(label)

fig = plt.figure(figsize=(9, 9))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='Accent', shade=False)
ax.set_zticks(classes)
ax.set_zticklabels(['class {}'.format(c) for c in classes])
plt.show()

この関数を上から見た図、つまり等高線を考えるとこれが分類境界になる。

# 等高線を描画する。
ax.contourf(X, Y, Z, alpha=0.4, cmap='Paired')
plt.show()