ディープラーニングで笑顔を自動検知したい

ディープラーニングで笑顔を自動検知するまでの学習過程を綴っていきます。

数字識別(MNIST)を実践~プログラムその2~

数字識別(MNIST)を実践~プログラムその2~

今回は前回作成した学習済みモデルを使って推論を行ってみるところまでをやってみたいと思います。

プログラム

プログラム全文が短いのでここに貼り付けてみたいと思います。

学習済みモデル、推論対象の画像ファイルのパスは一旦固定にしています。

・eval_mnist_cnn.py

# -*- coding: utf-8 -*-
import chainer
import chainer.functions as F
import chainer.links as L
import chainer.initializers as I
from chainer import training
from PIL import Image
import numpy as np

# (1)
class CNN(chainer.Chain):
    def __init__(self):
        super(CNN, self).__init__(
            conv1=L.Convolution2D(1, 16, 5, 1, 0),
            conv2=L.Convolution2D(16, 32, 5, 1, 0),
            l3=L.Linear(None, 10),
        )

    def __call__(self, x):
        h1 = F.max_pooling_2d(F.relu(self.conv1(x)), ksize=2, stride=2)
        h2 = F.max_pooling_2d(F.relu(self.conv2(h1)), ksize=2, stride=2) 
        y = self.l3(h2)
        return y

# (2)
def convert_image(img):
    data = np.array(Image.open(img).convert('L').resize((28, 28)), dtype=np.float32)
    data = (255.0 - data) / 255.0
    data = data.reshape(1, 1, 28, 28)
    return data

def main():

    # (3)
    model = L.Classifier(CNN())    
    chainer.serializers.load_npz('result/CNN.model', model)

    # (4)
    img = convert_image('image/num01.png')
    x = chainer.Variable(np.asarray(img))
    
    # (5)
    y = model.predictor(x)
    c = F.softmax(y).data.argmax()    
    print('判定結果:{}'.format(c))        

if __name__ == '__main__':
    main()

プログラムの解説

(1)ニューラルネットワークの定義
前回同様、ニューラルネットワークの定義を行います。前回と全く同じ内容です。


(2)画像変換処理
今回、推論対象となる画像ファイルを読み込むのですが、モデルに合うように変換を行います。
具体的には、ファイル読み込み、グレースケールに変換、28x28ピクセルにリサイズ、白黒反転して正規化等を行っています。


(3)モデルのインスタンス化と読み込み
(1)で定義したCNN()のインスタンス化を行い、前回作成した『CNN.model』ファイルを読み込みます。


(4)推論対象の画像ファイルの読み込み
推論させたい対象となる画像ファイルを読み込んで、convert_image()を使って変換します。
変換したら、それをchainerで扱うVariableという型に変換します。


(5)推論
(4)で読み込んだ画像が0~9のどれに該当するか推論を行って結果を出力します。

プログラムを動かす前の事前準備

(1)推論対象画像の準備
eval_mnist_cnn.pyと同階層に『image』というフォルダを作成して、フォルダ内に『num01.png』という名称で画像を配置します。
以下、筆者が使用した画像ファイルになります。(Windowsのペイントで作成しました)

f:id:amiami05:20190312205811p:plain

(2)学習済みモデル
前回の記事で作成した学習済みモデルファイルが必要になります。

(3)ライブラリのインストール
今回、画像ファイル処理をおこなっているため『pillow』というライブラリのインストールが必要になります。