{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "XV25GIUMruLl" }, "source": [ "# Setup Environment\n", "\n", "We need to install [interpret](https://github.com/interpretml/interpret/) and [gamchanger](https://github.com/interpretml/gam-changer/) packages." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "jYksosEEsGwU" }, "outputs": [], "source": [ "# Install `interpretml` and `gamchanger` packages.\n", "!pip install --upgrade interpret gamchanger" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "p9CB8dx2B67Y" }, "source": [ "# Model Training\n", "\n", "We will train a simple [EBM model](https://interpret.ml/docs/ebm.html) to predict if an indivisual's income is above 50K using the [census dataset](https://archive.ics.uci.edu/ml/datasets/census+income)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "Fhad39L2rb4k" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GAM Changer version: 0.1.12\n", "Interpret version: 0.4.2\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "import gamchanger as gc\n", "import interpret\n", "\n", "from json import load\n", "from sklearn.model_selection import train_test_split\n", "from interpret.glassbox import ExplainableBoostingClassifier\n", "\n", "print('GAM Changer version: ', gc.__version__)\n", "print('Interpret version: ', interpret.__version__)\n", "\n", "df = pd.read_csv(\n", " \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n", " header=None)\n", "\n", "df.columns = [\n", " \"Age\", \"WorkClass\", \"fnlwgt\", \"Education\", \"EducationNum\",\n", " \"MaritalStatus\", \"Occupation\", \"Relationship\", \"Race\", \"Gender\",\n", " \"CapitalGain\", \"CapitalLoss\", \"HoursPerWeek\", \"NativeCountry\", \"Income\"\n", "]\n", "\n", "train_cols = df.columns[0:-1]\n", "label = df.columns[-1]\n", "X = df[train_cols]\n", "y = df[label].apply(lambda x: 0 if x == \" <=50K\" else 1) #Turning response into 0 and 1\n", "\n", "seed = 1\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rltttrkyrkpr", "outputId": "3ed091ac-2f14-4246-fa08-488af015ab42" }, "outputs": [ { "data": { "text/html": [ "
ExplainableBoostingClassifier(n_jobs=-1, random_state=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
ExplainableBoostingClassifier(n_jobs=-1, random_state=1)
ExplainableBoostingClassifier(n_jobs=-1, random_state=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
ExplainableBoostingClassifier(n_jobs=-1, random_state=1)