{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Nystr\u00f6m approximation for OT\n\nShows how to use Nystr\u00f6m kernel approximation for approximating the Sinkhorn algorithm in linear time.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Titouan Vayer \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nfrom ot.lowrank import kernel_nystroem, sinkhorn_low_rank_kernel\nfrom ot.bregman import empirical_sinkhorn_nystroem\nimport math\nimport ot\nimport matplotlib.pyplot as plt\nfrom matplotlib.colors import LogNorm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "offset = 1\nn_samples_per_blob = 500 # We use 2D ''blobs'' data\nrandom_state = 42\nstd = 0.2 # standard deviation\nnp.random.seed(random_state)\n\ncenters = np.array(\n [\n [-offset, -offset], # Class 0 - blob 1\n [-offset, offset], # Class 0 - blob 2\n [offset, -offset], # Class 1 - blob 1\n [offset, offset], # Class 1 - blob 2\n ]\n)\n\nX_list = []\ny_list = []\n\nfor i, center in enumerate(centers):\n blob_points = np.random.randn(n_samples_per_blob, 2) * std + center\n label = 0 if i < 2 else 1\n X_list.append(blob_points)\n y_list.append(np.full(n_samples_per_blob, label))\n\nX = np.vstack(X_list)\ny = np.concatenate(y_list)\nXs = X[y == 0] # source data\nXt = X[y == 1] # target data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.scatter(Xs[:, 0], Xs[:, 1], label=\"Source\")\nplt.scatter(Xt[:, 0], Xt[:, 1], label=\"Target\")\nplt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute the Nystr\u00f6m approximation of the Gaussian kernel\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reg = 5.0 # proportional to the std of the Gaussian kernel\nanchors = 10 # number of anchor points for the Nystr\u00f6m approximation\not.tic()\nleft_factor, right_factor = kernel_nystroem(\n Xs, Xt, anchors=anchors, sigma=math.sqrt(reg / 2.0), random_state=random_state\n)\not.toc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use this approximation in a Sinkhorn algorithm with low rank kernel.\nEach matrix/vector product in the Sinkhorn is accelerated\nsince $Kv = K_1 (K_2^\\top v)$ can be computed in $O(nr)$ time\ninstead of $O(n^2)$\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "numItermax = 1000\nstopThr = 1e-7\nverbose = True\na, b = None, None\nwarn = True\nwarmstart = None\not.tic()\nu, v, dict_log = sinkhorn_low_rank_kernel(\n K1=left_factor,\n K2=right_factor,\n a=a,\n b=b,\n numItermax=numItermax,\n stopThr=stopThr,\n verbose=verbose,\n log=True,\n warn=warn,\n warmstart=warmstart,\n)\not.toc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare with Sinkhorn\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "M = ot.dist(Xs, Xt)\not.tic()\nG, log_ = ot.sinkhorn(\n a=[],\n b=[],\n M=M,\n reg=reg,\n numItermax=numItermax,\n verbose=verbose,\n log=True,\n warn=warn,\n warmstart=warmstart,\n)\not.toc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use directly ot.bregman.empirical_sinkhorn_nystroem\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ot.tic()\nG_nys = empirical_sinkhorn_nystroem(\n Xs,\n Xt,\n anchors=anchors,\n reg=reg,\n numItermax=numItermax,\n verbose=True,\n random_state=random_state,\n)[:]\not.toc()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ot.tic()\nG_sinkh = ot.bregman.empirical_sinkhorn(\n Xs, Xt, reg=reg, numIterMax=numItermax, verbose=True\n)\not.toc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare OT plans\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, ax = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)\nvmin = min(G_sinkh.min(), G_nys.min())\nvmax = max(G_sinkh.max(), G_nys.max())\nnorm = LogNorm(vmin=vmin, vmax=vmax)\nim0 = ax[0].imshow(G_sinkh, norm=norm, cmap=\"coolwarm\")\nim1 = ax[1].imshow(G_nys, norm=norm, cmap=\"coolwarm\")\ncbar = fig.colorbar(im1, ax=ax, orientation=\"vertical\", fraction=0.046, pad=0.04)\nax[0].set_title(\"OT plan Sinkhorn\")\nax[1].set_title(\"OT plan Nystr\u00f6m Sinkhorn\")\nfor a in ax:\n a.set_xticks([])\n a.set_yticks([])\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }