{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "using IJuliaPortrayals" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "using MXNet" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "MXNet.mx.SymbolicNode(MXNet.mx.MX_SymbolHandle(Ptr{Void} @0x00007f828448a390))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# CNN構築\n", "\n", "# input\n", "data = mx.Variable(:data)\n", "\n", "# first conv\n", "conv1 = @mx.chain mx.Convolution(data=data, kernel=(5,5), num_filter=32) =>\n", " mx.Activation(act_type=:relu) =>\n", " mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))\n", "\n", "# second conv\n", "conv2 = @mx.chain mx.Convolution(data=conv1, kernel=(5,5), num_filter=64) =>\n", " mx.Activation(act_type=:relu) =>\n", " mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))\n", "\n", "# first fully-connected\n", "fc1 = @mx.chain mx.Flatten(data=conv2) =>\n", " mx.FullyConnected(num_hidden=1024) =>\n", " mx.Activation(act_type=:relu)\n", "\n", "dp_fc1 = mx.Dropout(fc1, p=0.5)\n", "\n", "# second fully-connected\n", "fc2 = mx.FullyConnected(data=dp_fc1, num_hidden=10)\n", "\n", "# softmax loss\n", "cnn = mx.SoftmaxOutput(data=fc2, name=:softmax)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "Network Visualization\n", "\n", "\n", "convolution0\n", "\n", "convolution0\n", "Convolution\n", "kernel=(5,5)\n", "stride=(1,1)\n", "n-filter=32\n", "\n", "\n", "activation0\n", "\n", "activation0\n", "Activation\n", "act-type=relu\n", "\n", "\n", "activation0->convolution0\n", "\n", "\n", "\n", "\n", "pooling0\n", "\n", "pooling0\n", "Pooling\n", "type=max\n", "kernel=(2,2)\n", "stride=(2,2)\n", "\n", "\n", "pooling0->activation0\n", "\n", "\n", "\n", "\n", "convolution1\n", "\n", "convolution1\n", "Convolution\n", "kernel=(5,5)\n", "stride=(1,1)\n", "n-filter=64\n", "\n", "\n", "convolution1->pooling0\n", "\n", "\n", "\n", "\n", "activation1\n", "\n", "activation1\n", "Activation\n", "act-type=relu\n", "\n", "\n", "activation1->convolution1\n", "\n", "\n", "\n", "\n", "pooling1\n", "\n", "pooling1\n", "Pooling\n", "type=max\n", "kernel=(2,2)\n", "stride=(2,2)\n", "\n", "\n", "pooling1->activation1\n", "\n", "\n", "\n", "\n", "flatten0\n", "\n", "flatten0\n", "Flatten\n", "\n", "\n", "flatten0->pooling1\n", "\n", "\n", "\n", "\n", "fullyconnected0\n", "\n", "fullyconnected0\n", "FullyConnected\n", "num-hidden=1024\n", "\n", "\n", "fullyconnected0->flatten0\n", "\n", "\n", "\n", "\n", "activation2\n", "\n", "activation2\n", "Activation\n", "act-type=relu\n", "\n", "\n", "activation2->fullyconnected0\n", "\n", "\n", "\n", "\n", "dropout0\n", "\n", "dropout0\n", "Dropout\n", "\n", "\n", "dropout0->activation2\n", "\n", "\n", "\n", "\n", "fullyconnected1\n", "\n", "fullyconnected1\n", "FullyConnected\n", "num-hidden=10\n", "\n", "\n", "fullyconnected1->dropout0\n", "\n", "\n", "\n", "\n", "softmax\n", "\n", "softmax\n", "SoftmaxOutput\n", "\n", "\n", "softmax->fullyconnected1\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "IJuliaPortrayals.GraphViz(\"digraph \\\"Network Visualization\\\" {\\nnode [fontsize=10];\\nedge [fontsize=10];\\n\\\"convolution0\\\" [label=\\\"convolution0\\\\nConvolution\\\\nkernel=(5,5)\\\\nstride=(1,1)\\\\nn-filter=32\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#fb8072\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#941305\\\"];\\n\\\"activation0\\\" [label=\\\"activation0\\\\nActivation\\\\nact-type=relu\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#ffffb3\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#999900\\\"];\\n\\\"pooling0\\\" [label=\\\"pooling0\\\\nPooling\\\\ntype=max\\\\nkernel=(2,2)\\\\nstride=(2,2)\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#80b1d3\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#275372\\\"];\\n\\\"convolution1\\\" [label=\\\"convolution1\\\\nConvolution\\\\nkernel=(5,5)\\\\nstride=(1,1)\\\\nn-filter=64\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#fb8072\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#941305\\\"];\\n\\\"activation1\\\" [label=\\\"activation1\\\\nActivation\\\\nact-type=relu\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#ffffb3\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#999900\\\"];\\n\\\"pooling1\\\" [label=\\\"pooling1\\\\nPooling\\\\ntype=max\\\\nkernel=(2,2)\\\\nstride=(2,2)\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#80b1d3\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#275372\\\"];\\n\\\"flatten0\\\" [label=\\\"flatten0\\\\nFlatten\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#fdb462\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#975102\\\"];\\n\\\"fullyconnected0\\\" [label=\\\"fullyconnected0\\\\nFullyConnected\\\\nnum-hidden=1024\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#fb8072\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#941305\\\"];\\n\\\"activation2\\\" [label=\\\"activation2\\\\nActivation\\\\nact-type=relu\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#ffffb3\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#999900\\\"];\\n\\\"dropout0\\\" [label=\\\"dropout0\\\\nDropout\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#fccde5\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#90094e\\\"];\\n\\\"fullyconnected1\\\" [label=\\\"fullyconnected1\\\\nFullyConnected\\\\nnum-hidden=10\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#fb8072\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#941305\\\"];\\n\\\"softmax\\\" [label=\\\"softmax\\\\nSoftmaxOutput\\\",style=\\\"rounded,filled\\\",fixedsize=true,width=1.3,fillcolor=\\\"#b3de69\\\",shape=box,penwidth=2,height=0.8034,color=\\\"#597d1c\\\"];\\n\\\"activation0\\\" -> \\\"convolution0\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"pooling0\\\" -> \\\"activation0\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"convolution1\\\" -> \\\"pooling0\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"activation1\\\" -> \\\"convolution1\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"pooling1\\\" -> \\\"activation1\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"flatten0\\\" -> \\\"pooling1\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"fullyconnected0\\\" -> \\\"flatten0\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"activation2\\\" -> \\\"fullyconnected0\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"dropout0\\\" -> \\\"activation2\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"fullyconnected1\\\" -> \\\"dropout0\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n\\\"softmax\\\" -> \\\"fullyconnected1\\\" [arrowtail=open,color=\\\"#737373\\\",dir=back];\\n}\\n\",\"dot\",\"svg\")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "GraphViz(mx.to_graphviz(cnn))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "MXNet.mx.MXDataProvider(MXNet.mx.MX_DataIterHandle(Ptr{Void} @0x00007f82848a5910),Tuple{Symbol,Tuple}[(:data,(28,28,1,100))],Tuple{Symbol,Tuple}[(:softmax_label,(100,))],100,true,true)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# データ取得(データプロバイダ生成)\n", "batch_size = 100\n", "# include(Pkg.dir(\"MXNet\", \"examples\", \"mnist\", \"mnist-data.jl\"))\n", "# train_provider, eval_provider = get_mnist_providers(batch_size)\n", "data_name = :data\n", "label_name = :softmax_label\n", "flat=false\n", "train_provider = mx.MNISTProvider(image=\"MNIST_data/train-images-idx3-ubyte\",\n", " label=\"MNIST_data/train-labels-idx1-ubyte\",\n", " data_name=data_name, label_name=label_name,\n", " batch_size=batch_size, shuffle=true, flat=flat, silent=true)\n", "eval_provider = mx.MNISTProvider(image=\"MNIST_data/t10k-images-idx3-ubyte\",\n", " label=\"MNIST_data/t10k-labels-idx1-ubyte\",\n", " data_name=data_name, label_name=label_name,\n", " batch_size=batch_size, shuffle=false, flat=flat, silent=true)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO: Start training on [CPU0]\n", "INFO: Initializing parameters...\n", "INFO: Creating KVStore...\n", "INFO: Start training...\n", "INFO: == Epoch 001 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.7142\n", "INFO: time = 223.0584 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9740\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0001.params'\n", "INFO: == Epoch 002 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9804\n", "INFO: time = 208.5880 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9865\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0002.params'\n", "INFO: == Epoch 003 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9869\n", "INFO: time = 211.7859 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9892\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0003.params'\n", "INFO: == Epoch 004 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9898\n", "INFO: time = 206.5210 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9915\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0004.params'\n", "INFO: == Epoch 005 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9917\n", "INFO: time = 206.9616 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9890\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0005.params'\n", "INFO: == Epoch 006 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9928\n", "INFO: time = 206.5133 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9909\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0006.params'\n", "INFO: == Epoch 007 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9936\n", "INFO: time = 206.3797 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9908\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0007.params'\n", "INFO: == Epoch 008 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9945\n", "INFO: time = 206.2362 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9901\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0008.params'\n", "INFO: == Epoch 009 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9951\n", "INFO: time = 206.4188 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9910\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0009.params'\n", "INFO: == Epoch 010 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9962\n", "INFO: time = 206.4152 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9913\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0010.params'\n", "INFO: == Epoch 011 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9966\n", "INFO: time = 206.2151 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9908\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0011.params'\n", "INFO: == Epoch 012 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9964\n", "INFO: time = 206.3259 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9920\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0012.params'\n", "INFO: == Epoch 013 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9967\n", "INFO: time = 206.3057 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9907\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0013.params'\n", "INFO: == Epoch 014 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9974\n", "INFO: time = 206.4475 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9909\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0014.params'\n", "INFO: == Epoch 015 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9979\n", "INFO: time = 206.1776 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9921\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0015.params'\n", "INFO: == Epoch 016 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9979\n", "INFO: time = 206.1897 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9908\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0016.params'\n", "INFO: == Epoch 017 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9980\n", "INFO: time = 206.0748 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9921\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0017.params'\n", "INFO: == Epoch 018 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9983\n", "INFO: time = 205.9644 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9931\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0018.params'\n", "INFO: == Epoch 019 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9982\n", "INFO: time = 208.8835 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9887\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0019.params'\n", "INFO: == Epoch 020 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9976\n", "INFO: time = 205.8203 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9903\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0020.params'\n", "INFO: == Epoch 021 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9984\n", "INFO: time = 206.7411 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9925\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0021.params'\n", "INFO: == Epoch 022 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9988\n", "INFO: time = 206.3165 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9914\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0022.params'\n", "INFO: == Epoch 023 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9983\n", "INFO: time = 207.6620 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9923\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0023.params'\n", "INFO: == Epoch 024 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9977\n", "INFO: time = 205.8147 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9927\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0024.params'\n", "INFO: == Epoch 025 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9986\n", "INFO: time = 205.6267 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9919\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0025.params'\n", "INFO: == Epoch 026 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9986\n", "INFO: time = 1883.2102 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9924\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0026.params'\n", "INFO: == Epoch 027 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9992\n", "INFO: time = 208.2260 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9930\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0027.params'\n", "INFO: == Epoch 028 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9987\n", "INFO: time = 207.1775 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9925\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0028.params'\n", "INFO: == Epoch 029 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9993\n", "INFO: time = 206.7774 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9932\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0029.params'\n", "INFO: == Epoch 030 ==========\n", "INFO: ## Training summary\n", "INFO: accuracy = 0.9995\n", "INFO: time = 206.6229 seconds\n", "INFO: ## Validation summary\n", "INFO: accuracy = 0.9928\n", "INFO: Saved checkpoint to 'MNIST_CNN3-0030.params'\n" ] } ], "source": [ "# モデル構築・最適化\n", "\n", "# モデル setup\n", "model = mx.FeedForward(cnn, context=mx.cpu())\n", "\n", "# optimization algorithm\n", "optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)\n", "\n", "# save-checkpoint callback\n", "save_checkpoint = mx.do_checkpoint(\"MNIST_CNN3\")\n", "\n", "# fit parameters\n", "mx.fit(model, optimizer, train_provider, n_epoch=30, eval_data=eval_provider, callbacks=[save_checkpoint])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "10x10000 Array{Float32,2}:\n", " 6.6912e-25 1.84421e-19 3.34017e-14 … 6.27268e-23 5.81799e-18\n", " 7.64353e-21 6.56743e-20 1.0 8.11345e-25 2.53667e-19\n", " 5.96325e-20 1.0 7.18634e-14 1.18043e-28 3.24815e-19\n", " 7.11701e-20 5.54137e-23 1.69543e-18 2.38322e-18 1.05279e-21\n", " 1.03662e-19 1.21309e-25 8.94513e-13 1.3275e-28 1.67418e-19\n", " 2.30987e-25 1.73122e-29 1.34319e-14 … 1.0 5.61854e-19\n", " 1.45137e-27 7.20162e-20 9.96349e-14 2.55676e-18 1.0 \n", " 1.0 1.75189e-23 1.66512e-12 7.50446e-26 2.32328e-25\n", " 6.29969e-28 4.04601e-22 9.14715e-14 8.67362e-15 5.61445e-20\n", " 2.57974e-17 1.27717e-27 3.43039e-16 4.33827e-21 3.61943e-22" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 予測\n", "probs = mx.predict(model, eval_provider)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on eval set: 99.43%\n" ] } ], "source": [ "# 予測精度確認\n", "\n", "# collect all labels from eval data\n", "labels = Array[]\n", "for batch in eval_provider\n", " push!(labels, copy(mx.get(eval_provider, batch, :softmax_label)))\n", "end\n", "labels = cat(1, labels...)\n", "\n", "# Now we use compute the accuracy\n", "correct = 0\n", "for i = 1:length(labels)\n", " # labels are 0...9\n", " if indmax(probs[:,i]) == labels[i]+1\n", " correct += 1\n", " end\n", "end\n", "accuracy = 100correct/length(labels)\n", "println(mx.format(\"Accuracy on eval set: {1:.2f}%\", accuracy))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "MXNet.mx.ArrayDataProvider(Array{Float32,N}[\n", "784x1 Array{Float32,2}:\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " ⋮ \n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0\n", " 0.0],[:data],Array{Float32,N}[],Symbol[],1,1,false,0.0f0,0.0f0,[mx.NDArray(28,28,1,1)],MXNet.mx.NDArray[])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = first(eval_provider)\n", "# images0 = copy(mx.get(eval_provider, batch, :data))\n", "# size(images0)\n", "image = copy(mx.get(eval_provider, batch, :data))[:,:,:,1:1]\n", "\n", "# all(x->0.0<=x<=1.0,vec(image))\n", "# => true\n", "\n", "# provider = mx.ArrayDataProvider(images[:,1:1])\n", "provider = mx.ArrayDataProvider(image)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "10x1 Array{Float32,2}:\n", " 6.6912e-25 \n", " 7.64353e-21\n", " 5.96325e-20\n", " 7.11701e-20\n", " 1.03662e-19\n", " 2.30987e-25\n", " 1.45137e-27\n", " 1.0 \n", " 6.29969e-28\n", " 2.57974e-17" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mx.predict(model, provider)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "classify (generic function with 3 methods)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import JSON\n", "\n", "function classify(a::Vector{Float32})\n", " image = reshape(a, (28, 28, 1, 1))\n", " # classify\n", " result = mx.predict(model, mx.ArrayDataProvider(image))\n", " return JSON.json(vec(result))\n", "end\n", "\n", "function classify(a::Vector)\n", " classify(convert(Vector{Float32}, a))\n", "end\n", "\n", "function classify(s::AbstractString)\n", " classify(JSON.parse(s))\n", "end" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "Sorry, your browser doesn't support canvas technology.\n", "

\n", "\n", "\n", "

\n", "Result:
\n", "

\n", "
\n", "" ], "text/plain": [ "HTML{ASCIIString}(\"\\n
\\nSorry, your browser doesn't support canvas technology.\\n

\\n\\n\\n

\\nResult:
\\n

\\n
\\n\")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "HTML(open(readall, \"classify_canvas.html\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 0.4.5", "language": "julia", "name": "julia-0.4" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "0.4.5" } }, "nbformat": 4, "nbformat_minor": 0 }