{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lale: Auto-ML and Types for Scikit-learn\n", "\n", "This notebook is an introductory guide to\n", "[Lale](https://github.com/ibm/lale) for scikit-learn users.\n", "[Scikit-learn](https://scikit-learn.org) is a popular, easy-to-use,\n", "and comprehensive data science library for Python. This notebook aims\n", "to show how Lale can make scikit-learn even better in two areas:\n", "auto-ML and type checking. First, if you do not want to manually\n", "select all algorithms or tune all hyperparameters, you can leave it to\n", "Lale to do that for you automatically. Second, when you pass\n", "hyperparameters or datasets to scikit-learn, Lale checks that these\n", "are type-correct. For both auto-ML and type-checking, Lale uses a\n", "single source of truth: machine-readable schemas associated with\n", "scikit-learn compatible transformers and estimators. Rather than\n", "invent a new schema specification language, Lale uses [JSON\n", "Schema](https://json-schema.org/understanding-json-schema/), because\n", "it is popular, widely-supported, and makes it easy to store or send\n", "hyperparameters as JSON objects. Furthermore, by using the same\n", "schemas both for auto-ML and for type-checking, Lale ensures that\n", "auto-ML is consistent with type checking while also reducing the\n", "maintenance burden to a single set of schemas.\n", "\n", "Lale is an open-source Python library and you can install it by doing\n", "`pip install lale`. See\n", "[installation](https://github.com/IBM/lale/blob/master/docs/installation.rst)\n", "for further instructions. Lale uses the term *operator* to refer to\n", "what scikit-learn calls machine-learning transformer or estimator.\n", "Lale provides schemas for 180\n", "[operators](https://github.com/IBM/lale/tree/master/lale/lib). Most of\n", "these operators come from scikit-learn itself, but there are also\n", "operators from other frameworks such as XGBoost or PyTorch.\n", "If Lale does not yet support your favorite operator, you can add it\n", "yourself by following this\n", "[guide](https://github.com/IBM/lale/blob/master/examples/docs_new_operators.ipynb).\n", "If you do add a new operator, please consider contributing it back to\n", "Lale!\n", "\n", "The rest of this notebook first demonstrates auto-ML, then reveals\n", "some of the schemas that make that possible, and finally demonstrates\n", "how to also use the very same schemas for type checking." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Auto-ML with Lale\n", "\n", "Lale serves as an interface for two Auto-ML tasks: hyperparameter tuning\n", "and algorithm selection. Rather than provide new implementations for\n", "these tasks, Lale reuses existing implementations. The next few cells\n", "demonstrate how to use Hyperopt and GridSearchCV from Lale. Lale also\n", "supports additional optimizers, not shown in this notebook. In all\n", "cases, the syntax for specifying the search space is the same.\n", "\n", "### 1.1 Hyperparameter Tuning with Lale and Hyperopt\n", "\n", "Let's start by looking at hyperparameter tuning, which is an important\n", "subtask of auto-ML. To demonstrate it, we first need a dataset.\n", "Therefore, we load the California Housing dataset and display the\n", "first few rows to get a feeling for the data. Lale can process both\n", "Pandas dataframes and Numpy ndarrays; here we use dataframes." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MedIncHouseAgeAveRoomsAveBedrmsPopulationAveOccupLatitudeLongitudetarget
03.259633.05.0176571.0064212300.03.69181432.71-117.031.030
13.812549.04.4735451.0410051314.01.73809533.77-118.163.821
24.15634.05.6458330.985119915.02.72321434.66-120.481.726
31.942536.04.0028171.0338031418.03.99436632.69-117.110.934
43.554243.06.2684211.134211874.02.30000036.78-119.800.965
\n", "
" ], "text/plain": [ " MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude \\\n", "0 3.2596 33.0 5.017657 1.006421 2300.0 3.691814 32.71 \n", "1 3.8125 49.0 4.473545 1.041005 1314.0 1.738095 33.77 \n", "2 4.1563 4.0 5.645833 0.985119 915.0 2.723214 34.66 \n", "3 1.9425 36.0 4.002817 1.033803 1418.0 3.994366 32.69 \n", "4 3.5542 43.0 6.268421 1.134211 874.0 2.300000 36.78 \n", "\n", " Longitude target \n", "0 -117.03 1.030 \n", "1 -118.16 3.821 \n", "2 -120.48 1.726 \n", "3 -117.11 0.934 \n", "4 -119.80 0.965 " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "import lale.datasets\n", "(train_X, train_y), (test_X, test_y) = lale.datasets.california_housing_df()\n", "pd.concat([train_X.head(), train_y.head()], axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, the target column is a continuous number, indicating\n", "that this is a regression task. Besides the target, there are eight\n", "feature columns, which are also all continuous numbers. That means\n", "many scikit-learn operators will work out of the box on this data\n", "without needing to preprocess it first. Next, we need to import a few\n", "operators. `PCA` (principal component analysis) is a transformer from\n", "scikit-learn for linear dimensionality reduction.\n", "`DecisionTreeRegressor` is an estimator from scikit-learn that can\n", "predict the target column. `Pipeline` is how scikit-learn composes\n", "operators into a sequence. `Hyperopt` is a Lale wrapper for\n", "the [hyperopt](http://hyperopt.github.io/hyperopt/) auto-ML library.\n", "And finally, `wrap_imported_operators` augments `PCA`, `Tree`, and\n", "`Pipeline` with schemas to enable Lale to tune their hyperparameters." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "from sklearn.tree import DecisionTreeRegressor as Tree\n", "from sklearn.pipeline import Pipeline\n", "from lale.lib.lale import Hyperopt\n", "lale.wrap_imported_operators()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we create a two-step pipeline of `PCA` and `Tree`. This code\n", "looks almost like in scikit-learn. The only difference is that since\n", "we want Lale to tune the hyperparameters for us, we do\n", "not specify them by hand. Specifically, we just write `PCA` instead of\n", "`PCA(...)`, omitting the hyperparameters for `PCA`. Analogously, we\n", "just write `Tree` instead of `Tree(...)`, omitting the hyperparameters\n", "for `Tree`. Rather than binding hyperparameters by hand, we leave them\n", "free to be tuned by hyperopt." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "pca_tree_planned = Pipeline(steps=[('tfm', PCA), ('estim', Tree)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use `auto_configure` on the pipeline and pass `Hyperopt` as an optimizer. This will use the pipeline's search space to find the best pipeline. In this case, the search uses 10 trials. Each\n", "trial draws values for the hyperparameters from the ranges specified\n", "by the JSON schemas associated with the operators in the pipeline." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|████████| 10/10 [00:14<00:00, 1.42s/trial, best loss: -0.4141076900047944]\n", "CPU times: user 22.4 s, sys: 40.8 s, total: 1min 3s\n", "Wall time: 15 s\n" ] } ], "source": [ "%%time\n", "pca_tree_trained = pca_tree_planned.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=10, verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, Hyperopt uses k-fold cross validation \n", "to evaluate each trial and a default scoring metric based on the task. The end result is the pipeline that\n", "performed best out of all trials. In addition to the cross-val score,\n", "we can also evaluate this best pipeline against the test data. We\n", "simply use the existing R2 score metric from scikit-learn for this\n", "purpose." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R2 score 0.37\n" ] } ], "source": [ "import sklearn.metrics\n", "predicted = pca_tree_trained.predict(test_X)\n", "print(f'R2 score {sklearn.metrics.r2_score(test_y, predicted):.2f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2 Inspecting the Results of Automation\n", "\n", "In the previous example, the automation picked hyperparameter values\n", "for PCA and the decision tree. We know the values were valid and we\n", "know how well the pipeline performed with them. But we might also want\n", "to know exactly which values were picked. One way to do that is by\n", "visualizing the pipeline and using tooltips. If you are looking at\n", "this notebook in a viewer that supports tooltips, you can hover the\n", "mouse pointer over either one of the operators to see its\n", "hyperparameters." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "Pipeline: tfm, estim\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "\n", "tree\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "\n", "\n", "pca->tree\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pca_tree_trained.visualize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another way to view the results of hyperparameter tuning in Lale is by\n", "pretty-printing the pipeline as Python source code. Calling the\n", "`pretty_print` method with `ipython_display=True` prints the code with\n", "syntax highlighting in a Jupyter notebook. The pretty-printed code\n", "contains the hyperparameters." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.decomposition import PCA\n", "from sklearn.tree import DecisionTreeRegressor as Tree\n", "import lale\n", "\n", "lale.wrap_imported_operators()\n", "pca = PCA(svd_solver=\"full\", whiten=True)\n", "tree = Tree(\n", " criterion=\"friedman_mse\",\n", " min_samples_leaf=0.09016751753288961,\n", " min_samples_split=0.47029117142535803,\n", ")\n", "pipeline = Pipeline(steps=[(\"tfm\", pca), (\"estim\", tree)])\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pca_tree_trained.pretty_print(ipython_display=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.3 Hyperparameter Tuning with Lale and GridSearchCV\n", "\n", "Lale supports multiple auto-ML tools, not just hyperopt. For instance,\n", "you can also use\n", "[GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)\n", "from scikit-learn. You could use the exact same `pca_tree_planned`\n", "pipeline for this as we did with the hyperopt tool.\n", "However, to avoid running for a long time, here we simplify the space:\n", "for `PCA`, we bind the `svd_solver` so only the remaining hyperparameters\n", "are being searched, and for `Tree`, we call `freeze_trainable()` to bind\n", "all hyperparameters to their defaults. Lale again uses the schemas\n", "attached to the operators in the pipeline to generate a suitable search grid.\n", "Here, instead of the scikit-learn's `Pipeline(...)` API, we use the\n", "`make_pipeline` function. This function exists in both scikit-learn and\n", "Lale; the Lale version yields a Lale pipeline that supports `auto_configure`.\n", "Note that, to be compatible with scikit-learn, `lale.lib.lale.GridSearchCV`\n", "can also take a `param_grid` as an argument if the user chooses to use a\n", "handcrafted grid instead of the one generated automatically." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 11.2 s, sys: 20 s, total: 31.2 s\n", "Wall time: 6.43 s\n" ] } ], "source": [ "%%time\n", "from lale.lib.lale import GridSearchCV\n", "from lale.operators import make_pipeline\n", "grid_search_planned = make_pipeline(\n", " PCA(svd_solver='auto'), Tree().freeze_trainable())\n", "grid_search_result = grid_search_planned.auto_configure(\n", " train_X, train_y, optimizer=GridSearchCV, cv=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just like we saw earlier with hyperopt, you can use the best pipeline\n", "found for scoring and evaluate the quality of the predictions." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R2 score 0.49\n" ] } ], "source": [ "predicted = grid_search_result.predict(test_X)\n", "print(f'R2 score {sklearn.metrics.r2_score(test_y, predicted):.2f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similarly, to inspect the results of grid search, you have the same\n", "options as demonstrated earlier for hypopt. For instance, you can\n", "pretty-print the best pipeline found by grid search back as Python\n", "source code, and then look at its hyperparameters." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "\n", "tree\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "\n", "\n", "pca->tree\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "```python\n", "from sklearn.decomposition import PCA\n", "from sklearn.tree import DecisionTreeRegressor as Tree\n", "from lale.operators import make_pipeline\n", "\n", "pca = PCA(whiten=True)\n", "pipeline = make_pipeline(pca, Tree())\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "grid_search_result.visualize()\n", "grid_search_result.pretty_print(ipython_display=True, combinators=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we do not pretty-print with `combinators=False`, the pretty-printed\n", "code is rendered slightly differently, using `>>` instead of `make_pipeline`." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "from sklearn.decomposition import PCA\n", "from sklearn.tree import DecisionTreeRegressor as Tree\n", "import lale\n", "\n", "lale.wrap_imported_operators()\n", "pca = PCA(whiten=True)\n", "pipeline = pca >> Tree()\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "grid_search_result.pretty_print(ipython_display=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.4 Pipeline Combinators\n", "\n", "We already saw that `>>` is syntactic sugar for `make_pipeline`. Lale\n", "refers to `>>` as the *pipe combinator*. Besides `>>`, Lale supports\n", "two additional combinators. Before we introduce them, let's import a\n", "few more things." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from lale.lib.lale import NoOp, ConcatFeatures\n", "from sklearn.linear_model import LinearRegression as LinReg\n", "from xgboost import XGBRegressor as XGBoost\n", "lale.wrap_imported_operators()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lale emulates the scikit-learn APIs for composing pipelines using\n", "functions. We already saw `make_pipeline`. Another function in\n", "scikit-learn is `make_union`, which composes multiple sub-pipelines to\n", "run on the same data, then concatenates the features. In other words,\n", "`make_union` produces a horizontal stack of the data transformed by\n", "its sub-pipelines. To support auto-ML, Lale introduces a third\n", "function, `make_choice`, which does not exist in scikit-learn. The\n", "`make_choice` function specifies an algorithmic choice for auto-ML to\n", "resolve. In other words, `make_choice` creates a search space for\n", "automated algorithm selection." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "cluster:choice\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "\n", "concat_features\n", "\n", "\n", "Concat-\n", "Features\n", "\n", "\n", "\n", "\n", "\n", "pca->concat_features\n", "\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "\n", "no_op->concat_features\n", "\n", "\n", "\n", "\n", "\n", "tree\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "\n", "\n", "concat_features->tree\n", "\n", "\n", "\n", "\n", "\n", "lin_reg\n", "\n", "\n", "Lin-\n", "Reg\n", "\n", "\n", "\n", "\n", "\n", "xg_boost\n", "\n", "\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dag_with_functions = lale.operators.make_pipeline(\n", " lale.operators.make_union(PCA, NoOp),\n", " lale.operators.make_choice(Tree, LinReg, XGBoost(booster='gbtree')))\n", "dag_with_functions.visualize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The visualization shows `make_union` as multiple sub-pipelines feeding\n", "into `ConcatFeatures`, and it shows `make_choice` using an `|`\n", "combinator. Operators shown in white are already fully trained; in\n", "this case, these operators actually do not have any learnable\n", "coefficients, nor do they have hyperparameters. For each of the three\n", "functions `make_pipeline`, `make_choice`, and `make_union`, Lale also\n", "provides a corresponding combinator. We already saw the pipe\n", "combinator (`>>`) and the choice combinator (`|`). To get the effect\n", "of `make_union`, use the *and combinator* (`&`) with the\n", "`ConcatFeatures` operator. The next example shows the exact same\n", "pipeline as before, but written using combinators instead of\n", "functions." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "cluster:choice\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "\n", "concat_features\n", "\n", "\n", "Concat-\n", "Features\n", "\n", "\n", "\n", "\n", "\n", "pca->concat_features\n", "\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "\n", "no_op->concat_features\n", "\n", "\n", "\n", "\n", "\n", "tree\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "\n", "\n", "concat_features->tree\n", "\n", "\n", "\n", "\n", "\n", "lin_reg\n", "\n", "\n", "Lin-\n", "Reg\n", "\n", "\n", "\n", "\n", "\n", "xg_boost\n", "\n", "\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dag_with_combinators = (\n", " (PCA(svd_solver='full') & NoOp)\n", " >> ConcatFeatures\n", " >> (Tree | LinReg | XGBoost(booster='gbtree')))\n", "dag_with_combinators.visualize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.5 Combined Algorithm Selection and Hyperparameter Optimization\n", "\n", "Since the `dag_with_functions` specifies an algorithm choice, when we\n", "feed it to a `Hyperopt`, hyperopt will do algorithm selection\n", "for us. And since some of the operators in the dag do not have all\n", "their hyperparameters bound, hyperopt will also tune their free\n", "hyperparameters for us. Note that `booster` for `XGBoost` is fixed to `gbtree` and hence Hyperopt would not tune it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[18:37:48] WARNING: ../src/objective/regression_obj.cu:188: reg:linear is now deprecated in favor of reg:squarederror.\n", " 10%|▉ | 1/10 [00:00<00:04, 1.93trial/s, best loss: -0.6110921251096774]" ] } ], "source": [ "%%time\n", "multi_alg_trained = dag_with_functions.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Visualizing the best estimator reveals what algorithms\n", "hyperopt chose." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "multi_alg_trained.visualize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pretty-printing the best estimator reveals how hyperopt tuned the\n", "hyperparameters. For instance, we can see that a `randomized` `svd_solver` was chosen for PCA." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "multi_alg_trained.pretty_print(ipython_display=True, show_imports=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Of course, the trained pipeline can be used for predictions as usual,\n", "and we can use scikit-learn metrics to evaluate those predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predicted = multi_alg_trained.predict(test_X)\n", "print(f'R2 score {sklearn.metrics.r2_score(test_y, predicted):.2f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Viewing and Customizing Schemas\n", "\n", "This section reveals more of what happens behind the scenes for\n", "auto-ML with Lale. In particular, it shows the JSON Schemas used for\n", "auto-ML, and demonstrates how to customize them if desired.\n", "\n", "### 2.1 Looking at Schemas from a Notebook\n", "\n", "When writing data science code, I often don't remember all the API\n", "information about what hyperparameters and datasets an operator\n", "expects. Lale attaches this information to the operators and uses it\n", "for auto-ML as demonstrated above. The same information can also be\n", "useful as interactive documentation in a notebook. Most individual\n", "operators in the visualizations shown earlier in this notebook actually\n", "contain a hyperlink to the excellent online documentation of\n", "scikit-learn. We can also retrieve that hyperlink using a method call." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(Tree.documentation_url())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lale's helper function `ipython_display` pretty-prints JSON documents\n", "and JSON schemas in a Jupyter notebook. You can get a quick overview\n", "of the constructor arguments of an operator by calling the\n", "`get_defaults` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from lale.pretty_print import ipython_display\n", "ipython_display(dict(Tree.get_defaults()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hyperparameters can be categorical (meaning they accept a few\n", "discrete values) or continuous (integers or real numbers).\n", "As an example for a categorical hyperparameter, let's look at the\n", "`criterion`. JSON Schema can encode categoricals as an `enum`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ipython_display(Tree.hyperparam_schema('criterion'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As an example for a continuous hyperparameter, let's look at\n", "`max_depth`. The decision tree regressor in scikit-learn accepts\n", "either an integer for that, or `None`, which has its own meaning.\n", "JSON Schema can express these two choices as an `anyOf`, and\n", "encodes the Python `None` as a JSON `null`. Also, while\n", "any positive integer is a valid value, in the context of auto-ML,\n", "Lale specifies a bounded range for the optimizer to search over." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ipython_display(Tree.hyperparam_schema('max_depth'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides hyperparameter schemas, Lale also provides dataset schemas.\n", "For exampe, NMF, which stands for non-negative matrix factorization,\n", "requires a non-negative matrix as `X`. In JSON Schema, we express this\n", "as an array of arrays of numbers with `minimum: 0`. While NMF also\n", "accepts a second argument `y`, it does not use that argument.\n", "Therefore, Lale gives `y` the schema `{'laleType': 'Any'}`, which permits any\n", "values." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import NMF\n", "lale.wrap_imported_operators()\n", "ipython_display(NMF.input_schema_fit())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Customizing Schemas from a Notebook\n", "\n", "While you can use Lale schemas as-is, you can also customize the\n", "schemas to exert more control over the automation. As one example, it is common to tune XGBoost to use a large number for `n_estimators`. However, you might want to\n", "reduce the number of trees in an XGBoost forest to reduce memory\n", "consumption or to improve explainability. As another example, you\n", "might want to hand-pick one of the boosters to reduce the search space\n", "and thus hopefully speed up the search." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import lale.schemas as schemas\n", "Grove = XGBoost.customize_schema(\n", " n_estimators=schemas.Int(minimum=2, maximum=6),\n", " booster=schemas.Enum(['gbtree'], default='gbtree'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As this example demonstrates, Lale provides a simple Python API for\n", "writing schemas, which it then converts to JSON Schema internally. The\n", "result of customization is a new copy of the operator that can be used\n", "in the same way as any other operator in Lale. In particular, it can\n", "be part of a pipeline as before." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grove_planned = lale.operators.make_pipeline(\n", " lale.operators.make_union(PCA, NoOp),\n", " Grove)\n", "grove_planned.visualize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given this new planned pipeline, we use hyperopt as before to search\n", "for a good trained pipeline." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "grove_trained = grove_planned.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As with all trained Lale pipelines, we can evaluate `grove_trained`\n", "with metrics to see how well it does. Also, we can pretty-print\n", "it back as Python code to double-check whether hyperopt obeyed the\n", "customized schemas for `n_estimators` and `booster`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predicted = grove_trained.predict(test_X)\n", "print(f'R2 score {sklearn.metrics.r2_score(test_y, predicted):.2f}')\n", "grove_trained.pretty_print(ipython_display=True, show_imports=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Type-Checking with Lale\n", "\n", "The rest of this notebook gives examples for how the same schemas\n", "that serve for auto-ML can also serve for error checking. We will\n", "give comparative examples for error checking in scikit-learn (without\n", "schemas) and in Lale (with schemas). To make it clear which version\n", "of an operator is being used, all of the following examples uses\n", "fully-qualified names (e.g., `sklearn.feature_selection.RFE`). The\n", "fully-qualified names are for presentation purposes only; in typical\n", "usage of either scikit-learn or Lale, these would be simple names\n", "(e.g. just `RFE`).\n", "\n", "### 3.1 Hyperparameter Error Example in Scikit-Learn\n", "\n", "First, we import a few things." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import sklearn\n", "from sklearn import pipeline, feature_selection, ensemble, tree" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use `make_pipeline` to compose a pipeline of two steps: an RFE\n", "transformer and a decision tree regressor. RFE performs recursive\n", "feature elimination, keeping only those features of the input data\n", "that are the most useful for its `estimator` argument. For RFE's\n", "estimator argument, the following code uses a random forest with 10\n", "trees." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sklearn_hyperparam_error = sklearn.pipeline.make_pipeline(\n", " sklearn.feature_selection.RFE(\n", " estimator=sklearn.ensemble.RandomForestRegressor(n_estimators=10)),\n", " sklearn.tree.DecisionTreeRegressor(max_depth=-1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `max_depth` argument for a decision tree cannot be a\n", "negative number. Hence, the above code actually contains a bug: it\n", "sets `max_depth=-1`. Scikit-learn does not check for this mistake from\n", "the `__init__` method, otherwise we would have seen an error message\n", "already. Instead, scikit-learn checks for this mistake during `fit`.\n", "Unfortunately, it takes a few seconds to get the exception, because\n", "scikit-learn first trains the RFE transformer and uses it to transform\n", "the data. Only then does it pass the data to the decision tree." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "try:\n", " sklearn_hyperparam_error.fit(train_X, train_y)\n", "except ValueError as e:\n", " message = str(e)\n", "print(message, file=sys.stderr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fortunately, this error message is pretty clear. Scikit-learn\n", "implements the error check imperatively, using Python if-statements\n", "to raise an exception when hyperparameters are configured wrong.\n", "This notebook is part of Lale's regression test suite and gets run\n", "automatically when changes are pushed to the Lale source code\n", "repository. The assertion in the following cell is a test that the\n", "error-check indeed behaves as expected and documented here." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert message.startswith(\"The 'max_depth' parameter of DecisionTreeRegressor must be an int in the range [1, inf) or None. Got -1 instead.\") or message.startswith(\"max_depth must be greater than zero.\") or message.startswith(\"max_depth == -1, must be >= 1.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2 Checking Hyperparameters with Types\n", "\n", "Lale performs the same error checks, but using JSON Schema validation\n", "instead of Python if-statements and raise-statements. First, we import\n", "the `jsonschema` validator so we can catch its exceptions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jsonschema\n", "#enable schema validation explicitly for the notebook\n", "from lale.settings import set_disable_data_schema_validation, set_disable_hyperparams_schema_validation\n", "set_disable_data_schema_validation(False)\n", "set_disable_hyperparams_schema_validation(False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is the exact same pipeline as before, but written in Lale\n", "instead of directly in scikit-learn. In both cases, the underlying\n", "implementation is in scikit-learn; Lale only adds thin wrappers to\n", "support type checking and auto-ML." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "try:\n", " lale_hyperparam_error = lale.operators.make_pipeline(\n", " lale.lib.sklearn.RFE(\n", " estimator=lale.lib.sklearn.RandomForestRegressor(n_estimators=10)),\n", " lale.lib.sklearn.DecisionTreeRegressor(max_depth=-1))\n", "except jsonschema.ValidationError as e:\n", " message = e.message\n", "print(message, file=sys.stderr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert message.startswith(\"Invalid configuration for DecisionTreeRegressor(max_depth=-1)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just like in the scikit-learn example, the error message in the Lale\n", "example also pin-points the problem as passing `max_depth=-1` to the\n", "decision tree. It does so in a more stylized way, printing the\n", "relevant JSON schema for this hyperparameter. Lale detects the error\n", "already when the wrong hyperparameter is being passed as an argument,\n", "thus reducing the amount of code you have to look at to find the root\n", "cause. Furthermore, Lale takes only tens of milliseconds to detect\n", "the error, because it does not attempt to train the RFE transformer\n", "first. In this example, that only saves a few seconds, which may not\n", "be significant. But there are situations with larger time savings,\n", "such as when using larger datasets, slower operators, or when auto-ML\n", "tries out many pipelines." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.3 Dataset Error Example in Scikit-Learn\n", "\n", "Above, we saw an example for detecting a hyperparameter error in\n", "scikit-learn and in Lale. Next, we look at an analogous example for a\n", "dataset error. Again, let's first look at the experience with\n", "scikit-learn and then the same thing with Lale." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn import decomposition" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use scikit-learn to compose a pipeline of two steps: an RFE\n", "transformer as before, this time followed by an NMF transformer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sklearn_dataset_error = sklearn.pipeline.make_pipeline(\n", " sklearn.feature_selection.RFE(\n", " estimator=sklearn.ensemble.RandomForestRegressor(n_estimators=10)),\n", " sklearn.decomposition.NMF())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "NMF, or non-negative matrix factorization, does not allow any negative\n", "numbers in its input matrix. The California Housing dataset contains\n", "some negative numbers and the RFE does not eliminate those features.\n", "To detect the mistake, scikit-learn must first train the RFE and\n", "transform the data with it, which takes a few seconds. Then, NMF\n", "detects the error and throws an exception." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "try:\n", " sklearn_dataset_error.fit(train_X, train_y)\n", "except ValueError as e:\n", " message = str(e)\n", "print(message, file=sys.stderr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert message.startswith(\"Negative values in data passed to NMF (input X)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.4 Types for Dataset Checking\n", "\n", "Lale uses types (as expressed using JSON schemas) to check\n", "dataset-related mistakes. Below is the same pipeline as before, using\n", "thin Lale wrappers around scikit-learn operators. We redefine the\n", "pipeline to enable Lale type-checking for it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lale_dataset_error = lale.operators.make_pipeline(\n", " lale.lib.sklearn.RFE(\n", " estimator=lale.lib.sklearn.RandomForestRegressor(n_estimators=10)),\n", " lale.lib.sklearn.NMF())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we call `fit` on the pipeline, before doing the actual training,\n", "Lale can check that the\n", "schema is correct at each step of the pipeline. In other words, it\n", "checks whether the schema of the input data is valid for the first\n", "step of the pipeline, and that the schema of the output from each step\n", "is valid for the next step. By saving the time for training the RFE,\n", "this completes in tens of milliseconds instead of seconds as before." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Enable the data schema validation in lale settings\n", "from lale.settings import set_disable_data_schema_validation\n", "set_disable_data_schema_validation(False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "try:\n", " lale_dataset_error.fit(train_X, train_y)\n", "except ValueError as e:\n", " message = str(e)\n", "print(message, file=sys.stderr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert message.startswith('NMF.fit() invalid X, the schema of the actual data is not a subschema of the expected schema of the argument.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, the schemas for `X` differ: whereas the data is an\n", "array of arrays of unconstrained numbers, NMF expects an array of\n", "arrays of only non-negative numbers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.5 Hyperparameter Constraint Example in Scikit-Learn\n", "\n", "Sometimes, the validity of hyperparameters cannot be checked in\n", "isolation. Instead, the value of one hyperparameter can restrict\n", "which values are valid for another hyperparameter. For example,\n", "scikit-learn imposes a conditional hyperparameter constraint between\n", "the `svd_solver` and `n_components` arguments to PCA." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sklearn_constraint_error = sklearn.pipeline.make_pipeline(\n", " sklearn.feature_selection.RFE(\n", " estimator=sklearn.ensemble.RandomForestRegressor(n_estimators=10)),\n", " sklearn.decomposition.PCA(svd_solver='arpack', n_components='mle'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above notebook cell completed successfully, because scikit-learn\n", "did not yet check for the constraint. To observe the error message\n", "with scikit-learn, we must attempt to fit the pipeline." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "message=None\n", "try:\n", " sklearn_constraint_error.fit(train_X, train_y)\n", "except ValueError as e:\n", " message = str(e)\n", "print(message, file=sys.stderr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert message.startswith(\"n_components='mle' cannot be a string with svd_solver='arpack'\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Scikit-learn implements constraint-checking as Python code with\n", "if-statements and raise-statements. After a few seconds, we get an\n", "exception, and the error message explains what went wrong." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.6 Types for Constraint Checking\n", "\n", "Lale specifies constraints using JSON Schemas. When you configure an\n", "operator with actual hyperparameters, Lale immediately validates them\n", "against their schema including constraints." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "try:\n", " lale_constraint_error = lale.operators.make_pipeline(\n", " lale.lib.sklearn.RFE(\n", " estimator=lale.lib.sklearn.RandomForestRegressor(n_estimators=10)),\n", " PCA(svd_solver='arpack', n_components='mle'))\n", "except jsonschema.ValidationError as e:\n", " message = str(e)\n", "print(message, file=sys.stderr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert message.startswith(\"Invalid configuration for PCA(svd_solver='arpack', n_components='mle')\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lale reports the error quicker than scikit-learn, taking only tens of\n", "milliseconds instead of multiple seconds. The error message contains\n", "both a natural-language description of the constraint and its formal\n", "representation in JSON Schema. The `'anyOf'` implements an 'or', so\n", "you can read the constraints as\n", "\n", "```python\n", "(not (n_components in ['mle'])) or (svd_solver in ['full', 'auto'])\n", "```\n", "\n", "By basic Boolean algebra, this is equivalent to an implication\n", "\n", "```python\n", "(n_components in ['mle']) implies (svd_solver in ['full', 'auto'])\n", "```\n", "\n", "Since the constraint is specified declaratively in the schema, it gets\n", "applied wherever the schema gets used. Specifically, the constraint\n", "gets applied both during auto-ML and during type-checking. In the\n", "context of auto-ML, the constraint prunes the search space: it\n", "eliminates some hyperparameter combinations so that the auto-ML tool\n", "does not have to try them out. We have observed cases where this\n", "pruning makes a big difference in search convergence." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Conclusion\n", "\n", "This notebook showed additions to scikit-learn that simplify auto-ML\n", "as well as error checking. The common foundation for both of these\n", "additions is schemas for operators. For further reading, return to the\n", "Lale github [repository](https://github.com/ibm/lale), where you can\n", "find installation instructions, an FAQ, and links to further\n", "documentation, notebooks, talks, etc." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }