{ "cells": [ { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def numerical_gradient(f, x):\n", " h = 1e-4 # 0.0001\n", " grad = np.zeros_like(x)\n", " \n", " for idx in range(x.size):\n", " tmp_val = x[idx]\n", " # f(x+h)\n", " x[idx] = tmp_val + h\n", " fxh1 = f(x)\n", " #f(x-h)\n", " x[idx] = tmp_val - h\n", " fxh2 = f(x)\n", " \n", " grad[idx] = (fxh1 - fxh2) / (2*h)\n", " x[idx] = tmp_val # restore\n", " \n", " return grad" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def function_2(x):\n", " return x[0]**2 + x[1]**2" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 6., 8.])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numerical_gradient(function_2, np.array([3.0, 4.0]))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0., 4.])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numerical_gradient(function_2, np.array([0.0, 2.0]))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 6., 0.])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numerical_gradient(function_2, np.array([3.0, 0.0]))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def gradient_descent(f, init_x, lr=0.01, step_num=100):\n", " x = init_x\n", " \n", " for i in range(step_num):\n", " grad = numerical_gradient(f, x)\n", " x -= lr * grad\n", " \n", " return x" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ -6.11110793e-10, 8.14814391e-10])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "init_x = np.array([-3.0, 4.0])\n", "gradient_descent(function_2, init_x=init_x, lr=0.1, step_num=100)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ -2.58983747e+13, -1.29524862e+12])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "init_x = np.array([-3.0, 4.0])\n", "gradient_descent(function_2, init_x=init_x, lr=10.0, step_num=100)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-2.99999994, 3.99999992])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "init_x = np.array([-3.0, 4.0])\n", "gradient_descent(function_2, init_x=init_x, lr=1e-10, step_num=100)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from common.functions import softmax, cross_entropy_error\n", "from common.gradient import numerical_gradient" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class SimpleNet:\n", " def __init__(self):\n", " self.W = np.random.randn(2, 3) # Gaussian Distribution\n", " \n", " def predict(self, x):\n", " return np.dot(x, self.W)\n", " \n", " def loss(self, x, t):\n", " z = self.predict(x)\n", " y = softmax(z)\n", " loss = cross_entropy_error(y, t)\n", " \n", " return loss" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "net = SimpleNet()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 1.57097266 -0.05146236 1.43420101]\n", " [ 0.02408947 0.49283932 2.22659858]]\n" ] } ], "source": [ "print(net.W)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.96426412, 0.41267797, 2.86445932])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = np.array([0.6, 0.9])\n", "p = net.predict(x)\n", "p" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.21162096444663459" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = np.array([0, 0, 1])\n", "net.loss(x, t)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def f(W):\n", " return net.loss(x, t)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.07261079, 0.04182638, -0.11443717],\n", " [ 0.10891618, 0.06273958, -0.17165576]])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dW = numerical_gradient(f, net.W)\n", "dW" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from common.functions import *" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class TwoLayerNet:\n", " def __init__(self, input_size, hidden_size, output_size, \n", " weight_init_std=0.01):\n", " # init weights\n", " self.params = {}\n", " self.params['W1'] = weight_init_std * \\\n", " np.random.randn(input_size, hidden_size)\n", " self.params['b1'] = np.zeros(hidden_size)\n", " self.params['W2'] = weight_init_std * \\\n", " np.random.randn(hidden_size, output_size)\n", " self.params['b2'] = np.zeros(output_size)\n", " \n", " def predict(self, x):\n", " W1, W2 = self.params['W1'], self.params['W2']\n", " b1, b2 = self.params['b1'], self.params['b2']\n", " \n", " a1 = np.dot(x, W1) + b1\n", " z1 = sigmoid(a1)\n", " a2 = np.dot(z1, W2) + b2\n", " y = softmax(a2)\n", " \n", " return y\n", " \n", " def loss(self, x, t):\n", " y = self.predict(x)\n", " \n", " return cross_entropy_error(y, t)\n", " \n", " def accuracy(self, x, t):\n", " y = self.predict(x)\n", " y = np.argmax(y, axis=1)\n", " t = np.argmax(t, axis=1)\n", " \n", " accuracy = np.sum(y == t) / float(x.shape[0])\n", " return accuracy\n", " \n", " def numerical_gradient(self, x, t):\n", " loss_W = lambda W: self.loss(x, t)\n", " \n", " grads = {}\n", " grads['W1'] = numerical_gradient(loss_W, self.params['W1'])\n", " grads['b1'] = numerical_gradient(loss_W, self.params['b1'])\n", " grads['W2'] = numerical_gradient(loss_W, self.params['W2'])\n", " grads['b2'] = numerical_gradient(loss_W, self.params['b2'])\n", " \n", " return grads\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": true }, "outputs": [], "source": [ "net = TwoLayerNet(input_size=784, hidden_size=100, output_size=10)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(784, 100)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.params['W1'].shape" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(100,)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.params['b1'].shape" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(100, 10)" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.params['W2'].shape" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(10,)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.params['b2'].shape" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x = np.random.rand(100, 784)\n", "y = net.predict(x)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.0929079 , 0.10545567, 0.1011332 , 0.10213417, 0.100048 ,\n", " 0.09859305, 0.09589364, 0.1050598 , 0.09630168, 0.10247288],\n", " [ 0.09301324, 0.10553867, 0.10122337, 0.10208667, 0.09934299,\n", " 0.09852167, 0.09627567, 0.10514601, 0.09635366, 0.10249805],\n", " [ 0.09320152, 0.10543257, 0.10134838, 0.10227634, 0.09925217,\n", " 0.09846688, 0.09630402, 0.10498337, 0.0964722 , 0.10226256],\n", " [ 0.09266878, 0.10569335, 0.10127206, 0.10266394, 0.09956575,\n", " 0.09868581, 0.09608076, 0.10515584, 0.09590158, 0.10231212],\n", " [ 0.09305302, 0.10550946, 0.10122928, 0.10210634, 0.09984373,\n", " 0.09886459, 0.09600813, 0.1050553 , 0.09597674, 0.1023534 ],\n", " [ 0.09331485, 0.10531989, 0.10143237, 0.10204317, 0.09978208,\n", " 0.09857535, 0.09602764, 0.10528953, 0.09594862, 0.10226651],\n", " [ 0.0928379 , 0.10561995, 0.10143794, 0.1022184 , 0.09939102,\n", " 0.09871511, 0.09615843, 0.10507205, 0.0965413 , 0.10200789],\n", " [ 0.09281353, 0.10542469, 0.10155212, 0.10231861, 0.09961348,\n", " 0.09841559, 0.09607444, 0.10513555, 0.09629261, 0.10235936],\n", " [ 0.09293245, 0.10550783, 0.10142119, 0.10205802, 0.09958523,\n", " 0.09837237, 0.09633467, 0.10541461, 0.09611976, 0.10225386],\n", " [ 0.09289986, 0.1056334 , 0.10139139, 0.10164815, 0.09977585,\n", " 0.09895961, 0.09611974, 0.10512153, 0.09619731, 0.10225317],\n", " [ 0.09258278, 0.10547153, 0.10120106, 0.10232865, 0.09973687,\n", " 0.09880785, 0.09592924, 0.10521748, 0.09631695, 0.10240762],\n", " [ 0.09303644, 0.10531411, 0.10116376, 0.10245105, 0.09958393,\n", " 0.09865234, 0.0962649 , 0.10505575, 0.09615674, 0.10232098],\n", " [ 0.09260302, 0.10539303, 0.10131367, 0.10230459, 0.09994525,\n", " 0.09846959, 0.09608537, 0.10530974, 0.09604824, 0.10252751],\n", " [ 0.09295006, 0.10551502, 0.10136872, 0.10241844, 0.09967822,\n", " 0.09854984, 0.09622399, 0.10534094, 0.09592552, 0.10202925],\n", " [ 0.09327448, 0.10512344, 0.10113221, 0.10239402, 0.09986431,\n", " 0.09886895, 0.09601458, 0.10515902, 0.09598039, 0.10218859],\n", " [ 0.09299233, 0.10546107, 0.10142519, 0.10248919, 0.0996598 ,\n", " 0.09842925, 0.0959193 , 0.10503413, 0.096241 , 0.10234874],\n", " [ 0.0930229 , 0.1055508 , 0.10138768, 0.10238677, 0.09961899,\n", " 0.09874945, 0.09612063, 0.10485567, 0.09611624, 0.10219087],\n", " [ 0.09303769, 0.10532885, 0.10138388, 0.10252985, 0.09952819,\n", " 0.09858531, 0.09609863, 0.10505714, 0.09594898, 0.10250149],\n", " [ 0.0926607 , 0.10543339, 0.10154983, 0.10206987, 0.09957329,\n", " 0.09870852, 0.09622484, 0.10509715, 0.09608962, 0.10259278],\n", " [ 0.09308478, 0.10532966, 0.10148963, 0.10197841, 0.09980095,\n", " 0.09853303, 0.09636073, 0.10518206, 0.09595418, 0.10228657],\n", " [ 0.09273289, 0.10562945, 0.10157872, 0.10211403, 0.09954218,\n", " 0.09859246, 0.09604085, 0.10510426, 0.09611166, 0.10255349],\n", " [ 0.09304671, 0.10539205, 0.1010679 , 0.10209672, 0.09990028,\n", " 0.09843365, 0.09650144, 0.10509071, 0.0963799 , 0.10209063],\n", " [ 0.09317444, 0.10541322, 0.10113702, 0.1023813 , 0.09998182,\n", " 0.09852654, 0.09601359, 0.10518327, 0.09624309, 0.10194572],\n", " [ 0.09299863, 0.10547787, 0.10130434, 0.10247235, 0.09938361,\n", " 0.09870566, 0.09604828, 0.10496676, 0.09647397, 0.10216853],\n", " [ 0.09300635, 0.10545705, 0.10117339, 0.10228962, 0.09952985,\n", " 0.09842535, 0.09666665, 0.10484469, 0.09625436, 0.10235271],\n", " [ 0.09280799, 0.1055721 , 0.1012542 , 0.10228066, 0.09943349,\n", " 0.09839138, 0.0962981 , 0.10526674, 0.09603784, 0.1026575 ],\n", " [ 0.09285692, 0.10538053, 0.10113571, 0.10242615, 0.09976774,\n", " 0.09863012, 0.09583437, 0.10511881, 0.09607782, 0.10277183],\n", " [ 0.09324169, 0.10535371, 0.10114292, 0.10203163, 0.09973731,\n", " 0.09852199, 0.09633753, 0.10548286, 0.09621965, 0.10193072],\n", " [ 0.0929222 , 0.10587545, 0.10121776, 0.10205133, 0.09963414,\n", " 0.09857782, 0.09641187, 0.10508684, 0.09611636, 0.10210624],\n", " [ 0.09293693, 0.10536257, 0.10146791, 0.10237827, 0.09966499,\n", " 0.098599 , 0.09596902, 0.10537673, 0.09628863, 0.10195595],\n", " [ 0.0930968 , 0.10544525, 0.10144709, 0.10222534, 0.10002021,\n", " 0.09859486, 0.09577037, 0.10495168, 0.0962442 , 0.10220419],\n", " [ 0.09301212, 0.1056458 , 0.1015525 , 0.10212306, 0.09958425,\n", " 0.09828803, 0.09621234, 0.1051464 , 0.09623905, 0.10219644],\n", " [ 0.09281217, 0.10564033, 0.10137498, 0.10246494, 0.09971789,\n", " 0.09862356, 0.09602677, 0.10478378, 0.09653601, 0.10201957],\n", " [ 0.09307743, 0.1055587 , 0.10115507, 0.10207163, 0.09968943,\n", " 0.09863473, 0.09649056, 0.10516224, 0.09593227, 0.10222794],\n", " [ 0.09288073, 0.10528063, 0.10120155, 0.10232977, 0.0992267 ,\n", " 0.09857593, 0.09629453, 0.105383 , 0.09632317, 0.10250399],\n", " [ 0.09319098, 0.1052189 , 0.10133501, 0.10240542, 0.09952278,\n", " 0.09844207, 0.0960576 , 0.10515718, 0.09627831, 0.10239174],\n", " [ 0.09279781, 0.10540509, 0.10115933, 0.10212229, 0.09991968,\n", " 0.09871653, 0.09625242, 0.1050344 , 0.09624077, 0.10235168],\n", " [ 0.09297014, 0.10563907, 0.10143534, 0.10248328, 0.09975236,\n", " 0.09845218, 0.09578064, 0.10496744, 0.09624466, 0.10227489],\n", " [ 0.09267269, 0.10535607, 0.10128425, 0.10258506, 0.09982448,\n", " 0.09862899, 0.09617901, 0.10491605, 0.0960848 , 0.10246861],\n", " [ 0.0931228 , 0.10531965, 0.10136179, 0.10208117, 0.09973373,\n", " 0.09863552, 0.09634061, 0.10495159, 0.09598117, 0.10247196],\n", " [ 0.0927896 , 0.10537427, 0.10144334, 0.10227418, 0.09983793,\n", " 0.09856592, 0.09607543, 0.1048856 , 0.09633397, 0.10241975],\n", " [ 0.09267682, 0.10537043, 0.10159381, 0.10234836, 0.0997502 ,\n", " 0.0985187 , 0.09614119, 0.1050154 , 0.09631927, 0.10226582],\n", " [ 0.09294601, 0.1055481 , 0.10109421, 0.10226993, 0.09969694,\n", " 0.09852159, 0.09615979, 0.10551661, 0.09600873, 0.10223809],\n", " [ 0.09299171, 0.10545486, 0.1014576 , 0.1020845 , 0.09997446,\n", " 0.09826487, 0.09604613, 0.10505405, 0.09631063, 0.10236118],\n", " [ 0.09281308, 0.1053601 , 0.10148651, 0.10216482, 0.09969029,\n", " 0.09884875, 0.09579968, 0.10533915, 0.09606161, 0.10243601],\n", " [ 0.09314119, 0.10519844, 0.10138125, 0.10258859, 0.0994507 ,\n", " 0.09836353, 0.09608168, 0.10536495, 0.09630019, 0.10212948],\n", " [ 0.0932179 , 0.10509352, 0.10136455, 0.10247288, 0.09962402,\n", " 0.0984643 , 0.09623498, 0.10489334, 0.0962968 , 0.10233771],\n", " [ 0.0932496 , 0.10508594, 0.10158745, 0.10185015, 0.09945968,\n", " 0.09867447, 0.09599495, 0.10542168, 0.09631065, 0.10236543],\n", " [ 0.09239054, 0.10572909, 0.10145072, 0.10229395, 0.09964291,\n", " 0.09860416, 0.09627258, 0.1049183 , 0.09602399, 0.10267376],\n", " [ 0.09286374, 0.10509078, 0.10153034, 0.10244374, 0.09973043,\n", " 0.0983913 , 0.09596187, 0.10492631, 0.09643686, 0.10262463],\n", " [ 0.09290857, 0.10561684, 0.10091624, 0.10184458, 0.09958773,\n", " 0.09884173, 0.09636506, 0.10506259, 0.09632778, 0.10252888],\n", " [ 0.0930671 , 0.10547467, 0.1014752 , 0.10223337, 0.09963028,\n", " 0.09844618, 0.09630269, 0.10496758, 0.09596564, 0.10243729],\n", " [ 0.09291944, 0.10550189, 0.10128225, 0.10217023, 0.09961167,\n", " 0.09876585, 0.09622561, 0.10508077, 0.09603126, 0.10241103],\n", " [ 0.09293365, 0.10554515, 0.10149409, 0.10237709, 0.09955002,\n", " 0.09850161, 0.09622409, 0.10495057, 0.09600157, 0.10242216],\n", " [ 0.09262504, 0.10532547, 0.10123559, 0.10233704, 0.09965131,\n", " 0.09874683, 0.09625427, 0.10523021, 0.09628124, 0.10231301],\n", " [ 0.09266085, 0.10563015, 0.10149327, 0.10240379, 0.09954113,\n", " 0.09837556, 0.09624771, 0.10490928, 0.09627523, 0.10246304],\n", " [ 0.09279005, 0.10548137, 0.10122423, 0.10206681, 0.09984561,\n", " 0.09858248, 0.09638864, 0.10527409, 0.0961083 , 0.10223842],\n", " [ 0.09289807, 0.10523262, 0.10125267, 0.10216016, 0.09988487,\n", " 0.09866724, 0.09635009, 0.10525241, 0.09591358, 0.10238829],\n", " [ 0.09306998, 0.10504237, 0.10124571, 0.10220867, 0.09969836,\n", " 0.09877888, 0.09593485, 0.10520911, 0.09625314, 0.10255893],\n", " [ 0.09298017, 0.10550293, 0.10131011, 0.10225988, 0.09926442,\n", " 0.09874761, 0.09608259, 0.10537594, 0.09633204, 0.10214433],\n", " [ 0.09312334, 0.10564509, 0.1011713 , 0.10214632, 0.09939295,\n", " 0.09856929, 0.09638521, 0.10508584, 0.09612203, 0.10235862],\n", " [ 0.09316266, 0.10569992, 0.10148025, 0.10200224, 0.09984939,\n", " 0.09863175, 0.09609 , 0.10476663, 0.09603222, 0.10228494],\n", " [ 0.09291322, 0.10573389, 0.10095287, 0.10209558, 0.09976758,\n", " 0.09878431, 0.09605869, 0.10518378, 0.09624866, 0.10226144],\n", " [ 0.09299327, 0.10508258, 0.10112968, 0.10239027, 0.0996423 ,\n", " 0.09868273, 0.09609696, 0.10521983, 0.09599603, 0.10276635],\n", " [ 0.09316544, 0.10537806, 0.10134244, 0.10233228, 0.09970105,\n", " 0.09859353, 0.09612214, 0.10503006, 0.09603183, 0.10230317],\n", " [ 0.09291367, 0.10547921, 0.10154821, 0.10208022, 0.09979968,\n", " 0.09867101, 0.09612427, 0.10481417, 0.09635232, 0.10221725],\n", " [ 0.09308582, 0.10543164, 0.10113064, 0.10237458, 0.09951527,\n", " 0.09865144, 0.09598345, 0.1051388 , 0.09654628, 0.10214208],\n", " [ 0.09262663, 0.10533614, 0.10126861, 0.10195978, 0.09962621,\n", " 0.09888551, 0.09637737, 0.10527574, 0.09630532, 0.10233868],\n", " [ 0.09283884, 0.10560654, 0.10170607, 0.10210664, 0.09935601,\n", " 0.09864766, 0.09610845, 0.10508418, 0.09619765, 0.10234797],\n", " [ 0.09299104, 0.1054915 , 0.10161734, 0.10213074, 0.09944416,\n", " 0.09884851, 0.09614409, 0.10504494, 0.09616468, 0.10212301],\n", " [ 0.0932404 , 0.10525778, 0.10144113, 0.10249581, 0.0996119 ,\n", " 0.09848667, 0.09590917, 0.10519368, 0.09619361, 0.10216985],\n", " [ 0.09312855, 0.10532665, 0.10127102, 0.10238315, 0.09945666,\n", " 0.09833208, 0.09608495, 0.10495088, 0.09658753, 0.10247851],\n", " [ 0.09297056, 0.10541209, 0.10118944, 0.10244048, 0.09941505,\n", " 0.09860308, 0.09599168, 0.10501315, 0.09615737, 0.10280709],\n", " [ 0.09309037, 0.1055418 , 0.10132575, 0.10200752, 0.09996425,\n", " 0.09881251, 0.096235 , 0.10496613, 0.09615279, 0.10190387],\n", " [ 0.0929902 , 0.10537968, 0.10140437, 0.10231478, 0.09947537,\n", " 0.09863353, 0.09610326, 0.10512386, 0.09618994, 0.102385 ],\n", " [ 0.09316271, 0.10587491, 0.10104561, 0.10184887, 0.09979631,\n", " 0.0987707 , 0.0961894 , 0.10498153, 0.09624669, 0.10208327],\n", " [ 0.09267883, 0.10566397, 0.10139152, 0.10239437, 0.09941428,\n", " 0.09858267, 0.09583061, 0.10489618, 0.09643213, 0.10271545],\n", " [ 0.09267466, 0.10571734, 0.10132375, 0.10233848, 0.09957975,\n", " 0.09855079, 0.0960124 , 0.10503947, 0.09635051, 0.10241285],\n", " [ 0.09300987, 0.10560717, 0.10121133, 0.10232379, 0.09972364,\n", " 0.09866175, 0.09583634, 0.10474633, 0.09604491, 0.10283486],\n", " [ 0.09289735, 0.10530253, 0.10133054, 0.10221849, 0.0999344 ,\n", " 0.09865658, 0.09627921, 0.10508753, 0.09605576, 0.10223761],\n", " [ 0.09303492, 0.10530824, 0.10129613, 0.10225694, 0.09924018,\n", " 0.09884628, 0.0959959 , 0.10498526, 0.09654719, 0.10248896],\n", " [ 0.09272553, 0.10553886, 0.10108639, 0.10254428, 0.09964681,\n", " 0.09883252, 0.09611471, 0.10495217, 0.09608657, 0.10247217],\n", " [ 0.09275718, 0.10499128, 0.10124875, 0.10224824, 0.0996036 ,\n", " 0.098677 , 0.09625443, 0.1052669 , 0.09641795, 0.10253467],\n", " [ 0.09288568, 0.10549854, 0.10152235, 0.10226628, 0.09977237,\n", " 0.09836485, 0.09609393, 0.10507111, 0.09606076, 0.10246413],\n", " [ 0.09310356, 0.10577722, 0.10139805, 0.10204105, 0.09974502,\n", " 0.09861955, 0.09632951, 0.10487334, 0.09634328, 0.10176941],\n", " [ 0.09299005, 0.10532383, 0.10134905, 0.10230313, 0.09941665,\n", " 0.09858453, 0.09640385, 0.10505999, 0.09615628, 0.10241265],\n", " [ 0.09282008, 0.10537676, 0.10152507, 0.10240875, 0.09958684,\n", " 0.09842165, 0.09599903, 0.1049121 , 0.09594467, 0.10300506],\n", " [ 0.09267134, 0.10518258, 0.10127418, 0.10212827, 0.09945626,\n", " 0.09896402, 0.09615197, 0.10535166, 0.09627322, 0.1025465 ],\n", " [ 0.09319332, 0.10559174, 0.10133866, 0.10241742, 0.09967932,\n", " 0.09836049, 0.09606928, 0.1051732 , 0.09608128, 0.10209529],\n", " [ 0.09301218, 0.10534202, 0.10163597, 0.1024004 , 0.09948353,\n", " 0.09854612, 0.0964687 , 0.10497611, 0.09586819, 0.10226679],\n", " [ 0.0931067 , 0.1054424 , 0.10124272, 0.10258414, 0.09947893,\n", " 0.09844222, 0.09616333, 0.10511197, 0.0959703 , 0.10245728],\n", " [ 0.0931574 , 0.10575916, 0.1012452 , 0.10219702, 0.09973439,\n", " 0.09844335, 0.09621249, 0.10497318, 0.09615282, 0.102125 ],\n", " [ 0.09288593, 0.10545119, 0.10153591, 0.1023867 , 0.09982341,\n", " 0.09830605, 0.09612848, 0.10498039, 0.0960782 , 0.10242374],\n", " [ 0.09283482, 0.10571069, 0.10116512, 0.10252098, 0.09904207,\n", " 0.09873041, 0.09659869, 0.10502962, 0.09592377, 0.10244382],\n", " [ 0.09292705, 0.10529934, 0.10112663, 0.10187067, 0.09998368,\n", " 0.09878433, 0.09658129, 0.10516566, 0.09596034, 0.10230101],\n", " [ 0.09326483, 0.10514446, 0.10147883, 0.10233054, 0.09955384,\n", " 0.09862022, 0.09608851, 0.1051716 , 0.09612902, 0.10221815],\n", " [ 0.09255552, 0.10541584, 0.10093231, 0.10254851, 0.09945366,\n", " 0.09857041, 0.09630318, 0.1052671 , 0.09628015, 0.10267332],\n", " [ 0.09297027, 0.10564282, 0.10133211, 0.10231576, 0.09953203,\n", " 0.09846255, 0.09633847, 0.10499048, 0.09604892, 0.10236659],\n", " [ 0.09292853, 0.10571311, 0.10161256, 0.10172631, 0.09960786,\n", " 0.09868909, 0.09633054, 0.10498871, 0.09638428, 0.10201901],\n", " [ 0.09285064, 0.10554955, 0.10126082, 0.10240733, 0.09963301,\n", " 0.09885336, 0.09580484, 0.10526589, 0.09595257, 0.10242199]])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x = np.random.rand(100, 784)\n", "t = np.random.rand(100, 10)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "collapsed": true }, "outputs": [], "source": [ "grads = net.numerical_gradient(x, t)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(784, 100)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grads['W1'].shape" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(100,)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grads['b1'].shape" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(100, 10)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grads['W2'].shape" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(10,)" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grads['b2'].shape" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from dataset.mnist import load_mnist" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "collapsed": true }, "outputs": [], "source": [ "(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_loss_list = []" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_acc_list = []\n", "test_acc_list = []" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "collapsed": true }, "outputs": [], "source": [ "iters_num = 10000\n", "batch_size = 100\n", "learning_rate = 0.1" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(60000, 784)" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train.shape" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "iter_per_epoch = max(x_train.shape[0] / batch_size, 1)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "collapsed": true }, "outputs": [], "source": [ "network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "ename": "AxisError", "evalue": "axis 1 is out of bounds for array of dimension 1", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAxisError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0miter_per_epoch\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mtrain_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mtest_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mtrain_acc_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_acc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36maccuracy\u001b[0;34m(self, x, t)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0maccuracy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36margmax\u001b[0;34m(a, axis, out)\u001b[0m\n\u001b[1;32m 961\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 962\u001b[0m \"\"\"\n\u001b[0;32m--> 963\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'argmax'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 964\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 965\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/opt/conda/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_wrapfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;31m# An AttributeError occurs if the object does not have\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAxisError\u001b[0m: axis 1 is out of bounds for array of dimension 1" ] } ], "source": [ "for i in range(iters_num):\n", " batch_mask = np.random.choice(x_train.shape[0], batch_size)\n", " x_batch = x_train[batch_mask]\n", " t_batch = t_train[batch_mask]\n", " \n", " grad = network.numerical_gradient(x_batch, t_batch)\n", " \n", " for key in ('W1', 'b1', 'W2', 'b2'):\n", " network.params[key] -= learning_rate * grad[key]\n", " \n", " loss = network.loss(x_batch, t_batch)\n", " train_loss_list.append(loss)\n", " \n", " if i % iter_per_epoch == 0:\n", " train_acc = network.accuracy(x_train, t_train)\n", " test_acc = network.accuracy(x_test, t_test)\n", " train_acc_list.append(train_acc)\n", " test_acc_list.append(test_acc)\n", " print(\"train acc, test acc | {}, {}\".format(train_acc, test_acc))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }