{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom methods in `DropCorrelatedFeatures`\n", "\n", "In this tutorial we show how to pass a custom method to `DropCorrelatedFeatures` using the association measure [Distance Correlation](https://m-clark.github.io/docs/CorrelationComparison.pdf) from the python package [dcor](https://dcor.readthedocs.io/en/latest/index.html)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import dcor\n", "import warnings\n", "\n", "from sklearn.datasets import make_classification\n", "from feature_engine.selection import DropCorrelatedFeatures\n", "\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
var_0var_1var_2var_3var_4var_5var_6var_7var_8var_9var_10var_11
0-0.718421-0.3064300.4773371.6626511.621889-0.2260392.089741-2.1450332.6167780.0744771.4026621.599289
10.584286-0.8718701.4902903.6449213.584239-0.750463-0.024631-4.5250425.5185341.7885933.0777933.188758
2-1.644619-0.3919610.8911212.2327052.175168-0.278656-1.145170-2.8977883.535246-0.7966621.8832992.178584
31.795776-2.6453681.5683211.4494911.754788-3.2269230.6263740.238043-0.3102981.2472121.256478-2.376344
4-0.683522-1.420178-0.1201771.0198031.171396-1.708503-0.114110-0.2234240.2622470.3226120.877768-0.972715
.......................................
9950.379855-0.529128-0.0933612.6685572.608481-0.410322-1.343059-3.4097124.159278-1.2875482.2518012.507712
9960.410435-1.5903860.3015890.9620021.140932-1.9310620.0100150.011464-0.025811-1.1249700.831563-1.315063
9970.562542-0.173591-0.5513231.4569961.407670-0.077131-1.215225-1.9638632.3965591.6787601.2278211.551989
9980.187248-0.355866-1.3855391.3041381.288720-0.3244600.260543-1.5801151.926655-1.3300301.1018431.071300
9990.105134-2.9828150.3096572.0856682.406926-3.593946-0.339890-0.3875220.451001-0.2218391.796291-2.113529
\n", "

1000 rows × 12 columns

\n", "
" ], "text/plain": [ " var_0 var_1 var_2 var_3 var_4 var_5 var_6 \\\n", "0 -0.718421 -0.306430 0.477337 1.662651 1.621889 -0.226039 2.089741 \n", "1 0.584286 -0.871870 1.490290 3.644921 3.584239 -0.750463 -0.024631 \n", "2 -1.644619 -0.391961 0.891121 2.232705 2.175168 -0.278656 -1.145170 \n", "3 1.795776 -2.645368 1.568321 1.449491 1.754788 -3.226923 0.626374 \n", "4 -0.683522 -1.420178 -0.120177 1.019803 1.171396 -1.708503 -0.114110 \n", ".. ... ... ... ... ... ... ... \n", "995 0.379855 -0.529128 -0.093361 2.668557 2.608481 -0.410322 -1.343059 \n", "996 0.410435 -1.590386 0.301589 0.962002 1.140932 -1.931062 0.010015 \n", "997 0.562542 -0.173591 -0.551323 1.456996 1.407670 -0.077131 -1.215225 \n", "998 0.187248 -0.355866 -1.385539 1.304138 1.288720 -0.324460 0.260543 \n", "999 0.105134 -2.982815 0.309657 2.085668 2.406926 -3.593946 -0.339890 \n", "\n", " var_7 var_8 var_9 var_10 var_11 \n", "0 -2.145033 2.616778 0.074477 1.402662 1.599289 \n", "1 -4.525042 5.518534 1.788593 3.077793 3.188758 \n", "2 -2.897788 3.535246 -0.796662 1.883299 2.178584 \n", "3 0.238043 -0.310298 1.247212 1.256478 -2.376344 \n", "4 -0.223424 0.262247 0.322612 0.877768 -0.972715 \n", ".. ... ... ... ... ... \n", "995 -3.409712 4.159278 -1.287548 2.251801 2.507712 \n", "996 0.011464 -0.025811 -1.124970 0.831563 -1.315063 \n", "997 -1.963863 2.396559 1.678760 1.227821 1.551989 \n", "998 -1.580115 1.926655 -1.330030 1.101843 1.071300 \n", "999 -0.387522 0.451001 -0.221839 1.796291 -2.113529 \n", "\n", "[1000 rows x 12 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, _ = make_classification(\n", " n_samples=1000,\n", " n_features=12,\n", " n_redundant=6,\n", " n_clusters_per_class=1,\n", " weights=[0.50],\n", " class_sep=2,\n", " random_state=1,\n", ")\n", "\n", "colnames = [\"var_\" + str(i) for i in range(12)]\n", "X = pd.DataFrame(X, columns=colnames)\n", "\n", "X" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
var_0var_1var_2var_3var_6var_7var_9
0-0.718421-0.3064300.4773371.6626512.089741-2.1450330.074477
10.584286-0.8718701.4902903.644921-0.024631-4.5250421.788593
2-1.644619-0.3919610.8911212.232705-1.145170-2.897788-0.796662
31.795776-2.6453681.5683211.4494910.6263740.2380431.247212
4-0.683522-1.420178-0.1201771.019803-0.114110-0.2234240.322612
........................
9950.379855-0.529128-0.0933612.668557-1.343059-3.409712-1.287548
9960.410435-1.5903860.3015890.9620020.0100150.011464-1.124970
9970.562542-0.173591-0.5513231.456996-1.215225-1.9638631.678760
9980.187248-0.355866-1.3855391.3041380.260543-1.580115-1.330030
9990.105134-2.9828150.3096572.085668-0.339890-0.387522-0.221839
\n", "

1000 rows × 7 columns

\n", "
" ], "text/plain": [ " var_0 var_1 var_2 var_3 var_6 var_7 var_9\n", "0 -0.718421 -0.306430 0.477337 1.662651 2.089741 -2.145033 0.074477\n", "1 0.584286 -0.871870 1.490290 3.644921 -0.024631 -4.525042 1.788593\n", "2 -1.644619 -0.391961 0.891121 2.232705 -1.145170 -2.897788 -0.796662\n", "3 1.795776 -2.645368 1.568321 1.449491 0.626374 0.238043 1.247212\n", "4 -0.683522 -1.420178 -0.120177 1.019803 -0.114110 -0.223424 0.322612\n", ".. ... ... ... ... ... ... ...\n", "995 0.379855 -0.529128 -0.093361 2.668557 -1.343059 -3.409712 -1.287548\n", "996 0.410435 -1.590386 0.301589 0.962002 0.010015 0.011464 -1.124970\n", "997 0.562542 -0.173591 -0.551323 1.456996 -1.215225 -1.963863 1.678760\n", "998 0.187248 -0.355866 -1.385539 1.304138 0.260543 -1.580115 -1.330030\n", "999 0.105134 -2.982815 0.309657 2.085668 -0.339890 -0.387522 -0.221839\n", "\n", "[1000 rows x 7 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dcor_tr = DropCorrelatedFeatures(\n", " variables=None, method=dcor.distance_correlation, threshold=0.8\n", ")\n", "\n", "X_dcor = dcor_tr.fit_transform(X)\n", "\n", "X_dcor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the next example, we use the function [sklearn.feature_selection.mutual_info_regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.mutual_info_regression.html#sklearn.feature_selection.mutual_info_regression) to calculate the Mutual Information between two numerical variables, dropping any features with a score below 0.8.\n", "\n", "Remember that the callable should take as input two 1d ndarrays and output a float value, we define a custom function calling the sklearn method." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from sklearn.feature_selection import mutual_info_regression\n", "\n", "def custom_mi(x, y):\n", " x = x.reshape(-1, 1)\n", " y = y.reshape(-1, 1)\n", " return mutual_info_regression(x, y)[0] # should return a float value" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
var_0var_1var_2var_3var_6var_7var_9
0-0.718421-0.3064300.4773371.6626512.089741-2.1450330.074477
10.584286-0.8718701.4902903.644921-0.024631-4.5250421.788593
2-1.644619-0.3919610.8911212.232705-1.145170-2.897788-0.796662
31.795776-2.6453681.5683211.4494910.6263740.2380431.247212
4-0.683522-1.420178-0.1201771.019803-0.114110-0.2234240.322612
........................
9950.379855-0.529128-0.0933612.668557-1.343059-3.409712-1.287548
9960.410435-1.5903860.3015890.9620020.0100150.011464-1.124970
9970.562542-0.173591-0.5513231.456996-1.215225-1.9638631.678760
9980.187248-0.355866-1.3855391.3041380.260543-1.580115-1.330030
9990.105134-2.9828150.3096572.085668-0.339890-0.387522-0.221839
\n", "

1000 rows × 7 columns

\n", "
" ], "text/plain": [ " var_0 var_1 var_2 var_3 var_6 var_7 var_9\n", "0 -0.718421 -0.306430 0.477337 1.662651 2.089741 -2.145033 0.074477\n", "1 0.584286 -0.871870 1.490290 3.644921 -0.024631 -4.525042 1.788593\n", "2 -1.644619 -0.391961 0.891121 2.232705 -1.145170 -2.897788 -0.796662\n", "3 1.795776 -2.645368 1.568321 1.449491 0.626374 0.238043 1.247212\n", "4 -0.683522 -1.420178 -0.120177 1.019803 -0.114110 -0.223424 0.322612\n", ".. ... ... ... ... ... ... ...\n", "995 0.379855 -0.529128 -0.093361 2.668557 -1.343059 -3.409712 -1.287548\n", "996 0.410435 -1.590386 0.301589 0.962002 0.010015 0.011464 -1.124970\n", "997 0.562542 -0.173591 -0.551323 1.456996 -1.215225 -1.963863 1.678760\n", "998 0.187248 -0.355866 -1.385539 1.304138 0.260543 -1.580115 -1.330030\n", "999 0.105134 -2.982815 0.309657 2.085668 -0.339890 -0.387522 -0.221839\n", "\n", "[1000 rows x 7 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mi_tr = DropCorrelatedFeatures(\n", " variables=None, method=custom_mi, threshold=0.8\n", ")\n", "\n", "X_mi = mi_tr.fit_transform(X)\n", "X_mi" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }