RetinaNetで顔検出

前回(写真の人の顔に自動的にぼかしを入れる (OpenCV) - 情報関連の備忘録)opencvでは満足のいくdetectが出来なかったので、DNNでごり押しします。 kerasを使ってretinanetを利用します。学習に使うデータはすでにアノテーションされているものを利用します。なんと30000枚ほどの画像が揃ってる! 学習にかかる時間はえぐいが、検出力はなかなかの物。(注意:ここでは性能について検証していません。)

ネットワークについての元の論文は [1708.02002] Focal Loss for Dense Object Detection 。損失関数をいい感じに入れてるのが肝らしい。

手順

1. データの取得

WIDER FACE: A Face Detection BenchmarkからImageデータをダウンロードします。さらにanotationしたtxtファイルもあるので同時にダウンロード。 wider_face_train_bbx_gt.txtだけでも十分そうな気がする。

学習できるようにフォーマットを変更する。元のデータは[left, top, width, height]の順番.

1列が path/to/image.jpg,x1,y1,x2,y2,class_name となるように変更する。imageファイルのパスに注意する。

元のデータはファイル名やらなんやらがまとめて入っている。 枠のx1, x2が同じ値ものは省く。おそらくdetectしていない。

with open("wider_face_split/wider_face_train_bbx_gt.txt", "r") as f:
    f_out = open("face_annotate-all.csv", "w")
    line = f.readline()
    while line:
        # pathを見つけたら書き出す。
        if "jpg" in line:
            # 始めがいらない部分
            img_path = line.rstrip('\n')
            # 数
            line = f.readline()
            N = int(line)
        else:
            line = line.split(" ")
            if (int(line[0]) == int(line[0])+int(line[2])) or (int(line[1]) == int(line[1])+int(line[3])):
                print(img_path, "not object")
            else:
                f_out.write("WIDER_train/images/%s,%s,%s,%d,%d,face\n" % (img_path, line[0], line[1], int(line[0])+int(line[2]), int(line[1])+int(line[3])))
        line = f.readline()
    f_out.close()

また、class fileも作成する。 class_name,idのデータが並んだファイル。 今回はface, 0だけ。

2. ネットワークを入手

GitHub - fizyr/keras-retinanet: Keras implementation of RetinaNet object detection.を利用します。 git cloneで手元に入れる。 tensorlfow, kerasの準備をしておく。 初期化が必要です。 python setup.py build_ext --inplace

エラーが出たらそのほかに必要なmoduleを入れる。

3. 学習

以下のコマンドを走らせれば学習できる。

$ keras_retinanet/bin/train.py --backbone=resnet152 csv /path/to/csv/file/containing/annotations /path/to/csv/file/containing/classes

かなり時間がかかる。

モデルをdetect用にconvertする。

$ keras_retinanet/bin/convert_model.py /path/to/training/model.h5 /path/to/save/inference/model.h5

4. 検出

モデルをloadする。

from keras_retinanet.keras_retinanet.models import load_model
model = load_model('/path/to/model.h5', backbone_name='resnet50')

画像をreadしてretinanetの下にある関数を利用する(らしい)。

img = cv.imread("image_path.jpg")
# preprocess image for network
img = preprocess_image(img)
img, scale = resize_image(img)
boxes, scores, labels = model.predict_on_batch(np.expand_dims(image, axis=0))

# correct for image scale
boxes /= scale

その後は好きなscore以上のboxをとればよろしい。

ぼかしを入れる時に使ったコード

パスを追加してkeras-retinanet以下の関数を参照できるようにしています。

import glob
import re
import matplotlib.pyplot as plt
import sys
import pprint
import os
import numpy as np
import time
import cv2 as cv
import shutil
sys.path.append('my-project/keras-retinanet')
pprint.pprint(sys.path)
from keras_retinanet.models import load_model
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image

THRESHOLD = 0.4

def img2blured(img, boxes):
    print("detect area : ", len(boxes))
    blured = img.copy()
    for x1, y1, x2, y2 in boxes:
        x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2)
        blured = bluring(blured, x1, y1, x2, y2) 
    return blured
def bluring(img, x1, y1, x2, y2):
    tmp = img.copy()
    tmp[y1:y2, x1:x2] = cv.blur(tmp[y1:y2, x1:x2], (int((y2-y1)/3), int((x2-x1)/3)))
    return tmp
   
if __name__ == "__main__":
    print("model loading...")
    model = load_model("resnet50_detect.h5", backbone_name='resnet50')

    files = glob.glob("original/*.jpg")

    print(len(files), "files exists.")

    savedir = "detected/"
    for count, file in enumerate(files):
        # load
        print(count, file, "Load. ", end=" ")
        img = read_image_bgr(file) 
        original = img.copy()
        # preprocess image for network
        img = preprocess_image(img)
        img, scale = resize_image(img)
        # process image
        start = time.time()
        print("detecting...", end=" ")
        boxes, scores, labels = model.predict_on_batch(np.expand_dims(img, axis=0))
        print("end. processing time: %.2f sec" % (time.time() - start), end=" ")
        # correct for image scale
        boxes /= scale
        _, tmp = os.path.split(file)
        if (scores>THRESHOLD).sum() == 0: # non face
            print("non face")
            shutil.copy2(file, os.path.join(savedir, tmp))
        else:
            # bluring
            blured = img2blured(original, boxes[scores > THRESHOLD])
            # save
            cv.imwrite(os.path.join(savedir, tmp), blured)

次はYoloでも使ってみます。