{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; \n", "\n", "import urllib.request\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import tabular" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Income Prediction from Census Dataset\n", "\n", "In this notebook, we will predict which individuals make more than $50K from Census data. This is the same dataset used in the [AutoGluon tabular prediction example](https://autogluon.mxnet.io/tutorials/tabular_prediction/tabular-quickstart.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collect Training and Test Sets\n", "\n", "The original dataset is available from the [UCI Machine Learning Repository](http://archive.ics.uci.edu/ml/datasets/Adult), but we will download it from the AutoGluon website." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('/tmp/train.csv', )" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# training set\n", "urllib.request.urlretrieve('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv', \n", " '/tmp/train.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Load and Preprocess Data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "processing train: 35179 rows x 15 columns\n", "\n", "The following integer column(s) are being treated as categorical variables:\n", "['education-num']\n", "To treat any of these column(s) as numerical, cast the column to float in DataFrame or CSV\n", " and re-run tabular_from* function.\n", "\n", "processing test: 3894 rows x 15 columns\n" ] } ], "source": [ "trn, val, preproc = tabular.tabular_from_csv('/tmp/train.csv', label_columns='class', random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Create Model and Wrap in `Learner`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "done.\n" ] } ], "source": [ "model = tabular.tabular_classifier('mlp', trn)\n", "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Estimate LR\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Train for 274 steps\n", "Epoch 1/1024\n", "274/274 [==============================] - 8s 28ms/step - loss: 0.7151 - accuracy: 0.3431\n", "Epoch 2/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.6359 - accuracy: 0.6889\n", "Epoch 3/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.4145 - accuracy: 0.8113\n", "Epoch 4/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.3268 - accuracy: 0.8486\n", "Epoch 5/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.6269 - accuracy: 0.7968\n", "Epoch 6/1024\n", "274/274 [==============================] - 7s 25ms/step - loss: 0.5543 - accuracy: 0.7589\n", "Epoch 7/1024\n", " 50/274 [====>.........................] - ETA: 7s - loss: 47426.1304 - accuracy: 0.7517\n", "\n", "done.\n", "Visually inspect loss plot and select learning rate associated with falling loss\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Train" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "early_stopping automatically enabled at patience=5\n", "reduce_on_plateau automatically enabled at patience=2\n", "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.001...\n", "Train for 275 steps, validate for 122 steps\n", "Epoch 1/1024\n", "275/275 [==============================] - 10s 38ms/step - loss: 0.3674 - accuracy: 0.8285 - val_loss: 0.2957 - val_accuracy: 0.8624\n", "Epoch 2/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3157 - accuracy: 0.8549 - val_loss: 0.2962 - val_accuracy: 0.8652\n", "Epoch 3/1024\n", "269/275 [============================>.] - ETA: 0s - loss: 0.3128 - accuracy: 0.8558\n", "Epoch 00003: Reducing Max LR on Plateau: new max lr will be 0.0005 (if not early_stopping).\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.3127 - accuracy: 0.8559 - val_loss: 0.2994 - val_accuracy: 0.8621\n", "Epoch 4/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3086 - accuracy: 0.8574 - val_loss: 0.2951 - val_accuracy: 0.8652\n", "Epoch 5/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3078 - accuracy: 0.8586 - val_loss: 0.2953 - val_accuracy: 0.8654\n", "Epoch 6/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3057 - accuracy: 0.8595 - val_loss: 0.2933 - val_accuracy: 0.8659\n", "Epoch 7/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.3045 - accuracy: 0.8595 - val_loss: 0.2928 - val_accuracy: 0.8634\n", "Epoch 8/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.3033 - accuracy: 0.8605 - val_loss: 0.2927 - val_accuracy: 0.8649\n", "Epoch 9/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.3037 - accuracy: 0.8605 - val_loss: 0.2931 - val_accuracy: 0.8624\n", "Epoch 10/1024\n", "269/275 [============================>.] - ETA: 0s - loss: 0.3016 - accuracy: 0.8612\n", "Epoch 00010: Reducing Max LR on Plateau: new max lr will be 0.00025 (if not early_stopping).\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.3015 - accuracy: 0.8611 - val_loss: 0.2931 - val_accuracy: 0.8659\n", "Epoch 11/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2993 - accuracy: 0.8624 - val_loss: 0.2924 - val_accuracy: 0.8641\n", "Epoch 12/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2986 - accuracy: 0.8625 - val_loss: 0.2925 - val_accuracy: 0.8636\n", "Epoch 13/1024\n", "269/275 [============================>.] - ETA: 0s - loss: 0.2982 - accuracy: 0.8636\n", "Epoch 00013: Reducing Max LR on Plateau: new max lr will be 0.000125 (if not early_stopping).\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2982 - accuracy: 0.8636 - val_loss: 0.2926 - val_accuracy: 0.8634\n", "Epoch 14/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2958 - accuracy: 0.8641 - val_loss: 0.2923 - val_accuracy: 0.8636\n", "Epoch 15/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2950 - accuracy: 0.8642 - val_loss: 0.2920 - val_accuracy: 0.8654\n", "Epoch 16/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2944 - accuracy: 0.8645 - val_loss: 0.2938 - val_accuracy: 0.8608\n", "Epoch 17/1024\n", "272/275 [============================>.] - ETA: 0s - loss: 0.2940 - accuracy: 0.8640\n", "Epoch 00017: Reducing Max LR on Plateau: new max lr will be 6.25e-05 (if not early_stopping).\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2943 - accuracy: 0.8638 - val_loss: 0.2924 - val_accuracy: 0.8641\n", "Epoch 18/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2929 - accuracy: 0.8651 - val_loss: 0.2919 - val_accuracy: 0.8647\n", "Epoch 19/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2929 - accuracy: 0.8644 - val_loss: 0.2926 - val_accuracy: 0.8649\n", "Epoch 20/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2924 - accuracy: 0.8651 - val_loss: 0.2917 - val_accuracy: 0.8647\n", "Epoch 21/1024\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2921 - accuracy: 0.8658 - val_loss: 0.2918 - val_accuracy: 0.8652\n", "Epoch 22/1024\n", "274/275 [============================>.] - ETA: 0s - loss: 0.2908 - accuracy: 0.8667\n", "Epoch 00022: Reducing Max LR on Plateau: new max lr will be 3.125e-05 (if not early_stopping).\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2912 - accuracy: 0.8665 - val_loss: 0.2919 - val_accuracy: 0.8652\n", "Epoch 23/1024\n", "275/275 [==============================] - 10s 35ms/step - loss: 0.2909 - accuracy: 0.8663 - val_loss: 0.2920 - val_accuracy: 0.8652\n", "Epoch 24/1024\n", "272/275 [============================>.] - ETA: 0s - loss: 0.2905 - accuracy: 0.8667\n", "Epoch 00024: Reducing Max LR on Plateau: new max lr will be 1.5625e-05 (if not early_stopping).\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2904 - accuracy: 0.8670 - val_loss: 0.2921 - val_accuracy: 0.8649\n", "Epoch 25/1024\n", "269/275 [============================>.] - ETA: 0s - loss: 0.2899 - accuracy: 0.8668Restoring model weights from the end of the best epoch.\n", "275/275 [==============================] - 9s 34ms/step - loss: 0.2900 - accuracy: 0.8666 - val_loss: 0.2921 - val_accuracy: 0.8649\n", "Epoch 00025: early stopping\n", "Weights from best epoch have been loaded into model.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.autofit(1e-3)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " <=50K 0.89 0.94 0.91 3013\n", " >50K 0.74 0.62 0.67 881\n", "\n", " accuracy 0.86 3894\n", " macro avg 0.82 0.78 0.79 3894\n", "weighted avg 0.86 0.86 0.86 3894\n", "\n" ] }, { "data": { "text/plain": [ "array([[2824, 189],\n", " [ 338, 543]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=preproc.get_classes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate Model on Unseen Test Data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# download test dataset\n", "urllib.request.urlretrieve('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv', \n", " '/tmp/test.csv')\n", "\n", "\n", "test_df = pd.read_csv('/tmp/test.csv')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryclass
031Private16908511th7Married-civ-spouseSalesWifeWhiteFemale0020United-States<=50K
117Self-emp-not-inc22620312th8Never-marriedSalesOwn-childWhiteMale0045United-States<=50K
247Private54260Assoc-voc11Married-civ-spouseExec-managerialHusbandWhiteMale0188760United-States>50K
321Private176262Some-college10Never-marriedExec-managerialOwn-childWhiteFemale0030United-States<=50K
417Private24118512th8Never-marriedProf-specialtyOwn-childWhiteMale0020United-States<=50K
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 31 Private 169085 11th 7 \n", "1 17 Self-emp-not-inc 226203 12th 8 \n", "2 47 Private 54260 Assoc-voc 11 \n", "3 21 Private 176262 Some-college 10 \n", "4 17 Private 241185 12th 8 \n", "\n", " marital-status occupation relationship race sex \\\n", "0 Married-civ-spouse Sales Wife White Female \n", "1 Never-married Sales Own-child White Male \n", "2 Married-civ-spouse Exec-managerial Husband White Male \n", "3 Never-married Exec-managerial Own-child White Female \n", "4 Never-married Prof-specialty Own-child White Male \n", "\n", " capital-gain capital-loss hours-per-week native-country class \n", "0 0 0 20 United-States <=50K \n", "1 0 0 45 United-States <=50K \n", "2 0 1887 60 United-States >50K \n", "3 0 0 30 United-States <=50K \n", "4 0 0 20 United-States <=50K " ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `learner.evaluate` method is just an alias to `learner.validate`. By default, it was validate/evaluate\n", "`learner.val_data`, but both can accept a test set as an argument in the form of a `TabularDataset`.\n", "\n", "We use `learner.evaluate` here to compute test set metrics." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "processing test: 9769 rows x 15 columns\n", " precision recall f1-score support\n", "\n", " <=50K 0.88 0.94 0.91 7451\n", " >50K 0.76 0.61 0.67 2318\n", "\n", " accuracy 0.86 9769\n", " macro avg 0.82 0.77 0.79 9769\n", "weighted avg 0.85 0.86 0.85 9769\n", "\n" ] }, { "data": { "text/plain": [ "array([[6996, 455],\n", " [ 914, 1404]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.evaluate(preproc.preprocess_test(test_df), class_names=preproc.get_classes())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Generating Test Results\n", "\n", "Let's generate a DataFrame showing the test set predictions for each instance:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['<=50K', '>50K']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preproc.get_classes()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)\n", "preds = predictor.predict(test_df)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "df = test_df.copy()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "df['predicted_class'] = preds" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation-nummarital-statusoccupationrelationshipracesexcapital-gaincapital-losshours-per-weeknative-countryclasspredicted_class
031Private16908511th7Married-civ-spouseSalesWifeWhiteFemale0020United-States<=50K<=50K
117Self-emp-not-inc22620312th8Never-marriedSalesOwn-childWhiteMale0045United-States<=50K<=50K
247Private54260Assoc-voc11Married-civ-spouseExec-managerialHusbandWhiteMale0188760United-States>50K>50K
321Private176262Some-college10Never-marriedExec-managerialOwn-childWhiteFemale0030United-States<=50K<=50K
417Private24118512th8Never-marriedProf-specialtyOwn-childWhiteMale0020United-States<=50K<=50K
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education-num \\\n", "0 31 Private 169085 11th 7 \n", "1 17 Self-emp-not-inc 226203 12th 8 \n", "2 47 Private 54260 Assoc-voc 11 \n", "3 21 Private 176262 Some-college 10 \n", "4 17 Private 241185 12th 8 \n", "\n", " marital-status occupation relationship race sex \\\n", "0 Married-civ-spouse Sales Wife White Female \n", "1 Never-married Sales Own-child White Male \n", "2 Married-civ-spouse Exec-managerial Husband White Male \n", "3 Never-married Exec-managerial Own-child White Female \n", "4 Never-married Prof-specialty Own-child White Male \n", "\n", " capital-gain capital-loss hours-per-week native-country class \\\n", "0 0 0 20 United-States <=50K \n", "1 0 0 45 United-States <=50K \n", "2 0 1887 60 United-States >50K \n", "3 0 0 30 United-States <=50K \n", "4 0 0 20 United-States <=50K \n", "\n", " predicted_class \n", "0 <=50K \n", "1 <=50K \n", "2 >50K \n", "3 <=50K \n", "4 <=50K " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }