{"cells":[{"cell_type":"markdown","metadata":{"application/vnd.databricks.v1+cell":{"inputWidgets":{},"nuid":"48e4f2cb-f8c1-4511-a4ff-5cb614a5a435","showTitle":false,"title":""}},"source":["MRMR Feature Selection by Maykon Schots & Matheus Rugollo"]},{"cell_type":"markdown","metadata":{"application/vnd.databricks.v1+cell":{"inputWidgets":{},"nuid":"95086c54-0214-4ffd-b75c-0b181f9f494c","showTitle":false,"title":""}},"source":["
\n"," | feat_0 | \n","feat_1 | \n","feat_2 | \n","feat_3 | \n","feat_4 | \n","feat_5 | \n","feat_6 | \n","feat_7 | \n","feat_8 | \n","feat_9 | \n","feat_10 | \n","feat_11 | \n","feat_12 | \n","feat_13 | \n","feat_14 | \n","feat_15 | \n","feat_16 | \n","feat_17 | \n","feat_18 | \n","feat_19 | \n","feat_20 | \n","feat_21 | \n","feat_22 | \n","feat_23 | \n","feat_24 | \n","feat_25 | \n","feat_26 | \n","feat_27 | \n","feat_28 | \n","feat_29 | \n","
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n","-0.855484 | \n","0.147268 | \n","0.931779 | \n","0.514342 | \n","-1.160567 | \n","1.548771 | \n","-0.814841 | \n","-0.551259 | \n","-0.559991 | \n","-0.588341 | \n","-0.050275 | \n","-0.418050 | \n","-1.883618 | \n","0.627457 | \n","-1.430858 | \n","2.414968 | \n","0.965552 | \n","-0.427275 | \n","1.575164 | \n","1.386096 | \n","1.635452 | \n","0.171738 | \n","-0.988298 | \n","2.129211 | \n","1.502658 | \n","1.451028 | \n","-0.159674 | \n","1.999985 | \n","-0.088617 | \n","2.173897 | \n","
1 | \n","-0.233143 | \n","-0.253765 | \n","0.740147 | \n","0.816885 | \n","0.169952 | \n","1.594853 | \n","-0.432744 | \n","-0.118465 | \n","-0.245680 | \n","0.318794 | \n","-0.307883 | \n","-0.391124 | \n","1.137396 | \n","-1.150796 | \n","-1.642437 | \n","2.108450 | \n","1.168464 | \n","-0.840639 | \n","1.788039 | \n","0.134206 | \n","-0.498645 | \n","-0.199357 | \n","-1.260340 | \n","2.202738 | \n","1.859376 | \n","1.392850 | \n","-0.521572 | \n","1.733263 | \n","0.399896 | \n","2.562269 | \n","
2 | \n","-1.004689 | \n","-2.525201 | \n","-0.066370 | \n","1.708870 | \n","-1.836983 | \n","-0.200895 | \n","2.617520 | \n","0.494441 | \n","2.118569 | \n","-1.165920 | \n","-1.576650 | \n","0.674405 | \n","-0.687477 | \n","0.020001 | \n","-0.859114 | \n","-2.652188 | \n","0.951480 | \n","0.646614 | \n","0.821866 | \n","1.885669 | \n","-0.361392 | \n","-2.347812 | \n","-1.371678 | \n","-0.213288 | \n","1.733786 | \n","-0.814734 | \n","0.175997 | \n","-2.276026 | \n","3.047572 | \n","-1.306679 | \n","
3 | \n","-0.039887 | \n","-2.002593 | \n","-0.137059 | \n","1.352156 | \n","-0.283016 | \n","-0.166958 | \n","2.079097 | \n","0.316852 | \n","1.682286 | \n","0.208227 | \n","-1.249645 | \n","-0.460059 | \n","-0.315972 | \n","-0.031462 | \n","-0.673953 | \n","-2.114534 | \n","0.749490 | \n","0.440944 | \n","0.643708 | \n","0.300287 | \n","-1.195257 | \n","-1.862089 | \n","-1.082490 | \n","-0.179667 | \n","1.366995 | \n","-0.653095 | \n","0.799519 | \n","-1.814268 | \n","2.416413 | \n","0.303902 | \n","
4 | \n","1.650810 | \n","0.594325 | \n","0.268430 | \n","0.154022 | \n","-0.626083 | \n","1.439684 | \n","-1.215922 | \n","-1.373762 | \n","-0.893985 | \n","0.043517 | \n","0.242237 | \n","0.462302 | \n","-1.067515 | \n","-0.677758 | \n","-1.139388 | \n","2.671788 | \n","0.701012 | \n","1.002342 | \n","1.276911 | \n","0.194160 | \n","1.528878 | \n","0.584114 | \n","-0.644920 | \n","1.967760 | \n","1.044779 | \n","1.463181 | \n","0.461469 | \n","2.227199 | \n","-0.636540 | \n","0.948157 | \n","
\n | feat_0 | \nfeat_1 | \nfeat_2 | \nfeat_3 | \nfeat_4 | \nfeat_5 | \nfeat_6 | \nfeat_7 | \nfeat_8 | \nfeat_9 | \nfeat_10 | \nfeat_11 | \nfeat_12 | \nfeat_13 | \nfeat_14 | \nfeat_15 | \nfeat_16 | \nfeat_17 | \nfeat_18 | \nfeat_19 | \nfeat_20 | \nfeat_21 | \nfeat_22 | \nfeat_23 | \nfeat_24 | \nfeat_25 | \nfeat_26 | \nfeat_27 | \nfeat_28 | \nfeat_29 | \n
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n-0.855484 | \n0.147268 | \n0.931779 | \n0.514342 | \n-1.160567 | \n1.548771 | \n-0.814841 | \n-0.551259 | \n-0.559991 | \n-0.588341 | \n-0.050275 | \n-0.418050 | \n-1.883618 | \n0.627457 | \n-1.430858 | \n2.414968 | \n0.965552 | \n-0.427275 | \n1.575164 | \n1.386096 | \n1.635452 | \n0.171738 | \n-0.988298 | \n2.129211 | \n1.502658 | \n1.451028 | \n-0.159674 | \n1.999985 | \n-0.088617 | \n2.173897 | \n
1 | \n-0.233143 | \n-0.253765 | \n0.740147 | \n0.816885 | \n0.169952 | \n1.594853 | \n-0.432744 | \n-0.118465 | \n-0.245680 | \n0.318794 | \n-0.307883 | \n-0.391124 | \n1.137396 | \n-1.150796 | \n-1.642437 | \n2.108450 | \n1.168464 | \n-0.840639 | \n1.788039 | \n0.134206 | \n-0.498645 | \n-0.199357 | \n-1.260340 | \n2.202738 | \n1.859376 | \n1.392850 | \n-0.521572 | \n1.733263 | \n0.399896 | \n2.562269 | \n
2 | \n-1.004689 | \n-2.525201 | \n-0.066370 | \n1.708870 | \n-1.836983 | \n-0.200895 | \n2.617520 | \n0.494441 | \n2.118569 | \n-1.165920 | \n-1.576650 | \n0.674405 | \n-0.687477 | \n0.020001 | \n-0.859114 | \n-2.652188 | \n0.951480 | \n0.646614 | \n0.821866 | \n1.885669 | \n-0.361392 | \n-2.347812 | \n-1.371678 | \n-0.213288 | \n1.733786 | \n-0.814734 | \n0.175997 | \n-2.276026 | \n3.047572 | \n-1.306679 | \n
3 | \n-0.039887 | \n-2.002593 | \n-0.137059 | \n1.352156 | \n-0.283016 | \n-0.166958 | \n2.079097 | \n0.316852 | \n1.682286 | \n0.208227 | \n-1.249645 | \n-0.460059 | \n-0.315972 | \n-0.031462 | \n-0.673953 | \n-2.114534 | \n0.749490 | \n0.440944 | \n0.643708 | \n0.300287 | \n-1.195257 | \n-1.862089 | \n-1.082490 | \n-0.179667 | \n1.366995 | \n-0.653095 | \n0.799519 | \n-1.814268 | \n2.416413 | \n0.303902 | \n
4 | \n1.650810 | \n0.594325 | \n0.268430 | \n0.154022 | \n-0.626083 | \n1.439684 | \n-1.215922 | \n-1.373762 | \n-0.893985 | \n0.043517 | \n0.242237 | \n0.462302 | \n-1.067515 | \n-0.677758 | \n-1.139388 | \n2.671788 | \n0.701012 | \n1.002342 | \n1.276911 | \n0.194160 | \n1.528878 | \n0.584114 | \n-0.644920 | \n1.967760 | \n1.044779 | \n1.463181 | \n0.461469 | \n2.227199 | \n-0.636540 | \n0.948157 | \n
Going Further on implementing a robust feature selection with MRMR , we can use the process explained above to iterate over a range of threshold and choose what's best for our needs instead of a simple score performance evaluation!
"]},{"cell_type":"code","execution_count":null,"metadata":{"application/vnd.databricks.v1+cell":{"inputWidgets":{},"nuid":"f07796e9-8eb2-4462-86df-9bd309b8360f","showTitle":false,"title":""}},"outputs":[{"data":{"text/html":["\n",""]},"metadata":{"application/vnd.databricks.v1+output":{"addedWidgets":{},"arguments":{},"data":"","datasetInfos":[],"metadata":{},"removedWidgets":[],"type":"html"}},"output_type":"display_data"}],"source":["# Repeat df from example.\n","\n","import warnings\n","\n","import pandas as pd\n","from sklearn.datasets import make_classification\n","\n","warnings.filterwarnings('ignore')\n","\n","X, y = make_classification(\n"," n_samples=5000,\n"," n_features=30,\n"," n_redundant=15,\n"," n_clusters_per_class=1,\n"," weights=[0.50],\n"," class_sep=2,\n"," random_state=42\n",")\n","\n","cols = []\n","for i in range(len(X[0])):\n"," cols.append(f\"feat_{i}\")\n","X = pd.DataFrame(X, columns=cols)\n","y = pd.DataFrame({\"y\": y})\n"]},{"cell_type":"code","execution_count":null,"metadata":{"application/vnd.databricks.v1+cell":{"inputWidgets":{},"nuid":"abb44eed-4177-4525-a609-8b03d2b3c687","showTitle":false,"title":""}},"outputs":[{"data":{"text/html":["\n",""]},"metadata":{"application/vnd.databricks.v1+output":{"addedWidgets":{},"arguments":{},"data":"","datasetInfos":[],"metadata":{},"removedWidgets":[],"type":"html"}},"output_type":"display_data"}],"source":["# Functions to iterate over accepted threshold\n","from sklearn.feature_selection import (\n"," SelectKBest,\n"," mutual_info_classif,\n"," mutual_info_regression,\n",")\n","import os\n","import multiprocessing\n","\n","from sklearn.ensemble import RandomForestClassifier\n","from sklearn.model_selection import StratifiedKFold, cross_validate\n","\n","import pandas as pd\n","from feature_engine.selection import SmartCorrelatedSelection\n","\n","\n","def select_features_clf(X: pd.DataFrame, y: pd.DataFrame, corr_threshold: float) -> list:\n"," \"\"\" Function will select a set of features with minimum redundance and maximum relevante based on the set correlation threshold \"\"\"\n"," # Setup Smart Selector /// Tks feature_engine\n"," feature_selector = SmartCorrelatedSelection(\n"," variables=None,\n"," method=\"spearman\",\n"," threshold=corr_threshold,\n"," missing_values=\"ignore\",\n"," selection_method=\"variance\",\n"," estimator=None,\n"," )\n"," feature_selector.fit_transform(X)\n"," ### Setup a list of correlated clusters as lists and a list of uncorrelated features\n"," correlated_sets = feature_selector.correlated_feature_sets_\n"," correlated_clusters = [list(feature) for feature in correlated_sets]\n"," correlated_features = [feature for features in correlated_clusters for feature in features]\n"," uncorrelated_features = [feature for feature in X if feature not in correlated_features]\n"," top_features_cluster = []\n"," for cluster in correlated_clusters:\n"," selector = SelectKBest(score_func=mutual_info_classif, k=1) # selects the top feature (k=1) regarding target mutual information\n"," selector = selector.fit(X[cluster], y)\n"," top_features_cluster.append(\n"," list(selector.get_feature_names_out())[0]\n"," )\n"," return top_features_cluster + uncorrelated_features\n","\n","def get_clf_model_scores(X: pd.DataFrame, y: pd.DataFrame, scoring: str, selected_features:list):\n"," \"\"\" \"\"\"\n"," cv = StratifiedKFold(shuffle=True, random_state=42) \n"," model_result = cross_validate(\n"," RandomForestClassifier(),\n"," X[selected_features],\n"," y,\n"," cv=cv,\n"," scoring=scoring,\n"," groups=None,\n"," error_score=\"raise\",\n"," )\n"," return model_result[\"test_score\"].mean(), model_result[\"fit_time\"].mean(), model_result[\"score_time\"].mean()\n","\n","def evaluate_clf_feature_selection_range(X: pd.DataFrame, y: pd.DataFrame, scoring:str, corr_range: int, corr_starting_point: float = .98) -> pd.DataFrame:\n"," \"\"\" Evaluates feature selection for every .01 on corr threshold \"\"\"\n"," evaluation_data = {\n"," \"corr_threshold\": [],\n"," scoring: [],\n"," \"n_features\": [],\n"," \"fit_time\": [],\n"," \"score_time\": []\n"," }\n"," for i in range(corr_range):\n"," current_corr_threshold = corr_starting_point - (i / 100) ## Reduces .01 on corr_threshold for every iteration\n"," selected_features = select_features_clf(X, y, corr_threshold=current_corr_threshold)\n"," score, fit_time, score_time = get_clf_model_scores(X, y, scoring, selected_features)\n"," evaluation_data[\"corr_threshold\"].append(current_corr_threshold)\n"," evaluation_data[scoring].append(score)\n"," evaluation_data[\"n_features\"].append(len(selected_features))\n"," evaluation_data[\"fit_time\"].append(fit_time)\n"," evaluation_data[\"score_time\"].append(score_time)\n"," \n"," return pd.DataFrame(evaluation_data)\n"," \n"]},{"cell_type":"code","execution_count":null,"metadata":{"application/vnd.databricks.v1+cell":{"inputWidgets":{},"nuid":"4bcd731c-c905-4480-bd63-88ba2871e7a1","showTitle":false,"title":""}},"outputs":[{"data":{"text/html":["\n",""]},"metadata":{"application/vnd.databricks.v1+output":{"addedWidgets":{},"arguments":{},"data":"","datasetInfos":[],"metadata":{},"removedWidgets":[],"type":"html"}},"output_type":"display_data"}],"source":["evaluation_df = evaluate_clf_feature_selection_range(X, y, \"f1\", 15)"]},{"cell_type":"code","execution_count":null,"metadata":{"application/vnd.databricks.v1+cell":{"inputWidgets":{},"nuid":"10a32ffa-c69d-49f3-abe5-b16ba4994022","showTitle":false,"title":""}},"outputs":[],"source":["%pip install hiplot"]},{"cell_type":"code","execution_count":null,"metadata":{"application/vnd.databricks.v1+cell":{"inputWidgets":{},"nuid":"bf6d3834-7549-4e7a-80b0-69e59ee6da5a","showTitle":false,"title":""}},"outputs":[{"data":{"text/html":["\n","\n","\n","\n","\n","