{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "02d29398",
   "metadata": {},
   "source": [
    "(mmt-tune)=\n",
    "\n",
    "# Batch training & tuning on Ray Tune"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2780b3da",
   "metadata": {},
   "source": [
    "**Batch training and tuning** are common tasks in simple machine learning use-cases such as time series forecasting. They require fitting of simple models on data batches corresponding to different locations, products, etc. Batch training can take less time to process all the data at once, but only if those batches can run in parallel!\n",
    "\n",
    "This notebook showcases how to conduct batch regression with algorithms from XGBoost and Scikit-learn with **[Ray Tune](tune-main)**. **XGBoost** is a popular open-source library used for regression and classification. **Scikit-learn** is a popular open-source library with a vast assortment of well-known ML algorithms.\n",
    "\n",
    "![Batch training diagram](../../data/examples/images/batch-training.svg)\n",
    "\n",
    "For the data, we will use the [NYC Taxi dataset](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page). This popular tabular dataset contains historical taxi pickups by timestamp and location in NYC.\n",
    "\n",
    "For the training, we will train separate regression models to predict `trip_duration`, with a different model for each dropoff location in NYC. Specifically, we will conduct an experiment for each `dropoff_location_id`, to find the best either XGBoost or Scikit-learn model, per location."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c261b2bd",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Contents\n",
    "\n",
    "In this this tutorial, you will learn how to:\n",
    " 1. [Define how to load and prepare Parquet data](#prepare_data)\n",
    " 2. [Define a Trainable (callable) function](#define_trainable)\n",
    " 3. [Run batch training and inference with Ray Tune](#run_tune_search)\n",
    " 4. [Load a model from checkpoint and perform batch prediction](#load_checkpoint)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "604e8c44",
   "metadata": {},
   "source": [
    "# Walkthrough\n",
    "\n",
    "```{tip}\n",
    "Prerequisite for this notebook: Read the [Key Concepts](tune-60-seconds) page for Ray Tune.\n",
    "```\n",
    "First, let's make sure we have all Python packages we need installed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6160f20e",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "!pip install -q \"ray[tune]\" scikit-learn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2147f97d",
   "metadata": {},
   "source": [
    "Next, let's import a few required libraries, including open-source Ray itself!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c37d1b39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of CPUs in this system: 8\n",
      "numpy: 1.21.6\n",
      "pyarrow: 10.0.0\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pickle\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "print(f\"Number of CPUs in this system: {os.cpu_count()}\")\n",
    "from typing import Tuple, List, Union, Optional, Callable\n",
    "import time\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "print(f\"numpy: {np.__version__}\")\n",
    "import pyarrow\n",
    "import pyarrow.parquet as pq\n",
    "import pyarrow.dataset as pds\n",
    "\n",
    "print(f\"pyarrow: {pyarrow.__version__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8a2ad12",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ray\n",
    "\n",
    "if ray.is_initialized():\n",
    "    ray.shutdown()\n",
    "ray.init()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3563fed9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'memory': 451212691046.0, 'object_store_memory': 175243542524.0, 'node:172.31.206.67': 1.0, 'CPU': 152.0, 'node:172.31.138.114': 1.0, 'node:172.31.221.253': 1.0, 'node:172.31.144.75': 1.0, 'node:172.31.169.100': 1.0, 'node:172.31.136.199': 1.0, 'node:172.31.251.87': 1.0, 'node:172.31.249.240': 1.0, 'node:172.31.252.125': 1.0, 'node:172.31.211.165': 1.0}\n"
     ]
    }
   ],
   "source": [
    "print(ray.cluster_resources())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0341b265",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sklearn: 1.2.0\n",
      "xgboost: 1.3.3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ray/anaconda3/lib/python3.8/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n",
      "  from pandas import MultiIndex, Int64Index\n"
     ]
    }
   ],
   "source": [
    "# import standard sklearn libraries\n",
    "import sklearn\n",
    "from sklearn.base import BaseEstimator\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "\n",
    "print(f\"sklearn: {sklearn.__version__}\")\n",
    "import xgboost as xgb\n",
    "\n",
    "print(f\"xgboost: {xgb.__version__}\")\n",
    "# import ray libraries\n",
    "from ray import train, tune\n",
    "from ray.train import Checkpoint\n",
    "\n",
    "# set global random seed for sklearn models\n",
    "np.random.seed(415)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4881e9ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For benchmarking purposes, we can print the times of various operations.\n",
    "# In order to reduce clutter in the output, this is set to False by default.\n",
    "PRINT_TIMES = False\n",
    "\n",
    "\n",
    "def print_time(msg: str):\n",
    "    if PRINT_TIMES:\n",
    "        print(msg)\n",
    "\n",
    "\n",
    "# To speed things up, we’ll only use a small subset of the full dataset consisting of two last months of 2019.\n",
    "# You can choose to use the full dataset for 2018-2019 by setting the SMOKE_TEST variable to False.\n",
    "SMOKE_TEST = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43545104",
   "metadata": {
    "tags": []
   },
   "source": [
    "(prepare_data)=\n",
    "## Define how to load and prepare Parquet data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c5e5428",
   "metadata": {},
   "source": [
    "First, we need to load some data. Since the NYC Taxi dataset is fairly large, we will filter files first into a PyArrow dataset. And then in the next cell after, we will filter the data on read into a PyArrow table and convert that to a pandas dataframe.\n",
    "\n",
    "```{tip}\n",
    "Use PyArrow dataset and table for reading or writing large parquet files, since its native multithreaded C++ adapter is faster than pandas read_parquet, even using engine=pyarrow.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "65e8465b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NYC Taxi using 1 file(s)!\n",
      "s3_files: ['s3://anonymous@air-example-data/ursa-labs-taxi-data/by_year/2019/06/data.parquet/ab5b9d2b8cc94be19346e260b543ec35_000000.parquet']\n",
      "Locations: [141, 229, 173]\n"
     ]
    }
   ],
   "source": [
    "# Define some global variables.\n",
    "TARGET = \"trip_duration\"\n",
    "s3_partitions = pds.dataset(\n",
    "    \"s3://anonymous@air-example-data/ursa-labs-taxi-data/by_year/\",\n",
    "    partitioning=[\"year\", \"month\"],\n",
    ")\n",
    "s3_files = [f\"s3://anonymous@{file}\" for file in s3_partitions.files]\n",
    "\n",
    "# Obtain all location IDs\n",
    "all_location_ids = (\n",
    "    pq.read_table(s3_files[0], columns=[\"dropoff_location_id\"])[\"dropoff_location_id\"]\n",
    "    .unique()\n",
    "    .to_pylist()\n",
    ")\n",
    "# drop [264, 265]\n",
    "all_location_ids.remove(264)\n",
    "all_location_ids.remove(265)\n",
    "\n",
    "# Use smoke testing or not.\n",
    "starting_idx = -1 if SMOKE_TEST else 0\n",
    "# TODO: drop location 199 to test error-handling before final git checkin\n",
    "sample_locations = [141, 229, 173] if SMOKE_TEST else all_location_ids\n",
    "\n",
    "# Display what data will be used.\n",
    "s3_files = s3_files[starting_idx:]\n",
    "print(f\"NYC Taxi using {len(s3_files)} file(s)!\")\n",
    "print(f\"s3_files: {s3_files}\")\n",
    "print(f\"Locations: {sample_locations}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "92e5cc73",
   "metadata": {},
   "outputs": [],
   "source": [
    "############\n",
    "# STEP 1.  Define Python functions to\n",
    "#          a) read and prepare a segment of data.\n",
    "############\n",
    "\n",
    "# Function to read a pyarrow.Table object using pyarrow parquet\n",
    "def read_data(file: str, sample_id: np.int32) -> pd.DataFrame:\n",
    "\n",
    "    df = pq.read_table(\n",
    "        file,\n",
    "        filters=[\n",
    "            (\"passenger_count\", \">\", 0),\n",
    "            (\"trip_distance\", \">\", 0),\n",
    "            (\"fare_amount\", \">\", 0),\n",
    "            (\"pickup_location_id\", \"not in\", [264, 265]),\n",
    "            (\"dropoff_location_id\", \"not in\", [264, 265]),\n",
    "            (\"dropoff_location_id\", \"=\", sample_id),\n",
    "        ],\n",
    "        columns=[\n",
    "            \"pickup_at\",\n",
    "            \"dropoff_at\",\n",
    "            \"pickup_location_id\",\n",
    "            \"dropoff_location_id\",\n",
    "            \"passenger_count\",\n",
    "            \"trip_distance\",\n",
    "            \"fare_amount\",\n",
    "        ],\n",
    "    ).to_pandas()\n",
    "\n",
    "    return df\n",
    "\n",
    "\n",
    "# Function to transform a pandas dataframe\n",
    "def transform_df(input_df: pd.DataFrame) -> pd.DataFrame:\n",
    "    df = input_df.copy()\n",
    "\n",
    "    # calculate trip_duration\n",
    "    df[\"trip_duration\"] = (df[\"dropoff_at\"] - df[\"pickup_at\"]).dt.seconds\n",
    "    # filter trip_durations > 1 minute and less than 24 hours\n",
    "    df = df[df[\"trip_duration\"] > 60]\n",
    "    df = df[df[\"trip_duration\"] < 24 * 60 * 60]\n",
    "    # keep only necessary columns\n",
    "    df = df[\n",
    "        [\"dropoff_location_id\", \"passenger_count\", \"trip_distance\", \"trip_duration\"]\n",
    "    ].copy()\n",
    "    df[\"dropoff_location_id\"] = df[\"dropoff_location_id\"].fillna(-1)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55b6c727",
   "metadata": {},
   "source": [
    "(define_trainable)=\n",
    "## Define a Trainable (callable) function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd7ab2d0",
   "metadata": {},
   "source": [
    "Next, we define a trainable function, called `train_model()`, in order to train and evaluate a model on a data partition. This function will be called *in parallel for every permutation* in the Tune search space! \n",
    "\n",
    "Inside this trainable function:\n",
    "- 📖 The input must include a `config` argument. \n",
    "- 📈 Inside the function, the tuning metric (a model's loss or error) must be calculated and reported using `ray.train.report()`.\n",
    "- ✔️ Optionally [checkpoint](train-checkpointing) (save) the model for fault tolerance and easy deployment later.\n",
    "\n",
    "```{tip}\n",
    "Ray Tune has two ways of [defining a trainable](tune_60_seconds_trainables), namely the Function API and the Class API. Both are valid ways of defining a trainable, but *the Function API is generally recommended*.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5b59bb62",
   "metadata": {},
   "outputs": [],
   "source": [
    "############\n",
    "# STEP 1.  Define Python functions to\n",
    "#          b) train and evaluate a model on a segment of data.\n",
    "############\n",
    "def train_model(config: dict) -> None:\n",
    "\n",
    "    algorithm = config[\"algorithm\"]\n",
    "    sample_location_id = config[\"location\"]\n",
    "\n",
    "    # Load data.\n",
    "    df_list = [read_data(f, sample_location_id) for f in s3_files]\n",
    "    df_raw = pd.concat(df_list, ignore_index=True)\n",
    "\n",
    "    # Transform data.\n",
    "    df = transform_df(df_raw)\n",
    "\n",
    "    # We need at least 10 rows to create a train / test split.\n",
    "    if df.shape[0] < 10:\n",
    "        print_time(f\"Location {sample_location_id} has only {df.shape[0]} rows.\")\n",
    "        train.report(dict(error=None))\n",
    "        return None\n",
    "\n",
    "    # Train/valid split.\n",
    "    train_df, valid_df = train_test_split(df, test_size=0.2, shuffle=True)\n",
    "    train_X = train_df[[\"passenger_count\", \"trip_distance\"]]\n",
    "    train_y = train_df[TARGET]\n",
    "    valid_X = valid_df[[\"passenger_count\", \"trip_distance\"]]\n",
    "    valid_y = valid_df[TARGET]\n",
    "\n",
    "    # Train model.\n",
    "    model = algorithm.fit(train_X, train_y)\n",
    "    pred_y = model.predict(valid_X)\n",
    "\n",
    "    # Evaluate.\n",
    "    error = sklearn.metrics.mean_absolute_error(valid_y, pred_y)\n",
    "\n",
    "    # Define a model checkpoint using Ray Train API.\n",
    "    state_dict = {\"model\": algorithm, \"location_id\": sample_location_id}\n",
    "    \n",
    "    with TemporaryDirectory() as tmpdir:\n",
    "        with open(os.path.join(tmpdir, \"ckpt.pkl\"), 'wb') as file:\n",
    "            pickle.dump(state_dict, file)\n",
    "    \n",
    "        checkpoint = Checkpoint.from_directory(tmpdir)\n",
    "\n",
    "        # Save checkpoint and report back metrics, using ray.train.report()\n",
    "        # The metrics you specify here will appear in Tune summary table.\n",
    "        # They will also be recorded in Tune results under `metrics`.\n",
    "        metrics = dict(error=error)\n",
    "        train.report(metrics, checkpoint=checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d59fbfab",
   "metadata": {},
   "source": [
    "(run_tune_search)=\n",
    "## Run batch training on Ray Tune"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4db1c6bd",
   "metadata": {},
   "source": [
    "**Recall what we are doing, high level, is training several different models per pickup location.** We are using Ray Tune so we can *run all these trials in parallel* on a Ray cluster. At the end, we will inspect the results of the experiment and deploy only the best model per pickup location.\n",
    "\n",
    "**Step 1. Define Python functions to read and prepare a segment of data and train and evaluate one or many models per segment of data**.  We already did this, above.\n",
    "\n",
    "**Step 2. Scaling**:\n",
    "Below, we use the default resources config which is 1 CPU core for each task. For more information about configuring resource allocations, see [A Guide To Parallelism and Resources](tune-parallelism). \n",
    "\n",
    "**Step 3. Search Space**:\n",
    "Below, we define our [Tune search space](tune-key-concepts-search-spaces), which consists of:\n",
    "- Different algorithms:\n",
    "  - XGBoost\n",
    "  - Scikit-learn LinearRegression\n",
    "- Some or all NYC taxi drop-off locations. \n",
    "\n",
    "**Step 4. Search Algorithm or Strategy**:\n",
    "Below, our Tune jobs will be defined using a search space and simple grid search. \n",
    "> The typical use case for Tune search spaces is for hyperparameter tuning. In our case, we are defining the Tune search space in order to run distributed tuning jobs automatically.  Each training job will use a different data partition (taxi pickup location), different algorithm, and the compute resources we defined in the Scaling config.\n",
    "\n",
    "**Step 5. Now we are ready to kick off a Ray Tune experiment!** \n",
    "- Define a `tuner` object.\n",
    "- Put the training function `train_model()` inside the `tuner` object.\n",
    "- Run the experiment using `tuner.fit()`.\n",
    "\n",
    "💡 After you run the cell below, right-click on it and choose \"Enable Scrolling for Outputs\"! This will make it easier to view, since tuning output can be very long!\n",
    "\n",
    "**Setting SMOKE_TEST=False, running on Anyscale:  518 models, using 18 NYC Taxi S3 files dating from 2018/01 to 2019/06 (split into partitions approx 1GiB each), simultaneously trained on a 10-node AWS cluster of [m5.4xlarges](https://aws.amazon.com/ec2/instance-types/m5/). Total data reading and train time was 37 minutes.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4acad940",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"tuneStatus\">\n",
       "  <div style=\"display: flex;flex-direction: row\">\n",
       "    <div style=\"display: flex;flex-direction: column;\">\n",
       "      <h3>Tune Status</h3>\n",
       "      <table>\n",
       "<tbody>\n",
       "<tr><td>Current time:</td><td>2023-01-10 16:26:11</td></tr>\n",
       "<tr><td>Running for: </td><td>00:00:20.45        </td></tr>\n",
       "<tr><td>Memory:      </td><td>3.0/30.9 GiB       </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    </div>\n",
       "    <div class=\"vDivider\"></div>\n",
       "    <div class=\"systemInfo\">\n",
       "      <h3>System Info</h3>\n",
       "      Using FIFO scheduling algorithm.<br>Resources requested: 0/152 CPUs, 0/0 GPUs, 0.0/420.22 GiB heap, 0.0/163.21 GiB objects\n",
       "    </div>\n",
       "    \n",
       "  </div>\n",
       "  <div class=\"hDivider\"></div>\n",
       "  <div class=\"trialStatus\">\n",
       "    <h3>Trial Status</h3>\n",
       "    <table>\n",
       "<thead>\n",
       "<tr><th>Trial name             </th><th>status    </th><th>loc                 </th><th>algorithm           </th><th style=\"text-align: right;\">  location</th><th style=\"text-align: right;\">  iter</th><th style=\"text-align: right;\">  total time (s)</th><th style=\"text-align: right;\">   error</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>train_model_7fd9c_00000</td><td>TERMINATED</td><td>172.31.211.165:3629 </td><td>LinearRegression()  </td><td style=\"text-align: right;\">       141</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">         1.90341</td><td style=\"text-align: right;\"> 500.005</td></tr>\n",
       "<tr><td>train_model_7fd9c_00001</td><td>TERMINATED</td><td>172.31.252.125:17717</td><td>XGBRegressor(ba_9dc0</td><td style=\"text-align: right;\">       141</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">         2.41094</td><td style=\"text-align: right;\"> 523.611</td></tr>\n",
       "<tr><td>train_model_7fd9c_00002</td><td>TERMINATED</td><td>172.31.251.87:4579  </td><td>LinearRegression()  </td><td style=\"text-align: right;\">       229</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">         1.86279</td><td style=\"text-align: right;\"> 568.826</td></tr>\n",
       "<tr><td>train_model_7fd9c_00003</td><td>TERMINATED</td><td>172.31.138.114:11079</td><td>XGBRegressor(ba_0040</td><td style=\"text-align: right;\">       229</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">         2.53176</td><td style=\"text-align: right;\"> 583.261</td></tr>\n",
       "<tr><td>train_model_7fd9c_00004</td><td>TERMINATED</td><td>172.31.221.253:3999 </td><td>LinearRegression()  </td><td style=\"text-align: right;\">       173</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">         1.8416 </td><td style=\"text-align: right;\"> 950.346</td></tr>\n",
       "<tr><td>train_model_7fd9c_00005</td><td>TERMINATED</td><td>172.31.136.199:12355</td><td>XGBRegressor(ba_0160</td><td style=\"text-align: right;\">       173</td><td style=\"text-align: right;\">     1</td><td style=\"text-align: right;\">         2.02936</td><td style=\"text-align: right;\">2046.04 </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "  </div>\n",
       "</div>\n",
       "<style>\n",
       ".tuneStatus {\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".tuneStatus .systemInfo {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       ".tuneStatus .trialStatus {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".tuneStatus .hDivider {\n",
       "  border-bottom-width: var(--jp-border-width);\n",
       "  border-bottom-color: var(--jp-border-color0);\n",
       "  border-bottom-style: solid;\n",
       "}\n",
       ".tuneStatus .vDivider {\n",
       "  border-left-width: var(--jp-border-width);\n",
       "  border-left-color: var(--jp-border-color0);\n",
       "  border-left-style: solid;\n",
       "  margin: 0.5em 1em 0.5em 1em;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"trialProgress\">\n",
       "  <h3>Trial Progress</h3>\n",
       "  <table>\n",
       "<thead>\n",
       "<tr><th>Trial name             </th><th style=\"text-align: right;\">   error</th><th>should_checkpoint  </th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>train_model_7fd9c_00000</td><td style=\"text-align: right;\"> 500.005</td><td>True               </td></tr>\n",
       "<tr><td>train_model_7fd9c_00001</td><td style=\"text-align: right;\"> 523.611</td><td>True               </td></tr>\n",
       "<tr><td>train_model_7fd9c_00002</td><td style=\"text-align: right;\"> 568.826</td><td>True               </td></tr>\n",
       "<tr><td>train_model_7fd9c_00003</td><td style=\"text-align: right;\"> 583.261</td><td>True               </td></tr>\n",
       "<tr><td>train_model_7fd9c_00004</td><td style=\"text-align: right;\"> 950.346</td><td>True               </td></tr>\n",
       "<tr><td>train_model_7fd9c_00005</td><td style=\"text-align: right;\">2046.04 </td><td>True               </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "</div>\n",
       "<style>\n",
       ".trialProgress {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".trialProgress h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".trialProgress td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-01-10 16:26:11,740\tINFO tune.py:762 -- Total run time: 22.07 seconds (20.27 seconds for the tuning loop).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of models: 6\n",
      "TOTAL TIME TAKEN: 0.37 minutes\n"
     ]
    }
   ],
   "source": [
    "############\n",
    "# STEP 2. Customize distributed compute scaling.\n",
    "############\n",
    "# Use Ray Tune default resources config which is 1 CPU core for each task.\n",
    "\n",
    "############\n",
    "# STEP 3. Define a search space dict of all config parameters.\n",
    "############\n",
    "search_space = {\n",
    "    \"algorithm\": tune.grid_search(\n",
    "        [LinearRegression(fit_intercept=True), xgb.XGBRegressor(max_depth=4)]\n",
    "    ),\n",
    "    \"location\": tune.grid_search(sample_locations),\n",
    "}\n",
    "\n",
    "# Optional STEP 4. Specify the hyperparameter tuning search strategy.\n",
    "\n",
    "############\n",
    "# STEP 5. Run the experiment with Ray Tune APIs.\n",
    "# https://docs.ray.io/en/latest/tune/examples/tune-pytorch-lightning.html\n",
    "############\n",
    "start = time.time()\n",
    "\n",
    "# Define a tuner object.\n",
    "tuner = tune.Tuner(\n",
    "    train_model,\n",
    "    param_space=search_space,\n",
    "    run_config=train.RunConfig(\n",
    "        # redirect logs to relative path instead of default ~/ray_results/\n",
    "        name=\"batch_tuning\",\n",
    "    ),\n",
    ")\n",
    "\n",
    "# Fit the tuner object.\n",
    "results = tuner.fit()\n",
    "\n",
    "total_time_taken = time.time() - start\n",
    "print(f\"Total number of models: {len(results)}\")\n",
    "print(f\"TOTAL TIME TAKEN: {total_time_taken/60:.2f} minutes\")\n",
    "\n",
    "# Total number of models: 6\n",
    "# TOTAL TIME TAKEN: 0.37 minutes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ae0b413",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "**After the Tune experiment has finished, select the best model per dropoff location.**\n",
    "\n",
    "We can assemble the {doc}`Tune results </tune/examples/tune_analyze_results>` into a pandas dataframe, then sort by minimum error, to select the best model per dropoff location."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "945b3bc2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>location_id</th>\n",
       "      <th>error</th>\n",
       "      <th>algorithm</th>\n",
       "      <th>checkpoint</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>141</td>\n",
       "      <td>500.005318</td>\n",
       "      <td>LinearRegression()</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>141</td>\n",
       "      <td>523.610705</td>\n",
       "      <td>XGBRegressor(base_score=0.5, booster='gbtree',...</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>229</td>\n",
       "      <td>568.826123</td>\n",
       "      <td>LinearRegression()</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>229</td>\n",
       "      <td>583.261077</td>\n",
       "      <td>XGBRegressor(base_score=0.5, booster='gbtree',...</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>173</td>\n",
       "      <td>950.345817</td>\n",
       "      <td>LinearRegression()</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>173</td>\n",
       "      <td>2046.043927</td>\n",
       "      <td>XGBRegressor(base_score=0.5, booster='gbtree',...</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   location_id        error  \\\n",
       "0          141   500.005318   \n",
       "1          141   523.610705   \n",
       "2          229   568.826123   \n",
       "3          229   583.261077   \n",
       "4          173   950.345817   \n",
       "5          173  2046.043927   \n",
       "\n",
       "                                           algorithm  \\\n",
       "0                                 LinearRegression()   \n",
       "1  XGBRegressor(base_score=0.5, booster='gbtree',...   \n",
       "2                                 LinearRegression()   \n",
       "3  XGBRegressor(base_score=0.5, booster='gbtree',...   \n",
       "4                                 LinearRegression()   \n",
       "5  XGBRegressor(base_score=0.5, booster='gbtree',...   \n",
       "\n",
       "                                          checkpoint  \n",
       "0  Checkpoint(local_path=/home/ray/christy-air/fo...  \n",
       "1  Checkpoint(local_path=/home/ray/christy-air/fo...  \n",
       "2  Checkpoint(local_path=/home/ray/christy-air/fo...  \n",
       "3  Checkpoint(local_path=/home/ray/christy-air/fo...  \n",
       "4  Checkpoint(local_path=/home/ray/christy-air/fo...  \n",
       "5  Checkpoint(local_path=/home/ray/christy-air/fo...  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# get a list of training loss errors\n",
    "errors = [i.metrics.get(\"error\", 10000.0) for i in results]\n",
    "\n",
    "# get a list of checkpoints\n",
    "checkpoints = [i.checkpoint for i in results]\n",
    "\n",
    "# get a list of locations\n",
    "locations = [i.config[\"location\"] for i in results]\n",
    "\n",
    "# get a list of model params\n",
    "algorithms = [i.config[\"algorithm\"] for i in results]\n",
    "\n",
    "# Assemble a pandas dataframe from Tune results\n",
    "results_df = pd.DataFrame(\n",
    "    zip(locations, errors, algorithms, checkpoints),\n",
    "    columns=[\"location_id\", \"error\", \"algorithm\", \"checkpoint\"],\n",
    ")\n",
    "results_df.head(8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d5d049af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>error</th>\n",
       "      <th>algorithm</th>\n",
       "      <th>checkpoint</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>location_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>141</th>\n",
       "      <td>500.005318</td>\n",
       "      <td>LinearRegression()</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>229</th>\n",
       "      <td>568.826123</td>\n",
       "      <td>LinearRegression()</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>173</th>\n",
       "      <td>950.345817</td>\n",
       "      <td>LinearRegression()</td>\n",
       "      <td>Checkpoint(local_path=/home/ray/christy-air/fo...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  error           algorithm  \\\n",
       "location_id                                   \n",
       "141          500.005318  LinearRegression()   \n",
       "229          568.826123  LinearRegression()   \n",
       "173          950.345817  LinearRegression()   \n",
       "\n",
       "                                                    checkpoint  \n",
       "location_id                                                     \n",
       "141          Checkpoint(local_path=/home/ray/christy-air/fo...  \n",
       "229          Checkpoint(local_path=/home/ray/christy-air/fo...  \n",
       "173          Checkpoint(local_path=/home/ray/christy-air/fo...  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Keep only 1 model per location_id with minimum error\n",
    "final_df = results_df.copy()\n",
    "final_df = final_df.loc[(final_df.error > 0), :]\n",
    "final_df = final_df.loc[final_df.groupby(\"location_id\")[\"error\"].idxmin()]\n",
    "final_df.sort_values(by=[\"error\"], inplace=True)\n",
    "final_df.set_index(\"location_id\", inplace=True, drop=True)\n",
    "final_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "00ec0f8d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "algorithm         \n",
       "LinearRegression()    1.0\n",
       "dtype: float64"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_df[[\"algorithm\"]].astype(\"str\").value_counts(normalize=True)\n",
    "\n",
    "# 0.67 XGB\n",
    "# 0.33 Linear Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbc62da1",
   "metadata": {},
   "source": [
    "(load_checkpoint)=\n",
    "## Load a model from checkpoint and perform batch prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "249bf4d3",
   "metadata": {},
   "source": [
    "```{tip}\n",
    "Ray Predictors make batch inference easy since they have internal logic to parallelize the inference.\n",
    "```\n",
    "\n",
    "Finally, we will restore the best and worst models from checkpoint and make predictions. \n",
    "\n",
    "- We will easily obtain Checkpoint objects from the Tune results. \n",
    "- We will restore a regression model directly from checkpoint, and demonstrate it can be used for prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ed0e8140",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "141"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Choose a dropoff location\n",
    "sample_location_id = final_df.index[0]\n",
    "sample_location_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "221cb8ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "algorithm type:: <class 'sklearn.linear_model._base.LinearRegression'>\n",
      "checkpoint type:: <class 'ray.air.checkpoint.Checkpoint'>\n"
     ]
    }
   ],
   "source": [
    "# Get the algorithm used\n",
    "sample_algorithm = final_df.loc[[sample_location_id]].algorithm.values[0]\n",
    "print(f\"algorithm type:: {type(sample_algorithm)}\")\n",
    "\n",
    "# Get a checkpoint directly from the pandas dataframe of Tune results\n",
    "checkpoint = final_df.checkpoint[sample_location_id]\n",
    "print(f\"checkpoint type:: {type(checkpoint)}\")\n",
    "\n",
    "# Restore a model from checkpoint\n",
    "with checkpoint.as_directory() as tmpdir:\n",
    "    with open(os.path.join(tmpdir, \"ckpt.pkl\"), \"rb\") as fin:\n",
    "        state_dict = pickle.load(fin)\n",
    "sample_model = state_dict[\"model\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "12770a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create some test data\n",
    "df_list = [read_data(f, sample_location_id) for f in s3_files[:1]]\n",
    "df_raw = pd.concat(df_list, ignore_index=True)\n",
    "df = transform_df(df_raw)\n",
    "_, test_df = train_test_split(df, test_size=0.2, shuffle=True)\n",
    "test_X = test_df[[\"passenger_count\", \"trip_distance\"]]\n",
    "test_y = np.array(test_df.trip_duration)  # actual values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a4e1ce5a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pred_y</th>\n",
       "      <th>trip_duration</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1153.574219</td>\n",
       "      <td>1174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>870.131592</td>\n",
       "      <td>299</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1065.683105</td>\n",
       "      <td>1206</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>591.070801</td>\n",
       "      <td>566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>766.853149</td>\n",
       "      <td>630</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1037.557861</td>\n",
       "      <td>852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1540.295410</td>\n",
       "      <td>1596</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>827.835510</td>\n",
       "      <td>801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1871.982422</td>\n",
       "      <td>1363</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>960.105408</td>\n",
       "      <td>715</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        pred_y  trip_duration\n",
       "0  1153.574219           1174\n",
       "1   870.131592            299\n",
       "2  1065.683105           1206\n",
       "3   591.070801            566\n",
       "4   766.853149            630\n",
       "5  1037.557861            852\n",
       "6  1540.295410           1596\n",
       "7   827.835510            801\n",
       "8  1871.982422           1363\n",
       "9   960.105408            715"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Perform batch prediction using restored model from checkpoint\n",
    "pred_y = sample_model.predict(test_X)\n",
    "\n",
    "# Zip together predictions and actuals to visualize\n",
    "pd.DataFrame(zip(pred_y, test_y), columns=[\"pred_y\", TARGET])[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad2ef857",
   "metadata": {},
   "source": [
    "**Compare validation and test error.**\n",
    "\n",
    "During model training we reported error on \"validation\" data (random sample). Below, we will report error on a pretend \"test\" data set (a different random sample).\n",
    "\n",
    "Do a quick validation that both errors are reasonably close together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "89cb9b79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test error: 513.4911755733472\n"
     ]
    }
   ],
   "source": [
    "# Evaluate restored model on test data.\n",
    "error = sklearn.metrics.mean_absolute_error(test_y, pred_y)\n",
    "print(f\"Test error: {error}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f80b8a57",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation error: 500.0053176600036\n"
     ]
    }
   ],
   "source": [
    "# Compare test error with training validation error\n",
    "print(f\"Validation error: {final_df.error[sample_location_id]}\")\n",
    "\n",
    "# Validation and test errors should be reasonably close together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c0f02f8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orphan": true,
  "vscode": {
   "interpreter": {
    "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}