{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "18AF5Ab4p6VL" }, "source": [ "##### Copyright 2018 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "crfqaJOyp8bq" }, "source": [ "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", "https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "B_XlLLpcWjkA" }, "source": [ "# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n", "\n", "_Forked from_ `neural_network_and_data_loading.ipynb`\n", "\n", "_Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary_\n", "\n", "![JAX](https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png)\n", "\n", "Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "-8OFzj9TqXof" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting jaxlib\n", " Using cached https://files.pythonhosted.org/packages/06/af/c0d5f539820e97e8ec27f05a0ee50327fe34a35369e4e02ea45ce2a45c01/jaxlib-0.1.8-cp36-none-manylinux1_x86_64.whl\n", "Collecting scipy (from jaxlib)\n", " Using cached https://files.pythonhosted.org/packages/67/e6/6d4edaceee6a110ecf6f318482f5229792f143e468b34a631f5a0899f56d/scipy-1.2.0-cp36-cp36m-manylinux1_x86_64.whl\n", "Requirement already satisfied: protobuf>=3.6.0 in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jaxlib) (3.6.1)\n", "Requirement already satisfied: six in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jaxlib) (1.12.0)\n", "Requirement already satisfied: numpy>=1.12 in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jaxlib) (1.16.1)\n", "Requirement already satisfied: absl-py in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jaxlib) (0.7.0)\n", "Requirement already satisfied: setuptools in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from protobuf>=3.6.0->jaxlib) (40.8.0)\n", "Installing collected packages: scipy, jaxlib\n", "Successfully installed jaxlib-0.1.6 scipy-1.2.0\n", "Requirement already up-to-date: jax in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (0.1.16)\n", "Requirement already satisfied, skipping upgrade: six in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (1.12.0)\n", "Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (2.3.2)\n", "Requirement already satisfied, skipping upgrade: protobuf>=3.6.0 in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (3.6.1)\n", "Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (1.16.1)\n", "Requirement already satisfied, skipping upgrade: absl-py in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (0.7.0)\n", "Requirement already satisfied, skipping upgrade: setuptools in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from protobuf>=3.6.0->jax) (40.8.0)\n" ] } ], "source": [ "!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.13-cp36-none-linux_x86_64.whl\n", "!pip install --upgrade -q jax" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "OksHydJDtbbI" }, "outputs": [], "source": [ "from __future__ import print_function, division, absolute_import\n", "import jax.numpy as np\n", "from jax import grad, jit, vmap\n", "from jax import random" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MTVcKi-ZYB3R" }, "source": [ "### Hyperparameters\n", "Let's get a few bookkeeping items out of the way." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "-fmWA06xYE7d" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages/jax/lib/xla_bridge.py:146: UserWarning: No GPU found, falling back to CPU.\n", " warnings.warn('No GPU found, falling back to CPU.')\n" ] } ], "source": [ "# A helper function to randomly initialize weights and biases\n", "# for a dense neural network layer\n", "def random_layer_params(m, n, key, scale=1e-2):\n", " w_key, b_key = random.split(key)\n", " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", "\n", "# Initialize all layers for a fully-connected neural network with sizes \"sizes\"\n", "def init_network_params(sizes, key):\n", " keys = random.split(key, len(sizes))\n", " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", "\n", "layer_sizes = [784, 512, 512, 10]\n", "param_scale = 0.1\n", "step_size = 0.0001\n", "num_epochs = 10\n", "batch_size = 128\n", "n_targets = 10\n", "params = init_network_params(layer_sizes, random.PRNGKey(0))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "BtoNk_yxWtIw" }, "source": [ "### Auto-batching predictions\n", "\n", "Let us first define our prediction function. Note that we're defining this for a _single_ image example. We're going to use JAX's `vmap` function to automatically handle mini-batches, with no performance penalty." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "id": "7APc6tD7TiuZ" }, "outputs": [], "source": [ "from jax.scipy.misc import logsumexp\n", "\n", "def relu(x):\n", " return np.maximum(0, x)\n", "\n", "def predict(params, image):\n", " # per-example predictions\n", " activations = image\n", " for w, b in params[:-1]:\n", " outputs = np.dot(w, activations) + b\n", " activations = relu(outputs)\n", " \n", " final_w, final_b = params[-1]\n", " logits = np.dot(final_w, activations) + final_b\n", " return logits - logsumexp(logits)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dRW_TvCTWgaP" }, "source": [ "Let's check that our prediction function only works on single images." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "4sW2A5mnXHc5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10,)\n" ] } ], "source": [ "# This works on single examples\n", "random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))\n", "preds = predict(params, random_flattened_image)\n", "print(preds.shape)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", "id": "PpyQxuedXfhp" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Invalid shapes!\n" ] } ], "source": [ "# Doesn't work with a batch\n", "random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))\n", "try:\n", " preds = predict(params, random_flattened_images)\n", "except TypeError:\n", " print('Invalid shapes!')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", "id": "oJOOncKMXbwK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10, 10)\n" ] } ], "source": [ "# Let's upgrade it to handle batches using `vmap`\n", "\n", "# Make a batched version of the `predict` function\n", "batched_predict = vmap(predict, in_axes=(None, 0))\n", "\n", "# `batched_predict` has the same call signature as `predict`\n", "batched_preds = batched_predict(params, random_flattened_images)\n", "print(batched_preds.shape)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "elsG6nX03BvW" }, "source": [ "At this point, we have all the ingredients we need to define our neural network and train it. We've built an auto-batched version of `predict`, which we should be able to use in a loss function. We should be able to use `grad` to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use `jit` to speed up everything." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "NwDuFqc9X7ER" }, "source": [ "### Utility and loss functions" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "id": "6lTI6I4lWdh5" }, "outputs": [], "source": [ "def one_hot(x, k, dtype=np.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return np.array(x[:, None] == np.arange(k), dtype)\n", " \n", "def accuracy(params, images, targets):\n", " target_class = np.argmax(targets, axis=1)\n", " predicted_class = np.argmax(batched_predict(params, images), axis=1)\n", " return np.mean(predicted_class == target_class)\n", "\n", "def loss(params, images, targets):\n", " preds = batched_predict(params, images)\n", " return -np.sum(preds * targets)\n", "\n", "@jit\n", "def update(params, x, y):\n", " grads = grad(loss)(params, x, y)\n", " return [(w - step_size * dw, b - step_size * db)\n", " for (w, b), (dw, db) in zip(params, grads)]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "umJJGZCC2oKl" }, "source": [ "### Data Loading with `tensorflow/datasets`\n", "\n", "JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": {}, "colab_type": "code", "id": "gEvWt8_u2pqG" }, "outputs": [], "source": [ "# Install tensorflow-datasets\n", "# TODO(rsepassi): Switch to stable version on release\n", "!pip install -q --upgrade tfds-nightly tf-nightly" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import tensorflow_datasets as tfds\n", "\n", "data_dir = '/tmp/tfds'\n", "\n", "# Fetch full datasets for evaluation\n", "# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)\n", "# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy\n", "mnist_data, info = tfds.load(name=\"mnist\", batch_size=-1, data_dir=data_dir, with_info=True)\n", "mnist_data = tfds.as_numpy(mnist_data)\n", "train_data, test_data = mnist_data['train'], mnist_data['test']\n", "num_labels = info.features['label'].num_classes\n", "h, w, c = info.features['image'].shape\n", "num_pixels = h * w * c\n", "\n", "# Full train set\n", "train_images, train_labels = train_data['image'], train_data['label']\n", "train_images = np.reshape(train_images, (len(train_images), num_pixels))\n", "train_labels = one_hot(train_labels, num_labels)\n", "\n", "# Full test set\n", "test_images, test_labels = test_data['image'], test_data['label']\n", "test_images = np.reshape(test_images, (len(test_images), num_pixels))\n", "test_labels = one_hot(test_labels, num_labels)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train: (60000, 784) (60000, 10)\n", "Test: (10000, 784) (10000, 10)\n" ] } ], "source": [ "print('Train:', train_images.shape, train_labels.shape)\n", "print('Test:', test_images.shape, test_labels.shape)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xxPd6Qw3Z98v" }, "source": [ "### Training Loop" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": {}, "colab_type": "code", "id": "X2DnZo3iYj18" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 in 4.93 sec\n", "Training set accuracy 0.9690666794776917\n", "Test set accuracy 0.9631999731063843\n", "Epoch 1 in 3.91 sec\n", "Training set accuracy 0.9807999730110168\n", "Test set accuracy 0.97079998254776\n", "Epoch 2 in 4.02 sec\n", "Training set accuracy 0.9878833293914795\n", "Test set accuracy 0.9763000011444092\n", "Epoch 3 in 4.03 sec\n", "Training set accuracy 0.992733359336853\n", "Test set accuracy 0.9787999987602234\n", "Epoch 4 in 3.95 sec\n", "Training set accuracy 0.9907500147819519\n", "Test set accuracy 0.9745000004768372\n", "Epoch 5 in 4.01 sec\n", "Training set accuracy 0.9953666925430298\n", "Test set accuracy 0.9782000184059143\n", "Epoch 6 in 3.90 sec\n", "Training set accuracy 0.9984833598136902\n", "Test set accuracy 0.9815000295639038\n", "Epoch 7 in 3.93 sec\n", "Training set accuracy 0.9991166591644287\n", "Test set accuracy 0.9824000000953674\n", "Epoch 8 in 4.16 sec\n", "Training set accuracy 0.999833345413208\n", "Test set accuracy 0.982200026512146\n", "Epoch 9 in 4.03 sec\n", "Training set accuracy 0.999916672706604\n", "Test set accuracy 0.9829999804496765\n" ] } ], "source": [ "import time\n", "\n", "def get_train_batches():\n", " # as_supervised=True gives us the (image, label) as a tuple instead of a dict\n", " ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)\n", " # You can build up an arbitrary tf.data input pipeline\n", " ds = ds.batch(128).prefetch(1)\n", " # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays\n", " return tfds.as_numpy(ds)\n", "\n", "for epoch in range(num_epochs):\n", " start_time = time.time()\n", " for x, y in get_train_batches():\n", " x = np.reshape(x, (len(x), num_pixels))\n", " y = one_hot(y, num_labels)\n", " params = update(params, x, y)\n", " epoch_time = time.time() - start_time\n", "\n", " train_acc = accuracy(params, train_images, train_labels)\n", " test_acc = accuracy(params, test_images, test_labels)\n", " print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n", " print(\"Training set accuracy {}\".format(train_acc))\n", " print(\"Test set accuracy {}\".format(test_acc))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xC1CMcVNYwxm" }, "source": [ "We've now used the whole of the JAX API: `grad` for derivatives, `jit` for speedups and `vmap` for auto-vectorization.\n", "We used NumPy to specify all of our computation, and borrowed the great data loaders from `tensorflow/datasets`, and ran the whole thing on the GPU." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "neural-network-and-data-loading.ipynb", "provenance": [], "toc_visible": true, "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 1 }