{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter error examples\n", "\n", "Since the schema of the `C` hyperparameter of `LR` specifies an\n", "exclusive minimum of zero, passing zero is not valid. Lale internally\n", "calls an off-the-shelf JSON Schema validator when an operator gets\n", "configured with concrete hyperparameter values." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression as LR\n", "import lale\n", "lale.wrap_imported_operators()\n", "from lale.settings import set_disable_data_schema_validation, set_disable_hyperparams_schema_validation\n", "#enable schema validation explicitly for the notebook\n", "set_disable_data_schema_validation(False)\n", "set_disable_hyperparams_schema_validation(False)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Invalid configuration for LR(C=0.0) due to invalid value C=0.0.\n", "Some possible fixes include:\n", "- set C=1.0\n", "Schema of argument C: {\n", " \"description\": \"Inverse regularization strength. Smaller values specify stronger regularization.\",\n", " \"type\": \"number\",\n", " \"distribution\": \"loguniform\",\n", " \"minimum\": 0.0,\n", " \"exclusiveMinimum\": true,\n", " \"default\": 1.0,\n", " \"minimumForOptimizer\": 0.03125,\n", " \"maximumForOptimizer\": 32768,\n", "}\n", "Invalid value: 0.0\n" ] } ], "source": [ "import jsonschema\n", "import sys\n", "try:\n", " LR(C=0.0)\n", "except jsonschema.ValidationError as e:\n", " message = e.message\n", "print(message, file=sys.stderr)\n", "assert message.startswith('Invalid configuration for LR(C=0.0)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides per-hyperparameter types, there are also conditional\n", "inter-hyperparameter constraints. These are checked using the\n", "same call to an off-the-shelf JSON Schema validator." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Invalid configuration for LR(solver='sag', penalty='l1') due to constraint the newton-cg, sag, and lbfgs solvers support only l2 or no penalties.\n", "Some possible fixes include:\n", "- set penalty='l2'\n", "Schema of failing constraint: https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.logistic_regression.html#constraint-1\n", "Invalid value: {'solver': 'sag', 'penalty': 'l1', 'dual': False, 'C': 1.0, 'tol': 0.0001, 'fit_intercept': True, 'intercept_scaling': 1.0, 'class_weight': None, 'random_state': None, 'max_iter': 100, 'multi_class': 'auto', 'verbose': 0, 'warm_start': False, 'n_jobs': None, 'l1_ratio': None}\n" ] } ], "source": [ "try:\n", " LR(LR.enum.solver.sag, LR.enum.penalty.l1)\n", "except jsonschema.ValidationError as e:\n", " message = e.message\n", "print(message, file=sys.stderr)\n", "assert message.find('support only l2 or no penalties') != -1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are even constraints that affect three different hyperparameters." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Invalid configuration for LR(penalty='l2', solver='sag', dual=True) due to constraint the dual formulation is only implemented for l2 penalty with the liblinear solver.\n", "Some possible fixes include:\n", "- set dual=False\n", "Schema of failing constraint: https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.logistic_regression.html#constraint-2\n", "Invalid value: {'penalty': 'l2', 'solver': 'sag', 'dual': True, 'C': 1.0, 'tol': 0.0001, 'fit_intercept': True, 'intercept_scaling': 1.0, 'class_weight': None, 'random_state': None, 'max_iter': 100, 'multi_class': 'auto', 'verbose': 0, 'warm_start': False, 'n_jobs': None, 'l1_ratio': None}\n" ] } ], "source": [ "try:\n", " LR(LR.enum.penalty.l2, LR.enum.solver.sag, dual=True)\n", "except jsonschema.ValidationError as e:\n", " message = e.message\n", "print(message, file=sys.stderr)\n", "assert message.find('dual formulation is only implemented for') != -1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset error example for individual operator\n", "\n", "Lale uses JSON Schema validation not only for hyperparameters but also\n", "for data. The dataset `train_X` is multimodal: some columns contain\n", "text strings whereas others contain numbers." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
drugNameconditionreviewdateusefulCountrating
0ValsartanLeft Ventricular Dysfunction\"It has no side effect, I take it in combinati...May 20, 2012279.0
1GuanfacineADHD\"My son is halfway through his fourth week of ...April 27, 20101928.0
2LybrelBirth Control\"I used to take another oral contraceptive, wh...December 14, 2009175.0
3Ortho EvraBirth Control\"This is my first time using any form of birth...November 3, 2015108.0
4Buprenorphine / naloxoneOpiate Dependence\"Suboxone has completely turned my life around...November 27, 2016379.0
\n", "
" ], "text/plain": [ " drugName condition \\\n", "0 Valsartan Left Ventricular Dysfunction \n", "1 Guanfacine ADHD \n", "2 Lybrel Birth Control \n", "3 Ortho Evra Birth Control \n", "4 Buprenorphine / naloxone Opiate Dependence \n", "\n", " review date \\\n", "0 \"It has no side effect, I take it in combinati... May 20, 2012 \n", "1 \"My son is halfway through his fourth week of ... April 27, 2010 \n", "2 \"I used to take another oral contraceptive, wh... December 14, 2009 \n", "3 \"This is my first time using any form of birth... November 3, 2015 \n", "4 \"Suboxone has completely turned my life around... November 27, 2016 \n", "\n", " usefulCount rating \n", "0 27 9.0 \n", "1 192 8.0 \n", "2 17 5.0 \n", "3 10 8.0 \n", "4 37 9.0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "from lale.datasets.uci.uci_datasets import fetch_drugscom\n", "train_X, train_y, test_X, test_y = fetch_drugscom()\n", "pd.concat([train_X.head(), train_y.head()], axis=1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "#Enable the schema validation for data \n", "from lale.settings import set_disable_data_schema_validation\n", "set_disable_data_schema_validation(False)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "{\n", " \"$schema\": \"http://json-schema.org/draft-04/schema#\",\n", " \"type\": \"array\",\n", " \"items\": {\n", " \"type\": \"array\",\n", " \"minItems\": 5,\n", " \"maxItems\": 5,\n", " \"items\": [\n", " {\"description\": \"drugName\", \"type\": \"string\"},\n", " {\n", " \"description\": \"condition\",\n", " \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [NaN]}],\n", " },\n", " {\"description\": \"review\", \"type\": \"string\"},\n", " {\"description\": \"date\", \"type\": \"string\"},\n", " {\"description\": \"usefulCount\", \"type\": \"integer\", \"minimum\": 0},\n", " ],\n", " },\n", " \"minItems\": 161297,\n", " \"maxItems\": 161297,\n", "}\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from lale.pretty_print import ipython_display\n", "ipython_display(lale.datasets.data_schemas.to_schema(train_X))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since `train_X` contains strings but `LR` expects only numbers, the\n", "call to `fit` reports a type error." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "LR.fit() invalid X, the schema of the actual data is not a subschema of the expected schema of the argument.\n", "actual_schema = {\n", " \"$schema\": \"http://json-schema.org/draft-04/schema#\",\n", " \"type\": \"array\",\n", " \"items\": {\n", " \"type\": \"array\",\n", " \"minItems\": 5,\n", " \"maxItems\": 5,\n", " \"items\": [\n", " {\"description\": \"drugName\", \"type\": \"string\"},\n", " {\n", " \"description\": \"condition\",\n", " \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [NaN]}],\n", " },\n", " {\"description\": \"review\", \"type\": \"string\"},\n", " {\"description\": \"date\", \"type\": \"string\"},\n", " {\"description\": \"usefulCount\", \"type\": \"integer\", \"minimum\": 0},\n", " ],\n", " },\n", " \"minItems\": 161297,\n", " \"maxItems\": 161297,\n", "}\n", "expected_schema = {\n", " \"description\": \"Features; the outer array is over samples.\",\n", " \"type\": \"array\",\n", " \"items\": {\"type\": \"array\", \"items\": {\"type\": \"number\"}},\n", "}\n" ] } ], "source": [ "trainable_lr = LR(max_iter=1000)\n", "try:\n", " LR.validate_schema(train_X, train_y)\n", "except ValueError as e:\n", " message = str(e)\n", "print(message, file=sys.stderr)\n", "assert message.startswith('LR.fit() invalid X')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load a pure numerical dataset instead." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)
05.03.41.60.4
16.33.34.71.6
25.13.41.50.2
34.83.01.40.1
46.73.14.71.5
\n", "
" ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n", "0 5.0 3.4 1.6 0.4\n", "1 6.3 3.3 4.7 1.6\n", "2 5.1 3.4 1.5 0.2\n", "3 4.8 3.0 1.4 0.1\n", "4 6.7 3.1 4.7 1.5" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from lale.datasets import load_iris_df\n", "(train_X, train_y), (test_X, test_y) = load_iris_df()\n", "train_X.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training LR with the Iris dataset works fine." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [], "source": [ "trained_lr = trainable_lr.fit(train_X, train_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Lifecycle error example\n", "\n", "Lale encourages separating the lifecycle states, here represented\n", "by `trainable_lr` vs. `trained_lr`. The `predict` method should\n", "only be called on a trained model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test_y [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n", "predicted [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n" ] } ], "source": [ "predicted = trained_lr.predict(test_X)\n", "print(f'test_y {[*test_y]}')\n", "print(f'predicted {[*predicted]}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On the other hand, the `predict` method should not be called on a trainable model." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test_y [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n", "predicted [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "The `predict` method is deprecated on a trainable operator, because the learned coefficients could be accidentally overwritten by retraining. Call `predict` on the trained operator returned by `fit` instead.\n" ] } ], "source": [ "import warnings\n", "warnings.filterwarnings(\"error\", category=DeprecationWarning)\n", "try:\n", " predicted = trainable_lr.predict(test_X)\n", "except DeprecationWarning as w:\n", " message = str(w)\n", "print(message, file=sys.stderr)\n", "assert message.startswith('The `predict` method is deprecated on a trainable')\n", "print(f'test_y {[*test_y]}')\n", "print(f'predicted {[*predicted]}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Delegate error example\n", "\n", "LogisticRegression is an estimator and therefore does not have a\n", "transform method, even when trained." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " trained_lr.transform(train_X)\n", "except AttributeError as e:\n", " message = 'AttributeError'\n", " print(message, file=sys.stderr)\n", "assert message.startswith('AttributeError')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }