{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "uJHywE_oL3j2" }, "source": [ "# Differentially private convolutional neural network on MNIST.\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/differentially_private_sgd.ipynb)\n", "\n", "A large portion of this code is forked from the differentially private SGD\n", "example in the [JAX repo](\n", "https://github.com/google/jax/blob/main/examples/differentially_private_sgd.py).\n", "\n", "[Differentially Private Stochastic Gradient Descent](https://arxiv.org/abs/1607.00133) requires clipping the per-example parameter\n", "gradients, which is non-trivial to implement efficiently for convolutional\n", "neural networks. The JAX XLA compiler shines in this setting by optimizing the\n", "minibatch-vectorized computation for convolutional architectures. Train time\n", "takes a few seconds per epoch on a commodity GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VaYIiCnjL3j3" }, "outputs": [], "source": [ "import warnings\n", "import dp_accounting\n", "import jax\n", "import jax.numpy as jnp\n", "from optax import contrib\n", "from optax import losses\n", "import optax\n", "from jax.example_libraries import stax\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "import matplotlib.pyplot as plt\n", "\n", "# Shows on which platform JAX is running.\n", "print(\"JAX running on\", jax.devices()[0].platform.upper())" ] }, { "cell_type": "markdown", "metadata": { "id": "t7Dn8L_Uw0Yb" }, "source": [ "This table contains hyperparameters and the corresponding expected test accuracy.\n", "\n", "\n", "| DPSGD | LEARNING_RATE | NOISE_MULTIPLIER | L2_NORM_CLIP | BATCH_SIZE | NUM_EPOCHS | DELTA | FINAL TEST ACCURACY |\n", "| ------ | ------------- | ---------------- | ------------ | ---------- | ---------- | ----- | ------------------- |\n", "| False | 0.1 | NA | NA | 256 | 20 | NA | ~99% |\n", "| True | 0.25 | 1.3 | 1.5 | 256 | 15 | 1e-5 | ~95% |\n", "| True | 0.15 | 1.1 | 1.0 | 256 | 60 | 1e-5 | ~96.6% |\n", "| True | 0.25 | 0.7 | 1.5 | 256 | 45 | 1e-5 | ~97% |" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jve2h810L3j3" }, "outputs": [], "source": [ "# Whether to use DP-SGD or vanilla SGD:\n", "DPSGD = True\n", "# Learning rate for the optimizer:\n", "LEARNING_RATE = 0.25\n", "# Noise multiplier for DP-SGD optimizer:\n", "NOISE_MULTIPLIER = 1.3\n", "# L2 norm clip:\n", "L2_NORM_CLIP = 1.5\n", "# Number of samples in each batch:\n", "BATCH_SIZE = 256\n", "# Number of epochs:\n", "NUM_EPOCHS = 15\n", "# Probability of information leakage:\n", "DELTA = 1e-5" ] }, { "cell_type": "markdown", "metadata": { "id": "iLGeV4y4DBkL" }, "source": [ "CIFAR10 and CIFAR100 are composed of 32x32 images with 3 channels (RGB). We'll now load the dataset using `tensorflow_datasets` and display a few of the first samples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zynvtk4wDBkL" }, "outputs": [], "source": [ "(train_loader, test_loader), info = tfds.load(\n", " \"mnist\", split=[\"train\", \"test\"], as_supervised=True, with_info=True\n", ")\n", "\n", "min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)\n", "train_loader = train_loader.map(min_max_rgb)\n", "test_loader = test_loader.map(min_max_rgb)\n", "\n", "train_loader_batched = train_loader.shuffle(\n", " buffer_size=10_000, reshuffle_each_iteration=True\n", ").batch(BATCH_SIZE, drop_remainder=True)\n", "\n", "NUM_EXAMPLES = info.splits[\"test\"].num_examples\n", "test_batch = next(test_loader.batch(NUM_EXAMPLES, drop_remainder=True).as_numpy_iterator())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o6In7oQ-0EhG" }, "outputs": [], "source": [ "init_random_params, predict = stax.serial(\n", " stax.Conv(16, (8, 8), padding=\"SAME\", strides=(2, 2)),\n", " stax.Relu,\n", " stax.MaxPool((2, 2), (1, 1)),\n", " stax.Conv(32, (4, 4), padding=\"VALID\", strides=(2, 2)),\n", " stax.Relu,\n", " stax.MaxPool((2, 2), (1, 1)),\n", " stax.Flatten,\n", " stax.Dense(32),\n", " stax.Relu,\n", " stax.Dense(10),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "j2OUgc6J0Jsl" }, "source": [ "This function computes the privacy parameter epsilon for the given number of steps and probability of information leakage `DELTA`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "43177TofzuOa" }, "outputs": [], "source": [ "def compute_epsilon(steps):\n", " if NUM_EXAMPLES * DELTA \u003e 1.:\n", " warnings.warn(\"Your delta might be too high.\")\n", " q = BATCH_SIZE / float(NUM_EXAMPLES)\n", " orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))\n", " accountant = dp_accounting.rdp.RdpAccountant(orders)\n", " accountant.compose(dp_accounting.PoissonSampledDpEvent(\n", " q, dp_accounting.GaussianDpEvent(NOISE_MULTIPLIER)), steps)\n", " return accountant.get_epsilon(DELTA)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "W9mPtPvB0D3X" }, "outputs": [], "source": [ "@jax.jit\n", "def loss_fn(params, batch):\n", " images, labels = batch\n", " logits = predict(params, images)\n", " return losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean(), logits\n", "\n", "\n", "@jax.jit\n", "def test_step(params, batch):\n", " images, labels = batch\n", " logits = predict(params, images)\n", " loss = losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean()\n", " accuracy = (logits.argmax(1) == labels).mean()\n", " return loss, accuracy * 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vOet-_860ysL" }, "outputs": [], "source": [ "if DPSGD:\n", " tx = contrib.dpsgd(\n", " learning_rate=LEARNING_RATE, l2_norm_clip=L2_NORM_CLIP,\n", " noise_multiplier=NOISE_MULTIPLIER, seed=1337)\n", "else:\n", " tx = optax.sgd(learning_rate=LEARNING_RATE)\n", "\n", "_, params = init_random_params(jax.random.PRNGKey(1337), (-1, 28, 28, 1))\n", "opt_state = tx.init(params)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b-NmP7g01EdA" }, "outputs": [], "source": [ "@jax.jit\n", "def train_step(params, opt_state, batch):\n", " grad_fn = jax.grad(loss_fn, has_aux=True)\n", " if DPSGD:\n", " # Inserts a dimension in axis 1 to use jax.vmap over the batch.\n", " batch = jax.tree_util.tree_map(lambda x: x[:, None], batch)\n", " # Uses jax.vmap across the batch to extract per-example gradients.\n", " grad_fn = jax.vmap(grad_fn, in_axes=(None, 0))\n", "\n", " grads, _ = grad_fn(params, batch)\n", " updates, new_opt_state = tx.update(grads, opt_state, params)\n", " new_params = optax.apply_updates(params, updates)\n", " return new_params, new_opt_state" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QMl9dnbJ1OtQ" }, "outputs": [], "source": [ "accuracy, loss, epsilon = [], [], []\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " for batch in train_loader_batched.as_numpy_iterator():\n", " params, opt_state = train_step(params, opt_state, batch)\n", "\n", " # Evaluates test accuracy.\n", " test_loss, test_acc = test_step(params, test_batch)\n", " accuracy.append(test_acc)\n", " loss.append(test_loss)\n", " print(f\"Epoch {epoch + 1}/{NUM_EPOCHS}, test accuracy: {test_acc}\")\n", "\n", " #\n", " if DPSGD:\n", " steps = (1 + epoch) * NUM_EXAMPLES // BATCH_SIZE\n", " eps = compute_epsilon(steps)\n", " epsilon.append(eps)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9nsV-9_b2qca" }, "outputs": [], "source": [ "if DPSGD:\n", " _, axs = plt.subplots(ncols=3, figsize=(9, 3))\n", "else:\n", " _, axs = plt.subplots(ncols=2, figsize=(6, 3))\n", "\n", "axs[0].plot(accuracy)\n", "axs[0].set_title(\"Test accuracy\")\n", "axs[1].plot(loss)\n", "axs[1].set_title(\"Test loss\")\n", "\n", "if DPSGD:\n", " axs[2].plot(epsilon)\n", " axs[2].set_title(\"Epsilon\")\n", "\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1ubOEWod3OPj" }, "outputs": [], "source": [ "print(f'Final accuracy: {accuracy[-1]}')" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/grp/tools/ml_python:ml_notebook", "kind": "private" }, "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }