{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Wasserstein unmixing with PyTorch\n\n

Note

Example added in release: 0.8.0.

\n\nIn this example we estimate mixing parameters from distributions that minimize\nthe Wasserstein distance. In other words we suppose that a target\ndistribution $\\mu^t$ can be expressed as a weighted sum of source\ndistributions $\\mu^s_k$ with the following model:\n\n\\begin{align}\\mu^t = \\sum_{k=1}^K w_k\\mu^s_k\\end{align}\n\nwhere $\\mathbf{w}$ is a vector of size $K$ and belongs in the\ndistribution simplex $\\Delta_K$.\n\nIn order to estimate this weight vector we propose to optimize the Wasserstein\ndistance between the model and the observed $\\mu^t$ with respect to\nthe vector. This leads to the following optimization problem:\n\n\\begin{align}\\min_{\\mathbf{w}\\in\\Delta_K} \\quad W \\left(\\mu^t,\\sum_{k=1}^K w_k\\mu^s_k\\right)\\end{align}\n\nThis minimization is done in this example with a simple projected gradient\ndescent in PyTorch. We use the automatic backend of POT that allows us to\ncompute the Wasserstein distance with :any:`ot.emd2` with\ndifferentiable losses.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nt = 100\nnt1 = 10 #\n\nns1 = 50\nns = 2 * ns1\n\nrng = np.random.RandomState(2)\n\nxt = rng.randn(nt, 2) * 0.2\nxt[:nt1, 0] += 1\nxt[nt1:, 1] += 1\n\n\nxs1 = rng.randn(ns1, 2) * 0.2\nxs1[:, 0] += 1\nxs2 = rng.randn(ns1, 2) * 0.2\nxs2[:, 1] += 1\n\nxs = np.concatenate((xs1, xs2))\n\n# Sample reweighting matrix H\nH = np.zeros((ns, 2))\nH[:ns1, 0] = 1 / ns1\nH[ns1:, 1] = 1 / ns1\n# each columns sums to 1 and has weights only for samples form the\n# corresponding source distribution\n\nM = ot.dist(xs, xt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1)\npl.scatter(xt[:, 0], xt[:, 1], label=\"Target $\\mu^t$\", alpha=0.5)\npl.scatter(xs1[:, 0], xs1[:, 1], label=\"Source $\\mu^s_1$\", alpha=0.5)\npl.scatter(xs2[:, 0], xs2[:, 1], label=\"Source $\\mu^s_2$\", alpha=0.5)\npl.title(\"Sources and Target distributions\")\npl.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimization of the model wrt the Wasserstein distance\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# convert numpy arrays to torch tensors\nH2 = torch.tensor(H)\nM2 = torch.tensor(M)\n\n# weights for the source distributions\nw = torch.tensor(ot.unif(2), requires_grad=True)\n\n# uniform weights for target\nb = torch.tensor(ot.unif(nt))\n\nlr = 2e-3 # learning rate\nniter = 500 # number of iterations\nlosses = [] # loss along the iterations\n\n# loss for the minimal Wasserstein estimator\n\n\ndef get_loss(w):\n a = torch.mv(H2, w) # distribution reweighting\n return ot.emd2(a, b, M2) # squared Wasserstein 2\n\n\nfor i in range(niter):\n loss = get_loss(w)\n losses.append(float(loss))\n\n loss.backward()\n\n with torch.no_grad():\n w -= lr * w.grad # gradient step\n w[:] = ot.utils.proj_simplex(w) # projection on the simplex\n\n w.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimated weights and convergence of the objective\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "we = w.detach().numpy()\nprint(\"Estimated mixture:\", we)\n\npl.figure(2)\npl.semilogy(losses)\npl.grid()\npl.title(\"Wasserstein distance\")\npl.xlabel(\"Iterations\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting the reweighted source distribution\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(3)\n\n# compute source weights\nws = H.dot(we)\n\npl.scatter(xt[:, 0], xt[:, 1], label=\"Target $\\mu^t$\", alpha=0.5)\npl.scatter(\n xs[:, 0],\n xs[:, 1],\n color=\"C3\",\n s=ws * 20 * ns,\n label=\"Weighted sources $\\sum_{k} w_k\\mu^s_k$\",\n alpha=0.5,\n)\npl.title(\"Target and reweighted source distributions\")\npl.legend()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }