{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Spherical Sliced Wasserstein on distributions in S^2\n\n

Note

Example added in release: 0.8.0.

\n\nThis example illustrates the computation of the spherical sliced Wasserstein discrepancy as\nproposed in [46].\n\n[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein\". International Conference on Learning Representations.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Cl\u00e9ment Bonet \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport matplotlib.pylab as pl\nimport numpy as np\n\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 200 # nb samples\n\nxs = np.random.randn(n, 3)\nxt = np.random.randn(n, 3)\n\nxs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True))\nxt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True))\n\na, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = pl.figure(figsize=(10, 10))\nax = pl.axes(projection=\"3d\")\nax.grid(False)\n\nu, v = np.mgrid[0 : 2 * np.pi : 30j, 0 : np.pi : 30j]\nx = np.cos(u) * np.sin(v)\ny = np.sin(u) * np.sin(v)\nz = np.cos(v)\nax.plot_surface(x, y, z, color=\"gray\", alpha=0.03)\nax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color=\"gray\")\n\nax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label=\"Source\")\nax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label=\"Target\")\n\nfs = 10\n# Labels\nax.set_xlabel(\"x\", fontsize=fs)\nax.set_ylabel(\"y\", fontsize=fs)\nax.set_zlabel(\"z\", fontsize=fs)\n\nax.view_init(20, 120)\nax.set_xlim(-1.5, 1.5)\nax.set_ylim(-1.5, 1.5)\nax.set_zlim(-1.5, 1.5)\n\n# Ticks\nax.set_xticks([-1, 0, 1])\nax.set_yticks([-1, 0, 1])\nax.set_zticks([-1, 0, 1])\n\npl.legend(loc=0)\npl.title(\"Source and Target distribution\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Spherical Sliced Wasserstein for different seeds and number of projections\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_seed = 20\nn_projections_arr = np.logspace(0, 3, 10, dtype=int)\nres = np.empty((n_seed, 10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for seed in range(n_seed):\n for i, n_projections in enumerate(n_projections_arr):\n res[seed, i] = ot.sliced_wasserstein_sphere(\n xs, xt, a, b, n_projections, seed=seed, p=1\n )\n\nres_mean = np.mean(res, axis=0)\nres_std = np.std(res, axis=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot Spherical Sliced Wasserstein\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2)\npl.plot(n_projections_arr, res_mean, label=r\"$SSW_1$\")\npl.fill_between(\n n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5\n)\n\npl.legend()\npl.xscale(\"log\")\n\npl.xlabel(\"Number of projections\")\npl.ylabel(\"Distance\")\npl.title(\"Spherical Sliced Wasserstein Distance with 95% confidence interval\")\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }