Pynote

Python、機械学習、画像処理について

数学 - 勾配について Python で可視化して理解する。

概要

機械学習Deep Learning など最適化問題を解く際に勾配法が広く使われている。
この記事では勾配法に出てくる勾配について、定義及び性質を示したあと、Python を使ってグラフに描画して理解する。

勾配とは

[定義] 勾配

\mathbb{R}^n の開集合 \Omega 上で関数 f: \Omega \to \mathbb{R}^n が定義されているとする。
f\boldsymbol{a} \in \Omega微分可能であるとき、(df(\boldsymbol{a}))^T を点 \boldsymbol{a}f勾配といい、\nabla(\boldsymbol{a}) と表す。

 \displaystyle
\nabla(\boldsymbol{a}) = \left(
    \frac{\partial f}{\partial x_1}(\boldsymbol{a}),
    \frac{\partial f}{\partial x_2}(\boldsymbol{a}),
    \cdots,
    \frac{\partial f}{\partial x_n}(\boldsymbol{a})
\right)^T

[定理] 勾配ベクトルは傾斜が最も急な方向を表す。

\mathbb{R}^n の開集合 \Omega 上で関数 f: \Omega \to \mathbb{R}^n が定義されているとする。
f\boldsymbol{a} \in \Omega微分可能であるとき、勾配 \nabla f(\boldsymbol{a}) は点 \boldsymbol{a} において傾斜が最も急な方向を表す。

証明)

テイラー展開より
 \displaystyle
f(\boldsymbol{a} + \boldsymbol{h}) - f(\boldsymbol{a})
    = \langle \nabla f(\boldsymbol{a}), \boldsymbol{h} \rangle
    + o(\boldsymbol{h}), \quad (\boldsymbol{h} \to \boldsymbol{0})

右辺の第2項は高位の無限小だから、右辺の第1項が f の増分の主部といえる。
h\nabla f のなす角を \theta \in [0, \pi] とすると、

 \displaystyle
\langle \nabla f(\boldsymbol{a}), \boldsymbol{h} \rangle
= \|\nabla f(\boldsymbol{a})\| \|\boldsymbol{h}\| \cos \theta \\
これは、\theta = 0 のとき最大となる。
よって、\nabla f(\boldsymbol{a})\boldsymbol{h} が同じ方向のとき、関数値の増加が最大となる。

Python で可視化する。

f(x, y) = x^2 + y^2 + xy という変数を定義したとき、その勾配は

 \displaystyle
\nabla f
= \left(\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y}\right)^T
 = \left(2x + y, 2y + x\right)^T

となる。

関数を定義する。

関数 f 及びその勾配 \nabla fPython で定義すると、以下のようになる。

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D


def f(x, y):
    return x**2 + y**2 + x * y


def gradient(x, y):
    df_x = 2 * x + y
    df_y = 2 * y + x
    return df_x, df_y

以下で次の3つのグラフを matplotlib で作成する。

  • 関数の 3D グラフ
  • ある点の勾配
  • 各点の勾配を表すベクトル図

matplotlib の使い方については以下を参照されたい。

pynote.hatenablog.com
pynote.hatenablog.com

関数を描画する。

# 描画する。
X, Y = np.mgrid[-10:11, -10:11]
Z = f(X, Y)

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel('$x$', fontsize=15)
ax.set_ylabel('$y$', fontsize=15)
ax.set_zlabel('$z$', fontsize=15)

# 3Dグラフを作成する。
surf = ax.plot_surface(X, Y, Z, alpha=0.5, edgecolor='black')
plt.show()


勾配を描画する。

# 点 (1, 4) における勾配を計算する。
x, y = 1, 4
u, v = gradient(x, y)

# 描画する。
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlabel('$x$', fontsize=15)
ax.set_ylabel('$y$', fontsize=15)

# 点 (1, 4) 及び勾配方向の矢印を作成する。
ax.plot(x, y, 'ro')
ax.text(x - 3, y, '({}, {})'.format(x, y), fontsize=15)
ax.arrow(x, y, u * 0.4, v * 0.4, width=0.1, color='purple')

# 等高線を作成する。
contours = ax.contour(X, Y, Z)
ax.clabel(contours, inline=1, fontsize=10, fmt='%.2f')

plt.show()


各点での勾配をベクトル図で描画する。

X, Y = np.mgrid[-10:11, -10:11]
Z = f(X, Y)  # 各点での関数 f の値を計算する。
U, V = gradient(X, Y)  # 各点での関数 f の勾配を計算する。

# 描画する。
fig = plt.figure(figsize=(10, 5))

# 勾配のベクトル図を作成する。
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlabel('$x$', fontsize=15)
ax.set_ylabel('$y$', fontsize=15)
ax.quiver(X, Y, U, V)
plt.show()