{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lale and its Impact on the Data Science Workflow\n",
"\n",
"Guillaume Baudart, Martin Hirzel, Kiran Kate, Pari Ram, and Avi Shinnar\n",
"\n",
"27 March 2020\n",
"\n",
"Examples, documentation, code: https://github.com/ibm/lale\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Value Proposition\n",
"\n",
"- **target user**: data scientist familiar with Python and scikit-learn\n",
"- **scope**: data preparation and machine learning (including some DL)\n",
"- **value**: consistent API for both manual machine learning and auto-ML\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# !pip install --quiet lale"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example Dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape train_X_all (522910, 54), test_X (58102, 54)\n"
]
}
],
"source": [
"import lale.datasets\n",
"(train_X_all, train_y_all), (test_X, test_y) = lale.datasets.covtype_df(test_size=0.1)\n",
"print(f'shape train_X_all {train_X_all.shape}, test_X {test_X.shape}')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape train_X (52291, 54), other_X (470619, 54)\n"
]
}
],
"source": [
"import sklearn.model_selection\n",
"train_X, other_X, train_y, other_y = sklearn.model_selection.train_test_split(\n",
" train_X_all, train_y_all, test_size=0.9)\n",
"print(f'shape train_X {train_X.shape}, other_X {other_X.shape}')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" y | \n",
" Elevation | \n",
" Aspect | \n",
" Slope | \n",
" Horizontal_Distance_To_Hydrology | \n",
" Vertical_Distance_To_Hydrology | \n",
" Horizontal_Distance_To_Roadways | \n",
" Hillshade_9am | \n",
" Hillshade_Noon | \n",
" Hillshade_3pm | \n",
" Horizontal_Distance_To_Fire_Points | \n",
" Wilderness_Area1 | \n",
" Wilderness_Area2 | \n",
" Wilderness_Area3 | \n",
" Wilderness_Area4 | \n",
" Soil_Type1 | \n",
" Soil_Type2 | \n",
" Soil_Type3 | \n",
" Soil_Type4 | \n",
" Soil_Type5 | \n",
" Soil_Type6 | \n",
" Soil_Type7 | \n",
" Soil_Type8 | \n",
" Soil_Type9 | \n",
" Soil_Type10 | \n",
" Soil_Type11 | \n",
" Soil_Type12 | \n",
" Soil_Type13 | \n",
" Soil_Type14 | \n",
" Soil_Type15 | \n",
" Soil_Type16 | \n",
" Soil_Type17 | \n",
" Soil_Type18 | \n",
" Soil_Type19 | \n",
" Soil_Type20 | \n",
" Soil_Type21 | \n",
" Soil_Type22 | \n",
" Soil_Type23 | \n",
" Soil_Type24 | \n",
" Soil_Type25 | \n",
" Soil_Type26 | \n",
" Soil_Type27 | \n",
" Soil_Type28 | \n",
" Soil_Type29 | \n",
" Soil_Type30 | \n",
" Soil_Type31 | \n",
" Soil_Type32 | \n",
" Soil_Type33 | \n",
" Soil_Type34 | \n",
" Soil_Type35 | \n",
" Soil_Type36 | \n",
" Soil_Type37 | \n",
" Soil_Type38 | \n",
" Soil_Type39 | \n",
" Soil_Type40 | \n",
"
\n",
" \n",
" \n",
" \n",
" 274665 | \n",
" 3 | \n",
" 2354.0 | \n",
" 130.0 | \n",
" 23.0 | \n",
" 285.0 | \n",
" 80.0 | \n",
" 277.0 | \n",
" 250.0 | \n",
" 220.0 | \n",
" 86.0 | \n",
" 874.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 120210 | \n",
" 2 | \n",
" 2985.0 | \n",
" 91.0 | \n",
" 18.0 | \n",
" 886.0 | \n",
" 187.0 | \n",
" 3180.0 | \n",
" 244.0 | \n",
" 209.0 | \n",
" 88.0 | \n",
" 828.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 111775 | \n",
" 2 | \n",
" 3142.0 | \n",
" 88.0 | \n",
" 20.0 | \n",
" 684.0 | \n",
" -52.0 | \n",
" 551.0 | \n",
" 245.0 | \n",
" 204.0 | \n",
" 80.0 | \n",
" 1082.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 400567 | \n",
" 3 | \n",
" 2493.0 | \n",
" 108.0 | \n",
" 14.0 | \n",
" 182.0 | \n",
" 34.0 | \n",
" 666.0 | \n",
" 243.0 | \n",
" 223.0 | \n",
" 107.0 | \n",
" 1294.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 224682 | \n",
" 2 | \n",
" 2796.0 | \n",
" 352.0 | \n",
" 9.0 | \n",
" 594.0 | \n",
" 84.0 | \n",
" 2955.0 | \n",
" 205.0 | \n",
" 225.0 | \n",
" 158.0 | \n",
" 1471.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 424723 | \n",
" 1 | \n",
" 3126.0 | \n",
" 197.0 | \n",
" 13.0 | \n",
" 85.0 | \n",
" 10.0 | \n",
" 5344.0 | \n",
" 216.0 | \n",
" 251.0 | \n",
" 166.0 | \n",
" 1148.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 445777 | \n",
" 1 | \n",
" 2981.0 | \n",
" 333.0 | \n",
" 16.0 | \n",
" 150.0 | \n",
" 14.0 | \n",
" 2704.0 | \n",
" 182.0 | \n",
" 218.0 | \n",
" 175.0 | \n",
" 655.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 388163 | \n",
" 1 | \n",
" 3380.0 | \n",
" 219.0 | \n",
" 6.0 | \n",
" 395.0 | \n",
" 88.0 | \n",
" 2895.0 | \n",
" 213.0 | \n",
" 246.0 | \n",
" 169.0 | \n",
" 1224.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 522588 | \n",
" 7 | \n",
" 3397.0 | \n",
" 113.0 | \n",
" 15.0 | \n",
" 706.0 | \n",
" 240.0 | \n",
" 1507.0 | \n",
" 245.0 | \n",
" 223.0 | \n",
" 103.0 | \n",
" 1040.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" 128441 | \n",
" 2 | \n",
" 2831.0 | \n",
" 155.0 | \n",
" 21.0 | \n",
" 85.0 | \n",
" 27.0 | \n",
" 4235.0 | \n",
" 239.0 | \n",
" 236.0 | \n",
" 116.0 | \n",
" 5071.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" y Elevation Aspect Slope Horizontal_Distance_To_Hydrology \\\n",
"274665 3 2354.0 130.0 23.0 285.0 \n",
"120210 2 2985.0 91.0 18.0 886.0 \n",
"111775 2 3142.0 88.0 20.0 684.0 \n",
"400567 3 2493.0 108.0 14.0 182.0 \n",
"224682 2 2796.0 352.0 9.0 594.0 \n",
"424723 1 3126.0 197.0 13.0 85.0 \n",
"445777 1 2981.0 333.0 16.0 150.0 \n",
"388163 1 3380.0 219.0 6.0 395.0 \n",
"522588 7 3397.0 113.0 15.0 706.0 \n",
"128441 2 2831.0 155.0 21.0 85.0 \n",
"\n",
" Vertical_Distance_To_Hydrology Horizontal_Distance_To_Roadways \\\n",
"274665 80.0 277.0 \n",
"120210 187.0 3180.0 \n",
"111775 -52.0 551.0 \n",
"400567 34.0 666.0 \n",
"224682 84.0 2955.0 \n",
"424723 10.0 5344.0 \n",
"445777 14.0 2704.0 \n",
"388163 88.0 2895.0 \n",
"522588 240.0 1507.0 \n",
"128441 27.0 4235.0 \n",
"\n",
" Hillshade_9am Hillshade_Noon Hillshade_3pm \\\n",
"274665 250.0 220.0 86.0 \n",
"120210 244.0 209.0 88.0 \n",
"111775 245.0 204.0 80.0 \n",
"400567 243.0 223.0 107.0 \n",
"224682 205.0 225.0 158.0 \n",
"424723 216.0 251.0 166.0 \n",
"445777 182.0 218.0 175.0 \n",
"388163 213.0 246.0 169.0 \n",
"522588 245.0 223.0 103.0 \n",
"128441 239.0 236.0 116.0 \n",
"\n",
" Horizontal_Distance_To_Fire_Points Wilderness_Area1 \\\n",
"274665 874.0 0.0 \n",
"120210 828.0 0.0 \n",
"111775 1082.0 0.0 \n",
"400567 1294.0 0.0 \n",
"224682 1471.0 0.0 \n",
"424723 1148.0 1.0 \n",
"445777 655.0 0.0 \n",
"388163 1224.0 0.0 \n",
"522588 1040.0 0.0 \n",
"128441 5071.0 1.0 \n",
"\n",
" Wilderness_Area2 Wilderness_Area3 Wilderness_Area4 Soil_Type1 \\\n",
"274665 0.0 1.0 0.0 0.0 \n",
"120210 0.0 1.0 0.0 0.0 \n",
"111775 1.0 0.0 0.0 0.0 \n",
"400567 0.0 0.0 1.0 0.0 \n",
"224682 0.0 1.0 0.0 0.0 \n",
"424723 0.0 0.0 0.0 0.0 \n",
"445777 0.0 1.0 0.0 0.0 \n",
"388163 0.0 1.0 0.0 0.0 \n",
"522588 0.0 1.0 0.0 0.0 \n",
"128441 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type2 Soil_Type3 Soil_Type4 Soil_Type5 Soil_Type6 \\\n",
"274665 0.0 1.0 0.0 0.0 0.0 \n",
"120210 0.0 0.0 0.0 0.0 0.0 \n",
"111775 0.0 0.0 0.0 0.0 0.0 \n",
"400567 0.0 0.0 0.0 0.0 1.0 \n",
"224682 0.0 0.0 0.0 0.0 0.0 \n",
"424723 0.0 0.0 0.0 0.0 0.0 \n",
"445777 0.0 0.0 0.0 0.0 0.0 \n",
"388163 0.0 0.0 0.0 0.0 0.0 \n",
"522588 0.0 0.0 0.0 0.0 0.0 \n",
"128441 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type7 Soil_Type8 Soil_Type9 Soil_Type10 Soil_Type11 \\\n",
"274665 0.0 0.0 0.0 0.0 0.0 \n",
"120210 0.0 0.0 0.0 1.0 0.0 \n",
"111775 0.0 0.0 0.0 0.0 0.0 \n",
"400567 0.0 0.0 0.0 0.0 0.0 \n",
"224682 0.0 0.0 0.0 0.0 0.0 \n",
"424723 0.0 0.0 0.0 0.0 0.0 \n",
"445777 0.0 0.0 0.0 0.0 0.0 \n",
"388163 0.0 0.0 0.0 0.0 0.0 \n",
"522588 0.0 0.0 0.0 0.0 0.0 \n",
"128441 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type12 Soil_Type13 Soil_Type14 Soil_Type15 Soil_Type16 \\\n",
"274665 0.0 0.0 0.0 0.0 0.0 \n",
"120210 0.0 0.0 0.0 0.0 0.0 \n",
"111775 0.0 0.0 0.0 0.0 0.0 \n",
"400567 0.0 0.0 0.0 0.0 0.0 \n",
"224682 0.0 0.0 0.0 0.0 0.0 \n",
"424723 0.0 0.0 0.0 0.0 0.0 \n",
"445777 0.0 0.0 0.0 0.0 0.0 \n",
"388163 0.0 0.0 0.0 0.0 0.0 \n",
"522588 0.0 0.0 0.0 0.0 0.0 \n",
"128441 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type17 Soil_Type18 Soil_Type19 Soil_Type20 Soil_Type21 \\\n",
"274665 0.0 0.0 0.0 0.0 0.0 \n",
"120210 0.0 0.0 0.0 0.0 0.0 \n",
"111775 0.0 0.0 1.0 0.0 0.0 \n",
"400567 0.0 0.0 0.0 0.0 0.0 \n",
"224682 0.0 0.0 0.0 0.0 0.0 \n",
"424723 0.0 0.0 0.0 0.0 0.0 \n",
"445777 0.0 0.0 0.0 0.0 0.0 \n",
"388163 0.0 0.0 0.0 0.0 0.0 \n",
"522588 0.0 0.0 0.0 0.0 0.0 \n",
"128441 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type22 Soil_Type23 Soil_Type24 Soil_Type25 Soil_Type26 \\\n",
"274665 0.0 0.0 0.0 0.0 0.0 \n",
"120210 0.0 0.0 0.0 0.0 0.0 \n",
"111775 0.0 0.0 0.0 0.0 0.0 \n",
"400567 0.0 0.0 0.0 0.0 0.0 \n",
"224682 0.0 0.0 0.0 0.0 0.0 \n",
"424723 0.0 0.0 0.0 0.0 0.0 \n",
"445777 0.0 0.0 0.0 0.0 0.0 \n",
"388163 0.0 0.0 0.0 0.0 0.0 \n",
"522588 0.0 0.0 0.0 0.0 0.0 \n",
"128441 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type27 Soil_Type28 Soil_Type29 Soil_Type30 Soil_Type31 \\\n",
"274665 0.0 0.0 0.0 0.0 0.0 \n",
"120210 0.0 0.0 0.0 0.0 0.0 \n",
"111775 0.0 0.0 0.0 0.0 0.0 \n",
"400567 0.0 0.0 0.0 0.0 0.0 \n",
"224682 0.0 0.0 0.0 0.0 0.0 \n",
"424723 0.0 0.0 1.0 0.0 0.0 \n",
"445777 0.0 0.0 0.0 0.0 0.0 \n",
"388163 0.0 0.0 0.0 0.0 0.0 \n",
"522588 0.0 0.0 0.0 0.0 0.0 \n",
"128441 0.0 0.0 0.0 1.0 0.0 \n",
"\n",
" Soil_Type32 Soil_Type33 Soil_Type34 Soil_Type35 Soil_Type36 \\\n",
"274665 0.0 0.0 0.0 0.0 0.0 \n",
"120210 0.0 0.0 0.0 0.0 0.0 \n",
"111775 0.0 0.0 0.0 0.0 0.0 \n",
"400567 0.0 0.0 0.0 0.0 0.0 \n",
"224682 1.0 0.0 0.0 0.0 0.0 \n",
"424723 0.0 0.0 0.0 0.0 0.0 \n",
"445777 1.0 0.0 0.0 0.0 0.0 \n",
"388163 1.0 0.0 0.0 0.0 0.0 \n",
"522588 0.0 0.0 0.0 0.0 0.0 \n",
"128441 0.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type37 Soil_Type38 Soil_Type39 Soil_Type40 \n",
"274665 0.0 0.0 0.0 0.0 \n",
"120210 0.0 0.0 0.0 0.0 \n",
"111775 0.0 0.0 0.0 0.0 \n",
"400567 0.0 0.0 0.0 0.0 \n",
"224682 0.0 0.0 0.0 0.0 \n",
"424723 0.0 0.0 0.0 0.0 \n",
"445777 0.0 0.0 0.0 0.0 \n",
"388163 0.0 0.0 0.0 0.0 \n",
"522588 0.0 0.0 0.0 1.0 \n",
"128441 0.0 0.0 0.0 0.0 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"pd.set_option('display.max_columns', None)\n",
"pd.concat([pd.DataFrame({'y': train_y}, index=train_X.index),\n",
" train_X], axis=1).tail(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Manual Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.decomposition import PCA\n",
"from xgboost import XGBClassifier as XGBoost\n",
"lale.wrap_imported_operators()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"manual_trainable = PCA(n_components=6) >> XGBoost(n_estimators=3)\n",
"manual_trainable.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.56 s, sys: 672 ms, total: 4.23 s\n",
"Wall time: 3.88 s\n"
]
}
],
"source": [
"%%time\n",
"manual_trained = manual_trainable.fit(train_X, train_y)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy 75.5%\n"
]
}
],
"source": [
"import sklearn.metrics\n",
"manual_y = manual_trained.predict(test_X)\n",
"print(f'accuracy {sklearn.metrics.accuracy_score(test_y, manual_y):.1%}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameter Tuning"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'description': 'Number of trees to fit.',\n",
" 'type': 'integer',\n",
" 'default': 1000,\n",
" 'minimumForOptimizer': 500,\n",
" 'maximumForOptimizer': 1500}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"XGBoost.hyperparam_schema('n_estimators')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.pca.html\n"
]
}
],
"source": [
"print(PCA.documentation_url())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from lale.lib.lale import Hyperopt\n",
"import lale.schemas as schemas\n",
"\n",
"CustomPCA = PCA.customize_schema(n_components=schemas.Int(min=2, max=54))\n",
"CustomXGBoost = XGBoost.customize_schema(n_estimators=schemas.Int(min=1, max=10))\n",
"\n",
"hpo_planned = CustomPCA >> CustomXGBoost\n",
"hpo_trainable = Hyperopt(estimator=hpo_planned, max_evals=10, cv=3)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|███████| 10/10 [04:22<00:00, 26.22s/trial, best loss: -0.8287659271127307]\n",
"CPU times: user 4min 57s, sys: 20 s, total: 5min 17s\n",
"Wall time: 4min 53s\n"
]
}
],
"source": [
"%%time\n",
"hpo_trained = hpo_trainable.fit(train_X, train_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### --- Excursions: Types as Search Spaces ---\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy 84.2%\n"
]
}
],
"source": [
"hpo_y = hpo_trained.predict(test_X)\n",
"print(f'accuracy {sklearn.metrics.accuracy_score(test_y, hpo_y):.1%}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inspecting Automation Results"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"hpo_trained.get_pipeline().visualize()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"```python\n",
"from sklearn.decomposition import PCA as CustomPCA\n",
"from xgboost import XGBClassifier as CustomXGBoost\n",
"import lale\n",
"\n",
"lale.wrap_imported_operators()\n",
"custom_pca = CustomPCA(n_components=43, svd_solver=\"full\", whiten=True)\n",
"custom_xg_boost = CustomXGBoost(\n",
" gamma=0.42208258595069725,\n",
" learning_rate=0.6558019595096513,\n",
" max_depth=13,\n",
" min_child_weight=13,\n",
" n_estimators=9,\n",
" reg_alpha=0.3590229319214039,\n",
" reg_lambda=0.7978279409450941,\n",
" subsample=0.6209085649172931,\n",
")\n",
"pipeline = custom_pca >> custom_xg_boost\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"hpo_trained.get_pipeline().pretty_print(ipython_display=True)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" tid | \n",
" loss | \n",
" time | \n",
" log_loss | \n",
" status | \n",
"
\n",
" \n",
" name | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" p0 | \n",
" 0 | \n",
" -0.754298 | \n",
" 4.080399 | \n",
" 1.039077 | \n",
" ok | \n",
"
\n",
" \n",
" p1 | \n",
" 1 | \n",
" -0.774493 | \n",
" 7.493949 | \n",
" 0.799467 | \n",
" ok | \n",
"
\n",
" \n",
" p2 | \n",
" 2 | \n",
" -0.725306 | \n",
" 6.744288 | \n",
" 0.948600 | \n",
" ok | \n",
"
\n",
" \n",
" p3 | \n",
" 3 | \n",
" -0.783175 | \n",
" 4.715054 | \n",
" 1.036146 | \n",
" ok | \n",
"
\n",
" \n",
" p4 | \n",
" 4 | \n",
" -0.759672 | \n",
" 8.948971 | \n",
" 0.576866 | \n",
" ok | \n",
"
\n",
" \n",
" p5 | \n",
" 5 | \n",
" -0.823029 | \n",
" 11.589523 | \n",
" 0.514666 | \n",
" ok | \n",
"
\n",
" \n",
" p6 | \n",
" 6 | \n",
" -0.783404 | \n",
" 12.232503 | \n",
" 0.765154 | \n",
" ok | \n",
"
\n",
" \n",
" p7 | \n",
" 7 | \n",
" -0.828766 | \n",
" 20.878259 | \n",
" 0.435281 | \n",
" ok | \n",
"
\n",
" \n",
" p8 | \n",
" 8 | \n",
" -0.724561 | \n",
" 4.045507 | \n",
" 0.669205 | \n",
" ok | \n",
"
\n",
" \n",
" p9 | \n",
" 9 | \n",
" -0.731828 | \n",
" 4.792484 | \n",
" 1.780335 | \n",
" ok | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" tid loss time log_loss status\n",
"name \n",
"p0 0 -0.754298 4.080399 1.039077 ok\n",
"p1 1 -0.774493 7.493949 0.799467 ok\n",
"p2 2 -0.725306 6.744288 0.948600 ok\n",
"p3 3 -0.783175 4.715054 1.036146 ok\n",
"p4 4 -0.759672 8.948971 0.576866 ok\n",
"p5 5 -0.823029 11.589523 0.514666 ok\n",
"p6 6 -0.783404 12.232503 0.765154 ok\n",
"p7 7 -0.828766 20.878259 0.435281 ok\n",
"p8 8 -0.724561 4.045507 0.669205 ok\n",
"p9 9 -0.731828 4.792484 1.780335 ok"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hpo_trained.summary()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"p8\n"
]
}
],
"source": [
"worst_name = hpo_trained.summary().loss.argmax()\n",
"if not isinstance(worst_name, str): #newer pandas argmax returns index\n",
" worst_name = hpo_trained.summary().index[worst_name]\n",
"print(worst_name)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"```python\n",
"custom_pca = CustomPCA(n_components=19, svd_solver=\"full\")\n",
"custom_xg_boost = CustomXGBoost(\n",
" gamma=0.025801085053521078,\n",
" learning_rate=0.5793622466253201,\n",
" max_depth=3,\n",
" min_child_weight=8,\n",
" n_estimators=9,\n",
" reg_alpha=0.49646670359671663,\n",
" reg_lambda=0.9280083037935846,\n",
" subsample=0.5479690370134093,\n",
")\n",
"pipeline = custom_pca >> custom_xg_boost\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"hpo_trained.get_pipeline(worst_name).visualize()\n",
"hpo_trained.get_pipeline(worst_name).pretty_print(ipython_display=True, show_imports=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Combined Algorithm Selection and Hyperparameter Tuning"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.preprocessing import Normalizer as Norm\n",
"from sklearn.linear_model import LogisticRegression as LR\n",
"from sklearn.tree import DecisionTreeClassifier as Tree\n",
"from sklearn.neighbors import KNeighborsClassifier as KNN\n",
"from lale.lib.lale import NoOp\n",
"lale.wrap_imported_operators()\n",
"\n",
"KNN = KNN.customize_schema(n_neighbors=schemas.Int(min=1, max=10))\n",
"transp_planned = (Norm | NoOp) >> (Tree | LR(solver='liblinear') | KNN)\n",
"transp_planned.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|█████████| 3/3 [01:25<00:00, 28.55s/trial, best loss: -0.8412346112501562]\n",
"CPU times: user 1min 27s, sys: 1.34 s, total: 1min 28s\n",
"Wall time: 1min 27s\n"
]
}
],
"source": [
"%%time\n",
"transp_trained = transp_planned.auto_configure(\n",
" train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### --- Excursion: Bindings as Lifecycle ---\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"```python\n",
"knn = KNN(algorithm=\"ball_tree\", metric=\"manhattan\", n_neighbors=9)\n",
"pipeline = NoOp() >> knn\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"transp_trained.pretty_print(ipython_display=True, show_imports=False)\n",
"transp_trained.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy 86.6%\n",
"CPU times: user 52.4 s, sys: 78.1 ms, total: 52.5 s\n",
"Wall time: 53 s\n"
]
}
],
"source": [
"%%time\n",
"transp_y = transp_trained.predict(test_X)\n",
"print(f'accuracy {sklearn.metrics.accuracy_score(test_y, transp_y):.1%}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Non-Linear Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'description': 'Features of forest covertypes dataset (classification).',\n",
" 'documentation_url': 'https://scikit-learn.org/0.20/datasets/index.html#forest-covertypes',\n",
" 'type': 'array',\n",
" 'items': {'type': 'array',\n",
" 'minItems': 54,\n",
" 'maxItems': 54,\n",
" 'items': [{'description': 'Elevation', 'type': 'integer'},\n",
" {'description': 'Aspect', 'type': 'integer'},\n",
" {'description': 'Slope', 'type': 'integer'},\n",
" {'description': 'Horizontal_Distance_To_Hydrology', 'type': 'integer'},\n",
" {'description': 'Vertical_Distance_To_Hydrology', 'type': 'integer'},\n",
" {'description': 'Horizontal_Distance_To_Roadways', 'type': 'integer'},\n",
" {'description': 'Hillshade_9am', 'type': 'integer'},\n",
" {'description': 'Hillshade_Noon', 'type': 'integer'},\n",
" {'description': 'Hillshade_3pm', 'type': 'integer'},\n",
" {'description': 'Horizontal_Distance_To_Fire_Points', 'type': 'integer'},\n",
" {'description': 'Wilderness_Area1', 'enum': [0, 1]},\n",
" {'description': 'Wilderness_Area2', 'enum': [0, 1]},\n",
" {'description': 'Wilderness_Area3', 'enum': [0, 1]},\n",
" {'description': 'Wilderness_Area4', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type1', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type2', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type3', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type4', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type5', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type6', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type7', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type8', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type9', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type10', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type11', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type12', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type13', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type14', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type15', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type16', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type17', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type18', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type19', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type20', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type21', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type22', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type23', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type24', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type25', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type26', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type27', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type28', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type29', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type30', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type31', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type32', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type33', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type34', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type35', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type36', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type37', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type38', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type39', 'enum': [0, 1]},\n",
" {'description': 'Soil_Type40', 'enum': [0, 1]}]},\n",
" 'minItems': 58102,\n",
" 'maxItems': 58102}"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_X.json_schema"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Wilderness_Area1', 'Wilderness_Area2', 'Wilderness_Area3', 'Wilderness_Area4', 'Soil_Type1', 'Soil_Type2', 'Soil_Type3', 'Soil_Type4', 'Soil_Type5', 'Soil_Type6', 'Soil_Type7', 'Soil_Type8', 'Soil_Type9', 'Soil_Type10', 'Soil_Type11', 'Soil_Type12', 'Soil_Type13', 'Soil_Type14', 'Soil_Type15', 'Soil_Type16', 'Soil_Type17', 'Soil_Type18', 'Soil_Type19', 'Soil_Type20', 'Soil_Type21', 'Soil_Type22', 'Soil_Type23', 'Soil_Type24', 'Soil_Type25', 'Soil_Type26', 'Soil_Type27', 'Soil_Type28', 'Soil_Type29', 'Soil_Type30', 'Soil_Type31', 'Soil_Type32', 'Soil_Type33', 'Soil_Type34', 'Soil_Type35', 'Soil_Type36', 'Soil_Type37', 'Soil_Type38', 'Soil_Type39', 'Soil_Type40']\n"
]
}
],
"source": [
"from lale.lib.lale import categorical\n",
"print(categorical(max_values=2)(test_X))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from lale.lib.lale import Project\n",
"from lale.lib.lale import ConcatFeatures as Concat\n",
"from sklearn.feature_selection import SelectKBest as FeatSel\n",
"lale.wrap_imported_operators()\n",
"\n",
"binary_prep = Project(columns=categorical(max_values=2)) >> FeatSel\n",
"other_prep = Project(drop_columns=categorical(max_values=2)) >> (Norm | NoOp)\n",
"nonlin_planned = (binary_prep & other_prep) >> Concat >> KNN\n",
"nonlin_planned.visualize()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100%|█████████| 3/3 [02:30<00:00, 50.20s/trial, best loss: -0.8618882829755578]\n",
"CPU times: user 2min 32s, sys: 344 ms, total: 2min 33s\n",
"Wall time: 2min 35s\n"
]
}
],
"source": [
"%%time\n",
"nonlin_trained = nonlin_planned.auto_configure(\n",
" train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### --- Excursion: Combinators ---\n",
"\n",
"| Lale feature | Name | Description | Scikit-learn feature |\n",
"| ----------------------- | ---- | ------------ | ----------------------------------- |\n",
"| >> or `make_pipeline` | pipe | feed to next | `make_pipeline` |\n",
"| & or `make_union` | and | run both | `make_union` or `ColumnTransformer` |\n",
"| | or `make_choice` | or | choose one | N/A (specific to given AutoML tool) |\n",
"\n",
"### --- Excursion: Interoperability ---\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"```python\n",
"project_0 = Project(columns=lale.lib.lale.categorical(max_values=2))\n",
"feat_sel = FeatSel(k=8)\n",
"pipeline_0 = make_pipeline(project_0, feat_sel)\n",
"project_1 = Project(drop_columns=lale.lib.lale.categorical(max_values=2))\n",
"pipeline_1 = make_pipeline(project_1, NoOp())\n",
"union = make_union(pipeline_0, pipeline_1)\n",
"knn = KNN(algorithm=\"kd_tree\", n_neighbors=7, weights=\"distance\")\n",
"pipeline = make_pipeline(union, knn)\n",
"```"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"nonlin_trained.visualize()\n",
"nonlin_trained.pretty_print(ipython_display=True, show_imports=False, combinators=False)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy 88.6%\n",
"CPU times: user 5.02 s, sys: 78.1 ms, total: 5.09 s\n",
"Wall time: 5.13 s\n"
]
}
],
"source": [
"%%time\n",
"nonlin_y = nonlin_trained.predict(test_X)\n",
"print(f'accuracy {sklearn.metrics.accuracy_score(test_y, nonlin_y):.1%}')"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Wilderness_Area1 | \n",
" Wilderness_Area4 | \n",
" Soil_Type2 | \n",
" Soil_Type3 | \n",
" Soil_Type4 | \n",
" Soil_Type10 | \n",
" Soil_Type38 | \n",
" Soil_Type39 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 8 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 9 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Wilderness_Area1 Wilderness_Area4 Soil_Type2 Soil_Type3 Soil_Type4 \\\n",
"0 1.0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 0.0 \n",
"3 1.0 0.0 0.0 0.0 0.0 \n",
"4 1.0 0.0 0.0 0.0 0.0 \n",
"5 0.0 1.0 0.0 0.0 0.0 \n",
"6 1.0 0.0 0.0 0.0 0.0 \n",
"7 1.0 0.0 0.0 0.0 0.0 \n",
"8 1.0 0.0 0.0 0.0 0.0 \n",
"9 1.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Soil_Type10 Soil_Type38 Soil_Type39 \n",
"0 0.0 1.0 0.0 \n",
"1 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 \n",
"5 1.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 \n",
"7 0.0 1.0 0.0 \n",
"8 0.0 0.0 0.0 \n",
"9 0.0 0.0 0.0 "
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"binary_prep_trainable = Project(columns=categorical(max_values=2)) >> FeatSel(k=8)\n",
"binary_prep_trained = binary_prep_trainable.fit(train_X, train_y)\n",
"binary_prep_trained.transform(test_X.head(10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"- code and documentation: https://github.com/ibm/lale\n",
"- more examples: https://nbviewer.jupyter.org/github/IBM/lale/tree/master/examples/\n",
"- frequently asked questions: https://github.com/IBM/lale/blob/master/docs/faq.rst\n",
"- arXiv paper: https://arxiv.org/pdf/1906.03957.pdf\n",
"\n",
""
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}