{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8b66fbef",
   "metadata": {},
   "source": [
    "(tune-horovod-example)=\n",
    "\n",
    "# Using Horovod with Tune\n",
    "\n",
    "```{image} /images/horovod.png\n",
    ":align: center\n",
    ":alt: Horovod Logo\n",
    ":height: 120px\n",
    ":target: https://horovod.ai/\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82188b4b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import torch\n",
    "\n",
    "import ray\n",
    "from ray import train, tune\n",
    "from ray.train.horovod import HorovodTrainer\n",
    "from ray.train import ScalingConfig\n",
    "from ray.tune.tune_config import TuneConfig\n",
    "from ray.tune.tuner import Tuner\n",
    "\n",
    "\n",
    "def sq(x):\n",
    "    m2 = 1.0\n",
    "    m1 = -20.0\n",
    "    m0 = 50.0\n",
    "    return m2 * x * x + m1 * x + m0\n",
    "\n",
    "\n",
    "def qu(x):\n",
    "    m3 = 10.0\n",
    "    m2 = 5.0\n",
    "    m1 = -20.0\n",
    "    m0 = -5.0\n",
    "    return m3 * x * x * x + m2 * x * x + m1 * x + m0\n",
    "\n",
    "\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, mode=\"sq\"):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "        if mode == \"square\":\n",
    "            self.mode = 0\n",
    "            self.param = torch.nn.Parameter(torch.FloatTensor([1.0, -1.0]))\n",
    "        else:\n",
    "            self.mode = 1\n",
    "            self.param = torch.nn.Parameter(torch.FloatTensor([1.0, -1.0, 1.0]))\n",
    "\n",
    "    def forward(self, x):\n",
    "        if ~self.mode:\n",
    "            return x * x + self.param[0] * x + self.param[1]\n",
    "        else:\n",
    "            return_val = 10 * x * x * x\n",
    "            return_val += self.param[0] * x * x\n",
    "            return_val += self.param[1] * x + self.param[2]\n",
    "            return return_val\n",
    "\n",
    "\n",
    "def train_loop_per_worker(config):\n",
    "    import torch\n",
    "    import horovod.torch as hvd\n",
    "\n",
    "    hvd.init()\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    mode = config[\"mode\"]\n",
    "    net = Net(mode).to(device)\n",
    "    optimizer = torch.optim.SGD(\n",
    "        net.parameters(),\n",
    "        lr=config[\"lr\"],\n",
    "    )\n",
    "    optimizer = hvd.DistributedOptimizer(optimizer)\n",
    "\n",
    "    num_steps = 5\n",
    "    print(hvd.size())\n",
    "    np.random.seed(1 + hvd.rank())\n",
    "    torch.manual_seed(1234)\n",
    "    # To ensure consistent initialization across workers,\n",
    "    hvd.broadcast_parameters(net.state_dict(), root_rank=0)\n",
    "    hvd.broadcast_optimizer_state(optimizer, root_rank=0)\n",
    "\n",
    "    start = time.time()\n",
    "    x_max = config[\"x_max\"]\n",
    "    for step in range(1, num_steps + 1):\n",
    "        features = torch.Tensor(np.random.rand(1) * 2 * x_max - x_max).to(device)\n",
    "        if mode == \"square\":\n",
    "            labels = sq(features)\n",
    "        else:\n",
    "            labels = qu(features)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(features)\n",
    "        loss = torch.nn.MSELoss()(outputs, labels)\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        time.sleep(0.1)\n",
    "        train.report(dict(loss=loss.item()))\n",
    "    total = time.time() - start\n",
    "    print(f\"Took {total:0.3f} s. Avg: {total / num_steps:0.3f} s.\")\n",
    "\n",
    "\n",
    "def tune_horovod(num_workers, num_samples, use_gpu, mode=\"square\", x_max=1.0):\n",
    "    horovod_trainer = HorovodTrainer(\n",
    "        train_loop_per_worker=train_loop_per_worker,\n",
    "        scaling_config=ScalingConfig(\n",
    "            trainer_resources={\"CPU\": 0}, num_workers=num_workers, use_gpu=use_gpu\n",
    "        ),\n",
    "        train_loop_config={\"mode\": mode, \"x_max\": x_max},\n",
    "    )\n",
    "\n",
    "    tuner = Tuner(\n",
    "        horovod_trainer,\n",
    "        param_space={\"train_loop_config\": {\"lr\": tune.uniform(0.1, 1)}},\n",
    "        tune_config=TuneConfig(mode=\"min\", metric=\"loss\", num_samples=num_samples),\n",
    "    )\n",
    "\n",
    "    result_grid = tuner.fit()\n",
    "\n",
    "    print(\"Best hyperparameters found were: \", result_grid.get_best_result().config)\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import argparse\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\n",
    "        \"--mode\", type=str, default=\"square\", choices=[\"square\", \"cubic\"]\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--learning_rate\", type=float, default=0.1, dest=\"learning_rate\"\n",
    "    )\n",
    "    parser.add_argument(\"--x_max\", type=float, default=1.0, dest=\"x_max\")\n",
    "    parser.add_argument(\"--gpu\", action=\"store_true\")\n",
    "    parser.add_argument(\n",
    "        \"--smoke-test\", action=\"store_true\", help=(\"Finish quickly for testing.\")\n",
    "    )\n",
    "    parser.add_argument(\"--num-workers\", type=int, default=2)\n",
    "    args, _ = parser.parse_known_args()\n",
    "\n",
    "    if args.smoke_test:\n",
    "        # Smoke test with 2 samples x 2 workers x 1 CPU/worker\n",
    "        # (and allocating 0 CPUs for the trainers)\n",
    "        ray.init(num_cpus=4)\n",
    "\n",
    "    tune_horovod(\n",
    "        num_workers=args.num_workers,\n",
    "        num_samples=2 if args.smoke_test else 10,\n",
    "        use_gpu=args.gpu,\n",
    "        mode=args.mode,\n",
    "        x_max=args.x_max,\n",
    "    )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ray_dev_py38",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) \n[Clang 12.0.1 ]"
  },
  "orphan": true,
  "vscode": {
   "interpreter": {
    "hash": "265d195fda5292fe8f69c6e37c435a5634a1ed3b6799724e66a975f68fa21517"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}