{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Categorical Feature Support in Gradient Boosting\n\n.. currentmodule:: sklearn\n\nIn this example, we compare the training times and prediction performances of\n:class:`~ensemble.HistGradientBoostingRegressor` with different encoding\nstrategies for categorical features. In particular, we evaluate:\n\n- \"Dropped\": dropping the categorical features;\n- \"One Hot\": using a :class:`~preprocessing.OneHotEncoder`;\n- \"Ordinal\": using an :class:`~preprocessing.OrdinalEncoder` and treat\n categories as ordered, equidistant quantities;\n- \"Target\": using a :class:`~preprocessing.TargetEncoder`;\n- \"Native\": relying on the `native category support\n ` of the\n :class:`~ensemble.HistGradientBoostingRegressor` estimator.\n\nFor such purpose we use the Ames Iowa Housing dataset, which consists of\nnumerical and categorical features, where the target is the house sale price.\n\nSee `sphx_glr_auto_examples_ensemble_plot_hgbt_regression.py` for an\nexample showcasing some other features of\n:class:`~ensemble.HistGradientBoostingRegressor`.\n\nSee `sphx_glr_auto_examples_preprocessing_plot_target_encoder.py` for a\ncomparison of encoding strategies in the presence of high cardinality\ncategorical features.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Ames Housing dataset\nFirst, we load the Ames Housing data as a pandas dataframe. The features\nare either categorical or numerical:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.datasets import fetch_openml\n\nX, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True)\n\n# Select only a subset of features of X to make the example faster to run\ncategorical_columns_subset = [\n \"BldgType\",\n \"GarageFinish\",\n \"LotConfig\",\n \"Functional\",\n \"MasVnrType\",\n \"HouseStyle\",\n \"FireplaceQu\",\n \"ExterCond\",\n \"ExterQual\",\n \"PoolQC\",\n]\n\nnumerical_columns_subset = [\n \"3SsnPorch\",\n \"Fireplaces\",\n \"BsmtHalfBath\",\n \"HalfBath\",\n \"GarageCars\",\n \"TotRmsAbvGrd\",\n \"BsmtFinSF1\",\n \"BsmtFinSF2\",\n \"GrLivArea\",\n \"ScreenPorch\",\n]\n\nX = X[categorical_columns_subset + numerical_columns_subset]\nX[categorical_columns_subset] = X[categorical_columns_subset].astype(\"category\")\n\ncategorical_columns = X.select_dtypes(include=\"category\").columns\nn_categorical_features = len(categorical_columns)\nn_numerical_features = X.select_dtypes(include=\"number\").shape[1]\n\nprint(f\"Number of samples: {X.shape[0]}\")\nprint(f\"Number of features: {X.shape[1]}\")\nprint(f\"Number of categorical features: {n_categorical_features}\")\nprint(f\"Number of numerical features: {n_numerical_features}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient boosting estimator with dropped categorical features\nAs a baseline, we create an estimator where the categorical features are\ndropped:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.compose import make_column_selector, make_column_transformer\nfrom sklearn.ensemble import HistGradientBoostingRegressor\nfrom sklearn.pipeline import make_pipeline\n\ndropper = make_column_transformer(\n (\"drop\", make_column_selector(dtype_include=\"category\")), remainder=\"passthrough\"\n)\nhist_dropped = make_pipeline(dropper, HistGradientBoostingRegressor(random_state=42))\nhist_dropped" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient boosting estimator with one-hot encoding\nNext, we create a pipeline to one-hot encode the categorical features,\nwhile letting the remaining features `\"passthrough\"` unchanged:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.preprocessing import OneHotEncoder\n\none_hot_encoder = make_column_transformer(\n (\n OneHotEncoder(sparse_output=False, handle_unknown=\"ignore\"),\n make_column_selector(dtype_include=\"category\"),\n ),\n remainder=\"passthrough\",\n)\n\nhist_one_hot = make_pipeline(\n one_hot_encoder, HistGradientBoostingRegressor(random_state=42)\n)\nhist_one_hot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient boosting estimator with ordinal encoding\nNext, we create a pipeline that treats categorical features as ordered\nquantities, i.e. the categories are encoded as 0, 1, 2, etc., and treated as\ncontinuous features.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n\nfrom sklearn.preprocessing import OrdinalEncoder\n\nordinal_encoder = make_column_transformer(\n (\n OrdinalEncoder(handle_unknown=\"use_encoded_value\", unknown_value=np.nan),\n make_column_selector(dtype_include=\"category\"),\n ),\n remainder=\"passthrough\",\n)\n\nhist_ordinal = make_pipeline(\n ordinal_encoder, HistGradientBoostingRegressor(random_state=42)\n)\nhist_ordinal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient boosting estimator with target encoding\nAnother possibility is to use the :class:`~preprocessing.TargetEncoder`, which\nencodes the categories computed from the mean of the (training) target\nvariable, as computed using a smoothed `np.mean(y, axis=0)` i.e.:\n\n- in regression it uses the mean of `y`;\n- in binary classification, the positive-class rate;\n- in multiclass, a vector of class rates (one per class).\n\nFor each category, it computes these target averages using :term:`cross\nfitting`, meaning that the training data are split into folds: in each fold\nthe averages are calculated only on a subset of data and then applied to the\nheld-out part. This way, each sample is encoded using statistics from data it\nwas not part of, preventing information leakage from the target.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.preprocessing import TargetEncoder\n\ntarget_encoder = make_column_transformer(\n (\n TargetEncoder(target_type=\"continuous\", random_state=42),\n make_column_selector(dtype_include=\"category\"),\n ),\n remainder=\"passthrough\",\n)\n\nhist_target = make_pipeline(\n target_encoder, HistGradientBoostingRegressor(random_state=42)\n)\nhist_target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient boosting estimator with native categorical support\nWe now create a :class:`~ensemble.HistGradientBoostingRegressor` estimator\nthat can natively handle categorical features without explicit encoding. Such\nfunctionality can be enabled by setting `categorical_features=\"from_dtype\"`,\nwhich automatically detects features with categorical dtypes, or more explicitly\nby `categorical_features=categorical_columns_subset`.\n\nUnlike previous encoding approaches, the estimator natively deals with the\ncategorical features. At each split, it partitions the categories of such a\nfeature into disjoint sets using a heuristic that sorts them by their effect\non the target variable, see [Split finding with categorical features](https://scikit-learn.org/stable/modules/ensemble.html#split-finding-with-categorical-features)\nfor details.\n\nWhile ordinal encoding may work well for low-cardinality features even if\ncategories have no natural order, reaching meaningful splits requires deeper\ntrees as the cardinality increases. The native categorical support avoids this\nby directly working with unordered categories. The advantage over one-hot\nencoding is the omitted preprocessing and faster fit and predict time.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "hist_native = HistGradientBoostingRegressor(\n random_state=42, categorical_features=\"from_dtype\"\n)\nhist_native" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model comparison\nHere we use :term:`cross validation` to compare the models performance in\nterms of :func:`~metrics.mean_absolute_percentage_error` and fit times. In the\nupcoming plots, error bars represent 1 standard deviation as computed across\ncross-validation splits.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.model_selection import cross_validate\n\ncommon_params = {\"cv\": 5, \"scoring\": \"neg_mean_absolute_percentage_error\", \"n_jobs\": -1}\n\ndropped_result = cross_validate(hist_dropped, X, y, **common_params)\none_hot_result = cross_validate(hist_one_hot, X, y, **common_params)\nordinal_result = cross_validate(hist_ordinal, X, y, **common_params)\ntarget_result = cross_validate(hist_target, X, y, **common_params)\nnative_result = cross_validate(hist_native, X, y, **common_params)\nresults = [\n (\"Dropped\", dropped_result),\n (\"One Hot\", one_hot_result),\n (\"Ordinal\", ordinal_result),\n (\"Target\", target_result),\n (\"Native\", native_result),\n]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\nimport matplotlib.ticker as ticker\n\n\ndef plot_performance_tradeoff(results, title):\n fig, ax = plt.subplots()\n markers = [\"s\", \"o\", \"^\", \"x\", \"D\"]\n\n for idx, (name, result) in enumerate(results):\n test_error = -result[\"test_score\"]\n mean_fit_time = np.mean(result[\"fit_time\"])\n mean_score = np.mean(test_error)\n std_fit_time = np.std(result[\"fit_time\"])\n std_score = np.std(test_error)\n\n ax.scatter(\n result[\"fit_time\"],\n test_error,\n label=name,\n marker=markers[idx],\n )\n ax.scatter(\n mean_fit_time,\n mean_score,\n color=\"k\",\n marker=markers[idx],\n )\n ax.errorbar(\n x=mean_fit_time,\n y=mean_score,\n yerr=std_score,\n c=\"k\",\n capsize=2,\n )\n ax.errorbar(\n x=mean_fit_time,\n y=mean_score,\n xerr=std_fit_time,\n c=\"k\",\n capsize=2,\n )\n\n ax.set_xscale(\"log\")\n\n nticks = 7\n x0, x1 = np.log10(ax.get_xlim())\n ticks = np.logspace(x0, x1, nticks)\n ax.set_xticks(ticks)\n ax.xaxis.set_major_formatter(ticker.FormatStrFormatter(\"%1.1e\"))\n ax.minorticks_off()\n\n ax.annotate(\n \" best\\nmodels\",\n xy=(0.04, 0.04),\n xycoords=\"axes fraction\",\n xytext=(0.09, 0.14),\n textcoords=\"axes fraction\",\n arrowprops=dict(arrowstyle=\"->\", lw=1.5),\n )\n ax.set_xlabel(\"Time to fit (seconds)\")\n ax.set_ylabel(\"Mean Absolute Percentage Error\")\n ax.set_title(title)\n ax.legend()\n plt.show()\n\n\nplot_performance_tradeoff(results, \"Gradient Boosting on Ames Housing\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the plot above, the \"best models\" are those that are closer to the\ndown-left corner, as indicated by the arrow. Those models would indeed\ncorrespond to faster fitting and lower error.\n\nThe model using one-hot encoded data is the slowest. This is to be expected,\nas one-hot encoding creates an additional feature for each category value of\nevery categorical feature, greatly increasing the number of split candidates\nduring training. In theory, we expect the native handling of categorical\nfeatures to be slightly slower than treating categories as ordered quantities\n('Ordinal'), since native handling requires `sorting categories\n`. Fitting times should however be close when the\nnumber of categories is small, and this may not always be reflected in\npractice.\n\nThe time required to fit when using the `TargetEncoder` depends on the\ncross fitting parameter `cv`, as adding splits come at a computational cost.\n\nIn terms of prediction performance, dropping the categorical features leads to\nthe worst performance. The four models that make use of the categorical\nfeatures have comparable error rates, with a slight edge for the native\nhandling.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Limiting the number of splits\nIn general, one can expect poorer predictions from one-hot-encoded data,\nespecially when the tree depths or the number of nodes are limited: with\none-hot-encoded data, one needs more split points, i.e. more depth, in order\nto recover an equivalent split that could be obtained in one single split\npoint with native handling.\n\nThis is also true when categories are treated as ordinal quantities: if\ncategories are `A..F` and the best split is `ACF - BDE` the one-hot-encoder\nmodel would need 3 split points (one per category in the left node), and the\nordinal non-native model would need 4 splits: 1 split to isolate `A`, 1 split\nto isolate `F`, and 2 splits to isolate `C` from `BCDE`.\n\nHow strongly the models' performances differ in practice depends on the\ndataset and on the flexibility of the trees.\n\nTo see this, let us re-run the same analysis with under-fitting models where\nwe artificially limit the total number of splits by both limiting the number\nof trees and the depth of each tree.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for pipe in (hist_dropped, hist_one_hot, hist_ordinal, hist_target, hist_native):\n if pipe is hist_native:\n # The native model does not use a pipeline so, we can set the parameters\n # directly.\n pipe.set_params(max_depth=3, max_iter=15)\n else:\n pipe.set_params(\n histgradientboostingregressor__max_depth=3,\n histgradientboostingregressor__max_iter=15,\n )\n\ndropped_result = cross_validate(hist_dropped, X, y, **common_params)\none_hot_result = cross_validate(hist_one_hot, X, y, **common_params)\nordinal_result = cross_validate(hist_ordinal, X, y, **common_params)\ntarget_result = cross_validate(hist_target, X, y, **common_params)\nnative_result = cross_validate(hist_native, X, y, **common_params)\nresults_underfit = [\n (\"Dropped\", dropped_result),\n (\"One Hot\", one_hot_result),\n (\"Ordinal\", ordinal_result),\n (\"Target\", target_result),\n (\"Native\", native_result),\n]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_performance_tradeoff(\n results_underfit, \"Gradient Boosting on Ames Housing (few and shallow trees)\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The results for these underfitting models confirm our previous intuition: the\nnative category handling strategy performs the best when the splitting budget\nis constrained. The three explicit encoding strategies (one-hot, ordinal and\ntarget encoding) lead to slightly larger errors than the estimator's native\nhandling, but still perform better than the baseline model that just dropped\nthe categorical features altogether.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 0 }