{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qLGkt5qiyz4E" }, "source": [ "# Introduction to BNNs with Larq\n", "\n", " \n", "\n", "This tutorial demonstrates how to train a simple binarized Convolutional Neural Network (CNN) to classify MNIST digits. This simple network will achieve approximately 98% accuracy on the MNIST test set. This tutorial uses Larq and the [Keras Sequential API](https://www.tensorflow.org/guide/keras), so creating and training our model will require only a few lines of code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pip install larq" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "iAve6DCL4JH4" }, "outputs": [], "source": [ "import tensorflow as tf\n", "import larq as lq" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jRFxccghyMVo" }, "source": [ "### Download and prepare the MNIST dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "JWoEqyMuXFF4" }, "outputs": [], "source": [ "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", "\n", "train_images = train_images.reshape((60000, 28, 28, 1))\n", "test_images = test_images.reshape((10000, 28, 28, 1))\n", "\n", "# Normalize pixel values to be between -1 and 1\n", "train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Oewp-wYg31t9" }, "source": [ "### Create the model\n", "\n", "The following will create a simple binarized CNN.\n", "\n", "The quantization function\n", "$$\n", "q(x) = \\begin{cases}\n", " -1 & x < 0 \\\\\\\n", " 1 & x \\geq 0\n", "\\end{cases}\n", "$$\n", "is used in the forward pass to binarize the activations and the latent full precision weights. The gradient of this function is zero almost everywhere which prevents the model from learning.\n", "\n", "To be able to train the model the gradient is instead estimated using the Straight-Through Estimator (STE)\n", "(the binarization is essentially replaced by a clipped identity on the backward pass):\n", "$$\n", "\\frac{\\partial q(x)}{\\partial x} = \\begin{cases}\n", " 1 & \\left|x\\right| \\leq 1 \\\\\\\n", " 0 & \\left|x\\right| > 1\n", "\\end{cases}\n", "$$\n", "\n", "In Larq this can be done by using `input_quantizer=\"ste_sign\"` and `kernel_quantizer=\"ste_sign\"`.\n", "Additionally, the latent full precision weights are clipped to -1 and 1 using `kernel_constraint=\"weight_clip\"`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "L9YmGQBQPrdn" }, "outputs": [], "source": [ "# All quantized layers except the first will use the same options\n", "kwargs = dict(input_quantizer=\"ste_sign\",\n", " kernel_quantizer=\"ste_sign\",\n", " kernel_constraint=\"weight_clip\")\n", "\n", "model = tf.keras.models.Sequential()\n", "\n", "# In the first layer we only quantize the weights and not the input\n", "model.add(lq.layers.QuantConv2D(32, (3, 3),\n", " kernel_quantizer=\"ste_sign\",\n", " kernel_constraint=\"weight_clip\",\n", " use_bias=False,\n", " input_shape=(28, 28, 1)))\n", "model.add(tf.keras.layers.MaxPooling2D((2, 2)))\n", "model.add(tf.keras.layers.BatchNormalization(scale=False))\n", "\n", "model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))\n", "model.add(tf.keras.layers.MaxPooling2D((2, 2)))\n", "model.add(tf.keras.layers.BatchNormalization(scale=False))\n", "\n", "model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))\n", "model.add(tf.keras.layers.BatchNormalization(scale=False))\n", "model.add(tf.keras.layers.Flatten())\n", "\n", "model.add(lq.layers.QuantDense(64, use_bias=False, **kwargs))\n", "model.add(tf.keras.layers.BatchNormalization(scale=False))\n", "model.add(lq.layers.QuantDense(10, use_bias=False, **kwargs))\n", "model.add(tf.keras.layers.BatchNormalization(scale=False))\n", "model.add(tf.keras.layers.Activation(\"softmax\"))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "lvDVFkg-2DPm" }, "source": [ "Almost all parameters in the network are binarized, so either -1 or 1. This makes the network extremely fast if it would be deployed on custom BNN hardware.\n", "\n", " Here is the complete architecture of our model:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "8-C4XBg4UTJy" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+sequential stats------------------------------------------------------------------------------------------+\n", "| Layer Input prec. Outputs # 1-bit # 32-bit Memory 1-bit MACs 32-bit MACs |\n", "| (bit) x 1 x 1 (kB) |\n", "+----------------------------------------------------------------------------------------------------------+\n", "| quant_conv2d - (-1, 26, 26, 32) 288 0 0.04 0 194688 |\n", "| max_pooling2d - (-1, 13, 13, 32) 0 0 0 0 0 |\n", "| batch_normalization - (-1, 13, 13, 32) 0 64 0.25 0 0 |\n", "| quant_conv2d_1 1 (-1, 11, 11, 64) 18432 0 2.25 2230272 0 |\n", "| max_pooling2d_1 - (-1, 5, 5, 64) 0 0 0 0 0 |\n", "| batch_normalization_1 - (-1, 5, 5, 64) 0 128 0.50 0 0 |\n", "| quant_conv2d_2 1 (-1, 3, 3, 64) 36864 0 4.50 331776 0 |\n", "| batch_normalization_2 - (-1, 3, 3, 64) 0 128 0.50 0 0 |\n", "| flatten - (-1, 576) 0 0 0 0 0 |\n", "| quant_dense 1 (-1, 64) 36864 0 4.50 36864 0 |\n", "| batch_normalization_3 - (-1, 64) 0 128 0.50 0 0 |\n", "| quant_dense_1 1 (-1, 10) 640 0 0.08 640 0 |\n", "| batch_normalization_4 - (-1, 10) 0 20 0.08 0 0 |\n", "| activation - (-1, 10) 0 0 0 ? ? |\n", "+----------------------------------------------------------------------------------------------------------+\n", "| Total 93088 468 13.19 2599552 194688 |\n", "+----------------------------------------------------------------------------------------------------------+\n", "+sequential summary----------------------------+\n", "| Total params 93.6 k |\n", "| Trainable params 93.1 k |\n", "| Non-trainable params 468 |\n", "| Model size 13.19 KiB |\n", "| Model size (8-bit FP weights) 11.82 KiB |\n", "| Float-32 Equivalent 365.45 KiB |\n", "| Compression Ratio of Memory 0.04 |\n", "| Number of MACs 2.79 M |\n", "| Ratio of MACs that are binarized 0.9303 |\n", "+----------------------------------------------+\n" ] } ], "source": [ "lq.models.summary(model)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "P3odqfHP4M67" }, "source": [ "### Compile and train the model\n", "\n", "Note: This may take a few minutes depending on your system." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "id": "MdDzI75PUXrG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/6\n", "938/938 [==============================] - 20s 21ms/step - loss: 0.4231 - accuracy: 0.9771\n", "Epoch 2/6\n", "938/938 [==============================] - 20s 21ms/step - loss: 0.4186 - accuracy: 0.9788\n", "Epoch 3/6\n", "938/938 [==============================] - 21s 22ms/step - loss: 0.4170 - accuracy: 0.9787\n", "Epoch 4/6\n", "938/938 [==============================] - 21s 23ms/step - loss: 0.4126 - accuracy: 0.9793\n", "Epoch 5/6\n", "938/938 [==============================] - 21s 23ms/step - loss: 0.4142 - accuracy: 0.9803\n", "Epoch 6/6\n", "938/938 [==============================] - 21s 22ms/step - loss: 0.4124 - accuracy: 0.9814\n", "313/313 [==============================] - 2s 5ms/step - loss: 0.4507 - accuracy: 0.9762\n" ] } ], "source": [ "model.compile(optimizer='adam',\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'])\n", "\n", "model.fit(train_images, train_labels, batch_size=64, epochs=6)\n", "\n", "test_loss, test_acc = model.evaluate(test_images, test_labels)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jKgyC5K_4O0d" }, "source": [ "### Evaluate the model" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": {}, "colab_type": "code", "id": "0LvwaKhtUdOo" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy 97.62 %\n" ] } ], "source": [ "print(f\"Test accuracy {test_acc * 100:.2f} %\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8cfJ8AR03gT5" }, "source": [ "As you can see, our simple binarized CNN has achieved a test accuracy of around 98 %. Not bad for a few lines of code!\n", "\n", "For information on converting Larq models to an optimized format and using or benchmarking them on Android or ARM devices, have a look at [this guide](https://docs.larq.dev/compute-engine/end_to_end/)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "intro_to_cnns.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true, "version": "0.3.2" }, "kernel_info": { "name": "python3" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "nteract": { "version": "0.14.3" } }, "nbformat": 4, "nbformat_minor": 4 }