{ "cells": [ { "cell_type": "code", "execution_count": 14, "id": "67febe95-dd8e-4564-8a5e-e641bb16906e", "metadata": {}, "outputs": [], "source": [ "import math\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "import tensorflow_addons as tfa\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score\n", "from tensorflow.keras.callbacks import EarlyStopping\n", "\n", "from tabtransformertf.models.tabtransformer import TabTransformer\n", "from tabtransformertf.utils.preprocessing import df_to_dataset, build_categorical_prep" ] }, { "cell_type": "markdown", "id": "3f89b597-3f3a-4079-bc7c-ad3fbf16dba8", "metadata": {}, "source": [ "## Download Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "e0c2547d-a0b7-478c-ae45-92e99dc6f113", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train dataset shape: (32561, 15)\n", "Test dataset shape: (16282, 15)\n" ] } ], "source": [ "CSV_HEADER = [\n", " \"age\",\n", " \"workclass\",\n", " \"fnlwgt\",\n", " \"education\",\n", " \"education_num\",\n", " \"marital_status\",\n", " \"occupation\",\n", " \"relationship\",\n", " \"race\",\n", " \"gender\",\n", " \"capital_gain\",\n", " \"capital_loss\",\n", " \"hours_per_week\",\n", " \"native_country\",\n", " \"income_bracket\",\n", "]\n", "\n", "train_data_url = (\n", " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", ")\n", "train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)\n", "\n", "test_data_url = (\n", " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\"\n", ")\n", "test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)\n", "\n", "print(f\"Train dataset shape: {train_data.shape}\")\n", "print(f\"Test dataset shape: {test_data.shape}\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "75521624-3f40-4aad-bbd4-cea5c5cb5782", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducation_nummarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryincome_bracket
039State-gov77516Bachelors13Never-marriedAdm-clericalNot-in-familyWhiteMale2174040United-States<=50K
150Self-emp-not-inc83311Bachelors13Married-civ-spouseExec-managerialHusbandWhiteMale0013United-States<=50K
238Private215646HS-grad9DivorcedHandlers-cleanersNot-in-familyWhiteMale0040United-States<=50K
353Private23472111th7Married-civ-spouseHandlers-cleanersHusbandBlackMale0040United-States<=50K
428Private338409Bachelors13Married-civ-spouseProf-specialtyWifeBlackFemale0040Cuba<=50K
\n", "
" ], "text/plain": [ " age workclass fnlwgt education education_num \\\n", "0 39 State-gov 77516 Bachelors 13 \n", "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", "2 38 Private 215646 HS-grad 9 \n", "3 53 Private 234721 11th 7 \n", "4 28 Private 338409 Bachelors 13 \n", "\n", " marital_status occupation relationship race gender \\\n", "0 Never-married Adm-clerical Not-in-family White Male \n", "1 Married-civ-spouse Exec-managerial Husband White Male \n", "2 Divorced Handlers-cleaners Not-in-family White Male \n", "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", "4 Married-civ-spouse Prof-specialty Wife Black Female \n", "\n", " capital_gain capital_loss hours_per_week native_country income_bracket \n", "0 2174 0 40 United-States <=50K \n", "1 0 0 13 United-States <=50K \n", "2 0 0 40 United-States <=50K \n", "3 0 0 40 United-States <=50K \n", "4 0 0 40 Cuba <=50K " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data.head()" ] }, { "cell_type": "markdown", "id": "89bff333-3750-46c9-b687-3f06a8f43845", "metadata": {}, "source": [ "## Preprocess" ] }, { "cell_type": "code", "execution_count": 4, "id": "420201d9-eb7d-4873-bae8-2296fef046c4", "metadata": {}, "outputs": [], "source": [ "# Column information\n", "NUMERIC_FEATURES = train_data.select_dtypes(include=np.number).columns\n", "CATEGORICAL_FEATURES = train_data.select_dtypes(exclude=np.number).columns[:-1] # exclude label column and DT\n", "\n", "FEATURES = list(NUMERIC_FEATURES) + list(CATEGORICAL_FEATURES)\n", "LABEL = 'income_bracket'" ] }, { "cell_type": "code", "execution_count": 5, "id": "aaa2d685-44be-410a-89cf-acdd4a069adb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.2408095574460244, 0.23621176759611842)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# encoding as binary target\n", "train_data[LABEL] = train_data[LABEL].apply(lambda x: int(x == ' >50K')) \n", "test_data[LABEL] = test_data[LABEL].apply(lambda x: int(x == ' >50K.'))\n", "train_data[LABEL].mean(), test_data[LABEL].mean()" ] }, { "cell_type": "code", "execution_count": 6, "id": "677298dc-05fe-47ef-b40a-f5b1d62e8187", "metadata": {}, "outputs": [], "source": [ "test_data = test_data.iloc[1:, :] # drop invalid row" ] }, { "cell_type": "code", "execution_count": 7, "id": "25986888-5dc1-497c-96c7-bc5e5411c105", "metadata": {}, "outputs": [], "source": [ "# Set data types\n", "train_data[CATEGORICAL_FEATURES] = train_data[CATEGORICAL_FEATURES].astype(str)\n", "test_data[CATEGORICAL_FEATURES] = test_data[CATEGORICAL_FEATURES].astype(str)\n", "\n", "train_data[NUMERIC_FEATURES] = train_data[NUMERIC_FEATURES].astype(float)\n", "test_data[NUMERIC_FEATURES] = test_data[NUMERIC_FEATURES].astype(float)" ] }, { "cell_type": "code", "execution_count": 8, "id": "ad185875-9f9e-43b2-9422-40fe9ff66d7d", "metadata": {}, "outputs": [], "source": [ "# Train/test split\n", "X_train, X_val = train_test_split(train_data, test_size=0.2)" ] }, { "cell_type": "markdown", "id": "a07ad92f-f1b9-46ad-9b61-831cbc903e22", "metadata": {}, "source": [ "## Modelling Prep" ] }, { "cell_type": "code", "execution_count": 9, "id": "1592a15a-cf05-4a35-ba46-69e877998d32", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/8 [00:00 0.5), 4))" ] } ], "metadata": { "kernelspec": { "display_name": "blog", "language": "python", "name": "blog" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.4" } }, "nbformat": 4, "nbformat_minor": 5 }