Pynote

Python、機械学習、画像処理について

matplotlib - 図に複数のグラフを追加する方法

概要

matplotlib で図に複数のグラフを追加する方法について紹介する。
一つの Figure に複数の Axes を追加するには、plt.subplot()、plt.subplots()、Figure.add_subplot() または Figure.subplots() を使用する。

plt.subplot()

axes = subplot(nrows=1, ncols=1, index=1)

subplot() は、current figure を (nrows, ncols) 分割し、指定した index のセルに Axes を1つ作成する。
分割した各セルは row-major order で 1 から nrows * ncols の index が振られている。


import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

# current figure が存在しないので、Figure を作成し、current figure に設定する。
# current figure を 2x2 に分割した場合に index が1の位置に Axes を作成し、current axes に設定する。
plt.subplot(2, 2, 1)
# current axes の Axes.plot() を呼び出して、データを描画する。
plt.plot(x, y1)

# current figure を 2x2 に分割した場合に index が4の位置に Axes を作成し、current axes に設定する。
plt.subplot(2, 2, 4)
# current axes の Axes.plot() を呼び出して、データを描画する。
plt.plot(x, y2)

plt.show()

途中で分割を変更した場合は、以前に作成した Axes は削除される。

axes1 = add_subplot(2, 2, 1)
axes2 = add_subplot(1, 2, 2)  # 分割数を変更したので、axes1 は削除された。

nrows, ncols, index を引数1つで指定する。

nrows, ncols, index が一桁であることが保証される、つまり nrow * ncols < 10 の場合は、以下のようにできる。

plt.subplot(224)  # plt.subplot(2, 2, 4) と同じ

オブジェクト指向な使い方

plt.subplot() は作成した Axes を返すので、plt.plot() の代わりにそのオブジェクトの Axes.plot() を直接呼び出すこともできる。

import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

axes1 = plt.subplot(2, 2, 1)
axes1.plot(x, y1)
axes2 = plt.subplot(2, 2, 4)
axes2.plot(x, y2)

plt.show()

plt.subplots()

fig, axes_list = subplots(nrows=1, ncols=1)

subplots() は、current figure を (nrows, ncols) に分割し、各セルに Axes を作成して、numpy 配列で返す。


import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

fig, axes_list = plt.subplots(2, 1, 1)
axes_list[0].plot(x, y1)
axes_list[1].plot(x, y2)

plt.show()


返り値 axes_list について

nrows=n, ncols=1 の場合: (n,) の numpy 配列
nrows=1, ncols=n の場合: (n,) の numpy 配列
nrows=n, ncols=m の場合: (n, m) の numpy 配列
nrows=1, ncols=1 の場合: Axes

fig, axes_list = plt.subplots(3, 1)
print(axes_list.shape) # (3,)

fig, axes_list = plt.subplots(1, 3)
print(axes_list.shape) # (3,)

fig, axes_list = plt.subplots(3, 3)
print(axes_list.shape) # (3, 3)

fig, axes = plt.subplots(1, 1)
print(type(axes))  # <class 'matplotlib.axes._subplots.AxesSubplot'>

Figure.add_subplot()

plt.subplot() は Figure.add_subplot() でも同じことができる。

import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

# current figure が存在しないので、Figure を作成し、current figure に設定する。
fig = plt.gcf()

# fig を 2x2 に分割し、index=1 の位置に Axes を作成する。
axes1 = fig.add_subplot(2, 2, 1)
axes1.plot(x, y1)

# fig を 2x2 に分割し、index=2 の位置に Axes を作成する。
axes1 = fig.add_subplot(2, 2, 4)
axes1.plot(x, y1)

plt.show()

Figure.subplots()

plt.subplots() は Figure.subplots() でも同じことができる。

import numpy as np
import matplotlib.pyplot as plt

x = np.arange(0., 10., 0.1)
y1 = np.sin(x)
y2 = np.cos(x)

# current figure が存在しないので、Figure を作成し、current figure に設定する。
fig = plt.gcf()

axes_list = fig.subplots(2, 1)
axes_list[0].plot(x, y1)
axes_list[1].plot(x, y2)

plt.show()