{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sample grouping\n", "In this notebook we present the concept of **sample groups**. We use the\n", "handwritten digits dataset to highlight some surprising results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_digits\n", "\n", "digits = load_digits()\n", "data, target = digits.data, digits.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create a model consisting of a logistic regression classifier with a\n", "preprocessor to scale the data.\n", "\n", "
\n", "

Note

\n", "

Here we use a MinMaxScaler as we know that each pixel's gray-scale is\n", "strictly bounded between 0 (white) and 16 (black). This makes MinMaxScaler\n", "more suited in this case than StandardScaler, as some pixels consistently\n", "have low variance (pixels at the borders might almost always be zero if most\n", "digits are centered in the image). Then, using StandardScaler can result in\n", "a very high scaled value due to division by a small number.

\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.pipeline import make_pipeline\n", "\n", "model = make_pipeline(MinMaxScaler(), LogisticRegression(max_iter=1_000))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The idea is to compare the estimated generalization performance using\n", "different cross-validation techniques and see how such estimations are\n", "impacted by underlying data structures. We first use a `KFold`\n", "cross-validation without shuffling the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_score, KFold\n", "\n", "cv = KFold(shuffle=False)\n", "test_score_no_shuffling = cross_val_score(model, data, target, cv=cv, n_jobs=2)\n", "print(\n", " \"The average accuracy is \"\n", " f\"{test_score_no_shuffling.mean():.3f} \u00b1 \"\n", " f\"{test_score_no_shuffling.std():.3f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's repeat the experiment by shuffling the data within the\n", "cross-validation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cv = KFold(shuffle=True)\n", "test_score_with_shuffling = cross_val_score(\n", " model, data, target, cv=cv, n_jobs=2\n", ")\n", "print(\n", " \"The average accuracy is \"\n", " f\"{test_score_with_shuffling.mean():.3f} \u00b1 \"\n", " f\"{test_score_with_shuffling.std():.3f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We observe that shuffling the data improves the mean accuracy. We can go a\n", "little further and plot the distribution of the testing score. For such\n", "purpose we concatenate the test scores." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "all_scores = pd.DataFrame(\n", " [test_score_no_shuffling, test_score_with_shuffling],\n", " index=[\"KFold without shuffling\", \"KFold with shuffling\"],\n", ").T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now plot the score distributions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "all_scores.plot.hist(bins=16, edgecolor=\"black\", alpha=0.7)\n", "plt.xlim([0.8, 1.0])\n", "plt.xlabel(\"Accuracy score\")\n", "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n", "_ = plt.title(\"Distribution of the test scores\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Shuffling the data results in a higher cross-validated test accuracy with less\n", "variance compared to when the data is not shuffled. It means that some\n", "specific fold leads to a low score in this case." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(test_score_no_shuffling)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thus, shuffling the data breaks the underlying structure and thus makes the\n", "classification task easier to our model. To get a better understanding, we can\n", "read the dataset description in more detail:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(digits.DESCR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we read carefully, `load_digits` loads a copy of the **test set** of the\n", "UCI ML hand-written digits dataset, which consists of 1797 images by\n", "**13 different writers**. Thus, each writer wrote several times the same\n", "numbers. Let's suppose the dataset is ordered by writer. Subsequently,\n", "not shuffling the data will keep all writer samples together either in the\n", "training or the testing sets. Mixing the data will break this structure, and\n", "therefore digits written by the same writer will be available in both the\n", "training and testing sets.\n", "\n", "Besides, a writer will usually tend to write digits in the same manner. Thus,\n", "our model will learn to identify a writer's pattern for each digit instead of\n", "recognizing the digit itself.\n", "\n", "We can solve this problem by ensuring that the data associated with a writer\n", "should either belong to the training or the testing set. Thus, we want to\n", "group samples for each writer.\n", "\n", "Indeed, we can recover the groups by looking at the target variable." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "target[:200]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "It might not be obvious at first, but there is a structure in the target:\n", "there is a repetitive pattern that always starts by some series of ordered\n", "digits from 0 to 9 followed by random digits at a certain point. If we look in\n", "detail, we see that there are 14 such patterns, always with around 130 samples\n", "each.\n", "\n", "Even if it is not exactly corresponding to the 13 writers in the documentation\n", "(maybe one writer wrote two series of digits), we can make the hypothesis that\n", "each of these patterns corresponds to a different writer and thus a different\n", "group." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from itertools import count\n", "import numpy as np\n", "\n", "# defines the lower and upper bounds of sample indices\n", "# for each writer\n", "writer_boundaries = [\n", " 0,\n", " 130,\n", " 256,\n", " 386,\n", " 516,\n", " 646,\n", " 776,\n", " 915,\n", " 1029,\n", " 1157,\n", " 1287,\n", " 1415,\n", " 1545,\n", " 1667,\n", " 1797,\n", "]\n", "groups = np.zeros_like(target)\n", "lower_bounds = writer_boundaries[:-1]\n", "upper_bounds = writer_boundaries[1:]\n", "\n", "for group_id, lb, up in zip(count(), lower_bounds, upper_bounds):\n", " groups[lb:up] = group_id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check the grouping by plotting the indices linked to writers' ids." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(groups)\n", "plt.yticks(np.unique(groups))\n", "plt.xticks(writer_boundaries, rotation=90)\n", "plt.xlabel(\"Target index\")\n", "plt.ylabel(\"Writer index\")\n", "_ = plt.title(\"Underlying writer groups existing in the target\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we group the digits by writer, we can incorporate this information into\n", "the cross-validation process by using group-aware variations of the strategies\n", "we have explored in this course, for example, the `GroupKFold` strategy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GroupKFold\n", "\n", "cv = GroupKFold()\n", "test_score = cross_val_score(\n", " model, data, target, groups=groups, cv=cv, n_jobs=2\n", ")\n", "print(\n", " f\"The average accuracy is {test_score.mean():.3f} \u00b1 {test_score.std():.3f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that this strategy leads to a lower generalization performance than the\n", "other two techniques. However, this is the most reliable estimate if our goal\n", "is to evaluate the capabilities of the model to generalize to new unseen\n", "writers. In this sense, shuffling the dataset (or alternatively using the\n", "writers' ids as a new feature) would lead the model to memorize the different\n", "writer's particular handwriting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_scores = pd.DataFrame(\n", " [test_score_no_shuffling, test_score_with_shuffling, test_score],\n", " index=[\n", " \"KFold without shuffling\",\n", " \"KFold with shuffling\",\n", " \"KFold with groups\",\n", " ],\n", ").T" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_scores.plot.hist(bins=16, edgecolor=\"black\", alpha=0.7)\n", "plt.xlim([0.8, 1.0])\n", "plt.xlabel(\"Accuracy score\")\n", "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n", "_ = plt.title(\"Distribution of the test scores\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In conclusion, accounting for any sample grouping patterns is crucial when\n", "assessing a model\u2019s ability to generalize to new groups. Without this\n", "consideration, the results may appear overly optimistic compared to the actual\n", "performance.\n", "\n", "The interested reader can learn about other group-aware cross-validation\n", "techniques in the [scikit-learn user\n", "guide](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data)." ] } ], "metadata": { "jupytext": { "main_language": "python" }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }