Pynote

Python、機械学習、画像処理について

OpenCV - Non Maximum Suppression について

Non Maximum Suppression

物体検出を行うと、1つの物体に対して複数回検出されるということがある。
これらを1つの短形に統合する Non Maximum Suppression という処理がよく用いられる。

短形の表現

短形の左上の座標を (x_1, y_1)、右下の座標を (x_2, y_2)、幅及び高さを (w, h) と参照する。

(x_1, y_1), (x_2, y_2) から (w, h) を求めるには、w = x_2 - x_1 + 1, h = y_2 - y_1 + 1
逆に (x_1, y_1), (w, h) から (x_2, y_2) を求めるには、x_2 = x_1 + w - 1, y_2 = y_1 + h  - 1 となることに注意する。


Overlap Ratio

2つの矩形 a, b があったとき、\frac{area(a \cap b)}{area(a)} で計算できる値を Overlap Ratio という。
2つの矩形が完全に一致していれば OverlapRatio = 1、全く重なっていなかったら OverlapRatio = 0 となる。


Non Maximum Suppression の処理

入力として、スコア付きの矩形が渡される。
1. 入力からスコアが一番高い矩形を選択し、出力に移す。
2. 選択した矩形と入力に残っている各矩形の IOU を計算し、閾値以上のものを入力から削除する。
(選択した矩形とある程度重なっている矩形は同じ物体であると判断する。)
3. 入力が空になるまで1、2を繰り返す。
4. 入力が空になったら、出力にある矩形を結果出力とする。


閾値の設定

重複と判断して削除する Overlap Ratio の閾値[0, 1] の範囲で適切な値に決める必要がある。
値が大きいほど重複と判断する基準が厳しくなり、同一物体に複数の矩形が残ってしまう可能性がある。
逆に値が低いほど重複と判断する基準が緩くなり、異なる物体を示している矩形が同じ物体を示していると判断され、削除される可能性がある。

実装

モジュールを import する。

import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

画像を読み込む。

今回は Lena の画像を使用し、顔の一部をテンプレート画像として用いる。

入力画像

テンプレート画像

短形を描画する関数を用意する。

def draw_boxes(img, boxes, title):
    plt.figure(figsize=(8, 8))
    ax = plt.gca()
    ax.axis('off')
    ax.set_title(title)
    ax.imshow(cv2.cvtColor(drawn, cv2.COLOR_BGR2RGB))

    for box in boxes:
        x, y = box[:2]
        w, h = box[2:] - box[:2] + 1
        
        ax.add_patch(patches.Rectangle(
            (x, y), w, h, linewidth=1, edgecolor='green', fill=None))
    plt.show()

テンプレートマッチングを行う。

テンプレートマッチングを行い、類似度が 0.9 以上の短形を検出されたと判定する。

# テンプレートマッチングを行う。
results = cv2.matchTemplate(img, template, cv2.TM_CCOEFF_NORMED)

# 類似度が 0.6 以上の位置及びスコアを取得する。
positions = np.where(results >= 0.9)
scores = results[positions]

テンプレートマッチングの結果を元に短形一覧を作成する。

各要素が (x_1, y_1, x_2, y_2) である短形一覧を作成する。

# 短形に変換する。
boxes = []
h, w = template.shape[:2]  # テンプレート画像の高さ及び幅
for y, x in zip(*positions):
    boxes.append([x, y, x + w - 1, y + h - 1])
boxes = np.array(boxes)

この時点で検出結果を描画すると、以下のようになる。
1つの物体として検出されているように見えるが、実際は重なっており 25 個検出されてしまっている。

# 検出された短形の数
print('boxes.shape', boxes.shape)  # boxes.shape (25, 4)

# 短形一覧を描画する。
drawn = img.copy()
draw_boxes(drawn, boxes, 'Before Non Maximum Suppression')


Non Maximum Suppression を実装する。

def non_max_suppression(boxes, scores, overlap_thresh):
    '''Non Maximum Suppression (NMS) を行う。

    Args:
        boxes     : (N, 4) の numpy 配列。矩形の一覧。
        overlap_thresh: [0, 1] の実数。閾値。

    Returns:
        boxes : (M, 4) の numpy 配列。Non Maximum Suppression により残った矩形の一覧。
    '''
    if boxes.size == 0:
        return []

    # float 型に変換する。
    boxes = boxes.astype("float")
    # (NumBoxes, 4) の numpy 配列を x1, y1, x2, y2 の一覧を表す4つの (NumBoxes, 1) の numpy 配列に分割する。
    x1, y1, x2, y2 = np.squeeze(np.split(boxes, 4, axis=1))

    # 矩形の面積を計算する。
    area = (x2 - x1 + 1) * (y2 - y1 + 1)

    indices = np.argsort(scores)  # スコアを降順にソートしたインデックス一覧
    selected = []  # NMS により選択されたインデックス一覧

    # indices がなくなるまでループする。
    while len(indices) > 0:
        # indices は降順にソートされているので、一番最後の要素の値 (インデックス) が
        # 残っている中で最もスコアが高い。
        last = len(indices) - 1
        
        selected_index = indices[last]
        remaining_indices = indices[:last]
        selected.append(selected_index)

        # 選択した短形と残りの短形の共通部分の x1, y1, x2, y2 を計算する。
        i_x1 = np.maximum(x1[selected_index], x1[remaining_indices])
        i_y1 = np.maximum(y1[selected_index], y1[remaining_indices])
        i_x2 = np.minimum(x2[selected_index], x2[remaining_indices])
        i_y2 = np.minimum(y2[selected_index], y2[remaining_indices])

        # 選択した短形と残りの短形の共通部分の幅及び高さを計算する。
        # 共通部分がない場合は、幅や高さは負の値になるので、その場合、幅や高さは 0 とする。
        i_w = np.maximum(0, i_x2 - i_x1 + 1)
        i_h = np.maximum(0, i_y2 - i_y1 + 1)

        # 選択した短形と残りの短形の Overlap Ratio を計算する。
        overlap = (i_w * i_h) / area[remaining_indices]

        # 選択した短形及び OVerlap Ratio が閾値以上の短形を indices から削除する。
        indices = np.delete(indices, np.concatenate(([last], np.where(overlap > overlap_thresh)[0])))

    # 選択された短形の一覧を返す。
    return boxes[selected].astype("int")

Non Maximum Suppression を適用する。

Non Maximum Suppression を適用して、描画する。
25個あった検出結果が1個に統合されているのがわかる。

# Non Maximum Suppression を行う。
boxes = non_max_suppression(boxes, scores, overlap_thresh=0.6)
print('boxes.shape', boxes.shape)  # boxes.shape (1, 4)

# NMS 後に残った短形一覧を描画する。
drawn = img.copy()
draw_boxes(drawn, boxes, 'After Non Maximum Suppression')