数字識別(MNIST)を実践~プログラムその3~
数字識別(MNIST)を実践~プログラムその3~
今回はディープラーニングとは直接関係ありませんが、コマンドラインからファイルのパスを受け取れるようにしてみたいと思います。
プログラム
・eval_mnist_cnn.py
# -*- coding: utf-8 -*- import argparse import chainer import chainer.functions as F import chainer.links as L import chainer.initializers as I from chainer import training from PIL import Image import numpy as np # ニューラル・ネットワークの構造 class CNN(chainer.Chain): # ニューラル・ネットワークの定義 def __init__(self): super(CNN, self).__init__( conv1=L.Convolution2D(1, 16, 5, 1, 0), conv2=L.Convolution2D(16, 32, 5, 1, 0), l3=L.Linear(None, 10), ) def __call__(self, x): h1 = F.max_pooling_2d(F.relu(self.conv1(x)), ksize=2, stride=2) h2 = F.max_pooling_2d(F.relu(self.conv2(h1)), ksize=2, stride=2) y = self.l3(h2) return y def convert_image(img): data = np.array(Image.open(img).convert('L').resize((28, 28)), dtype=np.float32) data = (255.0 - data) / 255.0 data = data.reshape(1, 1, 28, 28) return data def main(): # オプションの設定 parser = argparse.ArgumentParser(description='ChainerMNISTサンプル') parser.add_argument('--inputimage', '-i', default='', help='画像イメージファイル') parser.add_argument('--model', '-m', default='', help='モデルファイル') args = parser.parse_args() model = L.Classifier(CNN()) chainer.serializers.load_npz(args.model, model) img = convert_image(args.inputimage) x = chainer.Variable(np.asarray(img)) y = model.predictor(x) c = F.softmax(y).data.argmax() print('判定結果:{}'.format(c)) if __name__ == '__main__': main()
プログラムの解説
前回との違いは赤文字になっている個所です。
「数字画像ファイル」「学習済みモデルファイル」のパスをコマンドラインから受け取るようにしています。
python eval_mnist_cnn.py --inputimage image/num01.png --model result/CNN.model
もしくは
python eval_mnist_cnn.py -i image/num01.png -m result/CNN.model
でファイルを指定してプログラムを実行できるようになりました。