k 近傍法 (KNN)
k-近傍法
k-近傍法 (k-nearest neighbor algorithm, KNN) は次の規則で行う分類アルゴリズムである。
クラスの訓練集合を とする。
ただし、 は 次元の点、 はその点に対応するラベルを で表す。
分類したい点 が与えられたとき、訓練データから 個の近傍の点を探し、それらの点が属するクラスの多数決で 点 が属するクラスを決定する。
同票数の場合、多数決で決められなくなる問題があるが、 を奇数にすることで回避できる。
距離の計算には、ユークリッド距離がよく使われる。
のとき、最も近い点を探すことを意味するので、最近傍法 (nearest neighbor search, NNS) になる。
3クラスの2次元の訓練集合
クラス分類を行いたい点 (赤) が与えられたとする。
k = 7 の場合、距離が近い上位7個の点を調べる。
クラス0に属する点は6個、クラス2に属する点は1個なので、多数決で点 (赤) はクラス0と分類する。
アルゴリズムの特徴
- 学習時は訓練データを記録するだけなので、計算が不要である。このことを怠惰学習 (lazy learner) という。訓練データの追加も容易である。
- すべての訓練データを記録しておく必要があるので、メモリが必要となる。
- 推論時は近傍の点を探すために訓練データの各点との距離を計算する必要があるため、計算量が大きい。
KNeighborsClassifier
scikit-learn の sklearn.neighbors.KNeighborsClassifier で 近傍法を使用できる。
が偶数で同票数の場合、その中で一番小さい値のクラスに割り当てられる。(例: k=4 でクラス0が2個、クラス1が2個の場合、クラス0と判定する。*1
import numpy as np from sklearn.datasets import make_blobs from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier # データセットを作成する。 X, y = make_blobs(n_samples=60, centers=3, random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) # k 近傍法を学習する。 knn = KNeighborsClassifier(n_neighbors=3).fit(X_train, y_train) # テストデータに対する精度を計算する。 score = knn.score(X_test, y_test) print(f"score: {score:.2%}")
決定境界を描画する。
import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import make_blobs from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier def plot_decision_regions(X, y, classifier, ax): # データセットを描画する。 scatter = ax.scatter(X[:, 0], X[:, 1], c=y, s=20, cmap="Paired") ax.legend(*scatter.legend_elements(), title="Classes") # 推論する。 xx, yy = np.meshgrid( np.linspace(*ax.get_xlim(), 100), np.linspace(*ax.get_ylim(), 100) ) xy = np.column_stack([xx.ravel(), yy.ravel()]) zz = classifier.predict(xy).reshape(xx.shape) # 決定境界 ax.contourf(xx, yy, zz, alpha=0.3) # データセットを作成する。 X, y = make_blobs(n_samples=100, centers=3, random_state=0) # k 近傍法を学習する。 knn = KNeighborsClassifier(n_neighbors=7).fit(X, y) # 決定境界を描画する。 fig, ax = plt.subplots(facecolor="w") plot_decision_regions(X, y, knn, ax) plt.show()
パラメータ k の選択
はハイパーパラメータであり、最適な値は学習するデータによって異なる。
が大きいほど、分類にあたってのノイズの影響を低減できるが、クラス間の境界が明確にならない傾向がある。
import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import make_blobs from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier def plot_decision_regions(X, y, classifier, ax): # データセットを描画する。 scatter = ax.scatter(X[:, 0], X[:, 1], c=y, s=20, cmap="Paired") # 推論する。 xx, yy = np.meshgrid( np.linspace(*ax.get_xlim(), 100), np.linspace(*ax.get_ylim(), 100) ) xy = np.column_stack([xx.ravel(), yy.ravel()]) zz = classifier.predict(xy).reshape(xx.shape) # 決定境界 ax.contourf(xx, yy, zz, alpha=0.3) # データセットを作成する。 X, y = make_blobs(n_samples=100, centers=3, random_state=0) # k の値による決定境界の違い fig = plt.figure(figsize=(9, 9), facecolor="w") for k in range(1, 10): knn = KNeighborsClassifier(n_neighbors=k).fit(X, y) # 決定境界を描画する。 ax = fig.add_subplot(3, 3, k) ax.set_title(f"k = {k}") plot_decision_regions(X, y, knn, ax) plt.show()
グリッドサーチでいくつかの値での学習を試して、最適な値を採用するとよい。
scikit-learn の sklearn.model_selection.GridSearchCV を使用すると、この探索を簡単に行える。
import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn.datasets import make_blobs from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.neighbors import KNeighborsClassifier # データセットを作成する。 X, y = make_blobs(n_samples=100, centers=2, random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) # 試行するパラメータとその値 params = {"n_neighbors": np.arange(1, 11)} # グリッドサーチする。 clf = GridSearchCV( KNeighborsClassifier(), params, cv=5, return_train_score=False, iid=False ) clf.fit(X_train, y_train) # 最も精度がいいモデルを取得する。 best_clf = clf.best_estimator_ best_score = best_clf.score(X_test, y_test) print(f"score: {best_score:.2%}") # score: 100.00% cv_result = pd.DataFrame(clf.cv_results_) for row in cv_result.itertuples(): print(f"k = {row.params['n_neighbors']}: {row.mean_test_score:.2%}") # k = 1: 95.00% # k = 2: 92.50% # k = 3: 96.25% # k = 4: 95.00% # k = 5: 93.75% # k = 6: 95.00% # k = 7: 93.75% # k = 8: 91.25% # k = 9: 92.50% # k = 10: 91.25%