{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b1bc6250-d62d-401d-9559-afaa870c0065", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "id": "21f27256-2412-43e3-8e90-a408d0eb2881", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.metrics import f1_score\n", "\n", "class ModelSelectionHelper:\n", "\n", " def __init__(self, models, params):\n", " print('init model selection helper notebook')\n", " self.models = models\n", " self.params = params\n", " self.keys = models.keys()\n", " self.grid_searches = {}\n", " \n", " def fit(self, X, y, scoring=None, n_jobs=3, refit=True, cv=5, verbose=1, return_train_score=True):\n", " for key in self.keys:\n", " print('---------------------------------------------------------------------------')\n", " print(key)\n", " \n", " model = self.models[key]\n", " params = self.params[key]\n", " \n", " grid_search = GridSearchCV(model, params, cv=cv, n_jobs=n_jobs,\n", " verbose=verbose, scoring=scoring, refit=refit,\n", " return_train_score=return_train_score)\n", " \n", " grid_search.fit(X, y)\n", " \n", " self.grid_searches[key] = grid_search\n", " \n", " #best_params_ and best_score_ are only available if refit=True\n", " if (refit==True): \n", " print('best_params_: ', grid_search.best_params_)\n", " print('best_score_: ', grid_search.best_score_)\n", " \n", " def get_model_best_estimator(self, key):\n", " return self.grid_searches[key].best_estimator_" ] }, { "cell_type": "code", "execution_count": null, "id": "7586a157-bdec-4363-b30c-5e9a7ad13074", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "856bb41b-9b4d-4a98-8d55-80811256a220", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }