{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "2438a1d7-6564-4a4d-bb8a-5ae7f3eba552", "metadata": { "tags": [] }, "source": [ "# Visualizing Population Based Training (PBT) Hyperparameter Optimization\n", "\n", "**Assumptions:** The reader has a basic understanding of the [PBT algorithm](https://www.deepmind.com/blog/population-based-training-of-neural-networks) and wants to dive deeper and verify the underlying algorithm behavior with [Ray's PBT implementation](tune-scheduler-pbt). [This guide](pbt-guide-ref) provides resources for gaining some context.\n", "\n", "This tutorial will go through a simple example that will help you develop a better understanding of what PBT is doing under the hood when using it to tune your algorithms. Follow along by launching the notebook with the rocket 🚀 icon above.\n", "\n", "We will learn how to:\n", "\n", "1. **Set up checkpointing and loading for PBT** with the function trainable interface\n", "2. **Configure Tune and PBT scheduler parameters**\n", "3. **Visualize PBT algorithm behavior** to gain some intuition\n", "\n", "## Set up Toy the Example\n", "\n", "The toy example optimization problem we will use comes from the [PBT paper](https://arxiv.org/pdf/1711.09846.pdf) (see Figure 2 for more details). The goal is to find parameters that maximize an quadratic function, while only having access to an estimator that depends on a set of hyperparameters. A practical example of this is maximizing the (unknown) generalization capabilities of a model across all possible inputs with only access to the empirical loss of your model, which depends on hyperparameters in order to optimize.\n", "\n", "We'll start with some imports." ] }, { "cell_type": "code", "execution_count": 1, "id": "49b2e7ba-532b-431e-aa81-1467cb2b4e70", "metadata": {}, "outputs": [], "source": [ "!pip install -U \"ray[tune]\"" ] }, { "attachments": {}, "cell_type": "markdown", "id": "efec7627-fd60-48e9-8214-0b4fbb8e4402", "metadata": {}, "source": [ "Note: If you're running on Colab, please copy {doc}`this helper file ` into your Colab mount as `pbt_visualization_utils.py` using the file explorer on the left." ] }, { "cell_type": "code", "execution_count": 1, "id": "90471b91", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import time\n", "\n", "import ray\n", "from ray import train, tune\n", "from ray.train import Checkpoint, FailureConfig, RunConfig\n", "from ray.tune.schedulers import PopulationBasedTraining\n", "from ray.tune.tune_config import TuneConfig\n", "from ray.tune.tuner import Tuner\n", "\n", "from pbt_visualization_utils import (\n", " get_init_theta, plot_parameter_history,\n", " plot_Q_history, make_animation\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a223d6a2-a7d5-40a1-8e12-2a5a1a0a0070", "metadata": {}, "source": [ "Concretely, we will use the definitions (with very minor modifications) provided in the [paper](https://arxiv.org/pdf/1711.09846.pdf) for the function we are trying to optimize, and the estimator we are given.\n", "\n", "Here is a list of the concepts we will use for the example, and what they might be analagous to in practice:\n", "\n", "| Concept within this example | Description | Practical analogy |\n", "|---------|-------------|-------------------|\n", "|`theta = [theta0, theta1]`|The model parameters that we will update in our training loop.|Neural network parameters|\n", "|`h = [h0, h1]`|The hyperparameters that PBT will optimize.|Learning rate, batch size, etc.|\n", "|`Q(theta)`|The quadratic function we are trying to maximize.|Generalization capability over all inputs|\n", "|`Qhat(theta \\| h)`|The estimator we are given as our training objective, depends (`\\|`) on `h`.|Empirical loss/reward|\n", "\n", "Below are the implementations in code." ] }, { "cell_type": "code", "execution_count": 2, "id": "a75e75db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial parameter values: theta = [0.9 0.9]\n" ] } ], "source": [ "def Q(theta):\n", " return 1.2 - (3/4 * theta[0] ** 2 + theta[1] ** 2)\n", "\n", "def Qhat(theta, h):\n", " return 1.2 - (h[0] * theta[0] ** 2 + h[1] * theta[1] ** 2)\n", "\n", "def grad_Qhat(theta, h):\n", " theta_grad = -2 * h * theta\n", " theta_grad[0] *= 3/4\n", " h_grad = -np.square(theta)\n", " h_grad[0] *= 3/4\n", " return {\"theta\": theta_grad, \"h\": h_grad}\n", "\n", "theta_0 = get_init_theta()\n", "print(\"Initial parameter values: theta = \", theta_0)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "0ee21632-9be6-4f80-ac80-c71696cb0f4f", "metadata": {}, "source": [ "## Defining the Function Trainable\n", "\n", "We will define the training loop:\n", "1. Load the hyperparameter configuration\n", "2. Initialize the model, **resuming from a checkpoint if one exists (this is important for PBT, since the scheduler will pause and resume trials frequently when trials get exploited).**\n", "3. Run the training loop and **checkpoint.**" ] }, { "cell_type": "code", "execution_count": 3, "id": "2d1a9fb5", "metadata": {}, "outputs": [], "source": [ "def train_func(config):\n", " # Load the hyperparam config passed in by the Tuner\n", " h0 = config.get(\"h0\")\n", " h1 = config.get(\"h1\")\n", " h = np.array([h0, h1]).astype(float)\n", " \n", " lr = config.get(\"lr\")\n", " train_step = 1\n", " checkpoint_interval = config.get(\"checkpoint_interval\", 1)\n", " logging_interval = config.get(\"logging_interval\", 10)\n", " \n", " # Initialize the model parameters\n", " theta = get_init_theta()\n", " \n", " # Load a checkpoint if it exists\n", " # This checkpoint could be a trial's own checkpoint to resume,\n", " # or another trial's checkpoint placed by PBT that we will exploit\n", " if train.get_checkpoint():\n", " checkpoint_dict = train.get_checkpoint().to_dict()\n", " # Load in model (theta)\n", " theta = checkpoint_dict[\"theta\"]\n", " last_step = checkpoint_dict[\"train_step\"]\n", " train_step = last_step + 1\n", " \n", " # Main training loop (trial stopping is configured later)\n", " while True:\n", " # Perform gradient ascent steps\n", " param_grads = grad_Qhat(theta, h)\n", " theta_grad = np.asarray(param_grads[\"theta\"])\n", " theta = theta + lr * theta_grad\n", " \n", " # Checkpoint every `checkpoint_interval` steps\n", " checkpoint = None\n", " should_checkpoint = train_step % checkpoint_interval == 0\n", " if should_checkpoint:\n", " checkpoint = Checkpoint.from_dict({\n", " \"h\": h,\n", " \"train_step\": train_step,\n", " \"theta\": theta,\n", " })\n", "\n", " # Define which custom metrics we want in our trial result\n", " result = {\n", " \"Q\": Q(theta),\n", " \"theta0\": theta[0], \"theta1\": theta[1],\n", " \"h0\": h0, \"h1\": h1,\n", " \"train_step\": train_step,\n", " }\n", " \n", " # Report metric for this training iteration, and include the\n", " # trial checkpoint that contains the current parameters if we\n", " # saved it this train step\n", " train.report(result, checkpoint=checkpoint)\n", " \n", " train_step += 1" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5bdc96e0-b4bf-4a7a-9f15-e94de6f4d21b", "metadata": {}, "source": [ "```{note}\n", "Since PBT will keep restoring from latest checkpoints, it's important to save and load `train_step` correctly in a function trainable. **Make sure you increment the loaded `train_step` by one as shown above.** This avoids repeating an iteration and causing the checkpoint and perturbation intervals to be out of sync.\n", "\n", "```" ] }, { "attachments": {}, "cell_type": "markdown", "id": "caa002e2-1d68-404c-84bd-99b8d8119dac", "metadata": {}, "source": [ "## Configure PBT and Tuner\n", "\n", "We start by initializing ray (shutting it down if a session existed previously)." ] }, { "cell_type": "code", "execution_count": 4, "id": "f68445a3-958f-49a0-a9f9-03121c3c731c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-09-14 11:43:32,337\tINFO worker.py:1517 -- Started a local Ray instance.\n" ] }, { "data": { "text/html": [ "
| Python version: | \n", "3.8.13 | \n", "
| Ray version: | \n", "3.0.0.dev0 | \n", "