{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Restricted Boltzmann Machine features for digit classification\n\nFor greyscale image data where pixel values can be interpreted as degrees of\nblackness on a white background, like handwritten digit recognition, the\nBernoulli Restricted Boltzmann machine model (:class:`BernoulliRBM\n`) can perform effective non-linear\nfeature extraction.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\nIn order to learn good latent representations from a small dataset, we\nartificially generate more labeled data by perturbing the training data with\nlinear shifts of 1 pixel in each direction.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nfrom scipy.ndimage import convolve\n\nfrom sklearn import datasets\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.preprocessing import minmax_scale\n\n\ndef nudge_dataset(X, Y):\n \"\"\"\n This produces a dataset 5 times bigger than the original one,\n by moving the 8x8 images in X around by 1px to left, right, down, up\n \"\"\"\n direction_vectors = [\n [[0, 1, 0], [0, 0, 0], [0, 0, 0]],\n [[0, 0, 0], [1, 0, 0], [0, 0, 0]],\n [[0, 0, 0], [0, 0, 1], [0, 0, 0]],\n [[0, 0, 0], [0, 0, 0], [0, 1, 0]],\n ]\n\n def shift(x, w):\n return convolve(x.reshape((8, 8)), mode=\"constant\", weights=w).ravel()\n\n X = np.concatenate(\n [X] + [np.apply_along_axis(shift, 1, X, vector) for vector in direction_vectors]\n )\n Y = np.concatenate([Y for _ in range(5)], axis=0)\n return X, Y\n\n\nX, y = datasets.load_digits(return_X_y=True)\nX = np.asarray(X, \"float32\")\nX, Y = nudge_dataset(X, y)\nX = minmax_scale(X, feature_range=(0, 1)) # 0-1 scaling\n\nX_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Models definition\n\nWe build a classification pipeline with a BernoulliRBM feature extractor and\na :class:`LogisticRegression `\nclassifier.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn import linear_model\nfrom sklearn.neural_network import BernoulliRBM\nfrom sklearn.pipeline import Pipeline\n\nlogistic = linear_model.LogisticRegression(solver=\"newton-cg\", tol=1)\nrbm = BernoulliRBM(random_state=0, verbose=True)\n\nrbm_features_classifier = Pipeline(steps=[(\"rbm\", rbm), (\"logistic\", logistic)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n\nThe hyperparameters of the entire model (learning rate, hidden layer size,\nregularization) were optimized by grid search, but the search is not\nreproduced here because of runtime constraints.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.base import clone\n\n# Hyper-parameters. These were set by cross-validation,\n# using a GridSearchCV. Here we are not performing cross-validation to\n# save time.\nrbm.learning_rate = 0.06\nrbm.n_iter = 10\n\n# More components tend to give better prediction performance, but larger\n# fitting time\nrbm.n_components = 100\nlogistic.C = 6000\n\n# Training RBM-Logistic Pipeline\nrbm_features_classifier.fit(X_train, Y_train)\n\n# Training the Logistic regression classifier directly on the pixel\nraw_pixel_classifier = clone(logistic)\nraw_pixel_classifier.C = 100.0\nraw_pixel_classifier.fit(X_train, Y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn import metrics\n\nY_pred = rbm_features_classifier.predict(X_test)\nprint(\n \"Logistic regression using RBM features:\\n%s\\n\"\n % (metrics.classification_report(Y_test, Y_pred))\n)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Y_pred = raw_pixel_classifier.predict(X_test)\nprint(\n \"Logistic regression using raw pixel features:\\n%s\\n\"\n % (metrics.classification_report(Y_test, Y_pred))\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The features extracted by the BernoulliRBM help improve the classification\naccuracy with respect to the logistic regression on raw pixels.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\nplt.figure(figsize=(4.2, 4))\nfor i, comp in enumerate(rbm.components_):\n plt.subplot(10, 10, i + 1)\n plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r, interpolation=\"nearest\")\n plt.xticks(())\n plt.yticks(())\nplt.suptitle(\"100 components extracted by RBM\", fontsize=16)\nplt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }