{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Multilabel classification\n\nThis example simulates a multi-label document classification problem. The\ndataset is generated randomly based on the following process:\n\n- pick the number of labels: n ~ Poisson(n_labels)\n- n times, choose a class c: c ~ Multinomial(theta)\n- pick the document length: k ~ Poisson(length)\n- k times, choose a word: w ~ Multinomial(theta_c)\n\nIn the above process, rejection sampling is used to make sure that n is more\nthan 2, and that the document length is never zero. Likewise, we reject classes\nwhich have already been chosen. The documents that are assigned to both\nclasses are plotted surrounded by two colored circles.\n\nThe classification is performed by projecting to the first two principal\ncomponents found by PCA and CCA for visualisation purposes, followed by using\nthe :class:`~sklearn.multiclass.OneVsRestClassifier` metaclassifier using two\nSVCs with linear kernels to learn a discriminative model for each class.\nNote that PCA is used to perform an unsupervised dimensionality reduction,\nwhile CCA is used to perform a supervised one.\n\nNote: in the plot, \"unlabeled samples\" does not mean that we don't know the\nlabels (as in semi-supervised learning) but that the samples simply do *not*\nhave a label.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn.cross_decomposition import CCA\nfrom sklearn.datasets import make_multilabel_classification\nfrom sklearn.decomposition import PCA\nfrom sklearn.multiclass import OneVsRestClassifier\nfrom sklearn.svm import SVC\n\n\ndef plot_hyperplane(clf, min_x, max_x, linestyle, label):\n # get the separating hyperplane\n w = clf.coef_[0]\n a = -w[0] / w[1]\n xx = np.linspace(min_x - 5, max_x + 5) # make sure the line is long enough\n yy = a * xx - (clf.intercept_[0]) / w[1]\n plt.plot(xx, yy, linestyle, label=label)\n\n\ndef plot_subfigure(X, Y, subplot, title, transform):\n if transform == \"pca\":\n X = PCA(n_components=2).fit_transform(X)\n elif transform == \"cca\":\n X = CCA(n_components=2).fit(X, Y).transform(X)\n else:\n raise ValueError\n\n min_x = np.min(X[:, 0])\n max_x = np.max(X[:, 0])\n\n min_y = np.min(X[:, 1])\n max_y = np.max(X[:, 1])\n\n classif = OneVsRestClassifier(SVC(kernel=\"linear\"))\n classif.fit(X, Y)\n\n plt.subplot(2, 2, subplot)\n plt.title(title)\n\n zero_class = np.where(Y[:, 0])\n one_class = np.where(Y[:, 1])\n plt.scatter(X[:, 0], X[:, 1], s=40, c=\"gray\", edgecolors=(0, 0, 0))\n plt.scatter(\n X[zero_class, 0],\n X[zero_class, 1],\n s=160,\n edgecolors=\"b\",\n facecolors=\"none\",\n linewidths=2,\n label=\"Class 1\",\n )\n plt.scatter(\n X[one_class, 0],\n X[one_class, 1],\n s=80,\n edgecolors=\"orange\",\n facecolors=\"none\",\n linewidths=2,\n label=\"Class 2\",\n )\n\n plot_hyperplane(\n classif.estimators_[0], min_x, max_x, \"k--\", \"Boundary\\nfor class 1\"\n )\n plot_hyperplane(\n classif.estimators_[1], min_x, max_x, \"k-.\", \"Boundary\\nfor class 2\"\n )\n plt.xticks(())\n plt.yticks(())\n\n plt.xlim(min_x - 0.5 * max_x, max_x + 0.5 * max_x)\n plt.ylim(min_y - 0.5 * max_y, max_y + 0.5 * max_y)\n if subplot == 2:\n plt.xlabel(\"First principal component\")\n plt.ylabel(\"Second principal component\")\n plt.legend(loc=\"upper left\")\n\n\nplt.figure(figsize=(8, 6))\n\nX, Y = make_multilabel_classification(\n n_classes=2, n_labels=1, allow_unlabeled=True, random_state=1\n)\n\nplot_subfigure(X, Y, 1, \"With unlabeled samples + CCA\", \"cca\")\nplot_subfigure(X, Y, 2, \"With unlabeled samples + PCA\", \"pca\")\n\nX, Y = make_multilabel_classification(\n n_classes=2, n_labels=1, allow_unlabeled=False, random_state=1\n)\n\nplot_subfigure(X, Y, 3, \"Without unlabeled samples + CCA\", \"cca\")\nplot_subfigure(X, Y, 4, \"Without unlabeled samples + PCA\", \"pca\")\n\nplt.subplots_adjust(0.04, 0.02, 0.97, 0.94, 0.09, 0.2)\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }