Pynote

Python、機械学習、画像処理について

Keras - ImageNet の学習済みモデルを利用して画像分類を行う。

概要



Keras では VGG、GoogLeNet、ResNet などの有名な CNN モデルの学習済みモデルが簡単に利用できるようになっている。
今回は ImageNet で学習済みの VGG16 モデルを使った画像分類を行う方法を紹介する。

手順

モデルを構築する。

keras.applications.vgg16.VGG16() で VGG16 モデルを作成できる。
今回は ImageNet で学習したモデルで画像分類を行うので、コンストラクタの引数は以下のように指定する。いずれもデフォルト引数なので、省略している。

  • include_top=True: モデルの top 側にある分類用の全結合層を含める。


  • weights='imagenet': ImageNet で学習した重みを使用する。
import numpy as np
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing import image

# VGG16 を構築する。
model = VGG16()
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         (None, 224, 224, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

初回だけ重みをダウンロードする必要があるため、少し時間がかかる。

画像を読み込む。

画像の読み込みには、keras.preprocessing.image モジュールを使用する。

pynote.hatenablog.com

1. keras.preprocessing.image.load_img() で画像を読み込む。
この関数は、指定したパスから画像を読み込み、target_size 引数で指定した大きさにリサイズした PIL.Image オブジェクトを返す。モデルの入力を確認すると、画像の大きさは (224, 224) が要求されていることがわかるので、この値を指定する。

2. keras.preprocessing.image.img_to_array() で PIL.Image オブジェクトから np.float32 型の numpy 配列に変換する。

3. モデルは (BatchSize, Height, Width, Channels) の numpy 配列を要求するので、numpy,expand_dims() でバッチ用の次元を追加する。

4. keras.applications.vgg16.preprocess_input() で前処理を行う。
VGG16 の ImageNet の学習は前処理として、データセットの各チャンネルの平均値を引くということを行っているので、推論時もこの処理を行う必要がある。

サンプル画像

print('model.input_shape', model.input_shape)  # model.input_shape (None, 224, 224, 3)

# 画像を読み込み、モデルの入力サイズでリサイズする。
img_path = 'sample1.jpg'
img = image.load_img(img_path, target_size=model.input_shape[1:3])

# PIL.Image オブジェクトを np.float32 型の numpy 配列に変換する。
x = image.img_to_array(img)
print('x.shape: {}, x.dtype: {}'.format(x.shape, x.dtype))
# x.shape: (224, 224, 3), x.dtype: float32

# 配列の形状を (Height, Width, Channels) から (1, Height, Width, Channels) に変更する。
x = np.expand_dims(x, axis=0)
print('x.shape: {}'.format(x.shape))  # x.shape: (1, 224, 224, 3)

# VGG16 用の前処理を行う。
x = preprocess_input(x)

推論する。

Model.predict() で推論を行える。返り値として、各サンプルごとの 1000 クラス分の確率値が返ってくる。

preds = model.predict(x)
print('preds.shape: {}'.format(preds.shape))  # preds.shape: (1, 1000)

このままでは、確率値がどのクラスに対応するかわからないので、keras.applications.vgg16.decode_predictions() で結果をデコードする。
すると確率が高い順に上位 top までの ImageNet のID、ラベル名、確率値が得られる。

result = decode_predictions(preds, top=3)[0]
print(result)
# [('n02326432', 'hare', 0.5343197),
#  ('n02356798', 'fox_squirrel', 0.27990195),
#  ('n02325366', 'wood_rabbit', 0.10779714)]

for _, name, score in result:
    print('{}: {:.2%}'.format(name, score))
# fox_squirrel: 53.47%
# wood_rabbit: 27.78%
# hare: 9.58%

日本語に訳すと、
キツネリス: 53.47%
小ウサギ: 27.78%
野うさぎ: 9.58%

リスが一番確率が高いので正解である。

日本語のラベル名で表示する。

有志の方が翻訳してくださったImageNet の日本語のラベルデータがあったので、それを使って日本語名で表示してみる。

imagenet_class_index.json をダウンロードし、スクリプトを実行しているパスに配置する。

まずラベルデータを読み込む。

import json

# ImageNet のラベル一覧を読み込む。
with open('imagenet_class_index.json') as f:
    data = json.load(f)
    class_names = np.array([row['ja'] for row in data])

Model.predict() の返り値は (BatchSize, 1000) の numpy 配列であった。
今回は推論した画像は一枚であり、(1, 1000) の numpy 配列が返ってくるので index=0 を取り出し、score とする。
さらに、ndarray.argsort() で降順ソートする。すると、配列には確率が低い順にインデックスが入っているので、最後の3つを取り出すと、確率が高い上位3位のインデックスが取得できる。最後に配列を反転させることで、確率が高い順に対応するインデックスが取得できる。

# 推論する。
scores = model.predict(x)[0]
top3_classes = scores.argsort()[-3:][::-1]

# 推論結果を表示する。
for name, score in zip(class_names[top3_classes], scores[top3_classes]):
    print('{}: {:.2%}'.format(name, score))
# キツネリス: 53.47%
# 木のウサギ: 27.78%
# 野ウサギ: 9.58%

いろんな画像を推論してみる。

推論結果を簡単に確認できるように関数を作っておく。

def infer_img(img_path):
    img = image.load_img(img_path, target_size=model.input_shape[1:3])
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    
   # 推論する。
    scores = model.predict(x)[0]
    top3_classes = scores.argsort()[-3:][::-1]
    print(top3_classes)

    # 推論結果を描画する。
    plt.axis('off')
    plt.imshow(img)

    y_pos = [250, 275, 300]
    for y, name, score in zip(y_pos, class_names[top3_classes], scores[top3_classes]):
        text = '{}: {:.2%}'.format(name, score)
        plt.text(0, y, text, fontsize=13)

    plt.show()

各クラスの確率は ImageNet の1000 クラスしか出てこないので、当然そこに含まれないクラスは推論できない。