{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyCaret 2 Classification Example\n", "This notebook is created using PyCaret 2.0. Last updated : 28-07-2020" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pycaret-nightly-0.39\n" ] } ], "source": [ "# check version\n", "from pycaret.utils import version\n", "version()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Data Repository" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DatasetData TypesDefault TaskTarget Variable# Instances# AttributesMissing Values
0anomalyMultivariateAnomaly DetectionNone100010N
1franceMultivariateAssociation Rule MiningInvoiceNo, Description85578N
2germanyMultivariateAssociation Rule MiningInvoiceNo, Description94958N
3bankMultivariateClassification (Binary)deposit4521117N
4bloodMultivariateClassification (Binary)Class7485N
5cancerMultivariateClassification (Binary)Class68310N
6creditMultivariateClassification (Binary)default2400024N
7diabetesMultivariateClassification (Binary)Class variable7689N
8electrical_gridMultivariateClassification (Binary)stabf1000014N
9employeeMultivariateClassification (Binary)left1499910N
10heartMultivariateClassification (Binary)DEATH20016N
11heart_diseaseMultivariateClassification (Binary)Disease27014N
12hepatitisMultivariateClassification (Binary)Class15432Y
13incomeMultivariateClassification (Binary)income >50K3256114Y
14juiceMultivariateClassification (Binary)Purchase107015N
15nbaMultivariateClassification (Binary)TARGET_5Yrs134021N
16wineMultivariateClassification (Binary)type649813N
17telescopeMultivariateClassification (Binary)Class1902011N
18glassMultivariateClassification (Multiclass)Type21410N
19irisMultivariateClassification (Multiclass)species1505N
20pokerMultivariateClassification (Multiclass)CLASS10000011N
21questionsMultivariateClassification (Multiclass)Next_Question4994N
22satelliteMultivariateClassification (Multiclass)Class643537N
23asia_gdpMultivariateClusteringNone4011N
24electionsMultivariateClusteringNone319554Y
25facebookMultivariateClusteringNone705012N
26iplMultivariateClusteringNone15325N
27jewelleryMultivariateClusteringNone5054N
28miceMultivariateClusteringNone108082Y
29migrationMultivariateClusteringNone23312N
30perfumeMultivariateClusteringNone2029N
31pokemonMultivariateClusteringNone80013Y
32populationMultivariateClusteringNone25556Y
33public_healthMultivariateClusteringNone22421N
34seedsMultivariateClusteringNone2107N
35wholesaleMultivariateClusteringNone4408N
36tweetsTextNLPtweet85942N
37amazonTextNLP / ClassificationreviewText200002N
38kivaTextNLP / Classificationen68187N
39spxTextNLP / Regressiontext8744N
40wikipediaTextNLP / ClassificationText5003N
41automobileMultivariateRegressionprice20226Y
42bikeMultivariateRegressioncnt1737915N
43bostonMultivariateRegressionmedv50614N
44concreteMultivariateRegressionstrength10309N
45diamondMultivariateRegressionPrice60008N
46energyMultivariateRegressionHeating Load / Cooling Load76810N
47forestMultivariateRegressionarea51713N
48goldMultivariateRegressionGold_T+222558121N
49houseMultivariateRegressionSalePrice146181Y
50insuranceMultivariateRegressioncharges13387N
51parkinsonsMultivariateRegressionPPE587522N
52trafficMultivariateRegressiontraffic_volume482048N
\n", "
" ], "text/plain": [ " Dataset Data Types Default Task \\\n", "0 anomaly Multivariate Anomaly Detection \n", "1 france Multivariate Association Rule Mining \n", "2 germany Multivariate Association Rule Mining \n", "3 bank Multivariate Classification (Binary) \n", "4 blood Multivariate Classification (Binary) \n", "5 cancer Multivariate Classification (Binary) \n", "6 credit Multivariate Classification (Binary) \n", "7 diabetes Multivariate Classification (Binary) \n", "8 electrical_grid Multivariate Classification (Binary) \n", "9 employee Multivariate Classification (Binary) \n", "10 heart Multivariate Classification (Binary) \n", "11 heart_disease Multivariate Classification (Binary) \n", "12 hepatitis Multivariate Classification (Binary) \n", "13 income Multivariate Classification (Binary) \n", "14 juice Multivariate Classification (Binary) \n", "15 nba Multivariate Classification (Binary) \n", "16 wine Multivariate Classification (Binary) \n", "17 telescope Multivariate Classification (Binary) \n", "18 glass Multivariate Classification (Multiclass) \n", "19 iris Multivariate Classification (Multiclass) \n", "20 poker Multivariate Classification (Multiclass) \n", "21 questions Multivariate Classification (Multiclass) \n", "22 satellite Multivariate Classification (Multiclass) \n", "23 asia_gdp Multivariate Clustering \n", "24 elections Multivariate Clustering \n", "25 facebook Multivariate Clustering \n", "26 ipl Multivariate Clustering \n", "27 jewellery Multivariate Clustering \n", "28 mice Multivariate Clustering \n", "29 migration Multivariate Clustering \n", "30 perfume Multivariate Clustering \n", "31 pokemon Multivariate Clustering \n", "32 population Multivariate Clustering \n", "33 public_health Multivariate Clustering \n", "34 seeds Multivariate Clustering \n", "35 wholesale Multivariate Clustering \n", "36 tweets Text NLP \n", "37 amazon Text NLP / Classification \n", "38 kiva Text NLP / Classification \n", "39 spx Text NLP / Regression \n", "40 wikipedia Text NLP / Classification \n", "41 automobile Multivariate Regression \n", "42 bike Multivariate Regression \n", "43 boston Multivariate Regression \n", "44 concrete Multivariate Regression \n", "45 diamond Multivariate Regression \n", "46 energy Multivariate Regression \n", "47 forest Multivariate Regression \n", "48 gold Multivariate Regression \n", "49 house Multivariate Regression \n", "50 insurance Multivariate Regression \n", "51 parkinsons Multivariate Regression \n", "52 traffic Multivariate Regression \n", "\n", " Target Variable # Instances # Attributes Missing Values \n", "0 None 1000 10 N \n", "1 InvoiceNo, Description 8557 8 N \n", "2 InvoiceNo, Description 9495 8 N \n", "3 deposit 45211 17 N \n", "4 Class 748 5 N \n", "5 Class 683 10 N \n", "6 default 24000 24 N \n", "7 Class variable 768 9 N \n", "8 stabf 10000 14 N \n", "9 left 14999 10 N \n", "10 DEATH 200 16 N \n", "11 Disease 270 14 N \n", "12 Class 154 32 Y \n", "13 income >50K 32561 14 Y \n", "14 Purchase 1070 15 N \n", "15 TARGET_5Yrs 1340 21 N \n", "16 type 6498 13 N \n", "17 Class 19020 11 N \n", "18 Type 214 10 N \n", "19 species 150 5 N \n", "20 CLASS 100000 11 N \n", "21 Next_Question 499 4 N \n", "22 Class 6435 37 N \n", "23 None 40 11 N \n", "24 None 3195 54 Y \n", "25 None 7050 12 N \n", "26 None 153 25 N \n", "27 None 505 4 N \n", "28 None 1080 82 Y \n", "29 None 233 12 N \n", "30 None 20 29 N \n", "31 None 800 13 Y \n", "32 None 255 56 Y \n", "33 None 224 21 N \n", "34 None 210 7 N \n", "35 None 440 8 N \n", "36 tweet 8594 2 N \n", "37 reviewText 20000 2 N \n", "38 en 6818 7 N \n", "39 text 874 4 N \n", "40 Text 500 3 N \n", "41 price 202 26 Y \n", "42 cnt 17379 15 N \n", "43 medv 506 14 N \n", "44 strength 1030 9 N \n", "45 Price 6000 8 N \n", "46 Heating Load / Cooling Load 768 10 N \n", "47 area 517 13 N \n", "48 Gold_T+22 2558 121 N \n", "49 SalePrice 1461 81 Y \n", "50 charges 1338 7 N \n", "51 PPE 5875 22 N \n", "52 traffic_volume 48204 8 N " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from pycaret.datasets import get_data\n", "index = get_data('index')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IdPurchaseWeekofPurchaseStoreIDPriceCHPriceMMDiscCHDiscMMSpecialCHSpecialMMLoyalCHSalePriceMMSalePriceCHPriceDiffStore7PctDiscMMPctDiscCHListPriceDiffSTORE
01CH23711.751.990.000.0000.5000001.991.750.24No0.0000000.0000000.241
12CH23911.751.990.000.3010.6000001.691.75-0.06No0.1507540.0000000.241
23CH24511.862.090.170.0000.6800002.091.690.40No0.0000000.0913980.231
34MM22711.691.690.000.0000.4000001.691.690.00No0.0000000.0000000.001
45CH22871.691.690.000.0000.9565351.691.690.00Yes0.0000000.0000000.000
\n", "
" ], "text/plain": [ " Id Purchase WeekofPurchase StoreID PriceCH PriceMM DiscCH DiscMM \\\n", "0 1 CH 237 1 1.75 1.99 0.00 0.0 \n", "1 2 CH 239 1 1.75 1.99 0.00 0.3 \n", "2 3 CH 245 1 1.86 2.09 0.17 0.0 \n", "3 4 MM 227 1 1.69 1.69 0.00 0.0 \n", "4 5 CH 228 7 1.69 1.69 0.00 0.0 \n", "\n", " SpecialCH SpecialMM LoyalCH SalePriceMM SalePriceCH PriceDiff Store7 \\\n", "0 0 0 0.500000 1.99 1.75 0.24 No \n", "1 0 1 0.600000 1.69 1.75 -0.06 No \n", "2 0 0 0.680000 2.09 1.69 0.40 No \n", "3 0 0 0.400000 1.69 1.69 0.00 No \n", "4 0 0 0.956535 1.69 1.69 0.00 Yes \n", "\n", " PctDiscMM PctDiscCH ListPriceDiff STORE \n", "0 0.000000 0.000000 0.24 1 \n", "1 0.150754 0.000000 0.24 1 \n", "2 0.000000 0.091398 0.23 1 \n", "3 0.000000 0.000000 0.00 1 \n", "4 0.000000 0.000000 0.00 0 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = get_data('juice')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. Initialize Setup" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setup Succesfully Completed!\n" ] }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Description Value
0session_id123
1Target TypeBinary
2Label EncodedCH: 0, MM: 1
3Original Data(1070, 19)
4Missing Values False
5Numeric Features 13
6Categorical Features 5
7Ordinal Features False
8High Cardinality Features False
9High Cardinality Method None
10Sampled Data(1070, 19)
11Transformed Train Set(748, 28)
12Transformed Test Set(322, 28)
13Numeric Imputer mean
14Categorical Imputer constant
15Normalize False
16Normalize Method None
17Transformation False
18Transformation Method None
19PCA False
20PCA Method None
21PCA Components None
22Ignore Low Variance False
23Combine Rare Levels False
24Rare Level Threshold None
25Numeric Binning False
26Remove Outliers False
27Outliers Threshold None
28Remove Multicollinearity False
29Multicollinearity Threshold None
30Clustering False
31Clustering Iteration None
32Polynomial Features False
33Polynomial Degree None
34Trignometry Features False
35Polynomial Threshold None
36Group Features False
37Feature Selection False
38Features Selection Threshold None
39Feature Interaction False
40Feature Ratio False
41Interaction Threshold None
42Fix ImbalanceFalse
43Fix Imbalance MethodSMOTE
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from pycaret.classification import *\n", "clf1 = setup(data, target = 'Purchase', session_id=123, log_experiment=False, experiment_name='bank1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. Compare Baseline" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Model Accuracy AUC Recall Prec. F1 Kappa MCC TT (Sec)
0Logistic Regression0.82630.89590.72620.81390.76440.62800.63380.0429
1Linear Discriminant Analysis0.82630.89380.75360.79380.77130.63170.63420.0089
2Ridge Classifier0.82360.00000.74990.79200.76800.62620.62920.0045
3Ada Boost Classifier0.80750.86370.70530.78370.73980.58810.59240.0780
4Gradient Boosting Classifier0.80620.88690.73630.76510.74790.59090.59390.1205
5CatBoost Classifier0.80490.89320.73260.76290.74570.58780.58993.5314
6Extreme Gradient Boosting0.79140.87160.72940.73670.73090.56090.56330.0675
7Light Gradient Boosting Machine0.78610.88060.70530.73930.71950.54710.54970.0898
8Quadratic Discriminant Analysis0.76210.82400.62670.73970.66780.48630.50000.0060
9Random Forest Classifier0.76080.83970.66740.71240.68480.49280.49740.1142
10Decision Tree Classifier0.75940.75190.69110.69700.69070.49430.49750.0051
11Extra Trees Classifier0.74330.82050.67080.67580.66980.46050.46380.1503
12K Neighbors Classifier0.72310.76830.60620.66000.62870.40940.41290.0054
13Naive Bayes0.71400.79520.74660.61000.67080.42270.43010.0032
14SVM - Linear Kernel0.52670.00000.42000.24090.25610.02040.02990.0074
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "top5 = compare_models(n_select=5)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Accuracy AUC Recall Prec. F1 Kappa MCC
00.74670.84260.68970.66670.67800.46930.4695
10.82670.94750.72410.80770.76360.62740.6298
20.76000.82010.65520.70370.67860.48750.4883
30.81330.90700.82760.72730.77420.61620.6200
40.78670.89730.86210.67570.75760.57200.5856
50.81330.89060.72410.77780.75000.60140.6023
60.81330.88330.83330.73530.78120.61960.6233
70.84000.92220.80000.80000.80000.66670.6667
80.79730.87590.68970.76920.72730.56670.5689
90.85140.91110.75860.84620.80000.68230.6849
Mean0.80490.88980.75640.75090.75100.59090.5939
SD0.03140.03540.06730.05620.04200.06590.0662
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "top5_tuned = [tune_model(i) for i in top5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 4. Create Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = create_model('lr', fold = 5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dt = create_model('dt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rf = create_model('rf', fold = 5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "models()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "models(type='ensemble').index.tolist()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ensembled_models = compare_models(whitelist = models(type='ensemble').index.tolist(), fold = 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 5. Tune Hyperparameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tuned_lr = tune_model(lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tuned_rf = tune_model(rf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 6. Ensemble Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bagged_dt = ensemble_model(dt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "boosted_dt = ensemble_model(dt, method = 'Boosting')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 7. Blend Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "blender = blend_models(estimator_list = [boosted_dt, bagged_dt], method = 'soft')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 8. Stack Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "stacker = stack_models(estimator_list = [boosted_dt,bagged_dt,tuned_rf], meta_model=rf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 9. Analyze Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_model(rf)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_model(rf, plot = 'confusion_matrix')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_model(rf, plot = 'boundary')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "evaluate_model(rf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 10. Interpret Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "catboost = create_model('catboost', cross_validation=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interpret_model(catboost)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interpret_model(catboost, plot = 'correlation')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interpret_model(catboost, plot = 'reason', observation = 12)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 11. AutoML()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best = automl(optimize = 'Recall')\n", "best" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 12. Predict Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pred_holdouts = predict_model(lr)\n", "pred_holdouts.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "new_data = data.copy()\n", "new_data.drop(['deposit'], axis=1, inplace=True)\n", "predict_new = predict_model(lr, data=new_data)\n", "predict_new.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 13. Save / Load Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_model(lr, model_name='best-model')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loaded_bestmodel = load_model('best-model')\n", "print(loaded_bestmodel)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn import set_config\n", "set_config(display='diagram')\n", "loaded_bestmodel[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn import set_config\n", "set_config(display='text')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 14. Deploy Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "deploy_model(lr, model_name = 'best-aws', authentication = {'bucket' : 'pycaret-test'})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 15. Get Config / Set Config" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train = get_config('X_train')\n", "X_train.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "get_config('seed')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pycaret.classification import set_config\n", "set_config('seed', 999)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "get_config('seed')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 16. Get System Logs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "get_system_logs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 17. MLFlow UI" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mlflow ui" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# to generate csv file with experiment logs\n", "get_logs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# End\n", "Thank you. For more information / tutorials on PyCaret, please visit https://www.pycaret.org" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }