Pynote

Python、機械学習、画像処理について

matplotlib - 図にグリッド状にグラフを作成する方法

subplot

matplotlib において、figure の中に複数の axes がある場合、それらを subplot という。
subplot のレイアウトは入れ子であったり、グリッドであったり様々である。

グリッド状に配置された subplot を作成する。

グリッド状に配置された subplot を作成するための関数が pyplot インターフェイスまたは Figure クラスのメソッドとして提供されている。

pyplot.subplot

ax = subplot(nrows=1, ncols=1, index=1)

pyplot.subplot は、current figure を (nrows, ncols) に等分割し、指定した index のセルに Axes を1つ作成する。
分割した各セルは row-major order で 1 から nrows * ncols のインデックスが振られており、Axes を作成する場所のインデックスを index に指定する。


import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

# current figure が存在しないので、figure を作成し、
# current figure に設定する。
# current figure を 2x2 に等分割した場合に index が1の位置に Axes を
# 作成し、current axes に設定する。
plt.subplot(2, 2, 1)
# current axes の Axes.plot() を呼び出して、データを描画する。
plt.plot(x, y1)

# current figure を 2x2 に分割した場合に index が4の位置に Axes を
# 作成し、current axes に設定する。
plt.subplot(2, 2, 4)
# current axes の Axes.plot() を呼び出して、データを描画する。
plt.plot(x, y2)

plt.show()

nrows, ncols, index がすべて一桁であることが保証される (nrow * ncols < 10) 場合は、nrows, ncols, index を引数1つで指定する。

plt.subplot(224)  # plt.subplot(2, 2, 4) と同じ

途中で分割を変更した場合は、以前に作成した axes は削除される。

ax1 = add_subplot(2, 2, 1)
ax2 = add_subplot(1, 2, 2)  # 分割数を変更したので、ax1 は削除された。

Figure.add_subplot

Figure クラスの Figure.add_subplot() でも同じことができる。

import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

fig = plt.figure()

# fig を 2x2 に分割し、index=1 の位置に Axes を作成する。
ax1 = fig.add_subplot(2, 2, 1)
ax1.plot(x, y1)

# fig を 2x2 に分割し、index=2 の位置に Axes を作成する。
ax2 = fig.add_subplot(2, 2, 4)
ax2.plot(x, y1)

plt.show()

pyplot.subplots

fig, ax = subplots(nrows=1, ncols=1)

pyplot.subplots は、current figure を (nrows, ncols) に分割し、各セルに Axes を一度に作成して、numpy 配列で返す。


import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

fig, ax = plt.subplots(2, 1)
ax[0].plot(x, y1)
ax[1].plot(x, y2)

plt.show()

返り値 ax について

nrows=n, ncols=1 の場合: (n,) の numpy 配列
nrows=1, ncols=n の場合: (n,) の numpy 配列
nrows=n, ncols=m の場合: (n, m) の numpy 配列
nrows=1, ncols=1 の場合: Axes

fig, ax = plt.subplots(3, 1)
print(ax.shape) # (3,)

fig, ax = plt.subplots(1, 3)
print(ax.shape) # (3,)

fig, ax = plt.subplots(3, 3)
print(ax.shape) # (3, 3)

fig, ax = plt.subplots(1, 1)
print(type(ax))  # <class 'matplotlib.axes._subplots.AxesSubplot'>

nrows, ncols のデフォルト引数は1なので、省略した場合は Axes を1つ作成する。

fig, ax = plt.subplots()
print(type(ax))  # <class 'matplotlib.axes._subplots.AxesSubplot'>

Figure.subplots

Figure.subplots() でも同じことができる。

import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

fig = plt.figure()

ax = fig.subplots(2, 1)
ax[0].plot(x, y1)
ax[1].plot(x, y2)

plt.show()

Axes の配置を調整する。

目盛りのラベルなどが重ならないように Axes のレイアウトを調整するには、subplots_adjust で設定する。


import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 8))
fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9,
                    wspace=0.5, hspace=0.5)

for i in range(1, 10):
    ax = fig.add_subplot(3, 3, i)
    ax.text(0.5, 0.5, 'ax{}'.format(i),
            fontsize=20, ha='center', va='center')

plt.show()


自動で調整する。

tight_layout を呼び出すと、目盛りなどが重ならないように自動で間隔を調整できる。

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 8))

for i in range(1, 10):
    ax = fig.add_subplot(3, 3, i)
    ax.text(0.5, 0.5, 'ax{}'.format(i),
            fontsize=20, ha='center', va='center')
fig.tight_layout()

plt.show()