{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Optimal Transport with different ground metrics\n\n2D OT on empirical distribution with different ground metric.\n\nStole the figure idea from Fig. 1 and 2 in\nhttps://arxiv.org/pdf/1706.07650.pdf\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset 1 : uniform sampling\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 20 # nb samples\nxs = np.zeros((n, 2))\nxs[:, 0] = np.arange(n) + 1\nxs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...\n\nxt = np.zeros((n, 2))\nxt[:, 1] = np.arange(n) + 1\n\na, b = ot.unif(n), ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM1 = ot.dist(xs, xt, metric=\"euclidean\")\nM1 /= M1.max()\n\n# loss matrix\nM2 = ot.dist(xs, xt, metric=\"sqeuclidean\")\nM2 /= M2.max()\n\n# loss matrix\nMp = ot.dist(xs, xt, metric=\"cityblock\")\nMp /= Mp.max()\n\n# Data\npl.figure(1, figsize=(7, 3))\npl.clf()\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.axis(\"equal\")\npl.title(\"Source and target distributions\")\n\n\n# Cost matrices\npl.figure(2, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\npl.imshow(M1, interpolation=\"nearest\")\npl.title(\"Euclidean cost\")\n\npl.subplot(1, 3, 2)\npl.imshow(M2, interpolation=\"nearest\")\npl.title(\"Squared Euclidean cost\")\n\npl.subplot(1, 3, 3)\npl.imshow(Mp, interpolation=\"nearest\")\npl.title(\"L1 (cityblock cost\")\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset 1 : Plot OT Matrices\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "G1 = ot.emd(a, b, M1)\nG2 = ot.emd(a, b, M2)\nGp = ot.emd(a, b, Mp)\n\n# OT matrices\npl.figure(3, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\not.plot.plot2D_samples_mat(xs, xt, G1, c=[0.5, 0.5, 1])\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.axis(\"equal\")\n# pl.legend(loc=0)\npl.title(\"OT Euclidean\")\n\npl.subplot(1, 3, 2)\not.plot.plot2D_samples_mat(xs, xt, G2, c=[0.5, 0.5, 1])\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.axis(\"equal\")\n# pl.legend(loc=0)\npl.title(\"OT squared Euclidean\")\n\npl.subplot(1, 3, 3)\not.plot.plot2D_samples_mat(xs, xt, Gp, c=[0.5, 0.5, 1])\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.axis(\"equal\")\n# pl.legend(loc=0)\npl.title(\"OT L1 (cityblock)\")\npl.tight_layout()\n\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset 2 : Partial circle\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 20 # nb samples\nxtot = np.zeros((n + 1, 2))\nxtot[:, 0] = np.cos((np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)\nxtot[:, 1] = np.sin((np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)\n\nxs = xtot[:n, :]\nxt = xtot[1:, :]\n\na, b = ot.unif(n), ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM1 = ot.dist(xs, xt, metric=\"euclidean\")\nM1 /= M1.max()\n\n# loss matrix\nM2 = ot.dist(xs, xt, metric=\"sqeuclidean\")\nM2 /= M2.max()\n\n# loss matrix\nMp = ot.dist(xs, xt, metric=\"cityblock\")\nMp /= Mp.max()\n\n\n# Data\npl.figure(4, figsize=(7, 3))\npl.clf()\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.axis(\"equal\")\npl.title(\"Source and target distributions\")\n\n\n# Cost matrices\npl.figure(5, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\npl.imshow(M1, interpolation=\"nearest\")\npl.title(\"Euclidean cost\")\n\npl.subplot(1, 3, 2)\npl.imshow(M2, interpolation=\"nearest\")\npl.title(\"Squared Euclidean cost\")\n\npl.subplot(1, 3, 3)\npl.imshow(Mp, interpolation=\"nearest\")\npl.title(\"L1 (cityblock) cost\")\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset 2 : Plot OT Matrices\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "G1 = ot.emd(a, b, M1)\nG2 = ot.emd(a, b, M2)\nGp = ot.emd(a, b, Mp)\n\n# OT matrices\npl.figure(6, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\not.plot.plot2D_samples_mat(xs, xt, G1, c=[0.5, 0.5, 1])\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.axis(\"equal\")\n# pl.legend(loc=0)\npl.title(\"OT Euclidean\")\n\npl.subplot(1, 3, 2)\not.plot.plot2D_samples_mat(xs, xt, G2, c=[0.5, 0.5, 1])\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.axis(\"equal\")\n# pl.legend(loc=0)\npl.title(\"OT squared Euclidean\")\n\npl.subplot(1, 3, 3)\not.plot.plot2D_samples_mat(xs, xt, Gp, c=[0.5, 0.5, 1])\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.axis(\"equal\")\n# pl.legend(loc=0)\npl.title(\"OT L1 (cityblock)\")\npl.tight_layout()\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }