{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "j_LlXHYcmRaC" }, "source": [ "# Lookahead Optimizer on MNIST\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/lookahead_mnist.ipynb)\n", "\n", "This notebook trains a simple Convolution Neural Network (CNN) for hand-written digit recognition (MNIST dataset) using the [Lookahead optimizer](https://arxiv.org/pdf/1907.08610v1.pdf)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9cu0kFNrnJj7" }, "outputs": [], "source": [ "from flax import linen as nn\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2Adl_l_uZs1d" }, "outputs": [], "source": [ "# @markdown The learning rate for the fast optimizer:\n", "FAST_LEARNING_RATE = 0.002 # @param{type:\"number\"}\n", "# @markdown The learning rate for the slow optimizer:\n", "SLOW_LEARNING_RATE = 0.1 # @param{type:\"number\"}\n", "# @markdown Number of fast optimizer steps to take before synchronizing parameters:\n", "SYNC_PERIOD = 5 # @param{type:\"integer\"}\n", "# @markdown Number of samples in each batch:\n", "BATCH_SIZE = 256 # @param{type:\"integer\"}\n", "# @markdown Total number of epochs to train for:\n", "N_EPOCHS = 1 # @param{type:\"integer\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "ZZej3FcOhuRE" }, "source": [ "MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using `tensorflow_datasets`, apply min-max normalization to images, shuffle the data in the train set and create batches of size `BATCH_SIZE`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xPZ0paOehHWg" }, "outputs": [], "source": [ "(train_loader, test_loader), info = tfds.load(\n", " \"mnist\", split=[\"train\", \"test\"], as_supervised=True, with_info=True\n", ")\n", "NUM_CLASSES = info.features[\"label\"].num_classes\n", "IMG_SIZE = info.features[\"image\"].shape\n", "\n", "min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)\n", "train_loader = train_loader.map(min_max_rgb)\n", "test_loader = test_loader.map(min_max_rgb)\n", "\n", "train_loader_batched = train_loader.shuffle(\n", " buffer_size=10_000, reshuffle_each_iteration=True\n", ").batch(BATCH_SIZE, drop_remainder=True)\n", "\n", "test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "XkLaC2MlbAqa" }, "source": [ "The data is ready! Next let's define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax to implement a simple CNN." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RppusWrcaXzX" }, "outputs": [], "source": [ "class CNN(nn.Module):\n", " \"\"\"A simple CNN model.\"\"\"\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = x.reshape((x.shape[0], -1)) # flatten\n", " x = nn.Dense(features=256)(x)\n", " x = nn.relu(x)\n", " x = nn.Dense(features=10)(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DKOi55MgdPyp" }, "outputs": [], "source": [ "net = CNN()\n", "\n", "@jax.jit\n", "def predict(params, inputs):\n", " return net.apply({'params': params}, inputs)\n", "\n", "\n", "@jax.jit\n", "def loss_accuracy(params, data):\n", " \"\"\"Computes loss and accuracy over a mini-batch.\n", "\n", " Args:\n", " params: parameters of the model.\n", " bn_params: state of the model.\n", " data: tuple of (inputs, labels).\n", " is_training: if true, uses train mode, otherwise uses eval mode.\n", "\n", " Returns:\n", " loss: float\n", " \"\"\"\n", " inputs, labels = data\n", " logits = predict(params, inputs)\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits=logits, labels=labels\n", " ).mean()\n", " accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)\n", " return loss, {\"accuracy\": accuracy}\n", "\n", "@jax.jit\n", "def update_model(state, grads):\n", " return state.apply_gradients(grads=grads)" ] }, { "cell_type": "markdown", "metadata": { "id": "0eB2dhIpjTIi" }, "source": [ "Next we need to initialize CNN parameters and solver state. We also define a convenience function `dataset_stats` that we'll call once per epoch to collect the loss and accuracy of our solver over the test set. We will be using the Lookahead optimizer.\n", "Its wrapper keeps a pair of slow and fast parameters. To\n", "initialize them, we create a pair of synchronized parameters from the\n", "initial model parameters.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PBnbq7gui34L" }, "outputs": [], "source": [ "fast_solver = optax.adam(FAST_LEARNING_RATE)\n", "solver = optax.lookahead(fast_solver, SYNC_PERIOD, SLOW_LEARNING_RATE)\n", "rng = jax.random.PRNGKey(0)\n", "dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)\n", "\n", "params = net.init({\"params\": rng}, dummy_data)[\"params\"]\n", "\n", "# Initializes the lookahead optimizer with the initial model parameters.\n", "params = optax.LookaheadParams.init_synced(params)\n", "solver_state = solver.init(params)\n", "\n", "def dataset_stats(params, data_loader):\n", " \"\"\"Computes loss and accuracy over the dataset `data_loader`.\"\"\"\n", " all_accuracy = []\n", " all_loss = []\n", " for batch in data_loader.as_numpy_iterator():\n", " batch_loss, batch_aux = loss_accuracy(params, batch)\n", " all_loss.append(batch_loss)\n", " all_accuracy.append(batch_aux[\"accuracy\"])\n", " return {\"loss\": np.mean(all_loss), \"accuracy\": np.mean(all_accuracy)}" ] }, { "cell_type": "markdown", "metadata": { "id": "4H6GWNJf0XTY" }, "source": [ "Finally, we do the actual training. The next cell train the model for `N_EPOCHS`. Within each epoch we iterate over the batched loader `train_loader_batched`, and once per epoch we also compute the test set accuracy and loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DeQr0urBjoDj" }, "outputs": [], "source": [ "train_accuracy = []\n", "train_losses = []\n", "\n", "# Computes test set accuracy at initialization.\n", "test_stats = dataset_stats(params.slow, test_loader_batched)\n", "test_accuracy = [test_stats[\"accuracy\"]]\n", "test_losses = [test_stats[\"loss\"]]\n", "\n", "\n", "@jax.jit\n", "def train_step(params, solver_state, batch):\n", " # Performs a one step update.\n", " (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(\n", " params.fast, batch\n", " )\n", " updates, solver_state = solver.update(grad, solver_state, params)\n", " params = optax.apply_updates(params, updates)\n", " return params, solver_state, loss, aux\n", "\n", "\n", "for epoch in range(N_EPOCHS):\n", " train_accuracy_epoch = []\n", " train_losses_epoch = []\n", "\n", " for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):\n", " params, solver_state, train_loss, train_aux = train_step(\n", " params, solver_state, train_batch\n", " )\n", " train_accuracy_epoch.append(train_aux[\"accuracy\"])\n", " train_losses_epoch.append(train_loss)\n", " if step % 20 == 0:\n", " print(\n", " f\"step {step}, train loss: {train_loss:.2e}, train accuracy:\"\n", " f\" {train_aux['accuracy']:.2f}\"\n", " )\n", "\n", " # Validation is done on the slow lookahead parameters.\n", " test_stats = dataset_stats(params.slow, test_loader_batched)\n", " test_accuracy.append(test_stats[\"accuracy\"])\n", " test_losses.append(test_stats[\"loss\"])\n", " train_accuracy.append(np.mean(train_accuracy_epoch))\n", " train_losses.append(np.mean(train_losses_epoch))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yyS1oRZBtytP" }, "outputs": [], "source": [ "f\"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}\"" ] } ], "metadata": { "colab": { "provenance": [] }, "execution": { "timeout": -1 }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }