{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "2To_0Q86toj4" }, "source": [ "# Optimistic Gradient Descent in a Bilinear Min-Max Problem\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/ogda_example.ipynb)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MMvXmgsvTmcl" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "VqpOVQIgTuJ0" }, "source": [ "Consider the following min-max problem:\n", "\n", "$$\n", "\\min_{x \\in \\mathbb R^m} \\max_{y\\in\\mathbb R^n} f(x,y),\n", "$$\n", "\n", "where $f: \\mathbb R^m \\times \\mathbb R^n \\to \\mathbb R$ is a convex-concave function. The solution to such a problem is a saddle-point $(x^\\star, y^\\star)\\in \\mathbb R^m \\times \\mathbb R^n$ such that\n", "\n", "$$\n", "f(x^\\star, y) \\leq f(x^\\star, y^\\star) \\leq f(x, y^\\star).\n", "$$\n", "\n", "Standard gradient descent-ascent (GDA) updates $x$ and $y$ according to the following update rule at step $k$: \n", "\n", "$$\n", "x_{k+1} = x_k - \\eta_k \\nabla_x f(x_k, y_k) \\\\\n", "y_{k+1} = y_k + \\eta_k \\nabla_y f(x_k, y_k),\n", "$$\n", "\n", "where $\\eta_k$ is a step size. However, it's well-documented that GDA can fail to converge in this setting. This is an important issue because gradient-based min-max optimisation is increasingly prevalent in machine learning (e.g., GANs, constrained RL). *Optimistic* GDA (OGDA) addresses this shortcoming by introducing a form of memory-based negative momentum: \n", "\n", "$$\n", "x_{k+1} = x_k - 2 \\eta_k \\nabla_x f(x_k, y_k) + \\eta_k \\nabla_x f(x_{k-1}, y_{k-1}) \\\\\n", "y_{k+1} = y_k + 2 \\eta_k \\nabla_y f(x_k, y_k) - \\eta_k \\nabla_y f(x_{k-1}, y_{k-1})).\n", "$$\n", "\n", "Thus, to implement OGD (or OGA), the optimiser needs to keep track of the gradient from the previous step. OGDA has been formally shown to converge to the optimum $(x_k, y_k) \\to (x^\\star, y^\\star)$ in this setting. The generalised form of the OGDA update rule is given by\n", "\n", "$$\n", "x_{k+1} = x_k - (\\alpha + \\beta) \\eta_k \\nabla_x f(x_k, y_k) + \\beta \\eta_k \\nabla_x f(x_{k-1}, y_{k-1}) \\\\\n", "y_{k+1} = y_k + (\\alpha + \\beta) \\eta_k \\nabla_y f(x_k, y_k) - \\beta \\eta_k \\nabla_y f(x_{k-1}, y_{k-1})),\n", "$$\n", "\n", "which recovers standard OGDA when $\\alpha=\\beta=1$. See [Mokhtari et al., 2019](https://arxiv.org/abs/1901.08511v2) for more details. " ] }, { "cell_type": "markdown", "metadata": { "id": "WxcGo14yWhWn" }, "source": [ "$$\n", "\\pi^{k+1} = \\pi^k - \\tau_\\pi^k \\nabla_\\pi \\mathcal L(\\pi^k, \\mu^k) \\\\\n", "\\mu^{k+1} = \\mu^k + \\tau_\\mu^k \\nabla_\\mu \\mathcal L(\\pi^k_k, \\mu^k),\n", "$$\n", "\n", "$$\n", "\\pi^{k+1} = \\pi^k - 2\\tau_\\pi^k \\nabla_\\pi \\mathcal L(\\pi^k, \\mu^k) + \\tau_\\pi^k \\nabla_\\pi \\mathcal L(\\pi^{k-1}, \\mu^{k-1})\\\\\n", "\\mu^{k+1} = \\mu^k + 2\\tau_\\mu^k \\nabla_\\mu \\mathcal L(\\pi^k_k, \\mu^k)+ \\tau_\\mu^k \\nabla_\\mu \\mathcal L(\\pi^{k-1}, \\mu^{k-1})\n", "$$\n", "\n", "where $\\eta_k$ is a step size. However, it's well-documented that GDA can fail to converge in this setting. This is an important issue because gradient-based min-max optimisation is increasingly prevalent in machine learning (e.g., GANs, constrained RL). *Optimistic* GDA (OGDA) addresses this shortcoming by introducing a form of memory-based negative momentum:\n", "\n", "$$\n", "x_{k+1} = x_k - 2 \\eta_k \\nabla_x f(x_k, y_k) + \\eta_k \\nabla_x f(x_{k-1}, y_{k-1}) \\\\\n", "y_{k+1} = y_k + 2 \\eta_k \\nabla_y f(x_k, y_k) - \\eta_k \\nabla_y f(x_{k-1}, y_{k-1})).\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "id": "nSyJyTSXszQ0" }, "source": [ "Define a bilinear min-max objective function: $\\min_x \\max_y xy$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "snDy575-iDXw" }, "outputs": [], "source": [ "def f(params: jnp.ndarray) -> jnp.ndarray:\n", " \"\"\"Objective: min_x max_y xy.\"\"\"\n", " return params[\"x\"] * params[\"y\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "G-4JMKlgs-Lr" }, "source": [ "Define an optimisation loop." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MXXxtGs0qlfy" }, "outputs": [], "source": [ "def optimise(params: optax.Params, x_optimiser: optax.GradientTransformation, y_optimiser: optax.GradientTransformation, n_steps: int = 1000, display_every: int = 100) -> optax.Params:\n", " \"\"\"An optimisation loop minimising x and maximising y.\"\"\"\n", "\n", " x_opt_state = x_optimiser.init(params[\"x\"])\n", " y_opt_state = y_optimiser.init(params[\"y\"])\n", " param_hist = [params]\n", " f_hist = []\n", "\n", " @jax.jit\n", " def step(params, x_opt_state, y_opt_state):\n", " f_value, grads = jax.value_and_grad(f)(params)\n", " x_update, x_opt_state = x_optimiser.update(grads[\"x\"], x_opt_state, params[\"x\"])\n", " # note that we\"re maximising y so we feed in the negative gradient to the OGD update\n", " y_update, y_opt_state = y_optimiser.update(-grads[\"y\"], y_opt_state, params[\"y\"])\n", " updates = {\"x\": x_update, \"y\": y_update}\n", " params = optax.apply_updates(params, updates)\n", " return params, x_opt_state, y_opt_state, f_value\n", "\n", " for k in range(n_steps):\n", " params, x_opt_state, y_opt_state, f_value = step(params, x_opt_state, y_opt_state)\n", " param_hist.append(params)\n", " f_hist.append(f_value)\n", " if k % display_every == 0:\n", " print(f\"step {k}, f(x, y) = {f_value}, (x, y) = ({params['x']}, {params['y']})\")\n", "\n", " return param_hist, f_hist" ] }, { "cell_type": "markdown", "metadata": { "id": "gDtB7gJdtPZj" }, "source": [ "Initialise $x$ and $y$, as well as optimisers for each. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JvlhEUMat1PX" }, "outputs": [], "source": [ "initial_params = {\n", " \"x\": jnp.array(1.0),\n", " \"y\": jnp.array(1.0)\n", "}\n", "\n", "# GDA\n", "x_gd_optimiser = optax.sgd(learning_rate=0.1)\n", "y_ga_optimiser = optax.sgd(learning_rate=0.1)\n", "\n", "# OGDA\n", "x_ogd_optimiser = optax.optimistic_gradient_descent(learning_rate=0.1)\n", "y_oga_optimiser = optax.optimistic_gradient_descent(learning_rate=0.1)" ] }, { "cell_type": "markdown", "metadata": { "id": "DF8oQEjLRO3a" }, "source": [ "Run each method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E-0_CtpjuEGi" }, "outputs": [], "source": [ "gda_hist, gda_f_hist = optimise(initial_params, x_gd_optimiser, y_ga_optimiser)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wBeMQEILwKsJ" }, "outputs": [], "source": [ "ogda_hist, ogda_f_hist = optimise(initial_params, x_ogd_optimiser, y_oga_optimiser)" ] }, { "cell_type": "markdown", "metadata": { "id": "S504XrZrtXNe" }, "source": [ "Visualise the optimisation trajectories. The optimal solution is $(0, 0)$. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "S-XvwF9HujRT" }, "outputs": [], "source": [ "gda_xs, gda_ys = [p[\"x\"] for p in gda_hist], [p[\"y\"] for p in gda_hist]\n", "ogda_xs, ogda_ys = [p[\"x\"] for p in ogda_hist], [p[\"y\"] for p in ogda_hist]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WDgiytoDvX8N" }, "outputs": [], "source": [ "plt.plot(gda_xs, gda_ys, alpha=0.6, color=\"C0\", label=\"GDA\")\n", "plt.plot(ogda_xs, ogda_ys, alpha=0.6, color=\"C1\", label=\"OGDA\")\n", "plt.scatter([1], [1], color=\"r\", label=r\"$(x_0, y_0)$\", s=30)\n", "plt.scatter([0], [0], color=\"k\", label=r\"$(x^\\star, y^\\star)$\", s=30)\n", "plt.xlim([-2.0, 2.0])\n", "plt.ylim([-2.0, 2.0])\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "", "kind": "private" }, "name": "ogda_example.ipynb", "private_outputs": true, "provenance": [ { "file_id": "15hB4sFTdcHM7tf7l03PHZdwaP7AfdI8K", "timestamp": 1658099029490 }, { "file_id": "1Orjeh6PdEz2Vuj_XGGAStvkyck_zzT72", "timestamp": 1658096315530 }, { "file_id": "1v26CV4ivf38ZnYyd18PDnRFX4WMi-MX9", "timestamp": 1657803636509 } ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }