{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Sparse coding with a precomputed dictionary\n\nTransform a signal as a sparse combination of Ricker wavelets. This example\nvisually compares different sparse coding methods using the\n:class:`~sklearn.decomposition.SparseCoder` estimator. The Ricker (also known\nas Mexican hat or the second derivative of a Gaussian) is not a particularly\ngood kernel to represent piecewise constant signals like this one. It can\ntherefore be seen how much adding different widths of atoms matters and it\ntherefore motivates learning the dictionary to best fit your type of signals.\n\nThe richer dictionary on the right is not larger in size, heavier subsampling\nis performed in order to stay on the same order of magnitude.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn.decomposition import SparseCoder\n\n\ndef ricker_function(resolution, center, width):\n \"\"\"Discrete sub-sampled Ricker (Mexican hat) wavelet\"\"\"\n x = np.linspace(0, resolution - 1, resolution)\n x = (\n (2 / (np.sqrt(3 * width) * np.pi**0.25))\n * (1 - (x - center) ** 2 / width**2)\n * np.exp(-((x - center) ** 2) / (2 * width**2))\n )\n return x\n\n\ndef ricker_matrix(width, resolution, n_components):\n \"\"\"Dictionary of Ricker (Mexican hat) wavelets\"\"\"\n centers = np.linspace(0, resolution - 1, n_components)\n D = np.empty((n_components, resolution))\n for i, center in enumerate(centers):\n D[i] = ricker_function(resolution, center, width)\n D /= np.sqrt(np.sum(D**2, axis=1))[:, np.newaxis]\n return D\n\n\nresolution = 1024\nsubsampling = 3 # subsampling factor\nwidth = 100\nn_components = resolution // subsampling\n\n# Compute a wavelet dictionary\nD_fixed = ricker_matrix(width=width, resolution=resolution, n_components=n_components)\nD_multi = np.r_[\n tuple(\n ricker_matrix(width=w, resolution=resolution, n_components=n_components // 5)\n for w in (10, 50, 100, 500, 1000)\n )\n]\n\n# Generate a signal\ny = np.linspace(0, resolution - 1, resolution)\nfirst_quarter = y < resolution / 4\ny[first_quarter] = 3.0\ny[np.logical_not(first_quarter)] = -1.0\n\n# List the different sparse coding methods in the following format:\n# (title, transform_algorithm, transform_alpha,\n# transform_n_nozero_coefs, color)\nestimators = [\n (\"OMP\", \"omp\", None, 15, \"navy\"),\n (\"Lasso\", \"lasso_lars\", 2, None, \"turquoise\"),\n]\nlw = 2\n\nplt.figure(figsize=(13, 6))\nfor subplot, (D, title) in enumerate(\n zip((D_fixed, D_multi), (\"fixed width\", \"multiple widths\"))\n):\n plt.subplot(1, 2, subplot + 1)\n plt.title(\"Sparse coding against %s dictionary\" % title)\n plt.plot(y, lw=lw, linestyle=\"--\", label=\"Original signal\")\n # Do a wavelet approximation\n for title, algo, alpha, n_nonzero, color in estimators:\n coder = SparseCoder(\n dictionary=D,\n transform_n_nonzero_coefs=n_nonzero,\n transform_alpha=alpha,\n transform_algorithm=algo,\n )\n x = coder.transform(y.reshape(1, -1))\n density = len(np.flatnonzero(x))\n x = np.ravel(np.dot(x, D))\n squared_error = np.sum((y - x) ** 2)\n plt.plot(\n x,\n color=color,\n lw=lw,\n label=\"%s: %s nonzero coefs,\\n%.2f error\" % (title, density, squared_error),\n )\n\n # Soft thresholding debiasing\n coder = SparseCoder(\n dictionary=D, transform_algorithm=\"threshold\", transform_alpha=20\n )\n x = coder.transform(y.reshape(1, -1))\n _, idx = np.where(x != 0)\n x[0, idx], _, _, _ = np.linalg.lstsq(D[idx, :].T, y, rcond=None)\n x = np.ravel(np.dot(x, D))\n squared_error = np.sum((y - x) ** 2)\n plt.plot(\n x,\n color=\"darkorange\",\n lw=lw,\n label=\"Thresholding w/ debiasing:\\n%d nonzero coefs, %.2f error\"\n % (len(idx), squared_error),\n )\n plt.axis(\"tight\")\n plt.legend(shadow=False, loc=\"best\")\nplt.subplots_adjust(0.04, 0.07, 0.97, 0.90, 0.09, 0.2)\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }