{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Spherical Sliced-Wasserstein Embedding on Sphere\n\nHere, we aim at transforming samples into a uniform\ndistribution on the sphere by minimizing SSW:\n\n\\begin{align}\\min_{x} SSW_2(\\nu, \\frac{1}{n}\\sum_{i=1}^n \\delta_{x_i})\\end{align}\n\nwhere $\\nu=\\mathrm{Unif}(S^{d-1})$.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Cl\u00e9ment Bonet \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3\n\nimport numpy as np\nimport matplotlib.pyplot as pl\nimport matplotlib.animation as animation\nimport torch\nimport torch.nn.functional as F\n\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data generation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(1)\n\nN = 500\nx0 = torch.rand(N, 3)\nx0 = F.normalize(x0, dim=-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_sphere(ax):\n # Create a sphere using spherical coordinates\n phi = np.linspace(0, 2 * np.pi, 100)\n theta = np.linspace(0, np.pi, 100)\n phi, theta = np.meshgrid(phi, theta)\n\n # Compute the spherical coordinates\n X = np.sin(theta) * np.cos(phi)\n Y = np.sin(theta) * np.sin(phi)\n Z = np.cos(theta)\n\n # Plot the wireframe\n ax.plot_wireframe(X, Y, Z, color=\"gray\", alpha=0.3)\n\n\n# plot the distributions\npl.figure(1)\nax = pl.axes(projection=\"3d\")\nplot_sphere(ax)\nax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label=\"Data samples\", alpha=0.5)\nax.set_title(\"Data distribution\")\nax.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient descent\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x = x0.clone()\nx.requires_grad_(True)\n\nn_iter = 100\nlr = 150\n\nlosses = []\nxvisu = torch.zeros(n_iter, N, 3)\n\nfor i in range(n_iter):\n sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)\n grad_x = torch.autograd.grad(sw, x)[0]\n\n x = x - lr * grad_x / np.sqrt(i / 10 + 1)\n x = F.normalize(x, p=2, dim=1)\n\n losses.append(sw.item())\n xvisu[i, :, :] = x.detach().clone()\n\n if i % 100 == 0:\n print(\"Iter: {:3d}, loss={}\".format(i, losses[-1]))\n\npl.figure(1)\npl.semilogy(losses)\npl.grid()\npl.title(\"SSW\")\npl.xlabel(\"Iterations\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot trajectories of generated samples along iterations\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80]\n\nfig = pl.figure(3, (10, 10))\nfor i in range(9):\n # pl.subplot(3, 3, i + 1)\n # ax = pl.axes(projection='3d')\n ax = fig.add_subplot(3, 3, i + 1, projection=\"3d\")\n plot_sphere(ax)\n ax.scatter(\n xvisu[ivisu[i], :, 0],\n xvisu[ivisu[i], :, 1],\n xvisu[ivisu[i], :, 2],\n label=\"Data samples\",\n alpha=0.5,\n )\n ax.set_title(\"Iter. {}\".format(ivisu[i]))\n # ax.axis(\"off\")\n if i == 0:\n ax.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animate trajectories of generated samples along iteration\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, (8, 8))\n\n\ndef _update_plot(i):\n i = 3 * i\n pl.clf()\n ax = pl.axes(projection=\"3d\")\n plot_sphere(ax)\n ax.scatter(\n xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label=\"Data samples$\", alpha=0.5\n )\n ax.axis(\"off\")\n ax.set_xlim((-1.5, 1.5))\n ax.set_ylim((-1.5, 1.5))\n ax.set_title(\"Iter. {}\".format(i))\n return 1\n\n\nprint(xvisu.shape)\n\ni = 0\nax = pl.axes(projection=\"3d\")\nplot_sphere(ax)\nax.scatter(\n xvisu[i, :, 0],\n xvisu[i, :, 1],\n xvisu[i, :, 2],\n label=\"Data samples from $G\\#\\mu_n$\",\n alpha=0.5,\n)\nax.axis(\"off\")\nax.set_xlim((-1.5, 1.5))\nax.set_ylim((-1.5, 1.5))\nax.set_title(\"Iter. {}\".format(ivisu[i]))\n\n\nani = animation.FuncAnimation(\n pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000\n)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }