{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameter error examples\n",
"\n",
"Since the schema of the `C` hyperparameter of `LR` specifies an\n",
"exclusive minimum of zero, passing zero is not valid. Lale internally\n",
"calls an off-the-shelf JSON Schema validator when an operator gets\n",
"configured with concrete hyperparameter values."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression as LR\n",
"import lale\n",
"lale.wrap_imported_operators()\n",
"from lale.settings import set_disable_data_schema_validation, set_disable_hyperparams_schema_validation\n",
"#enable schema validation explicitly for the notebook\n",
"set_disable_data_schema_validation(False)\n",
"set_disable_hyperparams_schema_validation(False)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Invalid configuration for LR(C=0.0) due to invalid value C=0.0.\n",
"Some possible fixes include:\n",
"- set C=1.0\n",
"Schema of argument C: {\n",
" \"description\": \"Inverse regularization strength. Smaller values specify stronger regularization.\",\n",
" \"type\": \"number\",\n",
" \"distribution\": \"loguniform\",\n",
" \"minimum\": 0.0,\n",
" \"exclusiveMinimum\": True,\n",
" \"default\": 1.0,\n",
" \"minimumForOptimizer\": 0.03125,\n",
" \"maximumForOptimizer\": 32768,\n",
"}\n",
"Invalid value: 0.0\n"
]
}
],
"source": [
"import jsonschema\n",
"import sys\n",
"try:\n",
" LR(C=0.0)\n",
"except jsonschema.ValidationError as e:\n",
" message = e.message\n",
"print(message, file=sys.stderr)\n",
"assert message.startswith('Invalid configuration for LR(C=0.0)')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Besides per-hyperparameter types, there are also conditional\n",
"inter-hyperparameter constraints. These are checked using the\n",
"same call to an off-the-shelf JSON Schema validator."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Invalid configuration for LR(solver='sag', penalty='l1') due to constraint the newton-cg, sag, and lbfgs solvers support only l2 or no penalties.\n",
"Some possible fixes include:\n",
"- set penalty='l2'\n",
"Schema of failing constraint: https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.logistic_regression.html#constraint-1\n",
"Invalid value: {'solver': 'sag', 'penalty': 'l1', 'dual': False, 'C': 1.0, 'tol': 0.0001, 'fit_intercept': True, 'intercept_scaling': 1.0, 'class_weight': None, 'random_state': None, 'max_iter': 100, 'multi_class': 'auto', 'verbose': 0, 'warm_start': False, 'n_jobs': None, 'l1_ratio': None}\n"
]
}
],
"source": [
"try:\n",
" LR(LR.enum.solver.sag, LR.enum.penalty.l1)\n",
"except jsonschema.ValidationError as e:\n",
" message = e.message\n",
"print(message, file=sys.stderr)\n",
"assert message.find('support only l2 or no penalties') != -1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are even constraints that affect three different hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Invalid configuration for LR(penalty='l2', solver='sag', dual=True) due to constraint the dual formulation is only implemented for l2 penalty with the liblinear solver.\n",
"Some possible fixes include:\n",
"- set dual=False\n",
"Schema of failing constraint: https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.logistic_regression.html#constraint-2\n",
"Invalid value: {'penalty': 'l2', 'solver': 'sag', 'dual': True, 'C': 1.0, 'tol': 0.0001, 'fit_intercept': True, 'intercept_scaling': 1.0, 'class_weight': None, 'random_state': None, 'max_iter': 100, 'multi_class': 'auto', 'verbose': 0, 'warm_start': False, 'n_jobs': None, 'l1_ratio': None}\n"
]
}
],
"source": [
"try:\n",
" LR(LR.enum.penalty.l2, LR.enum.solver.sag, dual=True)\n",
"except jsonschema.ValidationError as e:\n",
" message = e.message\n",
"print(message, file=sys.stderr)\n",
"assert message.find('dual formulation is only implemented for') != -1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset error example for individual operator\n",
"\n",
"Lale uses JSON Schema validation not only for hyperparameters but also\n",
"for data. The dataset `train_X` is multimodal: some columns contain\n",
"text strings whereas others contain numbers."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" urlDrugName | \n",
" effectiveness | \n",
" sideEffects | \n",
" condition | \n",
" benefitsReview | \n",
" sideEffectsReview | \n",
" commentsReview | \n",
" rating | \n",
"
\n",
" \n",
" \n",
" \n",
" | 2202 | \n",
" enalapril | \n",
" Highly Effective | \n",
" Mild Side Effects | \n",
" management of congestive heart failure | \n",
" slowed the progression of left ventricular dys... | \n",
" cough, hypotension , proteinuria, impotence , ... | \n",
" monitor blood pressure , weight and asses for ... | \n",
" 4 | \n",
"
\n",
" \n",
" | 3117 | \n",
" ortho-tri-cyclen | \n",
" Highly Effective | \n",
" Severe Side Effects | \n",
" birth prevention | \n",
" Although this type of birth control has more c... | \n",
" Heavy Cycle, Cramps, Hot Flashes, Fatigue, Lon... | \n",
" I Hate This Birth Control, I Would Not Suggest... | \n",
" 1 | \n",
"
\n",
" \n",
" | 1146 | \n",
" ponstel | \n",
" Highly Effective | \n",
" No Side Effects | \n",
" menstrual cramps | \n",
" I was used to having cramps so badly that they... | \n",
" Heavier bleeding and clotting than normal. | \n",
" I took 2 pills at the onset of my menstrual cr... | \n",
" 10 | \n",
"
\n",
" \n",
" | 3947 | \n",
" prilosec | \n",
" Marginally Effective | \n",
" Mild Side Effects | \n",
" acid reflux | \n",
" The acid reflux went away for a few months aft... | \n",
" Constipation, dry mouth and some mild dizzines... | \n",
" I was given Prilosec prescription at a dose of... | \n",
" 3 | \n",
"
\n",
" \n",
" | 1951 | \n",
" lyrica | \n",
" Marginally Effective | \n",
" Severe Side Effects | \n",
" fibromyalgia | \n",
" I think that the Lyrica was starting to help w... | \n",
" I felt extremely drugged and dopey. Could not... | \n",
" See above | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" urlDrugName effectiveness sideEffects \\\n",
"2202 enalapril Highly Effective Mild Side Effects \n",
"3117 ortho-tri-cyclen Highly Effective Severe Side Effects \n",
"1146 ponstel Highly Effective No Side Effects \n",
"3947 prilosec Marginally Effective Mild Side Effects \n",
"1951 lyrica Marginally Effective Severe Side Effects \n",
"\n",
" condition \\\n",
"2202 management of congestive heart failure \n",
"3117 birth prevention \n",
"1146 menstrual cramps \n",
"3947 acid reflux \n",
"1951 fibromyalgia \n",
"\n",
" benefitsReview \\\n",
"2202 slowed the progression of left ventricular dys... \n",
"3117 Although this type of birth control has more c... \n",
"1146 I was used to having cramps so badly that they... \n",
"3947 The acid reflux went away for a few months aft... \n",
"1951 I think that the Lyrica was starting to help w... \n",
"\n",
" sideEffectsReview \\\n",
"2202 cough, hypotension , proteinuria, impotence , ... \n",
"3117 Heavy Cycle, Cramps, Hot Flashes, Fatigue, Lon... \n",
"1146 Heavier bleeding and clotting than normal. \n",
"3947 Constipation, dry mouth and some mild dizzines... \n",
"1951 I felt extremely drugged and dopey. Could not... \n",
"\n",
" commentsReview rating \n",
"2202 monitor blood pressure , weight and asses for ... 4 \n",
"3117 I Hate This Birth Control, I Would Not Suggest... 1 \n",
"1146 I took 2 pills at the onset of my menstrual cr... 10 \n",
"3947 I was given Prilosec prescription at a dose of... 3 \n",
"1951 See above 2 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from lale.datasets.uci import fetch_drugslib\n",
"train_X, train_y, test_X, test_y = fetch_drugslib()\n",
"pd.concat([train_X.head(), train_y.head()], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"#Enable the schema validation for data \n",
"from lale.settings import set_disable_data_schema_validation\n",
"set_disable_data_schema_validation(False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"```python\n",
"{\n",
" \"$schema\": \"http://json-schema.org/draft-04/schema#\",\n",
" \"type\": \"array\",\n",
" \"items\": {\n",
" \"type\": \"array\",\n",
" \"minItems\": 7,\n",
" \"maxItems\": 7,\n",
" \"items\": [\n",
" {\"description\": \"urlDrugName\", \"type\": \"string\"},\n",
" {\n",
" \"description\": \"effectiveness\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"sideEffects\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"condition\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"benefitsReview\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"sideEffectsReview\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"commentsReview\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" ],\n",
" },\n",
" \"minItems\": 3107,\n",
" \"maxItems\": 3107,\n",
"}\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from lale.pretty_print import ipython_display\n",
"ipython_display(lale.datasets.data_schemas.to_schema(train_X))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since `train_X` contains strings but `LR` expects only numbers, the\n",
"call to `fit` reports a type error."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"LR.fit() invalid X, the schema of the actual data is not a subschema of the expected schema of the argument.\n",
"actual_schema = {\n",
" \"$schema\": \"http://json-schema.org/draft-04/schema#\",\n",
" \"type\": \"array\",\n",
" \"items\": {\n",
" \"type\": \"array\",\n",
" \"minItems\": 7,\n",
" \"maxItems\": 7,\n",
" \"items\": [\n",
" {\"description\": \"urlDrugName\", \"type\": \"string\"},\n",
" {\n",
" \"description\": \"effectiveness\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"sideEffects\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"condition\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"benefitsReview\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"sideEffectsReview\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" {\n",
" \"description\": \"commentsReview\",\n",
" \"anyOf\": [{\"type\": \"string\"}, {\"enum\": [float(\"nan\")]}],\n",
" },\n",
" ],\n",
" },\n",
" \"minItems\": 3107,\n",
" \"maxItems\": 3107,\n",
"}\n",
"expected_schema = {\n",
" \"description\": \"Features; the outer array is over samples.\",\n",
" \"type\": \"array\",\n",
" \"items\": {\"type\": \"array\", \"items\": {\"type\": \"number\"}},\n",
"}\n"
]
}
],
"source": [
"trainable_lr = LR(max_iter=1000)\n",
"try:\n",
" LR.validate_schema(train_X, train_y)\n",
"except ValueError as e:\n",
" message = str(e)\n",
"print(message, file=sys.stderr)\n",
"assert message.startswith('LR.fit() invalid X')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load a pure numerical dataset instead."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" sepal length (cm) | \n",
" sepal width (cm) | \n",
" petal length (cm) | \n",
" petal width (cm) | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 5.0 | \n",
" 3.4 | \n",
" 1.6 | \n",
" 0.4 | \n",
"
\n",
" \n",
" | 1 | \n",
" 6.3 | \n",
" 3.3 | \n",
" 4.7 | \n",
" 1.6 | \n",
"
\n",
" \n",
" | 2 | \n",
" 5.1 | \n",
" 3.4 | \n",
" 1.5 | \n",
" 0.2 | \n",
"
\n",
" \n",
" | 3 | \n",
" 4.8 | \n",
" 3.0 | \n",
" 1.4 | \n",
" 0.1 | \n",
"
\n",
" \n",
" | 4 | \n",
" 6.7 | \n",
" 3.1 | \n",
" 4.7 | \n",
" 1.5 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n",
"0 5.0 3.4 1.6 0.4\n",
"1 6.3 3.3 4.7 1.6\n",
"2 5.1 3.4 1.5 0.2\n",
"3 4.8 3.0 1.4 0.1\n",
"4 6.7 3.1 4.7 1.5"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from lale.datasets import load_iris_df\n",
"(train_X, train_y), (test_X, test_y) = load_iris_df()\n",
"train_X.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training LR with the Iris dataset works fine."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"trained_lr = trainable_lr.fit(train_X, train_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Lifecycle error example\n",
"\n",
"Lale encourages separating the lifecycle states, here represented\n",
"by `trainable_lr` vs. `trained_lr`. The `predict` method should\n",
"only be called on a trained model."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test_y [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n",
"predicted [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n"
]
}
],
"source": [
"predicted = trained_lr.predict(test_X)\n",
"print(f'test_y {[*test_y]}')\n",
"print(f'predicted {[*predicted]}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On the other hand, the `predict` method should not be called on a trainable model."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test_y [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n",
"predicted [2, 1, 1, 0, 2, 0, 1, 1, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 1, 0, 0, 2, 2, 0, 2, 1, 0, 2, 1, 0]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The `predict` method is deprecated on a trainable operator, because the learned coefficients could be accidentally overwritten by retraining. Call `predict` on the trained operator returned by `fit` instead.\n"
]
}
],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"error\", category=DeprecationWarning)\n",
"try:\n",
" predicted = trainable_lr.predict(test_X)\n",
"except DeprecationWarning as w:\n",
" message = str(w)\n",
"print(message, file=sys.stderr)\n",
"assert message.startswith('The `predict` method is deprecated on a trainable')\n",
"print(f'test_y {[*test_y]}')\n",
"print(f'predicted {[*predicted]}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Delegate error example\n",
"\n",
"LogisticRegression is an estimator and therefore does not have a\n",
"transform method, even when trained."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"AttributeError\n"
]
}
],
"source": [
"try:\n",
" trained_lr.transform(train_X)\n",
"except AttributeError as e:\n",
" message = 'AttributeError'\n",
" print(message, file=sys.stderr)\n",
"assert message.startswith('AttributeError')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}