{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Detect MNIST numbers Using Keras and PlaidML" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using plaidml.keras.backend backend.\n" ] } ], "source": [ "import os\n", "import time\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"plaidml.keras.backend\" # Using PlaidML Backend\n", "\n", "import keras\n", "from keras.datasets import mnist\n", "from keras.models import Sequential, load_model\n", "from keras.layers import Dense, Dropout, Flatten\n", "from keras.layers import Conv2D, MaxPooling2D\n", "from keras import backend as K" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1) Prepare the data\n", "ทำการโหลด Data MNIST จากเว็บไซต์\n", "http://yann.lecun.com/exdb/mnist/ มาไว้ที่ Folder เดียวกันกับ Notebook และทำการแตก gzip" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "image_size = 28\n", "num_images = 60000\n", "\n", "with open('./data/mnist/train-images-idx3-ubyte', 'rb') as fp:\n", " fp.read(16)\n", " buf = fp.read(image_size * image_size * num_images)\n", " data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)\n", " x_train = data.reshape(num_images, image_size, image_size, 1)\n", " \n", "with open('./data/mnist/train-labels-idx1-ubyte', 'rb') as fp:\n", " fp.read(8)\n", " buf = fp.read(num_images)\n", " y_train = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)\n", " \n", "num_images = 10000\n", " \n", "with open('./data/mnist/t10k-images-idx3-ubyte', 'rb') as fp:\n", " fp.read(16)\n", " buf = fp.read(image_size * image_size * num_images)\n", " data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)\n", " x_test = data.reshape(num_images, image_size, image_size, 1)\n", " \n", "with open('./data/mnist/t10k-labels-idx1-ubyte', 'rb') as fp:\n", " fp.read(8)\n", " buf = fp.read(num_images)\n", " y_test = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQNklEQVR4nO3de5Rd5V3G8e+TyZAbwU64NeQClMYqbW3QMUFxCUgLNDULsAabJTS1SPgDarvkIouqsFzaha2ASAs1NUiglRYFJEpWbZrWRaHcBkwhNNyEQEJCEgi3AAnJ5OcfZ8cO6Zz3zJx9bpn3+aw1a87Zv7P3/s2ZeWafc959zquIwMxGvlHtbsDMWsNhN8uEw26WCYfdLBMOu1kmHHazTDjsGZP035L+uNXrWns47COApDWSPtruPoZC0g8khaTR7e4lNw67tYykPwQc8jZx2EcwST2S/lPSZkmvFJen7nGzIyQ9IOk1SXdImjRg/aMl/VjSq5J+Ium4Er38AnApcFG927ByHPaRbRTwz8ChwHTgbeCre9zm08BngUOAncA/AEiaAtwJ/DUwCbgAuFXSgXvuRNL04h/C9EQvXwKuA14s8wNZ/Rz2ESwiXo6IWyPirYh4A/gb4Ng9bnZTRKyKiDeBvwBOl9QFnAEsi4hlEbErIpYDfcCcQfbzfES8JyKeH6wPSb3AMcA1DfzxbJj8/GkEkzQeuAo4GegpFk+U1BUR/cX1tQNWeQ7oBg6g8mhgnqS5A+rdwA+H2cMo4Frg8xGxU9LwfxBrCId9ZDsf+AAwOyJelDQT+B9gYOKmDbg8HdgBvETln8BNEXF2yR72A3qB7xRB7yqWr5M0LyJ+VHL7NkQO+8jRLWnsgOs7gYlUnqe/Wrzwdukg650h6UZgDfBXwL9FRL+kbwIPSjoJ+D6Vo/rRwNMRsW4Yfb1G5fWA3aYBDwC/BmwexnasJD9nHzmWUQn27q/LgL8HxlE5Ut8HfHeQ9W4CbqDywtlY4E8AImItcApwCZVQrgUuZJC/meIFuq2DvUAXFS/u/uJnAd8YEe/U+8Pa8MkfXmGWBx/ZzTLhsJtlwmE3y4TDbpaJlg697aMxMZYJrdylWVa28SbvxPZBz1wqFXZJJwNXUzlR4p8i4vLU7ccygdk6ocwuzSzh/lhRtVb3w/ji/OmvAR8HjgTmSzqy3u2ZWXOVec4+i8rZVM8UJ0d8m8pJGGbWgcqEfQrvfhPFumLZu0haKKlPUt8OtpfYnZmVUSbsg70I8HOn40XEoojojYjebsaU2J2ZlVEm7Ot49zumpgLry7VjZs1SJuwPAjMkHS5pH+BTwNLGtGVmjVb30FvxQQTnAf9FZejt+oh4rGGdmVlDlRpnj4hlVN5aaWYdzqfLmmXCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJkrN4mo2esohyfqZP7i3am3evi8n1537weOT9f5XX0vW7d1KhV3SGuANoB/YGRG9jWjKzBqvEUf24yPipQZsx8yayM/ZzTJRNuwBfE/SQ5IWDnYDSQsl9Unq28H2krszs3qVfRh/TESsl3QQsFzS4xFx18AbRMQiYBHAfpoUJfdnZnUqdWSPiPXF903A7cCsRjRlZo1Xd9glTZA0cfdl4ERgVaMaM7PGKvMw/mDgdkm7t/MvEfHdhnRle42f/uWUZP2T+1YfqFn+9vjkurFjZ1092eDqDntEPAN8pIG9mFkTeejNLBMOu1kmHHazTDjsZplw2M0y4be4WtKT16bPk3r6d7+erO9K1L50wYLkuuPefCBZt+Hxkd0sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TH2Uc4de+TrD9xzcxk/fG5X6uxh65hdvQzYzf5Y8payUd2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTHmcfATS6+q+x1jj6k3Ovq7F1Hw9GCv8mzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMeJx9L9B15C8m66sv2K9q7cmTao2jN9c927qr1ka//GZy3f5GN5O5mkd2SddL2iRp1YBlkyQtl/RU8b2nuW2aWVlDeRh/A3DyHssuBlZExAxgRXHdzDpYzbBHxF3Alj0WnwIsKS4vAU5tcF9m1mD1vkB3cERsACi+H1TthpIWSuqT1LcDf+aYWbs0/dX4iFgUEb0R0dvNmGbvzsyqqDfsGyVNBii+b2pcS2bWDPWGfSmwe77dBcAdjWnHzJql5ji7pJuB44ADJK0DLgUuB26RdBbwPDCvmU2OdFv+6DeS9bMvSv8vvWO/tVVrr+3allx39r+en6xfMfebyfrc8a8n69duOL5qrf+Jp5PrWmPVDHtEzK9SOqHBvZhZE/l0WbNMOOxmmXDYzTLhsJtlwmE3y4Tf4toC6y/6zWT9ws/ekqzPn7gxWU8Nr51x2jnJdd/fd1+yvu0T1d+iOhTPvrp/1dokXi61bRseH9nNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0x4nH2IRk+bWrW2+XemJddd8bmvJOs9o8Ym62eu+Viyvv4r769aG9f3QHLd7XN+PVk/dtzdyTqMT1ZH3VJ9nN1ay0d2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTHmcvdB14YLJ+6G0vVa39+yG1PjY/PY7+4Xs+k6wfcf4ryfq4temx9JS1H+1K1g/oGlf3tvdmW+fNTta3HJm+3w6/9qlkvX/z5mH3VJaP7GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJjzOXpj8H28n69cc8uNEVcl1N/W/laz33D4hWd/VsyNZH9XzS1VremFTct3Zs59Ib7vGz7b87fQ4/KSViXMEfqV630Px9tSJyfpzvx9Va5/48KPJdb865R+T9XU7tybrn3z2wmT9PTd24Di7pOslbZK0asCyyyS9IGll8TWnuW2aWVlDeRh/A3DyIMuvioiZxdeyxrZlZo1WM+wRcRewpQW9mFkTlXmB7jxJjxQP83uq3UjSQkl9kvp2sL3E7sysjHrDfh1wBDAT2ABcUe2GEbEoInojorebMXXuzszKqivsEbExIvojYhfwDWBWY9sys0arK+ySJg+4ehqwqtptzawzKKL6WCSApJuB44ADgI3ApcX1mUAAa4BzImJDrZ3tp0kxWyeUarhZ1n4xPYf67837UdXapQeubHQ7DfO3L38wWT9xYnq8+ah90seDWuPwu0j/fXWqr7/6vmR9yZXp0eb9F9/byHaG7P5YweuxZdBfSs2TaiJi/iCLF5fuysxayqfLmmXCYTfLhMNulgmH3SwTDrtZJmoOvTVSJw+91dLVU/WMYF78g/RbNbcem36L658fdWeyPn/ixmS9nZo59PZSf/ptx8fec27d2566OD0QNW7VumR954ud+TtJDb35yG6WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZcLj7B1g1PjxybpGp8eEn/3TD1WtbZv+TnLdJ09Kf2RyLUvfrH7+AcDi3pmltp/S//rrTdv23srj7GbmsJvlwmE3y4TDbpYJh90sEw67WSYcdrNMeMrmDrDrrfT73WuZfln16aS7Dj4oue41s2Yk65/reSpZ3xFdybrHwjuHj+xmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSZqjrNLmgbcCLwX2AUsioirJU0CvgMcRmXa5tMj4pXmtWr16N+4KVl/4LXD0huoMc5+yfdPT9ZncH96+9YyQzmy7wTOj4hfBo4GzpV0JHAxsCIiZgAriutm1qFqhj0iNkTEw8XlN4DVwBTgFGBJcbMlwKnNatLMyhvWc3ZJhwFHAfcDB0fEBqj8QwDS52WaWVsNOeyS9gVuBb4QEUM+4VnSQkl9kvp2sL2eHs2sAYYUdkndVIL+rYi4rVi8UdLkoj4ZGPSVoIhYFBG9EdHbzZhG9GxmdagZdkkCFgOrI+LKAaWlwILi8gLgjsa3Z2aNMpS3uB4DnAk8KmllsewS4HLgFklnAc8D85rTopUx+tBpyfqc/e8ttf3pd+4qtb61Ts2wR8TdUHUSbn8IvNlewmfQmWXCYTfLhMNulgmH3SwTDrtZJhx2s0z4o6RHuB1T90/W50/cWGr7Y5Y9WGp9ax0f2c0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTPj97CNc10OPJ+sfWHF2sn7mRzzl8kjhI7tZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNuloma4+ySpgE3Au8FdgGLIuJqSZcBZwObi5teEhHLmtWo1WfXtm3J+oxPP5ys30d3I9uxNhrKSTU7gfMj4mFJE4GHJC0valdFxN81rz0za5SaYY+IDcCG4vIbklYDU5rdmJk11rCes0s6DDgK2H0O5XmSHpF0vaSeKusslNQnqW8H20s1a2b1G3LYJe0L3Ap8ISJeB64DjgBmUjnyXzHYehGxKCJ6I6K3mzENaNnM6jGksEvqphL0b0XEbQARsTEi+iNiF/ANYFbz2jSzsmqGXZKAxcDqiLhywPLJA252GrCq8e2ZWaMM5dX4Y4AzgUclrSyWXQLMlzQTCGANcE5TOjSzhhjKq/F3Axqk5DF1s72Iz6Azy4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmVBEtG5n0mbguQGLDgBealkDw9OpvXVqX+De6tXI3g6NiAMHK7Q07D+3c6kvInrb1kBCp/bWqX2Be6tXq3rzw3izTDjsZplod9gXtXn/KZ3aW6f2Be6tXi3pra3P2c2sddp9ZDezFnHYzTLRlrBLOlnSE5KelnRxO3qoRtIaSY9KWimpr829XC9pk6RVA5ZNkrRc0lPF90Hn2GtTb5dJeqG471ZKmtOm3qZJ+qGk1ZIek/T5Ynlb77tEXy2531r+nF1SF/Ak8DFgHfAgMD8iftrSRqqQtAbojYi2n4Ah6beBrcCNEfGhYtmXgS0RcXnxj7InIv6sQ3q7DNja7mm8i9mKJg+cZhw4FfgMbbzvEn2dTgvut3Yc2WcBT0fEMxHxDvBt4JQ29NHxIuIuYMsei08BlhSXl1D5Y2m5Kr11hIjYEBEPF5ffAHZPM97W+y7RV0u0I+xTgLUDrq+js+Z7D+B7kh6StLDdzQzi4IjYAJU/HuCgNvezp5rTeLfSHtOMd8x9V8/052W1I+yDTSXVSeN/x0TErwIfB84tHq7a0AxpGu9WGWSa8Y5Q7/TnZbUj7OuAaQOuTwXWt6GPQUXE+uL7JuB2Om8q6o27Z9Atvm9qcz//r5Om8R5smnE64L5r5/Tn7Qj7g8AMSYdL2gf4FLC0DX38HEkTihdOkDQBOJHOm4p6KbCguLwAuKONvbxLp0zjXW2acdp837V9+vOIaPkXMIfKK/L/C3yxHT1U6et9wE+Kr8fa3RtwM5WHdTuoPCI6C9gfWAE8VXyf1EG93QQ8CjxCJViT29Tbb1F5avgIsLL4mtPu+y7RV0vuN58ua5YJn0FnlgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2Xi/wAvJuOMRRR42gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "n = 150\n", "plt.title(f'Label: {y_train[n]}')\n", "plt.imshow(x_train[n].squeeze())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2) Train the model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train shape: (60000, 28, 28, 1)\n", "60000 train samples\n", "10000 test samples\n", "Train on 60000 samples, validate on 10000 samples\n", "Epoch 1/20\n", "60000/60000 [==============================] - 16s 268us/step - loss: 0.2619 - acc: 0.9199 - val_loss: 0.0571 - val_acc: 0.9817\n", "Epoch 2/20\n", "60000/60000 [==============================] - 12s 203us/step - loss: 0.0910 - acc: 0.9731 - val_loss: 0.0431 - val_acc: 0.9852\n", "Epoch 3/20\n", "60000/60000 [==============================] - 12s 205us/step - loss: 0.0670 - acc: 0.9800 - val_loss: 0.0364 - val_acc: 0.9870\n", "Epoch 4/20\n", "60000/60000 [==============================] - 12s 200us/step - loss: 0.0529 - acc: 0.9839 - val_loss: 0.0332 - val_acc: 0.9894\n", "Epoch 5/20\n", "60000/60000 [==============================] - 12s 205us/step - loss: 0.0481 - acc: 0.9852 - val_loss: 0.0364 - val_acc: 0.9887\n", "Epoch 6/20\n", "60000/60000 [==============================] - 12s 204us/step - loss: 0.0419 - acc: 0.9873 - val_loss: 0.0295 - val_acc: 0.9905\n", "Epoch 7/20\n", "60000/60000 [==============================] - 12s 202us/step - loss: 0.0394 - acc: 0.9878 - val_loss: 0.0289 - val_acc: 0.9902\n", "Epoch 8/20\n", "60000/60000 [==============================] - 12s 202us/step - loss: 0.0347 - acc: 0.9895 - val_loss: 0.0290 - val_acc: 0.9904\n", "Epoch 9/20\n", "60000/60000 [==============================] - 12s 201us/step - loss: 0.0322 - acc: 0.9896 - val_loss: 0.0288 - val_acc: 0.9904\n", "Epoch 10/20\n", "60000/60000 [==============================] - 12s 202us/step - loss: 0.0282 - acc: 0.9913 - val_loss: 0.0300 - val_acc: 0.9919\n", "Epoch 11/20\n", "60000/60000 [==============================] - 12s 204us/step - loss: 0.0276 - acc: 0.9917 - val_loss: 0.0300 - val_acc: 0.9911\n", "Epoch 12/20\n", "60000/60000 [==============================] - 12s 202us/step - loss: 0.0252 - acc: 0.9918 - val_loss: 0.0260 - val_acc: 0.9921\n", "Epoch 13/20\n", "60000/60000 [==============================] - 12s 201us/step - loss: 0.0235 - acc: 0.9925 - val_loss: 0.0294 - val_acc: 0.9905\n", "Epoch 14/20\n", "60000/60000 [==============================] - 12s 205us/step - loss: 0.0225 - acc: 0.9926 - val_loss: 0.0348 - val_acc: 0.9899\n", "Epoch 15/20\n", "60000/60000 [==============================] - 12s 204us/step - loss: 0.0220 - acc: 0.9927 - val_loss: 0.0264 - val_acc: 0.9922\n", "Epoch 16/20\n", "60000/60000 [==============================] - 12s 205us/step - loss: 0.0211 - acc: 0.9932 - val_loss: 0.0337 - val_acc: 0.9908\n", "Epoch 17/20\n", "60000/60000 [==============================] - 12s 198us/step - loss: 0.0213 - acc: 0.9936 - val_loss: 0.0289 - val_acc: 0.9917\n", "Epoch 18/20\n", "60000/60000 [==============================] - 12s 201us/step - loss: 0.0222 - acc: 0.9926 - val_loss: 0.0275 - val_acc: 0.9917\n", "Epoch 19/20\n", "60000/60000 [==============================] - 12s 203us/step - loss: 0.0197 - acc: 0.9940 - val_loss: 0.0291 - val_acc: 0.9912\n", "Epoch 20/20\n", "60000/60000 [==============================] - 12s 201us/step - loss: 0.0179 - acc: 0.9939 - val_loss: 0.0334 - val_acc: 0.9907\n", "Test loss: 0.03335125161111355\n", "Test accuracy: 0.9907\n" ] } ], "source": [ "batch_size = 128\n", "num_classes = 10\n", "epochs = 20\n", "\n", "img_rows, img_cols = image_size, image_size\n", "\n", "# Reshape the image \n", "x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n", "x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n", "input_shape = (img_rows, img_cols, 1) \n", "\n", "x_train = x_train.astype('float32')\n", "x_test = x_test.astype('float32')\n", "x_train /= 255\n", "x_test /= 255\n", "print('x_train shape:', x_train.shape)\n", "print(x_train.shape[0], 'train samples')\n", "print(x_test.shape[0], 'test samples')\n", "\n", "# convert class vectors to binary class matrices\n", "y_train = keras.utils.to_categorical(y_train, num_classes)\n", "y_test = keras.utils.to_categorical(y_test, num_classes)\n", "\n", "# our model\n", "model = Sequential()\n", "model.add(Conv2D(32, kernel_size=(3, 3),\n", " activation='relu',\n", " input_shape=input_shape))\n", "model.add(Conv2D(64, (3, 3), activation='relu'))\n", "model.add(MaxPooling2D(pool_size=(2, 2)))\n", "model.add(Dropout(0.25))\n", "model.add(Flatten())\n", "model.add(Dense(128, activation='relu'))\n", "model.add(Dropout(0.5))\n", "model.add(Dense(num_classes, activation='softmax'))\n", "\n", "model.compile(loss=keras.losses.categorical_crossentropy,\n", " optimizer=keras.optimizers.Adadelta(),\n", " metrics=['accuracy'])\n", "\n", "model.fit(x_train, y_train,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " verbose=1,\n", " validation_data=(x_test, y_test))\n", "score = model.evaluate(x_test, y_test, verbose=0)\n", "print('Test loss:', score[0])\n", "print('Test accuracy:', score[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3) Predict an image" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(model.predict(x_test[6].reshape(1, 28, 28, 1)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4) Save the model for later usage" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "model.save('mnist_cnn.h5')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }