{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MatchVariables\n", "\n", "\n", "MatchVariables() ensures that the columns in the test set are identical to those\n", "in the train set.\n", "\n", "If the test set contains additional columns, they are dropped. Alternatively, if the\n", "test set lacks columns that were present in the train set, they will be added with a\n", "value determined by the user, for example np.nan." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from feature_engine.preprocessing import MatchVariables" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load titanic dataset from OpenML\n", "\n", "def load_titanic():\n", " data = pd.read_csv('https://www.openml.org/data/get_csv/16826755/phpMYEkMl')\n", " data = data.replace('?', np.nan)\n", " data['cabin'] = data['cabin'].astype(str).str[0]\n", " data['pclass'] = data['pclass'].astype('O')\n", " data['age'] = data['age'].astype('float')\n", " data['fare'] = data['fare'].astype('float')\n", " data['embarked'].fillna('C', inplace=True)\n", " data.drop(\n", " labels=['name', 'ticket', 'boat', 'body', 'home.dest'],\n", " axis=1, inplace=True,\n", " )\n", " return data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclasssurvivedsexagesibspparchfarecabinembarked
011female29.000000211.3375BS
111male0.916712151.5500CS
210female2.000012151.5500CS
310male30.000012151.5500CS
410female25.000012151.5500CS
\n", "
" ], "text/plain": [ " pclass survived sex age sibsp parch fare cabin embarked\n", "0 1 1 female 29.0000 0 0 211.3375 B S\n", "1 1 1 male 0.9167 1 2 151.5500 C S\n", "2 1 0 female 2.0000 1 2 151.5500 C S\n", "3 1 0 male 30.0000 1 2 151.5500 C S\n", "4 1 0 female 25.0000 1 2 151.5500 C S" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = load_titanic()\n", "\n", "data.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1000, 9), (309, 9))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# separate the dataset into train and test\n", "\n", "train = data.iloc[0:1000, :]\n", "test = data.iloc[1000:, :]\n", "\n", "train.shape, test.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MatchVariables(missing_values='ignore')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# set up the transformer\n", "match_cols = MatchVariables(missing_values=\"ignore\")\n", "\n", "# learn the variables in the train set\n", "match_cols.fit(train)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['pclass',\n", " 'survived',\n", " 'sex',\n", " 'age',\n", " 'sibsp',\n", " 'parch',\n", " 'fare',\n", " 'cabin',\n", " 'embarked']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# the transformer stores the input variables\n", "\n", "match_cols.input_features_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1 - Some columns are missing in the test set" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclasssurvivedsibspparchfarecabinembarked
100031007.7500nQ
1001312023.2500nQ
1002312023.2500nQ
1003312023.2500nQ
100431007.7875nQ
\n", "
" ], "text/plain": [ " pclass survived sibsp parch fare cabin embarked\n", "1000 3 1 0 0 7.7500 n Q\n", "1001 3 1 2 0 23.2500 n Q\n", "1002 3 1 2 0 23.2500 n Q\n", "1003 3 1 2 0 23.2500 n Q\n", "1004 3 1 0 0 7.7875 n Q" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's drop some columns in the test set for the demo\n", "test_t = test.drop([\"sex\", \"age\"], axis=1)\n", "\n", "test_t.head()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclasssurvivedsibspparchfarecabinembarked
100031007.7500nQ
1001312023.2500nQ
1002312023.2500nQ
1003312023.2500nQ
100431007.7875nQ
\n", "
" ], "text/plain": [ " pclass survived sibsp parch fare cabin embarked\n", "1000 3 1 0 0 7.7500 n Q\n", "1001 3 1 2 0 23.2500 n Q\n", "1002 3 1 2 0 23.2500 n Q\n", "1003 3 1 2 0 23.2500 n Q\n", "1004 3 1 0 0 7.7875 n Q" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's drop some columns in the test set for the demo\n", "test_t = test.drop([\"sex\", \"age\"], axis=1)\n", "\n", "test_t.head()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The following variables are added to the DataFrame: ['sex', 'age']\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclasssurvivedsexagesibspparchfarecabinembarked
100031NaNNaN007.7500nQ
100131NaNNaN2023.2500nQ
100231NaNNaN2023.2500nQ
100331NaNNaN2023.2500nQ
100431NaNNaN007.7875nQ
\n", "
" ], "text/plain": [ " pclass survived sex age sibsp parch fare cabin embarked\n", "1000 3 1 NaN NaN 0 0 7.7500 n Q\n", "1001 3 1 NaN NaN 2 0 23.2500 n Q\n", "1002 3 1 NaN NaN 2 0 23.2500 n Q\n", "1003 3 1 NaN NaN 2 0 23.2500 n Q\n", "1004 3 1 NaN NaN 0 0 7.7875 n Q" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# the transformer adds the columns back\n", "test_tt = match_cols.transform(test_t)\n", "\n", "print()\n", "test_tt.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note how the missing columns were added back to the transformed test set, with\n", "missing values, in the position (i.e., order) in which they were in the train set.\n", "\n", "Similarly, if the test set contained additional columns, those would be removed:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test set contains variables not present in train set" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclasssurvivedsibspparchfarecabinembarkednew_col1new_col2
100031007.7500nQ5test
1001312023.2500nQ5test
1002312023.2500nQ5test
1003312023.2500nQ5test
100431007.7875nQ5test
\n", "
" ], "text/plain": [ " pclass survived sibsp parch fare cabin embarked new_col1 new_col2\n", "1000 3 1 0 0 7.7500 n Q 5 test\n", "1001 3 1 2 0 23.2500 n Q 5 test\n", "1002 3 1 2 0 23.2500 n Q 5 test\n", "1003 3 1 2 0 23.2500 n Q 5 test\n", "1004 3 1 0 0 7.7875 n Q 5 test" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_t.loc[:, \"new_col1\"] = 5\n", "test_t.loc[:, \"new_col2\"] = \"test\"\n", "\n", "test_t.head()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MatchVariables(fill_value=0, missing_values='ignore')" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# set up the transformer with different\n", "# fill value\n", "match_cols = MatchVariables(\n", " fill_value=0, missing_values=\"ignore\",\n", ")\n", "\n", "# learn the variables in the train set\n", "match_cols.fit(train)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The following variables are added to the DataFrame: ['sex', 'age']\n", "The following variables are dropped from the DataFrame: ['new_col2', 'new_col1']\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pclasssurvivedsexagesibspparchfarecabinembarked
10003100007.7500nQ
100131002023.2500nQ
100231002023.2500nQ
100331002023.2500nQ
10043100007.7875nQ
\n", "
" ], "text/plain": [ " pclass survived sex age sibsp parch fare cabin embarked\n", "1000 3 1 0 0 0 0 7.7500 n Q\n", "1001 3 1 0 0 2 0 23.2500 n Q\n", "1002 3 1 0 0 2 0 23.2500 n Q\n", "1003 3 1 0 0 2 0 23.2500 n Q\n", "1004 3 1 0 0 0 0 7.7875 n Q" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_tt = match_cols.transform(test_t)\n", "\n", "print()\n", "test_tt.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note how the columns that were present in the test set but not in train set were dropped. And now, the missing variables were added back into the dataset with the value 0." ] } ], "metadata": { "interpreter": { "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" }, "kernelspec": { "display_name": "fenotebook", "language": "python", "name": "fenotebook" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }