{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# BinaryNet on CIFAR10\n", "\n", " \n", "\n", "In this example we demonstrate how to use Larq to build and train BinaryNet on the CIFAR10 dataset to achieve a validation accuracy approximately 83% on laptop hardware.\n", "On a Nvidia GTX 1050 Ti Max-Q it takes approximately 200 minutes to train. For simplicity, compared to the original papers [BinaryConnect: Training Deep Neural Networks with binary weights during propagations](https://arxiv.org/abs/1511.00363), and [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830), we do not impliment learning rate scaling, or image whitening." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pip install larq" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import larq as lq\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import CIFAR10 Dataset\n", "\n", "We download and normalize the CIFAR10 dataset." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n", "170500096/170498071 [==============================] - 38s 0us/step\n" ] } ], "source": [ "num_classes = 10\n", "\n", "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()\n", "\n", "train_images = train_images.reshape((50000, 32, 32, 3)).astype(\"float32\")\n", "test_images = test_images.reshape((10000, 32, 32, 3)).astype(\"float32\")\n", "\n", "# Normalize pixel values to be between -1 and 1\n", "train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1\n", "\n", "train_labels = tf.keras.utils.to_categorical(train_labels, num_classes)\n", "test_labels = tf.keras.utils.to_categorical(test_labels, num_classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build BinaryNet\n", "\n", "Here we build the BinaryNet model layer by layer using the [Keras Sequential API](https://www.tensorflow.org/guide/keras)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# All quantized layers except the first will use the same options\n", "kwargs = dict(input_quantizer=\"ste_sign\",\n", " kernel_quantizer=\"ste_sign\",\n", " kernel_constraint=\"weight_clip\",\n", " use_bias=False)\n", "\n", "model = tf.keras.models.Sequential([\n", " # In the first layer we only quantize the weights and not the input\n", " lq.layers.QuantConv2D(128, 3,\n", " kernel_quantizer=\"ste_sign\",\n", " kernel_constraint=\"weight_clip\",\n", " use_bias=False,\n", " input_shape=(32, 32, 3)),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", "\n", " lq.layers.QuantConv2D(128, 3, padding=\"same\", **kwargs),\n", " tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", "\n", " lq.layers.QuantConv2D(256, 3, padding=\"same\", **kwargs),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", "\n", " lq.layers.QuantConv2D(256, 3, padding=\"same\", **kwargs),\n", " tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", "\n", " lq.layers.QuantConv2D(512, 3, padding=\"same\", **kwargs),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", "\n", " lq.layers.QuantConv2D(512, 3, padding=\"same\", **kwargs),\n", " tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", " tf.keras.layers.Flatten(),\n", "\n", " lq.layers.QuantDense(1024, **kwargs),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", "\n", " lq.layers.QuantDense(1024, **kwargs),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", "\n", " lq.layers.QuantDense(10, **kwargs),\n", " tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),\n", " tf.keras.layers.Activation(\"softmax\")\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One can output a summary of the model:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+sequential stats---------------------------------------------------------------------------------------------+\n", "| Layer Input prec. Outputs # 1-bit # 32-bit Memory 1-bit MACs 32-bit MACs |\n", "| (bit) x 1 x 1 (kB) |\n", "+-------------------------------------------------------------------------------------------------------------+\n", "| quant_conv2d - (-1, 30, 30, 128) 3456 0 0.42 0 3110400 |\n", "| batch_normalization - (-1, 30, 30, 128) 0 256 1.00 0 0 |\n", "| quant_conv2d_1 1 (-1, 30, 30, 128) 147456 0 18.00 132710400 0 |\n", "| max_pooling2d - (-1, 15, 15, 128) 0 0 0 0 0 |\n", "| batch_normalization_1 - (-1, 15, 15, 128) 0 256 1.00 0 0 |\n", "| quant_conv2d_2 1 (-1, 15, 15, 256) 294912 0 36.00 66355200 0 |\n", "| batch_normalization_2 - (-1, 15, 15, 256) 0 512 2.00 0 0 |\n", "| quant_conv2d_3 1 (-1, 15, 15, 256) 589824 0 72.00 132710400 0 |\n", "| max_pooling2d_1 - (-1, 7, 7, 256) 0 0 0 0 0 |\n", "| batch_normalization_3 - (-1, 7, 7, 256) 0 512 2.00 0 0 |\n", "| quant_conv2d_4 1 (-1, 7, 7, 512) 1179648 0 144.00 57802752 0 |\n", "| batch_normalization_4 - (-1, 7, 7, 512) 0 1024 4.00 0 0 |\n", "| quant_conv2d_5 1 (-1, 7, 7, 512) 2359296 0 288.00 115605504 0 |\n", "| max_pooling2d_2 - (-1, 3, 3, 512) 0 0 0 0 0 |\n", "| batch_normalization_5 - (-1, 3, 3, 512) 0 1024 4.00 0 0 |\n", "| flatten - (-1, 4608) 0 0 0 0 0 |\n", "| quant_dense 1 (-1, 1024) 4718592 0 576.00 4718592 0 |\n", "| batch_normalization_6 - (-1, 1024) 0 2048 8.00 0 0 |\n", "| quant_dense_1 1 (-1, 1024) 1048576 0 128.00 1048576 0 |\n", "| batch_normalization_7 - (-1, 1024) 0 2048 8.00 0 0 |\n", "| quant_dense_2 1 (-1, 10) 10240 0 1.25 10240 0 |\n", "| batch_normalization_8 - (-1, 10) 0 20 0.08 0 0 |\n", "| activation - (-1, 10) 0 0 0 ? ? |\n", "+-------------------------------------------------------------------------------------------------------------+\n", "| Total 10352000 7700 1293.75 510961664 3110400 |\n", "+-------------------------------------------------------------------------------------------------------------+\n", "+sequential summary---------------------------+\n", "| Total params 10.4 M |\n", "| Trainable params 10.4 M |\n", "| Non-trainable params 7.7 k |\n", "| Model size 1.26 MiB |\n", "| Model size (8-bit FP weights) 1.24 MiB |\n", "| Float-32 Equivalent 39.52 MiB |\n", "| Compression Ratio of Memory 0.03 |\n", "| Number of MACs 514 M |\n", "| Ratio of MACs that are binarized 0.9939 |\n", "+---------------------------------------------+\n" ] } ], "source": [ "lq.models.summary(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Training\n", "\n", "Compile the model and train the model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [], "source": [ "model.compile(\n", " tf.keras.optimizers.Adam(lr=0.01, decay=0.0001),\n", " loss=\"categorical_crossentropy\",\n", " metrics=[\"accuracy\"],\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "outputExpanded": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 50000 samples, validate on 10000 samples\n", "Epoch 1/100\n", "50000/50000 [==============================] - 131s 3ms/step - loss: 1.5733 - acc: 0.4533 - val_loss: 1.6368 - val_acc: 0.4244\n", "Epoch 2/100\n", "50000/50000 [==============================] - 125s 3ms/step - loss: 1.1485 - acc: 0.6387 - val_loss: 1.8497 - val_acc: 0.3764\n", "Epoch 3/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.9641 - acc: 0.7207 - val_loss: 1.5696 - val_acc: 0.4794\n", "Epoch 4/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.8452 - acc: 0.7728 - val_loss: 1.5765 - val_acc: 0.4669\n", "Epoch 5/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.7553 - acc: 0.8114 - val_loss: 1.0653 - val_acc: 0.6928\n", "Epoch 6/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.6841 - acc: 0.8447 - val_loss: 1.0944 - val_acc: 0.6880\n", "Epoch 7/100\n", "50000/50000 [==============================] - 125s 3ms/step - loss: 0.6356 - acc: 0.8685 - val_loss: 0.9909 - val_acc: 0.7317\n", "Epoch 8/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.5907 - acc: 0.8910 - val_loss: 0.9453 - val_acc: 0.7446\n", "Epoch 9/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.5610 - acc: 0.9043 - val_loss: 0.9441 - val_acc: 0.7460\n", "Epoch 10/100\n", "50000/50000 [==============================] - 125s 3ms/step - loss: 0.5295 - acc: 0.9201 - val_loss: 0.8892 - val_acc: 0.7679\n", "Epoch 11/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.5100 - acc: 0.9309 - val_loss: 0.8808 - val_acc: 0.7818\n", "Epoch 12/100\n", "50000/50000 [==============================] - 126s 3ms/step - loss: 0.4926 - acc: 0.9397 - val_loss: 0.8404 - val_acc: 0.7894\n", "Epoch 13/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.4807 - acc: 0.9470 - val_loss: 0.8600 - val_acc: 0.7928\n", "Epoch 14/100\n", "50000/50000 [==============================] - 126s 3ms/step - loss: 0.4661 - acc: 0.9529 - val_loss: 0.9046 - val_acc: 0.7732\n", "Epoch 15/100\n", "50000/50000 [==============================] - 125s 3ms/step - loss: 0.4588 - acc: 0.9571 - val_loss: 0.8505 - val_acc: 0.7965\n", "Epoch 16/100\n", "50000/50000 [==============================] - 126s 3ms/step - loss: 0.4558 - acc: 0.9593 - val_loss: 0.8748 - val_acc: 0.7859\n", "Epoch 17/100\n", "50000/50000 [==============================] - 126s 3ms/step - loss: 0.4434 - acc: 0.9649 - val_loss: 0.9109 - val_acc: 0.7656\n", "Epoch 18/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.4449 - acc: 0.9643 - val_loss: 0.8532 - val_acc: 0.7971\n", "Epoch 19/100\n", "50000/50000 [==============================] - 126s 3ms/step - loss: 0.4349 - acc: 0.9701 - val_loss: 0.8677 - val_acc: 0.7951\n", "Epoch 20/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.4351 - acc: 0.9698 - val_loss: 0.9145 - val_acc: 0.7740\n", "Epoch 21/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.4268 - acc: 0.9740 - val_loss: 0.8308 - val_acc: 0.8065\n", "Epoch 22/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.4243 - acc: 0.9741 - val_loss: 0.8229 - val_acc: 0.8075\n", "Epoch 23/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.4201 - acc: 0.9764 - val_loss: 0.8411 - val_acc: 0.8062\n", "Epoch 24/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.4190 - acc: 0.9769 - val_loss: 0.8649 - val_acc: 0.7951\n", "Epoch 25/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.4139 - acc: 0.9787 - val_loss: 0.8257 - val_acc: 0.8071\n", "Epoch 26/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.4154 - acc: 0.9779 - val_loss: 0.8041 - val_acc: 0.8205\n", "Epoch 27/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.4128 - acc: 0.9798 - val_loss: 0.8296 - val_acc: 0.8115\n", "Epoch 28/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.4121 - acc: 0.9798 - val_loss: 0.8241 - val_acc: 0.8074\n", "Epoch 29/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.4093 - acc: 0.9807 - val_loss: 0.8575 - val_acc: 0.7913\n", "Epoch 30/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.4048 - acc: 0.9826 - val_loss: 0.8118 - val_acc: 0.8166\n", "Epoch 31/100\n", "50000/50000 [==============================] - 126s 3ms/step - loss: 0.4041 - acc: 0.9837 - val_loss: 0.8375 - val_acc: 0.8082\n", "Epoch 32/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.4045 - acc: 0.9831 - val_loss: 0.8604 - val_acc: 0.8091\n", "Epoch 33/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.4047 - acc: 0.9823 - val_loss: 0.8797 - val_acc: 0.7931\n", "Epoch 34/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.4023 - acc: 0.9842 - val_loss: 0.8694 - val_acc: 0.8020\n", "Epoch 35/100\n", "50000/50000 [==============================] - 125s 3ms/step - loss: 0.3995 - acc: 0.9858 - val_loss: 0.8161 - val_acc: 0.8186\n", "Epoch 36/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3976 - acc: 0.9859 - val_loss: 0.8495 - val_acc: 0.7988\n", "Epoch 37/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.4021 - acc: 0.9847 - val_loss: 0.8542 - val_acc: 0.8062\n", "Epoch 38/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.3939 - acc: 0.9869 - val_loss: 0.8347 - val_acc: 0.8122\n", "Epoch 39/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.3955 - acc: 0.9856 - val_loss: 0.8521 - val_acc: 0.7993\n", "Epoch 40/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.3907 - acc: 0.9885 - val_loss: 0.9023 - val_acc: 0.7992\n", "Epoch 41/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3911 - acc: 0.9873 - val_loss: 0.8597 - val_acc: 0.8010\n", "Epoch 42/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.3917 - acc: 0.9885 - val_loss: 0.8968 - val_acc: 0.7936\n", "Epoch 43/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.3931 - acc: 0.9874 - val_loss: 0.8318 - val_acc: 0.8169\n", "Epoch 44/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3897 - acc: 0.9893 - val_loss: 0.8811 - val_acc: 0.7988\n", "Epoch 45/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3876 - acc: 0.9888 - val_loss: 0.8453 - val_acc: 0.8094\n", "Epoch 46/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3876 - acc: 0.9889 - val_loss: 0.8195 - val_acc: 0.8179\n", "Epoch 47/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3891 - acc: 0.9890 - val_loss: 0.8373 - val_acc: 0.8137\n", "Epoch 48/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3902 - acc: 0.9888 - val_loss: 0.8457 - val_acc: 0.8120\n", "Epoch 49/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3864 - acc: 0.9903 - val_loss: 0.9012 - val_acc: 0.7907\n", "Epoch 50/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3859 - acc: 0.9903 - val_loss: 0.8291 - val_acc: 0.8053\n", "Epoch 51/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3830 - acc: 0.9915 - val_loss: 0.8494 - val_acc: 0.8139\n", "Epoch 52/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3828 - acc: 0.9907 - val_loss: 0.8447 - val_acc: 0.8135\n", "Epoch 53/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3823 - acc: 0.9910 - val_loss: 0.8539 - val_acc: 0.8120\n", "Epoch 54/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3832 - acc: 0.9905 - val_loss: 0.8592 - val_acc: 0.8098\n", "Epoch 55/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3823 - acc: 0.9908 - val_loss: 0.8585 - val_acc: 0.8087\n", "Epoch 56/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3817 - acc: 0.9911 - val_loss: 0.8840 - val_acc: 0.7889\n", "Epoch 57/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3827 - acc: 0.9914 - val_loss: 0.8205 - val_acc: 0.8250\n", "Epoch 58/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3818 - acc: 0.9912 - val_loss: 0.8571 - val_acc: 0.8051\n", "Epoch 59/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3811 - acc: 0.9919 - val_loss: 0.8155 - val_acc: 0.8254\n", "Epoch 60/100\n", "50000/50000 [==============================] - 125s 3ms/step - loss: 0.3803 - acc: 0.9919 - val_loss: 0.8617 - val_acc: 0.8040\n", "Epoch 61/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.3793 - acc: 0.9926 - val_loss: 0.8212 - val_acc: 0.8192\n", "Epoch 62/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.3825 - acc: 0.9912 - val_loss: 0.8139 - val_acc: 0.8277\n", "Epoch 63/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.3784 - acc: 0.9923 - val_loss: 0.8304 - val_acc: 0.8121\n", "Epoch 64/100\n", "50000/50000 [==============================] - 125s 2ms/step - loss: 0.3809 - acc: 0.9918 - val_loss: 0.7961 - val_acc: 0.8289\n", "Epoch 65/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3750 - acc: 0.9930 - val_loss: 0.8676 - val_acc: 0.8110\n", "Epoch 66/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3789 - acc: 0.9928 - val_loss: 0.8308 - val_acc: 0.8148\n", "Epoch 67/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3783 - acc: 0.9929 - val_loss: 0.8595 - val_acc: 0.8097\n", "Epoch 68/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3758 - acc: 0.9935 - val_loss: 0.8359 - val_acc: 0.8065\n", "Epoch 69/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3784 - acc: 0.9927 - val_loss: 0.8189 - val_acc: 0.8255\n", "Epoch 70/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3786 - acc: 0.9924 - val_loss: 0.8754 - val_acc: 0.8001\n", "Epoch 71/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3749 - acc: 0.9936 - val_loss: 0.8188 - val_acc: 0.8262\n", "Epoch 72/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3758 - acc: 0.9932 - val_loss: 0.8540 - val_acc: 0.8169\n", "Epoch 73/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3740 - acc: 0.9934 - val_loss: 0.8127 - val_acc: 0.8258\n", "Epoch 74/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3749 - acc: 0.9932 - val_loss: 0.8662 - val_acc: 0.8018\n", "Epoch 75/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3721 - acc: 0.9941 - val_loss: 0.8359 - val_acc: 0.8213\n", "Epoch 76/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3746 - acc: 0.9937 - val_loss: 0.8462 - val_acc: 0.8178\n", "Epoch 77/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3741 - acc: 0.9936 - val_loss: 0.8983 - val_acc: 0.7972\n", "Epoch 78/100\n", "50000/50000 [==============================] - 122s 2ms/step - loss: 0.3751 - acc: 0.9933 - val_loss: 0.8525 - val_acc: 0.8173\n", "Epoch 79/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.3762 - acc: 0.9931 - val_loss: 0.8190 - val_acc: 0.8201\n", "Epoch 80/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3737 - acc: 0.9940 - val_loss: 0.8441 - val_acc: 0.8196\n", "Epoch 81/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3729 - acc: 0.9935 - val_loss: 0.8151 - val_acc: 0.8267\n", "Epoch 82/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3735 - acc: 0.9938 - val_loss: 0.8405 - val_acc: 0.8163\n", "Epoch 83/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3723 - acc: 0.9939 - val_loss: 0.8225 - val_acc: 0.8243\n", "Epoch 84/100\n", "50000/50000 [==============================] - 123s 2ms/step - loss: 0.3738 - acc: 0.9938 - val_loss: 0.8413 - val_acc: 0.8115\n", "Epoch 85/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.3714 - acc: 0.9947 - val_loss: 0.9080 - val_acc: 0.7932\n", "Epoch 86/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.3744 - acc: 0.9942 - val_loss: 0.8467 - val_acc: 0.8135\n", "Epoch 87/100\n", "50000/50000 [==============================] - 124s 2ms/step - loss: 0.3705 - acc: 0.9948 - val_loss: 0.8491 - val_acc: 0.8163\n", "Epoch 88/100\n", "50000/50000 [==============================] - 128s 3ms/step - loss: 0.3733 - acc: 0.9944 - val_loss: 0.8005 - val_acc: 0.8214\n", "Epoch 89/100\n", "50000/50000 [==============================] - 134s 3ms/step - loss: 0.3693 - acc: 0.9949 - val_loss: 0.7791 - val_acc: 0.8321\n", "Epoch 90/100\n", "50000/50000 [==============================] - 135s 3ms/step - loss: 0.3724 - acc: 0.9942 - val_loss: 0.8458 - val_acc: 0.8124\n", "Epoch 91/100\n", "50000/50000 [==============================] - 128s 3ms/step - loss: 0.3732 - acc: 0.9947 - val_loss: 0.8315 - val_acc: 0.8164\n", "Epoch 92/100\n", "50000/50000 [==============================] - 127s 3ms/step - loss: 0.3699 - acc: 0.9950 - val_loss: 0.8140 - val_acc: 0.8226\n", "Epoch 93/100\n", "50000/50000 [==============================] - 131s 3ms/step - loss: 0.3694 - acc: 0.9950 - val_loss: 0.8342 - val_acc: 0.8210\n", "Epoch 94/100\n", "50000/50000 [==============================] - 134s 3ms/step - loss: 0.3698 - acc: 0.9946 - val_loss: 0.8938 - val_acc: 0.8019\n", "Epoch 95/100\n", "50000/50000 [==============================] - 133s 3ms/step - loss: 0.3698 - acc: 0.9946 - val_loss: 0.8771 - val_acc: 0.8066\n", "Epoch 96/100\n", "50000/50000 [==============================] - 164s 3ms/step - loss: 0.3712 - acc: 0.9946 - val_loss: 0.8396 - val_acc: 0.8211\n", "Epoch 97/100\n", "50000/50000 [==============================] - 155s 3ms/step - loss: 0.3689 - acc: 0.9949 - val_loss: 0.8728 - val_acc: 0.8112\n", "Epoch 98/100\n", "50000/50000 [==============================] - 133s 3ms/step - loss: 0.3663 - acc: 0.9953 - val_loss: 0.9615 - val_acc: 0.7902\n", "Epoch 99/100\n", "50000/50000 [==============================] - 133s 3ms/step - loss: 0.3714 - acc: 0.9944 - val_loss: 0.8414 - val_acc: 0.8188\n", "Epoch 100/100\n", "50000/50000 [==============================] - 138s 3ms/step - loss: 0.3682 - acc: 0.9956 - val_loss: 0.8055 - val_acc: 0.8266\n" ] } ], "source": [ "trained_model = model.fit(\n", " train_images, \n", " train_labels,\n", " batch_size=50, \n", " epochs=100,\n", " validation_data=(test_images, test_labels),\n", " shuffle=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Output\n", "\n", "We can now plot the final validation accuracy and loss:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9956000019311905\n", "0.8320999944210052\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(trained_model.history['acc'])\n", "plt.plot(trained_model.history['val_acc'])\n", "plt.title('model accuracy')\n", "plt.ylabel('accuracy')\n", "plt.xlabel('epoch')\n", "plt.legend(['train', 'test'], loc='upper left')\n", "\n", "print(np.max(trained_model.history['acc']))\n", "print(np.max(trained_model.history['val_acc']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.3663262344896793\n", "0.7790719392895699\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(trained_model.history['loss'])\n", "plt.plot(trained_model.history['val_loss'])\n", "plt.title('model loss')\n", "plt.ylabel('loss')\n", "plt.xlabel('epoch')\n", "plt.legend(['train', 'test'], loc='upper left')\n", "\n", "print(np.min(trained_model.history['loss']))\n", "print(np.min(trained_model.history['val_loss']))" ] } ], "metadata": { "kernel_info": { "name": "python3" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "nteract": { "version": "0.14.3" } }, "nbformat": 4, "nbformat_minor": 4 }