{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# TODO\n", "\n", "- Update commentary to be more intensive\n", "- Finish hyperparameter tuning\n", "- Add note on overfitting with random forests & gradient boosted trees\n", "\n", "# Ensemble Methods\n", "\n", "One quote I heard in the beginning of my machine learning journey is \"if you don't have a favorite algorithm yet, pick the random forest\". I'm glad that I heard that early on, because it has proven itself multiple times along with the other algorithms in the ensemble family.\n", "\n", "There are a few good reasons why ensemble methods are my favorite family of models. Not only are they often extremely powerful in predictive performance (often topping the leaderboards on Kaggle competitions for structured data), but they still maintain some semblance of interpretability and can often be parallelized to utilize all cores of a CPU.\n", "\n", "I recently gave a [talk on ensemble methods](https://github.com/JeffMacaluso/Talks/blob/master/EnsembleMethods/EnsembleMethods.pptx). at a local meetup group, and this is a more flushed out version of the hands-on portion that includes a more difficult dataset and more complete hyperparameter tuning.\n", "\n", "## Overview\n", "\n", "In this post, we'll train a few ensemble models on an artificial dataset for binary classification. We'll use scikit-learn to compare a few different types of ensemble methods, and then use XGBoost and LightGBM for more specialized implementations of gradient boosting. Additionally, we'll go over hyperparameter tuning and discuss a few strategies for tuning ensemble models.\n", "\n", "## Setup\n", "\n", "The setup here is largely a series of import statements, creating an artificial classification dataset with [scikit-learn's make_classification function](http://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html), and then creating a function to train our models and gather various metrics." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T00:39:58.966450Z", "start_time": "2018-07-12T00:39:37.655238Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2018/07/11 19:39\n", "OS: win32\n", "Python: 3.5.5 | packaged by conda-forge | (default, Apr 6 2018, 16:03:44) [MSC v.1900 64 bit (AMD64)]\n", "NumPy: 1.12.1\n", "Pandas: 0.23.1\n" ] } ], "source": [ "import sys\n", "import time\n", "import scipy\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from sklearn import datasets\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn import ensemble\n", "from sklearn import linear_model\n", "from sklearn import metrics\n", "\n", "print(time.strftime('%Y/%m/%d %H:%M'))\n", "print('OS:', sys.platform)\n", "print('Python:', sys.version)\n", "print('NumPy:', np.__version__)\n", "print('Pandas:', pd.__version__)\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creating an artificial data set with [scikit-learn's make_classification function](http://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html).\n", "\n", "**TODO: Increase the size of the dataset and re-run**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:04:01.008869Z", "start_time": "2018-07-12T23:04:00.696389Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...212223242526272829label
0-2.0990070.787135-0.686102-1.2288300.723556-0.311079-0.9528260.867668-1.6425710.624871...0.987041-0.3777890.242981-0.792679-1.715772-0.420792-1.7291521.308256-0.7005611
1-0.801402-3.8517691.5388050.5657830.5261720.7525650.014558-0.240075-1.4791381.819075...-0.153594-0.324068-2.0602201.5414811.297861-1.2289420.4946062.1679530.1784361
2-5.662407-5.8531611.625716-0.3905931.1992841.888906-1.019720-1.3926503.012919-1.139037...-0.335733-0.468439-1.9960232.419778-1.558457-0.539612-1.1595663.362889-0.8912730
31.5692565.984285-1.678201-0.6013250.4708340.688409-2.392620-0.946743-2.7136601.422514...-0.161714-1.189745-0.837363-0.825927-1.654660-0.3395401.2169200.1454221.4598361
42.8818645.945795-1.627379-1.361672-0.773137-0.071754-4.0750991.901061-4.2944250.795730...0.656566-0.235963-2.1461110.5940492.2904430.330266-0.019847-5.7707430.8155811
\n", "

5 rows × 31 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 \\\n", "0 -2.099007 0.787135 -0.686102 -1.228830 0.723556 -0.311079 -0.952826 \n", "1 -0.801402 -3.851769 1.538805 0.565783 0.526172 0.752565 0.014558 \n", "2 -5.662407 -5.853161 1.625716 -0.390593 1.199284 1.888906 -1.019720 \n", "3 1.569256 5.984285 -1.678201 -0.601325 0.470834 0.688409 -2.392620 \n", "4 2.881864 5.945795 -1.627379 -1.361672 -0.773137 -0.071754 -4.075099 \n", "\n", " 7 8 9 ... 21 22 23 \\\n", "0 0.867668 -1.642571 0.624871 ... 0.987041 -0.377789 0.242981 \n", "1 -0.240075 -1.479138 1.819075 ... -0.153594 -0.324068 -2.060220 \n", "2 -1.392650 3.012919 -1.139037 ... -0.335733 -0.468439 -1.996023 \n", "3 -0.946743 -2.713660 1.422514 ... -0.161714 -1.189745 -0.837363 \n", "4 1.901061 -4.294425 0.795730 ... 0.656566 -0.235963 -2.146111 \n", "\n", " 24 25 26 27 28 29 label \n", "0 -0.792679 -1.715772 -0.420792 -1.729152 1.308256 -0.700561 1 \n", "1 1.541481 1.297861 -1.228942 0.494606 2.167953 0.178436 1 \n", "2 2.419778 -1.558457 -0.539612 -1.159566 3.362889 -0.891273 0 \n", "3 -0.825927 -1.654660 -0.339540 1.216920 0.145422 1.459836 1 \n", "4 0.594049 2.290443 0.330266 -0.019847 -5.770743 0.815581 1 \n", "\n", "[5 rows x 31 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Creating an artificial dataset to test algorithms on\n", "data = datasets.make_classification(#n_samples=300000,\n", " n_samples=3000,\n", " n_classes=2,\n", " n_features=30,\n", " n_informative=10,\n", " n_redundant=5, # Superfluous features working as noise for the algorithms\n", " flip_y=0.5, # Introduces additional noise\n", " class_sep=0.7, \n", " n_clusters_per_class=10,\n", " random_state=46)\n", "\n", "# Assigning features/labels to variables for ease of use\n", "X = data[0] # Features\n", "y = data[1] # Label\n", "\n", "# Train/test split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=46)\n", "\n", "# Putting into a dataframe for viewing\n", "df = pd.DataFrame(X)\n", "df['label'] = y\n", "\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to adhere to [DRY typing](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself), we'll create a function to train our models and gather the accuracy, [AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve), [log loss](https://en.wikipedia.org/wiki/Cross_entropy), and model training time." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:04:07.289737Z", "start_time": "2018-07-12T23:04:07.242865Z" } }, "outputs": [], "source": [ "# Data frame for gathering results \n", "results = pd.DataFrame(columns=['Accuracy', 'LogLoss', 'AUC', 'TrainingTime'])\n", "tuned_results = pd.DataFrame(columns=['Accuracy', 'LogLoss', 'AUC', 'TrainingTime', 'NumIterations'])\n", "\n", "# Function for training a model and retrieving the results\n", "def train_model_get_results(model, model_name):\n", " '''\n", " Trains a model and appends the results to the results dataframe\n", " \n", " Input:\n", " - model: The model with specified hyperparameters to be trained\n", " - model_name: The name of the model to be used as the index\n", " - is_tuned: A binary flag for if hyperparameter tuning has been performed\n", " \n", " Output: The results dataframe with the model results added\n", " \n", " Note: Only works with scikit-learn models and frameworks that integrate \n", " with the scikit-learn API\n", " '''\n", " \n", " # Collecting training time for results\n", " start_time = time.time()\n", " \n", " print('Training the model')\n", " model.fit(X_train, y_train)\n", " \n", " end_time = time.time()\n", " total_training_time = end_time - start_time\n", " print('Completed')\n", " \n", " # Calculating the testing set accuracy with the score method\n", " accuracy = model.score(X_test, y_test)\n", " \n", " # Calcuating the AUC and log loss with predicted probabilities\n", " class_probabilities = model.predict_proba(X_test)\n", " log_loss = metrics.log_loss(y_test, class_probabilities)\n", " auc = metrics.roc_auc_score(y_test, class_probabilities[:, 1])\n", " \n", " # Adding the model results to the results dataframe\n", " model_results = [accuracy, log_loss, auc, total_training_time]\n", " results.loc[model_name] = model_results\n", " \n", " print('\\n', 'Non-tuned results:')\n", " return results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Baseline\n", "\n", "It's always useful to have a baseline to compare against and let us know generally how difficult a problem is going to be. I like to use linear or logistic regression due to each them being extremely fast to train." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:04:10.477054Z", "start_time": "2018-07-12T23:04:10.414547Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training the model\n", "Completed\n", "\n", " Non-tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.5055560.6979550.4912370.015628
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.505556 0.697955 0.491237 0.015628" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Instantiating the model\n", "logistic_regression = linear_model.LogisticRegression()\n", "\n", "# Using our user defined function to train the model and return the results\n", "train_model_get_results(model=logistic_regression, model_name='Logistic Regression')" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2018-06-30T21:52:49.624668Z", "start_time": "2018-06-30T21:52:49.593513Z" } }, "source": [ "## Bagging\n", "\n", "Bagging (bootstrap aggregating) is the technique that aggregates models built with bootstrapping, or sampling with replacement, via a majority vote or by averaging the predictions. The trees are independent of each other and can be built in parallel. \n", "\n", "Bagging models tend to decrease variance.\n", "\n", "### Random Forest\n", "\n", "The most popular bagging algorithm is the **random forest**. This algorithm works by building a series of decision trees where each tree uses a random selection of variables, and then decision trees vote on the final answer. \n", "\n", "More specifically, for each tree:\n", "\n", "- Use a different training sample with replacement (bootstrapping) for the data\n", "- For each node, choose a number of random attributes and find the best split\n", "- Typically is not pruned in order to have a smaller bias\n", "\n", "Once these trees are grown, a majority vote among all of the trees will be used to make predictions.\n", "\n", "The main ideas here are that the randomness makes a set of diverse models that helps improve accuracy and using random subsets of features to consider at each split helps make it more efficient to train.\n", "\n", "\n", "\n", "**Advantages:**\n", "- Robustness against over-fitting\n", " - Since the model is created through dense randomness, the generalization is typically better, and you can usually increase the accuracy with the number of trees up until a saturation point\n", "- Able to parallelize training multiple trees at once and thus speed up training time" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:04:48.256008Z", "start_time": "2018-07-12T23:04:47.834157Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training the model\n", "Completed\n", "\n", " Non-tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.5055560.6979550.4912370.015628
Random Forest0.5233330.8106270.5323430.171863
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.505556 0.697955 0.491237 0.015628\n", "Random Forest 0.523333 0.810627 0.532343 0.171863" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random_forest = ensemble.RandomForestClassifier(n_jobs=-1) # n_jobs=-1 uses all available cores\n", "\n", "train_model_get_results(random_forest, model_name='Random Forest')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Boosting\n", "\n", "Boosting methods train a sequence of weak learners (a learner that is barely better than random chance) where each successive model focuses on the parts that the previous model got wrong. The trees have to be built in a sequence and generally cannot be built in parallel without clever tricks.\n", "\n", "Boosting models tend to decrease bias.\n", "\n", "### Gradient Boosting\n", "\n", "While there are a few different boosting algorithms, gradient boosting is arguably the most popular. It's main differentiation from the others is that it uses gradient descent to decide what to focus on in order to minimize loss for the new trees being built in the sequence. This typically gives it performance advantages over other boosting algorithms.\n", "\n", "\n", "\n", "*Source: [BigML](https://blog.bigml.com/2017/03/14/introduction-to-boosted-trees/)*\n", "\n", "**Advantages:**\n", "- Can often outperform random forests when properly tuned\n", "\n", "**Disadvantages:**\n", "- Typically overfits easier than bagging\n", "- Sensitive to noise & extreme values\n", "- Has to be built sequentially, so cannot parallelize without tricks" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:04:49.927780Z", "start_time": "2018-07-12T23:04:48.849720Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training the model\n", "Completed\n", "\n", " Non-tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.5055560.6979550.4912370.015628
Random Forest0.5233330.8106270.5323430.171863
Gradient Boosted Trees0.5466670.6934030.5610601.031190
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.505556 0.697955 0.491237 0.015628\n", "Random Forest 0.523333 0.810627 0.532343 0.171863\n", "Gradient Boosted Trees 0.546667 0.693403 0.561060 1.031190" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gradient_boosting = ensemble.GradientBoostingClassifier()\n", "\n", "train_model_get_results(gradient_boosting, model_name='Gradient Boosted Trees')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Note on interpretability\n", "\n", "It's possible to obtain \"feature importance\" from both bagging and boosting methods. These are not as interpretable as coefficients from linear/logistic regressions, but can still give us an idea of what is happening. \n", "\n", "Note that the multicollinearity assumption applies here - these interpretations will be misleading if the features are heavily correlated with each other." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:04:50.084020Z", "start_time": "2018-07-12T23:04:49.675Z" } }, "outputs": [], "source": [ "def feature_importance(model):\n", " '''\n", " Plots the feature importance for an ensemble model from scikit-learn\n", " '''\n", " feature_importance = model.feature_importances_\n", " feature_importance = 100.0 * (feature_importance / feature_importance.max())\n", " sorted_idx = np.argsort(feature_importance)\n", " pos = np.arange(sorted_idx.shape[0]) + .5\n", " plt.figure(figsize=(15, 8))\n", " plt.subplot(1, 2, 2)\n", " plt.barh(pos, feature_importance[sorted_idx], align='center')\n", " plt.yticks(pos, sorted_idx)\n", " plt.xlabel('Relative Importance')\n", " plt.title('Variable Importance')\n", " plt.show()\n", " \n", "\n", "feature_importance(gradient_boosting)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stacking\n", "\n", "Stacking is the final ensemble technique where we combine several different models into a chain of sorts. It is structured similarly to a neural network where layers of models provide predictions that the next layer then uses as inputs. Ultimately, the meta-classifier creates a final prediction.\n", "\n", "\n", "*Source: [Anshul Joshi](https://www.quora.com/What-is-stacking-in-machine-learning)*\n", "\n", "This is a little more nuanced than blending models (averaging their predictions for a final prediction) as the meta-learner learns how useful each of the models are.\n", "\n", "**Advantages:**\n", "- Can be more performant when properly tuned\n", "\n", "**Disadvantages:**\n", "- Much more computationally costly\n", "- More difficult to tune\n", "- Complete loss of interpretability\n", "\n", "We'll need another function that is similar to our previous one for training the models and getting the results. In this case, we'll deal with one layer of classifiers and use a logistic regression for the meta-learner. We'll use five different algorithms for the first layer, but this function is designed to accept any number of models." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:04:52.990094Z", "start_time": "2018-07-12T23:04:51.005839Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training the model\n", "Completed\n", "\n", "Coefficients for models\n", "Model 1: -2.9603327700929984\n", "Model 2: 6.827364713943086\n", "Model 3: 6.967087517500586\n", "Model 4: 0.7215877036971426\n", "Model 5: 0.49496178638295857\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.5055560.6979550.4912370.015628
Random Forest0.5233330.8106270.5323430.171863
Gradient Boosted Trees0.5466670.6934030.5610601.031190
Stacking0.5200000.9940010.5334641.593773
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.505556 0.697955 0.491237 0.015628\n", "Random Forest 0.523333 0.810627 0.532343 0.171863\n", "Gradient Boosted Trees 0.546667 0.693403 0.561060 1.031190\n", "Stacking 0.520000 0.994001 0.533464 1.593773" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def train_stacking_get_results(list_of_models):\n", " '''\n", " Trains a stacking classifier and appends the results to the rsults dataframe\n", " \n", " Input: list_of_models: a list of untrained scikit-learn models\n", " \n", " Output: The results dataframe with the model results added\n", " \n", " Note: Only works with scikit-learn models and frameworks that integrate \n", " with the scikit-learn API\n", " '''\n", " # The meta learner is the one that takes the outputs from\n", " # the other models as input before final classification\n", " meta_learner = linear_model.LogisticRegression()\n", "\n", " # Collecting training time for results\n", " start_time = time.time()\n", " print('Training the model')\n", "\n", " # Fitting the first layer models\n", " for model in list_of_models:\n", " model.fit(X_train, y_train)\n", "\n", " # Collecting the predictions from the models for training\n", " model_output = []\n", "\n", " for model in list_of_models:\n", " class_probabilities = model.predict_proba(X_train)[:, 1]\n", " model_output.append(class_probabilities)\n", "\n", " # Re-shaping before passing to the meta learner\n", " X_train_meta = np.array(model_output).transpose()\n", "\n", " # Fitting the meta learner\n", " meta_learner.fit(X_train_meta, y_train)\n", "\n", " end_time = time.time()\n", " total_time = end_time - start_time\n", " print('Completed')\n", "\n", " # Collecting the predictions from the models for testing\n", " model_output = []\n", "\n", " for model in list_of_models:\n", " class_probabilities = model.predict_proba(X_test)[:, 1]\n", " model_output.append(class_probabilities)\n", "\n", " # Re-shaping before passing to the meta learner\n", " X_test_meta = np.array(model_output).transpose()\n", "\n", " # Collecting the accuracy from the meta learner\n", " accuracy = meta_learner.score(X_test_meta, y_test)\n", "\n", " # Calcuating the log loss with predicted probabilities\n", " class_probabilities = meta_learner.predict_proba(X_test_meta)\n", " log_loss = metrics.log_loss(y_test, class_probabilities)\n", " auc = metrics.roc_auc_score(y_test, class_probabilities[:, 1])\n", "\n", " # Printing coefficients of models\n", " print()\n", " print('Coefficients for models')\n", " for i, coef in enumerate(meta_learner.coef_[0]):\n", " print('Model {0}: {1}'.format( i+1, coef))\n", " \n", " model_results = [accuracy, log_loss, auc, total_time]\n", " results.loc['Stacking'] = model_results\n", "\n", " return results\n", "\n", "\n", "# Adding extra imports for additional models\n", "from sklearn import neighbors\n", "\n", "# Defining the learners for the first layer\n", "model_1 = linear_model.LogisticRegression()\n", "model_2 = ensemble.RandomForestClassifier(n_jobs=-1)\n", "model_3 = ensemble.RandomForestClassifier(n_jobs=-1)\n", "model_4 = ensemble.GradientBoostingClassifier()\n", "model_5 = neighbors.KNeighborsClassifier(n_jobs=-1)\n", "\n", "# Putting the models in a list to iterate through in the function\n", "models = [model_1, model_2, model_3, model_4, model_5]\n", "\n", "# Running our function to build a stacking model\n", "train_stacking_get_results(models)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter Tuning\n", "\n", "There two main methodologies for hyperparameter tuning: \n", "1. Manually testing hypotheses on how changing certain hyperparameters will impact the performance of the model\n", "2. Automatically checking a bunch of different combinations of hyperparameters using either a grid search or a randomized search\n", "\n", "For this post, we will discuss a few strategies for the first option, and then go with the second option by using a randomized search. \n", "\n", "Between grid search and random search, grid search generally makes more intuitive sense. However, research from [James Bergstra and Yoshua Bengio](http://jmlr.csail.mit.edu/papers/volume13/bergstra12a/bergstra12a.pdf) have shown that random search tends to converge to good hyperparameters faster than grid search. Here's a graphic from their paper that gives an intuitive example of how random search can potentially cover more ground when there are hyperparameters that aren't as important:\n", "\n", "\n", "\n", "*Source: [James Bergstra & Yoshua Bengio](http://jmlr.csail.mit.edu/papers/volume13/bergstra12a/bergstra12a.pdf)*\n", "\n", "## Hyperparameters & Decision Tree Structure\n", "\n", "Because both random forests and gradient boosted trees use decision trees for their underlying structures, their hyperparameters are largely the same. Here's a recap of the decision tree structure and a quick summary of what each of the hyperparameters we'll be tuning are:\n", "\n", "\n", "\n", "*Source: [Murtuza Morbiwala](http://insightfromdata.blogspot.com/2012/06/decision-tree-unembellished.html)*\n", "\n", "\n", "### Hyperparameters\n", "\n", "This is list is not all-inclusive, but has most of the common hyperparameters:\n", "\n", "- **Number of Estimators:** The number of decision trees to be trained\n", " - A higher number typically means better predictions (at the cost of computational power) up until a saturation point where the model begins to overfit\n", "- **Max Depth:** How deep a tree can be\n", " - This should ideally be low for gradient boosting and large (or none) for random forests\n", "- **Minimum Samples per Split:** The minimum samples considered to split a node\n", " - A higher number typically results in better performance at the cost of computational efficiency\n", "- **Minimum Samples per Leaf:** The minimum number of samples required to be a leaf node\n", " - A lower number could potentially result in more noise being captured\n", "- **Max Features:** The number of features to consider when looking for the best split\n", " - A lower number typically reduces variance/increases bias and improves computational efficiency\n", "- **Max Leaf Nodes:** The maximum number of leaf nodes for the tree\n", " - A smaller number could help prevent overfitting\n", "- **Learning Rate (gradient boosting only):** The adjustment/step size for each iteration\n", " - A larger step size can help get better performance in fewer iterations, but will plateau at a lower performance\n", " - A smaller step size will require more iterations (number of estimators) but will ultimately achieve a better performance\n", "\n", "Here is an illustration on what a learning rate is and how too small or large of a learning rate can have adverse impacts:\n", "\n", "\n", "*Source: [Jeremy Jordan](https://www.jeremyjordan.me/nn-learning-rate/)*\n", "\n", "Here is a more visual version of these hyperparameters on a tree: \n", "\n", "\n", "*Source: [Analytics Vidhya](https://www.analyticsvidhya.com/blog/2016/02/complete-guide-parameter-tuning-gradient-boosting-gbm-python/)*\n", "\n", "### General Strategies\n", "\n", "Most strategies are specific to either random forests or gradient boosting, but there are a few strategies that apply to both.\n", "\n", "- Increase the number of estimators until either just before overfitting begins to start occurring or there are severely diminishing returns in performance\n", " - Compare the performance against the default parameters to see if this helps and how much\n", "- Further adjust the model complexity (starting with tree depth)\n", " - Decrease the complexity of the trees if you suspect the model is suffering from high variance\n", " - Increase the complexity of the trees if you suspect the model is suffering from high bias\n", " \n", "Remember that hyperparameter tuning is all about controlling model complexity in order to achieve the optimal state in the bias-variance tradeoff:\n", "\n", "\n", "*Source: [Satya Mallick](https://www.learnopencv.com/bias-variance-tradeoff-in-machine-learning/)*\n", "\n", "### Setup\n", "\n", "In order to do the actual hyperparameter tuning we need to create our third and final function. This will take a model, a dictionary of parameters, perform a random search for the number of iterations, and then give us our results." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:04:57.458575Z", "start_time": "2018-07-12T23:04:57.427465Z" } }, "outputs": [], "source": [ "def hyperparameter_tune_get_results(model, parameters, model_name, num_rounds=30):\n", " '''\n", " Performs a random search to find optimal hyperparameters and append the results\n", " to the tuned_results dataframe\n", " \n", " Input: \n", " - model: A scikit-learn model\n", " - parameters: A dictionary of parameters for the model\n", " - model_name: A string of the model name for the tuned_results dataframe\n", " - num_rounds: The number of rounds to try different hyperparameters\n", " \n", " Output: The tuned_results dataframe with the results appended\n", " '''\n", " \n", " # Reporting the default parameters before tuning\n", " print('Default Parameters:', '\\n')\n", " print(model, '\\n')\n", " \n", " # Defining the random search cross validation\n", " random_search = RandomizedSearchCV(model,\n", " param_distributions=parameters,\n", " n_iter=num_rounds, n_jobs=-1, cv=3,\n", " return_train_score=True, random_state=46,\n", " verbose=10) # Set to 20 to print the status of each completed fit\n", " \n", " print('Beginning hyperparameter tuning')\n", " start_time = time.time()\n", " random_search.fit(X_train, y_train)\n", " end_time = time.time()\n", " total_training_time = end_time - start_time\n", " print('Completed')\n", " \n", " # Calculating the testing set accuracy on the best estimator with the score method\n", " accuracy = random_search.best_estimator_.score(X_test, y_test)\n", " \n", " # Calcuating the log loss with predicted probabilities\n", " class_probabilities = random_search.best_estimator_.predict_proba(X_test)\n", " log_loss = metrics.log_loss(y_test, class_probabilities)\n", " auc = metrics.roc_auc_score(y_test, class_probabilities[:, 1])\n", " \n", " # Adding the model results to the results dataframe\n", " model_results = [accuracy, log_loss, auc, total_training_time, num_rounds]\n", " tuned_results.loc[model_name] = model_results\n", " \n", " # Plotting the mean training accuracy from the different iterations\n", " sns.distplot(random_search.cv_results_['mean_test_score'])\n", " plt.title('Mean test score')\n", " \n", " print('Best estimator:', '\\n')\n", " print(random_search.best_estimator_)\n", " \n", " print()\n", " print('Accuracy before tuning:', results.loc[model_name]['Accuracy'])\n", " print('Accuracy after tuning:', tuned_results.loc[model_name]['Accuracy'])\n", " \n", " print('\\n', 'Tuned results:')\n", " return tuned_results" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2018-06-30T23:13:57.881984Z", "start_time": "2018-06-30T23:13:57.866361Z" } }, "source": [ "## Baseline\n", "\n", "For our logistic regression model, we're just going to tune the regularization parameter. One of the advantages of simpler models like this is that they are easier to tune because we don't have nearly as many hyperparameters to worry about.\n", "\n", "**Note: The number of rounds is being kept small in these examples to keep within time limits for the talk, but increase them in a real-world scenario for more effective hyperparameter tuning**" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:05:07.442344Z", "start_time": "2018-07-12T23:04:58.770997Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Default Parameters: \n", "\n", "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", " verbose=0, warm_start=False) \n", "\n", "Beginning hyperparameter tuning\n", "Fitting 3 folds for each of 10 candidates, totalling 30 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Done 5 tasks | elapsed: 7.0s\n", "[Parallel(n_jobs=-1)]: Done 10 tasks | elapsed: 7.1s\n", "[Parallel(n_jobs=-1)]: Done 17 tasks | elapsed: 7.3s\n", "[Parallel(n_jobs=-1)]: Done 27 out of 30 | elapsed: 7.8s remaining: 0.8s\n", "[Parallel(n_jobs=-1)]: Done 30 out of 30 | elapsed: 7.9s finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Completed\n", "Best estimator: \n", "\n", "LogisticRegression(C=6.8421729839625272, class_weight=None, dual=False,\n", " fit_intercept=True, intercept_scaling=1, max_iter=100,\n", " multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,\n", " solver='liblinear', tol=0.0001, verbose=0, warm_start=False)\n", "\n", "Accuracy before tuning: 0.505555555556\n", "Accuracy after tuning: 0.505555555556\n", "\n", " Tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTimeNumIterations
Logistic Regression0.5055560.6979670.4912228.29637110.0
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime NumIterations\n", "Logistic Regression 0.505556 0.697967 0.491222 8.296371 10.0" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEICAYAAABWJCMKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8VOW9x/HPLwlJCASyEJYk7IR9E4KAu1CLWCtoxVqrUq+Wq1VvW9te7W5rW7t469qquBWta10Rtyoq4gKyyL6GsCSEJSEsCZCEJM/9Yw41YoCQZHImOd/36zWvmTnnOXN+T5b5zlnmOeacQ0REgifK7wJERMQfCgARkYBSAIiIBJQCQEQkoBQAIiIBpQAQEQkoBYCISEApAMRXZrbJzCrMrMMR05eYmTOzHk1cz1lmlt9Ir/W+mV3TGK8lEg4KAIkEG4FvHX5iZkOA1v6V0zKYWYzfNUhkUwBIJHgCuLLG86nA4zUbmFmcmd1hZlvMbIeZPWBmrb15yWY2y8wKzWy39zizxrLvm9ltZvaRmZWY2b+P3OLw2rUB3gDSzazUu6WbWZSZ3WJmG8xsl5k9Z2Yp3jLxZvZPb/oeM1tgZp3M7PfA6cB93uvcV8v6al3Wm5diZo+ZWYHXp5drLPddM8sxs2Izm2lm6TXmOTO73szWA+u9af3N7G2v/Vozu+TEf0XSEikAJBLMA9qZ2QAziwa+CfzziDZ/AvoCw4E+QAbwK29eFPAY0B3oBhwEjnzDvQy4CugIxAI/PrII59x+YCJQ4Jxr690KgP8BJgNnAunAbuBv3mJTgfZAVyAVuBY46Jz7OTAXuMF7nRtq6Xety3rzngASgEFezXcCmNk44HbgEqALsBl45ojXnQyMBgZ6ofY28JT3Ot8C/m5mg2qpRwJGASCR4vBWwDnAGmDr4RlmZsB3gR8654qdcyXAH4BLAZxzu5xzLzjnDnjzfk/ozbqmx5xz65xzB4HnCAVJXf038HPnXL5zrhy4FbjY28VyiNCbdx/nXJVzbpFzbl8dX7fWZc2sC6EgutY5t9s5d8g5N8db5tvAo865xV4tPwXGHnGs5Hbv53QQOB/Y5Jx7zDlX6ZxbDLwAXHwC/ZcWSvsIJVI8AXwA9OSI3T9AGqFPw4tCWQCAAdEAZpZA6BPyuUCyNz/RzKKdc1Xe8+01Xu8A0PYEausOvGRm1TWmVQGdvLq7As+YWRKhLZefO+cO1eF1a13Wm1bsnNtdyzLpwOLDT5xzpWa2i9AW0SZvct4RtY82sz01psV465aA0xaARATn3GZCB4PPA148YnYRoV0jg5xzSd6tvXPu8Jv4j4B+wGjnXDvgDG+6ceJqGx43D5hYY91Jzrl459xW79P5b5xzA4FTCH3ivvIYr/X5io6+bB6Q4oXCkQoIvakD/zlukUqNLaYj1psHzDmi9rbOueuOVZsEgwJAIsnVwDhvX/x/OOeqgYeAO82sI4CZZZjZBK9JIqGA2OMdnP11A2rYAaSaWfsa0x4Afm9m3b11p5nZJO/x2WY2xDt2sY/Qbp2qGq/V62grOtqyzrlthA5G/907wN3KzA6H2lPAVWY23MziCO0Km++c23SU1cwC+prZFd7rtDKzUWY24ER/MNLyKAAkYjjnNjjnFh5l9s1ADjDPzPYB7xD61A9wF6HTRosIHVB+swE1rAGeBnK9M3PSgbuBmcC/zazEW8dob5HOwPOE3sBXA3P4/AD23YSOFew2s3tqWd2xlr2CUCCsAXYCP/Dqmw38ktB+/G1Ab7xjIUfpTwnwVa9NAaFdYX8C4ur8Q5EWy3RBGBGRYNIWgIhIQCkAREQCSgEgIhJQCgARkYCK6C+CdejQwfXo0cPvMkREmpVFixYVOefSjtcuogOgR48eLFx4tLMCRUSkNma2uS7ttAtIRCSgjhsAZvaome00sxU1pv3FzNaY2TIze6nmV9bN7KfeULVra3xTEzM715uWY2a3NH5XRETkRNRlC+AfhAbZqultYLBzbiiwjtCIhJjZQELfOBzkLfN3M4v2vur+N0IjHA4EvuW1FRERnxw3AJxzHwDFR0z7t3Ou0ns6Dzh88Y1JwDPOuXLn3EZCX90/2bvlOOdynXMVhMYvn9RIfRARkXpojGMA/0Vo4CoIDUlbcyjafG/a0aZ/iZlNM7OFZrawsLCwEcoTEZHaNCgAzOznQCXw5OFJtTRzx5j+5YnOTXfOZTvnstPSjnsWk4iI1FO9TwM1s6mExi8f7z4fUS6f0MUsDsskNAIhx5guIiI+qNcWgJmdS2h43guccwdqzJoJXGqhC3j3BLKAT4EFQJaZ9TSzWEIHimc2rHQREWmI424BmNnTwFlABzPLJ3SxjZ8SGk/8be8SffOcc9c651aa2XPAKkK7hq4/fEk+M7sBeIvQZfwedc6tDEN/RESkjiL6egDZ2dlO3wQWqbun5m/xu4QvuWx0N79LCBwzW+Scyz5eO30TWEQkoBQAIiIBpQAQEQkoBYCISEApAEREAkoBICISUAoAEZGAUgCIiASUAkBEJKAUACIiAaUAEBEJKAWAiEhAKQBERAJKASAiElAKABGRgFIAiIgElAJARCSgFAAiIgGlABARCSgFgIhIQCkAREQCSgEgIhJQCgARkYBSAIiIBJQCQEQkoBQAIiIBpQAQEQmo4waAmT1qZjvNbEWNaSlm9raZrffuk73pZmb3mFmOmS0zsxE1lpnqtV9vZlPD0x0REamrumwB/AM494hptwCznXNZwGzvOcBEIMu7TQPuh1BgAL8GRgMnA78+HBoiIuKP4waAc+4DoPiIyZOAGd7jGcDkGtMfdyHzgCQz6wJMAN52zhU753YDb/PlUBERkSZU32MAnZxz2wC8+47e9Awgr0a7fG/a0aZ/iZlNM7OFZrawsLCwnuWJiMjxNPZBYKtlmjvG9C9PdG66cy7bOZedlpbWqMWJiMjn6hsAO7xdO3j3O73p+UDXGu0ygYJjTBcREZ/UNwBmAofP5JkKvFJj+pXe2UBjgL3eLqK3gK+aWbJ38Per3jQREfFJzPEamNnTwFlABzPLJ3Q2zx+B58zsamALMMVr/jpwHpADHACuAnDOFZvZbcACr91vnXNHHlgWEZEmdNwAcM596yizxtfS1gHXH+V1HgUePaHqREQkbPRNYBGRgFIAiIgElAJARCSgFAAiIgGlABARCSgFgIhIQCkAREQCSgEgIhJQCgARkYBSAIiIBJQCQEQkoBQAIiIBpQAQEQkoBYCISEApAEREAkoBICISUAoAEZGAUgCIiASUAkBEJKAUACIiAaUAEBEJKAWAiEhAKQBERAJKASAiElAKABGRgFIAiIgEVIMCwMx+aGYrzWyFmT1tZvFm1tPM5pvZejN71sxivbZx3vMcb36PxuiAiIjUT70DwMwygP8Bsp1zg4Fo4FLgT8CdzrksYDdwtbfI1cBu51wf4E6vnYiI+KShu4BigNZmFgMkANuAccDz3vwZwGTv8STvOd788WZmDVy/iIjUU70DwDm3FbgD2ELojX8vsAjY45yr9JrlAxne4wwgz1u20mufeuTrmtk0M1toZgsLCwvrW56IiBxHQ3YBJRP6VN8TSAfaABNraeoOL3KMeZ9PcG66cy7bOZedlpZW3/JEROQ4GrIL6CvARudcoXPuEPAicAqQ5O0SAsgECrzH+UBXAG9+e6C4AesXEZEGaEgAbAHGmFmCty9/PLAKeA+42GszFXjFezzTe443/13n3Je2AEREpGk05BjAfEIHcxcDy73Xmg7cDNxkZjmE9vE/4i3yCJDqTb8JuKUBdYuISAPFHL/J0Tnnfg38+ojJucDJtbQtA6Y0ZH0iItJ49E1gEZGAUgCIiASUAkBEJKAUACIiAaUAEBEJKAWAiEhAKQBERAJKASAiElAKABGRgFIAiIgElAJARCSgFAAiIgGlABARCSgFgIhIQCkAREQCSgEgIhJQCgARkYBSAIiIBJQCQEQkoBQAIiIBpQAQEQkoBYCISEApAEREAkoBICISUAoAEZGAUgCIiARUgwLAzJLM7HkzW2Nmq81srJmlmNnbZrbeu0/22pqZ3WNmOWa2zMxGNE4XRESkPhq6BXA38KZzrj8wDFgN3ALMds5lAbO95wATgSzvNg24v4HrFhGRBqh3AJhZO+AM4BEA51yFc24PMAmY4TWbAUz2Hk8CHnch84AkM+tS78pFRKRBGrIF0AsoBB4zs8/M7GEzawN0cs5tA/DuO3rtM4C8Gsvne9O+wMymmdlCM1tYWFjYgPJEpDGUV1ax7+AhdpWWc6Ci0u9ypBHFNHDZEcCNzrn5ZnY3n+/uqY3VMs19aYJz04HpANnZ2V+aLyLh5ZxjY9F+lm/dy6Zd+9mxr/wL89vFx5Ce1JqTuiUzoEsiMVE6l6S5akgA5AP5zrn53vPnCQXADjPr4pzb5u3i2Vmjfdcay2cCBQ1Yv4g0Iuccy7buZe66Qgr2lhEbHUX31ASGZLSnTVwMraKjKC2rZMe+MnKL9rPm0y20iY3mjL5pnNK7A9FRtX3Gk0hW7wBwzm03szwz6+ecWwuMB1Z5t6nAH737V7xFZgI3mNkzwGhg7+FdRSLir6LScl7+bCu5RftJS4zjwuEZDO+WRKvo2j/dVzvH+h2lfLyhiDdWbOezLXuYPDydbqltmrhyaYiGbAEA3Ag8aWaxQC5wFaHjCs+Z2dXAFmCK1/Z14DwgBzjgtRURn83fuIvXlm0jJtqYPDyD7B7JRNmxP81HmdGvcyJ9O7Vl9bZ9vLpsG9Pn5nL+0HTG9EptosqloRoUAM65JUB2LbPG19LWAdc3ZH0i0ngqq6t5dek2Fmwqpm+ntlw0IpN28a1O6DXMjIHp7emV1pZnF+Qxc2kBO0vK+NqQdO0SagYaugUgIs1Q+aEqHp+3mY1F+zmzbxrnDOx03E/9xxLfKporxnbnrRXbmZtTxMGKKqZkd23Qa0r4KQBEAuZgRRX/+HgjW/cc5JLsTIZ3TW6U140yY+KQLrSOjebfq3bQKjqKySd96UxviSAKAJEAOVhRxSMf5bJjbzmXndyNgentG30dZ/XryKGqat5bW0h8q2guH9O90dchjUMn8IoExKGqap6Yt4kde8u5fEz3sLz5H/aVAZ0Y0yuVD3OKeH5RftjWIw2jABAJgGrneG5hHpt2HWBKdib9OieGdX1mxteGdKFXWht+9tJyluTtCev6pH4UACIB8Mbybaws2MfXhnRhaGZSk6wzOsr41qhudEyM49onFlG8v6JJ1it1pwAQaeEWbd7NRxt2cUrvVE7t06FJ190mLoYHLh9J8f4KfvriMkJng0ukUACItGB5xQd4eclWeqe1YeJgfwbfHZzRnh9P6MtbK3fwLx0PiCgKAJEWqrS8kifnb6ZdfAzfGtXN1y9mXXNaL8b2SuU3M1eyedd+3+qQL1IAiLRA1c7x/KI8DlRU8e3R3UmI8/eM76go4/8uGUZUlHHLC8u1KyhCKABEWqCPcopYt6OU84Z0IT2ptd/lAJCe1Jqbz+3PJ7m7eGHxVr/LERQAIi1OXvEB3lq5nUHp7RjdM8Xvcr7gspO7MbJ7Mr9/bZXOCooACgCRFqSisprnFuaRGN+Ki07KxCJsLJ6oKOMPFw6hpKyS3722yu9yAk8BINKCvLFiG7v2V3DxyExax0b7XU6t+nVOZNoZvXhx8VYWb9ntdzmBpgAQaSHmrCtk/sZiTuvTgd5pbf0u55iuP7sPHRPj+M2rq6iu1gFhvygARFqAkrJD3Pz8MjomxnHOwE5+l3NcbeJiuPnc/izN28NLn+mAsF8UACItwB/fWMPOkjK+MSLzqJdxjDQXnpTBsK5J/OnNNewvr/S7nEBqHn8pInJU83N38eT8LVx1ak+6piT4XU6dRUUZvzp/IDtLynlobq7f5QSSAkCkGSs7VMVPX1xO15TW/Oirff0u54SN7J7MuYM689AHuRSVlvtdTuAoAESasXtmrye3aD+3XziUhNjmeX2nn5zbj7LKau57N8fvUgJHASDSTK0s2MuDH+Ry8chMTstq2lE+G1PvtLZckp3Jk/M3s2XXAb/LCRQFgEgzVFlVzc0vLCM5IZZffG2A3+U02PfH9yU6yrjznXV+lxIoCgCRZuixjzaxYus+fnPBIJISYv0up8E6t4/nyrE9eGXJVjYUlvpdTmAoAESamYI9B7nznXWM79+R84Z09rucRjPtjF7ExURz7+z1fpcSGAoAkWbmtlmrqHaOWy8YFHFj/TREh7ZxXDm2OzOXFpCzU1sBTUEBINKMvL92J2+s2M6N47Ka1Tn/dfVdbyvgvne1FdAUGhwAZhZtZp+Z2SzveU8zm29m683sWTOL9abHec9zvPk9GrpukSApO1TFr2eupFeHNlxzek+/ywkLbQU0rcbYAvg+sLrG8z8BdzrnsoDdwNXe9KuB3c65PsCdXjsRqaMH5mxg864D/HbSYOJiInOkz8agrYCm06AAMLNM4GvAw95zA8YBz3tNZgCTvceTvOd488dbS9qBKRJGm3ft5+/vb+Drw9Kb9Tn/daGtgKbT0C2Au4D/Baq956nAHufc4ZGd8oEM73EGkAfgzd/rtf8CM5tmZgvNbGFhYWEDyxNp/pxz/OqVlcRGR7WIc/7r4j9nBGkrIKzqHQBmdj6w0zm3qObkWpq6Osz7fIJz051z2c657LS0tPqWJ9JivLVyO3PWFXLTOX3p1C7e73KaRGrbOK485fBWQInf5bRYDdkCOBW4wMw2Ac8Q2vVzF5BkZocHJckECrzH+UBXAG9+e6C4AesXafEOVFTy21dX0b9zIleO7e53OU1q2um9iI+J5m/vbfC7lBar3gHgnPupcy7TOdcDuBR41zn3beA94GKv2VTgFe/xTO853vx3nXO6FJDIMdz3bg4Fe8u4bfJgYprJOP+NJbVtHJeP6cYrS7ayedd+v8tpkcLxF3UzcJOZ5RDax/+IN/0RINWbfhNwSxjWLdJi5BaW8tDcXL4xIpNRPVL8LscX3z29FzHRUdz/vrYCwqFRxo91zr0PvO89zgVOrqVNGTClMdYn0tI55/j1zJXEt4rmlon9/S7HNx3bxXPpqK48/ekWbhyfRUZSa79LalGCtU0p0ky8uWI7c9cX8aNz+pKWGOd3Ob767zN74xxMn6OtgMamABCJMAcqKvntrFUM6NKOy8cE68BvbTKSWvONEZk8vSCPnfvK/C6nRVEAiESYe9/NYdveMm6bNChwB36P5rqzelNZVa1rBzcy/XWJRJCcnaU8PDd0la/sgB74rU2PDm2YNDyDf87bQvH+Cr/LaTEUACIRwjnHrTrwe1TfO6s3ZZVVPPrhRr9LaTEUACIR4vXl2/kwp4gff7UfHdoG+8BvbbI6JTJxcGdmfLyJvQcP+V1Oi9Aop4FKy/HU/C1+l1Cry0Z387uEsNpfXslts1YxsEs7vt3C+9oQ15/dh9eXb+fxjzdx4/gsv8tp9rQFIBIB7nl3Pdv3lXHbZB34PZZB6e0Z378jj3y0kf3llcdfQI5Jf2kiPsvZWcIjczcyZWQmI7vrwO/xXD+uD3sOHOLJ+Zv9LqXZUwCI+OjwN34TYqO5WQd+62REt2RO69OB6R9spOxQld/lNGsKABEfvbZ8Gx/l7OLHE3Tg90TcMK4PRaXlPLsgz+9SmjUFgIhPSsoOcdusVQxKb8e3R+sbvydiTK9UTu6RwgNzNlBRWX38BaRWCgARn9zx1lp2lpTzu8mDiY7S1VFP1A3j+rBtbxkvLM73u5RmSwEg4oPFW3bz+LzNXDmmOyd1S/a7nGbp9KwODMtsz9/fz6GySlsB9aEAEGlih6qq+dmLy+mUGM+PJ/Tzu5xmy8y4YVwWecUHmbm04PgLyJcoAESa2MNzN7Jmewm/mTSIxPhWfpfTrI3v35H+nRP523s5VFXrAoMnSgEg0oQ279rPXe+sY8KgTkwY1Nnvcpq9qCjjhnF92FC4nzdXbPe7nGZHASDSRJxz/OLlFbSKjuI3Fwz2u5wWY+LgLvRKa8O9765Hlxk/MQoAkSbyypIC5q4v4icT+tG5fbzf5bQY0VHGDWf3Yc32Emav3ul3Oc2KAkCkCewsKePWV1cyvGuSrvIVBhcMS6drSmvufS9HWwEnQAEgEmbOOX724goOVFRxx5ShOuc/DGKio/jeWX1YmreHueuL/C6n2VAAiITZi4u38s7qHfzkq/3o0zHR73JarItGZJCR1Jo7/r1WWwF1pAAQCaNtew9y66srye6ezH+d1tPvclq0uJhofnhOX5bl7+UNnRFUJwoAkTBxznHLC8uprHLcMWWYdv00gQtPyqBvp7bc8dZafTu4DhQAImHy7II85qwr5JaJ/enRoY3f5QRCdJTxkwn9yS3az3MLNUbQ8SgARMIgf/cBfvfaasb2SuUKnfXTpL4yoCMjuydz9+x1HKzQ9QKOpd4BYGZdzew9M1ttZivN7Pve9BQze9vM1nv3yd50M7N7zCzHzJaZ2YjG6oRIJKmsquYHzyzBOcefLx5KlHb9NCkz4+Zz+7NjXzn/+HiT3+VEtIZsAVQCP3LODQDGANeb2UDgFmC2cy4LmO09B5gIZHm3acD9DVi3SMT669vrWLh5N3+4aAhdUxL8LieQTu6Zwrj+Hbn//Rz2HjjkdzkRq94B4Jzb5pxb7D0uAVYDGcAkYIbXbAYw2Xs8CXjchcwDksysS70rF4lAc9YV8vf3N3DpqK5MGp7hdzmB9pMJ/Sgpr+T+ORv8LiViNcoxADPrAZwEzAc6Oee2QSgkgI5eswyg5vXb8r1pIi3Cjn1l3PTsEvp1SuTXXx/kdzmBN6BLOyYPz+Cxjzaybe9Bv8uJSDENfQEzawu8APzAObfP7Kj7O2ub8aVva5jZNEK7iOjWrVtDy5OjqK52rN9ZymdbdpNbtJ/te8so3l/B9n1lGNA2LoakhFjSEmPpntKGpIRWHON3G3hV1Y7vP/MZByqquO+yk2gdG+13SQLcdE5fXlu+jT+9sYa7Lj3J73IiToMCwMxaEXrzf9I596I3eYeZdXHObfN28RwenSkf6Fpj8UzgS1dxcM5NB6YDZGdn6+t8jaiispq56wt5bfk23lm1g31llQDExkTRpX08KW1iqaispto5dpaUs+/gnv8kdPvWrRjQpR3DMtvTNSWBKIXBF9z77nrm5Rbzl4uHktVJ3/aNFF1TEph2ei/uey+HK8Z2Z2T3FL9Liij1DgALfRx8BFjtnPtrjVkzganAH737V2pMv8HMngFGA3sP7yqS8CoqLeep+Vt4Yt5mCkvKaRcfwzkDO3NK71RGdE+mR2rCfz7dPzV/y3+Wq6p27CwpY9OuA2zYWcrCTcXMy91FaptYxvZOZWS3ZOJa6ZPuh+uLuHv2ei4akcGU7K7HX0Ca1PfO7s3zi/K5deYqXrn+VJ2VVUNDtgBOBa4AlpvZEm/azwi98T9nZlcDW4Ap3rzXgfOAHOAAcFUD1i11sOdABffP2cCMjzdRdqiaM/umceVF3Tk9K43YmOMf/omOMrq0b02X9q0Z2yuV8kNVrNq2j/kbi5m1bBvvrN7BaX3SOLV3amCDYPOu/Vz/1GKyOrbltkka4z8SJcTG8NPz+vP9Z5bwr0V5fHOUdi0fVu8AcM59SO379QHG19LeAdfXd31Sd4eqqpnx8Sbunr2e0vJKJg/P4Pqz+9CnY9sGvW5cq2hO6pbMSd2SySs+wPvrCnln9Q4+3lDEOQM7MapHSqB2DZWWV/LdxxdiBg9fOYo2cQ0+pCZhcsGwdJ74ZDN/fGMN5wzsTEqbWL9Ligj6i21hFm4q5ucvrWDtjhLO7JvGT8/rT//O7Rp9PV1TErhiTHfyig/w5srtvLKkgMWbdzNpeAbpSa0bfX2RpqracfH9H5Ozs5TvnNKTD3OKQtu2EpHMjN9fOISv3TOXP7y+mjumDPO7pIigoSBaiPLKKv74xhqmPPgJpeWVPHjFSP5x1aiwvPnX1DUlgWtO68kl2ZkUHzjE397L4bVlBZQfarlfwXfOcevMlazZXsL5Q9MbvGUlTaNf50SmndGL5xfl88mGXX6XExEUAC3A2u0lTP7bxzwwZwOXjurGv394BhMGdW6y0zbNjOFdk7npK30Z1TOFjzfs4s531rFm274mWX9Te/CDXJ6Yt5nTszowpleq3+XICbhxXBbdUhL4+UvLKWvBH1LqSgHQjFVXOx6em8vX7/2QwpIyHpmaze0XDfFtX3Tr2GgmD8/g2jN7kxAbw+PzNvP8orwWNSDXcwvy+OMbazh/aBcmDOrsdzlyglrHRvP7CweTW7Sfv769zu9yfKcAaKaK91cw9bFP+d1rqzmjbxpv/uAMxg/o5HdZQGi30PfO7s3Z/dJYkreHu2evY+32Er/LarCZSwu4+cVlnJ7Vgf+7ZFigDni3JKdnpfHt0d14aG4uCzYV+12OrxQAzdCSvD2cf89c5ucW8/sLB/PQlSPp0DbO77K+ICYqinMGdubaM3sT3yqaGZ9s4sXF+c12s/utldu56dkljOqewvQrsomLCeZpry3Fz84bQGZya378r6UcqKj0uxzfKACaEecc/5y3mUse+AQz4/nrxvLt0d0jeoiGzOQErj+7D2f2TWPR5t3cPXs963c2r62BV5Zs5XtPLmZwRnse+U62hnloAdrExfCXi4expfgAv311ld/l+EYB0EwcrKjiR88t5Rcvr2Bs71Rm3XgaQzOT/C6rTlpFRzFhUGhroFV0FI99tImXP9vaLM4UenbBFn7w7BKyuyfzz2tGkxjfyu+SpJGM6ZXKdWf25pkFebyyZKvf5fhC3wNoBjYW7ee6fy5i7Y4SfviVvtw4rk+z/Dp715QEbhzXh3dW7eDDnCLW7SzhGyMy6Z0WeadROuf469vruPfdHM7om8aDl4/UJ/8W6KZz+vLpxmJ+9uJyhmS0p1cE/i2Gk7YAItxbK7dzwb0fsn1fGY99ZxTf/0pWs3zzP6xVdBQTh3Rh2hm9iDbjkQ83MnPpVsorI2droOxQFd9/Zgn3vpvDJdmZPHyldvu0VDHRUdzzrZNoFRPF955czP7yYB0PUABEqMqqam5/YzX//cQieqa1YdaNp3FWv47HX7CZ6J7ahhvHZXGXhHlOAAAMEklEQVRq71Tm5xZz77s5bCza73dZ5BaWMvlvHzFzaQH/e24//vSNoXUaN0mar/Sk1tz1zeGs21HCj/+1lOrq4AxCrL/sCLSzpIzLH5nPg3NyuWx0N/517Vgyk1vepQVjY6L42tB0rjm9FwAPz81l1rICKiqrm7wW5xwvLs7n6/d+yI59ZTx21Si+d1afiD7ALo3nrH4d+dl5A3hjxXbunr3e73KajI4BRJgP1xfxg2c/o7S8kjumDOPikZl+lxR2PTu04X/GZfHmyu18vGEXKwv28dWBnRjWNalJzrUv2HOQn7+0nPfWFpLdPZl7vnVSIMYzki+6+rSerNlewt2z19M9NYGLRrT8/z0FQISorKrm7tnrue+9HPqkteWp746hb4AuLBIbE8UFw9IZktGe15dv41+L8vkwp4iz+3VkYHp4xjPaX17Jw3M3Mv2DDVQ7+NX5A5l6Sg+im/ExFqm/0IBxgynYc5CfPL+MpIRWjOsfGV+uDBcFQATYsa+MG5/+jE83FjNlZCa/mTSIhNhg/mp6dmjDdWf1Zln+Xmav3sFTn26hY2IcUWZcMDydto0wzMWeAxU8syCPh+dupKi0nImDO/Oz8wbQNaXl7WaTExMXE830K7O57KF5XPfPxTxx9WhO7tlyryJmoWH6I1N2drZbuHCh32WE1axlBfzy5RWUV1bzu8mDfd/srHlFML9VVTtWbN3LnHWFbN9XRpvYaM4bEhqD57SsDsSfwEVoKiqr+XhDEa8v38bMpQWUHarmlN6p/HhCP0Z0S65XfZH0s4pkl41ufhdg2VVazpQHP2H73jIenprNKb07+F3SCTGzRc657OO1C+bHzAhQvL+CX768gteWb2NYZnv+75LhGlb4CNFRxrCuSQzNbM+A9HY8NX8Lb67Yzr8W5RPfKoqhGUkM69qerI6JdEmKp0PbOKKjDOdCP9+dJWXk7CxlSd4elmzZQ0l5JW3jYrhgWDrfOaVn2HYtSfOX2jaOZ6aN4fKH53PVYwt48IqRLeosvMMUAD54c8V2fvHycvYePMRPJvTjv8/oRUy0Tsg6GjNjRLdkRnRL5g8XDmH+xl28t6aQz/J2M+PjzVRUHf2soSiDfp3b8fXh6Yzv35FT+5zYloMEV8fEeJ6ZNpbLH57PNTMW8ocLh3DJqJZ1zWcFQBPauucgf3h9Na8t28bgjHb885rRYb9gS0sTGxPF6VlpnJ6VBoR27ezYV0bBnoMU76+g2oHDkZIQS8d2cWQkJehLXFJvKW1ieXraGG54ajH/+8IyNhSW8r/n9m8xJwooAJrAwYoqHpizgQc/2IBz8KNz+nLtWaFxcaRhYmOi6JqSoAO4EjbtW7fise+M4jevruLBD3JZWbCPv35zGB0T4/0urcEUAGHknOPVZdu4/fXVbNtbxteHpXPLxP5k6BxzkWYlJjqK2yYPZlB6O259dSUT75rLHVOGcXb/5n1cQAEQBs45Zq/eyT3vrmdZ/l4GZ7Tj7ktPatGnk4kEwaUnd2Nk92RueOozrvrHAi4Yls4vzx9IWmJkXY+jrhQAjaiisprXlhfw8NyNrCzYR7eUBP5y8VAuGpHZYvYZigRdVqdEZt54Kve/v4G/v7eB99fu5Pqz+zD1lB7N7gQDBUAjyN99gOcW5vPsgi3s2FdOn45t+fPFQ7nwpAzt5xdpgeJiovnBV/py/tB0bpu1itvfWMNjH23i+rN7c/HIrs3mxAMFQD3tOVDBmyu2M2vZNj7aUATAGVlp/OkbPTgjK61ZD9ksInXTp2NbZvzXyXyyYRd/eWsNv3xlJf/39jouO7kbU7K70rNDG79LPCYFQB1VVlWzNH8PH6wr4sOcIpbk7aGq2tEjNYEbx2VxSXZmixyxU0SOb2zvVF647hQWbd7NQ3NzeWDOBv7+/gZGdk/mvCFdOGdAJ7qlRt77gwKgFs458ncfZGXBPlYV7GVFwT4WbCympLwSMxiamcR1Z/bm3MGdGZTeTkMGiwhmRnaPFLJ7pLBjXxkvfbaVlxZv5bZZq7ht1ip6pbVhdM9URvdMYUCXdvTs0Mb3a000eQCY2bnA3UA08LBz7o9NtW7nHAcqqigpq6Sk7BD7yirZc6CCrXsOsnX3QfK9+w2FpZSUha4MFGXQO60tXxvahTP6pnFK71SSEmKbqmQRaYY6tYvn2jN7c+2Zvdmy6wBvr97BRzlFzFpawNOfhsaQahVt9OrQlr6dE+nVoQ0d28XRoW0caYlxpHn34T6o3KQBYGbRwN+Ac4B8YIGZzXTOrWrM9ewqLWfM7bOJMiM6yjDgULWjsqqao13sJzY6ivSkeDKSW/P1YekMSm/HoPT29OuU2GwO6IhI5OmWmsDVp/Xk6tN6UlXtWLu9hHU7Sli7o4T1O0pYkrebV5cWfGm5wRntmHXj6WGtram3AE4GcpxzuQBm9gwwCWjUAIhvFc01p/eiutpRVe1wQEy00SoqirbxMSTGx9AuvhWJ8TEkJcSSnhRPhzZxOnArImEVHWUMTG/3pYEID1VVs6u0gqLScgpLyiksLad1E5xS2tQBkAHk1XieD4yu2cDMpgHTvKelZra2iWqLBB2AIr+L8Mkx+/7tJizEBy36936M312L7vdxhLvv3evSqKkDoLaP2F/YKeOcmw5Mb5pyIouZLazLGN4tkfoevL4Htd8QOX1v6kPQ+UDN8VQzgS/v/BIRkbBr6gBYAGSZWU8ziwUuBWY2cQ0iIkIT7wJyzlWa2Q3AW4ROA33UObeyKWuIcIHc9eVR34MnqP2GCOl7RF8TWEREwkcjlYmIBJQCQEQkoBQAYWRm55rZWjPLMbNbjtHuYjNzZpbtPY81s8fMbLmZLTWzs2q0/aaZLTOzlWb25yboRr0cr+9m9h0zKzSzJd7tmhrzpprZeu82tcb0kd7PJMfM7rEIHIQpTP3+vZnlmVlpU/WjPhq772aWYGavmdka7++9yYaNOVFh+r2/6f3/rzSzB7yRFBqXc063MNwIHeTeAPQCYoGlwMBa2iUCHwDzgGxv2vXAY97jjsAiQmGdCmwB0rx5M4Dxfve1Pn0HvgPcV8uyKUCud5/sPU725n0KjCX0fZI3gIl+97WJ+j0G6AKU+t3Hpuw7kACc7bWJBeZG2u88zL/3dt69AS8AlzZ27doCCJ//DHvhnKsADg97caTbgD8DZTWmDQRmAzjndgJ7gGxCf2DrnHOFXrt3gG+Ep/wGqWvfazMBeNs5V+yc2w28DZxrZl0I/UN84kL/FY8Dk8NRfAM0er8BnHPznHPbwlJx42n0vjvnDjjn3gPwXnMxoe8ORZpw/d73eW1iCAVLo5+xowAIn9qGvcio2cDMTgK6OudmHbHsUmCSmcWYWU9gJKEv0OUA/c2sh5nFEHoD7ErkOW7fPd/wdmc9b2aH+3G0ZTO8x8d7TT+Fo9/NRVj7bmZJwNfxPhhFmLD13czeAnYCJcDzjVo1CoBwOuawF2YWBdwJ/KiWdo8S+kNYCNwFfAxUep8QrgOeJbQ5vAmobNSqG8dxh/wAXgV6OOeGEtqSmXGcZevymn4LR7+bi7D13fuw8zRwj/MGkowwYeu7c24Cod1/ccC4hpf6RQqA8DnesBeJwGDgfTPbRGg/70wzy3bOVTrnfuicG+6cmwQkAesBnHOvOudGO+fGAmsPT48wxx3ywzm3yzlX7j19iNBWzrGWzeeLm/+ROIxIOPrdXISz79OB9c65uxq14sYT1t+7c66M0IgJdd2tVHd+H0BpqTdC++1ygZ58fmBo0DHav8/nB4ETgDbe43OAD2q06+jdJwNLgL5+97U+fQe61Hh8ITDPe5wCbPT6l+w9TvHmLSAUlIcPAp/nd1+bot812kfyQeBw/c5/R+gAaJTffWzKvgNtDy/jvf6zwA2NXrvfP7yWfAPOA9YROkPg59603wIX1NK2ZgD0IPTpfjWhzcXuNdo9Tej6CasIw1kBTdV34HZgpffP8h7Qv8ay/0XoeEcOcFWN6dnACu8178P7Jnsk3cLU7z8T+qRY7d3f6nc/m6LvhD4NO+//YIl3u8bvfjZR3zsR+sCzzFvuXiCmsevWUBAiIgGlYwAiIgGlABARCSgFgIhIQCkAREQCSgEgIhJQCgARkYBSAIiIBNT/A8MHAeI1Vks8AAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "parameters = {'C': scipy.stats.uniform(0, 10), # Uniform distribution between 0 and 10\n", " 'penalty': ['l1', 'l2']\n", " }\n", "\n", "hyperparameter_tune_get_results(model=logistic_regression, parameters=parameters,\n", " model_name='Logistic Regression', num_rounds=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Random Forests\n", "\n", "Because random forests are generally robust to overfitting and there aren't as many parameters to control as there are in gradient boosting, our hyperparameter tuning strategy doesn't have to be as nuanced.\n", "\n", "I've found that increasing the number of trees has the most direct impact on performance. Because the saturation point of overfitting by too many trees is relatively high for random forests, we can usually increase them until our models take too long to train or there isn't much of a performance gain from using more trees. Scikit-learn's random forest implementation only uses 10 by default, but R's [randomForest](https://cran.r-project.org/web/packages/randomForest/randomForest.pdf#page=17) package uses 500 by default.\n", "\n", "That's the first level of complexity to control, so after that it's looking into controlling the max depth for overall model complexity. How this is adjusted depends on if we need to reduce bias or variance.\n", "\n", "We can also control a few other components like the number of features considered for each split or the minimum samples required for each split/leaf, but these may not have as large of an impact as the number of estimators or max depth." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:10:23.063855Z", "start_time": "2018-07-12T23:05:12.582658Z" }, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Default Parameters: \n", "\n", "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=-1,\n", " oob_score=False, random_state=None, verbose=0,\n", " warm_start=False) \n", "\n", "Beginning hyperparameter tuning\n", "Fitting 3 folds for each of 30 candidates, totalling 90 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Done 5 tasks | elapsed: 16.1s\n", "[Parallel(n_jobs=-1)]: Done 10 tasks | elapsed: 56.8s\n", "[Parallel(n_jobs=-1)]: Done 17 tasks | elapsed: 1.2min\n", "[Parallel(n_jobs=-1)]: Done 24 tasks | elapsed: 1.5min\n", "[Parallel(n_jobs=-1)]: Done 33 tasks | elapsed: 1.9min\n", "[Parallel(n_jobs=-1)]: Done 42 tasks | elapsed: 2.7min\n", "[Parallel(n_jobs=-1)]: Done 53 tasks | elapsed: 3.2min\n", "[Parallel(n_jobs=-1)]: Done 64 tasks | elapsed: 3.6min\n", "[Parallel(n_jobs=-1)]: Done 77 tasks | elapsed: 4.6min\n", "[Parallel(n_jobs=-1)]: Done 90 out of 90 | elapsed: 5.0min finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Completed\n", "Best estimator: \n", "\n", "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=30, max_features='log2', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=2, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=831, n_jobs=-1,\n", " oob_score=False, random_state=None, verbose=0,\n", " warm_start=False)\n", "\n", "Accuracy before tuning: 0.523333333333\n", "Accuracy after tuning: 0.554444444444\n", "\n", " Tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTimeNumIterations
Logistic Regression0.5055560.6979670.4912228.29637110.0
Random Forest0.5544440.6834710.574590307.98135030.0
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime NumIterations\n", "Logistic Regression 0.505556 0.697967 0.491222 8.296371 10.0\n", "Random Forest 0.554444 0.683471 0.574590 307.981350 30.0" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xd4XOWZ9/Hvrd57saol94KxwTbGGLAx4ECWYEgIJQScEEJYkpCQsmGzbzabTdlk08iGNAIBhwRClmYwCcUGm9AM7kWyLVdZxZJsq1td9/vHHBOtI1sja0Znyv25Ll0zc+acmVvH45+eec45zyOqijHGmOAX4XYBxhhjfMMC3RhjQoQFujHGhAgLdGOMCREW6MYYEyIs0I0xJkRYoBtjTIiwQDc+IyIHRKRbRLJOWr5ZRFRESka5nkUiUuWj11ojIrf74rWM8RcLdONr+4GbTjwQkRlAvHvlhAYRiXK7BhP4LNCNrz0K3Drg8TLg9wNXEJFYEfmRiFSKSJ2I/FpE4p3n0kVkpYg0iEijc79wwLZrROTbIvKmiLSKyMsnfyNw1ksE/grki0ib85MvIhEicq+I7BWRoyLyZxHJcLaJE5E/OMubROQ9EckVke8CFwH3O69z/yDvN+i2znMZIvKwiNQ4v9OzA7b7tIjsEZFjIvKciOQPeE5F5LMiUgFUOMumiMgrzvq7ROT64f8TmVBlgW587R0gRUSmikgkcAPwh5PW+QEwCZgFTAAKgH93nosAHgbGAsVAB3BygH4M+CSQA8QAXzm5CFVtB64EalQ1yfmpAe4GrgEWAvlAI/ALZ7NlQCpQBGQCdwIdqvpvwN+Azzmv87lBfu9Bt3WeexRIAKY7Nf8UQEQWA/8FXA/kAQeBP530utcA84Bpzh+pV4DHnNe5CfiliEwfpB4ThizQjT+caKVfDuwEqk88ISICfBq4R1WPqWor8D3gRgBVPaqqT6nqcee57+IJ34EeVtXdqtoB/BnPHwZvfQb4N1WtUtUu4D+A65wujR48YTxBVftUdYOqtnj5uoNuKyJ5eP6w3Kmqjarao6prnW1uBn6nqhudWv4VmH/SsYb/cvZTB3AVcEBVH1bVXlXdCDwFXDeM39+EMOuXM/7wKPA6UMpJ3S1ANp7W6gZPtgMgQCSAiCTgacFeAaQ7zyeLSKSq9jmPDw94veNA0jBqGws8IyL9A5b1AblO3UXAn0QkDc83i39T1R4vXnfQbZ1lx1S1cZBt8oGNJx6oapuIHMXzjeWAs/jQSbXPE5GmAcuinPc2xlroxvdU9SCeg6MfBJ4+6ekjeLoipqtqmvOTqqonQvnLwGRgnqqmABc7y4XhG2wo0UPAlQPeO01V41S12mk9f0tVpwEX4GkR33qa1/r7G51620NAhhPyJ6vBE9LA+/3+mQz4RnPS+x4C1p5Ue5Kq/vPpajPhwwLd+MungMVOX/b7VLUf+C3wUxHJARCRAhH5gLNKMp7Ab3IOVn5zBDXUAZkikjpg2a+B74rIWOe9s0VkqXP/EhGZ4fT9t+DpRukb8FrjTvVGp9pWVWvxHJz9pXPAN1pETvyRegz4pIjMEpFYPF1P61T1wCneZiUwSURucV4nWkTmisjU4e4YE5os0I1fqOpeVV1/iqe/BuwB3hGRFmAVnlY5wH14TnM8gucA64sjqGEn8DiwzznzJB/4GfAc8LKItDrvMc/ZZAzwJJ5ALgfW8vcDuj/D09feKCL/M8jbnW7bW/AE/E6gHviiU99q4Bt4+sFrgfE4xxJO8fu0AkucdWrwdD39AIj1eqeYkCY2wYUxxoQGa6EbY0yIsEA3xpgQYYFujDEhwgLdGGNCxKheWJSVlaUlJSWj+ZbGGBP0NmzYcERVs4dab1QDvaSkhPXrT3UmmzHGmMGIyEFv1rMuF2OMCREW6MYYEyIs0I0xJkRYoBtjTIiwQDfGmBBhgW6MMSHCAt0YY0KEBboxxoQIC3RjjAkRNqeoCSiPrat05X0/Nq/Ylfc1xpeshW6MMSHCq0AXkXtEZIeIbBeRx0UkTkRKRWSdiFSIyBMiEuPvYo0xxpzakIEuIgXA3cAcVT0LiMQzp+EPgJ+q6kSgEc+kwMYYY1zibZdLFBAvIlFAAp4JbRfjmRQXYDlwje/LM8YY460hA11Vq4EfAZV4grwZ2AA0qWqvs1oVUOCvIo0xxgzNmy6XdGApUArkA4nAlYOsqqfY/g4RWS8i6xsaGkZSqzHGmNPwpsvlMmC/qjaoag/wNHABkOZ0wQAUAjWDbayqD6jqHFWdk5095IQbxhhjzpA3gV4JnC8iCSIiwKVAGfAacJ2zzjJghX9KNMYY4w1v+tDX4Tn4uRHY5mzzAPA14EsisgfIBB7yY53GGGOG4NWVoqr6TeCbJy3eB5zn84qMMcacEbtS1BhjQoQFujHGhAgLdGOMCREW6MYYEyIs0I0xJkRYoBtjTIiwQDfGmBBhgW6MMSHCAt0YY0KEBboxxoQIC3RjjAkRFujGGBMiLNCNMSZEWKAbY0yIsEA3xpgQYYFujDEhwgLdGGNCxJCBLiKTRWTzgJ8WEfmiiGSIyCsiUuHcpo9GwcYYYwbnzZyiu1R1lqrOAmYDx4FngHuB1ao6EVjtPDbGGOOS4Xa5XArsVdWDwFJgubN8OXCNLwszxhgzPMMN9BuBx537uapaC+Dc5gy2gYjcISLrRWR9Q0PDmVdqjDHmtLwOdBGJAa4G/nc4b6CqD6jqHFWdk52dPdz6jDHGeGk4LfQrgY2qWuc8rhORPADntt7XxRljjPHecAL9Jv7e3QLwHLDMub8MWOGroowxxgyfV4EuIgnA5cDTAxZ/H7hcRCqc577v+/KMMcZ4K8qblVT1OJB50rKjeM56McYYEwDsSlFjjAkRFujGGBMiLNCNMSZEWKAbY0yI8OqgqDHGPx5bV+nK+35sXrEr72v8y1roxhgTIizQjTEmRFigG2NMiLA+dGOAmqYO3tp7lPrWTo61dZMQG0VhejwTc5I4uzCNyAhxu0RjhmSBbsKWqrL5UBPr9h/j689se395fHQknb19qHoepydEs2hyDtfNLuSC8ZmIWLibwGSBbsJSQ2sXz26uZv+RdnKSY/nKkkksmT6GovQE4mMi6e7tp7a5g61Vzby2s55Xd9XzzKZqJuQkcfuFpXxkdiHRkdZjaQKLBboJO2U1LfzpvUqiIoVrZhUwpySdj58/9v+sExMVwdjMRMZmJvKhmfl09vSxcmstD7+5n3uf3sav1+7lnssn8aGz84mw7hgTIKyJYcLKpspGHnv3IHmpcdxz2STOK80gwosulLjoSK6bXcjKz1/Ig7fOIS46ki/8aTPX/+ZtymtbRqFyY4ZmgW7CxsbKRv53QxUlWYnctqCU5LjoYb+GiHDZtFz+cvdF/PdHzmbfkXau+vkbfHtlGa2dPX6o2hjvWaCbsFDVeJxnN1UzLjuRZfNLiI2OHNHrRUQI188t4tUvL+SGuUX87s39XPrjtazcWoOeOJpqzCizQDchr62rlz+uqyQpLoqb5hb79GBmWkIM37t2Bs/ctYCclFg+99gm7vzDBhpau3z2HsZ4y9sZi9JE5EkR2Ski5SIyX0QyROQVEalwbtP9Xawxw6Wq/Hn9Idq7erl53lgSY/1zHsCsojRWfPZC7r1yCq/tamDJT9fy3BZrrZvR5W1T5WfAi6o6BZgJlAP3AqtVdSKw2nlsTEDZWNnEnvo2/unsPArS4v36XpERwp0Lx/OXuy+kODORux/fxD//YSNH2qy1bkbHkIEuIinAxcBDAKrarapNwFJgubPacuAafxVpzJlo7+rlr9trGZuRwNySjFF73wk5yTx153z+5YrJvLqznivue501u+pH7f1N+PKmhT4OaAAeFpFNIvKgiCQCuapaC+Dc5gy2sYjcISLrRWR9Q0ODzwo3Zigvbj9MZ08fS88p8OrURF+KiozgrkUTeP7zF5KZGMsnHn6P76wso6u3b1TrMOHFm0CPAs4FfqWq5wDtDKN7RVUfUNU5qjonOzv7DMs0ZngOHm1nQ2UjF07IZkxKnGt1TB6TzIrPLeDW+WN58I39fORXb7Gvoc21ekxo8ybQq4AqVV3nPH4ST8DXiUgegHNr3ylNQFBVXi6rIyk2isVTBv3iOKrioiP5z6Vn8cAts6lq7OCqn7/Bs5uq3S7LhKAhA11VDwOHRGSys+hSoAx4DljmLFsGrPBLhcYM096GdvYfaWfR5GxiogLnzNwl08fw1y9cxFkFqXzxic186/kd9PXbWTDGd7w9h+vzwB9FJAbYB3wSzx+DP4vIp4BK4KP+KdEY76kqq8rrSI2PHtUDod7KS43nj7fP43t/KefhNw9QmpXIzfOKSYixYZXMyHn1KVLVzcCcQZ661LflGDMyu+taqTx2nKWz8gN2NMToyAi++aHpzChI5atPbuXXa/eybH4JmUmxbpdmglxgfuKNOUOv7qwnPSGa2WMD/zq3D59byG0LSmnv6uNXa/dy6Nhxt0syQc4C3YSMyqPtHGrsYMGELKIiguOjXZqVyD8vGk9sVAQPvbmffUfsDBhz5oLjU2+MF97Ye5S46IigaJ0PlJUUyx0Xjyc1PppH3jzA7rpWt0syQcoC3YSExuPd7KhuZm5JBrFRIxtJ0Q2p8dF8+qJx5CTH8od3DrLXzlU3Z8AC3YSEt/ceRQTmj8t0u5QzlhQbxW0LSslMiuHRtw9SebTd7ZJMkLFAN0Gvq7eP9QePMT0/lbSEGLfLGZEEJ9ST46J45O0D1LV0ul2SCSIW6CbobatqprOnnwvGB2/rfKDkuGhuu7CU6IgIlr99wGZCMl6zQDdBb/3BRrKTYynOSHC7FJ9JT4jhlvljae/q5dF3DtLd2+92SSYIWKCboFbX0knlsePMHZuOjPKIiv5WmJ7ADXOKqG7s4JlNVTZZhhmSBboJausPHCNShFnFwXWqorem5ady2bRctlQ1887+Y26XYwKcBboJWr19/WysbGJqfgpJfppaLhAsnJTN5Nxk/rK1lkq7mtSchgW6CVpltS109PQxN8guJBquCBGun1NESnwUj79bSUe3TZJhBmeBboLWpsomUuOjGZ+T5HYpfhcfE8lN5xXT2tnDs5urrT/dDMoC3QSltq5eKupbmVmYNurTy7mlMD2BS6fmsq26mS1VTW6XYwKQBboJStuqm+lXmFWU5nYpo2rhpGzGZiSwYnMNTce73S7HBBgLdBOUNlc2MiYljjGp7s0X6oYIET46pwhVrOvF/AOvAl1EDojINhHZLCLrnWUZIvKKiFQ4t6F9ZMoEjKNtXRxq7GBmmLXOT8hIjGHJ9Fx217Wx+ZB1vZi/G04L/RJVnaWqJ2YuuhdYraoTgdXOY2P87kT/8czCVJcrcc/54zIpSo/nhW21tHX1ul2OCRAj6XJZCix37i8Hrhl5Ocacnqqy5VAzpVmJQT8Q10hEiPDhcwvp6unnL9tq3S7HBAhvA12Bl0Vkg4jc4SzLVdVaAOc2Z7ANReQOEVkvIusbGhpGXrEJa3UtXTS0dXF2GLfOT8hNiePiSdlsPtRkMx0ZwPtAX6Cq5wJXAp8VkYu9fQNVfUBV56jqnOzs7DMq0pgTtlU3IcD0fAt08Jz1kpYQzfNbaujrtwOk4c6rQFfVGue2HngGOA+oE5E8AOe23l9FGgOe7pZt1S2UZiWG9KX+wxETFcFVM/Koa+ninX1H3S7HuGzIQBeRRBFJPnEfWAJsB54DljmrLQNW+KtIY8DT3XKkrYuzCqx1PtDUvBQm5SaxqrzOxk4Pc9600HOBN0RkC/Au8IKqvgh8H7hcRCqAy53HxvjNtupmp7slxe1SAoqIcNWMfHr6+lldbl+Uw9mQ31tVdR8wc5DlR4FL/VGUMSdTVbZXN1OSlUhyXLTb5QScrORY5o3L5J29R5k/PpPclPC64Mp42JWiJijUtXrObplh3S2ndOnkHGKjI/jrdjuNMVxZoJugUFbj6W6ZZt0tp5QQG8XiyTnsrmujoq7V7XKMCyzQTVAoq22hKCOBFOtuOa3zx2WSnhDNS2WHbZyXMGSBbgJe0/Fuapo6mZpnrfOhREVGcOmUXGqaOtlR0+J2OWaUWaCbgFde6wmmaRboXplVnEZ2Uiyryuvot1Z6WLFANwGvrLaF7KRYspNj3S4lKESIcNm0XOpbu9hiozGGFQt0E9A6uvvYf6TduluGaXp+CnmpcazeWW9DAoQRC3QT0HbVtdCvdnbLcEWIcPm0XI61d7PhYKPb5ZhRYoFuAlpZTQvJsVEUpse7XUrQmZybTHFGAq/tqqenr9/tcswosEA3Aaunr5/d9W1MyUsJm4mgfUmcVnpzRw/v7j/mdjlmFFigm4C1r6GN7t5+O7tlBMZnJzEuO5E1u+rp6u1zuxzjZxboJmCV1bYQExXB+OxEt0sJakumjaG9u491+6yVHuos0E1A6lelvLaVSbnJREXax3QkijMSmJCTxN/2HLG+9BBn/1NMQKo6dpy2rl7rbvGRSybn0N7Vy3sHrJUeyizQTUAqq20hQjxnapiRK81KpCQzgb9VHKHXWukhywLdBKSy2lbGZSURHxPpdikh45LJOTR39LCp0q4eDVVeT8woIpHAeqBaVa8SkVLgT0AGsBG4RVW7/VOmCSf1rZ0caeti/vjMUXvPx9ZVjtp7uWVCThKF6fGs2V1Pb1+/HZsIQcP5F/0CUD7g8Q+An6rqRKAR+JQvCzPhq9wZJXDqGOtu8SUR4ZLJOTQe72HF5hq3yzF+4FWgi0gh8E/Ag85jARYDTzqrLAeu8UeBJvyU1bZQkBZPWkKM26WEnCljkslLjeMXa/bYGC8hyNsW+n3AvwAnjqZkAk2q2us8rgIKBttQRO4QkfUisr6hoWFExZrQ19LZw6HGDhuMy09EhEWTc9jX0G5T1YWgIQNdRK4C6lV1w8DFg6w66J97VX1AVeeo6pzs7OwzLNOEi/fHPrfBuPxmen4K47MTuf/VPfRbKz2keNNCXwBcLSIH8BwEXYynxZ4mIicOqhYC1ilnRqy8toWMxBhybexzv4kQ4a5FE9h5uJVXd9a7XY7xoSEDXVX/VVULVbUEuBF4VVVvBl4DrnNWWwas8FuVJiy0dvawt6GdaXkpiA3G5VdXz8qnIC2eX67ZY3OPhpCRnLf0NeBLIrIHT5/6Q74pyYSrtbsb6OtX6z8fBdGREXxm4Tg2VjaxzkZiDBnDCnRVXaOqVzn396nqeao6QVU/qqpd/inRhIuXdtSREBPJ2MwEt0sJC9fPKSIrKYZfrtnrdinGR+zKAhMQunr7eG1nPdNs7PNRExcdyScXlPL67ga2VTW7XY7xAQt0ExDe2nvUMxiXnd0yqm6ZP5bk2Ch+tXaP26UYH7BANwHh5R2HSYyJZHx2ktulhJWUuGhumT+Wv24/zN6GNrfLMSNkgW5c19evvFJWx6IpOUTb+CKj7rYLS4mJjOA3a60vPdjZ/x7juo2VjRxp6+YD08e4XUpYykqK5Ya5RTyzqZra5g63yzEjYIFuXPfS9sPEREZwyWS7ktgtn75oHP0Kv319v9ulmBGwQDeuUlVeKjvMggmZJMdFu11O2CrKSGDpzHwef7eSY+02CnawskA3riqrbeHQsQ7rbgkAdy4aT0dPH4+8aa30YGWBblz10o46IgQum5brdilhb1JuMkum5fLIWwdo6+odegMTcCzQjate3nGYOWMzyEqywbgCwV2XTKCls5fH1h10uxRzBizQjWsOHm1n5+FWlky31nmgmFWUxgXjM3nwb/vp7OlzuxwzTBboxjUv7TgMYP3nAeauRROob+3iqY1VbpdihsnrSaKN8bWXdtQxPT+FogwbjGu0nW5SbFWlIC2eH7+8m/5+iIzw3dg6H5tX7LPXMv/IWujGFfWtnWysbLTWeQDyTFOXzbH2brZX26BdwcQC3bjixe2HUbXulkA1NS+F7ORY1u5usAkwgogFunHFyi21TMxJYvKYZLdLMYOIEGHhxGwOt3Syq67V7XKMl7yZJDpORN4VkS0iskNEvuUsLxWRdSJSISJPiEiM/8s1oeBwcyfvHTzGVWfnu12KOY2ZRWmkxUezZpe10oOFNy30LmCxqs4EZgFXiMj5wA+An6rqRKAR+JT/yjSh5IVttajCVTPz3C7FnEZkhHDhxCwqjx3nwNHjbpdjvODNJNGqqicGSo52fhRYDDzpLF8OXOOXCk3IWbm1hql5KTb2eRCYMzaDxJhI1u6ud7sU4wWv+tBFJFJENgP1wCvAXqBJVU9cH1wFFJxi2ztEZL2IrG9oaPBFzSaIVTUeZ1NlEx+y1nlQiImKYMGELHbXtVHTZEPrBjqvAl1V+1R1FlAInAdMHWy1U2z7gKrOUdU52dk2PGq4e2FrLQBXzbD+82AxrzST2KgI1u62BlmgG9ZZLqraBKwBzgfSROTEhUmFQI1vSzOhaMXmGmYWplKcaRcTBYv4mEjmlWayvbqZI61dbpdjTsObs1yyRSTNuR8PXAaUA68B1zmrLQNW+KtIExrKa1soq23hw+cWul2KGaYFEzKJjBBer7BWeiDzpoWeB7wmIluB94BXVHUl8DXgSyKyB8gEHvJfmSYUPLOpmqgI4UMzrbsl2CTHRTN7bDqbKpto7uhxuxxzCkOO5aKqW4FzBlm+D09/ujFD6utXnt1UzaLJOWQk2iULweiiidm8d+AYb1Q08E92DUFAsitFzah4c88R6lu7+Mi5g54MZYJARmIMZxem8d6BRo7bBBgByQLdjIqnN1aREhfF4qk5bpdiRmDhpGy6+/p5a99Rt0sxg7BAN37X2tnDizsOc9XMfGKjIt0ux4xAbkocU8ck8/beo3TZBBgBxwLd+N2zm2vo7Onno7Pt7JZQsHByDh09fbx74JjbpZiTWKAbv1JVHltXybS8FGYVpbldjvGB4owExmUl8uaeI/T29btdjhnAAt341eZDTZTXtvCxecWI+G7mG+OuhZOzaensZVNlk9ulmAEs0I1f/XFdJQkxkSydZae5hZIJ2UkUpMXzekUD/Ta0bsCwQDd+09zRw8qtNSydlU9yXLTb5RgfEhEWTsrmqE1TF1BskugAd7rJfP3JF5P5PrWhis6efj523lgfVGQCzbT8FLKSPNPUzShItS61AGAtdOMXvX39PPzWfs4tTmNGYarb5Rg/iBBh4aQsaps72W3T1AUEC3TjFy/tqOPQsQ7uuHi826UYP5pZlEZqfDRrbGjdgGCBbnxOVXng9b2UZCZw+bRct8sxfhQVEcFFE7M4ePQ4+460Db2B8SsLdONz7x1oZEtVM5+6aByREdavGurmjM0gOTaKVWX1Npm0yyzQjc898PpeMhJjuM7GPQ8LMVERLJyczYGj7extaHe7nLBmgW58asuhJlaV1/OJC0qIj7FxW8LFeSUZpMZH80rZYWulu8gC3fjUj17eRUZiDLddWOp2KWYURUVGsHhyDocaO9hlZ7y4xpsp6IpE5DURKReRHSLyBWd5hoi8IiIVzm26/8s1geytvUf4W8UR7lo0nqRYu8Qh3Jw7Np2MxBhWldVZK90l3rTQe4Evq+pUPJNDf1ZEpgH3AqtVdSKw2nlswpSq8qOXdjEmJY6Pn28XEoWjyAhh8ZQcapo72VHT4nY5YWnIQFfVWlXd6NxvxTNBdAGwFFjurLYcuMZfRZrA9+L2w2ysbOLzl04gLtr6zsPVrKI0spJiWVVeZ2O8uGBYfegiUoJnftF1QK6q1oIn9AGbiiZMtXX18q3ny5ial8INc4rcLse4KEKEy6bmUN/axbYqG+NltHkd6CKSBDwFfFFVvf4+JSJ3iMh6EVnf0GBXk4WiH7+8i7rWTr537VlERdpx9nB3VkEqY1LiWFVeR1+/tdJHk1f/+0QkGk+Y/1FVn3YW14lInvN8HlA/2Laq+oCqzlHVOdnZ2b6o2QSQ7dXNLH/rADfPK+acYjsubjyt9Mun5XK0vZv3bFajUeXNWS4CPASUq+pPBjz1HLDMub8MWOH78kwg6+ju48t/3kJGYixf/cAUt8sxAWTKmGRKMhNYvbPe5h4dRd600BcAtwCLRWSz8/NB4PvA5SJSAVzuPDZh5BsrtrO7vpWfXD+T1Hgb79z8nYhw5Vl5tHf18nqFdbWOliFPFlbVN4BTDchxqW/LMcHiz+sP8eSGKu6+dCIXT7KuNPOPijISOLswlTf2HOG80kz7oz8K7AiWGbZ39h3lG89u54LxmXzh0olul2MC2JJpY+hXWFVW53YpYcEC3QzLxspGbnvkPYozEvj5TefYaIrmtDISY5g/LpONlY3UNne4XU7Is0A3XttY2ciy371LTnIsf7x9HplJsW6XZILAosnZxEVH8uL2w26XEvIs0M2QVJXH1lVy42/eIT0hhj9++nxyUuLcLssEiYSYKC6ZnE1FfRuv28xGfmWBbk6robWLL/15C19/Zhvnj8/kuc8toCAt3u2yTJA5f1wm6QnRfPeFcnr7+t0uJ2RZoJtBNXf08Ms1e7jkR2t4fksNd186kYc/MZe0hBi3SzNBKCoygivPymNXXSuPvHXA7XJClo1xat7X0d3HgaPtbKtu5lvP76Crt5/Lpuby9Q9OYVx2ktvlmSA3PT+FhZOyuW9VBR+amU+uddv5nAV6mFBVOrr7aO7soaWjl5aOHud+Dy2dPTQd76GhtQsFYqMiuG52ITfMLeLswjS3SzchQkT41tXTWXLf63znhXJ+ftM5bpcUcizQQ0xvXz+1zZ0cajxOfWsXR9u6aDzuCe7ekwZKEiAxNoqU+CgyEmOYUZBKaXYiRekJLLugxJX6TWgryUrkzoXj+Z/VFdw4t4gFE7LcLimkWKCHgJaOHspqWyivbWH/kfb3gzsuOoKspFgK0+NJzUshJT7a8xMXRUp8NMlxUURF2GEUM7ruWjSeZzdV840V23nxCxcTE2WfQV+xQA9SqkpFfRvr9h9j1+EW+hUyE2OYV5rB2MxEijISSImLwjO2mjGBIy46km9dPZ1PPvIev/3bPj57yQS3SwoZFuhBRlXZebiV1eV11DR3khgTyYUTsjm3OI3s5FgLcBMULpmSw5Jpufz81QqWzsqnMD3B7ZJCgn3XCSI1TR385vUBWYokAAAL+UlEQVR9PPrOQTp7+/nIuYV87copXHHWGHJS4izMTVD59w9NA+CbK3bYpNI+Yi30INDV28dLO+pYt+8oCTGRfPicAs4pTrdxVExQK0xP4CtLJvOdF8p5dnM1155T6HZJQc8CPcAdONLOkxuraGzvZt64DC6fOob4GJuE2YSGTy4o5S/bavmP58pYMD7LhpQYIetyCVD9/cp9q3bz27/tQ1W5/aJxXD2zwMLchJTICOGHH51JR08f/+/Z7db1MkLeTEH3OxGpF5HtA5ZliMgrIlLh3Npkkj7UfLyH23+/nvtWVTCrKI27F0+kNCvR7bKM8Yvx2Ul8ZckkXi6r48kNVW6XE9S8aaE/Alxx0rJ7gdWqOhFY7Tw2PrDzcAtX/+INXt/dwLeXTue62YXERlur3IS2T104jnmlGfzHczs4eLTd7XKC1pCBrqqvAydP3b0UWO7cXw5c4+O6wtKKzdVc+4u36Oju44nPnM8t80vszBUTFiIjhJ/cMIuICOGeJzbbiIxn6Ez70HNVtRbAuc3xXUnhR1X5ycu7+MKfNjOjIJWVd1/I7LEZbpdlzKgqSIvnu9fOYGNlE/etqnC7nKDk97NcROQO4A6A4uJif79d0Onu7efep7by9KZqbphTxHeuPYvoSDtWbcLT1TPzeaOigftf28PsknQumWxtxeE40+SoE5E8AOe2/lQrquoDqjpHVedkZ9vs8AM1d/Sw7Hfv8vSmar6yZBLf/8gMC3MT9v5z6VlMGZPMPU9sprrJ5iEdjjNNj+eAZc79ZcAK35QTPqoaj/PRX7/F+oPH+OkNM/nc4onWX24MnrFefvXx2fT2KXf9YQOdPX1ulxQ0vDlt8XHgbWCyiFSJyKeA7wOXi0gFcLnz2Hhpe3Uz1/7yLWqbO1l+23l2hZwxJynNSuTH189kS1Uz9z611c5P99KQfeiqetMpnrrUx7WEhdd21vPZxzZ6Jlu+fR6TcpPdLsmYgPSB6WP46gcm88OXdjFpTDJ3LbJRGYdil/6Poj+8c5BvPreDqXnJ/G7ZXLvM2Zgh3LVoPLsOt/LDl3YxNiORfzo7z+2SApoF+ijo61e+95dyHnpjP4un5PDzm84hMdZ2vTFDERH++7qzqW3u4J4nNpOeEM0FNsvRKdkpFX7W3tXLZx7dwENv7OcTF5Tw21vnWJgbMwxx0ZE8eOtcSrMSuePRDWyvbna7pIBlge5Hh5s7uf43b/Pqzjq+dfV0/uPq6TbkrTFnIDUhmuW3nUdqfDQff2idhfopWKD7yabKRq75xZscONLOQ8vm2qTLxozQmNQ4Hv/0+SRER1qon4IFuo+pKo++fYDrf/M2UZHC/955AZdMsavdjPGF4swE/nTHfBJjorj5wXWsP3DyMFPhzQLdh4539/LFJzbzjRU7uGhiNis/fyHT8lPcLsuYkOIJ9fPJSIzh5gfX8fKOw26XFDAs0H1kT30bS+9/k+e31PDVD0zmwVvnkJYQ43ZZxoSkoowEnrxzPlPyUrjzD56TDuziIwv0EVNV/rjuIFff/wZH27v5/W3z+OwlE4iwg5/G+FVmUiyPf3oel03N5dsry/jiE5vp6A7vYQLs/LkRqG3u4N6ntrF2dwMXTsjihx89m7zUeLfLMiZsJMRE8euPz+YXr+3hJ6t2s+twK/fdOIspY8Kzq9Na6Gegt6+fh97Yz2U/Xsu6/Uf5z6XT+f1t51mYG+OCiAjh85dO5OFPzOVIWxdX//xNHvzbPvr6w68LxgJ9mN6oOMKH7n+Tb68sY25pBq/cs5Bb55dYF4sxLls0OYeXvngxCydn850Xyrn2l2+ytarJ7bJGlXW5eGlrVRM/eWU3a3Y1UJgezy9vPpcrzxpjQ94aE0Ayk2J54JbZPLelhu+8UM7SX7zJ9bOL+MJlE8lPC/1v0Bbop6GqrNt/jF+v3cuaXQ2kxkfz9Q9O4db5JcTZxM3GBCQRYemsAi6ZksPPVlXw6NsHeWZzNTfPK+b2i8ZREMLBboE+iOaOHlZureHRtw+y83ArGYkx/MsVk7nl/LEkx0W7XZ4xxgspcdF846ppfHJBCT9bVcHv3z7I798+yAdn5HHzvGLmlWaE3DdsC3RHS2cPa3c18Nfttawqr6e7t5+peSn84CMzuHpmAfEx1iI3JhgVpifww4/O5J7LJ/HIWwd4fF0lz2+poTgjgWvOKeAD03OZlpcSEuEetoHe3tXLjpoW3tl3lLf2HmH9gUZ6+5WspBg+dl4xHz63gBkFqSHxj2yMgfy0eL7+wancc9kkXtxRy/+ur+Lnr1bwP6srKEyP56KJWcwfn8W5xWkUpMUH5f/9EQW6iFwB/AyIBB5U1YCbiq61s4fqpg6qGzvYVdfKjpoWymta2H+0HVUQgWl5KXzqolKWTMtlVlG6jYhoTAiLj4nk2nMKufacQo60dbGqrI7VO+tZubWWx989BEBaQjTT8lKYnp/C9PxUijMTKEyLJyspNqDPaDvjQBeRSOAXeOYUrQLeE5HnVLXMV8WdsLehjcb2brp7++nq66e7d8BPXz8d3X00dfTQfLybpo4emo73UNfSSU1TBy2dvf/ntQrT45men8LSWQVMz09h9th00hPtEn1jwlFWUiw3nlfMjecV09vXz46aFrZWNVFW28KOmhaWv32Q7t7+99ePiYwgPy2OvNR40hOjSY2PITU+mrSEaJLjooiJjCAmKuLvt8796KgIpuWl+P1kipG00M8D9qjqPgAR+ROwFPB5oH97ZRlrdjWcdh0Rz0GQtIRo0uKjyU+LZ25JBgXp8RSkxVOQHs/47CRS4+2gpjHmH0VFRjCzKI2ZRWnvL+vp62f/kXYOHTvu+abvfNs/3NzJ7ro2mo730NzRTU/f0BcxrfrSQibkJPnzVxhRoBcAhwY8rgLmnbySiNwB3OE8bBORXSN4z2CTBRxxu4gzcfPovl3Q7qdRFvT7aZQ+VwG5nyb+YESbj/VmpZEE+mAdSf/wZ0pVHwAeGMH7BC0RWa+qc9yuI9DZfvKO7SfvhPN+Gsml/1VA0YDHhUDNyMoxxhhzpkYS6O8BE0WkVERigBuB53xTljHGmOE64y4XVe0Vkc8BL+E5bfF3qrrDZ5WFhrDsajoDtp+8Y/vJO2G7n8Rm+TDGmNBgw+caY0yIsEA3xpgQYYF+BkTkChHZJSJ7ROTeQZ7/hIg0iMhm5+d2Z/lYEdngLNshIneOfvWj50z304DnU0SkWkTuH72qR99I9pOI9A1YHtInJYxwPxWLyMsiUi4iZSJSMpq1jxpVtZ9h/OA5ALwXGAfEAFuAaSet8wng/kG2jQFinftJwAEg3+3fKdD204DnfwY8drp1gv1npPsJaHP7dwiS/bQGuNy5nwQkuP07+ePHWujD9/6QB6raDZwY8mBIqtqtql3Ow1hC+xvSGe8nABGZDeQCL/upvkAxov0URs54P4nINCBKVV8BUNU2VT3uv1LdE8qB4i+DDXlQMMh6HxGRrSLypIi8fwGWiBSJyFbnNX6gqqF6MdYZ7ycRiQB+DHzV/2W6bkSfJyBORNaLyDsico1fK3XXSPbTJKBJRJ4WkU0i8kNncMGQY4E+fN4MefA8UKKqZwOrgOXvr6h6yFk+AVgmIrl+q9RdI9lPdwF/UdVDhL4RfZ6AYvVc5v4x4D4RGe+fMl03kv0UBVwEfAWYi6fb5hP+KdNdFujDN+SQB6p6dEDXym+B2Se/iNMy34HngxaKRrKf5gOfE5EDwI+AW0Uk4Mba95ERfZ5OfMNTz6ina4Bz/Fmsi0ayn6qATU53TS/wLHCun+t1hQX68A055IGI5A14eDVQ7iwvFJF45346sAAI1dEnz3g/qerNqlqsqiV4WlW/V9V/OKshRIzk85QuIrHO/Sw8nyefD18dIM54PznbpotItvN4MSG6n8J2CrozpacY8kBE/hNYr6rPAXeLyNVAL3CMv3+9mwr8WEQUz1fIH6nqtlH/JUbBCPdT2PDB5+k3ItKPp3H2ffXDBDOBYCT7SVX7ROQrwGoREWADnhZ8yLFL/40xJkRYl4sxxoQIC3RjjAkRFujGGBMiLNCNMSZEWKAbY0yIsEA3xpgQYYFujDEh4v8DQEOe60p209IAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Creating the dictionary of parameters to use in the search\n", "parameters = {'n_estimators': scipy.stats.randint(low=10, high=1000), # Uniform distribution\n", " 'max_depth': [None, 10, 30], # Maximum number of levels in a tree\n", " 'max_features': ['auto', 'log2', None], # Number of features to consider at each split\n", " 'min_samples_split': [2, 5, 10], # Minimum number of samples required to split a node\n", " 'min_samples_leaf': [1, 2, 4] # Minimum number of samples required at each leaf node\n", " }\n", "\n", "hyperparameter_tune_get_results(random_forest, parameters, 'Random Forest', num_rounds=30)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As I mentioned earlier, random forests are relatively robust to overfitting. We can usually find good results by increasing the number of trees until we hit a point with diminishing returns on performance.\n", "\n", "However, let's test overfitting. I'll plot a [validation curve](https://chrisalbon.com/machine_learning/model_evaluation/plot_the_validation_curve/) (using the code from that post) to see how increasing the number of trees affects the cross validation accuracy." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "start_time": "2018-07-12T23:23:01.269Z" } }, "outputs": [], "source": [ "from sklearn import model_selection\n", "\n", "rf = ensemble.RandomForestClassifier(n_jobs=-1)\n", "\n", "trees_to_try = [10, 30, 60, 100, 300, 600, 1000, 3000, 6000, 10000, 30000]\n", "\n", "validation_curve_values = model_selection.validation_curve(\n", " estimator=rf,\n", " X=X,\n", " y=y,\n", " cv=3,\n", " param_name='n_estimators',\n", " param_range=trees_to_try,\n", " n_jobs=-1)\n", "\n", "validation_curve_values\n", "\n", "# TODO: combine these two cells" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "ExecuteTime": { "end_time": "2018-07-12T23:21:23.817433Z", "start_time": "2018-07-12T23:21:23.473703Z" } }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xt0VNX99/H3lySSiEAQ+IkSblqlxdyMMZYHWkUUsUvFiihY611EUVurVKjWC65aXP35qAhiqQr2pwTjBQqtmOcBrY83rhVBQCRKlABKCAQshpLLfv44Q5wkEzKTTC5z8nmtlZU5++w5s/ec5DNn9jmzx5xziIiIv3Ro7QaIiEj0KdxFRHxI4S4i4kMKdxERH1K4i4j4kMJdRMSHFO4iIj6kcBcR8SGFu4iID8W31gP36NHD9e/fv7UeXkQkJq1Zs2a3c65nQ/VaLdz79+/P6tWrW+vhRURikpl9GU49DcuIiPiQwl1ExIcU7iIiPqRwFxHxIYW7iIgPNXi1jJk9D1wI7HLOpYZYb8CTwM+A74BrnXP/inZDj2ThR9v5U/5mdpSWcUJyEpPOH8glp/Vus9tty9pjn0VaSkv+f4VzKeRcYAbw13rWXwCcHPg5E5gV+N0iFn60nSmvr6esvBKA7aVlTHl9PUCTnrTm2m5b1h77LNJSWvr/q8Fwd879PzPrf4Qqo4C/Ou/7+pabWbKZHe+c2xmlNh7Rn/I3Vz9Zh5WVV/L7hZ/wRfG/G73dOe8XNst227L22GeRllLf/9ef8je3TriHoTewLWi5KFBWJ9zNbDwwHqBv375ReGjYUVoWsvzb/1Tw1NsFjd5ufV8t29TttmXtsc8iLaW+/6/6MqypohHuFqIsZDecc7OB2QDZ2dlR+Wbunp07suvb/9Qp752cxPuTz2n0dodMe4vtIZ70pm63LWuPfRZpKfX9f52QnNQsjxeNq2WKgD5ByynAjihst0HLvyhhX9mhOuVJCXFMOn9gk7Y96fyBJCXERX27bVl77LNIS2np/69oHLkvAm4zs/l4J1L3Ndd4e/CZ5m6dEtj3XTkDeh7DL87sy7Pvbo3qGejD929PV460xz6LtJSW/v8yV99A0OEKZrnA2UAP4BvgASABwDn3TOBSyBnASLxLIa9zzjU4I1h2draLZOKw2meavbbBHy5J5coz+4W9HRGRWGZma5xz2Q3VC+dqmXENrHfAxAja1iihropxDma+/bnCXUSklpj5hGp9Z5Sb60yziEgsi5lwr++McnOdaRYRiWUxE+66kkNEJHyt9k1MkdKVHCIi4YuZcAcv4BXmIiINi5lhGRERCZ/CXUTEhxTuIiI+pHAXEfEhhbuIiA8p3EVEfEjhLiLiQwp3EREfUriLiPiQwl1ExIcU7iIiPqRwFxHxIYW7iIgPKdxFRHxI4S4i4kMKdxERH1K4i4j4kMJdRMSHFO4iIj6kcBcR8SGFu4iIDyncRUR8SOEuIuJDCncRER9SuIuI+JDCXUTEhxTuIiI+FFa4m9lIM9tsZgVmNjnE+n5mtszM1pnZP80sJfpNFRGRcDUY7mYWB8wELgAGAePMbFCtav8N/NU5lw5MBf4Y7YaKiEj4wjlyzwEKnHNfOOcOAfOBUbXqDAKWBW6/HWK9iIi0oHDCvTewLWi5KFAW7GNgdOD2z4HOZta99obMbLyZrTaz1cXFxY1pr4iIhCGccLcQZa7W8t3AWWb2EXAWsB2oqHMn52Y757Kdc9k9e/aMuLEiIhKe+DDqFAF9gpZTgB3BFZxzO4BLAczsGGC0c25ftBopIiKRCefIfRVwspkNMLOjgLHAouAKZtbDzA5vawrwfHSbKSIikWgw3J1zFcBtQD6wCchzzm0ws6lmdnGg2tnAZjP7DDgO+EMztVdERMJgztUePm8Z2dnZbvXq1a3y2CIiscrM1jjnshuqp0+oioj4kMJdRMSHFO4iIj6kcBcR8SGFu4iIDyncRUR8SOEuIuJDCncRER9SuIuI+JDCXUTEhxTuIiI+pHAXEfEhhbuIiA8p3EVEfEjhLiLiQwp3EREfUriLiPiQwl1ExIcU7iIiPqRwFxHxIYW7iIgPKdxFRHxI4S4i4kMKdxERH1K4i4j4kMJdRMSHFO4iIj6kcBcR8SGFu4iIDyncRUR8SOEuIuJDCncRER8KK9zNbKSZbTazAjObHGJ9XzN728w+MrN1Zvaz6DdVRETC1WC4m1kcMBO4ABgEjDOzQbWq3QfkOedOA8YCT0e7oSIiEr5wjtxzgALn3BfOuUPAfGBUrToO6BK43RXYEb0miohIpMIJ997AtqDlokBZsAeBq8ysCHgDuD3UhsxsvJmtNrPVxcXFjWiuiIiEIz6MOhaizNVaHgfMdc49ZmaDgf8xs1TnXFWNOzk3G5gNkJ2dXXsbIlKP8vJyioqKOHjwYGs3RVpIYmIiKSkpJCQkNOr+4YR7EdAnaDmFusMuNwAjAZxzH5pZItAD2NWoVolIDUVFRXTu3Jn+/ftjFup4S/zEOUdJSQlFRUUMGDCgUdsIZ1hmFXCymQ0ws6PwTpguqlXnK2A4gJn9CEgENO4iEiUHDx6ke/fuCvZ2wszo3r17k96pNRjuzrkK4DYgH9iEd1XMBjObamYXB6rdBdxkZh8DucC1zjkNu4hEkYK9fWnq/g7rOnfn3BvOuVOccyc55/4QKLvfObcocHujc26Icy7DOZfpnPs/TWqViLQpJSUlZGZmkpmZSa9evejdu3f18qFDh8LaxnXXXcfmzZuPWGfmzJm89NJL0WgyAN988w3x8fE899xzUdtmrLDWOsDOzs52q1evbpXHFok1mzZt4kc/+lHY9Rd+tJ0/5W9mR2kZJyQnMen8gVxyWu2L3BrnwQcf5JhjjuHuu++uUe6cwzlHhw5t54Pv06dP55VXXqFjx44sXbq02R6noqKC+PhwTmFGJtR+N7M1zrnshu7bdvaCiETFwo+2M+X19WwvLcMB20vLmPL6ehZ+tD3qj1VQUEBqaioTJkwgKyuLnTt3Mn78eLKzszn11FOZOnVqdd2hQ4eydu1aKioqSE5OZvLkyWRkZDB48GB27fKuvbjvvvt44oknqutPnjyZnJwcBg4cyAcffADAgQMHGD16NBkZGYwbN47s7GzWrl0bsn25ubk88cQTfPHFF3z99dfV5f/4xz/IysoiIyODESNGAPDtt99yzTXXkJaWRnp6OgsXLqxu62Hz58/nxhtvBOCqq67irrvuYtiwYfzud79j+fLlDB48mNNOO40hQ4awZcsWwAv+O++8k9TUVNLT03n66afJz89nzJgx1dtdsmQJl19+eZP3R7Dov9SISLN6aPEGNu7YX+/6j74q5VBljauQKSuv5LevriN35Vch7zPohC48cNGpjWrPxo0bmTNnDs888wwA06ZN49hjj6WiooJhw4Zx2WWXMWhQzQ+179u3j7POOotp06bxm9/8hueff57Jk+vMbIJzjpUrV7Jo0SKmTp3Km2++yVNPPUWvXr147bXX+Pjjj8nKygrZrsLCQvbu3cvpp5/OZZddRl5eHnfccQdff/01t9xyC++++y79+vVjz549gPeOpGfPnqxfvx7nHKWlpQ32/fPPP2fZsmV06NCBffv28d577xEXF8ebb77Jfffdx8svv8ysWbPYsWMHH3/8MXFxcezZs4fk5GTuuOMOSkpK6N69O3PmzOG6666L9Kk/Ih25i/hM7WBvqLypTjrpJM4444zq5dzcXLKyssjKymLTpk1s3Lixzn2SkpK44IILADj99NMpLCwMue1LL720Tp333nuPsWPHApCRkcGpp4Z+UcrNzeWKK64AYOzYseTm5gLw4YcfMmzYMPr16wfAscceC8DSpUuZOHEi4J3M7NatW4N9HzNmTPUwVGlpKZdeeimpqancfffdbNiwoXq7EyZMIC4urvrxOnTowJVXXsm8efPYs2cPa9asqX4HES06cheJMQ0dYQ+Z9hbbS8vqlPdOTuLlmwdHvT2dOnWqvr1lyxaefPJJVq5cSXJyMldddVXIy/mOOuqo6ttxcXFUVFSE3HbHjh3r1An3PGFubi4lJSW88MILAOzYsYOtW7finAt5JUqo8g4dOtR4vNp9Ce77vffey/nnn8+tt95KQUEBI0eOrHe7ANdffz2jR48G4IorrqgO/2jRkbuIz0w6fyBJCTWDIikhjknnD2z2x96/fz+dO3emS5cu7Ny5k/z8/Kg/xtChQ8nLywNg/fr1Id8ZbNy4kcrKSrZv305hYSGFhYVMmjSJ+fPnM2TIEN566y2+/PJLgOphmREjRjBjxgzAC+S9e/fSoUMHunXrxpYtW6iqqmLBggX1tmvfvn307u2dtJ47d251+YgRI5g1axaVlZU1Hq9Pnz706NGDadOmce211zbtSQlB4S7iM5ec1ps/XppG7+QkDO+I/Y+XpkXtapkjycrKYtCgQaSmpnLTTTcxZMiQqD/G7bffzvbt20lPT+exxx4jNTWVrl271qgzb948fv7zn9coGz16NPPmzeO4445j1qxZjBo1ioyMDH7xi18A8MADD/DNN9+QmppKZmYm7777LgCPPvooI0eOZPjw4aSkpNTbrnvuuYdJkybV6fPNN99Mr169SE9PJyMjo/qFCeDKK69kwIABnHLKKU16TkLRpZAiMSDSSyH9rKKigoqKChITE9myZQsjRoxgy5YtzXIpYnObMGECgwcP5pprrgm5vimXQsbesyEi7dq///1vhg8fTkVFBc45/vznP8dksGdmZtKtWzemT5/eLNuPvWdERNq15ORk1qxZ09rNaLL6rs2PFo25i4j4kMJdRMSHFO4iIj6kcBcR8SGFu4g0KBpT/gI8//zzNSbwCmca4Ei88sormBkFBQVR22asUriL+NG6PHg8FR5M9n6vy2v4PkfQvXt31q5dy9q1a5kwYQJ33nln9XLwVAINqR3uc+bMYeDA6H1yNjc3l6FDhzJ//vyobTOU+qZLaEsU7iJ+sy4PFt8B+7YBzvu9+I4mB3x9XnjhBXJycsjMzOTWW2+lqqqKiooKfvnLX5KWlkZqairTp0/n5ZdfZu3atVxxxRXVR/zhTAO8ZcsWzjzzTHJycvj9739fYwreYPv372fFihX85S9/qZ4k7LBHHnmEtLQ0MjIyuPfeewH47LPPOOecc8jIyCArK4vCwkKWLl3KJZdcUn2/CRMm8OKLLwKQkpLCww8/zJAhQ1iwYAHPPPMMZ5xxBhkZGYwZM4ayMm8+n6+//ppRo0ZVfyJ1xYoVTJkyhZkzZ1Zv95577uHpp5+O3k4IQde5i8SaJZPh6/X1ry9aBZX/qVlWXgZ/uw3WvBD6Pr3S4IJpETflk08+YcGCBXzwwQfEx8czfvx45s+fz0knncTu3btZv95rZ2lpKcnJyTz11FPMmDGDzMzMOtuqbxrg22+/nbvvvpsxY8ZUz/0Syuuvv86FF17ID3/4Qzp16sS6detIT09n8eLFLFmyhJUrV5KUlFQ9t8u4ceN48MEHueiiizh48CBVVVUNDud06tSJ999/H/CGqiZMmADA5MmTmTt3LrfccgsTJ07kvPPO47bbbqOiooLvvvuOHj16MHbsWCZOnEhlZSWvvPJKs1+rryN3Eb+pHewNlTfB0qVLWbVqFdnZ2WRmZvLOO+/w+eef84Mf/IDNmzfzq1/9ivz8/Dpzv4RS3zTAK1asqJ498corr6z3/rm5udVTAQdP8bt06VKuv/56kpKSAG/K3b1797J7924uuugiABITEzn66KMbbOPhKYQB1q1bx09+8hPS0tKYP39+9RS///znP7n55psBiI+Pp0uXLpx00kl07tyZ9evXs2TJEnJycsKaUrgpdOQuEmsaOsJ+PDUwJFNL1z5w3T+i2hTnHNdffz0PP/xwnXXr1q1jyZIlTJ8+nddee43Zs2cfcVvhTgMcSnFxMe+88w6ffvopZkZFRQUJCQk88sgj9U65G6osPj6eqqrv570/0hS/V199NUuWLCE1NZVnn32W5cuXH3HbN9xwA3PnzqWwsLA6/JuTjtxF/Gb4/ZCQVLMsIckrj7Jzzz2XvLw8du/eDXhDFV999RXFxcU45xgzZgwPPfQQ//rXvwDo3Lkz3377bUSPkZOTUz3Vbn0nSvPy8rjhhhv48ssvKSwspKioiBNOOIHly5czYsQInnvuueox8T179tCtWzd69OjB4sWLAS/Ev/vuO/r168eGDRs4dOgQe/fu5a233qq3XQcOHKBXr16Ul5czb9686vJhw4ZVfytVZWUl+/d735o1evRoFi9ezNq1azn33HMjeg4aQ+Eu4jfpl8NF070jdcz7fdF0rzzK0tLSeOCBBzj33HNJT09nxIgRfPPNN2zbto2f/vSnZGZmctNNN/HII48A3qWPN954Y0SXUE6fPp1HH32UnJwcdu3aFXKIJzc3t94pfi+88EJGjhxZPXT0+OOPA/DSSy/x2GOPkZ6eztChQykuLmbAgAFccsklpKWlcfXVV9f7FX4AU6dOJScnh/POO6/G1wjOmDGD/Px80tLSyM7O5tNPPwW8oZ+f/vSnjBs3rkW+RFxT/orEgPY85e+BAwc4+uijMTNefPFFFixYwGuvvdbazYpYVVUVmZmZLFy4kBNPPDGs+2jKXxHxrVWrVvHrX/+aqqoqunXrxpw5c1q7SRFbv349F198MWPGjAk72JtK4S4ibdrZZ5/d7NPjNre0tDS2bt3aoo+pMXcRER9SuIvEiNY6Pyato6n7W+EuEgMSExMpKSlRwLcTzjlKSkpITExs9DY05i4SA1JSUigqKqK4uLi1myItJDExkZSUlEbfX+EuEgMSEhIYMGBAazdDYoiGZUREfEjhLiLiQ2GFu5mNNLPNZlZgZpNDrH/czNYGfj4zs9LoN1VERMLV4Ji7mcUBM4HzgCJglZktcs5tPFzHOXdnUP3bgdOaoa0iIhKmcI7cc4AC59wXzrlDwHxg1BHqjwNyj7BeRESaWTjh3hsInhy6KFBWh5n1AwYA9c+TKSIizS6ccK876zzU90mKscCrzrnKkBsyG29mq81sta7XFRFpPuGEexHQJ2g5BdhRT92xHGFIxjk32zmX7ZzL7tmzZ/itFBGRiIQT7quAk81sgJkdhRfgi2pXMrOBQDfgw+g2UUREItVguDvnKoDbgHxgE5DnnNtgZlPN7OKgquOA+U6TX4iItLqwph9wzr0BvFGr7P5ayw9Gr1kiItIU+oSqiIgPKdxFRHxI4S4i4kMKdxERH1K4i4j4kMJdRMSHFO4iIj6kcBcR8SGFu4iIDyncRUR8SOEuIuJDCncRER9SuIuI+JDCXUTEhxTuIiI+pHAXEfEhhbuIiA8p3EVEfEjhLiLiQwp3EREfUriLiPiQwl1ExIcU7iIiPqRwFxHxIYW7iIgPKdxFRHxI4S4i4kMKdxERH1K4i4j4kMJdRMSHFO4iIj6kcBcR8aGwwt3MRprZZjMrMLPJ9dS53Mw2mtkGM5sX3WaKiEgk4huqYGZxwEzgPKAIWGVmi5xzG4PqnAxMAYY45/aa2X81V4NFRKRh4Ry55wAFzrkvnHOHgPnAqFp1bgJmOuf2AjjndkW3mSIiEolwwr03sC1ouShQFuwU4BQze9/MlpvZyGg1UEREItfgsAxgIcpciO2cDJwNpADvmlmqc660xobMxgPjAfr27RtxY0VEJDzhHLkXAX2CllOAHSHq/M05V+6c2wpsxgv7Gpxzs51z2c657J49eza2zSIi0oBwwn0VcLKZDTCzo4CxwKJadRYCwwDMrAfeMM0X0WyoiIiEr8Fwd85VALcB+cAmIM85t8HMpprZxYFq+UCJmW0E3gYmOedKmqvRIiJyZOZc7eHzlpGdne1Wr17dKo8tIhKrzGyNcy67oXr6hKqIiA8p3EVEfEjhLiLiQwp3EREfUrhLTevy4PFUeDDZ+70ur7VbJCKNEM4nVKW9WJcHi++A8jJved82bxkg/fLWa5eIREzhLt9b+tD3wX5YeRm8ORniE6FDfOAnzvsdl1BzuUNCPesPlwUv602jSHNSuLcnB/d7R+Ol2wK/vwy6/RUcKA59v+9KIO+XUW6MNRD+cSFePA6/gAQtxyW08PpGbsM6gIWapkmkeSjc/cI5KNvrhfThAK++HQjxg6U17xPXEbqmQHJfGHgBbPgb/Gdf3W0f0wuuehWqKqCq0vtdWV5zuar2chh1Il4f9FNxMFCnou664J8a68tbZl/Up7EvDh3iIS6+1gtMfDOvP9KLbTjr4/ViFsq6PFg2FfYVef97w+9vtiFPhXuscM47si796vuf2iF+6N8175PQyQvu5D6QkvP97a59vdudetYcHun/k5pj7gAJSTDiYeiV1jL9bG51Xigqg15EwnmBCGd9ZdALVajHaMT68rIwXwyD+1cOrqr1nms70gtUiBeHiIb6orW+se/mQj1GA0ONLXxOyx/h3oKvhs2mqhK+3Vn/kMm+Iu9oNVhishfWx54IJ54dCO4+gRDvC0ndIjt6OvycxfpzeSQd4rwfOrZ2S1pGVRW4I70TqvWCEOoFIpL1jXoxDLG+4hBUfde4F8NWY0cO/2+3e20NVl7m/b8p3EOIlSs8Ksu9wKwzZBI4Ct+/ve4fZqeeXlgfd6o3bHL4iPtwiCd2iX470y9vW8+bNE2HDkAHL3ASklq7Nc3POe/dSkPvhI44LBjmu6lIhxY/ruerpfcVNctTEfvhvmxq6Cs8lvwWOnaBTj3g6O5eUB7VKbIj2UjeEZSXefXqGzL5dmett8gGnY/3grpPTuCIO3DU3bWv93hHHR3x0yHSrpkFhoPiWrsldRW+62VBbV1TmuXhYj/c63vVK9sLuVfULItPhKN7QKfugd896lnuAVvfhfzJNd8RLLoddq6Dbv1qHXlvgwO1vjbW4qBrby+oB5wVFNyBEO+SAvFHRf/5EJG2afj9oc9pDb+/WR4u9sO9a0roV8POx8MVL8F3u+HAbu9k5He74UDJ92UlW7zl8gPhPVbFQfjwKe92jStNRtYdMul8vHcCSUQEWvycVuynz/D74W8TofLQ92UJSXDeVEg5PbxtlJd5YR8c/gturqeywV2b615pIiLSkBY8pxX76ZR+OXTp7Z2Nxryj5oumR/YEJiR5R9wnnAYnnwsZY73thNI1BTofp2AXkTYt9o/cd6yFvVvh/D/C4Fujt90WHh8TEYmm2D/8XPUsxCdB5pXR3W765d47gK59aPQ7AhGRVhLbR+5le2H9q5A+BpKSo799XfMtIjEqto/c1+ZCRRmccVNrt0REpE2J3XCvqvKGZFJy4Pj01m6NiEibErvhvvWfsOdzOOPG1m6JiEibE7vhvvJZb1qBUy9p7ZaIiLQ5sRnupdvgsyWQdTXEt5PZ/UREIhBbV8tUT+QVmG7gmONatz0iIm1U7IR77al9AZY95A3N6HJFEZEaYmdYpr6pfZdNbZ32iIi0YbET7vVN7dtME92LiMSy2An3+ia0b6aJ7kVEYlnshPvw++t+TZgm8hIRCSl2wl0TeYmIhC2sq2XMbCTwJBAHPOucm1Zr/bXAn4DtgaIZzrlno9hOjybyEhEJS4PhbmZxwEzgPKAIWGVmi5xzG2tVfdk5d1sztFFERCIUzrBMDlDgnPvCOXcImA+Mat5miYhIU4QT7r2B4G+gLgqU1TbazNaZ2atmVs931ImISEsIJ9wtRJmrtbwY6O+cSweWAi+E3JDZeDNbbWari4uLI2upiIiELZxwLwKCj8RTgB3BFZxzJc65/wQW/wKcHmpDzrnZzrls51x2z549G9NeEREJQzhXy6wCTjazAXhXw4wFanxhqZkd75zbGVi8GNjU0EbXrFmz28y+jKCtPYDdEdT3i/bY7/bYZ2if/W6PfYam9btfOJUaDHfnXIWZ3Qbk410K+bxzboOZTQVWO+cWAXeY2cVABbAHuDaM7UZ06G5mq51z2ZHcxw/aY7/bY5+hffa7PfYZWqbfYV3n7px7A3ijVtn9QbenAFOi2zQREWms2PmEqoiIhC2Wwn12azeglbTHfrfHPkP77Hd77DO0QL/NudpXNYqISKyLpSN3EREJU0yEu5mNNLPNZlZgZpNbuz3RYmZ9zOxtM9tkZhvM7FeB8mPN7P+a2ZbA726BcjOz6YHnYZ2ZZbVuDxrPzOLM7CMz+3tgeYCZrQj0+WUzOypQ3jGwXBBY3781290UZpYc+AT3p4F9Ptjv+9rM7gz8bX9iZrlmlujHfW1mz5vZLjP7JKgs4n1rZtcE6m8xs2ua0qY2H+5BE5ddAAwCxpnZoNZtVdRUAHc5534E/BiYGOjbZGCZc+5kYFlgGbzn4OTAz3hgVss3OWp+Rc3PQzwKPB7o817ghkD5DcBe59wPgMcD9WLVk8CbzrkfAhl4/fftvjaz3sAdQLZzLhXvUuqx+HNfzwVG1iqLaN+a2bHAA8CZeHN6PXD4BaFRnHNt+gcYDOQHLU8BprR2u5qpr3/Dm31zM3B8oOx4YHPg9p+BcUH1q+vF0g/ep5yXAecAf8eb4mI3EF97n+N9vmJw4HZ8oJ61dh8a0ecuwNbabffzvub7eamODey7vwPn+3VfA/2BTxq7b4FxwJ+DymvUi/SnzR+5E/7EZTEt8Bb0NGAFcJwLfOI38Pu/AtX88lw8AfwWqAosdwdKnXMVgeXgflX3ObB+X6B+rDkRKAbmBIajnjWzTvh4XzvntgP/DXwF7MTbd2vw/74+LNJ9G9V9HgvhHs7EZTHNzI4BXgN+7Zzbf6SqIcpi6rkwswuBXc65NcHFIaq6MNbFknggC5jlnDsNOMD3b9NDifl+B4YURgEDgBOATnhDErX5bV83pL5+RrX/sRDuDU5cFsvMLAEv2F9yzr0eKP7GzI4PrD8e2BUo98NzMQS42MwK8b4b4By8I/lkMzv8iengflX3ObC+K94UF7GmCChyzq0ILL+KF/Z+3tfnAludc8XOuXLgdeB/4f99fVik+zaq+zwWwr164rLAWfWxwKJWblNUmJkBzwGbnHP/O2jVIuDwmfJr8MbiD5dfHTjb/mNgn/t+wraY4Jyb4pxLcc71x9uXbznnfgG8DVwWqFa7z4efi8sC9WPuaM459zWwzczQmsVYAAAA6UlEQVQGBoqGAxvx8b7GG475sZkdHfhbP9xnX+/rIJHu23xghJl1C7zrGREoa5zWPgkR5omKnwGfAZ8D97Z2e6LYr6F4b7vWAWsDPz/DG2dcBmwJ/D42UN/wrhz6HFiPdxVCq/ejCf0/G/h74PaJwEqgAHgF6BgoTwwsFwTWn9ja7W5CfzOB1YH9vRDo5vd9DTwEfAp8AvwP0NGP+xrIxTuvUI53BH5DY/YtcH2g/wXAdU1pkz6hKiLiQ7EwLCMiIhFSuIuI+JDCXUTEhxTuIiI+pHAXEfEhhbuIiA8p3EVEfEjhLiLiQ/8fSBwIvy7nVDoAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Calculate mean and standard deviation for training set scores\n", "train_mean = np.mean(train_scores, axis=1)\n", "train_std = np.std(train_scores, axis=1)\n", "\n", "# Calculate mean and standard deviation for test set scores\n", "test_mean = np.mean(test_scores, axis=1)\n", "test_std = np.std(test_scores, axis=1)\n", "\n", "# TODO: Make this plot bigger, add y log scale, add title, despine if possible, add standard deviation?\n", "plt.plot(trees_to_try, train_mean, marker='o', label='Training Accuracy')\n", "plt.plot(trees_to_try, test_mean, marker='o', label='Testing Accuracy')\n", "plt.legend()\n", "# ax.set_yscale('log')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient Boosted Trees\n", "\n", "Because gradient boosting is used for kaggle-style competitions more commonly than random forests, there are quite a few more established strategies out there. These models can often be more difficult to tune than random forests, but it is a little more nuanced than simply cranking up the number of trees and crossing your fingers.\n", "\n", "One crucial hyperparameter that is introduced to gradient boosting is the learning rate. As previously mentioned, this tells us how drastic the adjustments are on our new trees being built. One peculiarity is that learning rates suffer pretty heavily from the [Goldilocks principle](https://en.wikipedia.org/wiki/Goldilocks_principle) - it has to be just right to have the optimal performance. It also highly depends on the number of trees we're training. Here is a chart that shows the relationship between the number of trees and the learning rate:\n", "\n", "\n", "*Source: [Synced](https://medium.com/syncedreview/tree-boosting-with-xgboost-why-does-xgboost-win-every-machine-learning-competition-ca8034c0b283)*\n", "\n", "Generally speaking, if we have a low number of trees and a high learning rate, we will get to a good performance faster but we will have a lower top-end performance. Conversely, we can get a better performance with a low learning rate and a lot of trees, but it will take much longer to get there.\n", "\n", "Most of the other hyperparameters are either similar to or are the same as those in random forests. However, we'll want to use different value ranges for them because the trees between the two algorithms are inherently different. Random forests use larger, relatively unconstrained trees, but boosting methods use weak learners. These week learners are by definition much less complex, so they are smaller, simpler trees.\n", "\n", "There are a variety of tuning guides (several are listed [here](https://machinelearningmastery.com/configure-gradient-boosting-algorithm/)), but my favorite is this guide from Zhonghua Zhang, the former \\#1 Kaggler in the world:\n", "\n", "\n", "*Source: [Zhonghua Zhang](https://www.slideshare.net/ShangxuanZhang/winning-data-science-competitions-presented-by-owen-zhang)*\n", "\n", "Note that this does include several hyperparameters specifically for XGBoost that are not included in the scikit-learn implementation, but we will ignore those for now." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "start_time": "2018-07-11T00:47:19.836Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Default Parameters: \n", "\n", "GradientBoostingClassifier(criterion='friedman_mse', init=None,\n", " learning_rate=0.1, loss='deviance', max_depth=3,\n", " max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=100,\n", " presort='auto', random_state=None, subsample=1.0, verbose=0,\n", " warm_start=False) \n", "\n", "Beginning hyperparameter tuning\n", "Fitting 3 folds for each of 30 candidates, totalling 90 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Done 5 tasks | elapsed: 2.2min\n", "[Parallel(n_jobs=-1)]: Done 10 tasks | elapsed: 8.0min\n", "[Parallel(n_jobs=-1)]: Done 17 tasks | elapsed: 11.0min\n", "[Parallel(n_jobs=-1)]: Done 24 tasks | elapsed: 15.9min\n", "[Parallel(n_jobs=-1)]: Done 33 tasks | elapsed: 21.8min\n", "[Parallel(n_jobs=-1)]: Done 42 tasks | elapsed: 24.5min\n", "[Parallel(n_jobs=-1)]: Done 53 tasks | elapsed: 30.9min\n", "[Parallel(n_jobs=-1)]: Done 64 tasks | elapsed: 41.6min\n", "[Parallel(n_jobs=-1)]: Done 77 tasks | elapsed: 54.8min\n", "[Parallel(n_jobs=-1)]: Done 90 out of 90 | elapsed: 66.5min finished\n" ] } ], "source": [ "# Creating the dictionary of parameters to use in the search\n", "parameters = {'n_estimators': scipy.stats.randint(low=100, high=1000), # Uniform distribution between 100 and 1000\n", " 'learning_rate': [0.003, 0.01, 0.03, 0.1, 0.3], # How drastic updates are\n", " 'subsample': [0.5, 0.75, 1.0], # The portion of rows to use in updates\n", " 'max_depth': [3, 6, 8, 10], # Maximum number of levels in a tree\n", " 'min_samples_split': [2, 5, 10], # Minimum number of samples required to split a node\n", " 'min_samples_leaf': [1, 2, 4] # Minimum number of samples required at each leaf node\n", " }\n", "\n", "hyperparameter_tune_get_results(gradient_boosting, parameters, 'Gradient Boosted Trees', num_rounds=30)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2018-07-11T01:56:24.482360Z", "start_time": "2018-07-11T01:56:24.451113Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.6886670.6210840.71936242.689128
Random Forest0.7130000.6004200.7314663247.834107
Gradient Boosted Trees0.7125560.6199080.7347924069.814399
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.688667 0.621084 0.719362 42.689128\n", "Random Forest 0.713000 0.600420 0.731466 3247.834107\n", "Gradient Boosted Trees 0.712556 0.619908 0.734792 4069.814399" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tuned_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stacking\n", "\n", "Stacking is more of a special case because we have to worry about tuning the hyperparameters of the individual models within the ensemble. We can borrow our tuned ensemble models for part of it, but will have to tune the *k*-NN model" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "ExecuteTime": { "end_time": "2018-07-11T02:35:46.570323Z", "start_time": "2018-07-11T02:33:04.825669Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training the model\n", "Completed\n", "\n", "Coefficients for models\n", "Model 1: -4.734082188838665\n", "Model 2: 13.070197844351487\n", "Model 3: 13.093441698919188\n", "Model 4: -7.996204564151691\n", "Model 5: 1.098772263424389\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.6890000.6211180.7193240.124991
Random Forest0.6704440.9800890.7050071.015584
Gradient Boosted Trees0.7033330.6020380.72921311.014955
Stacking0.7082220.9736270.725960146.808059
LightGBM0.7125560.6063740.7348030.187471
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.689000 0.621118 0.719324 0.124991\n", "Random Forest 0.670444 0.980089 0.705007 1.015584\n", "Gradient Boosted Trees 0.703333 0.602038 0.729213 11.014955\n", "Stacking 0.708222 0.973627 0.725960 146.808059\n", "LightGBM 0.712556 0.606374 0.734803 0.187471" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Defining the learners for the first layer\n", "model_1 = linear_model.LogisticRegression()\n", "model_2 = ensemble.RandomForestClassifier(n_estimators=620, max_depth=30,\n", " min_samples_split=10, n_jobs=-1)\n", "model_3 = ensemble.RandomForestClassifier(n_estimators=620, max_depth=30,\n", " min_samples_split=10, n_jobs=-1)\n", "model_4 = ensemble.GradientBoostingClassifier()\n", "model_5 = neighbors.KNeighborsClassifier(n_jobs=-1)\n", "\n", "# Putting the models in a list to iterate through in the function\n", "models = [model_1, model_2, model_3, model_4, model_5]\n", "\n", "# Running our function to build a stacking model\n", "train_stacking_get_results(models)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Additional Frameworks\n", "\n", "We have been using scikit-learn up until now for our models, but there are more specialized frameworks for gradient boosting in particular. Scikit-learn's gradient boosting algorithm is good, but lacks additional optimization and a few components and options that can be useful\n", "\n", "Specifically, we're going to focus on **XGBoost** and **LightGBM**. We'll go into more specifics for each, but both frameworks are focused on speed and performance and have the following advantages & disadvantages:\n", "\n", "#### Advantages\n", "- Ability to parallelize training\n", "- Ability to use GPUs\n", "- Additional under-the-hood optimization\n", "- Can specify loss functions\n", "- Additional tuning parameters\n", "- Distributed computing options\n", "- Native handling of missing values\n", "\n", "#### Disadvantages\n", "- Relatively difficult to install\n", "- Not as unified integration in older versions\n", "\n", "So generally speaking, XGBoost and LightGBM are able to train better models faster, but can be more difficult to set up and use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### XGBoost\n", "\n", "[XGBoost](https://github.com/dmlc/xgboost) is an extremely popular framework for gradient boosted trees created by Tianqi Chen, a Ph.D. student at the University of Washington. It was initially released in 2014, but did not become popular until it started dominating competitions on Kaggle a few years later. It has implementations in several languages, but we will be focusing on the Python implementation. For more history, Tianqi posted [this blog post](https://homes.cs.washington.edu/~tqchen/2016/03/10/story-and-lessons-behind-the-evolution-of-xgboost.html) about the history, philosophy, and learnings behind creating XGBoost.\n", "\n", "As I mentioned, both XGBoost and LightGBM use a series of clever tricks and under-the-hood optimizations that are not included in the Scikit-Learn implementation that make them train better models faster. One example is that XGBoost uses second derivatives to find the optimal constant in each terminal node, whereas other implementations just use the first derivative. This is nearly impossible to unpack without getting into the math, but it should give an idea of the type of under-the-hood optimization that is happening. If you are interested, [here is the XGBoost white paper](https://arxiv.org/abs/1603.02754) that explains a lot of the optimizations.\n", "\n", "#### Installation\n", "\n", "The [installation guide](https://xgboost.readthedocs.io/en/latest/build.html) states that there is only a wheel file on PyPI for the 64-bit version of Linux, so things get a little more complicated for Windows & OSX users. Specifically, you have to build the library from the source.\n", "\n", "However, I do have a workaround for Windows users (sorry OSX users!) that I borrowed from [this blog post](https://medium.com/@rakshithvasudev/how-i-installed-xgboost-after-a-lot-of-hassels-on-my-windows-machine-c53e972e801e). Download the wheel file for your version of Windows and Python [here](https://www.lfd.uci.edu/~gohlke/pythonlibs/#xgboost) (cp27/35/36/37 are the version of Python, and win32/\\_amd64 are the versions of Windows), navigate a command window to the directory where you downloaded it, and do a pip install in your command prompt with `pip install xgboost‑0.72‑cp35‑cp35m‑win_amd64.whl` using whichever wheel file you downloaded.\n", "\n", "If you don't know your version of Windows or Python, run the code block below." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2018-07-07T16:23:15.333691Z", "start_time": "2018-07-07T16:23:15.326695Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python: 3.5.5 | packaged by conda-forge | (default, Apr 6 2018, 16:03:44) [MSC v.1900 64 bit (AMD64)]\n", "('64bit', 'WindowsPE')\n" ] } ], "source": [ "import sys\n", "import platform\n", "\n", "print('Python:', sys.version)\n", "print(platform.architecture())" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2018-07-07T16:23:52.000357Z", "start_time": "2018-07-07T16:23:51.243785Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training the model\n", "Completed\n", "\n", " Non-tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.6960.6045430.7434980.023969
Random Forest0.6541.4734350.6849390.148897
Gradient Boosted Trees0.7240.5962530.7333900.577691
Stacking0.6561.0429050.6936955.193073
XGBoost0.7220.5953990.7369510.262850
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.696 0.604543 0.743498 0.023969\n", "Random Forest 0.654 1.473435 0.684939 0.148897\n", "Gradient Boosted Trees 0.724 0.596253 0.733390 0.577691\n", "Stacking 0.656 1.042905 0.693695 5.193073\n", "XGBoost 0.722 0.595399 0.736951 0.262850" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import xgboost as xgb\n", "\n", "xgboost = xgb.XGBClassifier(n_jobs=-1) # n_jobs=-1 uses all available cores\n", "\n", "# Due to the scikit-learn API option, LightGBM works with our function!\n", "train_model_get_results(xgboost, 'XGBoost')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Hyperparameter Tuning\n", "\n", "XGBoost has additional hyperparameters that can be tuned - [here is the full list](https://xgboost.readthedocs.io/en/latest/parameter.html#parameters-for-tree-booster). For the purposes of this demonstration, we'll stick with mostly the same hyperparameters that we used for our previous gradient boosting example.\n", "\n", "I mentioned this above, but below is the tuning guide from Zhonghua Zhang, the former \\#1 kaggler in the world. Additionally, [here](https://machinelearningmastery.com/configure-gradient-boosting-algorithm/) is the blog post containing other tuning strategies that are primarily focused on XGBoost.\n", "\n", "\n", "*Source: [Zhonghua Zhang](https://www.slideshare.net/ShangxuanZhang/winning-data-science-competitions-presented-by-owen-zhang)*\n", "\n", "**TODO: Update these hyperparameters**" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2018-07-07T16:25:43.511719Z", "start_time": "2018-07-07T16:24:52.360557Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Default Parameters: \n", "\n", "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n", " colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,\n", " max_depth=3, min_child_weight=1, missing=None, n_estimators=100,\n", " n_jobs=-1, nthread=None, objective='binary:logistic',\n", " random_state=0, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,\n", " seed=None, silent=True, subsample=1) \n", "\n", "Beginning hyperparameter tuning\n", "Fitting 3 folds for each of 5 candidates, totalling 15 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Done 1 tasks | elapsed: 7.9s\n", "[Parallel(n_jobs=-1)]: Done 2 tasks | elapsed: 10.2s\n", "[Parallel(n_jobs=-1)]: Done 3 tasks | elapsed: 11.9s\n", "[Parallel(n_jobs=-1)]: Done 4 tasks | elapsed: 12.7s\n", "[Parallel(n_jobs=-1)]: Done 5 tasks | elapsed: 14.3s\n", "[Parallel(n_jobs=-1)]: Done 6 tasks | elapsed: 15.4s\n", "[Parallel(n_jobs=-1)]: Done 7 tasks | elapsed: 24.3s\n", "[Parallel(n_jobs=-1)]: Done 8 tasks | elapsed: 25.5s\n", "[Parallel(n_jobs=-1)]: Done 9 out of 15 | elapsed: 27.5s remaining: 18.3s\n", "[Parallel(n_jobs=-1)]: Done 10 out of 15 | elapsed: 33.3s remaining: 16.6s\n", "[Parallel(n_jobs=-1)]: Done 11 out of 15 | elapsed: 38.3s remaining: 13.9s\n", "[Parallel(n_jobs=-1)]: Done 12 out of 15 | elapsed: 39.2s remaining: 9.7s\n", "[Parallel(n_jobs=-1)]: Done 13 out of 15 | elapsed: 41.3s remaining: 6.3s\n", "[Parallel(n_jobs=-1)]: Done 15 out of 15 | elapsed: 45.2s remaining: 0.0s\n", "[Parallel(n_jobs=-1)]: Done 15 out of 15 | elapsed: 45.2s finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Completed\n", "Best estimator: \n", "\n", "XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n", " colsample_bytree=1, gamma=0, learning_rate=0.01, max_delta_step=0,\n", " max_depth=4, min_child_weight=1, missing=None, n_estimators=899,\n", " n_jobs=-1, nthread=None, objective='binary:logistic',\n", " random_state=0, reg_alpha=1, reg_lambda=0, scale_pos_weight=1,\n", " seed=None, silent=True, subsample=1.0)\n", "\n", "Accuracy before tuning: 0.722\n", "Accuracy after tuning: 0.719333333333\n", "\n", " Tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.6966670.6045830.7433725.783737
Random Forest0.7206670.6003490.73445329.119582
Gradient Boosted Trees0.7213330.5913190.74243168.062629
XGBoost0.7193330.6031500.73241250.254662
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.696667 0.604583 0.743372 5.783737\n", "Random Forest 0.720667 0.600349 0.734453 29.119582\n", "Gradient Boosted Trees 0.721333 0.591319 0.742431 68.062629\n", "XGBoost 0.719333 0.603150 0.732412 50.254662" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xd4VOeZ/vHvM0W994KQECB6B1PccMcFG/fektjJxs7G+WWzm91kk/1tNmun2bHXKXYSlzhxr9imGGMMtgMG0QUIiSIkod6QhAC1d/+YwSsTAUKa0Zk5ej7XpUujmTNzboTOraP3nHmPGGNQSikV/BxWB1BKKeUbWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWujKZ0SkRETaRSTphPu3iIgRkZxBzjNfRMp99Fofi8jXfPFaSvmLFrrytf3Arce/EJFJQLh1cexBRFxWZ1CBTwtd+doLwF09vr4b+HPPBUQkVER+KSKlIlItIr8XkXDvY/Ei8p6I1IpIo/f2sB7P/VhEfiIin4lIi4h8cOJfBN7lIoGlQIaItHo/MkTEISLfF5G9IlIvIq+KSIL3OWEi8hfv/U0iskFEUkXkp8C5wJPe13myl/X1+lzvYwki8qyIVHj/TW/3eN59IrJHRBpEZLGIZPR4zIjIAyJSDBR77xsrIiu8y+8WkZvO/L9I2ZUWuvK1dUCMiIwTESdwM/CXE5b5GZAHTAVGAZnAj7yPOYBngWxgOHAEOLFAbwPuBVKAEOCfTgxhjDkMXA5UGGOivB8VwD8Ci4DzgQygEfiN92l3A7FAFpAIfAM4Yoz5AfAJ8KD3dR7s5d/d63O9j70ARAATvJkfAxCRC4GHgZuAdOAA8PIJr7sImA2M9/6SWgG86H2dW4HfisiEXvKoIUgLXfnD8b30S4BC4ODxB0REgPuA7xhjGowxLcB/A7cAGGPqjTFvGGPavI/9FE/59vSsMabIGHMEeBXPL4a++jrwA2NMuTHmGPAfwA3eIY0OPGU8yhjTZYzZaIxp7uPr9vpcEUnH84vlG8aYRmNMhzFmtfc5twPPGGM2ebP8KzD3hGMND3u/T0eAq4ASY8yzxphOY8wm4A3ghjP49ysb03E55Q8vAGuAEZww3AIk49lb3ejpdgAEcAKISASePdgFQLz38WgRcRpjurxfV/V4vTYg6gyyZQNviUh3j/u6gFRv7izgZRGJw/OXxQ+MMR19eN1en+u9r8EY09jLczKATce/MMa0ikg9nr9YSrx3l52QfbaINPW4z+Vdt1K6h658zxhzAM/B0SuAN094uA7PUMQEY0yc9yPWGHO8lL8LjAFmG2NigPO89wtnrrepRMuAy3usO84YE2aMOejde/7/xpjxwDw8e8R3neK1/m9FJ39uGZDgLfkTVeApaeCLcf9EevxFc8J6y4DVJ2SPMsb8w6myqaFDC135y1eBC71j2V8wxnQDfwAeE5EUABHJFJHLvItE4yn8Ju/Byh8PIEM1kCgisT3u+z3wUxHJ9q47WUSu8d6+QEQmecf+m/EMo3T1eK3ck63oZM81xlTiOTj7W+8BX7eIHP8l9SJwr4hMFZFQPENPnxtjSk6ymveAPBG50/s6bhGZJSLjzvQbo+xJC135hTFmrzEm/yQP/wuwB1gnIs3Ah3j2ygF+jec0xzo8B1iXDSBDIfASsM975kkG8DiwGPhARFq865jtfUoa8DqeQt4FrOb/Dug+jmesvVFEnuhldad67p14Cr4QqAEe8uZbCfw7nnHwSmAk3mMJJ/n3tACXepepwDP09DMgtM/fFGVrohe4UEope9A9dKWUsgktdKWUsgktdKWUsgktdKWUsolBfWNRUlKSycnJGcxVKqVU0Nu4cWOdMSb5dMsNaqHn5OSQn3+yM9mUUkr1RkQO9GU5HXJRSimb0EJXSimb0EJXSimb0EJXSimb0EJXSimb0EJXSimb0EJXSimb0EJXSimb0EJXSimb0GuKql69+Hmp1REG1W2zh1sdQakB0z10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyCS10pZSyidMWuohkicgqEdklIjtE5Nve+xNEZIWIFHs/x/s/rlJKqZPpyx56J/BdY8w4YA7wgIiMB74PrDTGjAZWer9WSillkdMWujGm0hizyXu7BdgFZALXAM97F3seWOSvkEoppU7vjMbQRSQHmAZ8DqQaYyrBU/pAykmec7+I5ItIfm1t7cDSKqWUOqk+F7qIRAFvAA8ZY5r7+jxjzNPGmJnGmJnJycn9yaiUUqoP+lToIuLGU+Z/Nca86b27WkTSvY+nAzX+iaiUUqov+nKWiwB/AnYZYx7t8dBi4G7v7buBd3wfTymlVF+5+rDM2cCdwHYR2eK979+AR4BXReSrQClwo38iKqWU6ovTFrox5lNATvLwRb6No5RSqr/0naJKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTWuhKKWUTfZltUamTMsbQeqyTprYOWo91EhfhJiU6DKfjZPO5KaX8RQtd9UtHVzdby5pYu6+eykNHv/SYyyHkJkdyfl4KI5IiLUqo1NCjha7OWEndYV7JL+PQkQ7SYsK4YmIaiVGhRIW6aDjczsGmI2wua+IPn+wjJzGC66YPIykq1OrYStmeFrrqs25j+KS4jhU7q4iLCOErZ49gZHIknotaeWQlRDAlK45LxqeSX9LAh7tqeHLVHq6blsnkYXEWplfK/rTQVZ8YY3h/WyVr99UzMTOW66ZlEuZ2nnR5t9PB3JFJjEuP4eUNZby8oYyKpqNcNiH1S78AlFK+o2e5qD5ZWVjD2n31nD0ykVtnZZ2yzHuKiwjhvnNzOSsngTXFtSwrqMIY4+e0Sg1NuoeuTmvt3jo+KqxhxvB4rpiUfsZ72E6HcM3UDBwO+GRPHQCXT0r3R1SlhjQtdHVK5Y1tvL+9krFp0Syaltnv4RIRYeHkDIzxlHp8ZAhzchN9nFapoU0LXZ3UsY4uXtlQRnSYmxtnZA343HIRYeGUDJraOnhvWwXJ0aGMTI7yUVqllI6hq5N6d1sFDYfbuWlmFuEhfRszPx2HCDfPyiIxKpQXPy+l4XC7T15XKaWFrk5id1ULm0qbmD8m2edvDgpzO7lrTjYGw6v5ZXR160FSpXxBC139nWOdXby3rYKkqBAuGJvil3UkRoVyzdRMShvaWLW7xi/rUGqo0UJXf+ePn+yn/nA7Cydn4HL470dkyrA4pmXFsaqwhgP1h/22HqWGCi109SUVTUd48qM9jE+PYXRqtN/Xd/WUDOIjQ3htYzntnd1+X59SdqaFrr7kVx8U0W0MVw7SeeKhbifXTc+k4XA7KwurB2WdStmVFrr6wr7aVt7aXM6dc7KJjwwZtPXmJkUxKyeeT4vrONh4ZNDWq5TdaKGrLzyxsphQl5Ovnz9y0Ne9YEI6UaEu3txcrme9KNVPWugKgD01LbyztYK75mWTHD34U92GhzhZOCWDykNHWbevftDXr5QdaKErAH79YTERbidfP2/w986Pm5ARw+iUKD7cVU3L0Q7LcigVrLTQFSV1h3l/eyV3zs0hYRDHzk90fL6Xzi7D8h1VluVQKlhpoSue+Ww/Lodw79k5VkchKTqUc0Ynsam0Sc9NV+oMaaEPcU1t7byWX87VUzJJjQmzOg4A88ckExPm4v3tlXTr3OlK9ZkW+hD3189LOdLRxdfOHWF1lC+EupxcOj6N8sYjbD94yOo4SgUNLfQhrL2zm+f/VsK5oz2XigskU4fHkR4bxvIdVXR06TtIleoLLfQh7P3tFdS0HONr5+ZaHeXvOES4fGI6TW0dehqjUn2khT6E/XVdKblJkZw3OsnqKL0alRLFmNRoVu2uoe1Yp9VxlAp4py10EXlGRGpEpKDHff8hIgdFZIv34wr/xlS+VljVTP6BRm6bPbzfl5UbDAsmpnGso5uPdIpdpU6rL3vozwELern/MWPMVO/HEt/GUv7213WlhLgcXD99mNVRTik1JoyZOQms21dPXesxq+MoFdBOW+jGmDVAwyBkUYPk8LFO3tp8kKsmpQ/qJFz9dfG4FFwOh77ZSKnTGMgY+oMiss07JBPvs0TK7xZvraD1WCe3zxludZQ+iQ5zc15eEjsqmvXNRkqdQn8L/XfASGAqUAn86mQLisj9IpIvIvm1tbX9XJ3ypZfWlzImNZrpw4Pn9/A5o5KJDnOxrKAKo282UqpX/Sp0Y0y1MabLGNMN/AE46xTLPm2MmWmMmZmcnNzfnMpHdle1sK38EDfNygrog6EnCnE5uGhsKgca2iisarE6jlIBqV+FLiI9L2dzLVBwsmVVYHljUzkuh7BoaobVUc7YjOx4kqJCWLajSudMV6oXrtMtICIvAfOBJBEpB34MzBeRqYABSoCv+zGj8pGOrm7e3HSQC8emkBg1+HOeD5TTIVw6Po0X15eyubSRmTkJPnvtFz8v9dlrBYvbZgfHMRTVd6ctdGPMrb3c/Sc/ZFF+tqaolrrWY9wwI7BPVTyVCRkxZMWH8+GuaiYPiyPEpe+NU+o43RqGkNc3lpMYGcIFY1OsjtJvIsKCiek0H+1krU4JoNSXaKEPEU1t7Xy4q5pF0zJxO4P7v31EUiRj06JZXVRDW7tOCaDUccG9Zas+W1pQRUeX4dppmVZH8YlLJ3imBPh4t54Kq9RxWuhDxOItFeQmRTIhI7Cmye2vtJgwpg2PZ+2+ehrb2q2Oo1RA0EIfAmqaj7Jufz0Lp2QE1bnnp3PxuBQEWLmr2uooSgUELfQh4L1tlRgDC6cE37nnpxIXEcLckYlsLm2i6tBRq+MoZTkt9CHg3W0VjE+PYVRKlNVRfO78vGRC3Tpxl1KghW57ZQ1tbC5tst3e+XERIS7m56Wwu7qFfbWtVsdRylJa6Da3eGsFAFdNTj/NksFr7shEYsPdLNuhE3epoU0L3ebe3VrB9OFxZCVEWB3Fb9xOBxePS6G88Qg7KpqtjqOUZbTQbay4uoXCqhautulwS0/ThseTEh3KBzt14i41dGmh29i7WytwCFxh4+GW4xwiXDYhjbrWdvIP6AW21NCkhW5TxhgWb61g7shEUqLDrI4zKMamRZOTGMHKXTUc7eiyOo5Sg04L3aYKDjZTUt82JIZbjhMRrpiUTuuxTj7eXWN1HKUGnRa6TS3eehC3U1gwwf7DLT0Ni49g+vB4PttTT13rMavjKDWotNBtqLvb8P62Ss4bnUxshNvqOIPusgmpOJ3Cku2VVkdRalBpodvQtoOHqDh0lCsmDa298+Oiw9xcOCaFwqoWiqr1+qNq6NBCt6GlBZW4HMLF41KtjmKZeSMTSYwM4f3tlXoaoxoytNBtxhjDsoIq5o1KGpLDLce5nA6umJRObcsx1umVjdQQoYVuM4VVLRyob2PBhDSro1hubFo0o1OiWFlYTesxvbKRsj8tdJtZWlCFCFw6YegOtxwnIlw5KZ32zm5W7NTZGJX9aaHbzPKCKmblJJAUFWp1lICQEhPGvJFJbChp5ED9YavjKOVXWug2sq+2ld3VLVw+UYdberpoXAqx4W7e3nJQD5AqW9NCt5Fl3os8XKbj518S6nKycHIG1c3H+GxPndVxlPIbLXQbWVZQxZSsODLiwq2OEnDGZ8QwLj2GlYXV1Os7SJVNaaHbRHljG9vKD+lwyylcPSUDhwhvbT5It14IQ9mQFrpNLN9RDaCnK55CbLibKyams6/uMBtKdIpdZT9a6DaxrKDSM31sUqTVUQLazJx4RiVHsbSgisa2dqvjKOVTWug2UNNylPwDjSzQ4ZbTEhGunZYJwBubynXoRdmKFroNfLCjGmPg8olDczKuMxUfGcJVk9LZV3uYv+lZL8pGtNBtYPmOKnKTIslLjbI6StCYkR3P+PQYlu+spurQUavjKOUTWuhBrqmtnbV767lsYhoiYnWcoCEiLJqWSbjbySv5pbR3dlsdSakB00IPcit2VtPZbfR0xX6ICnVxw4xhVDcf471tFVbHUWrAtNCD3PIdVWTGhTMpM9bqKEEpLzWa+XnJ5B9oZHNpo9VxlBoQLfQg1nqskzXFdVw2QYdbBuKicankJEby9paDVDfreLoKXlroQWxVYQ3tnd16uuIAOR3CLbOyCHU5eWHdAdrade50FZy00IPYsoIqkqJCmZEdb3WUoBcT7ub22cM51NbByxvKdFZGFZS00IPU0Y4uVu2u8Vzh3qHDLb6QnRjJNVMz2FPTytKCSqvjKHXGTlvoIvKMiNSISEGP+xJEZIWIFHs/6y7iIFtTVEtbe5cOt/jYzJwE5o1M5G9763WqXRV0+rKH/hyw4IT7vg+sNMaMBlZ6v1aDaFlBFbHhbubkJlodxXaumJTO+PQYlmyvZPvBQ1bHUarPTlvoxpg1wIlT010DPO+9/TywyMe51Cm0d3bz4a5qLh6Xitupo2a+5hDh5llZZCVE8Fp+GXtqWq2OpFSf9LcNUo0xlQDezym+i6ROZ+2+epqPduqbifzI7XRw15xskqJCeWFdCSV1ej1SFfj8vnsnIveLSL6I5NfW1vp7dUPCsoIqIkOcnDM6yeoothYR6uLes3OIDQ/h+bUllDW0WR1JqVPqb6FXi0g6gPdzzckWNMY8bYyZaYyZmZyc3M/VqeO6ug0rdlZxwdgUwtxOq+PYXnSYm6+eM4LIUBd/+mw/++p0+EUFrv4W+mLgbu/tu4F3fBNHnc6GkgbqWtt1qtxBFBvu5v5zc4kLd/PcZyXsrmqxOpJSverLaYsvAWuBMSJSLiJfBR4BLhGRYuAS79dqECwrqCLU5WD+GP1rZzDFhLu579xcUmI8Y+r5egk7FYBcp1vAGHPrSR66yMdZ1Gl0dxuWFlRyfl4ykaGn/a9TPhYZ6uJr5+Ty0vpS3tx8kPrD7VwyPhWHzqOjAoSe8xZENpc1Ut18jCsn63CLVcLcTu6am8OsnARWF9Xyl3UHONLeZXUspQAt9KCyZHsVIU4HF47Vs0St5HQIi6ZmsHByOkXVLfzm4z1UNB2xOpZSWujBorvbsHR7JeflJREd5rY6zpAnIswdmcR95+bS2dXN71bv5dPiWr3otLKUFnqQ2FreRMWho1wxSYdbAkl2YiQPXjiavNRolhRU8cxn+2k43G51LDVEaaEHiaUFVbidwkXjUq2Ook4QFerijtnDuXZaJgcbj/D4yiLWFNXqFLxq0OmpEkHAGMOS7ZWcMyqJ2HAdbglEIsKsnATyUqNZvLWCZTuq2FTayBWT0slLjbY6nhoidA89CGw/eIjyxiNcrsMtAS823M2dc7K5c042nd2G5/5WwrOf7edAvc4Fo/xP99CDwJLtVbgcwqXjdbglWIxLj2F0ShRr99WzuqiWp9bsIzc5kgvGpJCbFKnXgFV+oYUe4IzxvJlo3qgk4iJCrI6jzoDL6eDc0cnMHpHI+v31fFJcx58+3U92QgTn5SWTlxqtV5tSPqWFHuB2VjZzoL6Nfzh/pNVRVD+FuBycMzqZ2bmJ5B9oZE1RLS+sO0BsuJsZ2fHMzI7XX9bKJ7TQA9yS7ZU4HcKlE3Tu82DndjqYm5vIWTkJFFY1s35/A6sKa1hVWMOYtGimD49nTFq0XrRE9ZsWegDznN1SxdzcRBIidQ/OLpwOYUJGLBMyYmk43E5+SQMbDzRSWNVCiMvB+PQYJg+LZVRKFC6HlrvqOy30AFZY1cL+usN87dwRVkdRfpIQGcKlE9K4aFwq++sOs628iR0VzWwpayLc7WRCRgyTh8UxIilSx9vVaWmhB7B3t1bgdAiX6XCL7TkdwqiUKEalRHH11G721LSyrfwQ2w4eIv9AI1GhLiZlxjIlK46s+HA9S0b1Sgs9QBljeGdLBeeMSiIpKtTqOGoQuRwOxqbFMDYtho6ubnZXtbC1vIkNJQ2s3VdPfISbKcPimJIVR2pMmNVxVQDRQg9Qm0obOdh0hO9emmd1FGUht9PBxMxYJmbGcrSji50VzWwtb2J1US0fF9WSFhPG1Kw4ZmbHE6Fz5A95+hMQoN7ZUkGoy6Fnt6gvhLmdTM+OZ3p2PC1HOyg4eIgtZU0s21HFh7uqmTwsjjm5CQyLj7A6qrKIFnoA6ujq5v1tlVw8PpUo3etSvYgOczN3ZBJzRyZRdego6/bXs6W0iU2ljQyLD+fsUUlMyozVqykNMdoWAeizPXXUH27nmikZVkdRQSAtNoxFUzNZMCGNTaWNrNvXwCsbyli5q5r5Y1KYmhWnxT5EaKEHoMVbKogJc3G+XghanYEwt5N5I5OYk5vIjopmPt5dw+sby/mkuJbLJqQxJjVaz46xOS30AHOkvYvlO6pYOCWDUJfT6jgqCDlEmJQZy8SMGAoqmvlgRxV/XnuA0SlRLJycQVK0njVlV/o2tACzsrCaw+1dXD1Vh1vUwIi32B+6OI+rJqdT2tDG4x8Vs2JnFZ1d3VbHU36ge+gB5p0tFaTGhDJ7RKLVUZRNOB3CvJGeg6RLC6pYtbuWnZXNTBsez6RhsVbHUz6ke+gB5FBbBx/vrmHh5Ax9m7fyuegwNzfNzOKuOdm0tXex6Lef8ZtVe+jWS+XZhu6hB5ClBZV0dBmumZppdRRlY2PTY3goMZJNpY38Yvlu1u2r59GbppKsY+tBT/fQA8g7WyrITYpkYmaM1VGUzYWHOHnytmk8fN0k1u9vYOH/fMqWsiarY6kB0kIPEJWHjrBufz0Lp2ToqWVqUIgIt541nLe+eTYup3DTU2t5Lb/M6lhqALTQA8Sbmw5iDFw/fZjVUdQQMz4jhncfPIdZOfF87/Vt/HL5bozRcfVgpIUeAIwxvJZfxuwRCQxP1Hk41OCLjwzhuXvP4pZZWTy5ag/feWULxzq7rI6lzpAeFA0AGw80UlLfxgMXjLI6ihrC3E4HD183iayECH6xfDeNbR08decMwtz6BrdgoXvoAeC1/HIiQpxcMSnd6ihqiBMRHrhgFI9cN4k1xbV85bkNtLV3Wh1L9ZEWusXa2jt5b1sFV05KJ1JnVlQB4pazhvPoTVNYt6+eu59ZT8vRDqsjqT7QQrfY0u1VHG7v4oYZejBUBZZrpw3jf26dzubSJu7403oOtWmpBzotdIu9vKGUnMQIzhqRYHUUpf7OlZPT+e3t09lV0cxtf1xHU1u71ZHUKWihW6iouoUNJY3cetZwPfdcBaxLJ6Tx9F0zKK5u5d7nNnD4mI6pByotdAu9tL4Ut1N0uEUFvPljUnji1mlsLWviG3/ZqKc0BigtdIsc7ejijY3lXDYhjcQonUNDBb4FE9P42fWT+aS4jode3qJT8AYgLXSLLNleSfPRTm6bPdzqKEr12Y0zs/j3q8aztKCKf3tru76jNMDoeXIWefHzUkYkRTI3V+c9V8Hlq+eM4FBbO098tIfYcDc/uHK81ZGU14AKXURKgBagC+g0xsz0RSi7Kzh4iPwDjfzwynF6MFQFpe9ckkfTkQ7+8Ml+0mPD+co5I6yOpPDNHvoFxpg6H7zOkPH830oIdzu5cWaW1VGU6hcR4ccLJ1DdfJSfvL+TjLhwFkxMszrWkKdj6IOsvvUY72yt4PoZmcSGu62Oo1S/OR3Cr2+expRhcTz0ymY2lzZaHWnIG2ihG+ADEdkoIvf3toCI3C8i+SKSX1tbO8DVBb+XN5TR3tnN3XNzrI6i1ICFhzj5490zSYkO42vP51Na32Z1pCFtoIV+tjFmOnA58ICInHfiAsaYp40xM40xM5OTkwe4uuDW2dXNX9Yd4JxRSYxOjbY6jlI+kRQVyrP3zqLLGO55br2+m9RCAyp0Y0yF93MN8BZwli9C2dXSgioqDx3lnnk5VkdRyqdGJkfx9J0zKW84wv1/1jceWaXfhS4ikSISffw2cClQ4KtgdmOM4ak1e8lNjuTCsSlWx1HK584akcAvbpzM+pIG/vUNPUfdCgM5yyUVeMt72p0LeNEYs8wnqWzosz31FBxs5mfXT8Lh0FMVlT1dMzWTA/VtPLqiiBFJkXzrotFWRxpS+l3oxph9wBQfZrG1p9bsJTk6lEXTMq2OopRffevCUZTUHeZXK4oYnhjBNVP1Z36w6GmLg6Dg4CE+Ka7jK2ePINSll/NS9iYiPHz9JM7KSeB7r29j44EGqyMNGVrog+B3q/cSFerSeVvUkBHqcvLUnTPIiA3j/j9v1NMZB4kWup8VVbewZHsld83N1jcSqSElPjKEZ+6ZRWe34d7n1nPoiF7xyN+00P3s8ZXFRLid3HdurtVRlBp0uclRPHXnDEob2vjmXzfSoVPu+pUWuh/trvLsnd9zdg7xkSFWx1HKEnNyE3n4usl8tqeef3+7QE9n9COdPtePnviomMgQF187R/fO1dB2w4xhlNQd5slVexiRFMnXzx9pdSRb0kL3kx0Vh1iyvZJvzh+pe+dKAf/vkjz21x/mkWWFZCdGsGBiutWRbEeHXPzAGMPDSwqJC3dz/3m6J6IUgMMh/OrGKUzNiuOhV7awrbzJ6ki2o4XuB2uK6/h0Tx3funC0ntmiVA9hbidP3zmTpKhQvvp8PgebjlgdyVa00H2sq9vw8JJdDE+I4I452VbHUSrgJEeH8sw9szja3sVXn9tAy1E9ndFXtNB97PWNZRRWtfDPC8YQ4tJvr1K9yUuN5rd3TKe4ppUHXtyspzP6iDaODzW1tfOzZbuZkR3PlZP0gI9Sp3Lu6GT++9qJrCmq5Z9e20p3t57OOFB6losP/WL5bpra2vnJNbP14s9K9cHNs4ZT19rOL5bvJj4ihB8vHK/bzgBoofvI1rImXlxfyj3zchifEWN1HKWCxjfnj6S+tZ1nPttPUlQID16oU+72lxa6D3R2dfPDtwtIigrlO5fkWR1HqaAiIvzwynE0trXzyw+KSIgM1Yns+kkL3QeeWrOP7QcP8ZvbphMTpqcpKnWmHA7h5zdMpqmtnR++vZ2YcBdXTc6wOlbQ0YOiA1RY1cyvPyziysnpXDlZD4Qq1V9up4Pf3j6DGdnxfPtLGXnzAAAKtElEQVTlLby/rdLqSEFHC30AOrq6+e6rW4kNd/OTayZaHUepoBce4uTZe89i+vA4/vHlzby3rcLqSEFFC30AfrF8NzsqmvmvRZNI0PlalPKJqFAXz917FjOGe/bU392qpd5XWuj99OHOap5es4875gxnwcQ0q+MoZSuRoS6evXeWd/hlM4u11PtEC70fyhvb+O5rW5mQEcMPrxxvdRylbCky1MWz98xiVk4CD728mVc3lFkdKeBpoZ+htvZOvvGXjXR1G35z23TC3HrRZ6X85fie+tmjkvjnN7bx2IoivUDGKWihn4HubsNDL29hZ0UzT9w6lZykSKsjKWV7ESEunrlnFjfOGMbjK4v53uvbdO6Xk9Dz0M/Az5YX8sHOan501XguHJtqdRylhgy308HPb5hMZnw4v/6wmOrmo/z29ulE6/s+vkT30PvoqdV7eWr1Pm6fPZx7z86xOo5SQ46I8NDFefz8+sn8bW89Nz21jrKGNqtjBRQt9D54/m8lPLy0kKsmp/Of10zUyYOUstBNs7J45p5ZlDe0sfDJT1lTVGt1pIChhX4af15bwo8X7+CS8ak8dvNUnA4tc6Wsdn5eMou/dQ6p0WHc/ex6fv1hEZ06rq6FfjLGGB79YDc/emcHF49L5cnbpuF26rdLqUAxIimStx6Yx6Kpmfz6w2Ju/cO6IX9JO22oXhzr7OJf39zOEx/t4aaZw/j9HdMJdenpiUoFmogQF4/dPJXHbp7CzopmFjy2hlc2lA7ZUxu10E9Qdegotzy9jpc3lPHgBaP42fWTcemeuVIB7dppw1jy7XMZnxHDv7yxnbueWU9J3WGrYw06baoeVhXWcNX/fEJRVQu/u306/3TZGD0AqlSQyE6M5KX75vBfiyay6UAjlz62hl8u301be6fV0QaNnocOtBzt4Kfv7+LlDWWMSY3myfumMTo12upYSqkz5HAId8zJ5pLxqTyytJAnV+3htY1lfPuiPG6cOcz2x8Hs/a87je5uw+sby7nwV6t5Nb+Mf5g/ksXfOlvLXKkglxoTxmM3T+X1b8xlWHwE//bWdi55dDWvbiijvdO+Z8MMyT10Ywyri2p5bEURW8sPMSUrjj/cNZOpWXFWR1NK+dDMnARe/8ZcVu6q4dEVRZ75YD4s4u55Odw0M8t2014PqUJv7+xm+Y4q/vjpfraWNZEZF84vb5zCddMycej55UrZkohw8fhULhqXwuqiWn6/ei+PLC3k0RVFXDExjaunZnDOqGRCXME/YGH7QjfGsLu6hXe2VPBafjl1rccYnhDBI9dN4rrpw2zxn6iUOj0RYf6YFOaPSWF3VQsvrCth8ZYK3t5SQWy4m8smpHLV5AzmjUwM2jPbbFnoRzu62HSgkdVFtawsrGFPTStOh3B+XjJ3zsnm/Lxk3SNXaggbkxbNfy2axI+umsAnxbW8t62SJdureDW/nKhQFzNz4pmTm8ic3EQmZsQETcEHfaE3HG5nb20re2paKa5uZUtZIwUHm2nv6sbtFGblJHD3vBwun5hGUlSo1XGVUgEkxOXgonGpXDQulaMdXXy8u5ZP99Sybl8DjywtBCAyxMn4jBjyUqMZkxZNXqrnIz7CHXCnNQ+o0EVkAfA44AT+aIx5xCepTrB2bz1byppoOHyMhsMd3s/tlDUeoeFw+xfLhbocTMyM5d6zc5iVk8DckYlEhgb97yyl1CAIcztZMDHti0tK1rYc4/P99azf38Cuymbe3VrBXz/v7LG8g9SYsC8+kqNCiQx1EhHiIiLESXiIk9AeQ7pzchNJjQnz67+h320nIk7gN8AlQDmwQUQWG2N2+ircccsKKnl+7QFCXQ4SI0NIiAohITKUcekxjEqJYmRKFKOSo8iMC9ehFKWUTyRHh3LV5AyumpwBeI7HVTcfo6i6heKaVqoOHaGq+RjVzUfZVt5EXcsx2jq6ONmsA8/dOytwCx04C9hjjNkHICIvA9cAPi/07142hn9eMJaIEGfA/YmjlBoaRIS02DDSYsM4Ly+512WMMRzt6KatvZO29i7au7o53lhpsf4tcxhYoWcCPa/aWg7MPnEhEbkfuN/7ZauI7D6DdSQBdf1OOPiCLS8EX2bN6yO39353wOY9hWDL3J+82X1ZaCCF3tuu8t/9sWGMeRp4ul8rEMk3xszsz3OtEGx5Ifgya17/Cra8EHyZ/Zl3IOfilANZPb4eBlQMLI5SSqn+GkihbwBGi8gIEQkBbgEW+yaWUkqpM9XvIRdjTKeIPAgsx3Pa4jPGmB0+S+bRr6EaCwVbXgi+zJrXv4ItLwRfZr/llaF6ZQ+llLKb4Hg/q1JKqdPSQldKKZuwrNBFZIGI7BaRPSLy/ZMsc5OI7BSRHSLyove+qSKy1nvfNhG5OZDz9ngsRkQOisiTgZ5XRIaLyAcissv7eE6A5/25975dIvKEDNK7z06XWUQeE5Et3o8iEWnq8djdIlLs/bg7kPMG6jZ3qu+v9/FB3eYGmtkn250xZtA/8BxE3QvkAiHAVmD8CcuMBjYD8d6vU7yf84DR3tsZQCUQF6h5ezz+OPAi8GQgf3+9tz8GLvHejgIiAjUvMA/4zPsaTmAtMD8QvscnLP8tPCcOACQA+7yf47234wM4b0BucyfL2+O+QdvmfJHZF9udVXvoX0wbYIxpB45PG9DTfcBvjDGNAMaYGu/nImNMsfd2BVAD9P4+3ADICyAiM4BU4AM/5xxwXhEZD7iMMSu897caY9oCNS+eN7OF4dmAQgE3UO3nvH3N3NOtwEve25cBK4wxDd5/zwpggV/TDiBvAG9zveYFS7Y5GEBmX213VhV6b9MGZJ6wTB6QJyKficg68czs+CUichaeDXmv35J69DuviDiAXwHf83PGngby/c0DmkTkTRHZLCK/EM9EbAGZ1xizFliFZ6+xElhujNnl57x9zQyAiGQDI4CPzvS5PjSQvD0fC6Rt7nimL+W1aJuDgX2PfbLdWTW3bF+mDXDh+TN7Pp53oX4iIhONMcfH9dKBF4C7jTH+vuprv/MCdwBLjDFlgzS0CwPL6wLOBaYBpcArwD3An/yUFQaWNwkY570PYIWInGeMWeOnrMf1aeoLr1uA140xXf14rq8MJK/nBQJvmzvuxLzfZPC3ORhYZp9sd1btofdl2oBy4B1jTIcxZj+wG88GjYjEAO8DPzTGrAvwvHOBB0WkBPglcJeI+GXeeB/lLQc2e/9s7ATeBqYHcN5rgXXeP1FbgaXAHD/n7Wvm426hx3DAGT7XVwaSN1C3ueNOzGvFNgcD/5kY+HY3GAcLejkY4MJzIGgE/3fwYMIJyywAnvfeTsLzp0yid/mVwEPBkPeEZe5hcA6KDuT76/Qun+x97FnggQDOezPwofc13N6fjYWB8D32LjcGKMH7Jj7vfQnAfjwHROO9txMCOG9AbnMny3vC44Oyzfnge+yT7W5Q/nNO8o+/AijCMxb3A+99/wlc7b0twKN45lffDtzivf8OoAPY0uNjaqDmtfCHq9958Vy0ZJv3/ueAkEDN690QngJ2eR97NFB+hr1f/wfwSC/P/Qqwx/txbyDnDdRt7lTf3x6PD9o254OfiQFvd/rWf6WUsgl9p6hSStmEFrpSStmEFrpSStmEFrpSStmEFrpSStmEFrpSStmEFrpSStnE/wIbnuCBvyj7ZwAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "parameters = {'n_estimators': scipy.stats.randint(low=100, high=1000), # Uniform distribution between 10 and 1000\n", " 'learning_rate': [0.01, 0.03, 0.1, 0.3],\n", " 'max_depth': [4, 6, 8, 10],\n", " 'subsample': [0.5, 0.75, 1.0],\n", " 'reg_alpha': [0, 1], # L1 regularization\n", " 'reg_lambda': [0, 1] # L2 regularization\n", " }\n", "\n", "hyperparameter_tune_get_results(xgboost, parameters, 'XGBoost', num_rounds=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LightGBM\n", "\n", "[LightGBM](https://github.com/Microsoft/LightGBM) is a project from [Microsoft Research Asia](https://www.microsoft.com/en-us/research/lab/microsoft-research-asia/) that is focused around training gradient boosted trees in a highly efficient and distributed manner. It's generally comparable to XGBoost, but is not as popular because it is much newer. More specifically, LightGBM was released in December, 2016, after XGBoost had taken become the de-facto framework for Kaggle competitions.\n", "\n", "One of the fundamental differences between LightGBM and other implementations of gradient boosted trees is that it grows the trees leaf-wise rather than level-wise, which is reportedly able to let them achieve lower loss than level-wise trees:\n", "\n", "\n", "\n", "\n", "\n", "*Source: [LightGBM](https://github.com/Microsoft/LightGBM/blob/master/docs/Features.rst)*\n", "\n", "Additionally, LightGBM uses a histogram based algorithm to discretize continuous variables into buckets in order to speed up the training process and reduce the memory requirements. XGBoost has included this in recent versions, but it is not enabled by default.\n", "\n", "There are several other optimizations happening under the hood (listed [here](https://github.com/Microsoft/LightGBM/blob/master/docs/Features.rst)), but those are a few of the main differences from other implementations.\n", "\n", "#### Installation\n", "\n", "[The documentation on GitHub](https://github.com/Microsoft/LightGBM/tree/master/python-package#installation) has installation instructions for LightGBM. It can be installed from PyPI with `pip install lightgbm`, but requires a few things to work - check out the documentation depending on your OS." ] }, { "cell_type": "code", "execution_count": 104, "metadata": { "ExecuteTime": { "end_time": "2018-07-11T03:03:27.823937Z", "start_time": "2018-07-11T03:03:25.527201Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training the model\n", "Completed\n", "\n", " Non-tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.5304670.6904430.5421932.046665
Random Forest0.5769220.7793160.60805213.342987
LightGBM0.5850220.6783840.6167541.859263
LightGBM AUC0.5850220.6783840.6167541.765515
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.530467 0.690443 0.542193 2.046665\n", "Random Forest 0.576922 0.779316 0.608052 13.342987\n", "LightGBM 0.585022 0.678384 0.616754 1.859263\n", "LightGBM AUC 0.585022 0.678384 0.616754 1.765515" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import lightgbm as lgb\n", "\n", "lightGBM = lgb.LGBMClassifier(nthread=-1) # nthread=-1 uses all available cores\n", "\n", "# Due to the scikit-learn API option, LightGBM works with our function!\n", "train_model_get_results(lightGBM, 'LightGBM')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Hyperparameter Tuning\n", "\n", "Because LightGBM is so similar to XGBoost, we can use the same tuning guidelines in principle. However, there are a few additional tuning guidelines noted in the official [LightGBM Parameter Tuning Guide](http://lightgbm.readthedocs.io/en/latest/Parameters-Tuning.html).\n", "\n", "**TODO: Adjust these hyperparameters**" ] }, { "cell_type": "code", "execution_count": 114, "metadata": { "ExecuteTime": { "end_time": "2018-07-11T03:39:42.598836Z", "start_time": "2018-07-11T03:11:49.527668Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Default Parameters: \n", "\n", "LGBMClassifier(boosting_type='gbdt', colsample_bytree=1, learning_rate=0.1,\n", " max_bin=255, max_depth=-1, min_child_samples=10,\n", " min_child_weight=5, min_split_gain=0, n_estimators=100, nthread=-1,\n", " num_leaves=31, objective='binary', reg_alpha=0, reg_lambda=0,\n", " seed=0, silent=True, subsample=1, subsample_for_bin=50000,\n", " subsample_freq=1) \n", "\n", "Beginning hyperparameter tuning\n", "Fitting 3 folds for each of 30 candidates, totalling 90 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Done 5 tasks | elapsed: 1.5min\n", "[Parallel(n_jobs=-1)]: Done 10 tasks | elapsed: 3.4min\n", "[Parallel(n_jobs=-1)]: Done 17 tasks | elapsed: 5.6min\n", "[Parallel(n_jobs=-1)]: Done 24 tasks | elapsed: 7.9min\n", "[Parallel(n_jobs=-1)]: Done 33 tasks | elapsed: 9.8min\n", "[Parallel(n_jobs=-1)]: Done 42 tasks | elapsed: 13.4min\n", "[Parallel(n_jobs=-1)]: Done 53 tasks | elapsed: 16.3min\n", "[Parallel(n_jobs=-1)]: Done 64 tasks | elapsed: 19.2min\n", "[Parallel(n_jobs=-1)]: Done 77 tasks | elapsed: 24.3min\n", "[Parallel(n_jobs=-1)]: Done 90 out of 90 | elapsed: 27.2min finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Completed\n", "Best estimator: \n", "\n", "LGBMClassifier(boosting_type='gbdt', colsample_bytree=1, learning_rate=0.03,\n", " max_bin=255, max_depth=6, min_child_samples=10, min_child_weight=5,\n", " min_split_gain=0, n_estimators=935, nthread=-1, num_leaves=31,\n", " objective='binary', reg_alpha=1, reg_lambda=1, seed=0, silent=True,\n", " subsample=0.5, subsample_for_bin=50000, subsample_freq=1)\n", "\n", "Accuracy before tuning: 0.585022222222\n", "Accuracy after tuning: 0.635922222222\n", "\n", " Tuned results:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
LightGBM0.6359220.6478060.6769931664.384194
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "LightGBM 0.635922 0.647806 0.676993 1664.384194" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8XPV57/HPo9G+75sl2ZYt27KNscEbEMxiFgNhSUvClsQkFJK0uW1v2rSkaZsuyW24uc3SJCSlIamBEMKFJGwBYhyb3TveJe+SLEuydlnWLs1z/5gjYnxlNJJmdGZGz/v1mtfMnDnL1+PRM2d+5/x+R1QVY4wx4S/K7QDGGGMCwwq6McZECCvoxhgTIaygG2NMhLCCbowxEcIKujHGRAgr6MYYEyGsoJuAEZEqEekXkexzpu8SERWRGZOc50oRqQ3QujaJyJ8EYl3GBIsVdBNox4G7hp+IyAVAgntxIoOIRLudwYQ+K+gm0B4HPn3W87XAY2fPICJxIvJ/RKRGRE6JyI9FJMF5LUNEXhSRJhFpcx4XnbXsJhH5VxF5W0Q6ReR35/4icOZLAl4GCkXkjHMrFJEoEXlQRI6KSIuIPC0imc4y8SLyhDO9XUS2iUieiHwDuBz4gbOeH4ywvRGXdV7LFJGfiUid82/6zVnL3S8iR0SkVUSeF5HCs15TEfkzETkMHHamzROR9c78B0XkE2P/LzKRygq6CbTNQKqIlIuIB7gDeOKceR4C5gCLgdnANOAfndeigJ8B04ESoAc4t4DeDXwGyAVigb8+N4SqdgE3AHWqmuzc6oA/B24DrgAKgTbgh85ia4E0oBjIAj4P9KjqV4E3gS866/niCP/uEZd1XnscSAQWOJm/AyAiVwP/BnwCKACqgafOWe9twApgvvMltR540lnPXcDDIrJghDxmCrKCboJheC/9WqASODn8gogIcD/wP1W1VVU7gf8F3Amgqi2q+qyqdjuvfQNf8T3bz1T1kKr2AE/j+2Lw1+eAr6pqrar2Af8E3O40aQzgK8azVXVIVXeo6mk/1zvisiJSgO+L5fOq2qaqA6r6urPMPcBPVXWnk+UrwCXnHGv4N+d96gE+ClSp6s9UdVBVdwLPAreP4d9vIpi1y5lgeBx4A5jJOc0tQA6+vdUdvtoOgAAeABFJxLcHuwbIcF5PERGPqg45zxvOWl83kDyGbNOBX4uI96xpQ0Cek7sYeEpE0vH9sviqqg74sd4Rl3Wmtapq2wjLFAI7h5+o6hkRacH3i6XKmXzinOwrRKT9rGnRzraNsT10E3iqWo3v4OiNwK/OebkZX1PEAlVNd25pqjpclP8KmAusUNVUYJUzXRi7kYYSPQHccNa201U1XlVPOnvP/6yq84FL8e0Rf/pD1vWHDZ1/2RNAplPkz1WHr0gD77f7Z3HWL5pztnsCeP2c7Mmq+oUPy2amDivoJljuA6522rLfp6pe4L+A74hILoCITBOR651ZUvAV/HbnYOXXJpDhFJAlImlnTfsx8A0Rme5sO0dEbnUeXyUiFzht/6fxNaMMnbWu0vNt6HzLqmo9voOzDzsHfGNEZPhL6kngMyKyWETi8DU9bVHVqvNs5kVgjoh8yllPjIgsE5Hysb4xJjJZQTdBoapHVXX7eV7+W+AIsFlETgOv4dsrB/guvtMcm/EdYH1lAhkqgV8Ax5wzTwqB7wHPA78TkU5nGyucRfKBZ/AV5Argdf5wQPd7+Nra20TkP0bY3Ict+yl8Bb4SaAT+0sm3AfgHfO3g9cAsnGMJ5/n3dALXOfPU4Wt6egiI8/tNMRFN7AIXxhgTGWwP3RhjIoQVdGOMiRBW0I0xJkJYQTfGmAgxqR2LsrOzdcaMGZO5SWOMCXs7duxoVtWc0eab1II+Y8YMtm8/35lsxhhjRiIi1f7MZ00uxhgTIaygG2NMhLCCbowxEcIKujHGRAgr6MYYEyGsoBtjTISwgm6MMRHCCroxxkQIK+jGGBMh7JqixoSRJ7fUTPo2715RMunbNONje+jGGBMhrKAbY0yEsIJujDERwgq6McZECCvoxhgTIaygG2NMhPCroItIuog8IyKVIlIhIpeISKaIrBeRw859RrDDGmOMOT9/99C/B7yiqvOAC4EK4EFgg6qWARuc58YYY1wyakEXkVRgFfAogKr2q2o7cCuwzpltHXBbsEIaY4wZnT976KVAE/AzEXlPRH4iIklAnqrWAzj3uSMtLCIPiMh2Edne1NQUsODGGGM+yJ+CHg1cBPxIVZcAXYyheUVVH1HVpaq6NCdn1ItWG2OMGSd/CnotUKuqW5znz+Ar8KdEpADAuW8MTkRjjDH+GLWgq2oDcEJE5jqTVgMHgOeBtc60tcBzQUlojDHGL/6Otvg/gJ+LSCxwDPgMvi+Dp0XkPqAG+HhwIhpjjPGHXwVdVXcBS0d4aXVg4xhjjBkv6ylqjDERwgq6McZECCvoxhgTIaygG2NMhLCCbowxEcIKujHGRAgr6MYYEyGsoBtjTISwgm6MMRHCCroxxkQIK+jGGBMhrKAbY0yEsIJujDERwgq6McZECCvoxhgTIaygG2NMhLCCbowxEcIKujHGRAgr6MYYEyGsoBtjTISwgm6MMRHCCroxxkQIK+jGGBMhov2ZSUSqgE5gCBhU1aUikgn8EpgBVAGfUNW24MQ0xhgzmrHsoV+lqotVdanz/EFgg6qWARuc58YYY1wykSaXW4F1zuN1wG0Tj2OMMWa8/C3oCvxORHaIyAPOtDxVrQdw7nODEdAYY4x//GpDBy5T1ToRyQXWi0ilvxtwvgAeACgpKRlHRGOMMf7wq6Crap1z3ygivwaWA6dEpEBV60WkAGg8z7KPAI8ALF26VAMT2xgzFkNe5dCpTrZVtdIzMERWUhwFafEsn5lJjMdOdosUo/5PikiSiKQMPwauA/YBzwNrndnWAs8FK6QxZvxOtHbz7+sP8vjmauraexDgSGMnL+2t50ebjlLf0eN2RBMg/uyh5wG/FpHh+Z9U1VdEZBvwtIjcB9QAHw9eTGPMeBxpPMMTm6tJivNw9/ISygtS8UQJAAcbOvnVzloe3niU25cWcWFRustpzUSNWtBV9Rhw4QjTW4DVwQhljJm4yobT/HxLDTnJcdx72QxS42M+8Prc/BT+fHUZP99SzbM7aslKiqUoI9GltCYQrPHMmAjU1t3P09tPkJcax/2Xl/5/xXxYUlw096yYTnJ8NE9sruZ078AkJzWBZAXdmAgz5FWe3nYCVbh7+XQSYj0fOn9SXDSfWjmdnoEhntpag1ft3IVwZQXdmAiz6VAj1a3d3Lq4kMykWL+WKUhL4JYLC6lq6WbXifYgJzTBYgXdmAjSeLqXjZWNLC5OZ3FxxpiWXVKSQVFGAq/ub6BvYChICU0wWUE3JoK8vK+BGE8UN15QMOZlo0S4eVEhnb2DbDrUFIR0JtisoBsTIQ43dnLwVCdXzc0lOc7fTuAfVJyZyJLidN460kxrV3+AE5pgs4JuTATwqvLy3gYyEmO4ZFbWhNZ1/YJ8AN44bHvp4cYKujER4L2adhpO97JmYcGEu/KnJsSwpDidndVtdNppjGHFCroxYc6ryqaDjRSmxbOwMDUg61xVlsOQV3n3WEtA1mcmhxV0Y8LcvpMdtHT1c+XcXJwhOiYsOyWO+YWpbD7Wwpm+wYCs0wSfFXRjwpiqsulgEzlOAQ6kVWU59A54eWprTUDXa4LHCroxYayyoZOG071cOSeHqADtnQ8rzkxkRlYij2+uxuu13qPhwAq6MWHs9UNNZCTGsChIIyUun5lJdUs3m60tPSxYQTcmTJ1s66GmtZtLZ2W/PyRuoC0oTCMtIYYnrdklLFhBNyZMbT7WQqwniounj62L/1jEeKL42JJp/G7/KetoFAasoBsThrr7Btld287iknTiYz58NMWJumt5Cf1DXn61szao2zETZwXdmDC0vbqNQa+ysnRivUL9MTc/hYtK0nlyaw1qQ+uGNCvoxoQZrypbjrcwMzuJ/NT4SdnmnctKONbUZUPrhjgr6MaEmUMNnbR1D0zK3vmwNRfkExsdxXO76iZtm2bsrKAbE2bePdZCanw08wsC25How6TGx7B6Xi4v7qljcMg7ads1Y2MF3Zgw0tzZx+HGMyyfmRm0UxXP59bF02g+08/bR+2c9FBlBd2YMLLleAseEZbNyJz0bV81L4fU+Giee+/kpG/b+McKujFhort/kB01bSyYlkpKfMykbz8u2sONFxTw6v4GevrtEnWhyAq6MWHiN+/V0Tvg5ZJJPBh6rlsWF9LVP8RrFadcy2DOz++CLiIeEXlPRF50ns8UkS0iclhEfiki/l1e3BgzZqrKY+9WUZgWT0lmoms5Vs7MIjcljt/urXctgzm/seyh/wVQcdbzh4DvqGoZ0AbcF8hgxpg/2Hq8lcqGTlaWZgVszPPxiIoS1izMZ+PBRrr7bZz0UONXQReRIuAm4CfOcwGuBp5xZlkH3BaMgMYYeGxzNWkJwRtVcSxuWFhA74CXTQftmqOhxt899O8CfwMMn4CaBbSr6vBXdC0wbaQFReQBEdkuItubmuwDYMxYNXT08uq+Bj6xtIjYaPcPey2fmUlWUqw1u4SgUT8dIvJRoFFVd5w9eYRZRxzkQVUfUdWlqro0JydnnDGNmbqe3FrDkCqfXDnd7SgAeKKE6xfm8/vKRnoH7GyXUOLP1/1lwC0iUgU8ha+p5btAuohEO/MUAdYn2JgA6x/08outNVw1N5fpWUlux3nfjQsL6O4f4vVD9qs7lIxa0FX1K6papKozgDuB36vqPcBG4HZntrXAc0FLacwU9cr+Bpo6+/jUJaGxdz5sRWkmGYkxvGzNLiFlIg1yfwt8SUSO4GtTfzQwkYwxwx57p4rpWYlcURZazZUxniiunZ/HhopG+gdtbJdQMaaCrqqbVPWjzuNjqrpcVWer6sdVtS84EY2ZmvbXdbC9uo1PrZxO1CSP2+KP6+bn09k3yJbjNrZLqHD/kLkxZkSPv1tNfEwUH7+42O0oI7psdjbxMVG8dsB6jYYKK+jGhKCO7gF+s+skH1syjbTEyR+3xR8JsR4uL8th/YFTdiWjEGEF3ZgQ9MvtNfQOePnUyhluR/lQ15bnUdfRy/66025HMVhBNybkDA55WfdONStLM5lfOHkXsRiPq8tzEcEG6woRVtCNCTGv7G/gZHsP932k1O0oo8pOjuOikgzWWzt6SIgefRZjzPk8uaUm4Ov80aYjZCXFcup0b1DWH2jXzs/jmy9XUtfeQ2F6gttxpjTbQzcmhNS0dHGirYdLZ2UR5eKoimNxTXkeYM0uocAKujEh5K2jLcTHRHHR9Ay3o/htdm4ypdlJ1uwSAqygGxMi2rr72X+yg+UzMomL9rgdZ0yunZ/H5mMtnO4dcDvKlGYF3ZgQ8e7RFkRgpYuXmBuva+bnMTCkvG5jpLvKCroxIaBvYIhtVa0snJZGemL4Xc3xopIMMpNirdnFZVbQjQkB26vb6Bv0ctmsbLejjIsnSrh6Xi4bDzYyMGSDdbnFCroxLvOq8s7RZqZnJlLs4gWgJ+ra+Xl09g6y9Xir21GmLCvoxrhs38kO2roHuGx2eO6dD7u8LJu46ChrdnGRFXRjXORVZePBRnJS4kK+m/9oEmOjuWx2NhsqbbAut1hBN8ZFB+pOc+p0H1fNzQ2bjkQfZnV5LidaezjceMbtKFOSFXRjXKKqbDrYSFZSLIuK0tyOExCr51mvUTdZQTfGJQcbOqnr6OXKCNk7B8hPi2dBYSobKhrdjjIlWUE3xgWqyu8PNpKRGMPi4nS34wTU6vI8dta00drV73aUKccKujEuONx4htq2Hq6ck4snBK8XOhHXlOeiChsrbS99sllBN2aSqSq/r2wkLSGGJdMja+8cYGFhGrkpcWyotHb0yWYF3ZhJdqy5i5rWblbNySE6KvL+BKOcXqNvHGqmf9B6jU4mu8CFMZPs95WNpMRHszRMhsgdz0U2YjxRnOkb5JsvVzI7N3lMy969omTM2zM+kbd7YEwIO9J4huPNXawqyyHGE7l/frNykomOEioa7OLRk2nUT5SIxIvIVhHZLSL7ReSfnekzRWSLiBwWkV+KSPgNEWfMJFJVXt3fQHpCDCtmZrodJ6hio6OYlZNMZf1p6zU6ifzZRegDrlbVC4HFwBoRWQk8BHxHVcuANuC+4MU0JvztqzvNyfYerinPIzqC986HzStIoa17gMbOPrejTBmjfqrUZ7gfb4xzU+Bq4Bln+jrgtqAkNCYCDHmV9QcayE2JY3FJ5J3ZMpJ5+b6xaSobOl1OMnX4tZsgIh4R2QU0AuuBo0C7qg46s9QC086z7AMisl1Etjc12dVMzNS0s7qN5jP9XDc/P2J6hY4mLSGGwrR4KuutHX2y+FXQVXVIVRcDRcByoHyk2c6z7COqulRVl+bk5Iw/qTFhamDIy4bKU5RkJlJekOJ2nEk1Nz+VmtZuuvoGR5/ZTNiYGvJUtR3YBKwE0kVk+LTHIqAusNGMiQzvHm3hdO8g1y/IR6bI3vmw8oIUFDh0yppdJoM/Z7nkiEi68zgBuAaoADYCtzuzrQWeC1ZIY8JVT/8Qrx9qYk5eMjOzk9yOM+kK0xNIiYumwtrRJ4U/HYsKgHUi4sH3BfC0qr4oIgeAp0Tk68B7wKNBzGlMWHrjcBM9A0NcNz/f7SiuiBJhbn4Ke092MOj1RmTP2FAyakFX1T3AkhGmH8PXnm6MGUF7dz9vH2nmwqI0CtMT3I7jmnn5qWyvbqOquXvMvUbN2NjXpTFBMnxtzam6dz5sdq6v12il9RoNOivoxgTByfYe3jvRzqWzsshImtqdqN/vNdrQab1Gg8wKujEBpqr8dm89ibEerpyb63ackDA3P4XWrn7rNRpkVtCNCbDKhk6ON3exujyP+BiP23FCwrx83/n3B+1sl6Cygm5MAA15lZf3NZCdHMfyGZE9ANdYpCfGUpAWb6MvBpkVdGMCaFtVK81n+rhhYX7EXVpuoublp1LT0k239RoNGivoxgRI78AQr1WcYmZ20vtNDOYP5uX7eo0etF6jQWMF3ZgAef1QE939Q9y4sGDKdfH3x7QMX69RG30xeKygGxMAbU4nosXF6UzLmLqdiD7McK/RQ6c6GfTatUaDwQq6MQHwh05EeS4nCW3z8lPpG/RS1dztdpSIZAXdmAmqbetm14l2PjI7m/TEqd2JaDTDvUYP2tkuQWEF3ZgJGO5ElBQXzao5Nt7/aGKjoyjNSaLCeo0GhRV0Yyagov40VS3dXFOea52I/DQvP5XWrn6arNdowFlBN2ac+ge9vLyvgZyUOJZOt05E/iov8F1r9IBdmi7grKAbM05PbqmmpavfOhGNUVpCDEUZCVbQg8AKujHj0NEzwPc2HKY0J4m5edaJaKzmF6RS29ZDR8+A21EiihV0Y8bhPzYcpr1nwDoRjdP8QqfZpa7D5SSRxQq6MWN0pLGTde9Uceey4il9JaKJyE2JJyc5jv3W7BJQVtCNGQNV5Z9fOEBCrIe/vm6u23HC2vzCVKqau2ywrgCygm7MGLxW0cibh5v5n9fMISs5zu04YW1BYSpexcZ2CSAr6Mb4qXdgiH998QBlucl86pLpbscJe9PSE0hLiGG/taMHjBV0Y/z06FvHqWnt5ms3LyDGY386EyUizC9I5XDjGfoHbbCuQLBPpTF+aOjo5Ycbj3D9gjw+UpbtdpyIMb8wlUGvcsjGSA8IK+jG+OGbL1cw6FX+/qb5bkeJKDOykkiM9VgnowAZtaCLSLGIbBSRChHZLyJ/4UzPFJH1InLYuc8IflxjJt/mYy38Zlcdn1tVSnFmottxIoonSpiXn0plw2kbIz0A/NlDHwT+SlXLgZXAn4nIfOBBYIOqlgEbnOfGRJS+wSH+7td7Kc5M4E+vnO12nIi0oDCV3gEvx5u63I4S9kYt6Kpar6o7ncedQAUwDbgVWOfMtg64LVghjXHLjzYd5VhTF1+/7QISYm00xWCYnZtMrCfKOhkFwJja0EVkBrAE2ALkqWo9+Io+kHueZR4Qke0isr2pqWliaY2ZREebzvDwxqPccmEhV9hY50ET44miLC+ZirrTeG2M9Anxu6CLSDLwLPCXqur3V6mqPqKqS1V1aU6O/VGY8DDkVf7mmT0kxHr4h4/agdBgW1iYRmffINUtdmm6ifCroItIDL5i/nNV/ZUz+ZSIFDivFwCNwYlozOT7rzePsaO6jX+5dQE5KdYjNNjmFaQQ4xH21La7HSWs+XOWiwCPAhWq+u2zXnoeWOs8Xgs8F/h4xky+gw2dfPt3h7hhYT63XFjodpwpIS7aw9z8VPad7GBwyM52GS9/9tAvAz4FXC0iu5zbjcA3gWtF5DBwrfPcmLDWOzDEl57eRUp8NF+/baENjTuJFk1Lo6t/iHePtbgdJWxFjzaDqr4FnO9TvTqwcYxx1zdeqmB/3Wl+8umlNvjWJJubn0JcdBQv7q7n8jI73jYe1lPUGMfzu+t4fHM1D6wq5Zr5eW7HmXJiPFGUF6Ty8r56G9tlnKygG4PvohVfeXYPF0/P4MvX2zjnbllUlMbp3kHePGynOI+HFXQz5TV19nHvz7aREBvND+5eYiMpumh2bjJpCTG8sLvO7ShhyT65Zkrr7h/kvnXbaDnTz0/vXUpBml1Szk3RUVHcsDCf9QdO0Tsw5HacsGMF3UxZvQND/OnPd7LvZAffv2sJi4rS3Y5kgI8uKqSrf4iNlda1ZaysoJspqad/iPsf286mg01842MX2EHQELKyNJPs5Fhe2GPNLmNlBd1MOR3dA9z7s628daSZb92+iLuWl7gdyZwl2hPFjRcUsKGikTN2AekxsYJuppTKhtPc8sO32FnTxnfvWMzHlxa7HcmM4KOLCukb9LKh4pTbUcLKqB2LjJmIJ7fUTOr27l4x8t62qvJ/t9fytef3kxIfzVMPXMLF0+2aLKFq6fQM8lPjeX5XHbcunuZ2nLBhBd1EvKrmLv7u13t552gLK2Zm8v27lpCbGu92LPMhoqKEWxYX8tO3jtNyps967frJmlxMxKpr7+Grv97Ltd95nb21HXzjYwv5xf0rrZiHiT++qIhBr/LcLjs46i/bQzcRRVXZVtXGL7bW8NKeehTljmXF/I+ry8izQh5W5uansHBaKs/urOWzH5npdpywYAXdRITuvkF2nmjn0beOcbSpi5S4aO5cXsz9l9uFncPZ7RcV8U8vHKCi/jTlBaluxwl5VtBN2FJVqlq62VbVyt6THQx5lSUl6Xzr9kXctKiAxFj7eIe7WxZP4xu/reDZHbX8vV05alT2iTdhp6d/iJ01bWytaqWps4+46CiWzchg2YxM/uo6G1grkmQmxXL1vFx+s6uOB2+YR7SNs/OhrKCbsNHTP8SbR5p452gL/YNeijIS+KMl01hUlE5stP2hR6o/vqiIV/efYuPBJq61Hr0fygq6CXleVTYfa+G1ilP0Dni5YFoaV8zJoTDdBtKaCq6al0tuShxPba2xgj4KK+gmpNV39PCrnSc52d5DWW4yaxbm24iIU0yMJ4pPLC3m4U1HqGvvsS/yD2G/U01IUmev/OFNR+noGeCOZcXce+kMK+ZT1B3LilHg6e0n3I4S0qygm5AzMOTll9tP8PzuOmblJPEXq8u4sCjdLtg8hRVnJnJ5WQ6/3HaCIa+6HSdkWUE3IaW7b5BH3zrO3toOrpufx6cvmUFSnLUMGrh7eTH1Hb28fsjGST8fK+gmZLR19/PjN45R197DXctLuHJuLlG2V24cq8vzyEmJ4/F3q92OErKsoJuQ0Nbdz0/ePMaZvgE+e9lMFk5LczuSCTExnijuXl7CxoNNHG/ucjtOSLKCblzX0TPAo28dp2dgiPsuK2VGdpLbkUyIumdlCTEeYd07VW5HCUmjFnQR+amINIrIvrOmZYrIehE57NzbwNJmXLr6Bnn0rWN09Q3ymUtnMi3DzmIx55ebEs9HFxXyzI5aOnsH3I4TcvzZQ/9vYM050x4ENqhqGbDBeW7MmPQPenns3Srauwe499IZNoiW8cu9l87gTN8gz+yodTtKyBm1oKvqG0DrOZNvBdY5j9cBtwU4l4lwQ17lqW011Lb1cOeyYqZnWTOL8c+FxeksKUln3TtVdgrjOcbbhp6nqvUAzn3u+WYUkQdEZLuIbG9qahrn5kykeXlfPZUNndx8YSHzC+0AqBmb+z4yk6qWbn63v8HtKCEl6AdFVfURVV2qqktzcnKCvTkTBrYdb+Wdoy1cNiuLlaVZbscxYeiGhQXMyErk4U1HUbW99GHjLeinRKQAwLm3M/2NX443d/Hc7pPOuCwFbscxYcoTJXzuilnsPdnBW0ea3Y4TMsZb0J8H1jqP1wLPBSaOiWTt3f08uaWazKQ47lxWgifKOg2Z8fuji6aRlxrHwxuPuh0lZPhz2uIvgHeBuSJSKyL3Ad8ErhWRw8C1znNjzmtgyMuTW2sY9CqfXFlCQqzH7UgmzMVFe/iTj5Ty7rEWdta0uR0nJPhzlstdqlqgqjGqWqSqj6pqi6quVtUy5/7cs2CM+YAXdtdR29bD7RcXkZtiF2s2gXH3ihIyk2L5zvpDbkcJCdZT1ATd1uOtbK9u48o5OSywM1pMACXFRfOnV87izcPNvHPU2tKtoJugqmnt5oXddZTlJnONXW3GBMEnV04nPzWeb716cMqf8WIF3QRNU2cfT26pJjUhmjuWFdvIiSYo4mM8/MU1ZbxX085rFVP7hDsr6CYo+ge9fOGJHXT3D3HPiukkxtqY5iZ4br+4iBlZifzvVyoZGPK6Hcc1VtBNwKkq//CbfWyvbuP2i4vsGpAm6GI8UXzlxnION57hsSk8XroVdBNw696p4pfbT/BnV81iUVG623HMFHHd/DxWzcnhu+sP0dTZ53YcV1hBNwH19pFm/vWlCq4pz+Wvrp3rdhwzhYgIX7t5Pr2DQzz0SqXbcVxhBd0ETHVLF3/6852UZifxnTsWE2U9Qc0km5WTzGc/MpNndtSy5ViL23EmnRV0ExCdvQP8ybrtAPxk7VJS4mNcTmSmqj+/uoySzES+/MweuvoG3Y4zqaygmwnrGxzic4+zf8p7AAAMJ0lEQVTv4HhzFw/fc5GNbW5clRQXzbduX8SJtm6++fLUanqxc8lc9uSWmknd3t0rSgK6Pq9X+dLTu3nnaAvf/sSFXDY7O6DrH6vJfj9NaFpRmsVnLp3JT98+zvUL8vlImbufy8lie+hm3FSVf3x+Hy/tqecrN8zjjy4qcjuSMe/7mzVzmZWTxF/+8j0aOnrdjjMprKCbcVFVvvb8fp7YXMPnVpXywKpStyMZ8wHxMR5+/MmL6e4f4gs/30H/YOR3OLKCbsbM6/UV88fereb+y2fy4A3zEOvWb0JQWV4K37r9Qt6raedfXtzvdpygszZ0Mya9A0P89f/dzYt76rn/8pn83Y3lVsxNSLtpUQG7a0t55I1jTEtP5AtXznI7UtBYQQ8yVeVkew+HT53heHMXtW09tPf009E9wJAqDR29REcJSXHRJMdHk5UUR25KHDkpccR4QusHVFtXP597fAdbq1r5yg3zeGBVqRVzExb+ds086tp7eOiVStITY7hreWBPDggVVtCD4ERrNxsPNvLGoSZ2nWin+Uz/+68lxnrISIwlLSGGaI9wuneA/kGlq6+LnoGh9+eLEihIS6AkK5Gy3GRKs5OJjXavwG+rauXPf/EeLWf6+f5dS7j5wkLXshgzVp4o4dufWExn7yBf/fVe4qKjIvIgvhX0APB6lW1VrWyobOT3lY0caTwDwPSsRK6Yk8viknTK81OYkZ1EVlLsB/Zqzz7NbnDIS0tXP42dfdS193CitZvtVa28e7SF6ChhZnYSc/JSmJuXQnZK3KT82/oGh/jRpqN8//dHKMpI4NkvXMoFRXaRChN+YqOj+NEnL+K+/97Ol57eTfOZPh5YFVnNL1bQx0lVOVB/mud21fH8rjoaTvcS4xFWzMziruUlXDU3h9Kc5DGtM9oTRV5qPHmp8VwwzVc0B4a8VLV0caihk0OnzvDS3npe2ltPdnIc5QUpzC9IpTgzMShjjb9+qIl/en4/x5u7uHVxIV+/baH1ADVhLTE2mv/+7DK+9PRu/tdvK6lr7+Xvbix39ddvIFlBH6MTrd08t+skz+2q43DjGaKjhCvm5PCVG+exujyP5LjAvqUxnijKclMoy03hJnzt2JWnOqmoP807R1p483AzSbEe5uWnUl6QwuzclAl9OL1e5bWKU/znG8fYUd1GaXYSj312Oavm5ATuH2WMi+KiPXz/ziXkp8bz6FvHee9EO9+/cwklWYluR5swK+h+aOzs5bd76nlhTz07qn1XF182I4N/vW0hN11QQGZS7KRlyUiK5ZLSLC4pzaJ3YIhDTnHfX9/Bjpo2PFFCcUYipTlJFGckUJCeQOooe9UdPQPsre1g/YEGXt1/iobTvRRlJPBPN8/nrhUlxEV7JulfZ8zkiIoS/uGj81k6PYO/eXYPN/3Hm3x5zVzuWTEdTxgPKmcF/Txau/p5ZV8DL+yuY8vxFrwKc/NS+PL1c7nlwkKKM93/No+P8bCoKJ1FRekMefX9ppljzV1srGxk+OqKCTEe0hNjSEuI4e2jzcR5ougdHKK1q5+69l5qWrsBiIuO4oo5Ofz94nLWLMgnOsTOsjEm0G64oICF09J48Fd7+Mfn9vPU1hN89aZyLp2VFZZncFlBd3i9yv6602w82MjGg43sPtGOV6E0O4kvXl3GzYsKKMtLcTvmeXmihFk5ycxy2u17B4ao7+ilrr2HpjN9dHQP0NEzQEX9afoHvcTHeMhMjOWCaWncsayYBYWpLJuRSVKAm4yMCXXFmYk8cd8Kfru3ga+/dIB7frKFxcXpfP6KWawuzw2504c/zJT96+0f9FJRf5r3atrYWdPOO0dbaD7ju8rJhUVpfPHqMq6bn8eCwtSw/KaOj/EwMzuJmdkfHPkw0INzGRMJRISbFhWwujyXZ3bU8uPXj/L5J3aQlRTLrYuncd2CPC6enhHyxX1CBV1E1gDfAzzAT1T1mwFJFUCDQ15q23o43tzF0SZf557Khk72nux4f2yHvNQ4LpmVxVVzc1g1J4fs5Mk5JdAYE1riYzx8cuV07lxWzKaDTTy7s5YnNlfz07ePkxIXzYrSTBYX+5o5Z2YnUZieEFJt7uMu6CLiAX4IXAvUAttE5HlVPRCocMOON3fR0TPAkNfL4JAy5FUGvUrPwBBdfYN09fvuz/QO0tLVR1PnH26NnX0MevX9daUlxFCWm8zaS6azpCSDJSXpFKTZRYyNMX8Q7Ynimvl5XDM/j87eAd4+0symg01srWrltYrG9+eL8QhFGYmUZCZSmJ5AWoLvWFVqQjRpCTEkxUYTFx1FbHQUCwrTSIgN7gkGE9lDXw4cUdVjACLyFHArEPCC/s8v7GfTwaZR54sSyEzydZvPSYljdm4KualxzMxKojQnidKcZDISY8KyCcUY446U+BjWLCxgzcICADq6B9hf30F1Szc1rd3UtHRT3drF/roOOnoGGBjSEdfz2pdWMTs3uMfhJlLQpwEnznpeC6w4dyYReQB4wHl6RkQOTmCbozo++izZQHMwMwRYQPPeE6gVfbgp/R5PgnDLC2PIPEmf0dEE/D0ue2hCi0/3Z6aJFPSRdnP/v68mVX0EeGQC2wkoEdmuqkvdzuGvcMsL4ZfZ8gZfuGUOt7zDJnLIthYoPut5EVA3sTjGGGPGayIFfRtQJiIzRSQWuBN4PjCxjDHGjNW4m1xUdVBEvgi8iu+0xZ+qajhcEiRkmn/8FG55IfwyW97gC7fM4ZYXAFEd+YisMcaY8BLa3Z6MMcb4zQq6McZEiIgp6CKyRkQOisgREXlwhNfvFZEmEdnl3P7krNf+t4jsF5EKEfkPmaSeRxPM/JCI7HNud4RCXmeeT4jIAef9fPKs6WtF5LBzWzsZeQOQ+RURaReRF0M9r4gsFpF3nWl7Qv0zISLTRWSH87neLyKfn4y8E8l81mupInJSRH4wOYnHQFXD/obvoOxRoBSIBXYD88+Z517gByMseynwtrMOD/AucGWIZ74JWI/voHYSsB1IDYG8ZcB7QIbzPNe5zwSOOfcZzuOMEHmPR8zsPF4N3Ay8GEKf4/O9x3OAMudxIVAPpIdw3lggznmcDFQBhaH8Hp/1+veAJ0f623T7Fil76O8PQ6Cq/cDwMAT+UCAe5wMGxACngpLygyaSeT7wuqoOqmoXvg/lmiDlHOZP3vuBH6pqG4CqDg96cT2wXlVbndfWT0LeiWZGVTcAnZOQc9i486rqIVU97DyuAxqBYF9maiJ5+1W1z5knjslrLZjQZ0JELgbygN9NUt4xiZSCPtIwBNNGmO+PnZ+jz4hIMYCqvgtsxLdHUw+8qqoVwQ7MBDLjK+A3iEiiiGQDV/HBTl7B4E/eOcAcEXlbRDY7o3H6u2wwTCSzGwKSV0SW49tBORq0pD4TyisixSKyx1nHQ84XUbCNO7OIRAH/Dnx5EnKOS6SMh+7PMAQvAL9Q1T6nvW4dcLWIzAbK8fV0BVgvIqtU9Y3gxQUmkFlVfyciy4B3gCZ8zUSDQU3rX95ofD9Xr8T3fr4pIgv9XDYYxp1ZVduDnG0kE84rIgXA48BaVfUGMStMMK+qngAWiUgh8BsReUZVg/3reCKf408Cv1XVE5N0mG3MImUPfdRhCFS15ayfeP8FXOw8/hiwWVXPqOoZ4GVgZZDzwsQyo6rfUNXFqnotvg/pYbfzOvM8p6oDqnocOIjvD8OtYSImktkNE8orIqnAS8Dfq+rmUM87zNkz3w9cHsSsZ+cZb+ZLgC+KSBXwf4BPi0hoXQPC7Ub8QNzwfaMeA2byhwMdC86Zp+Csx8NFHOAO4DVnHTHABuDmEM/sAbKcx4uAfUB0CORdA6xzHmfj+2mbhe9g6HF8B0QznMeZIfIej5j5rNevZPIOik7kPY51Prt/ORlZA5C3CEhwpmcAh4ALQjnzOfPcSwgeFHU9QAD/o250PhRHga860/4FuMV5/G/49gJ242szn+dM9wD/CVTgG8v922GQOd7JegDYDCwOkbwCfNvJtRe486xlPwsccW6fCaH3+MMyv4mvSasH317b9aGaF19zwACw66xb0D8XE8h7LbDH+WzvAR4Ih8/EWeu4lxAs6Nb13xhjIkSktKEbY8yUZwXdGGMihBV0Y4yJEFbQjTEmQlhBN8aYCGEF3RhjIoQVdGOMiRD/DwgR/yovUYsgAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "parameters = {'n_estimators': scipy.stats.randint(low=10, high=1000), # Uniform distribution between 10 and 1000\n", " 'learning_rate': [0.01, 0.03, 0.1, 0.3],\n", " 'max_depth': [4, 6, 8, 10],\n", " 'subsample': [0.5, 0.75, 1.0],\n", " 'reg_alpha': [0, 1], # L1 regularization\n", " 'reg_lambda': [0, 1] # L2 regularization\n", " }\n", "\n", "hyperparameter_tune_get_results(lightGBM, parameters, 'LightGBM', num_rounds=30)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Blending\n", "\n", "While not technically not one of the three ensemble methods, blending is a popular technique for combining the predictions of multiple models via averaging. It's easy to program, and typically has good results.\n", "\n", "It has similar downsides as stacking - it requires more computational power, and any semblance of interpretability goes out the window." ] }, { "cell_type": "code", "execution_count": 115, "metadata": { "ExecuteTime": { "end_time": "2018-07-11T03:51:48.316770Z", "start_time": "2018-07-11T03:45:54.891909Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training model number 1\n", "Training model number 2\n", "Training model number 3\n", "Training model number 4\n", "Training model number 5\n", "Training model number 6\n", "Training model number 7\n", "Training model number 8\n", "Training model number 9\n", "Training model number 10\n", "\n", "Accuracy: 0.633433333333\n", "Log Loss: 0.649578716595\n", "AUC: 0.674961877427\n" ] } ], "source": [ "num_models = 10\n", "\n", "class_probabilities = []\n", "\n", "for model in range(num_models):\n", " \n", " # Progress printing for every 10% of completion\n", " if (model+1) % (round(num_models) / 10) == 0:\n", " print('Training model number', model+1)\n", " \n", " model = lgb.LGBMClassifier(nthread=-1, n_estimators=935, learning_rate=0.03)\n", " model.fit(X_train, y_train)\n", " model_prediction = model.predict_proba(X_test)\n", " class_probabilities.append(model_prediction)\n", " \n", "# Averaging the predictions for output\n", "class_probabilities = np.asarray(class_probabilities).mean(axis=0)\n", "predictions = np.where(class_probabilities[:, 1] > 0.5, 1, 0)\n", "\n", "print()\n", "print('Accuracy:', metrics.accuracy_score(y_test, predictions))\n", "print('Log Loss:', metrics.log_loss(y_test, class_probabilities))\n", "print('AUC:', metrics.roc_auc_score(y_test, class_probabilities[:, 1]))\n", "\n" ] }, { "cell_type": "code", "execution_count": 116, "metadata": { "ExecuteTime": { "end_time": "2018-07-11T03:55:11.077734Z", "start_time": "2018-07-11T03:55:11.030862Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
LightGBM0.6359220.6478060.6769931664.384194
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "LightGBM 0.635922 0.647806 0.676993 1664.384194" ] }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tuned_results" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "ExecuteTime": { "end_time": "2018-07-11T02:47:38.954431Z", "start_time": "2018-07-11T02:47:38.923322Z" } }, "outputs": [ { "data": { "text/plain": [ "0.71255555555555561" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics.accuracy_score(y_test, predictions)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "ExecuteTime": { "end_time": "2018-07-11T02:47:40.782445Z", "start_time": "2018-07-11T02:47:40.751205Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyLogLossAUCTrainingTime
Logistic Regression0.6886670.6210840.71936242.689128
Random Forest0.7130000.6004200.7314663247.834107
Gradient Boosted Trees0.7125560.6199080.7347924069.814399
LightGBM0.7088890.5986960.734768629.372810
\n", "
" ], "text/plain": [ " Accuracy LogLoss AUC TrainingTime\n", "Logistic Regression 0.688667 0.621084 0.719362 42.689128\n", "Random Forest 0.713000 0.600420 0.731466 3247.834107\n", "Gradient Boosted Trees 0.712556 0.619908 0.734792 4069.814399\n", "LightGBM 0.708889 0.598696 0.734768 629.372810" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tuned_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Summary\n", "\n", "And there we have it! We looked at different types of ensemble methods, how to tune them, and a few different frameworks for using them." ] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.5" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }