{ "cells": [ { "cell_type": "code", "execution_count": 14, "id": "67febe95-dd8e-4564-8a5e-e641bb16906e", "metadata": {}, "outputs": [], "source": [ "import math\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "import tensorflow_addons as tfa\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score\n", "from tensorflow.keras.callbacks import EarlyStopping\n", "\n", "from tabtransformertf.models.tabtransformer import TabTransformer\n", "from tabtransformertf.utils.preprocessing import df_to_dataset, build_categorical_prep" ] }, { "cell_type": "markdown", "id": "3f89b597-3f3a-4079-bc7c-ad3fbf16dba8", "metadata": {}, "source": [ "## Download Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "e0c2547d-a0b7-478c-ae45-92e99dc6f113", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train dataset shape: (32561, 15)\n", "Test dataset shape: (16282, 15)\n" ] } ], "source": [ "CSV_HEADER = [\n", " \"age\",\n", " \"workclass\",\n", " \"fnlwgt\",\n", " \"education\",\n", " \"education_num\",\n", " \"marital_status\",\n", " \"occupation\",\n", " \"relationship\",\n", " \"race\",\n", " \"gender\",\n", " \"capital_gain\",\n", " \"capital_loss\",\n", " \"hours_per_week\",\n", " \"native_country\",\n", " \"income_bracket\",\n", "]\n", "\n", "train_data_url = (\n", " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", ")\n", "train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)\n", "\n", "test_data_url = (\n", " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\"\n", ")\n", "test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)\n", "\n", "print(f\"Train dataset shape: {train_data.shape}\")\n", "print(f\"Test dataset shape: {test_data.shape}\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "75521624-3f40-4aad-bbd4-cea5c5cb5782", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | age | \n", "workclass | \n", "fnlwgt | \n", "education | \n", "education_num | \n", "marital_status | \n", "occupation | \n", "relationship | \n", "race | \n", "gender | \n", "capital_gain | \n", "capital_loss | \n", "hours_per_week | \n", "native_country | \n", "income_bracket | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "39 | \n", "State-gov | \n", "77516 | \n", "Bachelors | \n", "13 | \n", "Never-married | \n", "Adm-clerical | \n", "Not-in-family | \n", "White | \n", "Male | \n", "2174 | \n", "0 | \n", "40 | \n", "United-States | \n", "<=50K | \n", "
1 | \n", "50 | \n", "Self-emp-not-inc | \n", "83311 | \n", "Bachelors | \n", "13 | \n", "Married-civ-spouse | \n", "Exec-managerial | \n", "Husband | \n", "White | \n", "Male | \n", "0 | \n", "0 | \n", "13 | \n", "United-States | \n", "<=50K | \n", "
2 | \n", "38 | \n", "Private | \n", "215646 | \n", "HS-grad | \n", "9 | \n", "Divorced | \n", "Handlers-cleaners | \n", "Not-in-family | \n", "White | \n", "Male | \n", "0 | \n", "0 | \n", "40 | \n", "United-States | \n", "<=50K | \n", "
3 | \n", "53 | \n", "Private | \n", "234721 | \n", "11th | \n", "7 | \n", "Married-civ-spouse | \n", "Handlers-cleaners | \n", "Husband | \n", "Black | \n", "Male | \n", "0 | \n", "0 | \n", "40 | \n", "United-States | \n", "<=50K | \n", "
4 | \n", "28 | \n", "Private | \n", "338409 | \n", "Bachelors | \n", "13 | \n", "Married-civ-spouse | \n", "Prof-specialty | \n", "Wife | \n", "Black | \n", "Female | \n", "0 | \n", "0 | \n", "40 | \n", "Cuba | \n", "<=50K | \n", "