{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "true" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'numo/narray'\n", "require 'numo/gnuplot'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "true" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require_relative 'dataset/mnist'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[, , , ]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train, t_train, x_test, t_test = MNIST.load_mnist(\n", " flatten: true, normalize: false)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[60000, 784]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[60000]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t_train.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = x_train[0]\n", "label = t_train[0]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[784]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.shape" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "Gnuplot\n", "Produced by GNUPLOT 4.6 patchlevel 6 \n", "\n", "\n", "\n", "\n", "\n", "\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t \n", "\t \n", "\t\n", "\n", "\n", "\n", "\n", "\t\t\n", "\t\t 0\n", "\t\n", "\t\t\n", "\t\t 5\n", "\t\n", "\t\t\n", "\t\t 10\n", "\t\n", "\t\t\n", "\t\t 15\n", "\t\n", "\t\t\n", "\t\t 20\n", "\t\n", "\t\t\n", "\t\t 25\n", "\t\n", "\t\t\n", "\t\t 0\n", "\t\n", "\t\t\n", "\t\t 5\n", "\t\n", "\t\t\n", "\t\t 10\n", "\t\n", "\t\t\n", "\t\t 15\n", "\t\n", "\t\t\n", "\t\t 20\n", "\t\n", "\t\t\n", "\t\t 25\n", "\t\n", "\t\n", "\tgnuplot_plot_1\n", "\n", "\t\n", "\t\t'-' binary array=(28,28) format='%uint8'\n", "\t\n", ";\n", "\n", "\t\n", "\n", "\t\n", "\n", "\n", "\n" ], "text/plain": [ "#>" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Numo::noteplot do\n", " # set :key\n", " set xrange: -0.5...27.5\n", " set yrange: -0.5...27.5\n", " set :tic, :scale, 0\n", " unset :colorbox\n", " unset :cbtics\n", " set cbrange: 0..255\n", " set :palette, \"gray\"\n", " set :view, \"map\" # <- 利かない;;\n", " plot img.reshape(28, 28).reverse(0), with: \"image\"\n", "end" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":get_data" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_data()\n", " x_train, t_train, x_test, t_test = MNIST::load_mnist(\n", " normalize: true, flatten: true, one_hot_label: false)\n", " return x_test, t_test\n", "end" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[, ]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, t = get_data" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":init_network" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def init_network()\n", " require_relative 'dataset/load_npy'\n", " \n", " {\n", " \"W1\" => load_npy(\"dataset/W1.npy\"),\n", " \"b1\" => load_npy(\"dataset/b1.npy\"),\n", " \"W2\" => load_npy(\"dataset/W2.npy\"),\n", " \"b2\" => load_npy(\"dataset/b2.npy\"),\n", " \"W3\" => load_npy(\"dataset/W3.npy\"),\n", " \"b3\" => load_npy(\"dataset/b3.npy\")\n", " }\n", "end" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":sigmoid" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# シグモイド関数\n", "def sigmoid(x)\n", " 1 / (1 + Numo::NMath.exp(-x)) # Numo::DFloat を返す\n", "end" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":softmax" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# ソフトマックス関数\n", "def softmax(a)\n", " c = a.max\n", " exp_a = Numo::NMath.exp(a - c)\n", " sum_exp_a = exp_a.sum\n", " return exp_a / sum_exp_a\n", "end" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":predict" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def predict(network, x)\n", " w1, w2, w3 = network[\"W1\"], network[\"W2\"], network[\"W3\"]\n", " b1, b2, b3 = network[\"b1\"], network[\"b2\"], network[\"b3\"]\n", "\n", " a1 = x.dot(w1) + b1\n", " z1 = sigmoid(a1)\n", " a2 = z1.dot(w2) + b2\n", " z2 = sigmoid(a2)\n", " a3 = z2.dot(w3) + b3\n", " y = softmax(a3)\n", "\n", " return y\n", "end" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{\"W1\"=>Numo::SFloat(view)#shape=[784,50]\n", "[[-0.00741249, -0.00790439, -0.013075, 0.0185257, -0.00153461, ...], \n", " [-0.0102975, -0.0161665, -0.0122838, -0.0179263, 0.00339878, -0.0707078, ...], \n", " [-0.0130918, -0.00244747, -0.0177224, -0.0242403, -0.0220408, ...], \n", " [-0.0100084, 0.0195861, -0.00561698, 0.0383074, -0.0525067, -0.0235683, ...], \n", " [0.0220735, 0.00640837, -0.0283758, -0.0206995, -0.0271802, 0.00562039, ...], \n", " [0.0342826, 0.0356078, 0.0200793, 0.018984, 0.0210607, -0.0258285, ...], \n", " [0.0181678, 0.0759545, 0.0803725, 0.0126271, 0.0308478, 0.0182223, ...], \n", " [-0.0101596, -0.0128061, 0.0148865, 0.0142262, -0.0173277, 0.00898263, ...], \n", " [-0.0256382, -0.0372087, -0.0188893, 0.0143945, -0.0177627, 0.00180663, ...], \n", " [0.0306505, -0.0347321, -0.0334835, -0.0294925, 0.0147203, -0.00265436, ...], \n", " [-0.000140208, -0.052682, -0.0868047, 0.0210401, -0.0335705, 0.0398818, ...], \n", " [0.0378605, 0.00899352, -0.0673496, -0.0215917, 0.0231716, 0.0377583, ...], \n", " [-0.00906321, 0.0310108, -0.0135522, 0.0436055, -0.048058, 0.0194065, ...], \n", " [-0.02914, -0.0334625, 0.0237653, 0.0205169, -0.00135183, -0.0299549, ...], \n", " [-0.0494768, -0.0120735, -0.0185211, -0.00764567, -0.00143256, ...], \n", " [0.0230975, -0.0524216, -0.0278221, 0.0389236, 0.00804125, -0.0121782, ...], \n", " [-0.0247755, -0.0209219, 0.0311249, -0.0288237, 0.043576, -0.00481695, ...], \n", " [0.00931242, -0.0279488, 0.00211007, -0.00783962, 0.002434, -0.0244446, ...], \n", " [-0.00778972, -0.0122575, 0.0242902, -0.0184587, 0.0401073, 0.0157794, ...], \n", " [-0.0128085, 0.0446507, 0.00917025, 0.0146358, -0.00542548, 0.00721991, ...], \n", " ..., \"b1\"=>Numo::SFloat#shape=[50]\n", "[-0.0675032, 0.0695926, -0.0273047, 0.0225609, -0.220015, -0.220388, ...], \"W2\"=>Numo::SFloat(view)#shape=[50,100]\n", "[[-0.10694, 0.0159125, -0.443499, -0.147301, 0.110948, -0.267226, ...], \n", " [0.299116, -0.0332223, -0.0890222, 0.321936, 0.138116, -0.162876, ...], \n", " [0.0657665, 0.633045, 0.0232534, -0.122434, -0.0698603, 0.288259, ...], \n", " [0.0938811, 0.156378, -0.082346, 0.281713, -0.0161627, -0.12231, ...], \n", " [0.0480178, -0.239025, -0.0819087, 0.298241, 0.159125, -0.494149, ...], \n", " [-0.347745, -0.0468489, -0.0243375, 0.187181, 0.371477, -0.0329522, ...], \n", " [-0.246683, 0.0135599, 0.253191, 0.388198, -0.012781, 0.0365826, ...], \n", " [-0.338457, -0.455199, 0.441927, -0.140775, -0.0888904, -0.0440813, ...], \n", " [-0.0443381, -0.254269, 0.428783, -0.384618, 0.400534, 0.123745, ...], \n", " [0.164806, -0.0508704, -0.269874, -0.00801593, -0.226576, -0.10573, ...], \n", " [-0.238172, -0.345074, 0.520948, 0.146176, 0.116399, 0.133306, ...], \n", " [-0.0957638, 0.373424, 0.0782192, -0.347973, -0.168531, 0.342131, ...], \n", " [-0.156165, 0.183239, 0.345875, 0.211726, 0.00293521, 0.372232, 0.395342, ...], \n", " [0.290566, -0.24914, -0.204629, 0.161673, 0.143613, -0.0869808, ...], \n", " [0.0938107, -0.119691, 0.630243, 0.166359, -0.038692, 0.0963781, ...], \n", " [0.13873, -0.0787728, -0.212793, -0.0618759, -0.231768, -0.0665394, ...], \n", " [-0.134357, -0.0556394, 0.359851, 0.358701, 0.0799517, 0.10007, ...], \n", " [-0.29369, 0.323558, -0.408008, -0.192063, -0.189301, 0.0766154, ...], \n", " [0.442743, -0.297204, -0.259713, -0.159787, -0.247488, -0.150099, ...], \n", " [-0.102148, 0.293672, 0.131912, -0.123288, -0.22005, -0.0637278, ...], \n", " ..., \"b2\"=>Numo::SFloat#shape=[100]\n", "[-0.0147111, -0.0721513, -0.00155692, 0.121997, 0.116033, -0.00754946, ...], \"W3\"=>Numo::SFloat(view)#shape=[100,10]\n", "[[-0.421736, 0.689445, 0.087851, -0.483838, -0.195892, -0.311136, ...], \n", " [-0.524321, -0.143625, -0.00442161, 0.417746, 0.215626, -0.256584, ...], \n", " [0.6828, -0.512037, -0.441084, -0.082171, 0.319506, 1.08094, 0.296021, ...], \n", " [0.155144, 0.0678902, 0.947823, -0.016843, -0.580457, 0.0327762, ...], \n", " [0.505435, -0.27301, -0.0386345, 0.162854, -0.688129, 0.332544, 0.202865, ...], \n", " [-0.196135, -0.186595, -0.712957, 0.186612, 0.471211, 0.242333, ...], \n", " [-0.340777, 1.12584, -0.078864, -0.240078, -0.188947, 0.27763, 0.0558598, ...], \n", " [-0.599525, 0.777958, -0.634839, -0.00466428, 0.076268, -0.522314, ...], \n", " [-0.376633, 0.510983, -0.0223147, -0.252364, -0.346367, 0.264313, ...], \n", " [0.650214, -0.307582, -0.182832, 0.361862, -0.220673, 0.23675, 0.14689, ...], \n", " [0.98053, -0.283191, 0.437025, 0.336802, -0.704811, 0.514312, 0.363066, ...], \n", " [-0.514112, -0.0521245, 0.0873818, 0.0649854, -0.0679774, 0.0575621, ...], \n", " [-0.12203, -0.597061, -0.16789, -0.757287, 0.332458, 0.426246, -0.521181, ...], \n", " [0.154342, -0.368901, 0.582112, 0.515596, -0.564042, -0.499414, ...], \n", " [0.45486, -0.300265, 0.681925, -0.264147, -0.218347, 0.188041, 0.381713, ...], \n", " [-0.525331, 1.18984, 0.392173, 0.612659, -0.609629, -0.164489, -0.421452, ...], \n", " [0.66484, -0.477979, -0.265062, 0.636109, -1.42232, 0.0343148, -1.16742, ...], \n", " [0.0901859, 0.187217, 0.24068, 0.217941, -0.456552, 0.117313, 0.16936, ...], \n", " [0.491135, -0.900797, -0.536716, -0.789842, 0.383926, 0.27248, -0.204327, ...], \n", " [0.584113, -0.551996, 0.650859, 0.869456, -0.185319, -0.697416, ...], \n", " ..., \"b3\"=>Numo::SFloat#shape=[10]\n", "[-0.0602398, 0.00932628, -0.0135995, 0.0216713, 0.0107372, 0.066197, ...]}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "network = init_network()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":max_index" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def max_index(a)\n", " idx = 0\n", " val = a[0]\n", " a.each_with_index do |v, i|\n", " if v > val\n", " val = v\n", " idx = i\n", " end\n", " end\n", " idx\n", "end" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9352" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_cnt = 0\n", "\n", "x.each_with_index do |xi, i|\n", " y = predict(network, xi)\n", " # p = y.max_index # なぜか kernel が死んで落ちる(>_<)ので\n", " p = max_index(y)\n", " if p == t[i]\n", " accuracy_cnt += 1\n", " end\n", "end\n", "\n", "Float(accuracy_cnt) / x.length" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## バッチ処理(ミニバッチ)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "100" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 100" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":softmax" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# ソフトマックス関数(再定義)\n", "def softmax(a)\n", " if a.is_a?(Numo::NArray) && a.ndim > 1\n", " return Numo::NArray[*a.shape[0].times.map{|i|softmax(a[i,true])}]\n", " end\n", " c = a.max\n", " exp_a = Numo::NMath.exp(a - c)\n", " sum_exp_a = exp_a.sum\n", " return exp_a / sum_exp_a\n", "end" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9352" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_cnt = 0\n", "\n", "0.step(x.length-1, batch_size) do |i|\n", " x_batch = x[i...i+batch_size]\n", " y_batch = predict(network, x_batch)\n", " # ps = y_batch.max_index(1) # ←挙動おかしい\n", " ps = Numo::NArray[*y_batch.shape[0].times.map { |i| max_index(y_batch[i,true]) }]\n", " accuracy_cnt += ps.eq(t[i...i+batch_size]).count\n", "end\n", "\n", "Float(accuracy_cnt) / x.length" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Ruby 2.4.1", "language": "ruby", "name": "ruby" }, "language_info": { "file_extension": ".rb", "mimetype": "application/x-ruby", "name": "ruby", "version": "2.4.1" } }, "nbformat": 4, "nbformat_minor": 2 }