{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lale and its Impact on the Data Science Workflow\n", "\n", "Guillaume Baudart, Martin Hirzel, Kiran Kate, Pari Ram, and Avi Shinnar\n", "\n", "27 March 2020\n", "\n", "Examples, documentation, code: https://github.com/ibm/lale\n", "\n", "\"logo\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Value Proposition\n", "\n", "- **target user**: data scientist familiar with Python and scikit-learn\n", "- **scope**: data preparation and machine learning (including some DL)\n", "- **value**: consistent API for both manual machine learning and auto-ML\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# !pip install --quiet lale" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape train_X_all (522910, 54), test_X (58102, 54)\n" ] } ], "source": [ "import lale.datasets\n", "(train_X_all, train_y_all), (test_X, test_y) = lale.datasets.covtype_df(test_size=0.1)\n", "print(f'shape train_X_all {train_X_all.shape}, test_X {test_X.shape}')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape train_X (52291, 54), other_X (470619, 54)\n" ] } ], "source": [ "import sklearn.model_selection\n", "train_X, other_X, train_y, other_y = sklearn.model_selection.train_test_split(\n", " train_X_all, train_y_all, test_size=0.9)\n", "print(f'shape train_X {train_X.shape}, other_X {other_X.shape}')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yElevationAspectSlopeHorizontal_Distance_To_HydrologyVertical_Distance_To_HydrologyHorizontal_Distance_To_RoadwaysHillshade_9amHillshade_NoonHillshade_3pmHorizontal_Distance_To_Fire_PointsWilderness_Area1Wilderness_Area2Wilderness_Area3Wilderness_Area4Soil_Type1Soil_Type2Soil_Type3Soil_Type4Soil_Type5Soil_Type6Soil_Type7Soil_Type8Soil_Type9Soil_Type10Soil_Type11Soil_Type12Soil_Type13Soil_Type14Soil_Type15Soil_Type16Soil_Type17Soil_Type18Soil_Type19Soil_Type20Soil_Type21Soil_Type22Soil_Type23Soil_Type24Soil_Type25Soil_Type26Soil_Type27Soil_Type28Soil_Type29Soil_Type30Soil_Type31Soil_Type32Soil_Type33Soil_Type34Soil_Type35Soil_Type36Soil_Type37Soil_Type38Soil_Type39Soil_Type40
27466532354.0130.023.0285.080.0277.0250.0220.086.0874.00.00.01.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
12021022985.091.018.0886.0187.03180.0244.0209.088.0828.00.00.01.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
11177523142.088.020.0684.0-52.0551.0245.0204.080.01082.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
40056732493.0108.014.0182.034.0666.0243.0223.0107.01294.00.00.00.01.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
22468222796.0352.09.0594.084.02955.0205.0225.0158.01471.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
42472313126.0197.013.085.010.05344.0216.0251.0166.01148.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.0
44577712981.0333.016.0150.014.02704.0182.0218.0175.0655.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
38816313380.0219.06.0395.088.02895.0213.0246.0169.01224.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
52258873397.0113.015.0706.0240.01507.0245.0223.0103.01040.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.0
12844122831.0155.021.085.027.04235.0239.0236.0116.05071.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.0
\n", "
" ], "text/plain": [ " y Elevation Aspect Slope Horizontal_Distance_To_Hydrology \\\n", "274665 3 2354.0 130.0 23.0 285.0 \n", "120210 2 2985.0 91.0 18.0 886.0 \n", "111775 2 3142.0 88.0 20.0 684.0 \n", "400567 3 2493.0 108.0 14.0 182.0 \n", "224682 2 2796.0 352.0 9.0 594.0 \n", "424723 1 3126.0 197.0 13.0 85.0 \n", "445777 1 2981.0 333.0 16.0 150.0 \n", "388163 1 3380.0 219.0 6.0 395.0 \n", "522588 7 3397.0 113.0 15.0 706.0 \n", "128441 2 2831.0 155.0 21.0 85.0 \n", "\n", " Vertical_Distance_To_Hydrology Horizontal_Distance_To_Roadways \\\n", "274665 80.0 277.0 \n", "120210 187.0 3180.0 \n", "111775 -52.0 551.0 \n", "400567 34.0 666.0 \n", "224682 84.0 2955.0 \n", "424723 10.0 5344.0 \n", "445777 14.0 2704.0 \n", "388163 88.0 2895.0 \n", "522588 240.0 1507.0 \n", "128441 27.0 4235.0 \n", "\n", " Hillshade_9am Hillshade_Noon Hillshade_3pm \\\n", "274665 250.0 220.0 86.0 \n", "120210 244.0 209.0 88.0 \n", "111775 245.0 204.0 80.0 \n", "400567 243.0 223.0 107.0 \n", "224682 205.0 225.0 158.0 \n", "424723 216.0 251.0 166.0 \n", "445777 182.0 218.0 175.0 \n", "388163 213.0 246.0 169.0 \n", "522588 245.0 223.0 103.0 \n", "128441 239.0 236.0 116.0 \n", "\n", " Horizontal_Distance_To_Fire_Points Wilderness_Area1 \\\n", "274665 874.0 0.0 \n", "120210 828.0 0.0 \n", "111775 1082.0 0.0 \n", "400567 1294.0 0.0 \n", "224682 1471.0 0.0 \n", "424723 1148.0 1.0 \n", "445777 655.0 0.0 \n", "388163 1224.0 0.0 \n", "522588 1040.0 0.0 \n", "128441 5071.0 1.0 \n", "\n", " Wilderness_Area2 Wilderness_Area3 Wilderness_Area4 Soil_Type1 \\\n", "274665 0.0 1.0 0.0 0.0 \n", "120210 0.0 1.0 0.0 0.0 \n", "111775 1.0 0.0 0.0 0.0 \n", "400567 0.0 0.0 1.0 0.0 \n", "224682 0.0 1.0 0.0 0.0 \n", "424723 0.0 0.0 0.0 0.0 \n", "445777 0.0 1.0 0.0 0.0 \n", "388163 0.0 1.0 0.0 0.0 \n", "522588 0.0 1.0 0.0 0.0 \n", "128441 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type2 Soil_Type3 Soil_Type4 Soil_Type5 Soil_Type6 \\\n", "274665 0.0 1.0 0.0 0.0 0.0 \n", "120210 0.0 0.0 0.0 0.0 0.0 \n", "111775 0.0 0.0 0.0 0.0 0.0 \n", "400567 0.0 0.0 0.0 0.0 1.0 \n", "224682 0.0 0.0 0.0 0.0 0.0 \n", "424723 0.0 0.0 0.0 0.0 0.0 \n", "445777 0.0 0.0 0.0 0.0 0.0 \n", "388163 0.0 0.0 0.0 0.0 0.0 \n", "522588 0.0 0.0 0.0 0.0 0.0 \n", "128441 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type7 Soil_Type8 Soil_Type9 Soil_Type10 Soil_Type11 \\\n", "274665 0.0 0.0 0.0 0.0 0.0 \n", "120210 0.0 0.0 0.0 1.0 0.0 \n", "111775 0.0 0.0 0.0 0.0 0.0 \n", "400567 0.0 0.0 0.0 0.0 0.0 \n", "224682 0.0 0.0 0.0 0.0 0.0 \n", "424723 0.0 0.0 0.0 0.0 0.0 \n", "445777 0.0 0.0 0.0 0.0 0.0 \n", "388163 0.0 0.0 0.0 0.0 0.0 \n", "522588 0.0 0.0 0.0 0.0 0.0 \n", "128441 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type12 Soil_Type13 Soil_Type14 Soil_Type15 Soil_Type16 \\\n", "274665 0.0 0.0 0.0 0.0 0.0 \n", "120210 0.0 0.0 0.0 0.0 0.0 \n", "111775 0.0 0.0 0.0 0.0 0.0 \n", "400567 0.0 0.0 0.0 0.0 0.0 \n", "224682 0.0 0.0 0.0 0.0 0.0 \n", "424723 0.0 0.0 0.0 0.0 0.0 \n", "445777 0.0 0.0 0.0 0.0 0.0 \n", "388163 0.0 0.0 0.0 0.0 0.0 \n", "522588 0.0 0.0 0.0 0.0 0.0 \n", "128441 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type17 Soil_Type18 Soil_Type19 Soil_Type20 Soil_Type21 \\\n", "274665 0.0 0.0 0.0 0.0 0.0 \n", "120210 0.0 0.0 0.0 0.0 0.0 \n", "111775 0.0 0.0 1.0 0.0 0.0 \n", "400567 0.0 0.0 0.0 0.0 0.0 \n", "224682 0.0 0.0 0.0 0.0 0.0 \n", "424723 0.0 0.0 0.0 0.0 0.0 \n", "445777 0.0 0.0 0.0 0.0 0.0 \n", "388163 0.0 0.0 0.0 0.0 0.0 \n", "522588 0.0 0.0 0.0 0.0 0.0 \n", "128441 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type22 Soil_Type23 Soil_Type24 Soil_Type25 Soil_Type26 \\\n", "274665 0.0 0.0 0.0 0.0 0.0 \n", "120210 0.0 0.0 0.0 0.0 0.0 \n", "111775 0.0 0.0 0.0 0.0 0.0 \n", "400567 0.0 0.0 0.0 0.0 0.0 \n", "224682 0.0 0.0 0.0 0.0 0.0 \n", "424723 0.0 0.0 0.0 0.0 0.0 \n", "445777 0.0 0.0 0.0 0.0 0.0 \n", "388163 0.0 0.0 0.0 0.0 0.0 \n", "522588 0.0 0.0 0.0 0.0 0.0 \n", "128441 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type27 Soil_Type28 Soil_Type29 Soil_Type30 Soil_Type31 \\\n", "274665 0.0 0.0 0.0 0.0 0.0 \n", "120210 0.0 0.0 0.0 0.0 0.0 \n", "111775 0.0 0.0 0.0 0.0 0.0 \n", "400567 0.0 0.0 0.0 0.0 0.0 \n", "224682 0.0 0.0 0.0 0.0 0.0 \n", "424723 0.0 0.0 1.0 0.0 0.0 \n", "445777 0.0 0.0 0.0 0.0 0.0 \n", "388163 0.0 0.0 0.0 0.0 0.0 \n", "522588 0.0 0.0 0.0 0.0 0.0 \n", "128441 0.0 0.0 0.0 1.0 0.0 \n", "\n", " Soil_Type32 Soil_Type33 Soil_Type34 Soil_Type35 Soil_Type36 \\\n", "274665 0.0 0.0 0.0 0.0 0.0 \n", "120210 0.0 0.0 0.0 0.0 0.0 \n", "111775 0.0 0.0 0.0 0.0 0.0 \n", "400567 0.0 0.0 0.0 0.0 0.0 \n", "224682 1.0 0.0 0.0 0.0 0.0 \n", "424723 0.0 0.0 0.0 0.0 0.0 \n", "445777 1.0 0.0 0.0 0.0 0.0 \n", "388163 1.0 0.0 0.0 0.0 0.0 \n", "522588 0.0 0.0 0.0 0.0 0.0 \n", "128441 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type37 Soil_Type38 Soil_Type39 Soil_Type40 \n", "274665 0.0 0.0 0.0 0.0 \n", "120210 0.0 0.0 0.0 0.0 \n", "111775 0.0 0.0 0.0 0.0 \n", "400567 0.0 0.0 0.0 0.0 \n", "224682 0.0 0.0 0.0 0.0 \n", "424723 0.0 0.0 0.0 0.0 \n", "445777 0.0 0.0 0.0 0.0 \n", "388163 0.0 0.0 0.0 0.0 \n", "522588 0.0 0.0 0.0 1.0 \n", "128441 0.0 0.0 0.0 0.0 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "pd.set_option('display.max_columns', None)\n", "pd.concat([pd.DataFrame({'y': train_y}, index=train_X.index),\n", " train_X], axis=1).tail(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Manual Pipeline" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "from xgboost import XGBClassifier as XGBoost\n", "lale.wrap_imported_operators()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "xg_boost\n", "\n", "\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "pca->xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "manual_trainable = PCA(n_components=6) >> XGBoost(n_estimators=3)\n", "manual_trainable.visualize()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3.56 s, sys: 672 ms, total: 4.23 s\n", "Wall time: 3.88 s\n" ] } ], "source": [ "%%time\n", "manual_trained = manual_trainable.fit(train_X, train_y)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 75.5%\n" ] } ], "source": [ "import sklearn.metrics\n", "manual_y = manual_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, manual_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter Tuning" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'description': 'Number of trees to fit.',\n", " 'type': 'integer',\n", " 'default': 1000,\n", " 'minimumForOptimizer': 500,\n", " 'maximumForOptimizer': 1500}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "XGBoost.hyperparam_schema('n_estimators')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.pca.html\n" ] } ], "source": [ "print(PCA.documentation_url())" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from lale.lib.lale import Hyperopt\n", "import lale.schemas as schemas\n", "\n", "CustomPCA = PCA.customize_schema(n_components=schemas.Int(min=2, max=54))\n", "CustomXGBoost = XGBoost.customize_schema(n_estimators=schemas.Int(min=1, max=10))\n", "\n", "hpo_planned = CustomPCA >> CustomXGBoost\n", "hpo_trainable = Hyperopt(estimator=hpo_planned, max_evals=10, cv=3)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|███████| 10/10 [04:22<00:00, 26.22s/trial, best loss: -0.8287659271127307]\n", "CPU times: user 4min 57s, sys: 20 s, total: 5min 17s\n", "Wall time: 4min 53s\n" ] } ], "source": [ "%%time\n", "hpo_trained = hpo_trainable.fit(train_X, train_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursions: Types as Search Spaces ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 84.2%\n" ] } ], "source": [ "hpo_y = hpo_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, hpo_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspecting Automation Results" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "custom_pca\n", "\n", "\n", "Custom-\n", "PCA\n", "\n", "\n", "\n", "\n", "custom_xg_boost\n", "\n", "\n", "Custom-\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "custom_pca->custom_xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline().visualize()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "from sklearn.decomposition import PCA as CustomPCA\n", "from xgboost import XGBClassifier as CustomXGBoost\n", "import lale\n", "\n", "lale.wrap_imported_operators()\n", "custom_pca = CustomPCA(n_components=43, svd_solver=\"full\", whiten=True)\n", "custom_xg_boost = CustomXGBoost(\n", " gamma=0.42208258595069725,\n", " learning_rate=0.6558019595096513,\n", " max_depth=13,\n", " min_child_weight=13,\n", " n_estimators=9,\n", " reg_alpha=0.3590229319214039,\n", " reg_lambda=0.7978279409450941,\n", " subsample=0.6209085649172931,\n", ")\n", "pipeline = custom_pca >> custom_xg_boost\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline().pretty_print(ipython_display=True)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tidlosstimelog_lossstatus
name
p00-0.7542984.0803991.039077ok
p11-0.7744937.4939490.799467ok
p22-0.7253066.7442880.948600ok
p33-0.7831754.7150541.036146ok
p44-0.7596728.9489710.576866ok
p55-0.82302911.5895230.514666ok
p66-0.78340412.2325030.765154ok
p77-0.82876620.8782590.435281ok
p88-0.7245614.0455070.669205ok
p99-0.7318284.7924841.780335ok
\n", "
" ], "text/plain": [ " tid loss time log_loss status\n", "name \n", "p0 0 -0.754298 4.080399 1.039077 ok\n", "p1 1 -0.774493 7.493949 0.799467 ok\n", "p2 2 -0.725306 6.744288 0.948600 ok\n", "p3 3 -0.783175 4.715054 1.036146 ok\n", "p4 4 -0.759672 8.948971 0.576866 ok\n", "p5 5 -0.823029 11.589523 0.514666 ok\n", "p6 6 -0.783404 12.232503 0.765154 ok\n", "p7 7 -0.828766 20.878259 0.435281 ok\n", "p8 8 -0.724561 4.045507 0.669205 ok\n", "p9 9 -0.731828 4.792484 1.780335 ok" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hpo_trained.summary()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "p8\n" ] } ], "source": [ "worst_name = hpo_trained.summary().loss.argmax()\n", "if not isinstance(worst_name, str): #newer pandas argmax returns index\n", " worst_name = hpo_trained.summary().index[worst_name]\n", "print(worst_name)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "custom_pca\n", "\n", "\n", "Custom-\n", "PCA\n", "\n", "\n", "\n", "\n", "custom_xg_boost\n", "\n", "\n", "Custom-\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "custom_pca->custom_xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "```python\n", "custom_pca = CustomPCA(n_components=19, svd_solver=\"full\")\n", "custom_xg_boost = CustomXGBoost(\n", " gamma=0.025801085053521078,\n", " learning_rate=0.5793622466253201,\n", " max_depth=3,\n", " min_child_weight=8,\n", " n_estimators=9,\n", " reg_alpha=0.49646670359671663,\n", " reg_lambda=0.9280083037935846,\n", " subsample=0.5479690370134093,\n", ")\n", "pipeline = custom_pca >> custom_xg_boost\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline(worst_name).visualize()\n", "hpo_trained.get_pipeline(worst_name).pretty_print(ipython_display=True, show_imports=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combined Algorithm Selection and Hyperparameter Tuning" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "cluster:choice_0\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "cluster:choice_1\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "\n", "norm\n", "\n", "\n", "Norm\n", "\n", "\n", "\n", "\n", "tree\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "\n", "norm->tree\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "lr\n", "\n", "\n", "LR\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.preprocessing import Normalizer as Norm\n", "from sklearn.linear_model import LogisticRegression as LR\n", "from sklearn.tree import DecisionTreeClassifier as Tree\n", "from sklearn.neighbors import KNeighborsClassifier as KNN\n", "from lale.lib.lale import NoOp\n", "lale.wrap_imported_operators()\n", "\n", "KNN = KNN.customize_schema(n_neighbors=schemas.Int(min=1, max=10))\n", "transp_planned = (Norm | NoOp) >> (Tree | LR(solver='liblinear') | KNN)\n", "transp_planned.visualize()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|█████████| 3/3 [01:25<00:00, 28.55s/trial, best loss: -0.8412346112501562]\n", "CPU times: user 1min 27s, sys: 1.34 s, total: 1min 28s\n", "Wall time: 1min 27s\n" ] } ], "source": [ "%%time\n", "transp_trained = transp_planned.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursion: Bindings as Lifecycle ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "knn = KNN(algorithm=\"ball_tree\", metric=\"manhattan\", n_neighbors=9)\n", "pipeline = NoOp() >> knn\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "no_op->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "transp_trained.pretty_print(ipython_display=True, show_imports=False)\n", "transp_trained.visualize()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 86.6%\n", "CPU times: user 52.4 s, sys: 78.1 ms, total: 52.5 s\n", "Wall time: 53 s\n" ] } ], "source": [ "%%time\n", "transp_y = transp_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, transp_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Non-Linear Pipeline" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "{'description': 'Features of forest covertypes dataset (classification).',\n", " 'documentation_url': 'https://scikit-learn.org/0.20/datasets/index.html#forest-covertypes',\n", " 'type': 'array',\n", " 'items': {'type': 'array',\n", " 'minItems': 54,\n", " 'maxItems': 54,\n", " 'items': [{'description': 'Elevation', 'type': 'integer'},\n", " {'description': 'Aspect', 'type': 'integer'},\n", " {'description': 'Slope', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Hydrology', 'type': 'integer'},\n", " {'description': 'Vertical_Distance_To_Hydrology', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Roadways', 'type': 'integer'},\n", " {'description': 'Hillshade_9am', 'type': 'integer'},\n", " {'description': 'Hillshade_Noon', 'type': 'integer'},\n", " {'description': 'Hillshade_3pm', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Fire_Points', 'type': 'integer'},\n", " {'description': 'Wilderness_Area1', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area2', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area3', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area4', 'enum': [0, 1]},\n", " {'description': 'Soil_Type1', 'enum': [0, 1]},\n", " {'description': 'Soil_Type2', 'enum': [0, 1]},\n", " {'description': 'Soil_Type3', 'enum': [0, 1]},\n", " {'description': 'Soil_Type4', 'enum': [0, 1]},\n", " {'description': 'Soil_Type5', 'enum': [0, 1]},\n", " {'description': 'Soil_Type6', 'enum': [0, 1]},\n", " {'description': 'Soil_Type7', 'enum': [0, 1]},\n", " {'description': 'Soil_Type8', 'enum': [0, 1]},\n", " {'description': 'Soil_Type9', 'enum': [0, 1]},\n", " {'description': 'Soil_Type10', 'enum': [0, 1]},\n", " {'description': 'Soil_Type11', 'enum': [0, 1]},\n", " {'description': 'Soil_Type12', 'enum': [0, 1]},\n", " {'description': 'Soil_Type13', 'enum': [0, 1]},\n", " {'description': 'Soil_Type14', 'enum': [0, 1]},\n", " {'description': 'Soil_Type15', 'enum': [0, 1]},\n", " {'description': 'Soil_Type16', 'enum': [0, 1]},\n", " {'description': 'Soil_Type17', 'enum': [0, 1]},\n", " {'description': 'Soil_Type18', 'enum': [0, 1]},\n", " {'description': 'Soil_Type19', 'enum': [0, 1]},\n", " {'description': 'Soil_Type20', 'enum': [0, 1]},\n", " {'description': 'Soil_Type21', 'enum': [0, 1]},\n", " {'description': 'Soil_Type22', 'enum': [0, 1]},\n", " {'description': 'Soil_Type23', 'enum': [0, 1]},\n", " {'description': 'Soil_Type24', 'enum': [0, 1]},\n", " {'description': 'Soil_Type25', 'enum': [0, 1]},\n", " {'description': 'Soil_Type26', 'enum': [0, 1]},\n", " {'description': 'Soil_Type27', 'enum': [0, 1]},\n", " {'description': 'Soil_Type28', 'enum': [0, 1]},\n", " {'description': 'Soil_Type29', 'enum': [0, 1]},\n", " {'description': 'Soil_Type30', 'enum': [0, 1]},\n", " {'description': 'Soil_Type31', 'enum': [0, 1]},\n", " {'description': 'Soil_Type32', 'enum': [0, 1]},\n", " {'description': 'Soil_Type33', 'enum': [0, 1]},\n", " {'description': 'Soil_Type34', 'enum': [0, 1]},\n", " {'description': 'Soil_Type35', 'enum': [0, 1]},\n", " {'description': 'Soil_Type36', 'enum': [0, 1]},\n", " {'description': 'Soil_Type37', 'enum': [0, 1]},\n", " {'description': 'Soil_Type38', 'enum': [0, 1]},\n", " {'description': 'Soil_Type39', 'enum': [0, 1]},\n", " {'description': 'Soil_Type40', 'enum': [0, 1]}]},\n", " 'minItems': 58102,\n", " 'maxItems': 58102}" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_X.json_schema" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Wilderness_Area1', 'Wilderness_Area2', 'Wilderness_Area3', 'Wilderness_Area4', 'Soil_Type1', 'Soil_Type2', 'Soil_Type3', 'Soil_Type4', 'Soil_Type5', 'Soil_Type6', 'Soil_Type7', 'Soil_Type8', 'Soil_Type9', 'Soil_Type10', 'Soil_Type11', 'Soil_Type12', 'Soil_Type13', 'Soil_Type14', 'Soil_Type15', 'Soil_Type16', 'Soil_Type17', 'Soil_Type18', 'Soil_Type19', 'Soil_Type20', 'Soil_Type21', 'Soil_Type22', 'Soil_Type23', 'Soil_Type24', 'Soil_Type25', 'Soil_Type26', 'Soil_Type27', 'Soil_Type28', 'Soil_Type29', 'Soil_Type30', 'Soil_Type31', 'Soil_Type32', 'Soil_Type33', 'Soil_Type34', 'Soil_Type35', 'Soil_Type36', 'Soil_Type37', 'Soil_Type38', 'Soil_Type39', 'Soil_Type40']\n" ] } ], "source": [ "from lale.lib.lale import categorical\n", "print(categorical(max_values=2)(test_X))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "cluster:choice\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "\n", "project_0\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "feat_sel\n", "\n", "\n", "Feat-\n", "Sel\n", "\n", "\n", "\n", "\n", "project_0->feat_sel\n", "\n", "\n", "\n", "\n", "concat\n", "\n", "\n", "Concat\n", "\n", "\n", "\n", "\n", "feat_sel->concat\n", "\n", "\n", "\n", "\n", "project_1\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "norm\n", "\n", "\n", "Norm\n", "\n", "\n", "\n", "\n", "project_1->norm\n", "\n", "\n", "\n", "\n", "norm->concat\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "concat->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from lale.lib.lale import Project\n", "from lale.lib.lale import ConcatFeatures as Concat\n", "from sklearn.feature_selection import SelectKBest as FeatSel\n", "lale.wrap_imported_operators()\n", "\n", "binary_prep = Project(columns=categorical(max_values=2)) >> FeatSel\n", "other_prep = Project(drop_columns=categorical(max_values=2)) >> (Norm | NoOp)\n", "nonlin_planned = (binary_prep & other_prep) >> Concat >> KNN\n", "nonlin_planned.visualize()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|█████████| 3/3 [02:30<00:00, 50.20s/trial, best loss: -0.8618882829755578]\n", "CPU times: user 2min 32s, sys: 344 ms, total: 2min 33s\n", "Wall time: 2min 35s\n" ] } ], "source": [ "%%time\n", "nonlin_trained = nonlin_planned.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursion: Combinators ---\n", "\n", "| Lale feature | Name | Description | Scikit-learn feature |\n", "| ----------------------- | ---- | ------------ | ----------------------------------- |\n", "| >> or `make_pipeline` | pipe | feed to next | `make_pipeline` |\n", "| & or `make_union` | and | run both | `make_union` or `ColumnTransformer` |\n", "| | or `make_choice` | or | choose one | N/A (specific to given AutoML tool) |\n", "\n", "### --- Excursion: Interoperability ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "project_0\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "feat_sel\n", "\n", "\n", "Feat-\n", "Sel\n", "\n", "\n", "\n", "\n", "project_0->feat_sel\n", "\n", "\n", "\n", "\n", "concat\n", "\n", "\n", "Concat\n", "\n", "\n", "\n", "\n", "feat_sel->concat\n", "\n", "\n", "\n", "\n", "project_1\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "project_1->no_op\n", "\n", "\n", "\n", "\n", "no_op->concat\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "concat->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "```python\n", "project_0 = Project(columns=lale.lib.lale.categorical(max_values=2))\n", "feat_sel = FeatSel(k=8)\n", "pipeline_0 = make_pipeline(project_0, feat_sel)\n", "project_1 = Project(drop_columns=lale.lib.lale.categorical(max_values=2))\n", "pipeline_1 = make_pipeline(project_1, NoOp())\n", "union = make_union(pipeline_0, pipeline_1)\n", "knn = KNN(algorithm=\"kd_tree\", n_neighbors=7, weights=\"distance\")\n", "pipeline = make_pipeline(union, knn)\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "nonlin_trained.visualize()\n", "nonlin_trained.pretty_print(ipython_display=True, show_imports=False, combinators=False)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 88.6%\n", "CPU times: user 5.02 s, sys: 78.1 ms, total: 5.09 s\n", "Wall time: 5.13 s\n" ] } ], "source": [ "%%time\n", "nonlin_y = nonlin_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, nonlin_y):.1%}')" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Wilderness_Area1Wilderness_Area4Soil_Type2Soil_Type3Soil_Type4Soil_Type10Soil_Type38Soil_Type39
01.00.00.00.00.00.01.00.0
10.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.0
31.00.00.00.00.00.00.00.0
41.00.00.00.00.00.00.00.0
50.01.00.00.00.01.00.00.0
61.00.00.00.00.00.00.00.0
71.00.00.00.00.00.01.00.0
81.00.00.00.00.00.00.00.0
91.00.00.00.00.00.00.00.0
\n", "
" ], "text/plain": [ " Wilderness_Area1 Wilderness_Area4 Soil_Type2 Soil_Type3 Soil_Type4 \\\n", "0 1.0 0.0 0.0 0.0 0.0 \n", "1 0.0 0.0 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 0.0 0.0 \n", "3 1.0 0.0 0.0 0.0 0.0 \n", "4 1.0 0.0 0.0 0.0 0.0 \n", "5 0.0 1.0 0.0 0.0 0.0 \n", "6 1.0 0.0 0.0 0.0 0.0 \n", "7 1.0 0.0 0.0 0.0 0.0 \n", "8 1.0 0.0 0.0 0.0 0.0 \n", "9 1.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type10 Soil_Type38 Soil_Type39 \n", "0 0.0 1.0 0.0 \n", "1 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 \n", "3 0.0 0.0 0.0 \n", "4 0.0 0.0 0.0 \n", "5 1.0 0.0 0.0 \n", "6 0.0 0.0 0.0 \n", "7 0.0 1.0 0.0 \n", "8 0.0 0.0 0.0 \n", "9 0.0 0.0 0.0 " ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "binary_prep_trainable = Project(columns=categorical(max_values=2)) >> FeatSel(k=8)\n", "binary_prep_trained = binary_prep_trainable.fit(train_X, train_y)\n", "binary_prep_trained.transform(test_X.head(10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "- code and documentation: https://github.com/ibm/lale\n", "- more examples: https://nbviewer.jupyter.org/github/IBM/lale/tree/master/examples/\n", "- frequently asked questions: https://github.com/IBM/lale/blob/master/docs/faq.rst\n", "- arXiv paper: https://arxiv.org/pdf/1906.03957.pdf\n", "\n", "" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }