{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 402 RNN\n", "\n", "View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/\n", "My Youtube Channel: https://www.youtube.com/user/MorvanZhou\n", "\n", "Dependencies:\n", "* torch: 0.1.11\n", "* matplotlib\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.autograd import Variable\n", "import torchvision.datasets as dsets\n", "import torchvision.transforms as transforms\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(1) # reproducible" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Hyper Parameters\n", "EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch\n", "BATCH_SIZE = 64\n", "TIME_STEP = 28 # rnn time step / image height\n", "INPUT_SIZE = 28 # rnn input size / image width\n", "LR = 0.01 # learning rate\n", "DOWNLOAD_MNIST = True # set to True if haven't download the data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Mnist digital dataset\n", "train_data = dsets.MNIST(\n", " root='./mnist/',\n", " train=True, # this is training data\n", " transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to\n", " # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]\n", " download=DOWNLOAD_MNIST, # download it if you don't have it\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([60000, 28, 28])\n", "torch.Size([60000])\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADrpJREFUeJzt3X2sVHV+x/HPp6hpxAekpkhYLYsxGDWWbRAbQ1aNYX2I\nRlFjltSERiP7hyRu0pAa+sdqWqypD81SzQY26kKzdd1EjehufKiobGtCvCIq4qKu0SzkCjWIAj5Q\nuN/+cYftXb3zm8vMmTnD/b5fyeTOnO+cOd+c8OE8zvwcEQKQz5/U3QCAehB+ICnCDyRF+IGkCD+Q\nFOEHkiL8QFKEH6Oy/aLtL23vaTy21N0TqkX4UbI4Io5pPGbW3QyqRfiBpAg/Sv7Z9se2/9v2BXU3\ng2qZe/sxGtvnStosaZ+k70u6T9KsiPhdrY2hMoQfY2L7aUm/ioh/q7sXVIPdfoxVSHLdTaA6hB/f\nYHuS7Ytt/6ntI2z/jaTvSnq67t5QnSPqbgB96UhJ/yTpdEkHJP1W0lUR8U6tXaFSHPMDSbHbDyRF\n+IGkCD+QFOEHkurp2X7bnF0EuiwixnQ/RkdbftuX2N5i+z3bt3byWQB6q+1LfbYnSHpH0jxJWyW9\nImlBRGwuzMOWH+iyXmz550h6LyLej4h9kn4h6coOPg9AD3US/mmSfj/i9dbGtD9ie5HtAdsDHSwL\nQMW6fsIvIlZKWimx2w/0k062/NsknTzi9bca0wAcBjoJ/yuSTrP9bdtHafgHH9ZU0xaAbmt7tz8i\n9tteLOkZSRMkPRgRb1XWGYCu6um3+jjmB7qvJzf5ADh8EX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQf\nSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKE\nH0iK8ANJEX4gKcIPJEX4gaQIP5BU20N04/AwYcKEYv3444/v6vIXL17ctHb00UcX5505c2axfvPN\nNxfrd999d9PaggULivN++eWXxfqdd95ZrN9+++3Fej/oKPy2P5C0W9IBSfsjYnYVTQHoviq2/BdG\nxMcVfA6AHuKYH0iq0/CHpGdtv2p70WhvsL3I9oDtgQ6XBaBCne72z42Ibbb/XNJztn8bEetGviEi\nVkpaKUm2o8PlAahIR1v+iNjW+LtD0uOS5lTRFIDuazv8tifaPvbgc0nfk7SpqsYAdFcnu/1TJD1u\n++Dn/EdEPF1JV+PMKaecUqwfddRRxfp5551XrM+dO7dpbdKkScV5r7nmmmK9Tlu3bi3Wly9fXqzP\nnz+/aW337t3FeV9//fVi/aWXXirWDwdthz8i3pf0lxX2AqCHuNQHJEX4gaQIP5AU4QeSIvxAUo7o\n3U134/UOv1mzZhXra9euLda7/bXafjU0NFSs33DDDcX6nj172l724OBgsf7JJ58U61u2bGl72d0W\nER7L+9jyA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBSXOevwOTJk4v19evXF+szZsyosp1Ktep9165d\nxfqFF17YtLZv377ivFnvf+gU1/kBFBF+ICnCDyRF+IGkCD+QFOEHkiL8QFIM0V2BnTt3FutLliwp\n1i+//PJi/bXXXivWW/2EdcnGjRuL9Xnz5hXre/fuLdbPPPPMprVbbrmlOC+6iy0/kBThB5Ii/EBS\nhB9IivADSRF+ICnCDyTF9/n7wHHHHVestxpOesWKFU1rN954Y3He66+/vlh/+OGHi3X0n8q+z2/7\nQds7bG8aMW2y7edsv9v4e0InzQLovbHs9v9M0iVfm3arpOcj4jRJzzdeAziMtAx/RKyT9PX7V6+U\ntKrxfJWkqyruC0CXtXtv/5SIODjY2UeSpjR7o+1Fkha1uRwAXdLxF3siIkon8iJipaSVEif8gH7S\n7qW+7banSlLj747qWgLQC+2Gf42khY3nCyU9UU07AHql5W6/7YclXSDpRNtbJf1I0p2Sfmn7Rkkf\nSrqum02Od5999llH83/66adtz3vTTTcV64888kixPjQ01PayUa+W4Y+IBU1KF1XcC4Ae4vZeICnC\nDyRF+IGkCD+QFOEHkuIrvePAxIkTm9aefPLJ4rznn39+sX7ppZcW688++2yxjt5jiG4ARYQfSIrw\nA0kRfiApwg8kRfiBpAg/kBTX+ce5U089tVjfsGFDsb5r165i/YUXXijWBwYGmtbuv//+4ry9/Lc5\nnnCdH0AR4QeSIvxAUoQfSIrwA0kRfiApwg8kxXX+5ObPn1+sP/TQQ8X6scce2/ayly5dWqyvXr26\nWB8cHCzWs+I6P4Aiwg8kRfiBpAg/kBThB5Ii/EBShB9Iiuv8KDrrrLOK9XvvvbdYv+ii9gdzXrFi\nRbG+bNmyYn3btm1tL/twVtl1ftsP2t5he9OIabfZ3mZ7Y+NxWSfNAui9sez2/0zSJaNM/9eImNV4\n/LratgB0W8vwR8Q6STt70AuAHurkhN9i2280DgtOaPYm24tsD9hu/mNuAHqu3fD/RNKpkmZJGpR0\nT7M3RsTKiJgdEbPbXBaALmgr/BGxPSIORMSQpJ9KmlNtWwC6ra3w25464uV8SZuavRdAf2p5nd/2\nw5IukHSipO2SftR4PUtSSPpA0g8iouWXq7nOP/5MmjSpWL/iiiua1lr9VoBdvly9du3aYn3evHnF\n+ng11uv8R4zhgxaMMvmBQ+4IQF/h9l4gKcIPJEX4gaQIP5AU4QeS4iu9qM1XX31VrB9xRPli1P79\n+4v1iy++uGntxRdfLM57OOOnuwEUEX4gKcIPJEX4gaQIP5AU4QeSIvxAUi2/1Yfczj777GL92muv\nLdbPOeecprVW1/Fb2bx5c7G+bt26jj5/vGPLDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcZ1/nJs5\nc2axvnjx4mL96quvLtZPOumkQ+5prA4cOFCsDw6Wfy1+aGioynbGHbb8QFKEH0iK8ANJEX4gKcIP\nJEX4gaQIP5BUy+v8tk+WtFrSFA0Pyb0yIn5se7KkRyRN1/Aw3ddFxCfdazWvVtfSFywYbSDlYa2u\n40+fPr2dlioxMDBQrC9btqxYX7NmTZXtpDOWLf9+SX8XEWdI+mtJN9s+Q9Ktkp6PiNMkPd94DeAw\n0TL8ETEYERsaz3dLelvSNElXSlrVeNsqSVd1q0kA1TukY37b0yV9R9J6SVMi4uD9lR9p+LAAwGFi\nzPf22z5G0qOSfhgRn9n/PxxYRESzcfhsL5K0qNNGAVRrTFt+20dqOPg/j4jHGpO3257aqE+VtGO0\neSNiZUTMjojZVTQMoBotw+/hTfwDkt6OiHtHlNZIWth4vlDSE9W3B6BbWg7RbXuupN9IelPSwe9I\nLtXwcf8vJZ0i6UMNX+rb2eKzUg7RPWVK+XTIGWecUazfd999xfrpp59+yD1VZf369cX6XXfd1bT2\nxBPl7QVfyW3PWIfobnnMHxH/JanZh110KE0B6B/c4QckRfiBpAg/kBThB5Ii/EBShB9Iip/uHqPJ\nkyc3ra1YsaI476xZs4r1GTNmtNVTFV5++eVi/Z577inWn3nmmWL9iy++OOSe0Bts+YGkCD+QFOEH\nkiL8QFKEH0iK8ANJEX4gqTTX+c8999xifcmSJcX6nDlzmtamTZvWVk9V+fzzz5vWli9fXpz3jjvu\nKNb37t3bVk/of2z5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpNNf558+f31G9E5s3by7Wn3rqqWJ9\n//79xXrpO/e7du0qzou82PIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKOiPIb7JMlrZY0RVJIWhkR\nP7Z9m6SbJP1P461LI+LXLT6rvDAAHYsIj+V9Ywn/VElTI2KD7WMlvSrpKknXSdoTEXePtSnCD3Tf\nWMPf8g6/iBiUNNh4vtv225Lq/ekaAB07pGN+29MlfUfS+sakxbbfsP2g7ROazLPI9oDtgY46BVCp\nlrv9f3ijfYyklyQti4jHbE+R9LGGzwP8o4YPDW5o8Rns9gNdVtkxvyTZPlLSU5KeiYh7R6lPl/RU\nRJzV4nMIP9BlYw1/y91+25b0gKS3Rwa/cSLwoPmSNh1qkwDqM5az/XMl/UbSm5KGGpOXSlogaZaG\nd/s/kPSDxsnB0mex5Qe6rNLd/qoQfqD7KtvtBzA+EX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrw\nA0kRfiApwg8kRfiBpAg/kBThB5Lq9RDdH0v6cMTrExvT+lG/9tavfUn01q4qe/uLsb6xp9/n/8bC\n7YGImF1bAwX92lu/9iXRW7vq6o3dfiApwg8kVXf4V9a8/JJ+7a1f+5LorV219FbrMT+A+tS95QdQ\nE8IPJFVL+G1fYnuL7fds31pHD83Y/sD2m7Y31j2+YGMMxB22N42YNtn2c7bfbfwddYzEmnq7zfa2\nxrrbaPuymno72fYLtjfbfsv2LY3pta67Ql+1rLeeH/PbniDpHUnzJG2V9IqkBRGxuaeNNGH7A0mz\nI6L2G0Jsf1fSHkmrDw6FZvtfJO2MiDsb/3GeEBF/3ye93aZDHLa9S701G1b+b1XjuqtyuPsq1LHl\nnyPpvYh4PyL2SfqFpCtr6KPvRcQ6STu/NvlKSasaz1dp+B9PzzXprS9ExGBEbGg83y3p4LDyta67\nQl+1qCP80yT9fsTrrapxBYwiJD1r+1Xbi+puZhRTRgyL9pGkKXU2M4qWw7b30teGle+bddfOcPdV\n44TfN82NiL+SdKmkmxu7t30pho/Z+ula7U8knarhMRwHJd1TZzONYeUflfTDiPhsZK3OdTdKX7Ws\ntzrCv03SySNef6sxrS9ExLbG3x2SHtfwYUo/2X5whOTG3x019/MHEbE9Ig5ExJCkn6rGddcYVv5R\nST+PiMcak2tfd6P1Vdd6qyP8r0g6zfa3bR8l6fuS1tTQxzfYntg4ESPbEyV9T/039PgaSQsbzxdK\neqLGXv5Ivwzb3mxYedW87vpuuPuI6PlD0mUaPuP/O0n/UEcPTfqaIen1xuOtunuT9LCGdwP/V8Pn\nRm6U9GeSnpf0rqT/lDS5j3r7dw0P5f6GhoM2tabe5mp4l/4NSRsbj8vqXneFvmpZb9zeCyTFCT8g\nKcIPJEX4gaQIP5AU4QeSIvxAUoQfSOr/AH6evjIXWuv8AAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot one example\n", "print(train_data.train_data.size()) # (60000, 28, 28)\n", "print(train_data.train_labels.size()) # (60000)\n", "plt.imshow(train_data.train_data[0].numpy(), cmap='gray')\n", "plt.title('%i' % train_data.train_labels[0])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Data Loader for easy mini-batch return in training\n", "train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# convert test data into Variable, pick 2000 samples to speed up testing\n", "test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())\n", "test_x = Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)\n", "test_y = test_data.test_labels.numpy().squeeze()[:2000] # covert to numpy array" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class RNN(nn.Module):\n", " def __init__(self):\n", " super(RNN, self).__init__()\n", "\n", " self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns\n", " input_size=INPUT_SIZE,\n", " hidden_size=64, # rnn hidden unit\n", " num_layers=1, # number of rnn layer\n", " batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)\n", " )\n", "\n", " self.out = nn.Linear(64, 10)\n", "\n", " def forward(self, x):\n", " # x shape (batch, time_step, input_size)\n", " # r_out shape (batch, time_step, output_size)\n", " # h_n shape (n_layers, batch, hidden_size)\n", " # h_c shape (n_layers, batch, hidden_size)\n", " r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state\n", "\n", " # choose r_out at the last time step\n", " out = self.out(r_out[:, -1, :])\n", " return out" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNN (\n", " (rnn): LSTM(28, 64, batch_first=True)\n", " (out): Linear (64 -> 10)\n", ")\n" ] } ], "source": [ "rnn = RNN()\n", "print(rnn)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters\n", "loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0 | train loss: 2.3127 | test accuracy: 0.14\n", "Epoch: 0 | train loss: 1.1182 | test accuracy: 0.56\n", "Epoch: 0 | train loss: 0.9374 | test accuracy: 0.68\n", "Epoch: 0 | train loss: 0.6279 | test accuracy: 0.75\n", "Epoch: 0 | train loss: 0.8004 | test accuracy: 0.75\n", "Epoch: 0 | train loss: 0.3407 | test accuracy: 0.88\n", "Epoch: 0 | train loss: 0.3342 | test accuracy: 0.89\n", "Epoch: 0 | train loss: 0.3200 | test accuracy: 0.91\n", "Epoch: 0 | train loss: 0.3430 | test accuracy: 0.92\n", "Epoch: 0 | train loss: 0.1234 | test accuracy: 0.93\n", "Epoch: 0 | train loss: 0.2714 | test accuracy: 0.92\n", "Epoch: 0 | train loss: 0.1043 | test accuracy: 0.93\n", "Epoch: 0 | train loss: 0.2893 | test accuracy: 0.93\n", "Epoch: 0 | train loss: 0.0635 | test accuracy: 0.94\n", "Epoch: 0 | train loss: 0.2412 | test accuracy: 0.94\n", "Epoch: 0 | train loss: 0.1203 | test accuracy: 0.92\n", "Epoch: 0 | train loss: 0.0807 | test accuracy: 0.94\n", "Epoch: 0 | train loss: 0.1307 | test accuracy: 0.94\n", "Epoch: 0 | train loss: 0.1166 | test accuracy: 0.95\n" ] } ], "source": [ "# training and testing\n", "for epoch in range(EPOCH):\n", " for step, (x, y) in enumerate(train_loader): # gives batch data\n", " b_x = Variable(x.view(-1, 28, 28)) # reshape x to (batch, time_step, input_size)\n", " b_y = Variable(y) # batch y\n", "\n", " output = rnn(b_x) # rnn output\n", " loss = loss_func(output, b_y) # cross entropy loss\n", " optimizer.zero_grad() # clear gradients for this training step\n", " loss.backward() # backpropagation, compute gradients\n", " optimizer.step() # apply gradients\n", "\n", " if step % 50 == 0:\n", " test_output = rnn(test_x) # (samples, time_step, input_size)\n", " pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()\n", " accuracy = sum(pred_y == test_y) / float(test_y.size)\n", " print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[7 2 1 0 4 1 4 9 5 9] prediction number\n", "[7 2 1 0 4 1 4 9 5 9] real number\n" ] } ], "source": [ "# print 10 predictions from test data\n", "test_output = rnn(test_x[:10].view(-1, 28, 28))\n", "pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()\n", "print(pred_y, 'prediction number')\n", "print(test_y[:10], 'real number')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }