{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "20cc39b3", "metadata": {}, "outputs": [], "source": [ "from pprint import pprint\n", "\n", "from sklearn.ensemble import RandomForestRegressor\n", "\n", "from mlflow import MlflowClient" ] }, { "cell_type": "markdown", "id": "bd00304d", "metadata": {}, "source": [ "### Initializing the MLflow Client\n", "\n", "Depending on where you are running this notebook, your configuration may vary for how you initialize the MLflow Client in the following cell. \n", "\n", "For this example, we're using a locally running tracking server, but other options are available (The easiest is to use the free managed service within [Databricks Community Edition](https://community.cloud.databricks.com/)). \n", "\n", "Please see [the guide to running notebooks here](https://www.mlflow.org/docs/latest/getting-started/running-notebooks/index.html) for more information on setting the tracking server uri and configuring access to either managed or self-managed MLflow tracking servers." ] }, { "cell_type": "code", "execution_count": 2, "id": "ac741989", "metadata": {}, "outputs": [], "source": [ "# NOTE: review the links mentioned above for guidance on connecting to a managed tracking server, such as the free Databricks Community Edition\n", "\n", "client = MlflowClient(tracking_uri=\"http://127.0.0.1:8080\")" ] }, { "cell_type": "markdown", "id": "6129354a", "metadata": {}, "source": [ "#### Search Experiments with the MLflow Client API\n", "\n", "Let's take a look at the Default Experiment that is created for us.\n", "\n", "This safe 'fallback' experiment will store Runs that we create if we don't specify a \n", "new experiment. " ] }, { "cell_type": "code", "execution_count": 3, "id": "f208d9b8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[]\n" ] } ], "source": [ "# Search experiments without providing query terms behaves effectively as a 'list' action\n", "\n", "all_experiments = client.search_experiments()\n", "\n", "print(all_experiments)" ] }, { "cell_type": "code", "execution_count": 4, "id": "9b1e1914", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'lifecycle_stage': 'active', 'name': 'Default'}\n" ] } ], "source": [ "# Extract the experiment name and lifecycle_stage\n", "\n", "default_experiment = [\n", " {\"name\": experiment.name, \"lifecycle_stage\": experiment.lifecycle_stage}\n", " for experiment in all_experiments\n", " if experiment.name == \"Default\"\n", "][0]\n", "\n", "pprint(default_experiment)" ] }, { "cell_type": "markdown", "id": "81c37836", "metadata": {}, "source": [ "### Creating a new Experiment\n", "\n", "In this section, we'll:\n", "\n", "* create a new MLflow Experiment\n", "* apply metadata in the form of Experiment Tags" ] }, { "cell_type": "code", "execution_count": 5, "id": "b07c851f", "metadata": {}, "outputs": [], "source": [ "experiment_description = (\n", " \"This is the grocery forecasting project. \"\n", " \"This experiment contains the produce models for apples.\"\n", ")\n", "\n", "experiment_tags = {\n", " \"project_name\": \"grocery-forecasting\",\n", " \"store_dept\": \"produce\",\n", " \"team\": \"stores-ml\",\n", " \"project_quarter\": \"Q3-2023\",\n", " \"mlflow.note.content\": experiment_description,\n", "}\n", "\n", "produce_apples_experiment = client.create_experiment(name=\"Apple_Models\", tags=experiment_tags)" ] }, { "cell_type": "code", "execution_count": 6, "id": "3858e72a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Use search_experiments() to search on the project_name tag key\n", "\n", "apples_experiment = client.search_experiments(\n", " filter_string=\"tags.`project_name` = 'grocery-forecasting'\"\n", ")\n", "\n", "pprint(apples_experiment[0])" ] }, { "cell_type": "code", "execution_count": 7, "id": "181a5545", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "stores-ml\n" ] } ], "source": [ "# Access individual tag data\n", "\n", "print(apples_experiment[0].tags[\"team\"])" ] }, { "cell_type": "markdown", "id": "91c66551", "metadata": {}, "source": [ "### Running our first model training\n", "\n", "In this section, we'll:\n", "\n", "* create a synthetic data set that is relevant to a simple demand forecasting task\n", "* start an MLflow run\n", "* log metrics, parameters, and tags to the run\n", "* save the model to the run\n", "* register the model during model logging" ] }, { "cell_type": "markdown", "id": "5faffa16", "metadata": {}, "source": [ "#### Synthetic data generator for demand of apples\n", "\n", "Keep in mind that this is purely for demonstration purposes. \n", "\n", "The demand value is purely artificial and is deliberately covariant with the features. This is not a particularly realistic real-world scenario (if it were, we wouldn't need Data Scientists!). " ] }, { "cell_type": "code", "execution_count": 8, "id": "2268a1cb", "metadata": {}, "outputs": [], "source": [ "from datetime import datetime, timedelta\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "\n", "def generate_apple_sales_data_with_promo_adjustment(base_demand: int = 1000, n_rows: int = 5000):\n", " \"\"\"\n", " Generates a synthetic dataset for predicting apple sales demand with seasonality and inflation.\n", "\n", " This function creates a pandas DataFrame with features relevant to apple sales.\n", " The features include date, average_temperature, rainfall, weekend flag, holiday flag,\n", " promotional flag, price_per_kg, and the previous day's demand. The target variable,\n", " 'demand', is generated based on a combination of these features with some added noise.\n", "\n", " Args:\n", " base_demand (int, optional): Base demand for apples. Defaults to 1000.\n", " n_rows (int, optional): Number of rows (days) of data to generate. Defaults to 5000.\n", "\n", " Returns:\n", " pd.DataFrame: DataFrame with features and target variable for apple sales prediction.\n", "\n", " Example:\n", " >>> df = generate_apple_sales_data_with_seasonality(base_demand=1200, n_rows=6000)\n", " >>> df.head()\n", " \"\"\"\n", "\n", " # Set seed for reproducibility\n", " np.random.seed(9999)\n", "\n", " # Create date range\n", " dates = [datetime.now() - timedelta(days=i) for i in range(n_rows)]\n", " dates.reverse()\n", "\n", " # Generate features\n", " df = pd.DataFrame(\n", " {\n", " \"date\": dates,\n", " \"average_temperature\": np.random.uniform(10, 35, n_rows),\n", " \"rainfall\": np.random.exponential(5, n_rows),\n", " \"weekend\": [(date.weekday() >= 5) * 1 for date in dates],\n", " \"holiday\": np.random.choice([0, 1], n_rows, p=[0.97, 0.03]),\n", " \"price_per_kg\": np.random.uniform(0.5, 3, n_rows),\n", " \"month\": [date.month for date in dates],\n", " }\n", " )\n", "\n", " # Introduce inflation over time (years)\n", " df[\"inflation_multiplier\"] = 1 + (df[\"date\"].dt.year - df[\"date\"].dt.year.min()) * 0.03\n", "\n", " # Incorporate seasonality due to apple harvests\n", " df[\"harvest_effect\"] = np.sin(2 * np.pi * (df[\"month\"] - 3) / 12) + np.sin(\n", " 2 * np.pi * (df[\"month\"] - 9) / 12\n", " )\n", "\n", " # Modify the price_per_kg based on harvest effect\n", " df[\"price_per_kg\"] = df[\"price_per_kg\"] - df[\"harvest_effect\"] * 0.5\n", "\n", " # Adjust promo periods to coincide with periods lagging peak harvest by 1 month\n", " peak_months = [4, 10] # months following the peak availability\n", " df[\"promo\"] = np.where(\n", " df[\"month\"].isin(peak_months),\n", " 1,\n", " np.random.choice([0, 1], n_rows, p=[0.85, 0.15]),\n", " )\n", "\n", " # Generate target variable based on features\n", " base_price_effect = -df[\"price_per_kg\"] * 50\n", " seasonality_effect = df[\"harvest_effect\"] * 50\n", " promo_effect = df[\"promo\"] * 200\n", "\n", " df[\"demand\"] = (\n", " base_demand\n", " + base_price_effect\n", " + seasonality_effect\n", " + promo_effect\n", " + df[\"weekend\"] * 300\n", " + np.random.normal(0, 50, n_rows)\n", " ) * df[\"inflation_multiplier\"] # adding random noise\n", "\n", " # Add previous day's demand\n", " df[\"previous_days_demand\"] = df[\"demand\"].shift(1)\n", " df[\"previous_days_demand\"].fillna(method=\"bfill\", inplace=True) # fill the first row\n", "\n", " # Drop temporary columns\n", " df.drop(columns=[\"inflation_multiplier\", \"harvest_effect\", \"month\"], inplace=True)\n", "\n", " return df" ] }, { "cell_type": "code", "execution_count": 9, "id": "2924d135", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dateaverage_temperaturerainfallweekendholidayprice_per_kgpromodemandprevious_days_demand
9802023-09-14 11:13:56.94826734.1301831.454065001.4491770971.8024471001.085782
9812023-09-15 11:13:56.94826732.3536439.462859002.8565030818.951553971.802447
9822023-09-16 11:13:56.94826618.8168330.391470101.32642901281.352029818.951553
9832023-09-17 11:13:56.94826534.5330122.120477100.97013101357.3855041281.352029
9842023-09-18 11:13:56.94826523.0572022.365705001.0499310991.4270491357.385504
9852023-09-19 11:13:56.94826434.8101653.089005002.0351490974.971149991.427049
9862023-09-20 11:13:56.94826329.2089053.673292002.51809801056.249547974.971149
9872023-09-21 11:13:56.94826316.4286764.077782001.26897901063.1189151056.249547
9882023-09-22 11:13:56.94826232.0675122.734454000.76231701040.4920071063.118915
9892023-09-23 11:13:56.94826131.93820313.883486101.15330101285.0404701040.492007
9902023-09-24 11:13:56.94826118.0240557.544061100.61070301366.6445641285.040470
9912023-09-25 11:13:56.94826020.68106718.820490001.5334880973.9349241366.644564
9922023-09-26 11:13:56.94825916.0101327.705941001.63249811188.291256973.934924
9932023-09-27 11:13:56.94825918.7664556.274840002.8065540930.0894381188.291256
9942023-09-28 11:13:56.94825827.94879323.705246000.82946401060.576311930.089438
9952023-09-29 11:13:56.94825728.66107210.329865002.2905910910.6907761060.576311
9962023-09-30 11:13:56.94825610.8216933.575645100.89747301306.363801910.690776
9972023-10-01 11:13:56.94825621.1085606.221089101.09386411564.4223721306.363801
9982023-10-02 11:13:56.94825429.4513015.021463002.49308511164.3032561564.422372
9992023-10-03 11:13:56.94824819.2614580.438381002.61042211067.9634481164.303256
\n", "
" ], "text/plain": [ " date average_temperature rainfall weekend \n", "980 2023-09-14 11:13:56.948267 34.130183 1.454065 0 \\\n", "981 2023-09-15 11:13:56.948267 32.353643 9.462859 0 \n", "982 2023-09-16 11:13:56.948266 18.816833 0.391470 1 \n", "983 2023-09-17 11:13:56.948265 34.533012 2.120477 1 \n", "984 2023-09-18 11:13:56.948265 23.057202 2.365705 0 \n", "985 2023-09-19 11:13:56.948264 34.810165 3.089005 0 \n", "986 2023-09-20 11:13:56.948263 29.208905 3.673292 0 \n", "987 2023-09-21 11:13:56.948263 16.428676 4.077782 0 \n", "988 2023-09-22 11:13:56.948262 32.067512 2.734454 0 \n", "989 2023-09-23 11:13:56.948261 31.938203 13.883486 1 \n", "990 2023-09-24 11:13:56.948261 18.024055 7.544061 1 \n", "991 2023-09-25 11:13:56.948260 20.681067 18.820490 0 \n", "992 2023-09-26 11:13:56.948259 16.010132 7.705941 0 \n", "993 2023-09-27 11:13:56.948259 18.766455 6.274840 0 \n", "994 2023-09-28 11:13:56.948258 27.948793 23.705246 0 \n", "995 2023-09-29 11:13:56.948257 28.661072 10.329865 0 \n", "996 2023-09-30 11:13:56.948256 10.821693 3.575645 1 \n", "997 2023-10-01 11:13:56.948256 21.108560 6.221089 1 \n", "998 2023-10-02 11:13:56.948254 29.451301 5.021463 0 \n", "999 2023-10-03 11:13:56.948248 19.261458 0.438381 0 \n", "\n", " holiday price_per_kg promo demand previous_days_demand \n", "980 0 1.449177 0 971.802447 1001.085782 \n", "981 0 2.856503 0 818.951553 971.802447 \n", "982 0 1.326429 0 1281.352029 818.951553 \n", "983 0 0.970131 0 1357.385504 1281.352029 \n", "984 0 1.049931 0 991.427049 1357.385504 \n", "985 0 2.035149 0 974.971149 991.427049 \n", "986 0 2.518098 0 1056.249547 974.971149 \n", "987 0 1.268979 0 1063.118915 1056.249547 \n", "988 0 0.762317 0 1040.492007 1063.118915 \n", "989 0 1.153301 0 1285.040470 1040.492007 \n", "990 0 0.610703 0 1366.644564 1285.040470 \n", "991 0 1.533488 0 973.934924 1366.644564 \n", "992 0 1.632498 1 1188.291256 973.934924 \n", "993 0 2.806554 0 930.089438 1188.291256 \n", "994 0 0.829464 0 1060.576311 930.089438 \n", "995 0 2.290591 0 910.690776 1060.576311 \n", "996 0 0.897473 0 1306.363801 910.690776 \n", "997 0 1.093864 1 1564.422372 1306.363801 \n", "998 0 2.493085 1 1164.303256 1564.422372 \n", "999 0 2.610422 1 1067.963448 1164.303256 " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Generate the dataset!\n", "\n", "data = generate_apple_sales_data_with_promo_adjustment(base_demand=1_000, n_rows=1_000)\n", "\n", "data[-20:]" ] }, { "cell_type": "markdown", "id": "e076a312", "metadata": {}, "source": [ "### Train and log the model\n", "\n", "We're now ready to import our model class and train a ``RandomForestRegressor``" ] }, { "cell_type": "code", "execution_count": 10, "id": "6e354900", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "import mlflow\n", "\n", "# Use the fluent API to set the tracking uri and the active experiment\n", "mlflow.set_tracking_uri(\"http://127.0.0.1:8080\")\n", "\n", "# Sets the current active experiment to the \"Apple_Models\" experiment and returns the Experiment metadata\n", "apple_experiment = mlflow.set_experiment(\"Apple_Models\")\n", "\n", "# Define a run name for this iteration of training.\n", "# If this is not set, a unique name will be auto-generated for your run.\n", "run_name = \"apples_rf_test\"\n", "\n", "# Define an artifact path that the model will be saved to.\n", "artifact_path = \"rf_apples\"" ] }, { "cell_type": "code", "execution_count": 11, "id": "ae02e54b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/mlflow/models/signature.py:333: UserWarning: Hint: Inferred schema contains integer column(s). Integer columns in Python cannot represent missing values. If your input data contains missing values at inference time, it will be encoded as floats and will cause a schema enforcement error. The best way to avoid this problem is to infer the model schema based on a realistic data sample (training dataset) that includes missing values. Alternatively, you can declare integer columns as doubles (float64) whenever these columns may have missing values. See `Handling Integers With Missing Values `_ for more details.\n", " input_schema = _infer_schema(input_ex)\n", "/Users/benjamin.wilson/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils.\n", " warnings.warn(\"Setuptools is replacing distutils.\")\n" ] } ], "source": [ "# Split the data into features and target and drop irrelevant date field and target field\n", "X = data.drop(columns=[\"date\", \"demand\"])\n", "y = data[\"demand\"]\n", "\n", "# Split the data into training and validation sets\n", "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n", "\n", "params = {\n", " \"n_estimators\": 100,\n", " \"max_depth\": 6,\n", " \"min_samples_split\": 10,\n", " \"min_samples_leaf\": 4,\n", " \"bootstrap\": True,\n", " \"oob_score\": False,\n", " \"random_state\": 888,\n", "}\n", "\n", "# Train the RandomForestRegressor\n", "rf = RandomForestRegressor(**params)\n", "\n", "# Fit the model on the training data\n", "rf.fit(X_train, y_train)\n", "\n", "# Predict on the validation set\n", "y_pred = rf.predict(X_val)\n", "\n", "# Calculate error metrics\n", "mae = mean_absolute_error(y_val, y_pred)\n", "mse = mean_squared_error(y_val, y_pred)\n", "rmse = np.sqrt(mse)\n", "r2 = r2_score(y_val, y_pred)\n", "\n", "# Assemble the metrics we're going to write into a collection\n", "metrics = {\"mae\": mae, \"mse\": mse, \"rmse\": rmse, \"r2\": r2}\n", "\n", "# Initiate the MLflow run context\n", "with mlflow.start_run(run_name=run_name) as run:\n", " # Log the parameters used for the model fit\n", " mlflow.log_params(params)\n", "\n", " # Log the error metrics that were calculated during validation\n", " mlflow.log_metrics(metrics)\n", "\n", " # Log an instance of the trained model for later use\n", " mlflow.sklearn.log_model(sk_model=rf, input_example=X_val, artifact_path=artifact_path)" ] }, { "cell_type": "markdown", "id": "84c06abe", "metadata": {}, "source": [ "#### Success!\n", "\n", "You've just logged your first MLflow model! \n", "\n", "Navigate to the MLflow UI to see the run that was just created (named \"apples_rf_test\", logged to the Experiment \"Apple_Models\"). " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }