Pynote

Python、機械学習、画像処理について

数学 - 最小二乗法の理論と Python での実装方法について

概要

最小二乗法の仕組み及び Python の実装について

サンプル (\boldsymbol{x}_1, y_1), (\boldsymbol{x}_2, y_2), \cdots, (\boldsymbol{x}_n, y_n) が与えられたとする。

このデータを次の m 個の基底関数 g_{i}, (i = 1, 2, .\cdots, m) の線形結合で表される関数 f で近似することを考える。
 \displaystyle
f(\boldsymbol{x}) = \sum_{j = 1}^{m} a_{j} g_{j}(\boldsymbol{x}) \qquad (1)

このとき、すべてのサンプルの二乗誤差を損失関数として定義する。
\frac{1}{2}微分した際に出てくる定数2を打ち消して、式を簡潔にするためのもので、とくに意味はない。

 \displaystyle
J(a_1, a_2, \cdots, a_m) = \frac{1}{2} \sum_{i = 1}^{n} (f(\boldsymbol{x}_i) - y_i)^2 \qquad (2)

GG_{ij} = g_j(x_i) である n \times m 行列、\boldsymbol{a} = (a_1, a_2, \cdots, a_m)^T, \boldsymbol{y} = (y_1, y_2, \cdots, y_n)^T とすると、

 \displaystyle
\begin{align}
J(\boldsymbol{a})
& = \frac{1}{2} \sum_{i = 1}^{n} (f(\boldsymbol{x}_i) - y_i)^2 \\
& = \frac{1}{2} \sum_{i = 1}^{n} (\sum_{j = 1}^{m} a_{j} g_{j}(\boldsymbol{x}_{i}) - \boldsymbol{y}_i)^2 \\
& = \frac{1}{2} (G \boldsymbol{a} - \boldsymbol{y})^T (G \boldsymbol{a} - \boldsymbol{y}) \\
& = \frac{1}{2} \|G \boldsymbol{a} - \boldsymbol{y}\|^2
\end{align}

この関数 J の最小化問題を考える。
J は凸関数であるので、極小値が最小値となることが保証される。
極小値を求めるには、\frac{\partial}{\partial \boldsymbol{a}} = 0 を満たす解 \boldsymbol{a} を求めればよい。


関数 J(\boldsymbol{a})微分

 \displaystyle
\begin{align}
\frac{\partial}{\partial \boldsymbol{a}} \frac{1}{2} \|G \boldsymbol{a} - \boldsymbol{y}\|^2
&= G^T G \boldsymbol{a} - G^T \boldsymbol{y}
\end{align}

より、G^T G \boldsymbol{a} = G^T \boldsymbol{y} を解けばよいことがわかる。
この式を正規方程式という。

損失関数が凸関数であることの証明

\boldsymbol{a}^T G^T \boldsymbol{y} = (G \boldsymbol{a})^T \boldsymbol{y} = \boldsymbol{y}^T (G \boldsymbol{a}) に注意すると、

 \displaystyle
\begin{align}
\phi(\boldsymbol{a})
&= (G \boldsymbol{a} - \boldsymbol{y})^T (G \boldsymbol{a} - \boldsymbol{y}) \\
&= \boldsymbol{a}^T G^T  G \boldsymbol{a} - \boldsymbol{a}^T G^T \boldsymbol{y} - \boldsymbol{y}^T G \boldsymbol{a} - \boldsymbol{y}^T \boldsymbol{y}\\
&= \boldsymbol{a}^T G^T  G \boldsymbol{a} - 2 \boldsymbol{y}^T G \boldsymbol{a} - \|\boldsymbol{y}\|^2
\end{align}

\phi(\boldsymbol{a}) が凸関数であることを示すには、任意の \boldsymbol{a}_1, \boldsymbol{a}_2 \in \mathbb{R}^m, t \in [0, 1] に対して、

 \displaystyle
\phi(t \boldsymbol{a}_1 + (1 - t) \boldsymbol{a}_2) - (t \phi(\boldsymbol{a}_1) + (1 - t)\phi(\boldsymbol{a}_2)) \le 0

を示せばよい。
左辺を展開して整理すると、

 \displaystyle
\begin{align}
& t^2 \boldsymbol{a}_1 G^T G \boldsymbol{a}_1 +
2t(1 - t) \boldsymbol{a}_1 G^T G \boldsymbol{a}_2 +
(1 - t)^2 \boldsymbol{a}_2 G^T G \boldsymbol{a}_2 -
t \boldsymbol{a}_1 G^T G \boldsymbol{a}_1 -
(1 - t) \boldsymbol{a}_2 G^T G \boldsymbol{a}_2 \\
& = -t(1 - t)(\boldsymbol{a}_1 G^T G \boldsymbol{a}_1 -2
\boldsymbol{a}_1 G^T G \boldsymbol{a}_2 + \boldsymbol{a}_2 G^T G \boldsymbol{a}_2) \\
&= -t(1 - t)((\boldsymbol{a}_1 - \boldsymbol{a}_2)^T G^T G (\boldsymbol{a}_1 - \boldsymbol{a}_2) \\
&= -t(1 - t)( (G (\boldsymbol{a}_1 - \boldsymbol{a}_2) )^T (G (\boldsymbol{a}_1 - \boldsymbol{a}_2))) \\
&= -t(1 - t)\| G (\boldsymbol{a}_1 - \boldsymbol{a}_2) \|^2
\end{align}

ここで、\| G (\boldsymbol{a}_1 - \boldsymbol{a}_2) \|^2 \ge 0t \in [0, 1] により -t(1 - t) \le 0 であるから、

 \displaystyle
 -t(1 - t)\| G (\boldsymbol{a}_1 - \boldsymbol{a}_2) \|^2 \le 0

よって、\phi(\boldsymbol{a}) が凸関数である。
よって、J(\boldsymbol{a}) = \frac{1}{2} \phi(\boldsymbol{a}) も凸関数である。

1次近似

理論

サンプル (x_1, y_1), (x_2, y_2), \cdots, (x_n, y_n) が与えられたとする。

g_1(x) = x, g_2(x) = 1 として、f(x) = a_1 x + a_2、つまり直線近似することを考える。

このとき、\boldsymbol{a} = (a_1, a_2)^T, \boldsymbol{y} = (y_1, y_2, \cdots, y_n)^T
 \displaystyle
G = \begin{pmatrix}
x_1 & 1 \\
x_2 & 1 \\
\vdots & \vdots \\
x_n & 1 \\
\end{pmatrix}

であるから、正規方程式は

 \displaystyle
\begin{align}
G^T G \boldsymbol{a} &= G^T \boldsymbol{y} \\
\begin{pmatrix}
\sum_{i=1}^n x_i^2 & \sum_{i=1}^n x_i \\
\sum_{i=1}^n x_i & n \\
\end{pmatrix}
\begin{pmatrix}
a_1 \\
a_2 \\
\end{pmatrix}
& =
\begin{pmatrix}
\sum_{i=1}^n x_i y_i \\
\sum_{i=1}^n y_i \\
\end{pmatrix}
\end{align}

となる。

例題

実験によると鉱石の密度 $x (g/cm^3)$ と鉄含有量 $y (\%)$ の間には次の関係があった。

\displaystyle
\begin{array}{|c|c|}
\hline
x & 2.8 & 2.9 & 3.0 & 3.1 & 3.2 & 3.2 & 3.2 & 3.3 & 3.4 \\
\hline
y & 30 & 26 & 33 & 31 & 33 & 35 & 37 & 36 & 33 \\
\hline
\end{array}

これから密度が 3.25 g/cm^3 の鉱石の鉄含有量はいくらであると推定されるか。

import matplotlib.pyplot as plt
import numpy as np

xs = np.array([2.8, 2.9, 3.0, 3.1, 3.2, 3.2, 3.2, 3.3, 3.4])
ys = np.array([30, 26, 33, 31, 33, 35, 37, 36, 33])

# Ga = y を解く。
G = np.array([[(xs ** 2).sum(), xs.sum()],
              [xs.sum(), len(xs)]])
y = np.array([(xs * ys).sum(), ys.sum()])
Ginv = np.linalg.inv(G)

a1, a2 = Ginv.dot(y)
print('a1={:.2f}, a2={:.2f}'.format(a1, a2))  # a1=10.46, a2=0.02
# 密度が 3.25 g/cm^3 の鉱石の鉄含有量: 34.02

# 近似直線の式
def f(x):
    return a1 * x + a2

print('密度が {} g/cm^3 の鉱石の鉄含有量: {:.2f}'.format(3.25, f(3.25)))

# 直線上の点を作成する。
line_X = np.linspace(xs.min(), xs.max(), 100)
line_Y = f(line_X)

# データ及び近似した直線を作成する。
fig, ax = plt.subplots(figsize=(7, 7))
ax.set_xlabel('x')
ax.set_ylabel('y')

ax.plot(line_X, line_Y, 'r-')
ax.scatter(xs, ys)