{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lale and its Impact on the Data Science Workflow\n", "\n", "Guillaume Baudart, Martin Hirzel, Kiran Kate, Pari Ram, and Avi Shinnar\n", "\n", "27 March 2020\n", "\n", "Examples, documentation, code: https://github.com/ibm/lale\n", "\n", "\"logo\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Value Proposition\n", "\n", "- **target user**: data scientist familiar with Python and scikit-learn\n", "- **scope**: data preparation and machine learning (including some DL)\n", "- **value**: consistent API for both manual machine learning and auto-ML\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# !pip install --quiet lale" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#Enabling schema validation for this notebook\n", "from lale.settings import set_disable_data_schema_validation\n", "set_disable_data_schema_validation(False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape train_X_all (522910, 54), test_X (58102, 54)\n" ] } ], "source": [ "import lale.datasets\n", "(train_X_all, train_y_all), (test_X, test_y) = lale.datasets.covtype_df(test_size=0.1)\n", "print(f'shape train_X_all {train_X_all.shape}, test_X {test_X.shape}')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape train_X (52291, 54), other_X (470619, 54)\n" ] } ], "source": [ "import sklearn.model_selection\n", "train_X, other_X, train_y, other_y = sklearn.model_selection.train_test_split(\n", " train_X_all, train_y_all, test_size=0.9)\n", "print(f'shape train_X {train_X.shape}, other_X {other_X.shape}')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yElevationAspectSlopeHorizontal_Distance_To_HydrologyVertical_Distance_To_HydrologyHorizontal_Distance_To_RoadwaysHillshade_9amHillshade_NoonHillshade_3pmHorizontal_Distance_To_Fire_PointsWilderness_Area1Wilderness_Area2Wilderness_Area3Wilderness_Area4Soil_Type1Soil_Type2Soil_Type3Soil_Type4Soil_Type5Soil_Type6Soil_Type7Soil_Type8Soil_Type9Soil_Type10Soil_Type11Soil_Type12Soil_Type13Soil_Type14Soil_Type15Soil_Type16Soil_Type17Soil_Type18Soil_Type19Soil_Type20Soil_Type21Soil_Type22Soil_Type23Soil_Type24Soil_Type25Soil_Type26Soil_Type27Soil_Type28Soil_Type29Soil_Type30Soil_Type31Soil_Type32Soil_Type33Soil_Type34Soil_Type35Soil_Type36Soil_Type37Soil_Type38Soil_Type39Soil_Type40
48466532277.041.031.0228.0145.01045.0207.0157.065.01516.00.00.00.01.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
45113713273.0296.022.0371.045.01740.0153.0227.0212.0808.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.0
23930913062.0298.013.0408.078.02445.0184.0235.0191.01041.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.0
40690123195.042.019.0376.072.03873.0220.0196.0105.02935.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.0
37963223003.0310.014.0182.030.02573.0181.0230.0189.02408.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.0
51008412898.047.010.030.0-3.01865.0224.0219.0129.01022.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
9600122221.0338.022.0242.072.0437.0168.0204.0172.0342.00.00.00.01.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
3968413289.0322.018.0285.060.04012.0172.0219.0186.01291.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.0
22753522890.0272.06.0376.043.02296.0204.0242.0176.02460.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.0
8557813340.0204.016.0510.0134.01851.0210.0253.0174.01426.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
\n", "
" ], "text/plain": [ " y Elevation Aspect Slope Horizontal_Distance_To_Hydrology \\\n", "484665 3 2277.0 41.0 31.0 228.0 \n", "451137 1 3273.0 296.0 22.0 371.0 \n", "239309 1 3062.0 298.0 13.0 408.0 \n", "406901 2 3195.0 42.0 19.0 376.0 \n", "379632 2 3003.0 310.0 14.0 182.0 \n", "510084 1 2898.0 47.0 10.0 30.0 \n", "96001 2 2221.0 338.0 22.0 242.0 \n", "39684 1 3289.0 322.0 18.0 285.0 \n", "227535 2 2890.0 272.0 6.0 376.0 \n", "85578 1 3340.0 204.0 16.0 510.0 \n", "\n", " Vertical_Distance_To_Hydrology Horizontal_Distance_To_Roadways \\\n", "484665 145.0 1045.0 \n", "451137 45.0 1740.0 \n", "239309 78.0 2445.0 \n", "406901 72.0 3873.0 \n", "379632 30.0 2573.0 \n", "510084 -3.0 1865.0 \n", "96001 72.0 437.0 \n", "39684 60.0 4012.0 \n", "227535 43.0 2296.0 \n", "85578 134.0 1851.0 \n", "\n", " Hillshade_9am Hillshade_Noon Hillshade_3pm \\\n", "484665 207.0 157.0 65.0 \n", "451137 153.0 227.0 212.0 \n", "239309 184.0 235.0 191.0 \n", "406901 220.0 196.0 105.0 \n", "379632 181.0 230.0 189.0 \n", "510084 224.0 219.0 129.0 \n", "96001 168.0 204.0 172.0 \n", "39684 172.0 219.0 186.0 \n", "227535 204.0 242.0 176.0 \n", "85578 210.0 253.0 174.0 \n", "\n", " Horizontal_Distance_To_Fire_Points Wilderness_Area1 \\\n", "484665 1516.0 0.0 \n", "451137 808.0 0.0 \n", "239309 1041.0 1.0 \n", "406901 2935.0 1.0 \n", "379632 2408.0 0.0 \n", "510084 1022.0 1.0 \n", "96001 342.0 0.0 \n", "39684 1291.0 1.0 \n", "227535 2460.0 1.0 \n", "85578 1426.0 0.0 \n", "\n", " Wilderness_Area2 Wilderness_Area3 Wilderness_Area4 Soil_Type1 \\\n", "484665 0.0 0.0 1.0 0.0 \n", "451137 0.0 1.0 0.0 0.0 \n", "239309 0.0 0.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 0.0 \n", "379632 0.0 1.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 0.0 \n", "96001 0.0 0.0 1.0 0.0 \n", "39684 0.0 0.0 0.0 0.0 \n", "227535 0.0 0.0 0.0 0.0 \n", "85578 0.0 1.0 0.0 0.0 \n", "\n", " Soil_Type2 Soil_Type3 Soil_Type4 Soil_Type5 Soil_Type6 \\\n", "484665 0.0 0.0 0.0 0.0 0.0 \n", "451137 0.0 0.0 0.0 0.0 0.0 \n", "239309 0.0 0.0 0.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 0.0 0.0 \n", "379632 0.0 0.0 0.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 0.0 0.0 \n", "96001 0.0 0.0 0.0 0.0 0.0 \n", "39684 0.0 0.0 0.0 0.0 0.0 \n", "227535 0.0 0.0 0.0 0.0 0.0 \n", "85578 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type7 Soil_Type8 Soil_Type9 Soil_Type10 Soil_Type11 \\\n", "484665 0.0 0.0 0.0 1.0 0.0 \n", "451137 0.0 0.0 0.0 0.0 0.0 \n", "239309 0.0 0.0 0.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 0.0 0.0 \n", "379632 0.0 0.0 0.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 0.0 0.0 \n", "96001 0.0 0.0 0.0 1.0 0.0 \n", "39684 0.0 0.0 0.0 0.0 0.0 \n", "227535 0.0 0.0 0.0 0.0 0.0 \n", "85578 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type12 Soil_Type13 Soil_Type14 Soil_Type15 Soil_Type16 \\\n", "484665 0.0 0.0 0.0 0.0 0.0 \n", "451137 0.0 0.0 0.0 0.0 0.0 \n", "239309 0.0 0.0 0.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 0.0 0.0 \n", "379632 0.0 0.0 0.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 0.0 0.0 \n", "96001 0.0 0.0 0.0 0.0 0.0 \n", "39684 0.0 0.0 0.0 0.0 0.0 \n", "227535 0.0 0.0 0.0 0.0 0.0 \n", "85578 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type17 Soil_Type18 Soil_Type19 Soil_Type20 Soil_Type21 \\\n", "484665 0.0 0.0 0.0 0.0 0.0 \n", "451137 0.0 0.0 0.0 0.0 0.0 \n", "239309 0.0 0.0 0.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 0.0 0.0 \n", "379632 0.0 0.0 0.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 1.0 0.0 \n", "96001 0.0 0.0 0.0 0.0 0.0 \n", "39684 0.0 0.0 0.0 0.0 0.0 \n", "227535 0.0 0.0 0.0 0.0 0.0 \n", "85578 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type22 Soil_Type23 Soil_Type24 Soil_Type25 Soil_Type26 \\\n", "484665 0.0 0.0 0.0 0.0 0.0 \n", "451137 0.0 0.0 0.0 0.0 0.0 \n", "239309 0.0 0.0 0.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 0.0 0.0 \n", "379632 0.0 0.0 0.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 0.0 0.0 \n", "96001 0.0 0.0 0.0 0.0 0.0 \n", "39684 0.0 0.0 0.0 0.0 0.0 \n", "227535 0.0 0.0 0.0 0.0 0.0 \n", "85578 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type27 Soil_Type28 Soil_Type29 Soil_Type30 Soil_Type31 \\\n", "484665 0.0 0.0 0.0 0.0 0.0 \n", "451137 0.0 0.0 0.0 0.0 1.0 \n", "239309 0.0 0.0 1.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 1.0 0.0 \n", "379632 0.0 0.0 0.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 0.0 0.0 \n", "96001 0.0 0.0 0.0 0.0 0.0 \n", "39684 0.0 0.0 0.0 0.0 0.0 \n", "227535 0.0 0.0 1.0 0.0 0.0 \n", "85578 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type32 Soil_Type33 Soil_Type34 Soil_Type35 Soil_Type36 \\\n", "484665 0.0 0.0 0.0 0.0 0.0 \n", "451137 0.0 0.0 0.0 0.0 0.0 \n", "239309 0.0 0.0 0.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 0.0 0.0 \n", "379632 0.0 1.0 0.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 0.0 0.0 \n", "96001 0.0 0.0 0.0 0.0 0.0 \n", "39684 0.0 0.0 0.0 0.0 0.0 \n", "227535 0.0 0.0 0.0 0.0 0.0 \n", "85578 1.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type37 Soil_Type38 Soil_Type39 Soil_Type40 \n", "484665 0.0 0.0 0.0 0.0 \n", "451137 0.0 0.0 0.0 0.0 \n", "239309 0.0 0.0 0.0 0.0 \n", "406901 0.0 0.0 0.0 0.0 \n", "379632 0.0 0.0 0.0 0.0 \n", "510084 0.0 0.0 0.0 0.0 \n", "96001 0.0 0.0 0.0 0.0 \n", "39684 0.0 0.0 1.0 0.0 \n", "227535 0.0 0.0 0.0 0.0 \n", "85578 0.0 0.0 0.0 0.0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "pd.set_option('display.max_columns', None)\n", "pd.concat([pd.DataFrame({'y': train_y}, index=train_X.index),\n", " train_X], axis=1).tail(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Manual Pipeline" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "from xgboost import XGBClassifier as XGBoost\n", "lale.wrap_imported_operators()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "xg_boost\n", "\n", "\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "pca->xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "manual_trainable = PCA(n_components=6) >> XGBoost(n_estimators=3)\n", "manual_trainable.visualize()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.39 s, sys: 953 ms, total: 3.34 s\n", "Wall time: 2.05 s\n" ] } ], "source": [ "%%time\n", "manual_trained = manual_trainable.fit(train_X, train_y)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 67.1%\n" ] } ], "source": [ "import sklearn.metrics\n", "manual_y = manual_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, manual_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter Tuning" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'description': 'Number of trees to fit.',\n", " 'type': 'integer',\n", " 'default': 100,\n", " 'minimumForOptimizer': 50,\n", " 'maximumForOptimizer': 1000}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "XGBoost.hyperparam_schema('n_estimators')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.pca.html\n" ] } ], "source": [ "print(PCA.documentation_url())" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from lale.lib.lale import Hyperopt\n", "import lale.schemas as schemas\n", "\n", "CustomPCA = PCA.customize_schema(n_components=schemas.Int(minimum=2, maximum=54))\n", "CustomXGBoost = XGBoost.customize_schema(n_estimators=schemas.Int(minimum=1, maximum=10))\n", "\n", "hpo_planned = CustomPCA >> CustomXGBoost\n", "hpo_trainable = Hyperopt(estimator=hpo_planned, max_evals=10, cv=3)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|███████| 10/10 [02:15<00:00, 13.53s/trial, best loss: -0.7727907776451675]\n", "CPU times: user 2min 53s, sys: 19.3 s, total: 3min 13s\n", "Wall time: 2min 30s\n" ] } ], "source": [ "%%time\n", "hpo_trained = hpo_trainable.fit(train_X, train_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursions: Types as Search Spaces ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 77.7%\n" ] } ], "source": [ "hpo_y = hpo_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, hpo_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspecting Automation Results" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "custom_pca\n", "\n", "\n", "Custom-\n", "PCA\n", "\n", "\n", "\n", "\n", "custom_xg_boost\n", "\n", "\n", "Custom-\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "custom_pca->custom_xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline().visualize()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "from sklearn.decomposition import PCA as CustomPCA\n", "from xgboost import XGBClassifier as CustomXGBoost\n", "import lale\n", "\n", "lale.wrap_imported_operators()\n", "custom_pca = CustomPCA.customize_schema(\n", " n_components={\"type\": \"integer\", \"minimum\": 2, \"maximum\": 54}\n", ")(n_components=43, svd_solver=\"full\", whiten=True)\n", "custom_xg_boost = CustomXGBoost.customize_schema(\n", " n_estimators={\"type\": \"integer\", \"minimum\": 1, \"maximum\": 10}\n", ")(\n", " gamma=0.42208258595069725,\n", " learning_rate=0.6558019595096513,\n", " max_depth=5,\n", " min_child_weight=13,\n", " n_estimators=9,\n", " reg_alpha=0.3590229319214039,\n", " reg_lambda=0.7978279409450941,\n", " subsample=0.6209085649172931,\n", ")\n", "pipeline = custom_pca >> custom_xg_boost\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline().pretty_print(ipython_display=True, customize_schema=True)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tidlosstimelog_lossstatus
name
p00-0.6842292.2939111.161776ok
p11-0.7080573.3474940.950058ok
p22-0.6319833.3564431.123108ok
p33-0.6990502.6061001.168528ok
p44-0.7174285.1583460.690650ok
p55-0.7596537.1386890.655658ok
p66-0.7075983.5551260.942210ok
p77-0.77279110.9819150.555780ok
p88-0.6530572.0165870.845659ok
p99-0.6208532.1558181.817853ok
\n", "
" ], "text/plain": [ " tid loss time log_loss status\n", "name \n", "p0 0 -0.684229 2.293911 1.161776 ok\n", "p1 1 -0.708057 3.347494 0.950058 ok\n", "p2 2 -0.631983 3.356443 1.123108 ok\n", "p3 3 -0.699050 2.606100 1.168528 ok\n", "p4 4 -0.717428 5.158346 0.690650 ok\n", "p5 5 -0.759653 7.138689 0.655658 ok\n", "p6 6 -0.707598 3.555126 0.942210 ok\n", "p7 7 -0.772791 10.981915 0.555780 ok\n", "p8 8 -0.653057 2.016587 0.845659 ok\n", "p9 9 -0.620853 2.155818 1.817853 ok" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hpo_trained.summary()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "p9\n" ] } ], "source": [ "worst_name = hpo_trained.summary().loss.argmax()\n", "if not isinstance(worst_name, str): #newer pandas argmax returns index\n", " worst_name = hpo_trained.summary().index[worst_name]\n", "print(worst_name)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "custom_pca\n", "\n", "\n", "Custom-\n", "PCA\n", "\n", "\n", "\n", "\n", "custom_xg_boost\n", "\n", "\n", "Custom-\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "custom_pca->custom_xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "```python\n", "custom_pca = CustomPCA.customize_schema(\n", " n_components={\"type\": \"integer\", \"minimum\": 2, \"maximum\": 54}\n", ")(n_components=20, svd_solver=\"full\", whiten=True)\n", "custom_xg_boost = CustomXGBoost.customize_schema(\n", " n_estimators={\"type\": \"integer\", \"minimum\": 1, \"maximum\": 10}\n", ")(\n", " gamma=0.37068548766270437,\n", " learning_rate=0.02005982973762002,\n", " max_depth=2,\n", " min_child_weight=9,\n", " n_estimators=5,\n", " reg_alpha=0.8716519284632148,\n", " reg_lambda=0.7305593001592293,\n", " subsample=0.9559232064468288,\n", ")\n", "pipeline = custom_pca >> custom_xg_boost\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline(worst_name).visualize()\n", "hpo_trained.get_pipeline(worst_name).pretty_print(\n", " ipython_display=True, show_imports=False, customize_schema=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combined Algorithm Selection and Hyperparameter Tuning" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "cluster:choice_0\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "cluster:choice_1\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "\n", "norm\n", "\n", "\n", "Norm\n", "\n", "\n", "\n", "\n", "tree\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "\n", "norm->tree\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "lr\n", "\n", "\n", "LR\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.preprocessing import Normalizer as Norm\n", "from sklearn.linear_model import LogisticRegression as LR\n", "from sklearn.tree import DecisionTreeClassifier as Tree\n", "from sklearn.neighbors import KNeighborsClassifier as KNN\n", "from lale.lib.lale import NoOp\n", "lale.wrap_imported_operators()\n", "\n", "KNN = KNN.customize_schema(n_neighbors=schemas.Int(minimum=1, maximum=10))\n", "transp_planned = (Norm | NoOp) >> (Tree | LR(solver='liblinear') | KNN)\n", "transp_planned.visualize()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|█████████| 3/3 [01:24<00:00, 28.25s/trial, best loss: -0.8390927596840342]\n", "CPU times: user 1min 26s, sys: 953 ms, total: 1min 27s\n", "Wall time: 1min 26s\n" ] } ], "source": [ "%%time\n", "transp_trained = transp_planned.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursion: Bindings as Lifecycle ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "knn = KNN.customize_schema(\n", " n_neighbors={\"type\": \"integer\", \"minimum\": 1, \"maximum\": 10}\n", ")(algorithm=\"ball_tree\", metric=\"manhattan\", n_neighbors=9)\n", "pipeline = NoOp() >> knn\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "no_op->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "transp_trained.pretty_print(\n", " ipython_display=True, show_imports=False, customize_schema=True)\n", "transp_trained.visualize()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 86.6%\n", "CPU times: user 51.5 s, sys: 31.2 ms, total: 51.5 s\n", "Wall time: 52 s\n" ] } ], "source": [ "%%time\n", "transp_y = transp_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, transp_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Non-Linear Pipeline" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "{'description': 'Features of forest covertypes dataset (classification).',\n", " 'documentation_url': 'https://scikit-learn.org/0.20/datasets/index.html#forest-covertypes',\n", " 'type': 'array',\n", " 'items': {'type': 'array',\n", " 'minItems': 54,\n", " 'maxItems': 54,\n", " 'items': [{'description': 'Elevation', 'type': 'integer'},\n", " {'description': 'Aspect', 'type': 'integer'},\n", " {'description': 'Slope', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Hydrology', 'type': 'integer'},\n", " {'description': 'Vertical_Distance_To_Hydrology', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Roadways', 'type': 'integer'},\n", " {'description': 'Hillshade_9am', 'type': 'integer'},\n", " {'description': 'Hillshade_Noon', 'type': 'integer'},\n", " {'description': 'Hillshade_3pm', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Fire_Points', 'type': 'integer'},\n", " {'description': 'Wilderness_Area1', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area2', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area3', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area4', 'enum': [0, 1]},\n", " {'description': 'Soil_Type1', 'enum': [0, 1]},\n", " {'description': 'Soil_Type2', 'enum': [0, 1]},\n", " {'description': 'Soil_Type3', 'enum': [0, 1]},\n", " {'description': 'Soil_Type4', 'enum': [0, 1]},\n", " {'description': 'Soil_Type5', 'enum': [0, 1]},\n", " {'description': 'Soil_Type6', 'enum': [0, 1]},\n", " {'description': 'Soil_Type7', 'enum': [0, 1]},\n", " {'description': 'Soil_Type8', 'enum': [0, 1]},\n", " {'description': 'Soil_Type9', 'enum': [0, 1]},\n", " {'description': 'Soil_Type10', 'enum': [0, 1]},\n", " {'description': 'Soil_Type11', 'enum': [0, 1]},\n", " {'description': 'Soil_Type12', 'enum': [0, 1]},\n", " {'description': 'Soil_Type13', 'enum': [0, 1]},\n", " {'description': 'Soil_Type14', 'enum': [0, 1]},\n", " {'description': 'Soil_Type15', 'enum': [0, 1]},\n", " {'description': 'Soil_Type16', 'enum': [0, 1]},\n", " {'description': 'Soil_Type17', 'enum': [0, 1]},\n", " {'description': 'Soil_Type18', 'enum': [0, 1]},\n", " {'description': 'Soil_Type19', 'enum': [0, 1]},\n", " {'description': 'Soil_Type20', 'enum': [0, 1]},\n", " {'description': 'Soil_Type21', 'enum': [0, 1]},\n", " {'description': 'Soil_Type22', 'enum': [0, 1]},\n", " {'description': 'Soil_Type23', 'enum': [0, 1]},\n", " {'description': 'Soil_Type24', 'enum': [0, 1]},\n", " {'description': 'Soil_Type25', 'enum': [0, 1]},\n", " {'description': 'Soil_Type26', 'enum': [0, 1]},\n", " {'description': 'Soil_Type27', 'enum': [0, 1]},\n", " {'description': 'Soil_Type28', 'enum': [0, 1]},\n", " {'description': 'Soil_Type29', 'enum': [0, 1]},\n", " {'description': 'Soil_Type30', 'enum': [0, 1]},\n", " {'description': 'Soil_Type31', 'enum': [0, 1]},\n", " {'description': 'Soil_Type32', 'enum': [0, 1]},\n", " {'description': 'Soil_Type33', 'enum': [0, 1]},\n", " {'description': 'Soil_Type34', 'enum': [0, 1]},\n", " {'description': 'Soil_Type35', 'enum': [0, 1]},\n", " {'description': 'Soil_Type36', 'enum': [0, 1]},\n", " {'description': 'Soil_Type37', 'enum': [0, 1]},\n", " {'description': 'Soil_Type38', 'enum': [0, 1]},\n", " {'description': 'Soil_Type39', 'enum': [0, 1]},\n", " {'description': 'Soil_Type40', 'enum': [0, 1]}]},\n", " 'minItems': 58102,\n", " 'maxItems': 58102}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_X.json_schema" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Wilderness_Area1', 'Wilderness_Area2', 'Wilderness_Area3', 'Wilderness_Area4', 'Soil_Type1', 'Soil_Type2', 'Soil_Type3', 'Soil_Type4', 'Soil_Type5', 'Soil_Type6', 'Soil_Type7', 'Soil_Type8', 'Soil_Type9', 'Soil_Type10', 'Soil_Type11', 'Soil_Type12', 'Soil_Type13', 'Soil_Type14', 'Soil_Type15', 'Soil_Type16', 'Soil_Type17', 'Soil_Type18', 'Soil_Type19', 'Soil_Type20', 'Soil_Type21', 'Soil_Type22', 'Soil_Type23', 'Soil_Type24', 'Soil_Type25', 'Soil_Type26', 'Soil_Type27', 'Soil_Type28', 'Soil_Type29', 'Soil_Type30', 'Soil_Type31', 'Soil_Type32', 'Soil_Type33', 'Soil_Type34', 'Soil_Type35', 'Soil_Type36', 'Soil_Type37', 'Soil_Type38', 'Soil_Type39', 'Soil_Type40']\n" ] } ], "source": [ "from lale.lib.lale import categorical\n", "print(categorical(max_values=2)(test_X))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "cluster:choice\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "\n", "project_0\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "feat_sel\n", "\n", "\n", "Feat-\n", "Sel\n", "\n", "\n", "\n", "\n", "project_0->feat_sel\n", "\n", "\n", "\n", "\n", "concat\n", "\n", "\n", "Concat\n", "\n", "\n", "\n", "\n", "feat_sel->concat\n", "\n", "\n", "\n", "\n", "project_1\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "norm\n", "\n", "\n", "Norm\n", "\n", "\n", "\n", "\n", "project_1->norm\n", "\n", "\n", "\n", "\n", "norm->concat\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "concat->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from lale.lib.lale import Project\n", "from lale.lib.lale import ConcatFeatures as Concat\n", "from sklearn.feature_selection import SelectKBest as FeatSel\n", "lale.wrap_imported_operators()\n", "\n", "binary_prep = Project(columns=categorical(max_values=2)) >> FeatSel\n", "other_prep = Project(drop_columns=categorical(max_values=2)) >> (Norm | NoOp)\n", "nonlin_planned = (binary_prep & other_prep) >> Concat >> KNN\n", "nonlin_planned.visualize()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|█████████| 3/3 [02:17<00:00, 45.88s/trial, best loss: -0.8620412868709595]\n", "CPU times: user 2min 18s, sys: 359 ms, total: 2min 19s\n", "Wall time: 2min 21s\n" ] } ], "source": [ "%%time\n", "nonlin_trained = nonlin_planned.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3, verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursion: Combinators ---\n", "\n", "| Lale feature | Name | Description | Scikit-learn feature |\n", "| ----------------------- | ---- | ------------ | ----------------------------------- |\n", "| >> or `make_pipeline` | pipe | feed to next | `make_pipeline` |\n", "| & or `make_union` | and | run both | `make_union` or `ColumnTransformer` |\n", "| | or `make_choice` | or | choose one | N/A (specific to given AutoML tool) |\n", "\n", "### --- Excursion: Interoperability ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "project_0\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "feat_sel\n", "\n", "\n", "Feat-\n", "Sel\n", "\n", "\n", "\n", "\n", "project_0->feat_sel\n", "\n", "\n", "\n", "\n", "concat\n", "\n", "\n", "Concat\n", "\n", "\n", "\n", "\n", "feat_sel->concat\n", "\n", "\n", "\n", "\n", "project_1\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "project_1->no_op\n", "\n", "\n", "\n", "\n", "no_op->concat\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "concat->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "```python\n", "project_0 = Project(columns=lale.lib.lale.categorical(max_values=2))\n", "feat_sel = FeatSel(k=8)\n", "pipeline_0 = make_pipeline(project_0, feat_sel)\n", "project_1 = Project(drop_columns=lale.lib.lale.categorical(max_values=2))\n", "pipeline_1 = make_pipeline(project_1, NoOp())\n", "union = make_union(pipeline_0, pipeline_1)\n", "knn = KNN(algorithm=\"kd_tree\", n_neighbors=7, weights=\"distance\")\n", "pipeline = make_pipeline(union, knn)\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "nonlin_trained.visualize()\n", "nonlin_trained.pretty_print(ipython_display=True, show_imports=False, combinators=False)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 88.6%\n", "CPU times: user 4.31 s, sys: 46.9 ms, total: 4.36 s\n", "Wall time: 4.44 s\n" ] } ], "source": [ "%%time\n", "nonlin_y = nonlin_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, nonlin_y):.1%}')" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Wilderness_Area1Wilderness_Area4Soil_Type2Soil_Type3Soil_Type4Soil_Type10Soil_Type38Soil_Type39
01.00.00.00.00.00.01.00.0
10.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.0
31.00.00.00.00.00.00.00.0
41.00.00.00.00.00.00.00.0
50.01.00.00.00.01.00.00.0
61.00.00.00.00.00.00.00.0
71.00.00.00.00.00.01.00.0
81.00.00.00.00.00.00.00.0
91.00.00.00.00.00.00.00.0
\n", "
" ], "text/plain": [ " Wilderness_Area1 Wilderness_Area4 Soil_Type2 Soil_Type3 Soil_Type4 \\\n", "0 1.0 0.0 0.0 0.0 0.0 \n", "1 0.0 0.0 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 0.0 0.0 \n", "3 1.0 0.0 0.0 0.0 0.0 \n", "4 1.0 0.0 0.0 0.0 0.0 \n", "5 0.0 1.0 0.0 0.0 0.0 \n", "6 1.0 0.0 0.0 0.0 0.0 \n", "7 1.0 0.0 0.0 0.0 0.0 \n", "8 1.0 0.0 0.0 0.0 0.0 \n", "9 1.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type10 Soil_Type38 Soil_Type39 \n", "0 0.0 1.0 0.0 \n", "1 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 \n", "3 0.0 0.0 0.0 \n", "4 0.0 0.0 0.0 \n", "5 1.0 0.0 0.0 \n", "6 0.0 0.0 0.0 \n", "7 0.0 1.0 0.0 \n", "8 0.0 0.0 0.0 \n", "9 0.0 0.0 0.0 " ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "binary_prep_trainable = Project(columns=categorical(max_values=2)) >> FeatSel(k=8)\n", "binary_prep_trained = binary_prep_trainable.fit(train_X, train_y)\n", "binary_prep_trained.transform(test_X.head(10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "- code and documentation: https://github.com/ibm/lale\n", "- more examples: https://github.com/IBM/lale/tree/master/examples/\n", "- frequently asked questions: https://github.com/IBM/lale/blob/master/docs/faq.rst\n", "- arXiv paper: https://arxiv.org/pdf/1906.03957.pdf\n", "\n", "" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }