{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aa1c2614",
   "metadata": {},
   "source": [
    "(tune-rllib-example)=\n",
    "\n",
    "# Using RLlib with Tune\n",
    "\n",
    "```{image} /rllib/images/rllib-logo.png\n",
    ":align: center\n",
    ":alt: RLlib Logo\n",
    ":height: 120px\n",
    ":target: https://docs.ray.io\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Example\n",
    "\n",
    "Example of using a Tune scheduler ([Population Based Training](tune-scheduler-pbt)) with RLlib.\n",
    "\n",
    "This example specifies `num_workers=4`, `num_cpus=1`, and `num_gpus=0`, which means that each\n",
    "PPO trial will use 5 CPUs: 1 (for training) + 4 (for sample collection).\n",
    "This example runs 2 trials, so at least 10 CPUs must be available in the cluster resources\n",
    "in order to run both trials concurrently. Otherwise, the PBT scheduler will round-robin\n",
    "between training each trial, which is less efficient.\n",
    "\n",
    "If you want to run this example with GPUs, you can set `num_gpus` accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4621a1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "import ray\n",
    "from ray import train, tune\n",
    "from ray.tune.schedulers import PopulationBasedTraining\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import argparse\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\n",
    "        \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n",
    "    )\n",
    "    args, _ = parser.parse_known_args()\n",
    "\n",
    "    # Postprocess the perturbed config to ensure it's still valid\n",
    "    def explore(config):\n",
    "        # ensure we collect enough timesteps to do sgd\n",
    "        if config[\"train_batch_size\"] < config[\"sgd_minibatch_size\"] * 2:\n",
    "            config[\"train_batch_size\"] = config[\"sgd_minibatch_size\"] * 2\n",
    "        # ensure we run at least one sgd iter\n",
    "        if config[\"num_sgd_iter\"] < 1:\n",
    "            config[\"num_sgd_iter\"] = 1\n",
    "        return config\n",
    "\n",
    "    hyperparam_mutations = {\n",
    "        \"lambda\": lambda: random.uniform(0.9, 1.0),\n",
    "        \"clip_param\": lambda: random.uniform(0.01, 0.5),\n",
    "        \"lr\": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],\n",
    "        \"num_sgd_iter\": lambda: random.randint(1, 30),\n",
    "        \"sgd_minibatch_size\": lambda: random.randint(128, 16384),\n",
    "        \"train_batch_size\": lambda: random.randint(2000, 160000),\n",
    "    }\n",
    "\n",
    "    pbt = PopulationBasedTraining(\n",
    "        time_attr=\"time_total_s\",\n",
    "        perturbation_interval=120,\n",
    "        resample_probability=0.25,\n",
    "        # Specifies the mutations of these hyperparams\n",
    "        hyperparam_mutations=hyperparam_mutations,\n",
    "        custom_explore_fn=explore,\n",
    "    )\n",
    "\n",
    "    # Stop when we've either reached 100 training iterations or reward=300\n",
    "    stopping_criteria = {\"training_iteration\": 100, \"episode_reward_mean\": 300}\n",
    "\n",
    "    tuner = tune.Tuner(\n",
    "        \"PPO\",\n",
    "        tune_config=tune.TuneConfig(\n",
    "            metric=\"episode_reward_mean\",\n",
    "            mode=\"max\",\n",
    "            scheduler=pbt,\n",
    "            num_samples=1 if args.smoke_test else 2,\n",
    "        ),\n",
    "        param_space={\n",
    "            \"env\": \"Humanoid-v2\",\n",
    "            \"kl_coeff\": 1.0,\n",
    "            \"num_workers\": 4,\n",
    "            \"num_cpus\": 1,  # number of CPUs to use per trial\n",
    "            \"num_gpus\": 0,  # number of GPUs to use per trial\n",
    "            \"model\": {\"free_log_std\": True},\n",
    "            # These params are tuned from a fixed starting value.\n",
    "            \"lambda\": 0.95,\n",
    "            \"clip_param\": 0.2,\n",
    "            \"lr\": 1e-4,\n",
    "            # These params start off randomly drawn from a set.\n",
    "            \"num_sgd_iter\": tune.choice([10, 20, 30]),\n",
    "            \"sgd_minibatch_size\": tune.choice([128, 512, 2048]),\n",
    "            \"train_batch_size\": tune.choice([10000, 20000, 40000]),\n",
    "        },\n",
    "        run_config=train.RunConfig(stop=stopping_criteria),\n",
    "    )\n",
    "    results = tuner.fit()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "8cd3cc70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best performing trial's final set of hyperparameters:\n",
      "\n",
      "{'clip_param': 0.2,\n",
      " 'lambda': 0.95,\n",
      " 'lr': 0.0001,\n",
      " 'num_sgd_iter': 30,\n",
      " 'sgd_minibatch_size': 2048,\n",
      " 'train_batch_size': 20000}\n",
      "\n",
      "Best performing trial's final reported metrics:\n",
      "\n",
      "{'episode_len_mean': 61.09146341463415,\n",
      " 'episode_reward_max': 567.4424113245353,\n",
      " 'episode_reward_mean': 310.36948184391935,\n",
      " 'episode_reward_min': 87.74736189944105}\n"
     ]
    }
   ],
   "source": [
    "import pprint\n",
    "\n",
    "best_result = results.get_best_result()\n",
    "\n",
    "print(\"Best performing trial's final set of hyperparameters:\\n\")\n",
    "pprint.pprint(\n",
    "    {k: v for k, v in best_result.config.items() if k in hyperparam_mutations}\n",
    ")\n",
    "\n",
    "print(\"\\nBest performing trial's final reported metrics:\\n\")\n",
    "\n",
    "metrics_to_print = [\n",
    "    \"episode_reward_mean\",\n",
    "    \"episode_reward_max\",\n",
    "    \"episode_reward_min\",\n",
    "    \"episode_len_mean\",\n",
    "]\n",
    "pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4cc4685",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.rllib.algorithms.algorithm import Algorithm\n",
    "\n",
    "loaded_ppo = Algorithm.from_checkpoint(best_result.checkpoint)\n",
    "loaded_policy = loaded_ppo.get_policy()\n",
    "\n",
    "# See your trained policy in action\n",
    "# loaded_policy.compute_single_action(...)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db534c4e",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## More RLlib Examples\n",
    "\n",
    "- {doc}`/tune/examples/includes/pb2_ppo_example`:\n",
    "  Example of optimizing a distributed RLlib algorithm (PPO) with the PB2 scheduler.\n",
    "  Uses a small population size of 4, so can train on a laptop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d4fb61",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}