{ "cells": [ { "cell_type": "markdown", "metadata": { "_cell_guid": "84d4608d-4cc3-fcbb-57fb-61f07ad7d020", "_uuid": "6407080d145a62b4803d7f159c00118a056a7b5f" }, "source": [ "# Deep Neural Network training on MNIST\n", "\n", "This notebook is based on this [Kaggle project](https://www.kaggle.com/kernels/scriptcontent/4482867/download), adapted to fit into Cloudera [AI to EDGE demo](https://github.com/paulvid/ai_to_edge)" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "654456b6-e648-0379-0d66-1cc97af6d00d", "_uuid": "6b48ce0e361bdb67689dd2f254ecedd9ade1f5ff" }, "source": [ "**Import all required libraries**\n", "===============================" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "_cell_guid": "e5b02688-c589-5a89-e11c-837c6a99eb6e", "_uuid": "f043e48097bfd98e41710142dd8aac41fa88a801" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "test.csv\n", "test.csv.gz\n", "train.csv.gz\n", "\n" ] } ], "source": [ "# This Python 3 environment comes with many helpful analytics libraries installed\n", "# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python\n", "# For example, here's several helpful packages to load in \n", "\n", "import numpy as np # linear algebra\n", "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", "\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "from keras.models import Sequential\n", "from keras.layers import Dense , Dropout , Lambda, Flatten\n", "from keras.optimizers import Adam ,RMSprop\n", "from sklearn.model_selection import train_test_split\n", "from keras import backend as K\n", "from keras.preprocessing.image import ImageDataGenerator\n", "\n", "\n", "import gzip\n", "import shutil\n", "\n", "# Input data files are available in the \"./input/\" directory.\n", "# For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory\n", "\n", "\n", "from subprocess import check_output\n", "print(check_output([\"ls\", \"./input\"]).decode(\"utf8\"))\n", "\n", "# adding Open Neural Network Exchange (ONNX) \n", "import onnxruntime" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "22a7fd70-ab61-432d-24cb-93e558414495", "_uuid": "62fbd0fe9c338b7ac0b04e688c8ee7947e6170f7" }, "source": [ "**Load Train and Test data**\n", "============================" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# unzipping files\n", " \n", "with gzip.open('./input/train.csv.gz', 'rb') as f_in:\n", " with open('./input/train.csv', 'wb') as f_out:\n", " shutil.copyfileobj(f_in, f_out)\n", " \n", "with gzip.open('./input/test.csv.gz', 'rb') as f_in:\n", " with open('./input/test.csv', 'wb') as f_out:\n", " shutil.copyfileobj(f_in, f_out)\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "_cell_guid": "05226b08-226a-1a00-044d-a0e6b2101388", "_uuid": "4eff577bcd43479a3b7e91180393cbad9fcfca33" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(42000, 785)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labelpixel0pixel1pixel2pixel3pixel4pixel5pixel6pixel7pixel8...pixel774pixel775pixel776pixel777pixel778pixel779pixel780pixel781pixel782pixel783
01000000000...0000000000
10000000000...0000000000
21000000000...0000000000
34000000000...0000000000
40000000000...0000000000
\n", "

5 rows × 785 columns

\n", "
" ], "text/plain": [ " label pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 \\\n", "0 1 0 0 0 0 0 0 0 0 \n", "1 0 0 0 0 0 0 0 0 0 \n", "2 1 0 0 0 0 0 0 0 0 \n", "3 4 0 0 0 0 0 0 0 0 \n", "4 0 0 0 0 0 0 0 0 0 \n", "\n", " pixel8 ... pixel774 pixel775 pixel776 pixel777 pixel778 pixel779 \\\n", "0 0 ... 0 0 0 0 0 0 \n", "1 0 ... 0 0 0 0 0 0 \n", "2 0 ... 0 0 0 0 0 0 \n", "3 0 ... 0 0 0 0 0 0 \n", "4 0 ... 0 0 0 0 0 0 \n", "\n", " pixel780 pixel781 pixel782 pixel783 \n", "0 0 0 0 0 \n", "1 0 0 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", "[5 rows x 785 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create the training & test sets, skipping the header row with [1:]\n", "train = pd.read_csv(\"./input/train.csv\")\n", "print(train.shape)\n", "train.head()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "_cell_guid": "2ec570a6-b41a-2139-5e0e-4941c4f0a9d0", "_uuid": "67f0854ad0d812a1395130144a0adef9966fec88" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(28000, 784)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pixel0pixel1pixel2pixel3pixel4pixel5pixel6pixel7pixel8pixel9...pixel774pixel775pixel776pixel777pixel778pixel779pixel780pixel781pixel782pixel783
00000000000...0000000000
10000000000...0000000000
20000000000...0000000000
30000000000...0000000000
40000000000...0000000000
\n", "

5 rows × 784 columns

\n", "
" ], "text/plain": [ " pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 pixel8 \\\n", "0 0 0 0 0 0 0 0 0 0 \n", "1 0 0 0 0 0 0 0 0 0 \n", "2 0 0 0 0 0 0 0 0 0 \n", "3 0 0 0 0 0 0 0 0 0 \n", "4 0 0 0 0 0 0 0 0 0 \n", "\n", " pixel9 ... pixel774 pixel775 pixel776 pixel777 pixel778 pixel779 \\\n", "0 0 ... 0 0 0 0 0 0 \n", "1 0 ... 0 0 0 0 0 0 \n", "2 0 ... 0 0 0 0 0 0 \n", "3 0 ... 0 0 0 0 0 0 \n", "4 0 ... 0 0 0 0 0 0 \n", "\n", " pixel780 pixel781 pixel782 pixel783 \n", "0 0 0 0 0 \n", "1 0 0 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", "[5 rows x 784 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test= pd.read_csv(\"./input/test.csv\")\n", "print(test.shape)\n", "test.head()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "_cell_guid": "1ae10fe0-dde9-7659-f53d-1a1bd625cfb1", "_uuid": "bdffbed77ce62da528c60e43f2b1bea9f57fcdbc" }, "outputs": [], "source": [ "X_train = (train.iloc[:,1:].values).astype('float32') # all pixel values\n", "y_train = train.iloc[:,0].values.astype('int32') # only labels i.e targets digits\n", "X_test = test.values.astype('float32')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "_cell_guid": "250b1126-ce1d-6d3f-9736-2504f7a1e098", "_uuid": "5e3e1e3574c3e019eadfd14e4dda41fd15b4de2a" }, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "_cell_guid": "e0f15f8a-ac08-540a-58db-dab989cc687c", "_uuid": "4c96cf1c9cdc364ae3faff6b8c3c97aa7fa982d4" }, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 1, ..., 7, 6, 9], dtype=int32)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "c2c91588-5547-353a-7f92-39600027438e", "_uuid": "f01a969286e62fa5ffe37031ed6d4aea947b59a8" }, "source": [ "The output variable is an integer from 0 to 9. This is a **multiclass** classification problem." ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "60957d82-c76f-4822-28ff-def7011a34fa", "_uuid": "da0573528e8c6c3dd2b0e0cf33c600ad3f14466d" }, "source": [ "## Data Visualization\n", "Lets look at 3 images from data set with their labels." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "_cell_guid": "1541678d-a08b-d2b2-1e1e-eabf882baaec", "_uuid": "7998af1ce3c065c4a54a73cd97fed9afebde7e96" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU4AAABvCAYAAACD1ClOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAQEklEQVR4nO3de4wUZbrH8e+jIip4QxQJgmy8oVFklXg7umK84o2LBiGKaFTUHBNRYxaJxqCeiJooR42u4BLA27oqIBqjKNEoJiAXMbuAuIoLXlBEj4pwlCM854/pt6p6GGampqu7unt+n4RMdVVP10M/02+/9b5vva+5OyIi0no75B2AiEitUcEpIpKSCk4RkZRUcIqIpKSCU0QkJRWcIiIpqeAUEUmp7gtOM/ul0b8tZvZI3nFJ6czsaTNba2Y/m9knZnZ13jFJ6czsHTP7NfGZXZl3TI3VfcHp7p3DP2B/4H+BF3IOS7JxL9Db3fcALgTuMbNjc45JsnFD4rN7WN7BNFb3BWcjFwHrgPfyDkRK5+7L3P238LDw76AcQ5J2or0VnKOA6a77TOuGmT1mZpuAj4G1wGs5hyTZuNfM1pvZ+2Y2IO9gGrP2UoaY2YHAKuBgd/8873gkO2a2I3AiMAC4z93/L9+IpBRmdjywHNgMDAceBfq5+2e5BpbQnmqcI4F5KjTrj7tvcfd5wAHA9XnHI6Vx9wXuvsHdf3P3acD7wLl5x5XUngrOy4FpeQchZbUTauOsRw5Y3kEktYuC08xOAnqg3vS6YWb7mdlwM+tsZjua2dnACGBu3rFJ25nZXmZ2tpntYmY7mdmlwJ+A1/OOLWmnvAOokFHADHffkHcgkhmn4bL8LzRUAFYDY9x9dq5RSak6APcAfYAtNHT6DXb3T3KNqpF20zkkIpKVdnGpLiKSJRWcIiIplVRwmtk5ZrbSzD41s7FZBSX5Ul7rl3KbjTa3cRYGHX8CnAl8CSwERrj78uzCk0pTXuuXcpudUmqcxwGfuvsqd98M/A0YlE1YkiPltX4ptxkpZThSD+CLxOMvgeOb+4WuXbt67969SzhlbVu8ePF6d9837zhaoLymVCN5hZS5VV63n9eyj+M0s9HAaIBevXqxaNGicp+yapnZ6rxjyIryGlNe61NzeS3lUv0roGfi8QGFfUXcfZK793f3/vvuWwtfyu2e8lq/Wsyt8to6pRScC4FDzOwPZrYzDbOY6K6N2qe81i/lNiNtvlR399/N7AbgDWBHYIq7L8ssMsmF8lq/lNvslNTG6e6voYlj647yWr+U22zoziERkZRUcIqIpKSCU0QkJRWcIiIpqeAUEUmpvcwALyI1YsWKFQA88sgjAPz222/RsXXr1gHw6quvFv3OcccdF20PHToUgIEDBwLQt2/fzGNUjVNEJCUVnCIiKelSXarSd999B8SXa/PmzQPg7bff3ua5HTp0AOC8886L9vXp0weAww47rOi5gwcPjrY7d+4MwE476WOQlw0bGtZPHDduXLRv+vTpRceSwvzBZsWrBS9cuHCb7fHjxwMwbNiw6NjUqVMziFo1ThGR1Griq3bGjBkAzJkzB4AhQ4ZEx7p27Vr03F69egGwfv36aN+mTZtaPMe7774LwKxZswA4/PDDo2Ph2zC8tmTj66+/BuKG/hdffDE69uabbxY9t2PHjgA0NT/k1q1bgTh3zbnyyiuj7X79+gEwatQoAG644YbomGqh5bV6dcOMbaeeeioAa9as2eY55557LgA777xztG97Nc6mfPjhhwA8//zz0b4999wTgAceeGCb105DNU4RkZRq4mv1448/BmDSpEkATJ48OTrW+BuoqRrnxo0bi56TXGep8b7wOJwTittfJDuhTXLp0qXbHLvwwgsBOPnkk4seN26zBJg/fz4AAwYMiPY9/PDDQPEwFYAFCxZE28899xwAN910EwDffvttdOzee+9N8T+R1gpDi0aMGAHENc9kDXL48OEAPP300wDssEPb6ne//PILAM8++2y0L1y9hqtQ1ThFRCqkxYLTzKaY2Toz+2diXxcze9PM/lX4uXd5w5SsKa/1S7ktvxaXBzazPwG/ANPd/cjCvvuBH9x9QmFt5r3d/c8tnax///7eljVM7rnnHgD2228/AE455ZTo2HvvvZf69ZLCMJennnoKiC8ZxowZEz3nwQcfLOkcgZktdvf+mbxYiaohr+FS7PvvvweKhxMdfPDBrX6d119/HShunrnsssta/L1wKXfkkUcCsMcee0THFi9eDMRDnZpTTXmF7HLb1rw257rrrgPiZrdQ/owcOTJ6zsSJEwHo0qVLpudOq7m8tljjdPd3gR8a7R4ETCtsTwMGIzVFea1fym35tbVzqJu7ry1sfwN0yyieJr388ssAXHPNNUDxUKHkdlvMnDkTiGuaRxxxBNBuO4QqmtfW1Apb45xzzmnxOUuWLAHiDiGIOxl/+uknAObOnRsda01Ns8ZUNLfb89JLLwFxTTMMD3vooYei54QhQ9Ws5M4hb3gHtnu9b2ajzWyRmS0Kd4NI9VNe61dzuVVeW6etNc5vzay7u681s+7Auu090d0nAZOgoc2kjecDiocIlSIMT4J4OET4Bhw7diyw7cD6diKXvGYlOYtOaJd+8sknAVi1ahUAnTp1ip5zzDHHAPDKK68AtVHTKUGrcluOvL72WrzEUajdhyu8UNNs7r3/8ccfo+3ff/+96Pf32WefLEJMra01ztnAqML2KODlbMKRnCmv9Uu5zVCLNU4zew4YAHQ1sy+BO4EJwN/N7CpgNTBs+6/QNmFOPohrmqGNs1TJmuvKlSuBeA6/8LPe5ZXXpvz6669AXDuEuGbRWPfu3aPttWsbmuzC7XrJORrDvrPPPhuAJ554Aohvs4T6vaqoltyGK4C777472tc4r03VNENeH3/88aKfEI+c2GWXXQAYPXo0EN9CCW0f1J5GiwWnu4/YzqHTM45FKkh5rV/KbfnpziERkZRq4l71rC+pksNgQqdQuKTbbbfdMj2XtCzMhJS80eDzzz9v9e+H+Qluu+22aN9pp50GNH1vu1RGmE8zOT9AcMEFFwBx88x9990XHQu9+T///PN2Xzs0A4T5WpNlxB133FFK2K2iGqeISEpVW+NMDmxPzu6chdAhBK2b10/KK9Q+Tj89boILi3I1Z8qUKQC88MILADzzzDPRsRNPPDHLEKUNQsdPU7dIhyFgs2fPBpr+HIaZrY466qhtjoW5W8PwpmQHUugw6tatfGP8VeMUEUmpamucSVm1cYZZ3pua2CT5rSj5SLYvNzXTe2N33XUXALfffjtQXOsIt2GGmmeYBbwOb6WsWuG9DpP0AJxxxhkAbN68GYgnVrn00kuj54SbUJpbcSFMzhMGx4chTACfffYZoBqniEhVUcEpIpJSTVyqZyXcMZRsiL7ooouA0mdZktb56KOPou2ePXsCpc+7GO4UufHGG6N9YXhZuDQ8/vjjgbgjCeCggw4q6bzSOslmsGXLlgGwZcsWAHbddVeg7Qshhs9ysjmvR48ebXqtNFTjFBFJqV3VOMNQiGTn0KBBg/IKp10Jw4vOPPPMaN8777wDlGem7z59+gDxsJWrr74aiAfGA7z11lsAHHrooZmfX5qWZmb/poQ5LMLS0sGxxx4bbR944IElnaM1VOMUEUmpXdU4m2rjDDO+S3mFORnPP//8aF8l3vsTTjih6Pyh7RPg+uuvB+JZlUJ7m1SvK664Aohv5wyGDBlS0ThU4xQRSak183H2BKbTsEaJA5Pc/b/NrAvwPNAb+DcwzN3/p3yhtl1YsTD8bGllz/Ygr7zutddeWb1UKqHXdvz48dG+Sy65BID3338fiHvga1k9fF4bS07+8sEHHwDxVWNouw5rF1VKa2qcvwO3uPsRwAnAf5rZEcBYYK67HwLMLTyW2qG81ifltQJaszzwWndfUtjeAKwAeqDlRmua8lqflNfKSNU5ZGa9gT8CC6iS5UbT0ExITatEXsOSF4899li0L8xsU8lF0gYPjsuLMGQpLFlbD5fqSbX+eQ1zS9xyyy3bHNt9992BeJ6CSs9B0OrOITPrDLwEjHH3ohlGtdxo7VJe65PyWl6tqnGaWQcakvCMu88o7M5tudG2Cp1C6hxqUMm8htvuvvjii2jfG2+8AcDFF18MwA47lH+QR3Ihr/333x+A+fPnl/28lVSLn9dNmzZF248++igQL8CWvFIMNcv7778faPutmqVq8S/VGqL+K7DC3R9MHNJyozVMea1PymtltKbG+R/ASOAfZra0sG8cOS0lW4rwzRXathpvtzMVzWuYazPUFAAuv/xyIJ74Ydy4cdGxjh07ZnHabSSXkV26tOG/feedd5blXDmpus9rcs2hcKtkGLA+adIkIF47COK/h6bcfPPNAFx77bWZx5lGa5YHngdsr1dFy43WKOW1PimvlaE7h0REUmoX96pPnjwZiDuFklP5azngyho5cmS0HfIRFteaNWtWdGzChAlA3KnUuXPnNp1v+fLlQLysRnI41K233grkf9lX77755ptoOzTPhHkBQs99U0MFw6xVV111VbQv5CxvqnGKiKTULmqcM2fOBOJvtaFDh+YZjhSE2kffvn0BmDhxYnQsdAKExbgGDhwIxEOXIL5aWLNmDRDfcw4wZ84cAL766isgnu09DHWBeHYkKa/k/Jhh5vf169cXPefoo4+OtkPHUahpVmJG97RU4xQRSalua5zJux7C7OO65bI69evXD4CpU6dG+zZu3AjEw5fC7P1hPkaIa5yrV68Gite2GTFiBAAnnXQSAGeddRZQPABeKiPkF4oHutcy1ThFRFKq2xpnsnYZtjXbe+3o1KkTUDx/pki1UI1TRCQlFZwiIinV7aV6coH6rVu35hiJiNQb1ThFRFKySs5NaWbfARuB9S09twp1pfS4D3T3fbMIppoor8prFSprXitacAKY2SJ371/Rk2agVuOulFp9f2o17kqp1fen3HHrUl1EJCUVnCIiKeVRcE7K4ZxZqNW4K6VW359ajbtSavX9KWvcFW/jFBGpdbpUFxFJqWIFp5mdY2YrzexTMxtbqfOmZWY9zextM1tuZsvM7MbC/i5m9qaZ/avwc++8Y60WtZBb5TU95bWZ81biUt3MdgQ+Ac4EvgQWAiPcfXnZT55SYc3p7u6+xMx2BxYDg4ErgB/cfULhj2hvd/9zjqFWhVrJrfKajvLavErVOI8DPnX3Ve6+GfgbMKhC507F3de6+5LC9gZgBdCDhninFZ42jYbkSI3kVnlNTXltRqUKzh7AF4nHXxb2VTUz6w38EVgAdHP3tYVD3wDdcgqr2tRcbpXXVlFem6HOoe0ws87AS8AYd/85ecwb2jc0HKEGKa/1qdJ5rVTB+RXQM/H4gMK+qmRmHWhIwjPuPqOw+9tCe0poV1mXV3xVpmZyq7ymorw2o1IF50LgEDP7g5ntDAwHZlfo3KlYw3TxfwVWuPuDiUOzgVGF7VHAy5WOrUrVRG6V19SU1+bOW6kB8GZ2LjAR2BGY4u7/VZETp2RmJwPvAf8AwkSe42hoN/k70AtYDQxz9x9yCbLK1EJuldf0lNdmzqs7h0RE0lHnkIhISio4RURSUsEpIpKSCk4RkZRUcIqIpKSCU0QkJRWcIiIpqeAUEUnp/wHCI2BIq3BQBAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#Convert train datset to (num_images, img_rows, img_cols) format \n", "X_train = X_train.reshape(X_train.shape[0], 28, 28)\n", "\n", "for i in range(6, 9):\n", " plt.subplot(330 + (i+1))\n", " plt.imshow(X_train[i], cmap=plt.get_cmap('gray_r'))\n", " plt.title(y_train[i]);" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "_cell_guid": "6be2f3e9-42eb-85b6-9162-c25e4d706155", "_uuid": "4051a0e6612b8e6d4b8aef6a4d131be621cd3a14" }, "outputs": [ { "data": { "text/plain": [ "(42000, 28, 28, 1)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#expand 1 more dimention as 1 for colour channel gray\n", "X_train = X_train.reshape(X_train.shape[0], 28, 28,1)\n", "X_train.shape" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "_cell_guid": "6949468c-fd27-19c5-15c7-0b357a961003", "_uuid": "6d4c1323f1fa3f89a16532fed893ec5d72051bcb" }, "outputs": [ { "data": { "text/plain": [ "(28000, 28, 28, 1)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test = X_test.reshape(X_test.shape[0], 28, 28,1)\n", "X_test.shape" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "1232c385-3cb2-56fd-4d1d-f027df7bc78e", "_uuid": "185d620525e041eb61aabce19e8536614ab50870" }, "source": [ "**Preprocessing the digit images**\n", "==================================" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "6fcc1f9e-1586-e393-49ba-50c73564e0ed", "_uuid": "b8847f48f7408c93ce795db16f30c1b7c6a8cf89" }, "source": [ "**Feature Standardization**\n", "-------------------------------------\n", "\n", "It is important preprocessing step.\n", "It is used to centre the data around zero mean and unit variance." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "_cell_guid": "a3f837ef-0373-8d91-46e6-30992cf73166", "_uuid": "528a370b381c91b73131a8c7a4217968278696c8" }, "outputs": [], "source": [ "mean_px = X_train.mean().astype(np.float32)\n", "std_px = X_train.std().astype(np.float32)\n", "\n", "def standardize(x): \n", " return (x-mean_px)/std_px" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "725c55fc-9742-a63c-9822-c67ab0c773ee", "_uuid": "532d3f3bd26b0dfb42bc0c96e9710269234fae9b" }, "source": [ "*One Hot encoding of labels.*\n", "-----------------------------\n", "\n", "A one-hot vector is a vector which is 0 in most dimensions, and 1 in a single dimension. In this case, the nth digit will be represented as a vector which is 1 in the nth dimension. \n", "\n", "For example, 3 would be [0,0,0,1,0,0,0,0,0,0]." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "_cell_guid": "c879f076-b3dd-6cb1-e2d9-2f404f2ed132", "_uuid": "41bb3082e71111d73dd0432f9f60261f5be05e15" }, "outputs": [ { "data": { "text/plain": [ "10" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from keras.utils.np_utils import to_categorical\n", "y_train= to_categorical(y_train)\n", "num_classes = y_train.shape[1]\n", "num_classes" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "4d76fb04-57fc-e802-6d91-06ece552686b", "_uuid": "429e528f5bf36152cd9e0b2acaa457525a202171" }, "source": [ "Lets plot 10th label." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "_cell_guid": "1c927e75-08d2-d539-54f3-71ab0308fec1", "_uuid": "b3ad8362611417de16de730cc55a6ff6309766f4" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/site-packages/matplotlib/text.py:1150: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n", " if s != self._text:\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.title(y_train[9])\n", "plt.plot(y_train[9])\n", "plt.xticks(range(10));" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "4e130661-9f09-d9a9-d49b-7274ef13927f", "_uuid": "40aecbff9b92269d438c384ac429bfa47ab37dda" }, "source": [ "Oh its 3 !" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "6a89dcdd-7b68-6ed1-2c39-b3a1edb3e7be", "_uuid": "dc7ece2b7ee08767b664149d67d922d8c1d0bbb1" }, "source": [ "**Designing Neural Network Architecture**\n", "=========================================" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "_cell_guid": "39107235-d87a-af4d-44fb-80c9c3aa0212", "_uuid": "1070353d05490ccec23933c62f11cdfd2d7e5032" }, "outputs": [], "source": [ "# fix random seed for reproducibility\n", "seed = 43\n", "np.random.seed(seed)" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "a8b65f54-398b-267f-e31a-313210450f54", "_uuid": "62606ecbb1d7e259850aebf8a8514e54263a2a06" }, "source": [ "*Linear Model*\n", "--------------" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "_cell_guid": "5dbe450c-845f-aaa2-dbde-21414a91d8c1", "_uuid": "5f54b59d89cd4e43dd129d9b133950ba83b5cad8" }, "outputs": [], "source": [ "from keras.models import Sequential\n", "from keras.layers.core import Lambda , Dense, Flatten, Dropout\n", "from keras.callbacks import EarlyStopping\n", "from keras.layers import BatchNormalization, Convolution2D , MaxPooling2D" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "5c3f674f-f3fc-9614-f2d4-056c3e3ad633", "_uuid": "ff25a88562237e84e44f20b38079f2b44a394d2c" }, "source": [ "Lets create a simple model from Keras Sequential layer.\n", "\n", "1. Lambda layer performs simple arithmetic operations like sum, average, exponentiation etc.\n", "\n", " In 1st layer of the model we have to define input dimensions of our data in (rows,columns,colour channel) format.\n", " (In theano colour channel comes first)\n", "\n", "\n", "2. Flatten will transform input into 1D array.\n", "\n", "\n", "3. Dense is fully connected layer that means all neurons in previous layers will be connected to all neurons in fully connected layer.\n", " In the last layer we have to specify output dimensions/classes of the model.\n", " Here it's 10, since we have to output 10 different digit labels." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "_cell_guid": "a2c27783-3cfa-e907-4749-1e340a513f26", "_uuid": "fb79b4558335446a722542c8bc06288e96781423" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: Logging before flag parsing goes to stderr.\n", "W0624 09:59:34.535813 4634289600 deprecation_wrapper.py:119] From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", "\n", "W0624 09:59:34.550372 4634289600 deprecation_wrapper.py:119] From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", "\n", "W0624 09:59:34.632287 4634289600 deprecation_wrapper.py:119] From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "input shape (None, 28, 28, 1)\n", "output shape (None, 10)\n" ] } ], "source": [ "model= Sequential()\n", "model.add(Lambda(standardize,input_shape=(28,28,1)))\n", "model.add(Flatten())\n", "model.add(Dense(10, activation='softmax'))\n", "print(\"input shape \",model.input_shape)\n", "print(\"output shape \",model.output_shape)" ] }, { "cell_type": "markdown", "metadata": { "_cell_guid": "260645fb-61b7-68e9-6826-047b97436c14", "_uuid": "2dd7f371688dd2590de94a94814799654595e55d" }, "source": [ "***Compile network***\n", "-------------------\n", "\n", "Before making network ready for training we have to make sure to add below things:\n", "\n", " 1. A loss function: to measure how good the network is\n", " \n", " 2. An optimizer: to update network as it sees more data and reduce loss\n", " value\n", " \n", " 3. Metrics: to monitor performance of network" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "_cell_guid": "9d1d1af9-b2a8-e3b9-6eaf-100d08fe83aa", "_uuid": "4bb75be10b9eec8bcdfe48665f639bf326a5f5fc" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "W0624 09:59:37.087071 4634289600 deprecation_wrapper.py:119] From /usr/local/lib/python3.7/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", "\n", "W0624 09:59:37.093498 4634289600 deprecation_wrapper.py:119] From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3295: The name tf.log is deprecated. Please use tf.math.log instead.\n", "\n" ] } ], "source": [ "from keras.optimizers import RMSprop\n", "model.compile(optimizer=RMSprop(lr=0.001),\n", " loss='categorical_crossentropy',\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "_cell_guid": "db3b4be6-4f72-c6cc-65cd-b45978db2462", "_uuid": "51f82558d87e95fa5c146b0469ab6c8b42e13bcf" }, "outputs": [], "source": [ "from keras.preprocessing import image\n", "gen = image.ImageDataGenerator()" ] }, { "cell_type": "markdown", "metadata": { "_uuid": "841a6f3b78b607e142f3e18d88bd7957202e4dcb" }, "source": [ "## Cross Validation " ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "_cell_guid": "9071d720-da50-8530-e9f3-1f0c37aac7ff", "_uuid": "0cff7e02b1ee8894b4ee9080b9268558aaa4e7c5" }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "X = X_train\n", "y = y_train\n", "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.10, random_state=42)\n", "batches = gen.flow(X_train, y_train, batch_size=64)\n", "val_batches=gen.flow(X_val, y_val, batch_size=64)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "_cell_guid": "20e08e2a-a394-bb70-69f1-be0fdab4f9ab", "_uuid": "6b23c282e2772b1ee482596131d6f1d3494c3bce" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "W0624 09:59:43.208245 4634289600 deprecation.py:323] From /usr/local/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", "W0624 09:59:43.251929 4634289600 deprecation_wrapper.py:119] From /usr/local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:986: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "37800/37800 [==============================] - 64s 2ms/step - loss: 0.2401 - acc: 0.9342 - val_loss: 0.3294 - val_acc: 0.9145\n", "Epoch 2/3\n", "37800/37800 [==============================] - 63s 2ms/step - loss: 0.2156 - acc: 0.9417 - val_loss: 0.3547 - val_acc: 0.9072\n", "Epoch 3/3\n", "37800/37800 [==============================] - 63s 2ms/step - loss: 0.2099 - acc: 0.9437 - val_loss: 0.3724 - val_acc: 0.9072\n" ] } ], "source": [ "history=model.fit_generator(generator=batches, steps_per_epoch=batches.n, epochs=3, \n", " validation_data=val_batches, validation_steps=val_batches.n)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "_cell_guid": "9f344366-c372-0b04-b7e0-860778d4bfd3", "_uuid": "6900e38c62028692b9f101b94730c527129675cc" }, "outputs": [ { "data": { "text/plain": [ "dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "history_dict = history.history\n", "history_dict.keys()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "_cell_guid": "df40f5fc-586a-1fae-025e-ee508a8d9b71", "_uuid": "c4b26ff79e0f186212266b60d03611ad58d0d5e3" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "loss_values = history_dict['loss']\n", "val_loss_values = history_dict['val_loss']\n", "epochs = range(1, len(loss_values) + 1)\n", "\n", "# \"bo\" is for \"blue dot\"\n", "plt.plot(epochs, loss_values, 'bo')\n", "# b+ is for \"blue crosses\"\n", "plt.plot(epochs, val_loss_values, 'b+')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "_cell_guid": "1ed6b756-00c2-d08c-c596-0ce496ec3d04", "_uuid": "fc9be5b885360ca9972e0d6b5da1ea36dc12cb5f" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.clf() # clear figure\n", "acc_values = history_dict['acc']\n", "val_acc_values = history_dict['val_acc']\n", "\n", "plt.plot(epochs, acc_values, 'bo')\n", "plt.plot(epochs, val_acc_values, 'b+')\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Accuracy')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": { "_uuid": "64ec304e056ec0c9e33fe94ea2315cbf65a7fbff" }, "source": [ "## Fully Connected Model\n", "\n", "Neurons in a fully connected layer have full connections to all activations in the previous layer, as seen in regular Neural Networks. \n", "Adding another Dense Layer to model." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "_uuid": "9556f3de5bd370bcddc70a81910eb2104624e3a3" }, "outputs": [], "source": [ "def get_fc_model():\n", " model = Sequential([\n", " Lambda(standardize, input_shape=(28,28,1)),\n", " Flatten(),\n", " Dense(512, activation='relu'),\n", " Dense(10, activation='softmax')\n", " ])\n", " model.compile(optimizer='Adam', loss='categorical_crossentropy',\n", " metrics=['accuracy'])\n", " return model" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "_uuid": "1901b6805f878ac6ed4efafbaf15bf003d505654" }, "outputs": [], "source": [ "fc = get_fc_model()\n", "fc.optimizer.lr=0.01" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "_uuid": "5fb346c542c8920fac61ddc5df44b2136969a6e9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/1\n", "37800/37800 [==============================] - 178s 5ms/step - loss: 0.1598 - acc: 0.9715 - val_loss: 0.5633 - val_acc: 0.9488\n" ] } ], "source": [ "history=fc.fit_generator(generator=batches, steps_per_epoch=batches.n, epochs=1, \n", " validation_data=val_batches, validation_steps=val_batches.n)" ] }, { "cell_type": "markdown", "metadata": { "_uuid": "46b81b17854a98f2b380da694691502c1e583bfb" }, "source": [ "## Convolutional Neural Network\n", "CNNs are extremely efficient for images.\n" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "_uuid": "fd0ab0de6bdbf7addf7515c8d7b59d8d17fef8e7" }, "outputs": [], "source": [ "from keras.layers import Convolution2D, MaxPooling2D\n", "\n", "def get_cnn_model():\n", " model = Sequential([\n", " Lambda(standardize, input_shape=(28,28,1)),\n", " Convolution2D(32,(3,3), activation='relu'),\n", " Convolution2D(32,(3,3), activation='relu'),\n", " MaxPooling2D(),\n", " Convolution2D(64,(3,3), activation='relu'),\n", " Convolution2D(64,(3,3), activation='relu'),\n", " MaxPooling2D(),\n", " Flatten(),\n", " Dense(512, activation='relu'),\n", " Dense(10, activation='softmax')\n", " ])\n", " model.compile(Adam(), loss='categorical_crossentropy',\n", " metrics=['accuracy'])\n", " return model" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "_uuid": "b5baeaed50c9f03c68900461555f4b7831a7e7c7" }, "outputs": [], "source": [ "model= get_cnn_model()\n", "model.optimizer.lr=0.01" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "_uuid": "a3d15eeead8b26c41b53d084dc111ae9e667a169" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/1\n", "42000/42000 [==============================] - 3044s 72ms/step - loss: 0.6667 - acc: 0.9467 - val_loss: 5.9925 - val_acc: 0.6282\n" ] } ], "source": [ "history=model.fit_generator(generator=batches, steps_per_epoch=batches.n, epochs=1, \n", " validation_data=val_batches, validation_steps=val_batches.n)" ] }, { "cell_type": "markdown", "metadata": { "_uuid": "e2891c7e434a2022ee182a0e9bd243a876532dcc" }, "source": [ "## Data Augmentation\n", "It is tehnique of showing slighly different or new images to neural network to avoid overfitting. And to achieve better generalization.\n", "In case you have very small dataset, you can use different kinds of data augmentation techniques to increase your data size. Neural networks perform better if you provide them more data.\n", "\n", "Different data aumentation techniques are as follows:\n", "1. Cropping\n", "2. Rotating\n", "3. Scaling\n", "4. Translating\n", "5. Flipping \n", "6. Adding Gaussian noise to input images etc.\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "_uuid": "daa409b92678202cf7c751371b7ba17fb14aa2ac" }, "outputs": [], "source": [ "gen =ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,\n", " height_shift_range=0.08, zoom_range=0.08)\n", "batches = gen.flow(X_train, y_train, batch_size=64)\n", "val_batches = gen.flow(X_val, y_val, batch_size=64)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "_uuid": "f21ba7b8d77a37bee6e8238a8f517b654ae3f0a0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/1\n", "37800/37800 [==============================] - 2865s 76ms/step - loss: 7.0282 - acc: 0.5639 - val_loss: 8.1427 - val_acc: 0.4948\n" ] } ], "source": [ "model.optimizer.lr=0.001\n", "history=model.fit_generator(generator=batches, steps_per_epoch=batches.n, epochs=1, \n", " validation_data=val_batches, validation_steps=val_batches.n)" ] }, { "cell_type": "markdown", "metadata": { "_uuid": "538f504c44e14d389c70b2f35b7225de61b9015d" }, "source": [ "## Adding Batch Normalization\n", "\n", "BN helps to fine tune hyperparameters more better and train really deep neural networks." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "_uuid": "8b72580fbb06f5f4f769c514cb0d7d2f15aa2c2f" }, "outputs": [], "source": [ "from keras.layers.normalization import BatchNormalization\n", "\n", "def get_bn_model():\n", " model = Sequential([\n", " Lambda(standardize, input_shape=(28,28,1)),\n", " Convolution2D(32,(3,3), activation='relu'),\n", " BatchNormalization(axis=1),\n", " Convolution2D(32,(3,3), activation='relu'),\n", " MaxPooling2D(),\n", " BatchNormalization(axis=1),\n", " Convolution2D(64,(3,3), activation='relu'),\n", " BatchNormalization(axis=1),\n", " Convolution2D(64,(3,3), activation='relu'),\n", " MaxPooling2D(),\n", " Flatten(),\n", " BatchNormalization(),\n", " Dense(512, activation='relu'),\n", " BatchNormalization(),\n", " Dense(10, activation='softmax')\n", " ])\n", " model.compile(Adam(), loss='categorical_crossentropy', metrics=['accuracy'])\n", " return model" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "_uuid": "78e382d0b3de14312e762edc480b5d215be82269" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/1\n", "37800/37800 [==============================] - 4002s 106ms/step - loss: 0.0344 - acc: 0.9902 - val_loss: 0.0510 - val_acc: 0.9914\n" ] } ], "source": [ "model= get_bn_model()\n", "model.optimizer.lr=0.01\n", "history=model.fit_generator(generator=batches, steps_per_epoch=batches.n, epochs=1, \n", " validation_data=val_batches, validation_steps=val_batches.n)" ] }, { "cell_type": "markdown", "metadata": { "_uuid": "8e4b16516a57e152a911f6e7ba7f4d70ff204512" }, "source": [ "## Saving model to ONNX format" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "_uuid": "0fc055b482971b36561aaf9421c8a9c53df2900b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "42000/42000 [==============================] - 4374s 104ms/step - loss: 0.0140 - acc: 0.9981\n", "Epoch 2/3\n", "42000/42000 [==============================] - 4278s 102ms/step - loss: 0.0214 - acc: 0.9983\n", "Epoch 3/3\n", "42000/42000 [==============================] - 57188s 1s/step - loss: 0.0288 - acc: 0.9980\n" ] } ], "source": [ "model.optimizer.lr=0.01\n", "gen = image.ImageDataGenerator()\n", "batches = gen.flow(X, y, batch_size=64)\n", "history=model.fit_generator(generator=batches, steps_per_epoch=batches.n, epochs=3)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "_cell_guid": "c2841d54-f3dd-1ee8-a30d-4457dec0a67a", "_uuid": "4262c6bfb15ec96993e83bd2a2552eadf14fb33d" }, "outputs": [], "source": [ "# Convert into ONNX format with onnxmltools\n", "import keras2onnx\n", "onnx_model = keras2onnx.convert_keras(model, model.name)\n", "\n", "import onnx\n", "temp_model_file = 'model.onnx'\n", "\n", "onnx.save_model(onnx_model, temp_model_file)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ir_version: 5\n", "producer_name: \"keras2onnx\"\n", "producer_version: \"1.5.0\"\n", "domain: \"onnx\"\n", "model_version: 0\n", "doc_string: \"\"\n", "graph {\n", " node {\n", " input: \"lambda_5_input_01\"\n", " input: \"TFNodes1_lambda_5_sub_y_0\"\n", " output: \"TFNodes1_lambda_5_sub_0\"\n", " name: \"TFNodes1_lambda_5_sub\"\n", " op_type: \"Sub\"\n", " doc_string: \"\"\n", " domain: \"\"\n", " }\n", " node {\n", " input: \"TFNodes1_lambda_5_sub_0\"\n", " input: \"TFNodes1_lambda_5_truediv_y_0\"\n", " output: \"TFNodes1_lambda_5_truediv_0\"\n", " name: \"TFNodes1_lambda_5_truediv\"\n", "\n", "...\n" ] } ], "source": [ "print(str(onnx_model)[:500] + \"\\n...\")" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "content = onnx_model.SerializeToString()\n", "sess = onnxruntime.InferenceSession('model.onnx')" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "onnx.checker.check_model(onnx_model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "_change_revision": 0, "_is_fork": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 1 }