{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Image denoising using dictionary learning\n\nAn example comparing the effect of reconstructing noisy fragments\nof a raccoon face image using firstly online `DictionaryLearning` and\nvarious transform methods.\n\nThe dictionary is fitted on the distorted left half of the image, and\nsubsequently used to reconstruct the right half. Note that even better\nperformance could be achieved by fitting to an undistorted (i.e.\nnoiseless) image, but here we start from the assumption that it is not\navailable.\n\nA common practice for evaluating the results of image denoising is by looking\nat the difference between the reconstruction and the original image. If the\nreconstruction is perfect this will look like Gaussian noise.\n\nIt can be seen from the plots that the results of `omp` with two\nnon-zero coefficients is a bit less biased than when keeping only one\n(the edges look less prominent). It is in addition closer from the ground\ntruth in Frobenius norm.\n\nThe result of `least_angle_regression` is much more strongly biased: the\ndifference is reminiscent of the local intensity value of the original image.\n\nThresholding is clearly not useful for denoising, but it is here to show that\nit can produce a suggestive output with very high speed, and thus be useful\nfor other tasks such as object classification, where performance is not\nnecessarily related to visualisation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate distorted image\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n\ntry: # Scipy >= 1.10\n from scipy.datasets import face\nexcept ImportError:\n from scipy.misc import face\n\nraccoon_face = face(gray=True)\n\n# Convert from uint8 representation with values between 0 and 255 to\n# a floating point representation with values between 0 and 1.\nraccoon_face = raccoon_face / 255.0\n\n# downsample for higher speed\nraccoon_face = (\n raccoon_face[::4, ::4]\n + raccoon_face[1::4, ::4]\n + raccoon_face[::4, 1::4]\n + raccoon_face[1::4, 1::4]\n)\nraccoon_face /= 4.0\nheight, width = raccoon_face.shape\n\n# Distort the right half of the image\nprint(\"Distorting image...\")\ndistorted = raccoon_face.copy()\ndistorted[:, width // 2 :] += 0.075 * np.random.randn(height, width // 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Display the distorted image\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\n\ndef show_with_diff(image, reference, title):\n \"\"\"Helper function to display denoising\"\"\"\n plt.figure(figsize=(5, 3.3))\n plt.subplot(1, 2, 1)\n plt.title(\"Image\")\n plt.imshow(image, vmin=0, vmax=1, cmap=plt.cm.gray, interpolation=\"nearest\")\n plt.xticks(())\n plt.yticks(())\n plt.subplot(1, 2, 2)\n difference = image - reference\n\n plt.title(\"Difference (norm: %.2f)\" % np.sqrt(np.sum(difference**2)))\n plt.imshow(\n difference, vmin=-0.5, vmax=0.5, cmap=plt.cm.PuOr, interpolation=\"nearest\"\n )\n plt.xticks(())\n plt.yticks(())\n plt.suptitle(title, size=16)\n plt.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)\n\n\nshow_with_diff(distorted, raccoon_face, \"Distorted image\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extract reference patches\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from time import time\n\nfrom sklearn.feature_extraction.image import extract_patches_2d\n\n# Extract all reference patches from the left half of the image\nprint(\"Extracting reference patches...\")\nt0 = time()\npatch_size = (7, 7)\ndata = extract_patches_2d(distorted[:, : width // 2], patch_size)\ndata = data.reshape(data.shape[0], -1)\ndata -= np.mean(data, axis=0)\ndata /= np.std(data, axis=0)\nprint(f\"{data.shape[0]} patches extracted in %.2fs.\" % (time() - t0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learn the dictionary from reference patches\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.decomposition import MiniBatchDictionaryLearning\n\nprint(\"Learning the dictionary...\")\nt0 = time()\ndico = MiniBatchDictionaryLearning(\n # increase to 300 for higher quality results at the cost of slower\n # training times.\n n_components=50,\n batch_size=200,\n alpha=1.0,\n max_iter=10,\n)\nV = dico.fit(data).components_\ndt = time() - t0\nprint(f\"{dico.n_iter_} iterations / {dico.n_steps_} steps in {dt:.2f}.\")\n\nplt.figure(figsize=(4.2, 4))\nfor i, comp in enumerate(V[:100]):\n plt.subplot(10, 10, i + 1)\n plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r, interpolation=\"nearest\")\n plt.xticks(())\n plt.yticks(())\nplt.suptitle(\n \"Dictionary learned from face patches\\n\"\n + \"Train time %.1fs on %d patches\" % (dt, len(data)),\n fontsize=16,\n)\nplt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extract noisy patches and reconstruct them using the dictionary\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.feature_extraction.image import reconstruct_from_patches_2d\n\nprint(\"Extracting noisy patches... \")\nt0 = time()\ndata = extract_patches_2d(distorted[:, width // 2 :], patch_size)\ndata = data.reshape(data.shape[0], -1)\nintercept = np.mean(data, axis=0)\ndata -= intercept\nprint(\"done in %.2fs.\" % (time() - t0))\n\ntransform_algorithms = [\n (\"Orthogonal Matching Pursuit\\n1 atom\", \"omp\", {\"transform_n_nonzero_coefs\": 1}),\n (\"Orthogonal Matching Pursuit\\n2 atoms\", \"omp\", {\"transform_n_nonzero_coefs\": 2}),\n (\"Least-angle regression\\n4 atoms\", \"lars\", {\"transform_n_nonzero_coefs\": 4}),\n (\"Thresholding\\n alpha=0.1\", \"threshold\", {\"transform_alpha\": 0.1}),\n]\n\nreconstructions = {}\nfor title, transform_algorithm, kwargs in transform_algorithms:\n print(title + \"...\")\n reconstructions[title] = raccoon_face.copy()\n t0 = time()\n dico.set_params(transform_algorithm=transform_algorithm, **kwargs)\n code = dico.transform(data)\n patches = np.dot(code, V)\n\n patches += intercept\n patches = patches.reshape(len(data), *patch_size)\n if transform_algorithm == \"threshold\":\n patches -= patches.min()\n patches /= patches.max()\n reconstructions[title][:, width // 2 :] = reconstruct_from_patches_2d(\n patches, (height, width // 2)\n )\n dt = time() - t0\n print(\"done in %.2fs.\" % dt)\n show_with_diff(reconstructions[title], raccoon_face, title + \" (time: %.1fs)\" % dt)\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }