{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Low rank Sinkhorn\n\n

Note

Example added in release: 0.9.2.

\n\nThis example illustrates the computation of Low Rank Sinkhorn [26].\n\n[65] Scetbon, M., Cuturi, M., & Peyr\u00e9, G. (2021).\n\"Low-rank Sinkhorn factorization\". In International Conference on Machine Learning.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Laur\u00e8ne David \n#\n# License: MIT License\n#\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot.plot\nfrom ot.datasets import make_1D_gauss as gauss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 100\nm = 120\n\n# Gaussian distribution\na = gauss(n, m=int(n / 3), s=25 / np.sqrt(2)) + 1.5 * gauss(\n n, m=int(5 * n / 6), s=15 / np.sqrt(2)\n)\na = a / np.sum(a)\n\nb = 2 * gauss(m, m=int(m / 5), s=30 / np.sqrt(2)) + gauss(\n m, m=int(m / 2), s=35 / np.sqrt(2)\n)\nb = b / np.sum(b)\n\n# Source and target distribution\nX = np.arange(n).reshape(-1, 1)\nY = np.arange(m).reshape(-1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solve Low rank sinkhorn\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Solve low rank sinkhorn\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Q, R, g, log = ot.lowrank_sinkhorn(\n X,\n Y,\n a,\n b,\n rank=10,\n init=\"random\",\n gamma_init=\"rescale\",\n rescale_cost=True,\n warn=False,\n log=True,\n)\nP = log[\"lazy_plan\"][:]\n\not.plot.plot1D_mat(a, b, P, \"OT matrix Low rank\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sinkhorn vs Low Rank Sinkhorn\nCompare Sinkhorn and Low rank sinkhorn with different regularizations and ranks.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Compute cost matrix for sinkhorn OT\nM = ot.dist(X, Y)\nM = M / np.max(M)\n\n# Solve sinkhorn with different regularizations using ot.solve\nlist_reg = [0.05, 0.005, 0.001]\nlist_P_Sin = []\n\nfor reg in list_reg:\n P = ot.solve(M, a, b, reg=reg, max_iter=2000, tol=1e-8).plan\n list_P_Sin.append(P)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Solve low rank sinkhorn with different ranks using ot.solve_sample\nlist_rank = [3, 10, 50]\nlist_P_LR = []\n\nfor rank in list_rank:\n P = ot.solve_sample(X, Y, a, b, method=\"lowrank\", rank=rank).plan\n P = P[:]\n list_P_LR.append(P)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Plot sinkhorn vs low rank sinkhorn\npl.figure(1, figsize=(10, 8))\n\npl.subplot(2, 3, 1)\npl.imshow(list_P_Sin[0], interpolation=\"nearest\", cmap=\"gray_r\")\npl.axis(\"off\")\npl.title(\"Sinkhorn (reg=0.05)\")\n\npl.subplot(2, 3, 2)\npl.imshow(list_P_Sin[1], interpolation=\"nearest\", cmap=\"gray_r\")\npl.axis(\"off\")\npl.title(\"Sinkhorn (reg=0.005)\")\n\npl.subplot(2, 3, 3)\npl.imshow(list_P_Sin[2], interpolation=\"nearest\", cmap=\"gray_r\")\npl.axis(\"off\")\npl.title(\"Sinkhorn (reg=0.001)\")\npl.show()\n\npl.subplot(2, 3, 4)\npl.imshow(list_P_LR[0], interpolation=\"nearest\", cmap=\"gray_r\")\npl.axis(\"off\")\npl.title(\"Low rank (rank=3)\")\n\npl.subplot(2, 3, 5)\npl.imshow(list_P_LR[1], interpolation=\"nearest\", cmap=\"gray_r\")\npl.axis(\"off\")\npl.title(\"Low rank (rank=10)\")\n\npl.subplot(2, 3, 6)\npl.imshow(list_P_LR[2], interpolation=\"nearest\", cmap=\"gray_r\")\npl.axis(\"off\")\npl.title(\"Low rank (rank=50)\")\n\npl.tight_layout()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }