{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyCaret 2 Classification Example\n",
"This notebook is created using PyCaret 2.0. Last updated : 28-07-2020"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pycaret-nightly-0.39\n"
]
}
],
"source": [
"# check version\n",
"from pycaret.utils import version\n",
"version()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Data Repository"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Dataset | \n",
" Data Types | \n",
" Default Task | \n",
" Target Variable | \n",
" # Instances | \n",
" # Attributes | \n",
" Missing Values | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" anomaly | \n",
" Multivariate | \n",
" Anomaly Detection | \n",
" None | \n",
" 1000 | \n",
" 10 | \n",
" N | \n",
"
\n",
" \n",
" 1 | \n",
" france | \n",
" Multivariate | \n",
" Association Rule Mining | \n",
" InvoiceNo, Description | \n",
" 8557 | \n",
" 8 | \n",
" N | \n",
"
\n",
" \n",
" 2 | \n",
" germany | \n",
" Multivariate | \n",
" Association Rule Mining | \n",
" InvoiceNo, Description | \n",
" 9495 | \n",
" 8 | \n",
" N | \n",
"
\n",
" \n",
" 3 | \n",
" bank | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" deposit | \n",
" 45211 | \n",
" 17 | \n",
" N | \n",
"
\n",
" \n",
" 4 | \n",
" blood | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" Class | \n",
" 748 | \n",
" 5 | \n",
" N | \n",
"
\n",
" \n",
" 5 | \n",
" cancer | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" Class | \n",
" 683 | \n",
" 10 | \n",
" N | \n",
"
\n",
" \n",
" 6 | \n",
" credit | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" default | \n",
" 24000 | \n",
" 24 | \n",
" N | \n",
"
\n",
" \n",
" 7 | \n",
" diabetes | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" Class variable | \n",
" 768 | \n",
" 9 | \n",
" N | \n",
"
\n",
" \n",
" 8 | \n",
" electrical_grid | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" stabf | \n",
" 10000 | \n",
" 14 | \n",
" N | \n",
"
\n",
" \n",
" 9 | \n",
" employee | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" left | \n",
" 14999 | \n",
" 10 | \n",
" N | \n",
"
\n",
" \n",
" 10 | \n",
" heart | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" DEATH | \n",
" 200 | \n",
" 16 | \n",
" N | \n",
"
\n",
" \n",
" 11 | \n",
" heart_disease | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" Disease | \n",
" 270 | \n",
" 14 | \n",
" N | \n",
"
\n",
" \n",
" 12 | \n",
" hepatitis | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" Class | \n",
" 154 | \n",
" 32 | \n",
" Y | \n",
"
\n",
" \n",
" 13 | \n",
" income | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" income >50K | \n",
" 32561 | \n",
" 14 | \n",
" Y | \n",
"
\n",
" \n",
" 14 | \n",
" juice | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" Purchase | \n",
" 1070 | \n",
" 15 | \n",
" N | \n",
"
\n",
" \n",
" 15 | \n",
" nba | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" TARGET_5Yrs | \n",
" 1340 | \n",
" 21 | \n",
" N | \n",
"
\n",
" \n",
" 16 | \n",
" wine | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" type | \n",
" 6498 | \n",
" 13 | \n",
" N | \n",
"
\n",
" \n",
" 17 | \n",
" telescope | \n",
" Multivariate | \n",
" Classification (Binary) | \n",
" Class | \n",
" 19020 | \n",
" 11 | \n",
" N | \n",
"
\n",
" \n",
" 18 | \n",
" glass | \n",
" Multivariate | \n",
" Classification (Multiclass) | \n",
" Type | \n",
" 214 | \n",
" 10 | \n",
" N | \n",
"
\n",
" \n",
" 19 | \n",
" iris | \n",
" Multivariate | \n",
" Classification (Multiclass) | \n",
" species | \n",
" 150 | \n",
" 5 | \n",
" N | \n",
"
\n",
" \n",
" 20 | \n",
" poker | \n",
" Multivariate | \n",
" Classification (Multiclass) | \n",
" CLASS | \n",
" 100000 | \n",
" 11 | \n",
" N | \n",
"
\n",
" \n",
" 21 | \n",
" questions | \n",
" Multivariate | \n",
" Classification (Multiclass) | \n",
" Next_Question | \n",
" 499 | \n",
" 4 | \n",
" N | \n",
"
\n",
" \n",
" 22 | \n",
" satellite | \n",
" Multivariate | \n",
" Classification (Multiclass) | \n",
" Class | \n",
" 6435 | \n",
" 37 | \n",
" N | \n",
"
\n",
" \n",
" 23 | \n",
" asia_gdp | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 40 | \n",
" 11 | \n",
" N | \n",
"
\n",
" \n",
" 24 | \n",
" elections | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 3195 | \n",
" 54 | \n",
" Y | \n",
"
\n",
" \n",
" 25 | \n",
" facebook | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 7050 | \n",
" 12 | \n",
" N | \n",
"
\n",
" \n",
" 26 | \n",
" ipl | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 153 | \n",
" 25 | \n",
" N | \n",
"
\n",
" \n",
" 27 | \n",
" jewellery | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 505 | \n",
" 4 | \n",
" N | \n",
"
\n",
" \n",
" 28 | \n",
" mice | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 1080 | \n",
" 82 | \n",
" Y | \n",
"
\n",
" \n",
" 29 | \n",
" migration | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 233 | \n",
" 12 | \n",
" N | \n",
"
\n",
" \n",
" 30 | \n",
" perfume | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 20 | \n",
" 29 | \n",
" N | \n",
"
\n",
" \n",
" 31 | \n",
" pokemon | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 800 | \n",
" 13 | \n",
" Y | \n",
"
\n",
" \n",
" 32 | \n",
" population | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 255 | \n",
" 56 | \n",
" Y | \n",
"
\n",
" \n",
" 33 | \n",
" public_health | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 224 | \n",
" 21 | \n",
" N | \n",
"
\n",
" \n",
" 34 | \n",
" seeds | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 210 | \n",
" 7 | \n",
" N | \n",
"
\n",
" \n",
" 35 | \n",
" wholesale | \n",
" Multivariate | \n",
" Clustering | \n",
" None | \n",
" 440 | \n",
" 8 | \n",
" N | \n",
"
\n",
" \n",
" 36 | \n",
" tweets | \n",
" Text | \n",
" NLP | \n",
" tweet | \n",
" 8594 | \n",
" 2 | \n",
" N | \n",
"
\n",
" \n",
" 37 | \n",
" amazon | \n",
" Text | \n",
" NLP / Classification | \n",
" reviewText | \n",
" 20000 | \n",
" 2 | \n",
" N | \n",
"
\n",
" \n",
" 38 | \n",
" kiva | \n",
" Text | \n",
" NLP / Classification | \n",
" en | \n",
" 6818 | \n",
" 7 | \n",
" N | \n",
"
\n",
" \n",
" 39 | \n",
" spx | \n",
" Text | \n",
" NLP / Regression | \n",
" text | \n",
" 874 | \n",
" 4 | \n",
" N | \n",
"
\n",
" \n",
" 40 | \n",
" wikipedia | \n",
" Text | \n",
" NLP / Classification | \n",
" Text | \n",
" 500 | \n",
" 3 | \n",
" N | \n",
"
\n",
" \n",
" 41 | \n",
" automobile | \n",
" Multivariate | \n",
" Regression | \n",
" price | \n",
" 202 | \n",
" 26 | \n",
" Y | \n",
"
\n",
" \n",
" 42 | \n",
" bike | \n",
" Multivariate | \n",
" Regression | \n",
" cnt | \n",
" 17379 | \n",
" 15 | \n",
" N | \n",
"
\n",
" \n",
" 43 | \n",
" boston | \n",
" Multivariate | \n",
" Regression | \n",
" medv | \n",
" 506 | \n",
" 14 | \n",
" N | \n",
"
\n",
" \n",
" 44 | \n",
" concrete | \n",
" Multivariate | \n",
" Regression | \n",
" strength | \n",
" 1030 | \n",
" 9 | \n",
" N | \n",
"
\n",
" \n",
" 45 | \n",
" diamond | \n",
" Multivariate | \n",
" Regression | \n",
" Price | \n",
" 6000 | \n",
" 8 | \n",
" N | \n",
"
\n",
" \n",
" 46 | \n",
" energy | \n",
" Multivariate | \n",
" Regression | \n",
" Heating Load / Cooling Load | \n",
" 768 | \n",
" 10 | \n",
" N | \n",
"
\n",
" \n",
" 47 | \n",
" forest | \n",
" Multivariate | \n",
" Regression | \n",
" area | \n",
" 517 | \n",
" 13 | \n",
" N | \n",
"
\n",
" \n",
" 48 | \n",
" gold | \n",
" Multivariate | \n",
" Regression | \n",
" Gold_T+22 | \n",
" 2558 | \n",
" 121 | \n",
" N | \n",
"
\n",
" \n",
" 49 | \n",
" house | \n",
" Multivariate | \n",
" Regression | \n",
" SalePrice | \n",
" 1461 | \n",
" 81 | \n",
" Y | \n",
"
\n",
" \n",
" 50 | \n",
" insurance | \n",
" Multivariate | \n",
" Regression | \n",
" charges | \n",
" 1338 | \n",
" 7 | \n",
" N | \n",
"
\n",
" \n",
" 51 | \n",
" parkinsons | \n",
" Multivariate | \n",
" Regression | \n",
" PPE | \n",
" 5875 | \n",
" 22 | \n",
" N | \n",
"
\n",
" \n",
" 52 | \n",
" traffic | \n",
" Multivariate | \n",
" Regression | \n",
" traffic_volume | \n",
" 48204 | \n",
" 8 | \n",
" N | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Dataset Data Types Default Task \\\n",
"0 anomaly Multivariate Anomaly Detection \n",
"1 france Multivariate Association Rule Mining \n",
"2 germany Multivariate Association Rule Mining \n",
"3 bank Multivariate Classification (Binary) \n",
"4 blood Multivariate Classification (Binary) \n",
"5 cancer Multivariate Classification (Binary) \n",
"6 credit Multivariate Classification (Binary) \n",
"7 diabetes Multivariate Classification (Binary) \n",
"8 electrical_grid Multivariate Classification (Binary) \n",
"9 employee Multivariate Classification (Binary) \n",
"10 heart Multivariate Classification (Binary) \n",
"11 heart_disease Multivariate Classification (Binary) \n",
"12 hepatitis Multivariate Classification (Binary) \n",
"13 income Multivariate Classification (Binary) \n",
"14 juice Multivariate Classification (Binary) \n",
"15 nba Multivariate Classification (Binary) \n",
"16 wine Multivariate Classification (Binary) \n",
"17 telescope Multivariate Classification (Binary) \n",
"18 glass Multivariate Classification (Multiclass) \n",
"19 iris Multivariate Classification (Multiclass) \n",
"20 poker Multivariate Classification (Multiclass) \n",
"21 questions Multivariate Classification (Multiclass) \n",
"22 satellite Multivariate Classification (Multiclass) \n",
"23 asia_gdp Multivariate Clustering \n",
"24 elections Multivariate Clustering \n",
"25 facebook Multivariate Clustering \n",
"26 ipl Multivariate Clustering \n",
"27 jewellery Multivariate Clustering \n",
"28 mice Multivariate Clustering \n",
"29 migration Multivariate Clustering \n",
"30 perfume Multivariate Clustering \n",
"31 pokemon Multivariate Clustering \n",
"32 population Multivariate Clustering \n",
"33 public_health Multivariate Clustering \n",
"34 seeds Multivariate Clustering \n",
"35 wholesale Multivariate Clustering \n",
"36 tweets Text NLP \n",
"37 amazon Text NLP / Classification \n",
"38 kiva Text NLP / Classification \n",
"39 spx Text NLP / Regression \n",
"40 wikipedia Text NLP / Classification \n",
"41 automobile Multivariate Regression \n",
"42 bike Multivariate Regression \n",
"43 boston Multivariate Regression \n",
"44 concrete Multivariate Regression \n",
"45 diamond Multivariate Regression \n",
"46 energy Multivariate Regression \n",
"47 forest Multivariate Regression \n",
"48 gold Multivariate Regression \n",
"49 house Multivariate Regression \n",
"50 insurance Multivariate Regression \n",
"51 parkinsons Multivariate Regression \n",
"52 traffic Multivariate Regression \n",
"\n",
" Target Variable # Instances # Attributes Missing Values \n",
"0 None 1000 10 N \n",
"1 InvoiceNo, Description 8557 8 N \n",
"2 InvoiceNo, Description 9495 8 N \n",
"3 deposit 45211 17 N \n",
"4 Class 748 5 N \n",
"5 Class 683 10 N \n",
"6 default 24000 24 N \n",
"7 Class variable 768 9 N \n",
"8 stabf 10000 14 N \n",
"9 left 14999 10 N \n",
"10 DEATH 200 16 N \n",
"11 Disease 270 14 N \n",
"12 Class 154 32 Y \n",
"13 income >50K 32561 14 Y \n",
"14 Purchase 1070 15 N \n",
"15 TARGET_5Yrs 1340 21 N \n",
"16 type 6498 13 N \n",
"17 Class 19020 11 N \n",
"18 Type 214 10 N \n",
"19 species 150 5 N \n",
"20 CLASS 100000 11 N \n",
"21 Next_Question 499 4 N \n",
"22 Class 6435 37 N \n",
"23 None 40 11 N \n",
"24 None 3195 54 Y \n",
"25 None 7050 12 N \n",
"26 None 153 25 N \n",
"27 None 505 4 N \n",
"28 None 1080 82 Y \n",
"29 None 233 12 N \n",
"30 None 20 29 N \n",
"31 None 800 13 Y \n",
"32 None 255 56 Y \n",
"33 None 224 21 N \n",
"34 None 210 7 N \n",
"35 None 440 8 N \n",
"36 tweet 8594 2 N \n",
"37 reviewText 20000 2 N \n",
"38 en 6818 7 N \n",
"39 text 874 4 N \n",
"40 Text 500 3 N \n",
"41 price 202 26 Y \n",
"42 cnt 17379 15 N \n",
"43 medv 506 14 N \n",
"44 strength 1030 9 N \n",
"45 Price 6000 8 N \n",
"46 Heating Load / Cooling Load 768 10 N \n",
"47 area 517 13 N \n",
"48 Gold_T+22 2558 121 N \n",
"49 SalePrice 1461 81 Y \n",
"50 charges 1338 7 N \n",
"51 PPE 5875 22 N \n",
"52 traffic_volume 48204 8 N "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from pycaret.datasets import get_data\n",
"index = get_data('index')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Id | \n",
" Purchase | \n",
" WeekofPurchase | \n",
" StoreID | \n",
" PriceCH | \n",
" PriceMM | \n",
" DiscCH | \n",
" DiscMM | \n",
" SpecialCH | \n",
" SpecialMM | \n",
" LoyalCH | \n",
" SalePriceMM | \n",
" SalePriceCH | \n",
" PriceDiff | \n",
" Store7 | \n",
" PctDiscMM | \n",
" PctDiscCH | \n",
" ListPriceDiff | \n",
" STORE | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" CH | \n",
" 237 | \n",
" 1 | \n",
" 1.75 | \n",
" 1.99 | \n",
" 0.00 | \n",
" 0.0 | \n",
" 0 | \n",
" 0 | \n",
" 0.500000 | \n",
" 1.99 | \n",
" 1.75 | \n",
" 0.24 | \n",
" No | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.24 | \n",
" 1 | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" CH | \n",
" 239 | \n",
" 1 | \n",
" 1.75 | \n",
" 1.99 | \n",
" 0.00 | \n",
" 0.3 | \n",
" 0 | \n",
" 1 | \n",
" 0.600000 | \n",
" 1.69 | \n",
" 1.75 | \n",
" -0.06 | \n",
" No | \n",
" 0.150754 | \n",
" 0.000000 | \n",
" 0.24 | \n",
" 1 | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" CH | \n",
" 245 | \n",
" 1 | \n",
" 1.86 | \n",
" 2.09 | \n",
" 0.17 | \n",
" 0.0 | \n",
" 0 | \n",
" 0 | \n",
" 0.680000 | \n",
" 2.09 | \n",
" 1.69 | \n",
" 0.40 | \n",
" No | \n",
" 0.000000 | \n",
" 0.091398 | \n",
" 0.23 | \n",
" 1 | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" MM | \n",
" 227 | \n",
" 1 | \n",
" 1.69 | \n",
" 1.69 | \n",
" 0.00 | \n",
" 0.0 | \n",
" 0 | \n",
" 0 | \n",
" 0.400000 | \n",
" 1.69 | \n",
" 1.69 | \n",
" 0.00 | \n",
" No | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.00 | \n",
" 1 | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" CH | \n",
" 228 | \n",
" 7 | \n",
" 1.69 | \n",
" 1.69 | \n",
" 0.00 | \n",
" 0.0 | \n",
" 0 | \n",
" 0 | \n",
" 0.956535 | \n",
" 1.69 | \n",
" 1.69 | \n",
" 0.00 | \n",
" Yes | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.00 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Id Purchase WeekofPurchase StoreID PriceCH PriceMM DiscCH DiscMM \\\n",
"0 1 CH 237 1 1.75 1.99 0.00 0.0 \n",
"1 2 CH 239 1 1.75 1.99 0.00 0.3 \n",
"2 3 CH 245 1 1.86 2.09 0.17 0.0 \n",
"3 4 MM 227 1 1.69 1.69 0.00 0.0 \n",
"4 5 CH 228 7 1.69 1.69 0.00 0.0 \n",
"\n",
" SpecialCH SpecialMM LoyalCH SalePriceMM SalePriceCH PriceDiff Store7 \\\n",
"0 0 0 0.500000 1.99 1.75 0.24 No \n",
"1 0 1 0.600000 1.69 1.75 -0.06 No \n",
"2 0 0 0.680000 2.09 1.69 0.40 No \n",
"3 0 0 0.400000 1.69 1.69 0.00 No \n",
"4 0 0 0.956535 1.69 1.69 0.00 Yes \n",
"\n",
" PctDiscMM PctDiscCH ListPriceDiff STORE \n",
"0 0.000000 0.000000 0.24 1 \n",
"1 0.150754 0.000000 0.24 1 \n",
"2 0.000000 0.091398 0.23 1 \n",
"3 0.000000 0.000000 0.00 1 \n",
"4 0.000000 0.000000 0.00 0 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data = get_data('juice')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Initialize Setup"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setup Succesfully Completed!\n"
]
},
{
"data": {
"text/html": [
" | Description | Value |
\n",
" \n",
" 0 | \n",
" session_id | \n",
" 123 | \n",
"
\n",
" \n",
" 1 | \n",
" Target Type | \n",
" Binary | \n",
"
\n",
" \n",
" 2 | \n",
" Label Encoded | \n",
" CH: 0, MM: 1 | \n",
"
\n",
" \n",
" 3 | \n",
" Original Data | \n",
" (1070, 19) | \n",
"
\n",
" \n",
" 4 | \n",
" Missing Values | \n",
" False | \n",
"
\n",
" \n",
" 5 | \n",
" Numeric Features | \n",
" 13 | \n",
"
\n",
" \n",
" 6 | \n",
" Categorical Features | \n",
" 5 | \n",
"
\n",
" \n",
" 7 | \n",
" Ordinal Features | \n",
" False | \n",
"
\n",
" \n",
" 8 | \n",
" High Cardinality Features | \n",
" False | \n",
"
\n",
" \n",
" 9 | \n",
" High Cardinality Method | \n",
" None | \n",
"
\n",
" \n",
" 10 | \n",
" Sampled Data | \n",
" (1070, 19) | \n",
"
\n",
" \n",
" 11 | \n",
" Transformed Train Set | \n",
" (748, 28) | \n",
"
\n",
" \n",
" 12 | \n",
" Transformed Test Set | \n",
" (322, 28) | \n",
"
\n",
" \n",
" 13 | \n",
" Numeric Imputer | \n",
" mean | \n",
"
\n",
" \n",
" 14 | \n",
" Categorical Imputer | \n",
" constant | \n",
"
\n",
" \n",
" 15 | \n",
" Normalize | \n",
" False | \n",
"
\n",
" \n",
" 16 | \n",
" Normalize Method | \n",
" None | \n",
"
\n",
" \n",
" 17 | \n",
" Transformation | \n",
" False | \n",
"
\n",
" \n",
" 18 | \n",
" Transformation Method | \n",
" None | \n",
"
\n",
" \n",
" 19 | \n",
" PCA | \n",
" False | \n",
"
\n",
" \n",
" 20 | \n",
" PCA Method | \n",
" None | \n",
"
\n",
" \n",
" 21 | \n",
" PCA Components | \n",
" None | \n",
"
\n",
" \n",
" 22 | \n",
" Ignore Low Variance | \n",
" False | \n",
"
\n",
" \n",
" 23 | \n",
" Combine Rare Levels | \n",
" False | \n",
"
\n",
" \n",
" 24 | \n",
" Rare Level Threshold | \n",
" None | \n",
"
\n",
" \n",
" 25 | \n",
" Numeric Binning | \n",
" False | \n",
"
\n",
" \n",
" 26 | \n",
" Remove Outliers | \n",
" False | \n",
"
\n",
" \n",
" 27 | \n",
" Outliers Threshold | \n",
" None | \n",
"
\n",
" \n",
" 28 | \n",
" Remove Multicollinearity | \n",
" False | \n",
"
\n",
" \n",
" 29 | \n",
" Multicollinearity Threshold | \n",
" None | \n",
"
\n",
" \n",
" 30 | \n",
" Clustering | \n",
" False | \n",
"
\n",
" \n",
" 31 | \n",
" Clustering Iteration | \n",
" None | \n",
"
\n",
" \n",
" 32 | \n",
" Polynomial Features | \n",
" False | \n",
"
\n",
" \n",
" 33 | \n",
" Polynomial Degree | \n",
" None | \n",
"
\n",
" \n",
" 34 | \n",
" Trignometry Features | \n",
" False | \n",
"
\n",
" \n",
" 35 | \n",
" Polynomial Threshold | \n",
" None | \n",
"
\n",
" \n",
" 36 | \n",
" Group Features | \n",
" False | \n",
"
\n",
" \n",
" 37 | \n",
" Feature Selection | \n",
" False | \n",
"
\n",
" \n",
" 38 | \n",
" Features Selection Threshold | \n",
" None | \n",
"
\n",
" \n",
" 39 | \n",
" Feature Interaction | \n",
" False | \n",
"
\n",
" \n",
" 40 | \n",
" Feature Ratio | \n",
" False | \n",
"
\n",
" \n",
" 41 | \n",
" Interaction Threshold | \n",
" None | \n",
"
\n",
" \n",
" 42 | \n",
" Fix Imbalance | \n",
" False | \n",
"
\n",
" \n",
" 43 | \n",
" Fix Imbalance Method | \n",
" SMOTE | \n",
"
\n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from pycaret.classification import *\n",
"clf1 = setup(data, target = 'Purchase', session_id=123, log_experiment=False, experiment_name='bank1')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Compare Baseline"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" | Model | Accuracy | AUC | Recall | Prec. | F1 | Kappa | MCC | TT (Sec) |
\n",
" \n",
" 0 | \n",
" Logistic Regression | \n",
" 0.8263 | \n",
" 0.8959 | \n",
" 0.7262 | \n",
" 0.8139 | \n",
" 0.7644 | \n",
" 0.6280 | \n",
" 0.6338 | \n",
" 0.0429 | \n",
"
\n",
" \n",
" 1 | \n",
" Linear Discriminant Analysis | \n",
" 0.8263 | \n",
" 0.8938 | \n",
" 0.7536 | \n",
" 0.7938 | \n",
" 0.7713 | \n",
" 0.6317 | \n",
" 0.6342 | \n",
" 0.0089 | \n",
"
\n",
" \n",
" 2 | \n",
" Ridge Classifier | \n",
" 0.8236 | \n",
" 0.0000 | \n",
" 0.7499 | \n",
" 0.7920 | \n",
" 0.7680 | \n",
" 0.6262 | \n",
" 0.6292 | \n",
" 0.0045 | \n",
"
\n",
" \n",
" 3 | \n",
" Ada Boost Classifier | \n",
" 0.8075 | \n",
" 0.8637 | \n",
" 0.7053 | \n",
" 0.7837 | \n",
" 0.7398 | \n",
" 0.5881 | \n",
" 0.5924 | \n",
" 0.0780 | \n",
"
\n",
" \n",
" 4 | \n",
" Gradient Boosting Classifier | \n",
" 0.8062 | \n",
" 0.8869 | \n",
" 0.7363 | \n",
" 0.7651 | \n",
" 0.7479 | \n",
" 0.5909 | \n",
" 0.5939 | \n",
" 0.1205 | \n",
"
\n",
" \n",
" 5 | \n",
" CatBoost Classifier | \n",
" 0.8049 | \n",
" 0.8932 | \n",
" 0.7326 | \n",
" 0.7629 | \n",
" 0.7457 | \n",
" 0.5878 | \n",
" 0.5899 | \n",
" 3.5314 | \n",
"
\n",
" \n",
" 6 | \n",
" Extreme Gradient Boosting | \n",
" 0.7914 | \n",
" 0.8716 | \n",
" 0.7294 | \n",
" 0.7367 | \n",
" 0.7309 | \n",
" 0.5609 | \n",
" 0.5633 | \n",
" 0.0675 | \n",
"
\n",
" \n",
" 7 | \n",
" Light Gradient Boosting Machine | \n",
" 0.7861 | \n",
" 0.8806 | \n",
" 0.7053 | \n",
" 0.7393 | \n",
" 0.7195 | \n",
" 0.5471 | \n",
" 0.5497 | \n",
" 0.0898 | \n",
"
\n",
" \n",
" 8 | \n",
" Quadratic Discriminant Analysis | \n",
" 0.7621 | \n",
" 0.8240 | \n",
" 0.6267 | \n",
" 0.7397 | \n",
" 0.6678 | \n",
" 0.4863 | \n",
" 0.5000 | \n",
" 0.0060 | \n",
"
\n",
" \n",
" 9 | \n",
" Random Forest Classifier | \n",
" 0.7608 | \n",
" 0.8397 | \n",
" 0.6674 | \n",
" 0.7124 | \n",
" 0.6848 | \n",
" 0.4928 | \n",
" 0.4974 | \n",
" 0.1142 | \n",
"
\n",
" \n",
" 10 | \n",
" Decision Tree Classifier | \n",
" 0.7594 | \n",
" 0.7519 | \n",
" 0.6911 | \n",
" 0.6970 | \n",
" 0.6907 | \n",
" 0.4943 | \n",
" 0.4975 | \n",
" 0.0051 | \n",
"
\n",
" \n",
" 11 | \n",
" Extra Trees Classifier | \n",
" 0.7433 | \n",
" 0.8205 | \n",
" 0.6708 | \n",
" 0.6758 | \n",
" 0.6698 | \n",
" 0.4605 | \n",
" 0.4638 | \n",
" 0.1503 | \n",
"
\n",
" \n",
" 12 | \n",
" K Neighbors Classifier | \n",
" 0.7231 | \n",
" 0.7683 | \n",
" 0.6062 | \n",
" 0.6600 | \n",
" 0.6287 | \n",
" 0.4094 | \n",
" 0.4129 | \n",
" 0.0054 | \n",
"
\n",
" \n",
" 13 | \n",
" Naive Bayes | \n",
" 0.7140 | \n",
" 0.7952 | \n",
" 0.7466 | \n",
" 0.6100 | \n",
" 0.6708 | \n",
" 0.4227 | \n",
" 0.4301 | \n",
" 0.0032 | \n",
"
\n",
" \n",
" 14 | \n",
" SVM - Linear Kernel | \n",
" 0.5267 | \n",
" 0.0000 | \n",
" 0.4200 | \n",
" 0.2409 | \n",
" 0.2561 | \n",
" 0.0204 | \n",
" 0.0299 | \n",
" 0.0074 | \n",
"
\n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"top5 = compare_models(n_select=5)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" | Accuracy | AUC | Recall | Prec. | F1 | Kappa | MCC |
\n",
" \n",
" 0 | \n",
" 0.7467 | \n",
" 0.8426 | \n",
" 0.6897 | \n",
" 0.6667 | \n",
" 0.6780 | \n",
" 0.4693 | \n",
" 0.4695 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.8267 | \n",
" 0.9475 | \n",
" 0.7241 | \n",
" 0.8077 | \n",
" 0.7636 | \n",
" 0.6274 | \n",
" 0.6298 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.7600 | \n",
" 0.8201 | \n",
" 0.6552 | \n",
" 0.7037 | \n",
" 0.6786 | \n",
" 0.4875 | \n",
" 0.4883 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.8133 | \n",
" 0.9070 | \n",
" 0.8276 | \n",
" 0.7273 | \n",
" 0.7742 | \n",
" 0.6162 | \n",
" 0.6200 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.7867 | \n",
" 0.8973 | \n",
" 0.8621 | \n",
" 0.6757 | \n",
" 0.7576 | \n",
" 0.5720 | \n",
" 0.5856 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.8133 | \n",
" 0.8906 | \n",
" 0.7241 | \n",
" 0.7778 | \n",
" 0.7500 | \n",
" 0.6014 | \n",
" 0.6023 | \n",
"
\n",
" \n",
" 6 | \n",
" 0.8133 | \n",
" 0.8833 | \n",
" 0.8333 | \n",
" 0.7353 | \n",
" 0.7812 | \n",
" 0.6196 | \n",
" 0.6233 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.8400 | \n",
" 0.9222 | \n",
" 0.8000 | \n",
" 0.8000 | \n",
" 0.8000 | \n",
" 0.6667 | \n",
" 0.6667 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.7973 | \n",
" 0.8759 | \n",
" 0.6897 | \n",
" 0.7692 | \n",
" 0.7273 | \n",
" 0.5667 | \n",
" 0.5689 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.8514 | \n",
" 0.9111 | \n",
" 0.7586 | \n",
" 0.8462 | \n",
" 0.8000 | \n",
" 0.6823 | \n",
" 0.6849 | \n",
"
\n",
" \n",
" Mean | \n",
" 0.8049 | \n",
" 0.8898 | \n",
" 0.7564 | \n",
" 0.7509 | \n",
" 0.7510 | \n",
" 0.5909 | \n",
" 0.5939 | \n",
"
\n",
" \n",
" SD | \n",
" 0.0314 | \n",
" 0.0354 | \n",
" 0.0673 | \n",
" 0.0562 | \n",
" 0.0420 | \n",
" 0.0659 | \n",
" 0.0662 | \n",
"
\n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"top5_tuned = [tune_model(i) for i in top5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Create Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lr = create_model('lr', fold = 5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dt = create_model('dt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rf = create_model('rf', fold = 5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"models()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"models(type='ensemble').index.tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ensembled_models = compare_models(whitelist = models(type='ensemble').index.tolist(), fold = 3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Tune Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tuned_lr = tune_model(lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tuned_rf = tune_model(rf)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6. Ensemble Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bagged_dt = ensemble_model(dt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"boosted_dt = ensemble_model(dt, method = 'Boosting')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 7. Blend Models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"blender = blend_models(estimator_list = [boosted_dt, bagged_dt], method = 'soft')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 8. Stack Models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"stacker = stack_models(estimator_list = [boosted_dt,bagged_dt,tuned_rf], meta_model=rf)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 9. Analyze Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_model(rf)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_model(rf, plot = 'confusion_matrix')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_model(rf, plot = 'boundary')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"evaluate_model(rf)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 10. Interpret Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"catboost = create_model('catboost', cross_validation=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpret_model(catboost)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpret_model(catboost, plot = 'correlation')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"interpret_model(catboost, plot = 'reason', observation = 12)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 11. AutoML()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"best = automl(optimize = 'Recall')\n",
"best"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 12. Predict Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_holdouts = predict_model(lr)\n",
"pred_holdouts.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"new_data = data.copy()\n",
"new_data.drop(['deposit'], axis=1, inplace=True)\n",
"predict_new = predict_model(lr, data=new_data)\n",
"predict_new.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 13. Save / Load Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_model(lr, model_name='best-model')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loaded_bestmodel = load_model('best-model')\n",
"print(loaded_bestmodel)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import set_config\n",
"set_config(display='diagram')\n",
"loaded_bestmodel[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import set_config\n",
"set_config(display='text')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 14. Deploy Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"deploy_model(lr, model_name = 'best-aws', authentication = {'bucket' : 'pycaret-test'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 15. Get Config / Set Config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_train = get_config('X_train')\n",
"X_train.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"get_config('seed')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pycaret.classification import set_config\n",
"set_config('seed', 999)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"get_config('seed')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 16. Get System Logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"get_system_logs()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 17. MLFlow UI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!mlflow ui"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# to generate csv file with experiment logs\n",
"get_logs()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# End\n",
"Thank you. For more information / tutorials on PyCaret, please visit https://www.pycaret.org"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}