.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/semi_supervised/plot_semi_supervised_newsgroups.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_semi_supervised_plot_semi_supervised_newsgroups.py: ================================================ Semi-supervised Classification on a Text Dataset ================================================ This example demonstrates the effectiveness of semi-supervised learning for text classification on :class:`TF-IDF ` features when labeled data is scarce. For such purpose we compare four different approaches: 1. Supervised learning using 100% of labels in the training set (best-case scenario) - Uses :class:`~sklearn.linear_model.SGDClassifier` with full supervision - Represents the best possible performance when labeled data is abundant 2. Supervised learning using 20% of labels in the training set (baseline) - Same model as the best-case scenario but trained on a random 20% subset of the labeled training data - Shows the performance degradation of a fully supervised model due to limited labeled data 3. :class:`~sklearn.semi_supervised.SelfTrainingClassifier` (semi-supervised) - Uses 20% labeled data + 80% unlabeled data for training - Iteratively predicts labels for unlabeled data - Demonstrates how self-training can improve performance 4. :class:`~sklearn.semi_supervised.LabelSpreading` (semi-supervised) - Uses 20% labeled data + 80% unlabeled data for training - Propagates labels through the data manifold - Shows how graph-based methods can leverage unlabeled data The example uses the 20 newsgroups dataset, focusing on five categories. The results demonstrate how semi-supervised methods can achieve better performance than supervised learning with limited labeled data by effectively utilizing unlabeled samples. .. GENERATED FROM PYTHON SOURCE LINES 41-45 .. code-block:: Python # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause .. GENERATED FROM PYTHON SOURCE LINES 46-112 .. code-block:: Python from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.linear_model import SGDClassifier from sklearn.metrics import f1_score from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline from sklearn.semi_supervised import LabelSpreading, SelfTrainingClassifier # Loading dataset containing first five categories data = fetch_20newsgroups( subset="train", categories=[ "alt.atheism", "comp.graphics", "comp.os.ms-windows.misc", "comp.sys.ibm.pc.hardware", "comp.sys.mac.hardware", ], ) # Parameters sdg_params = dict(alpha=1e-5, penalty="l2", loss="log_loss") vectorizer_params = dict(ngram_range=(1, 2), min_df=5, max_df=0.8) # Supervised Pipeline pipeline = Pipeline( [ ("vect", CountVectorizer(**vectorizer_params)), ("tfidf", TfidfTransformer()), ("clf", SGDClassifier(**sdg_params)), ] ) # SelfTraining Pipeline st_pipeline = Pipeline( [ ("vect", CountVectorizer(**vectorizer_params)), ("tfidf", TfidfTransformer()), ("clf", SelfTrainingClassifier(SGDClassifier(**sdg_params))), ] ) # LabelSpreading Pipeline ls_pipeline = Pipeline( [ ("vect", CountVectorizer(**vectorizer_params)), ("tfidf", TfidfTransformer()), ("clf", LabelSpreading()), ] ) def eval_and_get_f1(clf, X_train, y_train, X_test, y_test): """Evaluate model performance and return F1 score""" print(f" Number of training samples: {len(X_train)}") print(f" Unlabeled samples in training set: {sum(1 for x in y_train if x == -1)}") clf.fit(X_train, y_train) y_pred = clf.predict(X_test) f1 = f1_score(y_test, y_pred, average="micro") print(f" Micro-averaged F1 score on test set: {f1:.3f}") print("\n") return f1 X, y = data.data, data.target X_train, X_test, y_train, y_test = train_test_split(X, y) .. GENERATED FROM PYTHON SOURCE LINES 113-116 1. Evaluate a supervised SGDClassifier using 100% of the (labeled) training set. This represents the best-case performance when the model has full access to all labeled examples. .. GENERATED FROM PYTHON SOURCE LINES 116-123 .. code-block:: Python f1_scores = {} print("1. Supervised SGDClassifier on 100% of the data:") f1_scores["Supervised (100%)"] = eval_and_get_f1( pipeline, X_train, y_train, X_test, y_test ) .. rst-class:: sphx-glr-script-out .. code-block:: none 1. Supervised SGDClassifier on 100% of the data: Number of training samples: 2117 Unlabeled samples in training set: 0 Micro-averaged F1 score on test set: 0.885 .. GENERATED FROM PYTHON SOURCE LINES 124-127 2. Evaluate a supervised SGDClassifier trained on only 20% of the data. This serves as a baseline to illustrate the performance drop caused by limiting the training samples. .. GENERATED FROM PYTHON SOURCE LINES 127-137 .. code-block:: Python import numpy as np print("2. Supervised SGDClassifier on 20% of the training data:") rng = np.random.default_rng(42) y_mask = rng.random(len(y_train)) < 0.2 # X_20 and y_20 are the subset of the train dataset indicated by the mask X_20, y_20 = map(list, zip(*((x, y) for x, y, m in zip(X_train, y_train, y_mask) if m))) f1_scores["Supervised (20%)"] = eval_and_get_f1(pipeline, X_20, y_20, X_test, y_test) .. rst-class:: sphx-glr-script-out .. code-block:: none 2. Supervised SGDClassifier on 20% of the training data: Number of training samples: 434 Unlabeled samples in training set: 0 Micro-averaged F1 score on test set: 0.725 .. GENERATED FROM PYTHON SOURCE LINES 138-142 3. Evaluate a semi-supervised SelfTrainingClassifier using 20% labeled and 80% unlabeled data. The remaining 80% of the training labels are masked as unlabeled (-1), allowing the model to iteratively label and learn from them. .. GENERATED FROM PYTHON SOURCE LINES 142-152 .. code-block:: Python print( "3. SelfTrainingClassifier (semi-supervised) using 20% labeled " "+ 80% unlabeled data):" ) y_train_semi = y_train.copy() y_train_semi[~y_mask] = -1 f1_scores["SelfTraining"] = eval_and_get_f1( st_pipeline, X_train, y_train_semi, X_test, y_test ) .. rst-class:: sphx-glr-script-out .. code-block:: none 3. SelfTrainingClassifier (semi-supervised) using 20% labeled + 80% unlabeled data): Number of training samples: 2117 Unlabeled samples in training set: 1683 Micro-averaged F1 score on test set: 0.823 .. GENERATED FROM PYTHON SOURCE LINES 153-157 4. Evaluate a semi-supervised LabelSpreading model using 20% labeled and 80% unlabeled data. Like SelfTraining, the model infers labels for the unlabeled portion of the data to enhance performance. .. GENERATED FROM PYTHON SOURCE LINES 157-162 .. code-block:: Python print("4. LabelSpreading (semi-supervised) using 20% labeled + 80% unlabeled data:") f1_scores["LabelSpreading"] = eval_and_get_f1( ls_pipeline, X_train, y_train_semi, X_test, y_test ) .. rst-class:: sphx-glr-script-out .. code-block:: none 4. LabelSpreading (semi-supervised) using 20% labeled + 80% unlabeled data: Number of training samples: 2117 Unlabeled samples in training set: 1683 Micro-averaged F1 score on test set: 0.649 .. GENERATED FROM PYTHON SOURCE LINES 163-171 Plot results ------------ Visualize the performance of different classification approaches using a bar chart. This helps to compare how each method performs based on the micro-averaged :func:`~sklearn.metrics.f1_score`. Micro-averaging computes metrics globally across all classes, which gives a single overall measure of performance and allows fair comparison between the different approaches, even in the presence of class imbalance. .. GENERATED FROM PYTHON SOURCE LINES 171-211 .. code-block:: Python import matplotlib.pyplot as plt plt.figure(figsize=(10, 6)) models = list(f1_scores.keys()) scores = list(f1_scores.values()) colors = ["royalblue", "royalblue", "forestgreen", "royalblue"] bars = plt.bar(models, scores, color=colors) plt.title("Comparison of Classification Approaches") plt.ylabel("Micro-averaged F1 Score on test set") plt.xticks() for bar in bars: height = bar.get_height() plt.text( bar.get_x() + bar.get_width() / 2.0, height, f"{height:.2f}", ha="center", va="bottom", ) plt.figtext( 0.5, 0.02, "SelfTraining classifier shows improved performance over " "supervised learning with limited data", ha="center", va="bottom", fontsize=10, style="italic", ) plt.tight_layout() plt.subplots_adjust(bottom=0.15) plt.show() .. image-sg:: /auto_examples/semi_supervised/images/sphx_glr_plot_semi_supervised_newsgroups_001.png :alt: Comparison of Classification Approaches :srcset: /auto_examples/semi_supervised/images/sphx_glr_plot_semi_supervised_newsgroups_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.771 seconds) .. _sphx_glr_download_auto_examples_semi_supervised_plot_semi_supervised_newsgroups.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/semi_supervised/plot_semi_supervised_newsgroups.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../../lite/lab/index.html?path=auto_examples/semi_supervised/plot_semi_supervised_newsgroups.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_semi_supervised_newsgroups.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_semi_supervised_newsgroups.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_semi_supervised_newsgroups.zip ` .. include:: plot_semi_supervised_newsgroups.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_