Pynote

Python、機械学習、画像処理について

Keras - CNN の畳み込み層の重みや特徴マップを可視化する方法

手順

モジュールを import する。

import matplotlib.pyplot as plt
import numpy as np
from keras import backend as K
from keras.applications.resnet50 import (ResNet50, preprocess_input)
from keras.preprocessing import image

画像を読み込む。

今回使用する画像

img = image.load_img('horse.jpg', target_size=(224, 224))  # 画像を読み込む。
x = image.img_to_array(img)  # PIL オブジェクトを numpy 配列にする。
x = np.expand_dims(x, axis=0)  # ミニバッチにするため、次元を追加する。
x = preprocess_input(x)  # ResNet 用の前処理を行う。

モデルを作成する。

Keras では、ImageNet で学習済みの CNN モデルが簡単に使える。
今回は ResNet-50 を利用する。

pynote.hatenablog.com

# ResNet モデルを作成する。
model = ResNet50(include_top=False)
model.summary()

中間層の特徴マップを返す関数を作成する。

# モデルの3層目の出力を返す関数を作成する。
get_feature_map = K.function([model.input, K.learning_phase()], [model.layers[3].output])

中間層 (畳み込み層) の特徴マップ及び重みを取得する。

# 順伝搬して特徴マップを取得する。
features = get_feature_map([x, False])[0]
print('features.shape', features.shape)  # features.shape (1, 112, 112, 64)

# 重みを取得する。
[weights, bias] = model.layers[1].get_weights()
print('layer.name', model.layers[2].name)  # layer.name bn_conv1
print('weights.shape', weights.shape)  # weights.shape (7, 7, 3, 64)
print('bias.shape', bias.shape)  # bias.shape (64,)

特徴マップ及びカーネルの重みを可視化する。

まず特徴マップ及び重みをカーネルごとに分割する。
さらに特徴マップ及びカーネルの重みは共に実数なので、[0, 255] の整数で表される配列に変換する必要がある。
Keras の array_to_img() を利用すると、簡単に行える。


pynote.hatenablog.com

# 特徴マップをカーネルごとに分割し、画像化する。
feature_imgs = []
for f in np.split(features, 64, axis=3):
    f = np.squeeze(f, axis=0)  # (1, FeatureH, FeatureW, FeatureC) -> (FeatureH, FeatureW, FeatureC)
    f = image.array_to_img(f)  # 特徴マップを画像化する。
    f = np.array(f)  # PIL オブジェクトを numpy 配列にする。
    feature_imgs.append(f)

# 重みをカーネルごとに分割し、画像化する。
weight_imgs = []
for w in np.split(weights, 64, axis=3):
    w = np.squeeze(w, axis=3)  # (KernelH, KernelW, KernelC, 1) -> (KernelH, KernelW, KernelC)
    w = image.array_to_img(w)  # 重みを画像化する。
    w = np.array(w)  # PIL オブジェクトを numpy 配列にする。
    weight_imgs.append(w)

描画する。

カーネルの重み及びそのカーネルで畳み込みを行った出力結果を並べて表示する。

cols = 2
rows = int(len(feature_imgs) / cols)
fig, axes = plt.subplots(rows, cols * 2, figsize=(12, 100))

for r in range(rows):
    for c in range(cols):
        i = r * cols + c
        w_axis, f_axis = axes[r, c * 2], axes[r, c * 2 + 1]

        w_axis.imshow(weight_imgs[i])
        w_axis.set_title('Kernel')
        w_axis.axis('off')

        f_axis.imshow(feature_imgs[i], cmap='gray')
        f_axis.set_title('Feature')
        f_axis.axis('off')

plt.show()