{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Sliced Wasserstein Distance on 2D distributions\n\n

Note

Example added in release: 0.8.0.

\n\nThis example illustrates the computation of the sliced Wasserstein Distance as\nproposed in [31].\n\n[31] Bonneel, Nicolas, et al. \"Sliced and radon wasserstein barycenters of\nmeasures.\" Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Adrien Corenflos \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport matplotlib.pylab as pl\nimport numpy as np\n\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 200 # nb samples\n\nmu_s = np.array([0, 0])\ncov_s = np.array([[1, 0], [0, 1]])\n\nmu_t = np.array([4, 4])\ncov_t = np.array([[1, -0.8], [-0.8, 1]])\n\nxs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)\nxt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)\n\na, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1)\npl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\npl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\npl.legend(loc=0)\npl.title(\"Source and target distributions\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sliced Wasserstein distance for different seeds and number of projections\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_seed = 20\nn_projections_arr = np.logspace(0, 3, 10, dtype=int)\nres = np.empty((n_seed, 10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for seed in range(n_seed):\n for i, n_projections in enumerate(n_projections_arr):\n res[seed, i] = ot.sliced_wasserstein_distance(\n xs, xt, a, b, n_projections, seed=seed\n )\n\nres_mean = np.mean(res, axis=0)\nres_std = np.std(res, axis=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot Sliced Wasserstein Distance\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2)\npl.plot(n_projections_arr, res_mean, label=\"SWD\")\npl.fill_between(\n n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5\n)\n\npl.legend()\npl.xscale(\"log\")\n\npl.xlabel(\"Number of projections\")\npl.ylabel(\"Distance\")\npl.title(\"Sliced Wasserstein Distance with 95% confidence interval\")\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }