{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VuJvL5osUlcB" }, "source": [ "# Breaking ADAM\n", "\n", "In case you haven't heard, one of the top papers at [ICLR 2018](https://iclr.cc/Conferences/2018) (pronounced:\n", "eye-clear, who knew?) was [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ). In the\n", "paper, the authors determine a flaw in the convergence proof of the ubiquitous ADAM optimizer. They also give an example\n", "of a simple function for which ADAM does not converge to the correct solution. We've seen how torchbearer can be used\n", "for simple function optimisation before and we can do something similar to reproduce the results\n", "from the paper. We should note that this isn't a suggestion that you should necesarily use AMSGrad (the proposed solution from the paper), most work still uses either SGD with momentum or vanilla Adam and some other approaches have been proposed such as [AdamW](https://arxiv.org/abs/1711.05101) which may provide benefits. Our intention here is simply to show how this interesting failure case can be demonstrated with torchbearer.\n", "\n", "**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU with\n", "\n", "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n", "\n", "## Install Torchbearer\n", "\n", "First we install torchbearer if needed." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.0.dev\n" ] } ], "source": [ "try:\n", " import torchbearer\n", "except:\n", " !pip install -q torchbearer\n", " import torchbearer\n", " \n", "print(torchbearer.__version__)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "c8usgF33UuCu" }, "source": [ "Online Optimization\n", "-----------------------------------\n", "\n", "Online learning basically just means learning from one example at a time, in sequence. The function given in the paper\n", " has a unique minimum at $x=-1$ and is defined as follows:\n", "\n", "\\begin{equation}\n", "f_t(x) = \\begin{cases}1010x, & \\text{for } t \\; \\texttt{mod} \\; 101 = 1 \\\\ -10x, & \\text{otherwise}\\end{cases}\n", "\\end{equation}\n", "\n", "We can then write this as a PyTorch model whose forward is a function of its parameters with the following:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "X-g1QIZpU3dZ" }, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class Online(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.x = nn.Parameter(torch.zeros(1))\n", "\n", " def forward(self, _, state):\n", " \"\"\"\n", " function to be minimised:\n", " f(x) = 1010x if t mod 101 = 1, else -10x\n", " \"\"\"\n", " if state[torchbearer.BATCH] % 101 == 1:\n", " res = 1010 * self.x\n", " else:\n", " res = -10 * self.x\n", "\n", " return res" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8QMiqLWKU4ml" }, "source": [ "We now define a loss (simply return the model output) and a metric which returns the value of our parameter `x`:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "gTeTSbbcU7qD" }, "outputs": [], "source": [ "def loss(y_pred, _):\n", " return y_pred\n", "\n", "\n", "@torchbearer.metrics.to_dict\n", "class est(torchbearer.metrics.Metric):\n", " def __init__(self):\n", " super().__init__('est')\n", "\n", " def process(self, state):\n", " return state[torchbearer.MODEL].x.data.item()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0ni3L0jrU9Z6" }, "source": [ "In the paper, `x` can only hold values in `[-1, 1]`. We don't strictly need to do anything but we can write\n", "a callback that greedily updates `x` if it is outside of its range as follows:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "ntahEfxTVCZ9" }, "outputs": [], "source": [ "@torchbearer.callbacks.on_step_training\n", "def greedy_update(state):\n", " if state[torchbearer.MODEL].x > 1:\n", " state[torchbearer.MODEL].x.data.fill_(1)\n", " elif state[torchbearer.MODEL].x < -1:\n", " state[torchbearer.MODEL].x.data.fill_(-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ideally, we would like to generate the graphs from the paper. To do this, we can use the tensorboard callback from torchbearer. However, we might also want to see a live graph directly in the notebook. For that we can use a callback which updates at the end of each step." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "\n", "# Notebook doesn't seem to work right in Colab\n", "if IN_COLAB:\n", " %matplotlib inline\n", "else:\n", " %matplotlib notebook\n", " \n", "from matplotlib import pyplot as plt\n", "\n", "class Plotter():\n", " def __init__(self):\n", " self.plots = []\n", " self.fig = plt.figure(figsize=(5, 5))\n", "\n", " def make_plotter(self, name, c='k', step_size=100):\n", " idx = len(self.plots)\n", " self.plots.append(([], [], c, name))\n", "\n", " @torchbearer.callbacks.on_step_training\n", " @torchbearer.callbacks.only_if(lambda state: state[torchbearer.BATCH] % step_size == 0)\n", " def store(state):\n", " self.plots[idx][0].append(len(self.plots[idx][0]) * step_size)\n", " self.plots[idx][1].append(state[torchbearer.METRICS]['est'])\n", "\n", " @torchbearer.callbacks.on_step_training\n", " @torchbearer.callbacks.only_if(lambda state: state[torchbearer.BATCH] % (2 * step_size) == 0)\n", " def plot(state):\n", " plt.clf()\n", " for plot in self.plots:\n", " plt.plot(plot[0], plot[1], plot[2], label=plot[3])\n", " plt.legend()\n", " plt.xlabel('Step')\n", " plt.ylabel('Estimate')\n", " self.fig.canvas.draw()\n", "\n", " return torchbearer.callbacks.CallbackList([store, plot])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "yzS_TWlHVCw9" }, "source": [ "Given a figure, the above code will plot the estimate history every given number of steps, although in Colab this will just plot the graph at the end. Finally, we can train this model twice; once with ADAM and once with AMSGrad (included in PyTorch) with just a few\n", "lines (this will take at least a few minutes on a GPU):" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 444, "resources": { "http://localhost:8080/nbextensions/google.colab/colabwidgets/controls.css": { "data": "", "headers": [ [ "content-type", "text/css" ] ], "ok": true, "status": 200, "status_text": "" } } }, "colab_type": "code", "id": "WibotVw8VEKg", "outputId": "68d8eee2-1df9-4fa3-de93-1197f9071bab" }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import sys\n", "if 'tensorboardX' in sys.modules:\n", " from torchbearer.callbacks import TensorBoard\n", "else:\n", " import mock\n", " TensorBoard = mock.MagicMock()\n", "import torch\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "training_steps = 100000\n", "\n", "plt.ion()\n", "plotter = Plotter()\n", "model = Stochastic()\n", "optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.99])\n", "tbtrial = torchbearer.Trial(model, optim, loss, [est()],\n", " callbacks=[\n", " greedy_update,\n", " plotter.make_plotter('Adam', c='g', step_size=1000),\n", " TensorBoard(comment='adam', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)\n", " ]).to(device)\n", "tbtrial.for_train_steps(training_steps).run(verbose=0)\n", "\n", "model = Stochastic()\n", "optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.99], amsgrad=True)\n", "tbtrial = torchbearer.Trial(model, optim, loss, [est()],\n", " callbacks=[\n", " greedy_update,\n", " plotter.make_plotter('AMSGrad', c='r', step_size=1000),\n", " TensorBoard(comment='amsgrad', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)\n", " ]).to(device)\n", "tbtrial.for_train_steps(training_steps).run(verbose=0)\n", "plt.ioff()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4YrgIproWZCK" }, "source": [ "After execution has finished, again running `tensorboard --logdir logs` and navigating to\n", "[localhost:6006](http://localhost:6006) , we see another graph similar to that of the stochastic setting in Figure 1 of\n", "the paper, where the top line is with ADAM and the bottom with AMSGrad:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "id": "_weTn-llXB1q" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 10, "metadata": { "image/png": { "width": 500 } }, "output_type": "execute_result" } ], "source": [ "from IPython.display import Image \n", "Image('https://raw.githubusercontent.com/ecs-vlc/torchbearer/master/docs/_static/img/ams_grad_stochastic.png', width=500)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Zad9d8_NXGQ2" }, "source": [ "Conclusions\n", "------------------------------------\n", "\n", "So, whatever your thoughts on the AMSGrad optimizer in practice, it's probably the sign of a good paper that you can\n", "re-implement the example and get very similar results without having to try too hard and (thanks to torchbearer) only\n", "writing a small amount of code. The paper includes some more complex, 'real-world' examples, can you re-implement those?\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "name": "amsgrad.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 1 }