{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Monitoring optimisation\n", "\n", "In this notebook we'll demo how to use `gpflow.training.monitor` for logging the optimisation of a GPflow model." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": false }, "outputs": [], "source": [ "import itertools\n", "import os\n", "import numpy as np\n", "import gpflow\n", "import gpflow.training.monitor as mon\n", "import numbers\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating the GPflow model\n", "We first generate some random data and create a GPflow model.\n", "\n", "Under the hood, GPflow gives a unique name to each model which is used to name the Variables it creates in the TensorFlow graph containing a random identifier. This is useful in interactive sessions, where people may create a few models, to prevent variables with the same name conflicting. However, when loading the model, we need to make sure that the names of all the variables are exactly the same as in the checkpoint. This is why we pass name=\"SVGP\" to the model constructor, and why we use gpflow.defer_build()." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(0)\n", "X = np.random.rand(10000, 1) * 10\n", "Y = np.sin(X) + np.random.randn(*X.shape)\n", "Xt = np.random.rand(10000, 1) * 10\n", "Yt = np.sin(Xt) + np.random.randn(*Xt.shape)\n", "\n", "with gpflow.defer_build():\n", " m = gpflow.models.SVGP(X, Y, gpflow.kernels.RBF(1), gpflow.likelihoods.Gaussian(),\n", " Z=np.linspace(0, 10, 5)[:, None],\n", " minibatch_size=100, name=\"SVGP\")\n", " m.likelihood.variance = 0.01\n", "m.compile()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's compute log likelihood before the optimisation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LML before the optimisation: -1271605.621944\n" ] } ], "source": [ "print('LML before the optimisation: %f' % m.compute_log_likelihood())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will be using a TensorFlow optimiser. All TensorFlow optimisers have a support for `global_step` variable. Its purpose is to track how many optimisation steps have occurred. It is useful to keep this in a TensorFlow variable as this allows it to be restored together with all the parameters of the model.\n", "\n", "The code below creates this variable using a monitor's helper function. It is important to create it before building the monitor in case the monitor includes a checkpoint task. This is because the checkpoint internally uses the TensorFlow Saver which creates a list of variables to save. Therefore all variables expected to be saved by the checkpoint task should exist by the time the task is created." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "session = m.enquire_session()\n", "global_step = mon.create_global_step(session)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also need to create our optimiser before building the monitor to make sure that it is restored correctly. Adam, for example, keeps track of certain gradient moments that have been accumulated over time. Momentum is another example of a state in the optimiser that may need to be restored. We also need to call `minimize` to initialise all the variables in the optimiser. We run for zero iterations so no actual optimisation is done. This is a slight hack." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "optimiser = gpflow.train.AdamOptimizer(0.01)\n", "optimiser.minimize(m, maxiter=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Construct the monitor\n", "\n", "Next we need to construct the monitor. `gpflow.training.monitor` provides classes that are building blocks for the monitor. Essengially, a monitor is a function that is provided as a callback to an optimiser. It consists of a number of tasks that may be executed at each step, subject to their running condition.\n", "\n", "In this example, we want to:\n", "- log certain scalar parameters in TensorBoard,\n", "- log the full optimisation objective (log marginal likelihood bound) periodically, even though we optimise with minibatches,\n", "- store a backup of the optimisation process periodically,\n", "- log performance for a test set periodically.\n", "\n", "We will define these tasks as follows:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "print_task = mon.PrintTimingsTask().with_name('print')\\\n", " .with_condition(mon.PeriodicIterationCondition(10))\\\n", " .with_exit_condition(True)\n", "\n", "sleep_task = mon.SleepTask(0.01).with_name('sleep').with_name('sleep')\n", "\n", "saver_task = mon.CheckpointTask('./monitor-saves').with_name('saver')\\\n", " .with_condition(mon.PeriodicIterationCondition(10))\\\n", " .with_exit_condition(True)\n", "\n", "file_writer = mon.LogdirWriter('./model-tensorboard')\n", "\n", "model_tboard_task = mon.ModelToTensorBoardTask(file_writer, m).with_name('model_tboard')\\\n", " .with_condition(mon.PeriodicIterationCondition(10))\\\n", " .with_exit_condition(True)\n", "\n", "lml_tboard_task = mon.LmlToTensorBoardTask(file_writer, m).with_name('lml_tboard')\\\n", " .with_condition(mon.PeriodicIterationCondition(100))\\\n", " .with_exit_condition(True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the above code shows, each task can be assigned a name and running conditions. The name will be shown in the task timing summary.\n", "\n", "There are two different types of running conditions: `with_condition` controls execution of the task at each iteration in the optimisation loop. `with_exit_condition` is a simple boolean flag indicating that the task should also run at the end of optimisation.\n", "In this example we want to run our tasks periodically, at every iteration or every 10th or 100th iteration.\n", "\n", "Notice that the two TensorBoard tasks will write events into the same file. It is possible to share a file writer between multiple tasks. However it is not possible to share the same event location between multiple file writers. An attempt to open two writers with the same location will result in error.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom tasks\n", "We may also want to perfom certain tasks that do not have pre-defined `Task` classes. For example, we may want to compute the performance on a test set. Here we create such a class by extending `BaseTensorBoardTask` to log the testing benchmarks in addition to all the scalar parameters." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class CustomTensorBoardTask(mon.BaseTensorBoardTask):\n", " def __init__(self, file_writer, model, Xt, Yt):\n", " super().__init__(file_writer, model)\n", " self.Xt = Xt\n", " self.Yt = Yt\n", " self._full_test_err = tf.placeholder(gpflow.settings.float_type, shape=())\n", " self._full_test_nlpp = tf.placeholder(gpflow.settings.float_type, shape=())\n", " self._summary = tf.summary.merge([tf.summary.scalar(\"test_rmse\", self._full_test_err),\n", " tf.summary.scalar(\"test_nlpp\", self._full_test_nlpp)])\n", " \n", " def run(self, context: mon.MonitorContext, *args, **kwargs) -> None:\n", " minibatch_size = 100\n", " preds = np.vstack([self.model.predict_y(Xt[mb * minibatch_size:(mb + 1) * minibatch_size, :])[0]\n", " for mb in range(-(-len(Xt) // minibatch_size))])\n", " test_err = np.mean((Yt - preds) ** 2.0)**0.5\n", " self._eval_summary(context, {self._full_test_err: test_err, self._full_test_nlpp: 0.0})\n", "\n", " \n", "custom_tboard_task = CustomTensorBoardTask(file_writer, m, Xt, Yt).with_name('custom_tboard')\\\n", " .with_condition(mon.PeriodicIterationCondition(100))\\\n", " .with_exit_condition(True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can put all these tasks into a monitor." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "monitor_tasks = [print_task, model_tboard_task, lml_tboard_task, custom_tboard_task, saver_task, sleep_task]\n", "monitor = mon.Monitor(monitor_tasks, session, global_step)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running the optimisation\n", "We finally get to running the optimisation.\n", "\n", "We may want to continue a previously run optimisation by resotring the TensorFlow graph from the latest checkpoint. Otherwise skip this step." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./monitor-saves/cp-900\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./monitor-saves/cp-900\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "LML after loading: -32615.572401\n" ] } ], "source": [ "if os.path.isdir('./monitor-saves'):\n", " mon.restore_session(session, './monitor-saves')\n", " print('LML after loading: %f' % m.compute_log_likelihood())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To check that the model has been correctly restored, we print out a model hyperparameter (1 at initialisation) and an optimiser variable (zeros at initialisation)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n", "[[ 401.49689376]\n", " [ 288.26263405]\n", " [ -51.83913957]\n", " [ 236.47567674]\n", " [-100.38292958]]\n" ] } ], "source": [ "print(m.kern.lengthscales.value)\n", "print(session.run(optimiser.optimizer.variables()[0]))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 10\ttotal itr.rate 11.40/s\trecent itr.rate 11.40/s\topt.step 910\ttotal opt.rate 12.72/s\trecent opt.rate 12.72/s\n", "Iteration 20\ttotal itr.rate 16.96/s\trecent itr.rate 33.13/s\topt.step 920\ttotal opt.rate 25.02/s\trecent opt.rate 780.35/s\n", "Iteration 30\ttotal itr.rate 21.50/s\trecent itr.rate 46.32/s\topt.step 930\ttotal opt.rate 36.29/s\trecent opt.rate 366.01/s\n", "Iteration 40\ttotal itr.rate 25.17/s\trecent itr.rate 51.48/s\topt.step 940\ttotal opt.rate 47.29/s\trecent opt.rate 519.86/s\n", "Iteration 50\ttotal itr.rate 27.82/s\trecent itr.rate 48.07/s\topt.step 950\ttotal opt.rate 57.95/s\trecent opt.rate 587.19/s\n", "Iteration 60\ttotal itr.rate 29.92/s\trecent itr.rate 48.13/s\topt.step 960\ttotal opt.rate 68.14/s\trecent opt.rate 563.82/s\n", "Iteration 70\ttotal itr.rate 31.63/s\trecent itr.rate 48.06/s\topt.step 970\ttotal opt.rate 77.43/s\trecent opt.rate 426.77/s\n", "Iteration 80\ttotal itr.rate 33.01/s\trecent itr.rate 47.65/s\topt.step 980\ttotal opt.rate 86.89/s\trecent opt.rate 599.23/s\n", "Iteration 90\ttotal itr.rate 34.22/s\trecent itr.rate 48.28/s\topt.step 990\ttotal opt.rate 96.12/s\trecent opt.rate 637.83/s\n", "Iteration 100\ttotal itr.rate 35.33/s\trecent itr.rate 50.04/s\topt.step 1000\ttotal opt.rate 104.69/s\trecent opt.rate 530.72/s\n", "Computing full lml...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:00<00:00, 465.36it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 110\ttotal itr.rate 29.51/s\trecent itr.rate 11.14/s\topt.step 1010\ttotal opt.rate 112.54/s\trecent opt.rate 450.99/s\n", "Iteration 120\ttotal itr.rate 30.43/s\trecent itr.rate 46.38/s\topt.step 1020\ttotal opt.rate 119.94/s\trecent opt.rate 432.72/s\n", "Iteration 130\ttotal itr.rate 31.23/s\trecent itr.rate 45.47/s\topt.step 1030\ttotal opt.rate 127.31/s\trecent opt.rate 484.54/s\n", "Iteration 140\ttotal itr.rate 31.95/s\trecent itr.rate 45.64/s\topt.step 1040\ttotal opt.rate 133.81/s\trecent opt.rate 397.59/s\n", "Iteration 150\ttotal itr.rate 32.56/s\trecent itr.rate 44.57/s\topt.step 1050\ttotal opt.rate 140.46/s\trecent opt.rate 462.15/s\n", "Iteration 160\ttotal itr.rate 32.96/s\trecent itr.rate 40.48/s\topt.step 1060\ttotal opt.rate 146.42/s\trecent opt.rate 402.79/s\n", "Iteration 170\ttotal itr.rate 33.52/s\trecent itr.rate 45.79/s\topt.step 1070\ttotal opt.rate 152.74/s\trecent opt.rate 493.06/s\n", "Iteration 180\ttotal itr.rate 34.10/s\trecent itr.rate 48.55/s\topt.step 1080\ttotal opt.rate 158.83/s\trecent opt.rate 493.83/s\n", "Iteration 190\ttotal itr.rate 34.56/s\trecent itr.rate 45.54/s\topt.step 1090\ttotal opt.rate 164.66/s\trecent opt.rate 484.71/s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:00<00:00, 600.72it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 200\ttotal itr.rate 35.00/s\trecent itr.rate 46.11/s\topt.step 1100\ttotal opt.rate 170.24/s\trecent opt.rate 479.07/s\n", "Computing full lml...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 210\ttotal itr.rate 33.57/s\trecent itr.rate 18.50/s\topt.step 1110\ttotal opt.rate 176.20/s\trecent opt.rate 587.12/s\n", "Iteration 220\ttotal itr.rate 33.98/s\trecent itr.rate 45.70/s\topt.step 1120\ttotal opt.rate 181.26/s\trecent opt.rate 456.23/s\n", "Iteration 230\ttotal itr.rate 34.34/s\trecent itr.rate 44.49/s\topt.step 1130\ttotal opt.rate 186.52/s\trecent opt.rate 516.53/s\n", "Iteration 240\ttotal itr.rate 34.74/s\trecent itr.rate 47.68/s\topt.step 1140\ttotal opt.rate 191.25/s\trecent opt.rate 459.12/s\n", "Iteration 250\ttotal itr.rate 35.10/s\trecent itr.rate 46.61/s\topt.step 1150\ttotal opt.rate 197.06/s\trecent opt.rate 727.69/s\n", "Iteration 260\ttotal itr.rate 35.11/s\trecent itr.rate 35.50/s\topt.step 1160\ttotal opt.rate 201.54/s\trecent opt.rate 466.21/s\n", "Iteration 270\ttotal itr.rate 35.45/s\trecent itr.rate 47.03/s\topt.step 1170\ttotal opt.rate 205.52/s\trecent opt.rate 422.49/s\n", "Iteration 280\ttotal itr.rate 35.76/s\trecent itr.rate 47.23/s\topt.step 1180\ttotal opt.rate 209.48/s\trecent opt.rate 436.70/s\n", "Iteration 290\ttotal itr.rate 36.03/s\trecent itr.rate 45.63/s\topt.step 1190\ttotal opt.rate 213.75/s\trecent opt.rate 498.71/s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:00<00:00, 622.80it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 300\ttotal itr.rate 36.36/s\trecent itr.rate 49.47/s\topt.step 1200\ttotal opt.rate 218.43/s\trecent opt.rate 596.23/s\n", "Computing full lml...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 310\ttotal itr.rate 35.28/s\trecent itr.rate 18.66/s\topt.step 1210\ttotal opt.rate 222.25/s\trecent opt.rate 468.74/s\n", "Iteration 320\ttotal itr.rate 35.54/s\trecent itr.rate 46.01/s\topt.step 1220\ttotal opt.rate 225.08/s\trecent opt.rate 371.33/s\n", "Iteration 330\ttotal itr.rate 35.76/s\trecent itr.rate 44.41/s\topt.step 1230\ttotal opt.rate 229.69/s\trecent opt.rate 667.65/s\n", "Iteration 340\ttotal itr.rate 35.88/s\trecent itr.rate 40.34/s\topt.step 1240\ttotal opt.rate 232.81/s\trecent opt.rate 421.93/s\n", "Iteration 350\ttotal itr.rate 36.08/s\trecent itr.rate 44.39/s\topt.step 1250\ttotal opt.rate 236.52/s\trecent opt.rate 516.11/s\n", "Iteration 360\ttotal itr.rate 36.31/s\trecent itr.rate 47.22/s\topt.step 1260\ttotal opt.rate 240.28/s\trecent opt.rate 540.74/s\n", "Iteration 370\ttotal itr.rate 36.52/s\trecent itr.rate 46.15/s\topt.step 1270\ttotal opt.rate 244.51/s\trecent opt.rate 670.20/s\n", "Iteration 380\ttotal itr.rate 36.72/s\trecent itr.rate 45.60/s\topt.step 1280\ttotal opt.rate 247.75/s\trecent opt.rate 485.35/s\n", "Iteration 390\ttotal itr.rate 36.91/s\trecent itr.rate 46.12/s\topt.step 1290\ttotal opt.rate 251.01/s\trecent opt.rate 501.65/s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:00<00:00, 580.38it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 400\ttotal itr.rate 37.09/s\trecent itr.rate 45.67/s\topt.step 1300\ttotal opt.rate 254.89/s\trecent opt.rate 641.85/s\n", "Computing full lml...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 410\ttotal itr.rate 36.22/s\trecent itr.rate 18.68/s\topt.step 1310\ttotal opt.rate 257.04/s\trecent opt.rate 387.75/s\n", "Iteration 420\ttotal itr.rate 36.45/s\trecent itr.rate 49.45/s\topt.step 1320\ttotal opt.rate 260.32/s\trecent opt.rate 546.02/s\n", "Iteration 430\ttotal itr.rate 36.64/s\trecent itr.rate 47.14/s\topt.step 1330\ttotal opt.rate 263.74/s\trecent opt.rate 588.43/s\n", "Iteration 440\ttotal itr.rate 36.82/s\trecent itr.rate 46.54/s\topt.step 1340\ttotal opt.rate 265.82/s\trecent opt.rate 403.08/s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/100 [00:00