{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lale and its Impact on the Data Science Workflow\n",
"\n",
"Guillaume Baudart, Martin Hirzel, Kiran Kate, Pari Ram, and Avi Shinnar\n",
"\n",
"27 March 2020\n",
"\n",
"Examples, documentation, code: https://github.com/ibm/lale\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Value Proposition\n",
"\n",
"- **target user**: data scientist familiar with Python and scikit-learn\n",
"- **scope**: data preparation and machine learning (including some DL)\n",
"- **value**: consistent API for both manual machine learning and auto-ML\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# !pip install --quiet lale"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#Enabling schema validation for this notebook\n",
"from lale.settings import set_disable_data_schema_validation\n",
"set_disable_data_schema_validation(False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example Dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape train_X_all (522910, 54), test_X (58102, 54)\n"
]
}
],
"source": [
"import lale.datasets\n",
"(train_X_all, train_y_all), (test_X, test_y) = lale.datasets.covtype_df(test_size=0.1)\n",
"print(f'shape train_X_all {train_X_all.shape}, test_X {test_X.shape}')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape train_X (52291, 54), other_X (470619, 54)\n"
]
}
],
"source": [
"import sklearn.model_selection\n",
"train_X, other_X, train_y, other_y = sklearn.model_selection.train_test_split(\n",
" train_X_all, train_y_all, test_size=0.9)\n",
"print(f'shape train_X {train_X.shape}, other_X {other_X.shape}')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" y | \n",
" Elevation | \n",
" Aspect | \n",
" Slope | \n",
" Horizontal_Distance_To_Hydrology | \n",
" Vertical_Distance_To_Hydrology | \n",
" Horizontal_Distance_To_Roadways | \n",
" Hillshade_9am | \n",
" Hillshade_Noon | \n",
" Hillshade_3pm | \n",
" Horizontal_Distance_To_Fire_Points | \n",
" Wilderness_Area1 | \n",
" Wilderness_Area2 | \n",
" Wilderness_Area3 | \n",
" Wilderness_Area4 | \n",
" Soil_Type1 | \n",
" Soil_Type2 | \n",
" Soil_Type3 | \n",
" Soil_Type4 | \n",
" Soil_Type5 | \n",
" Soil_Type6 | \n",
" Soil_Type7 | \n",
" Soil_Type8 | \n",
" Soil_Type9 | \n",
" Soil_Type10 | \n",
" Soil_Type11 | \n",
" Soil_Type12 | \n",
" Soil_Type13 | \n",
" Soil_Type14 | \n",
" Soil_Type15 | \n",
" Soil_Type16 | \n",
" Soil_Type17 | \n",
" Soil_Type18 | \n",
" Soil_Type19 | \n",
" Soil_Type20 | \n",
" Soil_Type21 | \n",
" Soil_Type22 | \n",
" Soil_Type23 | \n",
" Soil_Type24 | \n",
" Soil_Type25 | \n",
" Soil_Type26 | \n",
" Soil_Type27 | \n",
" Soil_Type28 | \n",
" Soil_Type29 | \n",
" Soil_Type30 | \n",
" Soil_Type31 | \n",
" Soil_Type32 | \n",
" Soil_Type33 | \n",
" Soil_Type34 | \n",
" Soil_Type35 | \n",
" Soil_Type36 | \n",
" Soil_Type37 | \n",
" Soil_Type38 | \n",
" Soil_Type39 | \n",
" Soil_Type40 | \n",
"
\n",
" \n",
" \n",
" \n",
" 484665 | \n",
" 3 | \n",
" 2277.0 | \n",
" 41.0 | \n",
" 31.0 | \n",
" 228.0 | \n",
" 145.0 | \n",
" 1045.0 | \n",
" 207.0 | \n",
" 157.0 | \n",
" 65.0 | \n",
" 1516.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 451137 | \n",
" 1 | \n",
" 3273.0 | \n",
" 296.0 | \n",
" 22.0 | \n",
" 371.0 | \n",
" 45.0 | \n",
" 1740.0 | \n",
" 153.0 | \n",
" 227.0 | \n",
" 212.0 | \n",
" 808.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 239309 | \n",
" 1 | \n",
" 3062.0 | \n",
" 298.0 | \n",
" 13.0 | \n",
" 408.0 | \n",
" 78.0 | \n",
" 2445.0 | \n",
" 184.0 | \n",
" 235.0 | \n",
" 191.0 | \n",
" 1041.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 406901 | \n",
" 2 | \n",
" 3195.0 | \n",
" 42.0 | \n",
" 19.0 | \n",
" 376.0 | \n",
" 72.0 | \n",
" 3873.0 | \n",
" 220.0 | \n",
" 196.0 | \n",
" 105.0 | \n",
" 2935.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 379632 | \n",
" 2 | \n",
" 3003.0 | \n",
" 310.0 | \n",
" 14.0 | \n",
" 182.0 | \n",
" 30.0 | \n",
" 2573.0 | \n",
" 181.0 | \n",
" 230.0 | \n",
" 189.0 | \n",
" 2408.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 510084 | \n",
" 1 | \n",
" 2898.0 | \n",
" 47.0 | \n",
" 10.0 | \n",
" 30.0 | \n",
" -3.0 | \n",
" 1865.0 | \n",
" 224.0 | \n",
" 219.0 | \n",
" 129.0 | \n",
" 1022.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 96001 | \n",
" 2 | \n",
" 2221.0 | \n",
" 338.0 | \n",
" 22.0 | \n",
" 242.0 | \n",
" 72.0 | \n",
" 437.0 | \n",
" 168.0 | \n",
" 204.0 | \n",
" 172.0 | \n",
" 342.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 39684 | \n",
" 1 | \n",
" 3289.0 | \n",
" 322.0 | \n",
" 18.0 | \n",
" 285.0 | \n",
" 60.0 | \n",
" 4012.0 | \n",
" 172.0 | \n",
" 219.0 | \n",
" 186.0 | \n",
" 1291.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 227535 | \n",
" 2 | \n",
" 2890.0 | \n",
" 272.0 | \n",
" 6.0 | \n",
" 376.0 | \n",
" 43.0 | \n",
" 2296.0 | \n",
" 204.0 | \n",
" 242.0 | \n",
" 176.0 | \n",
" 2460.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 85578 | \n",
" 1 | \n",
" 3340.0 | \n",
" 204.0 | \n",
" 16.0 | \n",
" 510.0 | \n",
" 134.0 | \n",
" 1851.0 | \n",
" 210.0 | \n",
" 253.0 | \n",
" 174.0 | \n",
" 1426.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" y Elevation Aspect Slope Horizontal_Distance_To_Hydrology \\\n",
"484665 3 2277.0 41.0 31.0 228.0 \n",
"451137 1 3273.0 296.0 22.0 371.0 \n",
"239309 1 3062.0 298.0 13.0 408.0 \n",
"406901 2 3195.0 42.0 19.0 376.0 \n",
"379632 2 3003.0 310.0 14.0 182.0 \n",
"510084 1 2898.0 47.0 10.0 30.0 \n",
"96001 2 2221.0 338.0 22.0 242.0 \n",
"39684 1 3289.0 322.0 18.0 285.0 \n",
"227535 2 2890.0 272.0 6.0 376.0 \n",
"85578 1 3340.0 204.0 16.0 510.0 \n",
"\n",
" Vertical_Distance_To_Hydrology Horizontal_Distance_To_Roadways \\\n",
"484665 145.0 1045.0 \n",
"451137 45.0 1740.0 \n",
"239309 78.0 2445.0 \n",
"406901 72.0 3873.0 \n",
"379632 30.0 2573.0 \n",
"510084 -3.0 1865.0 \n",
"96001 72.0 437.0 \n",
"39684 60.0 4012.0 \n",
"227535 43.0 2296.0 \n",
"85578 134.0 1851.0 \n",
"\n",
" Hillshade_9am Hillshade_Noon Hillshade_3pm \\\n",
"484665 207.0 157.0 65.0 \n",
"451137 153.0 227.0 212.0 \n",
"239309 184.0 235.0 191.0 \n",
"406901 220.0 196.0 105.0 \n",
"379632 181.0 230.0 189.0 \n",
"510084 224.0 219.0 129.0 \n",
"96001 168.0 204.0 172.0 \n",
"39684 172.0 219.0 186.0 \n",
"227535 204.0 242.0 176.0 \n",
"85578 210.0 253.0 174.0 \n",
"\n",
" Horizontal_Distance_To_Fire_Points Wilderness_Area1 \\\n",
"484665 1516.0 0.0 \n",
"451137 808.0 0.0 \n",
"239309 1041.0 1.0 \n",
"406901 2935.0 1.0 \n",
"379632 2408.0 0.0 \n",
"510084 1022.0 1.0 \n",
"96001 342.0 0.0 \n",
"39684 1291.0 1.0 \n",
"227535 2460.0 1.0 \n",
"85578 1426.0 0.0 \n",
"\n",
" Wilderness_Area2 Wilderness_Area3 Wilderness_Area4 Soil_Type1 \\\n",
"484665 0.0 0.0 1.0 0.0 \n",
"451137 0.0 1.0 0.0 0.0 \n",
"239309 0.0 0.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 0.0 \n",
"379632 0.0 1.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 0.0 \n",
"96001 0.0 0.0 1.0 0.0 \n",
"39684 0.0 0.0 0.0 0.0 \n",
"227535 0.0 0.0 0.0 0.0 \n",
"85578 0.0 1.0 0.0 0.0 \n",
"\n",
" Soil_Type2 Soil_Type3 Soil_Type4 Soil_Type5 Soil_Type6 \\\n",
"484665 0.0 0.0 0.0 0.0 0.0 \n",
"451137 0.0 0.0 0.0 0.0 0.0 \n",
"239309 0.0 0.0 0.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 0.0 0.0 \n",
"379632 0.0 0.0 0.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 0.0 0.0 \n",
"96001 0.0 0.0 0.0 0.0 0.0 \n",
"39684 0.0 0.0 0.0 0.0 0.0 \n",
"227535 0.0 0.0 0.0 0.0 0.0 \n",
"85578 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type7 Soil_Type8 Soil_Type9 Soil_Type10 Soil_Type11 \\\n",
"484665 0.0 0.0 0.0 1.0 0.0 \n",
"451137 0.0 0.0 0.0 0.0 0.0 \n",
"239309 0.0 0.0 0.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 0.0 0.0 \n",
"379632 0.0 0.0 0.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 0.0 0.0 \n",
"96001 0.0 0.0 0.0 1.0 0.0 \n",
"39684 0.0 0.0 0.0 0.0 0.0 \n",
"227535 0.0 0.0 0.0 0.0 0.0 \n",
"85578 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type12 Soil_Type13 Soil_Type14 Soil_Type15 Soil_Type16 \\\n",
"484665 0.0 0.0 0.0 0.0 0.0 \n",
"451137 0.0 0.0 0.0 0.0 0.0 \n",
"239309 0.0 0.0 0.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 0.0 0.0 \n",
"379632 0.0 0.0 0.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 0.0 0.0 \n",
"96001 0.0 0.0 0.0 0.0 0.0 \n",
"39684 0.0 0.0 0.0 0.0 0.0 \n",
"227535 0.0 0.0 0.0 0.0 0.0 \n",
"85578 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type17 Soil_Type18 Soil_Type19 Soil_Type20 Soil_Type21 \\\n",
"484665 0.0 0.0 0.0 0.0 0.0 \n",
"451137 0.0 0.0 0.0 0.0 0.0 \n",
"239309 0.0 0.0 0.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 0.0 0.0 \n",
"379632 0.0 0.0 0.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 1.0 0.0 \n",
"96001 0.0 0.0 0.0 0.0 0.0 \n",
"39684 0.0 0.0 0.0 0.0 0.0 \n",
"227535 0.0 0.0 0.0 0.0 0.0 \n",
"85578 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type22 Soil_Type23 Soil_Type24 Soil_Type25 Soil_Type26 \\\n",
"484665 0.0 0.0 0.0 0.0 0.0 \n",
"451137 0.0 0.0 0.0 0.0 0.0 \n",
"239309 0.0 0.0 0.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 0.0 0.0 \n",
"379632 0.0 0.0 0.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 0.0 0.0 \n",
"96001 0.0 0.0 0.0 0.0 0.0 \n",
"39684 0.0 0.0 0.0 0.0 0.0 \n",
"227535 0.0 0.0 0.0 0.0 0.0 \n",
"85578 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type27 Soil_Type28 Soil_Type29 Soil_Type30 Soil_Type31 \\\n",
"484665 0.0 0.0 0.0 0.0 0.0 \n",
"451137 0.0 0.0 0.0 0.0 1.0 \n",
"239309 0.0 0.0 1.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 1.0 0.0 \n",
"379632 0.0 0.0 0.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 0.0 0.0 \n",
"96001 0.0 0.0 0.0 0.0 0.0 \n",
"39684 0.0 0.0 0.0 0.0 0.0 \n",
"227535 0.0 0.0 1.0 0.0 0.0 \n",
"85578 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type32 Soil_Type33 Soil_Type34 Soil_Type35 Soil_Type36 \\\n",
"484665 0.0 0.0 0.0 0.0 0.0 \n",
"451137 0.0 0.0 0.0 0.0 0.0 \n",
"239309 0.0 0.0 0.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 0.0 0.0 \n",
"379632 0.0 1.0 0.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 0.0 0.0 \n",
"96001 0.0 0.0 0.0 0.0 0.0 \n",
"39684 0.0 0.0 0.0 0.0 0.0 \n",
"227535 0.0 0.0 0.0 0.0 0.0 \n",
"85578 1.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type37 Soil_Type38 Soil_Type39 Soil_Type40 \n",
"484665 0.0 0.0 0.0 0.0 \n",
"451137 0.0 0.0 0.0 0.0 \n",
"239309 0.0 0.0 0.0 0.0 \n",
"406901 0.0 0.0 0.0 0.0 \n",
"379632 0.0 0.0 0.0 0.0 \n",
"510084 0.0 0.0 0.0 0.0 \n",
"96001 0.0 0.0 0.0 0.0 \n",
"39684 0.0 0.0 1.0 0.0 \n",
"227535 0.0 0.0 0.0 0.0 \n",
"85578 0.0 0.0 0.0 0.0 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"pd.set_option('display.max_columns', None)\n",
"pd.concat([pd.DataFrame({'y': train_y}, index=train_X.index),\n",
" train_X], axis=1).tail(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Manual Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.decomposition import PCA\n",
"from xgboost import XGBClassifier as XGBoost\n",
"lale.wrap_imported_operators()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"manual_trainable = PCA(n_components=6) >> XGBoost(n_estimators=3)\n",
"manual_trainable.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.39 s, sys: 953 ms, total: 3.34 s\n",
"Wall time: 2.05 s\n"
]
}
],
"source": [
"%%time\n",
"manual_trained = manual_trainable.fit(train_X, train_y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy 67.1%\n"
]
}
],
"source": [
"import sklearn.metrics\n",
"manual_y = manual_trained.predict(test_X)\n",
"print(f'accuracy {sklearn.metrics.accuracy_score(test_y, manual_y):.1%}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameter Tuning"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'description': 'Number of trees to fit.',\n",
" 'type': 'integer',\n",
" 'default': 100,\n",
" 'minimumForOptimizer': 50,\n",
" 'maximumForOptimizer': 1000}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"XGBoost.hyperparam_schema('n_estimators')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.pca.html\n"
]
}
],
"source": [
"print(PCA.documentation_url())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from lale.lib.lale import Hyperopt\n",
"import lale.schemas as schemas\n",
"\n",
"CustomPCA = PCA.customize_schema(n_components=schemas.Int(minimum=2, maximum=54))\n",
"CustomXGBoost = XGBoost.customize_schema(n_estimators=schemas.Int(minimum=1, maximum=10))\n",
"\n",
"hpo_planned = CustomPCA >> CustomXGBoost\n",
"hpo_trainable = Hyperopt(estimator=hpo_planned, max_evals=10, cv=3)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|███████| 10/10 [02:15<00:00, 13.53s/trial, best loss: -0.7727907776451675]\n",
"CPU times: user 2min 53s, sys: 19.3 s, total: 3min 13s\n",
"Wall time: 2min 30s\n"
]
}
],
"source": [
"%%time\n",
"hpo_trained = hpo_trainable.fit(train_X, train_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### --- Excursions: Types as Search Spaces ---\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy 77.7%\n"
]
}
],
"source": [
"hpo_y = hpo_trained.predict(test_X)\n",
"print(f'accuracy {sklearn.metrics.accuracy_score(test_y, hpo_y):.1%}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inspecting Automation Results"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"hpo_trained.get_pipeline().visualize()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"```python\n",
"from sklearn.decomposition import PCA as CustomPCA\n",
"from xgboost import XGBClassifier as CustomXGBoost\n",
"import lale\n",
"\n",
"lale.wrap_imported_operators()\n",
"custom_pca = CustomPCA.customize_schema(\n",
" n_components={\"type\": \"integer\", \"minimum\": 2, \"maximum\": 54}\n",
")(n_components=43, svd_solver=\"full\", whiten=True)\n",
"custom_xg_boost = CustomXGBoost.customize_schema(\n",
" n_estimators={\"type\": \"integer\", \"minimum\": 1, \"maximum\": 10}\n",
")(\n",
" gamma=0.42208258595069725,\n",
" learning_rate=0.6558019595096513,\n",
" max_depth=5,\n",
" min_child_weight=13,\n",
" n_estimators=9,\n",
" reg_alpha=0.3590229319214039,\n",
" reg_lambda=0.7978279409450941,\n",
" subsample=0.6209085649172931,\n",
")\n",
"pipeline = custom_pca >> custom_xg_boost\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"hpo_trained.get_pipeline().pretty_print(ipython_display=True, customize_schema=True)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" tid | \n",
" loss | \n",
" time | \n",
" log_loss | \n",
" status | \n",
"
\n",
" \n",
" name | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" p0 | \n",
" 0 | \n",
" -0.684229 | \n",
" 2.293911 | \n",
" 1.161776 | \n",
" ok | \n",
"
\n",
" \n",
" p1 | \n",
" 1 | \n",
" -0.708057 | \n",
" 3.347494 | \n",
" 0.950058 | \n",
" ok | \n",
"
\n",
" \n",
" p2 | \n",
" 2 | \n",
" -0.631983 | \n",
" 3.356443 | \n",
" 1.123108 | \n",
" ok | \n",
"
\n",
" \n",
" p3 | \n",
" 3 | \n",
" -0.699050 | \n",
" 2.606100 | \n",
" 1.168528 | \n",
" ok | \n",
"
\n",
" \n",
" p4 | \n",
" 4 | \n",
" -0.717428 | \n",
" 5.158346 | \n",
" 0.690650 | \n",
" ok | \n",
"
\n",
" \n",
" p5 | \n",
" 5 | \n",
" -0.759653 | \n",
" 7.138689 | \n",
" 0.655658 | \n",
" ok | \n",
"
\n",
" \n",
" p6 | \n",
" 6 | \n",
" -0.707598 | \n",
" 3.555126 | \n",
" 0.942210 | \n",
" ok | \n",
"
\n",
" \n",
" p7 | \n",
" 7 | \n",
" -0.772791 | \n",
" 10.981915 | \n",
" 0.555780 | \n",
" ok | \n",
"
\n",
" \n",
" p8 | \n",
" 8 | \n",
" -0.653057 | \n",
" 2.016587 | \n",
" 0.845659 | \n",
" ok | \n",
"
\n",
" \n",
" p9 | \n",
" 9 | \n",
" -0.620853 | \n",
" 2.155818 | \n",
" 1.817853 | \n",
" ok | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" tid loss time log_loss status\n",
"name \n",
"p0 0 -0.684229 2.293911 1.161776 ok\n",
"p1 1 -0.708057 3.347494 0.950058 ok\n",
"p2 2 -0.631983 3.356443 1.123108 ok\n",
"p3 3 -0.699050 2.606100 1.168528 ok\n",
"p4 4 -0.717428 5.158346 0.690650 ok\n",
"p5 5 -0.759653 7.138689 0.655658 ok\n",
"p6 6 -0.707598 3.555126 0.942210 ok\n",
"p7 7 -0.772791 10.981915 0.555780 ok\n",
"p8 8 -0.653057 2.016587 0.845659 ok\n",
"p9 9 -0.620853 2.155818 1.817853 ok"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hpo_trained.summary()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"p9\n"
]
}
],
"source": [
"worst_name = hpo_trained.summary().loss.argmax()\n",
"if not isinstance(worst_name, str): #newer pandas argmax returns index\n",
" worst_name = hpo_trained.summary().index[worst_name]\n",
"print(worst_name)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"```python\n",
"custom_pca = CustomPCA.customize_schema(\n",
" n_components={\"type\": \"integer\", \"minimum\": 2, \"maximum\": 54}\n",
")(n_components=20, svd_solver=\"full\", whiten=True)\n",
"custom_xg_boost = CustomXGBoost.customize_schema(\n",
" n_estimators={\"type\": \"integer\", \"minimum\": 1, \"maximum\": 10}\n",
")(\n",
" gamma=0.37068548766270437,\n",
" learning_rate=0.02005982973762002,\n",
" max_depth=2,\n",
" min_child_weight=9,\n",
" n_estimators=5,\n",
" reg_alpha=0.8716519284632148,\n",
" reg_lambda=0.7305593001592293,\n",
" subsample=0.9559232064468288,\n",
")\n",
"pipeline = custom_pca >> custom_xg_boost\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"hpo_trained.get_pipeline(worst_name).visualize()\n",
"hpo_trained.get_pipeline(worst_name).pretty_print(\n",
" ipython_display=True, show_imports=False, customize_schema=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Combined Algorithm Selection and Hyperparameter Tuning"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.preprocessing import Normalizer as Norm\n",
"from sklearn.linear_model import LogisticRegression as LR\n",
"from sklearn.tree import DecisionTreeClassifier as Tree\n",
"from sklearn.neighbors import KNeighborsClassifier as KNN\n",
"from lale.lib.lale import NoOp\n",
"lale.wrap_imported_operators()\n",
"\n",
"KNN = KNN.customize_schema(n_neighbors=schemas.Int(minimum=1, maximum=10))\n",
"transp_planned = (Norm | NoOp) >> (Tree | LR(solver='liblinear') | KNN)\n",
"transp_planned.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|█████████| 3/3 [01:24<00:00, 28.25s/trial, best loss: -0.8390927596840342]\n",
"CPU times: user 1min 26s, sys: 953 ms, total: 1min 27s\n",
"Wall time: 1min 26s\n"
]
}
],
"source": [
"%%time\n",
"transp_trained = transp_planned.auto_configure(\n",
" train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### --- Excursion: Bindings as Lifecycle ---\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"```python\n",
"knn = KNN.customize_schema(\n",
" n_neighbors={\"type\": \"integer\", \"minimum\": 1, \"maximum\": 10}\n",
")(algorithm=\"ball_tree\", metric=\"manhattan\", n_neighbors=9)\n",
"pipeline = NoOp() >> knn\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"transp_trained.pretty_print(\n",
" ipython_display=True, show_imports=False, customize_schema=True)\n",
"transp_trained.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy 86.6%\n",
"CPU times: user 51.5 s, sys: 31.2 ms, total: 51.5 s\n",
"Wall time: 52 s\n"
]
}
],
"source": [
"%%time\n",
"transp_y = transp_trained.predict(test_X)\n",
"print(f'accuracy {sklearn.metrics.accuracy_score(test_y, transp_y):.1%}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Non-Linear Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'description': 'Features of forest covertypes dataset (classification).',\n",
" 'documentation_url': 'https://scikit-learn.org/0.20/datasets/index.html#forest-covertypes',\n",
" 'type': 'array',\n",
" 'items': {'type': 'array',\n",
" 'minItems': 54,\n",
" 'maxItems': 54,\n",
" 'items': [{'description': 'Elevation', 'type': 'integer'},\n",
" {'description': 'Aspect', 'type': 'integer'},\n",
" {'description': 'Slope', 'type': 'integer'},\n",
" {'description': 'Horizontal_Distance_To_Hydrology', 'type': 'integer'},\n",
" {'description': 'Vertical_Distance_To_Hydrology', 'type': 'integer'},\n",
" {'description': 'Horizontal_Distance_To_Roadways', 'type': 'integer'},\n",
" {'description': 'Hillshade_9am', 'type': 'integer'},\n",
" {'description': 'Hillshade_Noon', 'type': 'integer'},\n",
" {'description': 'Hillshade_3pm', 'type': 'integer'},\n",
" {'description': 'Horizontal_Distance_To_Fire_Points', 'type': 'integer'},\n",
" {'description': 'Wilderness_Area1', 'enum': [0, 1]},\n",
" {'description': 'Wilderness_Area2', 'enum': [0, 1]},\n",
" {'description': 'Wilderness_Area3', 'enum': [0, 1]},\n",
" {'description': 'Wilderness_Area4', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type1', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type2', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type3', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type4', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type5', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type6', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type7', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type8', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type9', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type10', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type11', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type12', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type13', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type14', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type15', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type16', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type17', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type18', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type19', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type20', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type21', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type22', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type23', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type24', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type25', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type26', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type27', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type28', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type29', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type30', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type31', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type32', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type33', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type34', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type35', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type36', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type37', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type38', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type39', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type40', 'enum': [0, 1]}]},\n",
" 'minItems': 58102,\n",
" 'maxItems': 58102}"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_X.json_schema"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Wilderness_Area1', 'Wilderness_Area2', 'Wilderness_Area3', 'Wilderness_Area4', 'Soil_Type1', 'Soil_Type2', 'Soil_Type3', 'Soil_Type4', 'Soil_Type5', 'Soil_Type6', 'Soil_Type7', 'Soil_Type8', 'Soil_Type9', 'Soil_Type10', 'Soil_Type11', 'Soil_Type12', 'Soil_Type13', 'Soil_Type14', 'Soil_Type15', 'Soil_Type16', 'Soil_Type17', 'Soil_Type18', 'Soil_Type19', 'Soil_Type20', 'Soil_Type21', 'Soil_Type22', 'Soil_Type23', 'Soil_Type24', 'Soil_Type25', 'Soil_Type26', 'Soil_Type27', 'Soil_Type28', 'Soil_Type29', 'Soil_Type30', 'Soil_Type31', 'Soil_Type32', 'Soil_Type33', 'Soil_Type34', 'Soil_Type35', 'Soil_Type36', 'Soil_Type37', 'Soil_Type38', 'Soil_Type39', 'Soil_Type40']\n"
]
}
],
"source": [
"from lale.lib.lale import categorical\n",
"print(categorical(max_values=2)(test_X))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from lale.lib.lale import Project\n",
"from lale.lib.lale import ConcatFeatures as Concat\n",
"from sklearn.feature_selection import SelectKBest as FeatSel\n",
"lale.wrap_imported_operators()\n",
"\n",
"binary_prep = Project(columns=categorical(max_values=2)) >> FeatSel\n",
"other_prep = Project(drop_columns=categorical(max_values=2)) >> (Norm | NoOp)\n",
"nonlin_planned = (binary_prep & other_prep) >> Concat >> KNN\n",
"nonlin_planned.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|█████████| 3/3 [02:17<00:00, 45.88s/trial, best loss: -0.8620412868709595]\n",
"CPU times: user 2min 18s, sys: 359 ms, total: 2min 19s\n",
"Wall time: 2min 21s\n"
]
}
],
"source": [
"%%time\n",
"nonlin_trained = nonlin_planned.auto_configure(\n",
" train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3, verbose=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### --- Excursion: Combinators ---\n",
"\n",
"| Lale feature | Name | Description | Scikit-learn feature |\n",
"| ----------------------- | ---- | ------------ | ----------------------------------- |\n",
"| >> or `make_pipeline` | pipe | feed to next | `make_pipeline` |\n",
"| & or `make_union` | and | run both | `make_union` or `ColumnTransformer` |\n",
"| | or `make_choice` | or | choose one | N/A (specific to given AutoML tool) |\n",
"\n",
"### --- Excursion: Interoperability ---\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"```python\n",
"project_0 = Project(columns=lale.lib.lale.categorical(max_values=2))\n",
"feat_sel = FeatSel(k=8)\n",
"pipeline_0 = make_pipeline(project_0, feat_sel)\n",
"project_1 = Project(drop_columns=lale.lib.lale.categorical(max_values=2))\n",
"pipeline_1 = make_pipeline(project_1, NoOp())\n",
"union = make_union(pipeline_0, pipeline_1)\n",
"knn = KNN(algorithm=\"kd_tree\", n_neighbors=7, weights=\"distance\")\n",
"pipeline = make_pipeline(union, knn)\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"nonlin_trained.visualize()\n",
"nonlin_trained.pretty_print(ipython_display=True, show_imports=False, combinators=False)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy 88.6%\n",
"CPU times: user 4.31 s, sys: 46.9 ms, total: 4.36 s\n",
"Wall time: 4.44 s\n"
]
}
],
"source": [
"%%time\n",
"nonlin_y = nonlin_trained.predict(test_X)\n",
"print(f'accuracy {sklearn.metrics.accuracy_score(test_y, nonlin_y):.1%}')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Wilderness_Area1 | \n",
" Wilderness_Area4 | \n",
" Soil_Type2 | \n",
" Soil_Type3 | \n",
" Soil_Type4 | \n",
" Soil_Type10 | \n",
" Soil_Type38 | \n",
" Soil_Type39 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 8 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 9 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Wilderness_Area1 Wilderness_Area4 Soil_Type2 Soil_Type3 Soil_Type4 \\\n",
"0 1.0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 0.0 \n",
"3 1.0 0.0 0.0 0.0 0.0 \n",
"4 1.0 0.0 0.0 0.0 0.0 \n",
"5 0.0 1.0 0.0 0.0 0.0 \n",
"6 1.0 0.0 0.0 0.0 0.0 \n",
"7 1.0 0.0 0.0 0.0 0.0 \n",
"8 1.0 0.0 0.0 0.0 0.0 \n",
"9 1.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type10 Soil_Type38 Soil_Type39 \n",
"0 0.0 1.0 0.0 \n",
"1 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 1.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"7 0.0 1.0 0.0 \n",
"8 0.0 0.0 0.0 \n",
"9 0.0 0.0 0.0 "
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"binary_prep_trainable = Project(columns=categorical(max_values=2)) >> FeatSel(k=8)\n",
"binary_prep_trained = binary_prep_trainable.fit(train_X, train_y)\n",
"binary_prep_trained.transform(test_X.head(10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"- code and documentation: https://github.com/ibm/lale\n",
"- more examples: https://github.com/IBM/lale/tree/master/examples/\n",
"- frequently asked questions: https://github.com/IBM/lale/blob/master/docs/faq.rst\n",
"- arXiv paper: https://arxiv.org/pdf/1906.03957.pdf\n",
"\n",
""
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}