{ "cells": [ { "cell_type": "markdown", "id": "594dfbd3", "metadata": {}, "source": [ "# ⭐ Scalling Machine Learning in Three Week course ⭐\n", "\n", "## Intro to MLFlow\n", "\n", "In this excercise, you will use:\n", "* MLflow\n", "* Track runa and experiment\n", "* MLFlow cli\n", "* ElasticNet by sklearn\n", "* Training a simple model to understand MLFlow tracking capabilites.\n", "\n", "\n", "This excercise is part of the [Scaling Machine Learning with Spark book](https://learning.oreilly.com/library/view/scaling-machine-learning/9781098106812/)\n", "available on the O'Reilly platform or on [Amazon](https://amzn.to/3WgHQvd)." ] }, { "cell_type": "code", "execution_count": 1, "id": "b89b3a00", "metadata": {}, "outputs": [], "source": [ "# The data set used in this example is from http://archive.ics.uci.edu/ml/datasets/Wine+Quality\n", "# P. Cortez, A. Cerdeira, F. Almeida, T. Matos and J. Reis.\n", "# Modeling wine preferences by data mining from physicochemical properties. In Decision Support Systems, Elsevier, 47(4):547-553, 2009.\n", "\n", "import os\n", "import warnings\n", "import sys\n", "\n", "import pandas as pd\n", "import numpy as np\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.linear_model import ElasticNet\n", "from urllib.parse import urlparse\n", "import mlflow\n", "import mlflow.sklearn\n", "\n", "import logging" ] }, { "cell_type": "code", "execution_count": 2, "id": "ab0d501e", "metadata": {}, "outputs": [], "source": [ "logging.basicConfig(level=logging.WARN)\n", "logger = logging.getLogger(__name__)" ] }, { "cell_type": "markdown", "id": "811b51c7", "metadata": {}, "source": [ "## Set eval metrics for \n", "We are using rmse, mae and r2.\n", "\n", "\n", "rmse - Root Mean Squared Error\n", "\n", "mae - Mean Absolute Error\n", "\n", "**RMSE and MAE** - The lower value of MAE, MSE, and RMSE implies higher accuracy of a regression model.\n", "\n", "> In our case of ElasticNet is part of the Linear Regression family where the x (input) and y (output) are assumed to have a linear relationship.\n", "\n", "\n", "\n", "**r2**- A higher value of R square is considered desirable. R Squared & Adjusted R Squared are used for explaining how well the independent variables in the linear regression model explains the variability in the dependent variable.\n", "\n", "### MAE\n", "Mean Absolute Error - In the context of machine learning, absolute error refers to the magnitude of difference between the prediction of an observation and the true value of that observation.\n", "\n", "![text](../figures/mae.jpeg)" ] }, { "cell_type": "markdown", "id": "58cb4edd", "metadata": {}, "source": [ "### RMSE\n", "It measures the average difference between values predicted by a model and the actual values. \n", "\n", "It provides an estimation of how well the model is able to predict the target value (accuracy).\n" ] }, { "cell_type": "markdown", "id": "834128c7", "metadata": {}, "source": [ "### R2 or R Square\n", "\n", "Statistical measure that represents the goodness of fit of a regression model. \n", "\n", "The ideal value for r-square is **1**. \n", "\n", "The closer the value of r-square to 1, the better is the model fitted.\n", "\n", "![text](../figures/rsquare.jpeg)" ] }, { "cell_type": "code", "execution_count": 3, "id": "0b3e8c28", "metadata": {}, "outputs": [], "source": [ "def eval_metrics(actual, pred):\n", " rmse = np.sqrt(mean_squared_error(actual, pred))\n", " mae = mean_absolute_error(actual, pred)\n", " r2 = r2_score(actual, pred)\n", " return rmse, mae, r2" ] }, { "cell_type": "code", "execution_count": 4, "id": "f6c65f3d", "metadata": {}, "outputs": [], "source": [ " # Read the wine-quality csv file from path\n", " csv_path = (\n", " \"../datasets/winequality-red.csv\"\n", " )\n", " try:\n", " data = pd.read_csv(csv_path, sep=\";\")\n", " except Exception as e:\n", " logger.exception(\n", " \"Error: %s\", e)" ] }, { "cell_type": "code", "execution_count": 5, "id": "343dae80", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | fixed acidity | \n", "volatile acidity | \n", "citric acid | \n", "residual sugar | \n", "chlorides | \n", "free sulfur dioxide | \n", "total sulfur dioxide | \n", "density | \n", "pH | \n", "sulphates | \n", "alcohol | \n", "quality | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "7.4 | \n", "0.700 | \n", "0.00 | \n", "1.9 | \n", "0.076 | \n", "11.0 | \n", "34.0 | \n", "0.99780 | \n", "3.51 | \n", "0.56 | \n", "9.4 | \n", "5 | \n", "
1 | \n", "7.8 | \n", "0.880 | \n", "0.00 | \n", "2.6 | \n", "0.098 | \n", "25.0 | \n", "67.0 | \n", "0.99680 | \n", "3.20 | \n", "0.68 | \n", "9.8 | \n", "5 | \n", "
2 | \n", "7.8 | \n", "0.760 | \n", "0.04 | \n", "2.3 | \n", "0.092 | \n", "15.0 | \n", "54.0 | \n", "0.99700 | \n", "3.26 | \n", "0.65 | \n", "9.8 | \n", "5 | \n", "
3 | \n", "11.2 | \n", "0.280 | \n", "0.56 | \n", "1.9 | \n", "0.075 | \n", "17.0 | \n", "60.0 | \n", "0.99800 | \n", "3.16 | \n", "0.58 | \n", "9.8 | \n", "6 | \n", "
4 | \n", "7.4 | \n", "0.700 | \n", "0.00 | \n", "1.9 | \n", "0.076 | \n", "11.0 | \n", "34.0 | \n", "0.99780 | \n", "3.51 | \n", "0.56 | \n", "9.4 | \n", "5 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1594 | \n", "6.2 | \n", "0.600 | \n", "0.08 | \n", "2.0 | \n", "0.090 | \n", "32.0 | \n", "44.0 | \n", "0.99490 | \n", "3.45 | \n", "0.58 | \n", "10.5 | \n", "5 | \n", "
1595 | \n", "5.9 | \n", "0.550 | \n", "0.10 | \n", "2.2 | \n", "0.062 | \n", "39.0 | \n", "51.0 | \n", "0.99512 | \n", "3.52 | \n", "0.76 | \n", "11.2 | \n", "6 | \n", "
1596 | \n", "6.3 | \n", "0.510 | \n", "0.13 | \n", "2.3 | \n", "0.076 | \n", "29.0 | \n", "40.0 | \n", "0.99574 | \n", "3.42 | \n", "0.75 | \n", "11.0 | \n", "6 | \n", "
1597 | \n", "5.9 | \n", "0.645 | \n", "0.12 | \n", "2.0 | \n", "0.075 | \n", "32.0 | \n", "44.0 | \n", "0.99547 | \n", "3.57 | \n", "0.71 | \n", "10.2 | \n", "5 | \n", "
1598 | \n", "6.0 | \n", "0.310 | \n", "0.47 | \n", "3.6 | \n", "0.067 | \n", "18.0 | \n", "42.0 | \n", "0.99549 | \n", "3.39 | \n", "0.66 | \n", "11.0 | \n", "6 | \n", "
1599 rows × 12 columns
\n", "