Pynote

Python、機械学習、画像処理について

TensorFlow / Keras - ImageDataGenerator を使った画像分類モデルの学習方法

概要

ImageDataGenerator を使用して画像分類の学習を行うチュートリアル。

Jupyter Notebook

本記事のコード全体は以下。

keras-image-data-generator-usage.ipynb

import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing import image
from tensorflow.keras.utils import get_file

flower_photos

flower_photos は、花の画像のデータセットである。
tf.keras.utils.get_file でダウンロードする。この関数はデフォルトでは、ダウンロードして解凍したファイルは ~/.keras/datasets 以下に保存する。

# データセットを取得する。
dataset_dir = get_file(
    "flower_photos",
    "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
    untar=True
)
print(dataset_dir)  # /root/.keras/datasets/flower_photos

num_classes = 5  # クラス数

データセットは以下のディレクトリ構成になっている。
各クラスごとにディレクトリがあり、そのクラスに属する画像が保存されている。

flower_photos
├── LICENSE.txt
├── daisy
├── dandelion
├── roses
├── sunflowers
└── tulips

クラス名 画像枚数
sunflowers (ヒマワリ) 699
dandelion (タンポポ) 898
daisy (ヒナギク) 633
roses (バラ) 641
tulips (チューリップ) 799

モデルを作成する。

CNN は 、ImageNet で学習したVGG16 を利用して、転移学習する。
tf.keras.applications.VGG16 で VGG16 モデルを作成できる。

  • include_top: VGG16 モデルは畳み込み層及びプーリング層で構成される画像から特徴抽出を行う部分と全結合層で構成されるクラスの識別を行う部分からなる。このうち、前者のみ利用するので、include_top=False とする。
  • input_shape: モデルの入力サイズは ImageNet を学習した際と同じ (224, 224, 3) にしておく。
  • 転移学習なので、この部分は学習中に重みを変更しないため、trainable=False としておく。
# VGG16 モデルを作成する。
vgg16 = VGG16(include_top=False, input_shape=(224, 224, 3))
vgg16.trainable = False  # 重みをフリーズする。

識別用の全結合層を3つ追加する。

model = Sequential(
    [
        vgg16,
        Flatten(),
        Dense(500, activation="relu"),
        Dropout(0.5),
        Dense(500, activation="relu"),
        Dropout(0.5),
        Dense(num_classes, activation="softmax"),
    ]
)

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Model)                (None, 7, 7, 512)         14714688  
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
dense (Dense)                (None, 500)               12544500  
_________________________________________________________________
dropout (Dropout)            (None, 500)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 500)               250500    
_________________________________________________________________
dropout_1 (Dropout)          (None, 500)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 2505      
=================================================================
Total params: 27,512,193
Trainable params: 12,797,505
Non-trainable params: 14,714,688
_________________________________________________________________

モデルをコンパイルする。

最適化手法は Adam を指定する。
損失関数は多クラス分類問題でラベルは 0, 1, \cdots, 4 の sparse 形式で与えられるので、sparse_categorical_crossentropy を指定する。
指標はクラス分類問題なので、損失関数の値の他に精度も確認したいので accuracy を指定する。

# コンパイルする。
model.compile(
    optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)

ImageDataGenerator を作成する。

tf.keras.preprocessing.image.ImageDataGenerator でオーグメンテーションした入力を生成する ImageDataGenerator を作成する。

学習済みの重みを利用する上でその重みを学習した際に行った前処理と同じ処理を行うことが好ましい。
この前処理は tf.keras.applications.vgg16.preprocess_input で行えるので、preprocessing_function に指定する。
オーグメンテーションのパラメータとしては以下の3つを指定する。

  • horizontal_flip: ランダムに左右反転する。
  • brightness_range: ランダムに明るさを変更する。

またデータセット全体を学習データ8割、バリデーションデータ2割となるように利用したいので、validation_split に 0.2 を指定する。

# ハイパーパラメータ
batch_size = 64  # バッチサイズ
num_epochs = 30  # エポック数

# ImageDataGenerator を作成する。
datagen_params = {
    "preprocessing_function": preprocess_input,
    "horizontal_flip": True,
    "brightness_range": (0.7, 1.3),
    "validation_split": 0.2,
}
datagen = image.ImageDataGenerator(**datagen_params)

次に ImageDataGenerator.flow_from_directory でディレクトリから画像を読み込み、データを生成するジェネレーターを学習用、バリデーション用にそれぞれ作成する。

  • directory: データセットのディレクトリパスを指定する。
  • target_size: モデルに入力する画像の形状 (高さ, 幅) を指定する。モデルの入力サイズは input_shape で (N, H, W, C) で得られるので、target_size=model.input_shape[1:3] とすればよい。
  • class_mode: 画像と一緒に生成するラベルの形式を指定する。sparse とした場合、対応するクラス ID を表す 0, 1, \cdots という整数値のラベルを返す。
  • subset: 学習データを生成するジェネレーターの場合は "training"、バリデーションデータを生成するジェネレーターの場合は "validation" を指定する。
# 学習データを生成するジェネレーターを作成する。
train_generator = datagen.flow_from_directory(
    dataset_dir,
    target_size=model.input_shape[1:3],
    batch_size=batch_size,
    class_mode="sparse",
    subset="training"
)

# バリデーションデータを生成するジェネレーターを作成する。
val_generator = datagen.flow_from_directory(
    dataset_dir,
    target_size=model.input_shape[1:3],
    batch_size=batch_size,
    class_mode="sparse",
    subset="validation",
)

# クラス ID とクラス名の対応関係
print(train_generator.class_indices)
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

学習する。

ジェネレーターで学習する場合は tf.keras.Model.fit_generator を使用する。

generator、validation_data にはそれぞれ学習データ、バリデーションデータを生成するジェネレーターを指定する。
steps_per_epoch、validation_steps はそれぞれ学習データ、バリデーションデータをすべて生成するのに必要な反復回数を指定する。
ジェネレーターは無限にデータを生成できるので、何回データを生成したら1エポックが完了したのかがわからないため、明示的に指定する必要がある。
samples 属性でジェネレーターが生成する際の元となる画像枚数が取得できるので、画像枚数 // バッチサイズでこの値を計算できる。

# 学習する。
history = model.fit_generator(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    validation_data=val_generator,
    validation_steps=val_generator.samples // batch_size,
    epochs=num_epochs,
)
Epoch 1/30
45/45 [==============================] - 14s 312ms/step - loss: 11.4499 - acc: 0.6289 - val_loss: 2.5192 - val_acc: 0.7997
Epoch 2/30
45/45 [==============================] - 14s 309ms/step - loss: 3.5395 - acc: 0.7791 - val_loss: 1.3187 - val_acc: 0.8324
Epoch 3/30
45/45 [==============================] - 14s 309ms/step - loss: 2.2268 - acc: 0.8073 - val_loss: 1.1119 - val_acc: 0.8381
Epoch 4/30
45/45 [==============================] - 14s 312ms/step - loss: 1.9359 - acc: 0.8143 - val_loss: 1.1818 - val_acc: 0.8210
Epoch 5/30
45/45 [==============================] - 14s 314ms/step - loss: 1.8972 - acc: 0.8153 - val_loss: 1.0083 - val_acc: 0.8338
以下略

30エポックで約83%の精度が出た。

損失関数の値、精度の履歴を可視化する。

epochs = np.arange(1, num_epochs + 1)

fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 4))

# 損失関数の履歴を可視化する。
ax1.plot(epochs, history.history["loss"], label="loss")
ax1.plot(epochs, history.history["val_loss"], label="validation loss")
ax1.set_xlabel("epochs")
ax1.legend()

# 精度の履歴を可視化する。
ax2.plot(epochs, history.history["acc"], label="accuracy")
ax2.plot(epochs, history.history["val_acc"], label="validation accuracy")
ax2.set_xlabel("epochs")
ax2.legend()

plt.show()


評価する。

# 評価する。
test_loss, test_acc = model.evaluate_generator(val_generator)

print(f"test loss: {test_loss:.2f}, test accuracy: {test_acc:.2%}")
# test loss: 1.07, test accuracy: 83.04%

推論する。

バリデーションデータから数枚推論して、結果を表示する。

class_names = list(val_generator.class_indices.keys())


def plot_prediction(img, prediction, label):
    pred_label = np.argmax(prediction)

    fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5), facecolor="w")

    ax1.imshow(img)
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_xlabel(
        f"{class_names[pred_label]} {prediction[pred_label]:.2%} ({class_names[label]})",
        fontsize=15,
    )

    bar_xs = np.arange(len(class_names))  # 棒の位置
    ax2.bar(bar_xs, prediction)
    ax2.set_xticks(bar_xs)
    ax2.set_xticklabels(class_names, rotation="vertical", fontsize=15)


# バリデーションデータから3サンプル推論して、結果を表示する。
for i in val_generator.index_array[:3]:
    img_path = val_generator.filepaths[i]
    label = val_generator.labels[i]

    # 画像を読み込む。
    img = Image.open(img_path)
    # モデルの入力サイズにリサイズする。
    img = img.resize(reversed(val_generator.target_size))
    # PIL -> numpy 配列
    img = np.array(img)
    # バッチ次元を追加する。
    x = np.expand_dims(img, axis=0)
    # 前処理を行う。
    x = preprocess_input(x)

    # 推論する。
    prediction = model.predict(x)

    # 推論結果を可視化する。
    plot_prediction(img, prediction[0], label)