{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Sample grouping\n", "We are going to linger into the concept of sample groups. As in the previous\n", "section, we will give an example to highlight some surprising results. This\n", "time, we will use the handwritten digits dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_digits\n", "\n", "digits = load_digits()\n", "data, target = digits.data, digits.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will recreate the same model used in the previous notebook: a logistic\n", "regression classifier with a preprocessor to scale the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.pipeline import make_pipeline\n", "\n", "model = make_pipeline(MinMaxScaler(), LogisticRegression(max_iter=1_000))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use the same baseline model. We will use a `KFold` cross-validation\n", "without shuffling the data at first." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_score, KFold\n", "\n", "cv = KFold(shuffle=False)\n", "test_score_no_shuffling = cross_val_score(model, data, target, cv=cv, n_jobs=2)\n", "print(\n", " \"The average accuracy is \"\n", " f\"{test_score_no_shuffling.mean():.3f} \u00b1 \"\n", " f\"{test_score_no_shuffling.std():.3f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's repeat the experiment by shuffling the data within the\n", "cross-validation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cv = KFold(shuffle=True)\n", "test_score_with_shuffling = cross_val_score(\n", " model, data, target, cv=cv, n_jobs=2\n", ")\n", "print(\n", " \"The average accuracy is \"\n", " f\"{test_score_with_shuffling.mean():.3f} \u00b1 \"\n", " f\"{test_score_with_shuffling.std():.3f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We observe that shuffling the data improves the mean accuracy. We could go a\n", "little further and plot the distribution of the testing score. We can first\n", "concatenate the test scores." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "all_scores = pd.DataFrame(\n", " [test_score_no_shuffling, test_score_with_shuffling],\n", " index=[\"KFold without shuffling\", \"KFold with shuffling\"],\n", ").T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot the distribution now." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "all_scores.plot.hist(bins=10, edgecolor=\"black\", alpha=0.7)\n", "plt.xlim([0.8, 1.0])\n", "plt.xlabel(\"Accuracy score\")\n", "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n", "_ = plt.title(\"Distribution of the test scores\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cross-validation testing error that uses the shuffling has less variance\n", "than the one that does not impose any shuffling. It means that some specific\n", "fold leads to a low score in this case." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(test_score_no_shuffling)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thus, there is an underlying structure in the data that shuffling will break\n", "and get better results. To get a better understanding, we should read the\n", "documentation shipped with the dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(digits.DESCR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we read carefully, 13 writers wrote the digits of our dataset, accounting\n", "for a total amount of 1797 samples. Thus, a writer wrote several times the\n", "same numbers. Let's suppose that the writer samples are grouped. Subsequently,\n", "not shuffling the data will keep all writer samples together either in the\n", "training or the testing sets. Mixing the data will break this structure, and\n", "therefore digits written by the same writer will be available in both the\n", "training and testing sets.\n", "\n", "Besides, a writer will usually tend to write digits in the same manner. Thus,\n", "our model will learn to identify a writer's pattern for each digit instead of\n", "recognizing the digit itself.\n", "\n", "We can solve this problem by ensuring that the data associated with a writer\n", "should either belong to the training or the testing set. Thus, we want to\n", "group samples for each writer.\n", "\n", "Indeed, we can recover the groups by looking at the target variable." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "target[:200]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "It might not be obvious at first, but there is a structure in the target:\n", "there is a repetitive pattern that always starts by some series of ordered\n", "digits from 0 to 9 followed by random digits at a certain point. If we look in\n", "detail, we see that there are 14 such patterns, always with around 130 samples\n", "each.\n", "\n", "Even if it is not exactly corresponding to the 13 writers in the documentation\n", "(maybe one writer wrote two series of digits), we can make the hypothesis that\n", "each of these patterns corresponds to a different writer and thus a different\n", "group." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from itertools import count\n", "import numpy as np\n", "\n", "# defines the lower and upper bounds of sample indices\n", "# for each writer\n", "writer_boundaries = [\n", " 0,\n", " 130,\n", " 256,\n", " 386,\n", " 516,\n", " 646,\n", " 776,\n", " 915,\n", " 1029,\n", " 1157,\n", " 1287,\n", " 1415,\n", " 1545,\n", " 1667,\n", " 1797,\n", "]\n", "groups = np.zeros_like(target)\n", "lower_bounds = writer_boundaries[:-1]\n", "upper_bounds = writer_boundaries[1:]\n", "\n", "for group_id, lb, up in zip(count(), lower_bounds, upper_bounds):\n", " groups[lb:up] = group_id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check the grouping by plotting the indices linked to writer ids." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(groups)\n", "plt.yticks(np.unique(groups))\n", "plt.xticks(writer_boundaries, rotation=90)\n", "plt.xlabel(\"Target index\")\n", "plt.ylabel(\"Writer index\")\n", "_ = plt.title(\"Underlying writer groups existing in the target\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we group the digits by writer, we can use cross-validation to take this\n", "information into account: the class containing `Group` should be used." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GroupKFold\n", "\n", "cv = GroupKFold()\n", "test_score = cross_val_score(\n", " model, data, target, groups=groups, cv=cv, n_jobs=2\n", ")\n", "print(\n", " f\"The average accuracy is {test_score.mean():.3f} \u00b1 {test_score.std():.3f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that this strategy is less optimistic regarding the model\n", "generalization performance. However, this is the most reliable if our goal is\n", "to make handwritten digits recognition writers independent. Besides, we can as\n", "well see that the standard deviation was reduced." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_scores = pd.DataFrame(\n", " [test_score_no_shuffling, test_score_with_shuffling, test_score],\n", " index=[\n", " \"KFold without shuffling\",\n", " \"KFold with shuffling\",\n", " \"KFold with groups\",\n", " ],\n", ").T" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_scores.plot.hist(bins=10, edgecolor=\"black\", alpha=0.7)\n", "plt.xlim([0.8, 1.0])\n", "plt.xlabel(\"Accuracy score\")\n", "plt.legend(bbox_to_anchor=(1.05, 0.8), loc=\"upper left\")\n", "_ = plt.title(\"Distribution of the test scores\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As a conclusion, it is really important to take any sample grouping pattern\n", "into account when evaluating a model. Otherwise, the results obtained will be\n", "over-optimistic in regards with reality." ] } ], "metadata": { "jupytext": { "main_language": "python" }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }