{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter error examples\n", "\n", "Since the schema of the `C` hyperparameter of `LR` specifies an\n", "exclusive minimum of zero, passing zero is not valid. Lale internally\n", "calls an off-the-shelf JSON Schema validator when an operator gets\n", "configured with concrete hyperparameter values." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression as LR\n", "import lale\n", "lale.wrap_imported_operators()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Invalid configuration for LR(C=0.0) due to invalid value C=0.0.\n", "Schema of argument C: {\n", " 'description': 'Inverse regularization strength. Smaller values specify stronger regularization.',\n", " 'type': 'number',\n", " 'distribution': 'loguniform',\n", " 'minimum': 0.0,\n", " 'exclusiveMinimum': true,\n", " 'default': 1.0,\n", " 'minimumForOptimizer': 0.03125,\n", " 'maximumForOptimizer': 32768}\n", "Value: 0.0\n" ] } ], "source": [ "import jsonschema\n", "import sys\n", "try:\n", " LR(C=0.0)\n", "except jsonschema.ValidationError as e:\n", " message = e.message\n", "print(message, file=sys.stderr)\n", "assert message.startswith('Invalid configuration for LR(C=0.0)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides per-hyperparameter types, there are also conditional\n", "inter-hyperparameter constraints. These are checked using the\n", "same call to an off-the-shelf JSON Schema validator." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Invalid configuration for LR(solver='sag', penalty='l1') due to constraint the newton-cg, sag, and lbfgs solvers support only l2 penalties.\n", "Schema of constraint 1: {\n", " 'description': 'The newton-cg, sag, and lbfgs solvers support only l2 penalties.',\n", " 'anyOf': [\n", " { 'type': 'object',\n", " 'properties': {\n", " 'solver': {\n", " 'not': {\n", " 'enum': ['newton-cg', 'sag', 'lbfgs']}}}},\n", " { 'type': 'object',\n", " 'properties': {\n", " 'penalty': {\n", " 'enum': ['l2']}}}]}\n", "Value: {'solver': 'sag', 'penalty': 'l1', 'dual': False, 'C': 1.0, 'tol': 0.0001, 'fit_intercept': True, 'intercept_scaling': 1.0, 'class_weight': None, 'random_state': None, 'max_iter': 100, 'multi_class': 'ovr', 'verbose': 0, 'warm_start': False, 'n_jobs': None}\n" ] } ], "source": [ "try:\n", " LR(LR.solver.sag, LR.penalty.l1)\n", "except jsonschema.ValidationError as e:\n", " message = e.message\n", "print(message, file=sys.stderr)\n", "assert message.find('support only l2 penalties') != -1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are even constraints that affect three different hyperparameters." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Invalid configuration for LR(penalty='l2', solver='sag', dual=True) due to constraint the dual formulation is only implemented for l2 penalty with the liblinear solver.\n", "Schema of constraint 2: {\n", " 'description': 'The dual formulation is only implemented for l2 penalty with the liblinear solver.',\n", " 'anyOf': [\n", " { 'type': 'object',\n", " 'properties': {\n", " 'dual': {\n", " 'enum': [false]}}},\n", " { 'type': 'object',\n", " 'properties': {\n", " 'penalty': {\n", " 'enum': ['l2']},\n", " 'solver': {\n", " 'enum': ['liblinear']}}}]}\n", "Value: {'penalty': 'l2', 'solver': 'sag', 'dual': True, 'C': 1.0, 'tol': 0.0001, 'fit_intercept': True, 'intercept_scaling': 1.0, 'class_weight': None, 'random_state': None, 'max_iter': 100, 'multi_class': 'ovr', 'verbose': 0, 'warm_start': False, 'n_jobs': None}\n" ] } ], "source": [ "try:\n", " LR(LR.penalty.l2, LR.solver.sag, dual=True)\n", "except jsonschema.ValidationError as e:\n", " message = e.message\n", "print(message, file=sys.stderr)\n", "assert message.find('dual formulation is only implemented for') != -1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset error example for individual operator\n", "\n", "Lale uses JSON Schema validation not only for hyperparameters but also\n", "for data. The dataset `train_X` is multimodal: some columns contain\n", "text strings whereas others contain numbers." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
drugNameconditionreviewdateusefulCountrating
0ValsartanLeft Ventricular Dysfunction\"It has no side effect, I take it in combinati...May 20, 2012279.0
1GuanfacineADHD\"My son is halfway through his fourth week of ...April 27, 20101928.0
2LybrelBirth Control\"I used to take another oral contraceptive, wh...December 14, 2009175.0
3Ortho EvraBirth Control\"This is my first time using any form of birth...November 3, 2015108.0
4Buprenorphine / naloxoneOpiate Dependence\"Suboxone has completely turned my life around...November 27, 2016379.0
\n", "
" ], "text/plain": [ " drugName condition \\\n", "0 Valsartan Left Ventricular Dysfunction \n", "1 Guanfacine ADHD \n", "2 Lybrel Birth Control \n", "3 Ortho Evra Birth Control \n", "4 Buprenorphine / naloxone Opiate Dependence \n", "\n", " review date \\\n", "0 \"It has no side effect, I take it in combinati... May 20, 2012 \n", "1 \"My son is halfway through his fourth week of ... April 27, 2010 \n", "2 \"I used to take another oral contraceptive, wh... December 14, 2009 \n", "3 \"This is my first time using any form of birth... November 3, 2015 \n", "4 \"Suboxone has completely turned my life around... November 27, 2016 \n", "\n", " usefulCount rating \n", "0 27 9.0 \n", "1 192 8.0 \n", "2 17 5.0 \n", "3 10 8.0 \n", "4 37 9.0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "from lale.datasets.uci.uci_datasets import fetch_drugscom\n", "train_X, train_y, test_X, test_y = fetch_drugscom()\n", "pd.concat([train_X.head(), train_y.head()], axis=1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/markdown": [ "```python\n", "{\n", " \u0001'type': 'array',\n", " 'items': {\n", " 'type': 'array',\n", " 'minItems': 5,\n", " 'maxItems': 5,\n", " 'items': [\n", " { 'description': 'drugName',\n", " 'type': 'string'},\n", " { 'description': 'condition',\n", " 'anyOf': [\n", " { 'type': 'string'},\n", " { 'enum': [NaN]}]},\n", " { 'description': 'review',\n", " 'type': 'string'},\n", " { 'description': 'date',\n", " 'type': 'string'},\n", " { 'description': 'usefulCount',\n", " 'type': 'integer',\n", " 'minimum': 0}]},\n", " 'minItems': 161297,\n", " 'maxItems': 161297}\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from lale.pretty_print import ipython_display\n", "ipython_display(lale.datasets.data_schemas.to_schema(train_X))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since `train_X` contains strings but `LR` expects only numbers, the\n", "call to `fit` reports a type error." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "LR.fit() invalid X: Expected sub to be a subschema of super.\n", "sub = {\n", " \u0001'type': 'array',\n", " 'items': {\n", " 'type': 'array',\n", " 'minItems': 5,\n", " 'maxItems': 5,\n", " 'items': [\n", " { 'description': 'drugName',\n", " 'type': 'string'},\n", " { 'description': 'condition',\n", " 'anyOf': [\n", " { 'type': 'string'},\n", " { 'enum': [NaN]}]},\n", " { 'description': 'review',\n", " 'type': 'string'},\n", " { 'description': 'date',\n", " 'type': 'string'},\n", " { 'description': 'usefulCount',\n", " 'type': 'integer',\n", " 'minimum': 0}]},\n", " 'minItems': 161297,\n", " 'maxItems': 161297}\n", "super = {\n", " 'description': 'Features; the outer array is over samples.',\n", " 'type': 'array',\n", " 'items': {\n", " 'type': 'array',\n", " 'items': {\n", " 'type': 'number'}}}\n" ] } ], "source": [ "trainable_lr = LR()\n", "try:\n", " LR.validate_schema(train_X, train_y)\n", "except ValueError as e:\n", " message = str(e)\n", "print(message, file=sys.stderr)\n", "assert message.startswith('LR.fit() invalid X')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load a pure numerical dataset instead." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)
05.03.41.60.4
16.33.34.71.6
25.13.41.50.2
34.83.01.40.1
46.73.14.71.5
\n", "
" ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n", "0 5.0 3.4 1.6 0.4\n", "1 6.3 3.3 4.7 1.6\n", "2 5.1 3.4 1.5 0.2\n", "3 4.8 3.0 1.4 0.1\n", "4 6.7 3.1 4.7 1.5" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from lale.datasets import load_iris_df\n", "(train_X, train_y), (test_X, test_y) = load_iris_df()\n", "train_X.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training LR with the Iris dataset works fine." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [], "source": [ "trained_lr = trainable_lr.fit(train_X, train_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Lifecycle error example\n", "\n", "Lale encourages separating the lifecycle states, here represented\n", "by `trainable_lr` vs. `trained_lr`. The `predict` method should\n", "only be called on a trained model." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test_y [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n", "predicted [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n" ] } ], "source": [ "predicted = trained_lr.predict(test_X)\n", "print(f'test_y {[*test_y]}')\n", "print(f'predicted {[*predicted]}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On the other hand, the `predict` method should not be called on a trainable model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test_y [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n", "predicted [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "The `predict` method is deprecated on a trainable operator, because the learned coefficients could be accidentally overwritten by retraining. Call `predict` on the trained operator returned by `fit` instead.\n" ] } ], "source": [ "import warnings\n", "warnings.filterwarnings(\"error\", category=DeprecationWarning)\n", "try:\n", " predicted = trainable_lr.predict(test_X)\n", "except DeprecationWarning as w:\n", " message = str(w)\n", "print(message, file=sys.stderr)\n", "assert message.startswith('The `predict` method is deprecated on a trainable')\n", "print(f'test_y {[*test_y]}')\n", "print(f'predicted {[*predicted]}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Delegate error example\n", "\n", "LogisticRegression is an estimator and therefore does not have a\n", "transform method, even when trained." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "AttributeError\n" ] } ], "source": [ "try:\n", " trained_lr.transform(train_X)\n", "except AttributeError as e:\n", " message = 'AttributeError'\n", " print(message, file=sys.stderr)\n", "assert message.startswith('AttributeError')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }