Pynote

Python、機械学習、画像処理について

Keras - Keras の ImageDataGenerator を使って学習画像を増やす

概要

CNN の学習を行う場合にオーグメンテーション (augmentation) を行い、学習データのバリエーションを増やすことで精度向上ができる場合がある。
Keras の preprocessing.image モジュールに含まれる ImageDataGenerator を使用すると、リアルタイムにオーグメンテーションを行いながら、学習が行える。

キーワード

  • ImageDataGenerator
  • オーグメンテーション (augmentation)

ImageDataGenerator

通常の学習では、データセットから指定した枚数だけ画像を選択し、ミニバッチを作成する。
一方、ImageDataGenerator を使用すると、画像を選択したあと、各画像にオーグメンテーションを行い、ミニバッチを作成する。

どのようなオーグメンテーションを行うかはインスタンス生成時のコンストラクタの引数で指定する。
変換はリアルタイムで行われ、予め保存するわけではないので、ディスク容量を圧迫しない。

基本的な使い方

必要なモジュールを import する。

import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing import image
ImportError: Could not import PIL.Image. The use of `array_to_img` requires PIL.

このようなエラーが出る場合は、pillow がないのでインストールする。

pip install pillow

image.ImageDataGenerator オブジェクトを作成する。
コンストラクタの引数でどのような変換を適用するか指定する。

# 画像データ生成器を作成する。
params = {
    'rotation_range': 20,
    'width_shift_range': 0.4
}
datagen = image.ImageDataGenerator(**params)

flow() でミニバッチを生成する Python ジェネレーターを作成する。
サンプル x (形状が (NumSamples, Height, Width, Channels) である4次元データ) 及びジェネレーターが生成するミニバッチ数 batch_size を指定する。

# ミニバッチを生成するジェネレーターを作成する。
gen = datagen.flow(x, batch_size=16)

サンプル x からミニバッチを作成して学習する場合は fit() を使用したが、ジェネレーターを利用してミニバッチを作成する場合は fit_generator() を使用する。

#  学習する。
model.fit_generator(gen, steps_per_epoch=len(x_train) / 32, epochs=16)

オーグメンテーションの種類

ImageDataGenerator で指定できるオーグメンテーションの種類を紹介する。

1枚の画像を使用して、それを元に ImageDataGenerator() でどのようなデータが生成されるのか可視化してみる。

import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing import image

# 画像を読み込む。
img = image.load_img('dog.jpg')
img = np.array(img)

plt.imshow(img)
plt.show()

# 画像データ生成器を作成する。
# -20° ~ 20° の範囲でランダムに回転を行う。
datagen = image.ImageDataGenerator(rotation_range=20)

# ミニバッチを生成する Python ジェネレーターを作成する。
x = img[np.newaxis]  #  (Height, Width, Channels)  -> (1, Height, Width, Channels) 
gen = datagen.flow(x, batch_size=1)  # 1枚しかないので、ミニバッチ数は1

# Python ジェネレーターで9枚生成して、表示する。
plt.figure(figsize=(10, 8))
for i in range(9):
    batches = next(gen)  # (NumBatches, Height, Width, Channels) の4次元データを返す。
    # 画像として表示するため、3次元データにし、float から uint8 にキャストする。
    gen_img = batches[0].astype(np.uint8)

    plt.subplot(3, 3, i + 1)
    plt.imshow(gen_img)
    plt.axis('off')
plt.show()

回転する。

引数 rotation_range に int で回転する範囲を指定する。
rotation_range=20 とした場合、-20° ~ 20° の範囲でランダムに回転する。

# -20° ~ 20° の範囲でランダムに回転する。
datagen = image.ImageDataGenerator(rotation_range=20)


上下反転する。

引数 vertical_flip に bool で上下反転するかどうかを指定する。
vertical_flip=True とした場合、ランダムに上下反転する。

# ランダムに上下反転する。
datagen = image.ImageDataGenerator(vertical_flip=True) 


左右反転する。

引数 horizontal_flip に bool で上下反転するかどうかを指定する。
horizontal_flip=True とした場合、ランダムに左右反転する。

# ランダムに左右反転する。
datagen = image.ImageDataGenerator(horizontal_flip=True)


上下平行移動する。

引数 height_shift_range に範囲 [0, 1] の float で上下平行移動する範囲を指定する。
height_shift_range=0.3 とした場合、[-0.3 * Height, 0.3 * Height] の範囲でランダムに上下平行移動する。

# [-0.3 * Height, 0.3 * Height] の範囲でランダムに上下平行移動する。
datagen = image.ImageDataGenerator(height_shift_range=0.3)


左右平行移動する。

引数 width_shift_range に範囲 [0, 1] の float で左右平行移動する範囲を指定する。
width_shift_range=0.3 とした場合、[-0.3 * Width, 0.3 * Width] の範囲でランダムに左右平行移動する。

# [-0.3 * Width, 0.3 * Width] の範囲でランダムに左右平行移動する。
datagen = image.ImageDataGenerator(width_shift_range=0.3)


せん断 (shear transformation) する。

せん断とは以下の変換である。

引数 shear_range に int でせん断する際の角度の範囲を指定する。
shear_range=5 とした場合、-5° ~ 5° の範囲でランダムにせん断する。

# -5° ~ 5° の範囲でランダムにせん断する。 
datagen = image.ImageDataGenerator(shear_range=5)


拡大縮小する。

引数 zoom_range に範囲 [0, 1] の float で拡大縮小する範囲を指定する。
zoom_range=0.3 とした場合、[1 - 0.3, 1 + 0.3] つまり [0.7, 1.3] の範囲でランダムに拡大縮小する。

# [1 - 0.3, 1 + 0.3] の範囲でランダムに拡大縮小する。
datagen = image.ImageDataGenerator(zoom_range=0.3)

zoom_range=[0.5, 1.2] とした場合、[0.5, 1.2] の範囲でランダムに拡大縮小する。

# [0.5, 1.2] の範囲でランダムに拡大縮小する。
datagen = image.ImageDataGenerator(zoom_range=[0.5, 1.2])


各画素値に値を足す。

引数 channel_shift_range に float で各画素値に値を足す範囲を指定する。
channel_shift_range=5. とした場合、[-5.0, 5.0] の範囲でランダムに画素値に値を足す。

# [-5.0, 5.0] の範囲でランダムに画素値に値を足す。
datagen = image.ImageDataGenerator(channel_shift_range=5.)


明度を変更する。

引数 brightness_range に範囲 [0, 1] の float で明度を変更する範囲を指定する。
0に近いほど暗くなり、1に近いほど元画像の明度に近くなる。
brightness_range=[0.3, 1.0] とした場合、[0.3, 1.0] の範囲でランダムに明度を変更する。

# [0.3, 1.0] の範囲でランダムに明度を変更する。
datagen = image.ImageDataGenerator(brightness_range=[0.3, 1.0])


外挿方法

引数 fill_mode で回転や平行移動等の結果、値がないピクセルをどのように埋めるかを指定する。

  • 'constant': 引数 cval で指定した値で埋める。(例: kkkkkkkk|abcd|kkkkkkkk, cval=k)
  • 'nearest': 最も近い値で埋める。(例: aaaaaaaa|abcd|dddddddd)
  • 'reflect': 折り返すようにして埋める。(abcddcba|abcd|dcbaabcd)
  • 'wrap': 繰り返すようにして埋める。(abcdabcd|abcd|abcdabcd)


定数倍する。

引数 rescale に指定した値で、各変換を行う前に画素値を rescale 倍する。
[0, 255] で表される画素値を [0, 1] に正規化する場合などに使用する。

# 画素値を [0, 255] から [0, 1] に変更する。
datagen = image.ImageDataGenerator(rescale=1. / 255)

コールバック関数による前処理を行う。

引数 preprocessing_function に3次元データを受け取り、3次元データを返す関数を指定することで、コールバック関数による前処理が行える。

def preprocess(x):
    x /= 255.
    return x

# 指定した前処理を行う。
datagen = image.ImageDataGenerator(preprocessing_function=preprocess)

データセット全体で各チャンネルごとの画素値の平均を0にする。

引数 featurewise_center に True を指定すると、データセット全体で各チャンネルごとの画素値の平均を0にする。

# データセット全体で各チャンネルごとの画素値の平均を0にする。
datagen = image.ImageDataGenerator(featurewise_center=True)
datagen.fit()  # データセット全体の統計量を予め計算する必要がある。

データセット全体で各チャンネルごとの画素値の分散を1にする。

引数 featurewise_std_normalization に True を指定すると、データセット全体で各チャンネルごとの画素値の分散を1にする。
featurewise_std_normalization=True にした場合、featurewise_center=True も指定しなければならない。

# データセット全体で各チャンネルごとの画素値の分散を1にする。
datagen = image.ImageDataGenerator(
    featurewise_center=True, featurewise_std_normalization=True)
datagen.fit()  # データセット全体の統計量を予め計算する必要がある。

白色化を行う。

引数 zca_epsilon に True を指定すると、白色化を行う。
zca_epsilon=True にした場合、featurewise_center=True も指定しなければならない。

# 白色化を行う。
datagen = image.ImageDataGenerator(zca_whitening=True)
datagen.fit()  # データセット全体の統計量を予め計算する必要がある。

サンプルごとの画素値の平均を0にする。

引数 samplewise_center に True を指定すると、サンプルごとの画素値の平均を0にする。

# サンプルごとの画素値の平均を0にする。
datagen = image.ImageDataGenerator(samplewise_center=True)

サンプルごとの画素値の分散を1にする。

引数 samplewise_std_normalization に True を指定すると、サンプルごとの画素値の分散を1にする。
samplewise_std_normalization=True にした場合、samplewise_center=True も指定しなければならない。

# サンプルごとの画素値の分散を1にする。
datagen = image.ImageDataGenerator(
    samplewise_center=True, samplewise_std_normalization=True)

生成した画像をフォルダに保存する。

flow() の引数 save_to_dir に保存するディレクトリ、save_prefix にファイル名の接頭辞、save_format にファイル形式 ('png' or 'jpeg') を指定することで、ジェネレーターで生成した画像がディレクトリに保存される。
どのような画像が生成されたかを確認したい場合に利用できる。

save_path = 'output'  # 保存ディレクトリのパス

# 指定したディレクトリが存在しないとエラーになるので、
# 予め作成しておく。
import os
os.makedirs(save_path, exist_ok=True)

# -5° ~ 5° の範囲でランダムにせん断する。 
datagen = image.ImageDataGenerator(shear_range=5)

# ミニバッチを生成するジェネレーターを作成する。
x = img[np.newaxis]  #  (Height, Width, Channels)  -> (1, Height, Width, Channels) 
gen = datagen.flow(x, batch_size=1, save_to_dir=save_path,
                   save_prefix='generated', save_format='png')

# ジェネレーターで9枚生成する。
plt.figure(figsize=(10, 8))
for i in range(9):
    # ミニバッチを生成したタイミングでディレクトリに
    # 画像が保存される。
    next(gen)