{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lale and its Impact on the Data Science Workflow\n", "\n", "Guillaume Baudart, Martin Hirzel, Kiran Kate, Pari Ram, and Avi Shinnar\n", "\n", "27 March 2020\n", "\n", "Examples, documentation, code: https://github.com/ibm/lale\n", "\n", "\"logo\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Value Proposition\n", "\n", "- **target user**: data scientist familiar with Python and scikit-learn\n", "- **scope**: data preparation and machine learning (including some DL)\n", "- **value**: consistent API for both manual machine learning and auto-ML\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: lale in /home/hirzel/python3.6venv/lib/python3.6/site-packages (0.3.5)\n", "Requirement already satisfied: lightgbm in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (2.2.3)\n", "Requirement already satisfied: astunparse in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (1.6.2)\n", "Requirement already satisfied: hyperopt==0.2.3 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (0.2.3)\n", "Requirement already satisfied: pandas<=0.25.3 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (0.25.0)\n", "Requirement already satisfied: xgboost in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (0.90)\n", "Requirement already satisfied: jsonsubschema in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (0.0.0)\n", "Requirement already satisfied: jsonschema in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (3.2.0)\n", "Requirement already satisfied: h5py in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (2.9.0)\n", "Requirement already satisfied: scikit-learn==0.20.3 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (0.20.3)\n", "Requirement already satisfied: scipy in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (1.3.0)\n", "Requirement already satisfied: numpy in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (1.17.0)\n", "Requirement already satisfied: graphviz in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (0.11.1)\n", "Requirement already satisfied: decorator in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from lale) (4.4.0)\n", "Requirement already satisfied: six<2.0,>=1.6.1 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from astunparse->lale) (1.12.0)\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from astunparse->lale) (0.33.4)\n", "Requirement already satisfied: future in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from hyperopt==0.2.3->lale) (0.17.1)\n", "Requirement already satisfied: cloudpickle in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from hyperopt==0.2.3->lale) (1.3.0)\n", "Requirement already satisfied: tqdm in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from hyperopt==0.2.3->lale) (4.32.2)\n", "Requirement already satisfied: networkx==2.2 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from hyperopt==0.2.3->lale) (2.2)\n", "Requirement already satisfied: pytz>=2017.2 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from pandas<=0.25.3->lale) (2019.1)\n", "Requirement already satisfied: python-dateutil>=2.6.1 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from pandas<=0.25.3->lale) (2.8.0)\n", "Requirement already satisfied: python-intervals in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from jsonsubschema->lale) (1.8.0)\n", "Requirement already satisfied: greenery in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from jsonsubschema->lale) (3.1)\n", "Requirement already satisfied: pyrsistent>=0.14.0 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from jsonschema->lale) (0.15.7)\n", "Requirement already satisfied: setuptools in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from jsonschema->lale) (41.0.1)\n", "Requirement already satisfied: attrs>=17.4.0 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from jsonschema->lale) (19.1.0)\n", "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from jsonschema->lale) (1.3.0)\n", "Requirement already satisfied: zipp>=0.5 in /home/hirzel/python3.6venv/lib/python3.6/site-packages (from importlib-metadata; python_version < \"3.8\"->jsonschema->lale) (0.5.2)\n" ] } ], "source": [ "!pip install lale" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape train_X_all (522910, 54), test_X (58102, 54)\n" ] } ], "source": [ "import lale.datasets\n", "(train_X_all, train_y_all), (test_X, test_y) = lale.datasets.covtype_df(test_size=0.1)\n", "print(f'shape train_X_all {train_X_all.shape}, test_X {test_X.shape}')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape train_X (52291, 54), other_X (470619, 54)\n" ] } ], "source": [ "import sklearn.model_selection\n", "train_X, other_X, train_y, other_y = sklearn.model_selection.train_test_split(\n", " train_X_all, train_y_all, test_size=0.9)\n", "print(f'shape train_X {train_X.shape}, other_X {other_X.shape}')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yElevationAspectSlopeHorizontal_Distance_To_HydrologyVertical_Distance_To_HydrologyHorizontal_Distance_To_RoadwaysHillshade_9amHillshade_NoonHillshade_3pmHorizontal_Distance_To_Fire_PointsWilderness_Area1Wilderness_Area2Wilderness_Area3Wilderness_Area4Soil_Type1Soil_Type2Soil_Type3Soil_Type4Soil_Type5Soil_Type6Soil_Type7Soil_Type8Soil_Type9Soil_Type10Soil_Type11Soil_Type12Soil_Type13Soil_Type14Soil_Type15Soil_Type16Soil_Type17Soil_Type18Soil_Type19Soil_Type20Soil_Type21Soil_Type22Soil_Type23Soil_Type24Soil_Type25Soil_Type26Soil_Type27Soil_Type28Soil_Type29Soil_Type30Soil_Type31Soil_Type32Soil_Type33Soil_Type34Soil_Type35Soil_Type36Soil_Type37Soil_Type38Soil_Type39Soil_Type40
32538423064.086.025.0702.0259.0721.0247.0189.056.01714.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.0
44217713277.031.015.0454.070.01570.0215.0206.0124.02754.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
18531623138.0257.014.0228.030.05649.0185.0248.0200.03051.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
18954132317.0150.08.0150.042.0644.0231.0240.0141.0781.00.00.00.01.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
42837422970.047.025.0319.0100.01919.0220.0178.080.03060.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
23463813278.0335.05.0360.035.05763.0209.0233.0163.0646.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
17220713175.0343.017.0162.03.04395.0183.0212.0166.02965.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.0
24080113355.0346.016.0180.06.01922.0188.0213.0163.04906.00.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.0
43527713154.0316.026.0339.0122.02688.0143.0209.0201.02720.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.00.00.00.00.00.00.00.00.00.0
29710073344.0313.020.00.00.04317.0163.0221.0196.04092.01.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01.00.00.0
\n", "
" ], "text/plain": [ " y Elevation Aspect Slope Horizontal_Distance_To_Hydrology \\\n", "325384 2 3064.0 86.0 25.0 702.0 \n", "442177 1 3277.0 31.0 15.0 454.0 \n", "185316 2 3138.0 257.0 14.0 228.0 \n", "189541 3 2317.0 150.0 8.0 150.0 \n", "428374 2 2970.0 47.0 25.0 319.0 \n", "234638 1 3278.0 335.0 5.0 360.0 \n", "172207 1 3175.0 343.0 17.0 162.0 \n", "240801 1 3355.0 346.0 16.0 180.0 \n", "435277 1 3154.0 316.0 26.0 339.0 \n", "297100 7 3344.0 313.0 20.0 0.0 \n", "\n", " Vertical_Distance_To_Hydrology Horizontal_Distance_To_Roadways \\\n", "325384 259.0 721.0 \n", "442177 70.0 1570.0 \n", "185316 30.0 5649.0 \n", "189541 42.0 644.0 \n", "428374 100.0 1919.0 \n", "234638 35.0 5763.0 \n", "172207 3.0 4395.0 \n", "240801 6.0 1922.0 \n", "435277 122.0 2688.0 \n", "297100 0.0 4317.0 \n", "\n", " Hillshade_9am Hillshade_Noon Hillshade_3pm \\\n", "325384 247.0 189.0 56.0 \n", "442177 215.0 206.0 124.0 \n", "185316 185.0 248.0 200.0 \n", "189541 231.0 240.0 141.0 \n", "428374 220.0 178.0 80.0 \n", "234638 209.0 233.0 163.0 \n", "172207 183.0 212.0 166.0 \n", "240801 188.0 213.0 163.0 \n", "435277 143.0 209.0 201.0 \n", "297100 163.0 221.0 196.0 \n", "\n", " Horizontal_Distance_To_Fire_Points Wilderness_Area1 \\\n", "325384 1714.0 1.0 \n", "442177 2754.0 0.0 \n", "185316 3051.0 1.0 \n", "189541 781.0 0.0 \n", "428374 3060.0 0.0 \n", "234638 646.0 1.0 \n", "172207 2965.0 1.0 \n", "240801 4906.0 0.0 \n", "435277 2720.0 1.0 \n", "297100 4092.0 1.0 \n", "\n", " Wilderness_Area2 Wilderness_Area3 Wilderness_Area4 Soil_Type1 \\\n", "325384 0.0 0.0 0.0 0.0 \n", "442177 0.0 1.0 0.0 0.0 \n", "185316 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 1.0 0.0 \n", "428374 0.0 1.0 0.0 0.0 \n", "234638 0.0 0.0 0.0 0.0 \n", "172207 0.0 0.0 0.0 0.0 \n", "240801 1.0 0.0 0.0 0.0 \n", "435277 0.0 0.0 0.0 0.0 \n", "297100 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type2 Soil_Type3 Soil_Type4 Soil_Type5 Soil_Type6 \\\n", "325384 0.0 0.0 0.0 0.0 0.0 \n", "442177 0.0 0.0 0.0 0.0 0.0 \n", "185316 0.0 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 1.0 0.0 0.0 \n", "428374 0.0 0.0 0.0 0.0 0.0 \n", "234638 0.0 0.0 0.0 0.0 0.0 \n", "172207 0.0 0.0 0.0 0.0 0.0 \n", "240801 0.0 0.0 0.0 0.0 0.0 \n", "435277 0.0 0.0 0.0 0.0 0.0 \n", "297100 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type7 Soil_Type8 Soil_Type9 Soil_Type10 Soil_Type11 \\\n", "325384 0.0 0.0 0.0 0.0 0.0 \n", "442177 0.0 0.0 0.0 0.0 0.0 \n", "185316 0.0 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 0.0 0.0 0.0 \n", "428374 0.0 0.0 0.0 0.0 0.0 \n", "234638 0.0 0.0 0.0 0.0 0.0 \n", "172207 0.0 0.0 0.0 0.0 0.0 \n", "240801 0.0 0.0 0.0 0.0 0.0 \n", "435277 0.0 0.0 0.0 0.0 0.0 \n", "297100 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type12 Soil_Type13 Soil_Type14 Soil_Type15 Soil_Type16 \\\n", "325384 0.0 0.0 0.0 0.0 0.0 \n", "442177 0.0 0.0 0.0 0.0 0.0 \n", "185316 0.0 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 0.0 0.0 0.0 \n", "428374 0.0 0.0 0.0 0.0 0.0 \n", "234638 0.0 0.0 0.0 0.0 0.0 \n", "172207 0.0 0.0 0.0 0.0 0.0 \n", "240801 0.0 0.0 0.0 0.0 0.0 \n", "435277 0.0 0.0 0.0 0.0 0.0 \n", "297100 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type17 Soil_Type18 Soil_Type19 Soil_Type20 Soil_Type21 \\\n", "325384 0.0 0.0 0.0 0.0 0.0 \n", "442177 0.0 0.0 0.0 0.0 0.0 \n", "185316 0.0 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 0.0 0.0 0.0 \n", "428374 0.0 0.0 0.0 0.0 0.0 \n", "234638 0.0 0.0 0.0 0.0 0.0 \n", "172207 0.0 0.0 0.0 0.0 0.0 \n", "240801 0.0 0.0 0.0 0.0 0.0 \n", "435277 0.0 0.0 0.0 0.0 0.0 \n", "297100 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type22 Soil_Type23 Soil_Type24 Soil_Type25 Soil_Type26 \\\n", "325384 0.0 0.0 0.0 0.0 0.0 \n", "442177 0.0 0.0 0.0 0.0 0.0 \n", "185316 1.0 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 0.0 0.0 0.0 \n", "428374 0.0 0.0 0.0 0.0 0.0 \n", "234638 1.0 0.0 0.0 0.0 0.0 \n", "172207 0.0 1.0 0.0 0.0 0.0 \n", "240801 0.0 0.0 0.0 0.0 0.0 \n", "435277 0.0 0.0 0.0 0.0 0.0 \n", "297100 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type27 Soil_Type28 Soil_Type29 Soil_Type30 Soil_Type31 \\\n", "325384 0.0 0.0 0.0 1.0 0.0 \n", "442177 0.0 0.0 0.0 0.0 0.0 \n", "185316 0.0 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 0.0 0.0 0.0 \n", "428374 0.0 0.0 0.0 0.0 0.0 \n", "234638 0.0 0.0 0.0 0.0 0.0 \n", "172207 0.0 0.0 0.0 0.0 0.0 \n", "240801 0.0 0.0 0.0 0.0 0.0 \n", "435277 0.0 0.0 1.0 0.0 0.0 \n", "297100 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type32 Soil_Type33 Soil_Type34 Soil_Type35 Soil_Type36 \\\n", "325384 0.0 0.0 0.0 0.0 0.0 \n", "442177 1.0 0.0 0.0 0.0 0.0 \n", "185316 0.0 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 0.0 0.0 0.0 \n", "428374 1.0 0.0 0.0 0.0 0.0 \n", "234638 0.0 0.0 0.0 0.0 0.0 \n", "172207 0.0 0.0 0.0 0.0 0.0 \n", "240801 0.0 0.0 0.0 0.0 0.0 \n", "435277 0.0 0.0 0.0 0.0 0.0 \n", "297100 0.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type37 Soil_Type38 Soil_Type39 Soil_Type40 \n", "325384 0.0 0.0 0.0 0.0 \n", "442177 0.0 0.0 0.0 0.0 \n", "185316 0.0 0.0 0.0 0.0 \n", "189541 0.0 0.0 0.0 0.0 \n", "428374 0.0 0.0 0.0 0.0 \n", "234638 0.0 0.0 0.0 0.0 \n", "172207 0.0 0.0 0.0 0.0 \n", "240801 0.0 0.0 1.0 0.0 \n", "435277 0.0 0.0 0.0 0.0 \n", "297100 0.0 1.0 0.0 0.0 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "pd.set_option('display.max_columns', None)\n", "pd.concat([pd.DataFrame({'y': train_y}, index=train_X.index),\n", " train_X], axis=1).tail(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Manual Pipeline" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "from xgboost import XGBClassifier as XGBoost\n", "lale.wrap_imported_operators()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "xg_boost\n", "\n", "\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "pca->xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "manual_trainable = PCA(n_components=6) >> XGBoost(n_estimators=3)\n", "manual_trainable.visualize()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.34 s, sys: 1.2 s, total: 3.55 s\n", "Wall time: 2.05 s\n" ] } ], "source": [ "%%time\n", "manual_trained = manual_trainable.fit(train_X, train_y)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 64.5%\n" ] } ], "source": [ "import sklearn.metrics\n", "manual_y = manual_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, manual_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter Tuning" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'description': 'Number of trees to fit.',\n", " 'type': 'integer',\n", " 'default': 100,\n", " 'minimumForOptimizer': 10,\n", " 'maximumForOptimizer': 1500}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "XGBoost.hyperparam_schema('n_estimators')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.pca.html\n" ] } ], "source": [ "print(PCA.documentation_url())" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from lale.lib.lale import Hyperopt\n", "import lale.schemas as schemas\n", "\n", "CustomPCA = PCA.customize_schema(n_components=schemas.Int(min=2, max=54))\n", "CustomXGBoost = XGBoost.customize_schema(n_estimators=schemas.Int(min=1, max=10))\n", "\n", "hpo_planned = CustomPCA >> CustomXGBoost\n", "hpo_trainable = Hyperopt(estimator=hpo_planned, max_evals=10, cv=3)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|███████| 10/10 [01:20<00:00, 6.64s/trial, best loss: -0.7885106540569516]\n", "CPU times: user 1min 50s, sys: 22.2 s, total: 2min 12s\n", "Wall time: 1min 28s\n" ] } ], "source": [ "%%time\n", "hpo_trained = hpo_trainable.fit(train_X, train_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursions: Types as Search Spaces ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 80.1%\n" ] } ], "source": [ "hpo_y = hpo_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, hpo_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspecting Automation Results" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "xg_boost\n", "\n", "\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "pca->xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline().visualize()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "from lale.lib.sklearn import PCA\n", "from lale.lib.xgboost.xgb_classifier import XGBoost\n", "import lale\n", "lale.wrap_imported_operators()\n", "\n", "pca = PCA(n_components=39, svd_solver='full')\n", "xg_boost = XGBoost(colsample_bylevel=0.6016063807304212, colsample_bytree=0.7763972782064467, learning_rate=0.16389357351003786, max_depth=10, min_child_weight=5, n_estimators=5, reg_alpha=0.10485915855270356, reg_lambda=0.9268502695024392, subsample=0.4503841871781402)\n", "pipeline = pca >> xg_boost\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline().pretty_print(ipython_display=True)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tidlosstimelog_lossstatus
name
p00-0.6679161.5322631.250336ok
p11-0.6355591.3950011.120280ok
p22-0.6702292.7456171.087269ok
p33-0.7885115.8763601.049096ok
p44-0.7189383.7255370.661428ok
p55-0.4820521.9521951.241045ok
p66-0.4820521.2094771.338511ok
p77-0.6694842.1067000.844174ok
p88-0.6323461.6121360.925707ok
p99-0.6223061.4742291.882534ok
\n", "
" ], "text/plain": [ " tid loss time log_loss status\n", "name \n", "p0 0 -0.667916 1.532263 1.250336 ok\n", "p1 1 -0.635559 1.395001 1.120280 ok\n", "p2 2 -0.670229 2.745617 1.087269 ok\n", "p3 3 -0.788511 5.876360 1.049096 ok\n", "p4 4 -0.718938 3.725537 0.661428 ok\n", "p5 5 -0.482052 1.952195 1.241045 ok\n", "p6 6 -0.482052 1.209477 1.338511 ok\n", "p7 7 -0.669484 2.106700 0.844174 ok\n", "p8 8 -0.632346 1.612136 0.925707 ok\n", "p9 9 -0.622306 1.474229 1.882534 ok" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hpo_trained.summary()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "p5\n" ] } ], "source": [ "worst_name = hpo_trained.summary().loss.argmax()\n", "print(worst_name)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "pca\n", "\n", "\n", "PCA\n", "\n", "\n", "\n", "\n", "xg_boost\n", "\n", "\n", "XG-\n", "Boost\n", "\n", "\n", "\n", "\n", "pca->xg_boost\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "```python\n", "pca = PCA(n_components=48, svd_solver='full', whiten=True)\n", "xg_boost = XGBoost(booster='gblinear', colsample_bylevel=0.41777546097517426, colsample_bytree=0.6852556915729863, learning_rate=0.4299362917360751, max_depth=15, min_child_weight=18, n_estimators=7, reg_alpha=0.5266202371276923, reg_lambda=0.494226267796831, subsample=0.8015579071911012)\n", "pipeline = pca >> xg_boost\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hpo_trained.get_pipeline(worst_name).visualize()\n", "hpo_trained.get_pipeline(worst_name).pretty_print(ipython_display=True, show_imports=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combined Algorithm Selection and Hyperparameter Tuning" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "cluster:choice_0\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "cluster:choice_1\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "\n", "norm\n", "\n", "\n", "Norm\n", "\n", "\n", "\n", "\n", "tree\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "\n", "norm->tree\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "lr\n", "\n", "\n", "LR\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.preprocessing import Normalizer as Norm\n", "from sklearn.linear_model import LogisticRegression as LR\n", "from sklearn.tree import DecisionTreeClassifier as Tree\n", "from sklearn.neighbors import KNeighborsClassifier as KNN\n", "from lale.lib.lale import NoOp\n", "lale.wrap_imported_operators()\n", "\n", "KNN = KNN.customize_schema(n_neighbors=schemas.Int(min=1, max=10))\n", "transp_planned = (Norm | NoOp) >> (Tree | LR(dual=True) | KNN)\n", "transp_planned.visualize()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|█████████| 3/3 [01:48<00:00, 32.59s/trial, best loss: -0.8376392446578157]\n", "CPU times: user 1min 50s, sys: 1.12 s, total: 1min 51s\n", "Wall time: 1min 49s\n" ] } ], "source": [ "%%time\n", "transp_trained = transp_planned.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursion: Bindings as Lifecycle ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "```python\n", "knn = KNN(algorithm='ball_tree', metric='manhattan', n_neighbors=9)\n", "pipeline = NoOp() >> knn\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "no_op->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "transp_trained.pretty_print(ipython_display=True, show_imports=False)\n", "transp_trained.visualize()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 86.6%\n", "CPU times: user 50.6 s, sys: 15.6 ms, total: 50.6 s\n", "Wall time: 50.7 s\n" ] } ], "source": [ "%%time\n", "transp_y = transp_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, transp_y):.1%}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Non-Linear Pipeline" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'description': 'Features of forest covertypes dataset (classification).',\n", " 'documentation_url': 'https://scikit-learn.org/0.20/datasets/index.html#forest-covertypes',\n", " 'type': 'array',\n", " 'items': {'type': 'array',\n", " 'minItems': 54,\n", " 'maxItems': 54,\n", " 'items': [{'description': 'Elevation', 'type': 'integer'},\n", " {'description': 'Aspect', 'type': 'integer'},\n", " {'description': 'Slope', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Hydrology', 'type': 'integer'},\n", " {'description': 'Vertical_Distance_To_Hydrology', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Roadways', 'type': 'integer'},\n", " {'description': 'Hillshade_9am', 'type': 'integer'},\n", " {'description': 'Hillshade_Noon', 'type': 'integer'},\n", " {'description': 'Hillshade_3pm', 'type': 'integer'},\n", " {'description': 'Horizontal_Distance_To_Fire_Points', 'type': 'integer'},\n", " {'description': 'Wilderness_Area1', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area2', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area3', 'enum': [0, 1]},\n", " {'description': 'Wilderness_Area4', 'enum': [0, 1]},\n", " {'description': 'Soil_Type1', 'enum': [0, 1]},\n", " {'description': 'Soil_Type2', 'enum': [0, 1]},\n", " {'description': 'Soil_Type3', 'enum': [0, 1]},\n", " {'description': 'Soil_Type4', 'enum': [0, 1]},\n", " {'description': 'Soil_Type5', 'enum': [0, 1]},\n", " {'description': 'Soil_Type6', 'enum': [0, 1]},\n", " {'description': 'Soil_Type7', 'enum': [0, 1]},\n", " {'description': 'Soil_Type8', 'enum': [0, 1]},\n", " {'description': 'Soil_Type9', 'enum': [0, 1]},\n", " {'description': 'Soil_Type10', 'enum': [0, 1]},\n", " {'description': 'Soil_Type11', 'enum': [0, 1]},\n", " {'description': 'Soil_Type12', 'enum': [0, 1]},\n", " {'description': 'Soil_Type13', 'enum': [0, 1]},\n", " {'description': 'Soil_Type14', 'enum': [0, 1]},\n", " {'description': 'Soil_Type15', 'enum': [0, 1]},\n", " {'description': 'Soil_Type16', 'enum': [0, 1]},\n", " {'description': 'Soil_Type17', 'enum': [0, 1]},\n", " {'description': 'Soil_Type18', 'enum': [0, 1]},\n", " {'description': 'Soil_Type19', 'enum': [0, 1]},\n", " {'description': 'Soil_Type20', 'enum': [0, 1]},\n", " {'description': 'Soil_Type21', 'enum': [0, 1]},\n", " {'description': 'Soil_Type22', 'enum': [0, 1]},\n", " {'description': 'Soil_Type23', 'enum': [0, 1]},\n", " {'description': 'Soil_Type24', 'enum': [0, 1]},\n", " {'description': 'Soil_Type25', 'enum': [0, 1]},\n", " {'description': 'Soil_Type26', 'enum': [0, 1]},\n", " {'description': 'Soil_Type27', 'enum': [0, 1]},\n", " {'description': 'Soil_Type28', 'enum': [0, 1]},\n", " {'description': 'Soil_Type29', 'enum': [0, 1]},\n", " {'description': 'Soil_Type30', 'enum': [0, 1]},\n", " {'description': 'Soil_Type31', 'enum': [0, 1]},\n", " {'description': 'Soil_Type32', 'enum': [0, 1]},\n", " {'description': 'Soil_Type33', 'enum': [0, 1]},\n", " {'description': 'Soil_Type34', 'enum': [0, 1]},\n", " {'description': 'Soil_Type35', 'enum': [0, 1]},\n", " {'description': 'Soil_Type36', 'enum': [0, 1]},\n", " {'description': 'Soil_Type37', 'enum': [0, 1]},\n", " {'description': 'Soil_Type38', 'enum': [0, 1]},\n", " {'description': 'Soil_Type39', 'enum': [0, 1]},\n", " {'description': 'Soil_Type40', 'enum': [0, 1]}]},\n", " 'minItems': 58102,\n", " 'maxItems': 58102}" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_X.json_schema" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "other columns: Elevation, Aspect, Slope, Horizontal_Distance_To_Hydrology, Vertical_Distance_To_Hydrology, Horizontal_Distance_To_Roadways, Hillshade_9am, Hillshade_Noon, Hillshade_3pm, Horizontal_Distance_To_Fire_Points\n" ] } ], "source": [ "area_columns = [f'Wilderness_Area{i}' for i in range(1, 5)]\n", "soil_columns = [f'Soil_Type{i}' for i in range(1, 41)]\n", "binary_columns = area_columns + soil_columns\n", "other_columns = [c for c in train_X.columns if c not in binary_columns]\n", "print(f'other columns: {\", \".join(other_columns)}')" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "cluster:choice\n", "\n", "\n", "Choice\n", "\n", "\n", "\n", "\n", "project_0\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "feat_sel\n", "\n", "\n", "Feat-\n", "Sel\n", "\n", "\n", "\n", "\n", "project_0->feat_sel\n", "\n", "\n", "\n", "\n", "concat\n", "\n", "\n", "Concat\n", "\n", "\n", "\n", "\n", "feat_sel->concat\n", "\n", "\n", "\n", "\n", "project_1\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "norm\n", "\n", "\n", "Norm\n", "\n", "\n", "\n", "\n", "project_1->norm\n", "\n", "\n", "\n", "\n", "norm->concat\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "concat->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from lale.lib.lale import Project\n", "from lale.lib.lale import ConcatFeatures as Concat\n", "from sklearn.feature_selection import SelectKBest as FeatSel\n", "lale.wrap_imported_operators()\n", "\n", "binary_prep = Project(columns=binary_columns) >> FeatSel\n", "other_prep = Project(columns=other_columns) >> (Norm | NoOp)\n", "nonlin_planned = (binary_prep & other_prep) >> Concat >> KNN\n", "nonlin_planned.visualize()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100%|█████████| 3/3 [02:08<00:00, 34.88s/trial, best loss: -0.8584651324517477]\n", "CPU times: user 2min 9s, sys: 62.5 ms, total: 2min 9s\n", "Wall time: 2min 10s\n" ] } ], "source": [ "%%time\n", "nonlin_trained = nonlin_planned.auto_configure(\n", " train_X, train_y, optimizer=Hyperopt, cv=3, max_evals=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### --- Excursion: Combinators ---\n", "\n", "| Lale feature | Name | Description | Scikit-learn feature |\n", "| ----------------------- | ---- | ------------ | ----------------------------------- |\n", "| >> or `make_pipeline` | pipe | feed to next | `make_pipeline` |\n", "| & or `make_union` | and | run both | `make_union` or `ColumnTransformer` |\n", "| | or `make_choice` | or | choose one | N/A (specific to given AutoML tool) |\n", "\n", "### --- Excursion: Interoperability ---\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "cluster:(root)\n", "\n", "\n", "\n", "\n", "\n", "project_0\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "feat_sel\n", "\n", "\n", "Feat-\n", "Sel\n", "\n", "\n", "\n", "\n", "project_0->feat_sel\n", "\n", "\n", "\n", "\n", "concat\n", "\n", "\n", "Concat\n", "\n", "\n", "\n", "\n", "feat_sel->concat\n", "\n", "\n", "\n", "\n", "project_1\n", "\n", "\n", "Project\n", "\n", "\n", "\n", "\n", "no_op\n", "\n", "\n", "No-\n", "Op\n", "\n", "\n", "\n", "\n", "project_1->no_op\n", "\n", "\n", "\n", "\n", "no_op->concat\n", "\n", "\n", "\n", "\n", "knn\n", "\n", "\n", "KNN\n", "\n", "\n", "\n", "\n", "concat->knn\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/markdown": [ "```python\n", "project_0 = Project(columns=['Wilderness_Area1', 'Wilderness_Area2', 'Wilderness_Area3', 'Wilderness_Area4', 'Soil_Type1', 'Soil_Type2', 'Soil_Type3', 'Soil_Type4', 'Soil_Type5', 'Soil_Type6', 'Soil_Type7', 'Soil_Type8', 'Soil_Type9', 'Soil_Type10', 'Soil_Type11', 'Soil_Type12', 'Soil_Type13', 'Soil_Type14', 'Soil_Type15', 'Soil_Type16', 'Soil_Type17', 'Soil_Type18', 'Soil_Type19', 'Soil_Type20', 'Soil_Type21', 'Soil_Type22', 'Soil_Type23', 'Soil_Type24', 'Soil_Type25', 'Soil_Type26', 'Soil_Type27', 'Soil_Type28', 'Soil_Type29', 'Soil_Type30', 'Soil_Type31', 'Soil_Type32', 'Soil_Type33', 'Soil_Type34', 'Soil_Type35', 'Soil_Type36', 'Soil_Type37', 'Soil_Type38', 'Soil_Type39', 'Soil_Type40'])\n", "feat_sel = FeatSel(k=8)\n", "pipeline_0 = make_pipeline(project_0, feat_sel)\n", "project_1 = Project(columns=['Elevation', 'Aspect', 'Slope', 'Horizontal_Distance_To_Hydrology', 'Vertical_Distance_To_Hydrology', 'Horizontal_Distance_To_Roadways', 'Hillshade_9am', 'Hillshade_Noon', 'Hillshade_3pm', 'Horizontal_Distance_To_Fire_Points'])\n", "pipeline_1 = make_pipeline(project_1, NoOp())\n", "union = make_union(pipeline_0, pipeline_1)\n", "knn = KNN(algorithm='kd_tree', n_neighbors=7, weights='distance')\n", "pipeline = make_pipeline(union, knn)\n", "```" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "nonlin_trained.visualize()\n", "nonlin_trained.pretty_print(ipython_display=True, show_imports=False, combinators=False)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy 88.6%\n", "CPU times: user 4.12 s, sys: 93.8 ms, total: 4.22 s\n", "Wall time: 4.19 s\n" ] } ], "source": [ "%%time\n", "nonlin_y = nonlin_trained.predict(test_X)\n", "print(f'accuracy {sklearn.metrics.accuracy_score(test_y, nonlin_y):.1%}')" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Wilderness_Area1Wilderness_Area4Soil_Type2Soil_Type3Soil_Type4Soil_Type10Soil_Type38Soil_Type39
01.00.00.00.00.00.01.00.0
10.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.0
31.00.00.00.00.00.00.00.0
41.00.00.00.00.00.00.00.0
50.01.00.00.00.01.00.00.0
61.00.00.00.00.00.00.00.0
71.00.00.00.00.00.01.00.0
81.00.00.00.00.00.00.00.0
91.00.00.00.00.00.00.00.0
\n", "
" ], "text/plain": [ " Wilderness_Area1 Wilderness_Area4 Soil_Type2 Soil_Type3 Soil_Type4 \\\n", "0 1.0 0.0 0.0 0.0 0.0 \n", "1 0.0 0.0 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 0.0 0.0 \n", "3 1.0 0.0 0.0 0.0 0.0 \n", "4 1.0 0.0 0.0 0.0 0.0 \n", "5 0.0 1.0 0.0 0.0 0.0 \n", "6 1.0 0.0 0.0 0.0 0.0 \n", "7 1.0 0.0 0.0 0.0 0.0 \n", "8 1.0 0.0 0.0 0.0 0.0 \n", "9 1.0 0.0 0.0 0.0 0.0 \n", "\n", " Soil_Type10 Soil_Type38 Soil_Type39 \n", "0 0.0 1.0 0.0 \n", "1 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 \n", "3 0.0 0.0 0.0 \n", "4 0.0 0.0 0.0 \n", "5 1.0 0.0 0.0 \n", "6 0.0 0.0 0.0 \n", "7 0.0 1.0 0.0 \n", "8 0.0 0.0 0.0 \n", "9 0.0 0.0 0.0 " ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "binary_prep_trainable = Project(columns=binary_columns) >> FeatSel(k=8)\n", "binary_prep_trained = binary_prep_trainable.fit(train_X, train_y)\n", "binary_prep_trained.transform(test_X.head(10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "- code and documentation: https://github.com/ibm/lale\n", "- more examples: https://nbviewer.jupyter.org/github/IBM/lale/tree/master/examples/\n", "- frequently asked questions: https://github.com/IBM/lale/blob/master/docs/faq.rst\n", "- arXiv paper: https://arxiv.org/pdf/1906.03957.pdf\n", "\n", "" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }