{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Evaluate the performance of a classifier with Confusion Matrix\n\nExample of confusion matrix usage to evaluate the quality\nof the output of a classifier on the iris data set. The\ndiagonal elements represent the number of points for which\nthe predicted label is equal to the true label, while\noff-diagonal elements are those that are mislabeled by the\nclassifier. The higher the diagonal values of the confusion\nmatrix the better, indicating many correct predictions.\n\nThe figures show the confusion matrix with and without\nnormalization by class support size (number of elements\nin each class). This kind of normalization can be\ninteresting in case of class imbalance to have a more\nvisual interpretation of which class is being misclassified.\n\nHere the results are not as good as they could be as our\nchoice for the regularization parameter C was not the best.\nIn real life applications this parameter is usually chosen\nusing `grid_search`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn import datasets\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.metrics import ConfusionMatrixDisplay\nfrom sklearn.model_selection import train_test_split\n\n# import some data to play with\niris = datasets.load_iris()\nX = iris.data\ny = iris.target\nclass_names = iris.target_names\n\n# Split the data into a training set and a test set\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n\n# Run classifier, using a model that is too regularized (C too low) to see\n# the impact on the results\nclassifier = LogisticRegression(C=0.01).fit(X_train, y_train)\n\nnp.set_printoptions(precision=2)\n\n# Plot non-normalized confusion matrix\ntitles_options = [\n (\"Confusion matrix, without normalization\", None),\n (\"Normalized confusion matrix\", \"true\"),\n]\nfor title, normalize in titles_options:\n disp = ConfusionMatrixDisplay.from_estimator(\n classifier,\n X_test,\n y_test,\n display_labels=class_names,\n cmap=plt.cm.Blues,\n normalize=normalize,\n )\n disp.ax_.set_title(title)\n\n print(title)\n print(disp.confusion_matrix)\n\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Binary Classification\n\nFor binary classification, use :func:`sklearn.metrics.confusion_matrix` with\nthe `ravel` method to get counts of true negatives, false positives, false\nnegatives, and true positives.\n\nTo obtain counts of true negatives, false positives, false negatives, and true\npositives at different thresholds, one can use\n:func:`sklearn.metrics.confusion_matrix_at_thresholds`.\nThis is fundamental for binary classification\nmetrics like :func:`~sklearn.metrics.roc_auc_score` and\n:func:`~sklearn.metrics.det_curve`.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.datasets import make_classification\nfrom sklearn.metrics import confusion_matrix_at_thresholds\n\nX, y = make_classification(\n n_samples=100,\n n_features=20,\n n_informative=20,\n n_redundant=0,\n n_classes=2,\n random_state=42,\n)\n\nX_train, X_test, y_train, y_test = train_test_split(\n X, y, test_size=0.3, random_state=42\n)\n\nclassifier = LogisticRegression(C=0.01)\nclassifier.fit(X_train, y_train)\n\ny_score = classifier.predict_proba(X_test)[:, 1]\n\ntns, fps, fns, tps, thresholds = confusion_matrix_at_thresholds(y_test, y_score)\n\n# Plot TNs, FPs, FNs and TPs vs Thresholds\nplt.figure(figsize=(10, 6))\n\nplt.plot(thresholds, tns, label=\"True Negatives (TNs)\")\nplt.plot(thresholds, fps, label=\"False Positives (FPs)\")\nplt.plot(thresholds, fns, label=\"False Negatives (FNs)\")\nplt.plot(thresholds, tps, label=\"True Positives (TPs)\")\nplt.xlabel(\"Thresholds\")\nplt.ylabel(\"Count\")\nplt.title(\"TNs, FPs, FNs and TPs vs Thresholds\")\nplt.legend()\nplt.grid()\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 0 }