{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Unbalanced Optimal Transport\n", "=========================================\n", "\n", "*Important:* Please read the [installation page](http://gpeyre.github.io/numerical-tours/installation_python/) for details about how to install the toolboxes.\n", "$\\newcommand{\\dotp}[2]{\\langle #1, #2 \\rangle}$\n", "$\\newcommand{\\enscond}[2]{\\lbrace #1, #2 \\rbrace}$\n", "$\\newcommand{\\pd}[2]{ \\frac{ \\partial #1}{\\partial #2} }$\n", "$\\newcommand{\\umin}[1]{\\underset{#1}{\\min}\\;}$\n", "$\\newcommand{\\umax}[1]{\\underset{#1}{\\max}\\;}$\n", "$\\newcommand{\\umin}[1]{\\underset{#1}{\\min}\\;}$\n", "$\\newcommand{\\uargmin}[1]{\\underset{#1}{argmin}\\;}$\n", "$\\newcommand{\\norm}[1]{\\|#1\\|}$\n", "$\\newcommand{\\abs}[1]{\\left|#1\\right|}$\n", "$\\newcommand{\\choice}[1]{ \\left\\{ \\begin{array}{l} #1 \\end{array} \\right. }$\n", "$\\newcommand{\\pa}[1]{\\left(#1\\right)}$\n", "$\\newcommand{\\diag}[1]{{diag}\\left( #1 \\right)}$\n", "$\\newcommand{\\qandq}{\\quad\\text{and}\\quad}$\n", "$\\newcommand{\\qwhereq}{\\quad\\text{where}\\quad}$\n", "$\\newcommand{\\qifq}{ \\quad \\text{if} \\quad }$\n", "$\\newcommand{\\qarrq}{ \\quad \\Longrightarrow \\quad }$\n", "$\\newcommand{\\ZZ}{\\mathbb{Z}}$\n", "$\\newcommand{\\CC}{\\mathbb{C}}$\n", "$\\newcommand{\\RR}{\\mathbb{R}}$\n", "$\\newcommand{\\EE}{\\mathbb{E}}$\n", "$\\newcommand{\\Zz}{\\mathcal{Z}}$\n", "$\\newcommand{\\Ww}{\\mathcal{W}}$\n", "$\\newcommand{\\Vv}{\\mathcal{V}}$\n", "$\\newcommand{\\Nn}{\\mathcal{N}}$\n", "$\\newcommand{\\NN}{\\mathcal{N}}$\n", "$\\newcommand{\\Hh}{\\mathcal{H}}$\n", "$\\newcommand{\\Bb}{\\mathcal{B}}$\n", "$\\newcommand{\\Ee}{\\mathcal{E}}$\n", "$\\newcommand{\\Cc}{\\mathcal{C}}$\n", "$\\newcommand{\\Gg}{\\mathcal{G}}$\n", "$\\newcommand{\\Ss}{\\mathcal{S}}$\n", "$\\newcommand{\\Pp}{\\mathcal{P}}$\n", "$\\newcommand{\\Ff}{\\mathcal{F}}$\n", "$\\newcommand{\\Xx}{\\mathcal{X}}$\n", "$\\newcommand{\\Mm}{\\mathcal{M}}$\n", "$\\newcommand{\\Ii}{\\mathcal{I}}$\n", "$\\newcommand{\\Dd}{\\mathcal{D}}$\n", "$\\newcommand{\\Ll}{\\mathcal{L}}$\n", "$\\newcommand{\\Tt}{\\mathcal{T}}$\n", "$\\newcommand{\\si}{\\sigma}$\n", "$\\newcommand{\\al}{\\alpha}$\n", "$\\newcommand{\\la}{\\lambda}$\n", "$\\newcommand{\\ga}{\\gamma}$\n", "$\\newcommand{\\Ga}{\\Gamma}$\n", "$\\newcommand{\\La}{\\Lambda}$\n", "$\\newcommand{\\si}{\\sigma}$\n", "$\\newcommand{\\Si}{\\Sigma}$\n", "$\\newcommand{\\be}{\\beta}$\n", "$\\newcommand{\\de}{\\delta}$\n", "$\\newcommand{\\De}{\\Delta}$\n", "$\\newcommand{\\phi}{\\varphi}$\n", "$\\newcommand{\\th}{\\theta}$\n", "$\\newcommand{\\om}{\\omega}$\n", "$\\newcommand{\\Om}{\\Omega}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This numerical tour details how to perform \"unbalanced\" OT, which allows one to compare measures with different total mass, and also leads to more regular transportation plans by enabling creation/destruction of mass. This extension of classicl (\"balanced\") OT is crucial for applications of OT to imaging sciences and machine learning, since it makes OT more robust to noise and outliers. The modification with respect to the usual OT is very minor, since it corresponds a penalization of the mass conservation constraint.\n", "\n", "The original idea can be found in the paper of [Matthias Liero, Alexander Mielke and Giuseppe Savaré](https://arxiv.org/abs/1508.07941). The entropic regularized version with the corresponding Sinkhorn's algorithm can be found in the paper of [Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer and François-Xavier Vialard](https://arxiv.org/abs/1607.05816)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You need to install [CVXPY](https://www.cvxpy.org/). _Warning:_ seems to not be working on Python 3.7, use rather 3.6." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import cvxpy as cp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Definition of the input measures\n", "------------------------------------------\n", "\n", "For the sake of concreteness and to ease the display, we consider the transport between 1D distributions. But this can be applied to any OT problem.\n", "\n", "We consider two dicretes distributions\n", "$$ \\sum_{i=1}^n a_i \\de_{x_i} \\qandq \n", " \\sum_{j=1}^m b_j \\de_{y_j}, $$\n", "where $n,m$ are the number of points, $\\de_x$ is the Dirac at\n", "location $x$, and $(x_i)_i, (y_j)_j$ are the positions of the diracs (in some metric space, in the following we consider the space to be $\\RR$)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "n = int(120)\n", "m = int(110)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We consider two Gaussian measures sampled on a 1-D grid." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "Gaussian = lambda t0,sigma,N: np.exp(-(np.arange(0,N)/N-t0)**2/(2*sigma**2))\n", "normalize = lambda p: p/np.sum(p)\n", "sigma = .06;\n", "a = Gaussian(.25,sigma,n)\n", "b = Gaussian(.8,sigma,m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add some minimal mass and normalize. Here we do not use the same total mass for $a$ and $b$." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "vmin = .01;\n", "a = 0.95*normalize( a+np.max(a)*vmin)\n", "b = 1.05*normalize( b+np.max(b)*vmin)\n", "x = np.arange(0,n)/n\n", "y = np.arange(0,m)/m" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the histograms." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEtxJREFUeJzt3X+MXedd5/H3B5skgoUUOUZCScwYxUXrFATlEkDi53rpOpWoQWt2HUAEZGEVCH/wOxUqW4L4I6AlEmpWxatEhKyWpEQCRhDkPzZdfqk1GZP+cpClqSlkCBJO7fUqlDR1+90/7nEzezuTe2bmzsy993m/pCufH8+d+zxzz/3Mc557zuNUFZKkNnzBbldAkrRzDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQ/budgVG3XLLLbWwsLDb1ZCkmXLu3LmXqmr/uHJTF/oLCwssLS3tdjUkaaYk+fs+5RzekaSGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ3pFfpJjia5kGQ5yf1r7L8xyZPd/rNJFlbt+9ok709yPslHktw0uepLkjZibOgn2QM8DNwNHAbuSXJ4pNhJ4EpV3QE8BDzYPXcv8D+At1fVncB3Ap+eWO0lSRvSp6d/F7BcVRer6lXgCeDYSJljwGPd8lPAkSQB3gJ8uKo+BFBVn6iqz0ym6pKkjeoT+rcCL6xaX+m2rVmmqq4BV4F9wBuBSnImyd8k+YW1XiDJqSRLSZYuXbq00TZImkfJaw9NTJ/QX+s3Xj3L7AW+FfjB7t/vS3Lk8wpWna6qQVUN9u8fO3WEJGmT+oT+CnD7qvXbgBfXK9ON498MXO62/1lVvVRVnwSeBt681UpLmlPr9e7t8U9Mn9B/FjiU5GCSG4ATwOJImUXg3m75OPBMVRVwBvjaJF/U/TH4DuD5yVRdkrRRY2fZrKprSe5jGOB7gEer6nySB4ClqloEHgEeT7LMsId/onvulSS/yfAPRwFPV9WfbFNbJEljZNghnx6DwaCcWllq1LghnCnLq2mS5FxVDcaV845cSWqIoS9JDTH0JakhU/ffJUpqzEYuxVxd1vH9TbGnL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0JakhTsMgaXds9X/CckqGTbGnL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0JakhvUI/ydEkF5IsJ7l/jf03Jnmy2382yUK3fSHJvyb5YPd4z2SrL0naiLFz7yTZAzwMfDewAjybZLGqnl9V7CRwparuSHICeBD4z92+j1XV10243pKkTejT078LWK6qi1X1KvAEcGykzDHgsW75KeBIstXZlCRJk9Yn9G8FXli1vtJtW7NMVV0DrgL7un0HkzyX5M+SfNsW6ytpliWvPWbh586hPlMrr/VbHJ3HdL0y/wQcqKpPJPkG4A+T3FlV//f/e3JyCjgFcODAgR5V0qRc/4w4M63Uhj49/RXg9lXrtwEvrlcmyV7gZuByVX2qqj4BUFXngI8Bbxx9gao6XVWDqhrs379/462QJPXSJ/SfBQ4lOZjkBuAEsDhSZhG4t1s+DjxTVZVkf/dFMEm+CjgEXJxM1bVZa50Je3YstWHs8E5VXUtyH3AG2AM8WlXnkzwALFXVIvAI8HiSZeAywz8MAN8OPJDkGvAZ4O1VdXk7GiJJGi81ZYO5g8GglpaWdrsac21cb37KDgnNk504lWz0AE5yrqoG48p5R64kNcTQl6SG9LlkU3Oi75m1l3FK88ueviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIl2zOua3cALn6uV6+Kc0He/qS1BB7+pK2105P3eop6uuypy9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkO8Tn9OTfrSaP9jFWk+2NOXpIYY+pLUEENfkhpi6EtSQwx9SWpIr9BPcjTJhSTLSe5fY/+NSZ7s9p9NsjCy/0CSl5P83GSqLUnajLGhn2QP8DBwN3AYuCfJ4ZFiJ4ErVXUH8BDw4Mj+h4A/3Xp1JUlb0aenfxewXFUXq+pV4Ang2EiZY8Bj3fJTwJFkeGV3ku8FLgLnJ1NlSdJm9Qn9W4EXVq2vdNvWLFNV14CrwL4kXwz8IvArW6+qJGmr+oT+Wvd2jt6XuV6ZXwEeqqqXX/cFklNJlpIsXbp0qUeVJE215LXHNNRDn9NnGoYV4PZV67cBL65TZiXJXuBm4DLwTcDxJL8OvAH4bJJXqurdq59cVaeB0wCDwcAb/SVpm/QJ/WeBQ0kOAv8InAB+YKTMInAv8H7gOPBMVRXwbdcLJHkX8PJo4EuSds7Y0K+qa0nuA84Ae4BHq+p8kgeApapaBB4BHk+yzLCHf2I7K6217cRZrP/ntDTbUlP2yR0MBrW0tLTb1ZhJOz10OWWHjqbJtI2jN3CwJjlXVYNx5bwjV5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JakifWTY15XZrmpPrr9vAtCbS3LCnL0kNMfQlqSEO70ianGmbUvk6/yOIz7GnL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0JakhTsMwo6bpbnfvcJdmR6+efpKjSS4kWU5y/xr7b0zyZLf/bJKFbvtdST7YPT6U5PsmW31J0kaMDf0ke4CHgbuBw8A9SQ6PFDsJXKmqO4CHgAe77R8FBlX1dcBR4LeTeHYhSbukT0//LmC5qi5W1avAE8CxkTLHgMe65aeAI0lSVZ+sqmvd9psAT/4laRf1Cf1bgRdWra9029Ys04X8VWAfQJJvSnIe+Ajw9lV/BD4nyakkS0mWLl26tPFWSJJ66RP6a31lONpjX7dMVZ2tqjuBbwTekeSmzytYdbqqBlU12L9/f48qSZI2o0/orwC3r1q/DXhxvTLdmP3NwOXVBarqb4F/Ad602cpKkramT+g/CxxKcjDJDcAJYHGkzCJwb7d8HHimqqp7zl6AJF8JfDXw8YnUXNJ0SF57zIJZq++Ejb2SpqquJbkPOAPsAR6tqvNJHgCWqmoReAR4PMkywx7+ie7p3wrcn+TTwGeBn6iql7ajIZKk8VJTdjfNYDCopaWl3a7G1JvWTsqUHU7aCdN6MPYxRwdsknNVNRhXzmkYJKkhhr4kNcTQl6SGGPqS1BDnwZkhs/B9mTNuStPNnr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0JekhjgNg6SNm4U5QfpocN4Qe/qS1BBDX5Ia4vDODJjVM+nr9W7krFmaCfb0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkN6hX6So0kuJFlOcv8a+29M8mS3/2yShW77dyc5l+Qj3b//brLVlyRtxNjQT7IHeBi4GzgM3JPk8Eixk8CVqroDeAh4sNv+EvA9VfU1wL3A45OquCRp4/r09O8ClqvqYlW9CjwBHBspcwx4rFt+CjiSJFX1XFW92G0/D9yU5MZJVFyStHF9Qv9W4IVV6yvdtjXLVNU14Cqwb6TMfwSeq6pPba6qkqSt6jMNw1qTAIzeWP+6ZZLcyXDI5y1rvkByCjgFcODAgR5VkiRtRp+e/gpw+6r124AX1yuTZC9wM3C5W78N+APgh6vqY2u9QFWdrqpBVQ3279+/sRZIknrrE/rPAoeSHExyA3ACWBwps8jwi1qA48AzVVVJ3gD8CfCOqvqrSVW6Bclrj1k3T21p2ry/kfPctlXGhn43Rn8fcAb4W+C9VXU+yQNJ3tYVewTYl2QZ+Bng+mWd9wF3AO9M8sHu8eUTb4UkqZfUlM17OxgMamlpabersevmtcMxZYebNmJeD8pRM3qQJjlXVYNx5bwjV5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGtJnlk3toHm/6fF6+2b0pkdp5tnTl6SG2NOX9Prm/fRz1Or2zuEpqT19SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkO8OWsKtHbvC8z9/S/S1LKnL0kNsacv6fO1ePq5ljk8JbWnL0kNMfQlqSGGviQ1pFfoJzma5EKS5ST3r7H/xiRPdvvPJlnotu9L8r4kLyd592SrLknaqLGhn2QP8DBwN3AYuCfJ4ZFiJ4ErVXUH8BDwYLf9FeCdwM9NrMaSpE3r09O/C1iuqotV9SrwBHBspMwx4LFu+SngSJJU1b9U1V8yDH9J0i7rE/q3Ai+sWl/ptq1ZpqquAVeBfX0rkeRUkqUkS5cuXer7tJmXeGUc+HuQdlKf0F/r4zh6wWqfMuuqqtNVNaiqwf79+/s+TZK0QX1CfwW4fdX6bcCL65VJshe4Gbg8iQpKkianT+g/CxxKcjDJDcAJYHGkzCJwb7d8HHimak5uX5NacX2czbG2tc3J72fsNAxVdS3JfcAZYA/waFWdT/IAsFRVi8AjwONJlhn28E9cf36SjwNfCtyQ5HuBt1TV85NviiRpnF5z71TV08DTI9t+edXyK8D3r/PchS3UT5I0Qd6RK0kNcZbNHTbjw4Hbag4nNJSmjj19SWqIoS9JDTH0JakhjulLLfNLps25/nubwS+f7OlLUkMMfUlqiMM7O8Sz6I2Z4bNnaarZ05ekhhj6ktQQh3ek1jjWODkzeBu5PX1Jaog9/W1kh2rrZrAjJU01e/qS1BBDX5Ia4vCO1ArHG7fXjIxFGvrbwM/W9vCGLWnrHN6RpIbY05fmmaedu2OKh3oM/Qnxs7VzpvjzJE09h3ckqSH29LfIHv7u8svdNXhQTpcpO0gNfWkeGPTTb0rGJQ39TfDzNX2m5PMkTb1eY/pJjia5kGQ5yf1r7L8xyZPd/rNJFlbte0e3/UKS/zC5qu+s5LWHptvcv0+rD8a5b+ycWus93KH3cWzoJ9kDPAzcDRwG7klyeKTYSeBKVd0BPAQ82D33MHACuBM4Cvy37udNpfXeBz9Ts+n13s+pfk9ntuKaBX2Gd+4ClqvqIkCSJ4BjwPOryhwD3tUtPwW8O0m67U9U1aeAv0uy3P2890+m+mvYwofCUYHGmJ9qUJ/hnVuBF1atr3Tb1ixTVdeAq8C+ns+VJO2QPj39tfpDo53i9cr0eS5JTgGnutWXk1zoUa/Xcwvw0hZ/xqyxzfOvtfZCa20ejlRsts1f2adQn9BfAW5ftX4b8OI6ZVaS7AVuBi73fC5VdRo43afCfSRZqqrBpH7eLLDN86+19oJt3g59hneeBQ4lOZjkBoZfzC6OlFkE7u2WjwPPVFV12090V/ccBA4Bfz2ZqkuSNmpsT7+qriW5DzgD7AEerarzSR4AlqpqEXgEeLz7ovYywz8MdOXey/BL32vAT1bVZ7apLZKkMXrdnFVVTwNPj2z75VXLrwDfv85zfw34tS3UcTMmNlQ0Q2zz/GutvWCbJy7l7YuS1Axn2ZSkhsx06G9leohZ1KO9P5Pk+SQfTvK/kvS6hGuajWvzqnLHk1SSmb/So0+bk/yn7r0+n+R/7nQdJ63HsX0gyfuSPNcd32/djXpOSpJHk/xzko+usz9Jfqv7fXw4yZsn9uJVNZMPhl8qfwz4KuAG4EPA4ZEyPwG8p1s+ATy52/Xe5vZ+F/BF3fKPz3J7+7a5K/clwJ8DHwAGu13vHXifDwHPAV/WrX/5btd7B9p8Gvjxbvkw8PHdrvcW2/ztwJuBj66z/63AnzK81+mbgbOTeu1Z7ul/bnqIqnoVuD49xGrHgMe65aeAI930ELNobHur6n1V9clu9QMM74uYZX3eY4BfBX4deGUnK7dN+rT5x4CHq+oKQFX98w7XcdL6tLmAL+2Wb2aN+31mSVX9OcMrHddzDPjdGvoA8IYkXzGJ157l0N/K9BCzaKNTWpxk2FOYZWPbnOTrgdur6o93smLbqM/7/EbgjUn+KskHkhzdsdptjz5tfhfwQ0lWGF5J+FM7U7Vds21T2MzyfPpbmR5iFvVuS5IfAgbAd2xrjbbf67Y5yRcwnNX1R3aqQjugz/u8l+EQz3cyPJv7iyRvqqr/s8112y592nwP8DtV9V+TfAvD+4LeVFWf3f7q7Ypty65Z7ulvZHoIRqaHmEW9prRI8u+BXwLeVsPZTWfZuDZ/CfAm4H8n+TjDsc/FGf8yt+9x/UdV9emq+jvgAsM/ArOqT5tPAu8FqKr3AzcxnKNmXvX6vG/GLIf+VqaHmEVj29sNdfw2w8Cf9XFeGNPmqrpaVbdU1UJVLTD8HuNtVbW0O9WdiD7H9R8y/NKeJLcwHO65uKO1nKw+bf4H4AhAkn/LMPQv7Wgtd9Yi8MPdVTzfDFytqn+axA+e2eGd2sL0ELOoZ3t/A/g3wO9331f/Q1W9bdcqvUU92zxXerb5DPCWJM8DnwF+vqo+sXu13pqebf5Z4L8n+WmGwxw/MsMdOJL8HsPhuVu67yn+C/CFAFX1HobfW7wVWAY+CfzoxF57hn9vkqQNmuXhHUnSBhn6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ15P8BWEss93VK764AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.bar(x, a, width = 1/n, color = \"b\")\n", "plt.bar(y, b, width = 1/m, color = \"r\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compute the cost matrix, here we use Euclidean distance squared, $C_{i,j} = \\norm{x_i-x_j}^2$." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "C = np.abs(x[:,None]-y[None,:])**2\n", "plt.imshow(C);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Kantorovitch-Hellinger / Wasserstein-Fisher-Rao Transport\n", "------------------------------------------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The unbalanced OT problem corresponds to\n", "$$\n", " W_\\rho(a,b) \\triangleq \\umin{P \\in \\RR_+^{n \\times m}} \\dotp{P}{C}\n", " + \\rho D_\\phi( P 1_m|a )\n", " + \\rho D_\\phi( P^\\top 1_n|b ), \n", "$$\n", "\n", "where here $D_\\phi$ is a so-called [Cizarr f-divergence](https://en.wikipedia.org/wiki/F-divergence)\n", "$$\n", " D_\\phi(h|b) \\triangleq \\sum_{i} \\phi(h_i/b_i) b_i.\n", "$$\n", "\n", "The most well known are the KL divergence obaind when using $\\phi(s)=s \\log(s)-s+1$ and the total variation for $\\phi(s)=|s-1|$. These are the two examples we will consider in this tour.\n", "\n", "The parameter $\\rho$ controls the amount of mass conservation relaxation. When $\\rho \\rightarrow +\\infty$ one recovers the usual (balanced) OT. When $\\rho \\rightarrow 0$, no transport is performed. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the optimiztion variable $P$, the OT coupling." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "P = cp.Variable((n,m))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first consider the case of $\\phi(s)=s \\log(s)-s+1$, in which case the unbalanced OT problem is called either \"Kantorovitch-Hellinger\" or \"Wasserstein-Fisher-Rao\".\n", "\n", "In this case, assuming $n=m$ and that the sampling points are equal, $x_i=y_i$, one has that $W_\\rho/\\rho$ converges toward the squared Hellinger distance as $\\rho \\to +\\infty$\n", "$$\n", " W_\\rho(a,b)/\\rho \\longrightarrow \\sum_{i} (\\sqrt{a_i}-\\sqrt{b_i})^2.\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the CVXPY problem and solve it." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "u = np.ones((n,1))\n", "v = np.ones((m,1))\n", "q = cp.sum( cp.kl_div(cp.matmul(P,v),a[:,None]) )\n", "r = cp.sum( cp.kl_div(cp.matmul(P.T,u),b[:,None]) )\n", "\n", "constr = [0 <= P]\n", "# uncomment to perform balanced OT\n", "#constr = [0 <= P, cp.matmul(P,u)==a[:,None], cp.matmul(P.T,v)==b[:,None]]\n", "\n", "rho = .1\n", "objective = cp.Minimize( cp.sum(cp.multiply(P,C)) + rho*q + rho*r )\n", "\n", "prob = cp.Problem(objective, constr)\n", "result = prob.solve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the solution coupling." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASQAAAEyCAYAAABTdq1qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFPhJREFUeJzt3XuwnXV97/H3d+9cCJeYbCwYEmISSFFPTwuYkQRtzRCsSp3ieOlgnRp7mOZcPK3Ftgg95ww4tVNvI+ocDzYVhbYUL+gIQx0dG4mVKoFwGbkZCYGGTQJBbjlySbKTb/94nuy9iTvZYa+99vNba71fM3vWep717L2++yH7y+f3ey4rMhNJKkFf0wVI0n42JEnFsCFJKoYNSVIxbEiSimFDklQMG5KkYrSlIUXEWyJiU0RsjoiL2vEekrpPTPaJkRHRD/wMeBMwCNwKvCcz753UN5LUdaa14We+DticmVsAIuIrwLnAQRvSjL4jclb/MeTQ3jaUI6lpL/Asu3NXjLddOxrSfODhUcuDwBkHbhQRa4A1AEf0Hc2Kue9k78+faEM5kpq2Idcd1nbtmEMaqwv+0rgwM9dm5rLMXDaj74g2lCGp07SjIQ0CJ45aXgBsa8P7SOoy7WhItwJLI2JxRMwAzgOub8P7SOoykz6HlJlDEfE/ge8C/cCXMvOeQ39XQHhKlNTr2jGpTWZ+G/h2O362pO5lLJFUDBuSpGKU05D6xj1nSlKXK6chSep5ZTSkgAgTktTrymhIkoQNSVJBCmlIAf39TRchqWGFNCRJKqUhBdBXRimSmmMXkFSMMhpSBExvy2V1kjpIGQ1JkmjT1f4v2b6EXbvpP3ZgeNXeJ55ssCBJTTAhSSqGDUlSMYoYsuWePQw9so047T+NrHTIJvUcE5KkYhSRkPbbe9T04ed2Sqn3+HcvqRhFJaQ9x4wkpJkN1iGpGSYkScUoKiHNuvGu4ef//51nAHDUNzY0VY6kKWZCklQMG5KkYhQ1ZNv3wgvDz4dmVTf9nzbvFdXy9kcbqUnS1DEhSSpGUQlptJf9480AbP3zMwE44VMmJKnbmZAkFaPYhLTfwqu3APDg/6mS0ol/9aMmy5HURiYkScWwIUkqRvFDtv2H+xd9YS8AD3x0xfBrJ/3TEwDsvfdnU1+YpElnQpJUjOIT0n57H38cgEX/+/Hhdf9+cTXRPf3slwNw/Oec8JY6mQlJUjEiMyf2jREnAn8PvALYB6zNzM9GxADwVWAR8BDwe5n51KF+1uwYyDNi1YTqANhz9msBGFw1A4CBe0Z+p/0nWEpqzoZcx858MsbbrpWENAT8WWa+GlgOfCAiXgNcBKzLzKXAunpZksY14YT0Sz8o4jrg/9ZfKzNze0TMA9Zn5imH+t5WE9KBnnr/yJG4p15VPR7zUPX4K1/48aS9j6TDMxUJaVhELAJOAzYAx2fmdoD68biDfM+aiNgYERv3sGsyypDU4VpuSBFxNPAN4E8zc+fhfl9mrs3MZZm5bLp30JZEi4f9I2I6VTO6OjO/Wa9+LCLmjRqy7Wi1yJdq7pUjw7K59eMvfm85AFs+MTKcyzpAnvQXDuOkEkw4IUVEAFcA92Xmp0e9dD2wun6+Grhu4uVJ6iWtHPZ/A/BD4C6qw/4Af0k1j/Q1YCGwFXh3Zh7yc7Ene1L7pXrgU1V6ok5MM38+0qcX/I0nW0qtOtxJ7QkP2TLzJob/hH9Jc91FUsfqmEtH2umkP3/xyZNDZ712+PnmTy9/0WsznqnS08KPmJykyealI5KKYUIaw7Tv3zb8/OTvH/Di8l8HYPNlyzmYk695dmThlrsOup2kFzMhSSqGDUlSMSbtWrZWNH3Yf7JtvfTM4ee7X7bvxS/Wu/vkD3kXAvWOKb2WTZImg5PabbDw0vFPCTjwdAKA2B9W68fF141cdNz3wzsmozSpaCYkScVwDqlggxePzEXtGqj+Ox2YokY/X/rJTQDsfeKQV+pIU845JEkdxzmkgr3UC3sf+Kvq1irZXy3H6BS178XrZj1a/c/quP/nJTAqhwlJUjFsSJKK4aR2j9r3xtMAGHzjrOF1sf8czjxgmZGhXlSfaM5R26oXX3a1J3hqfE5qS+o4Tmr3qL4fVCdaLvzBxL4/X38qAI9eMHJqwvDE+f7HvfuXR1J4394Xbzu8PGr7uXdXnxWRd9wzseLUsUxIkophQtKExL/dCcAr/m3yf/bQytMBeO686vKaF81l1WnrwPmuFxdXPeyrT3/YN31k6mLoiOr50JHV4+5jqvXeO70MJiRJxbAhSSqGQzYVp3/97QAcM4XvueVjIx8gOjSnml2fOfA8APPmVpPs8fGXD28z/V9uQ5PPhCSpGCYkCVhy0fgfpz7zByP3p1r5mSo1ff57vw3AyRd4guhkMCFJKoYJSTpMu9746PDz7zIbgLM2VB9zdf6WkTmlS5a8Fk2MCUlSMUxIUgu2nlF9KOiFb//vw+tufOQLAJzyTx8A4KS/GH9+ShUTkqRi2JAkFcP7IUltcuXWmwA4+9b/CsCCd/bu3Qu8H5KkjuOkttQm71/4BgBu+Pe/BeBNV//x8GsnvdcP/hyLCUlSMUxIUpv9t1dWSWnTI1cMr1vx/uqUgLlXekrAaCYkScVoOSFFRD+wEXgkM98WEYuBrwADwO3AH2Tm7lbfR+p058w/ffj5LdsuB+DNV57aVDlFmoyE9EHgvlHLHwcuy8ylwFPA+ZPwHpJ6QEsNKSIWAL8DfLFeDuAs4Np6k6uAt7fyHpJ6R6tDts8AFzJyc79jgaczc6heHgTmt/geUtdZ/K01AFx83w0AfOPVxzVZTjEmnJAi4m3AjswcfS/Psc7EHPNU8IhYExEbI2LjHnaNtYmkHtNKQno98LsRcQ5wBDCbKjHNiYhpdUpaAGwb65szcy2wFqpLR1qoQ+o4v/o/bgFg4eYnANj3m28GoO+HvX3C5IQTUmZenJkLMnMRcB7w/cx8L3Aj8K56s9XAdS1XKakntOPEyA8DX4mIjwJ3AFeMs73Usy47+dUAnL/pWwB8+ZRXNllO4yalIWXmemB9/XwL8LrJ+LmSeouXjkgF2PTCvKZLKIKXjkgqhg1JUjEcskkFWP/h1wMwbd3IRy31rXq4qXIaY0KSVAwTklSAGd+5FYDnL1g8vG5WU8U0yIQkqRgmJKkgO545evh5L54iaUKSVAwbkqRiOGSTCrL3oZEh296V1S1v+9ff3lQ5U86EJKkYJiSpIEsuHPlYpAc+uQKAk9Y3VEwDTEiSimFCkgo1Y+dYd4TubiYkScUwIUmFmv6LpiuYeiYkScWwIUkqhkM2qVDTf9F7nw5mQpJUDBOSVKhpz5uQJKkxJiSpUP27m65g6pmQJBXDhCQVqm+Pc0iS1BgbkqRiOGSTChX7HLJJUmNMSFKhYl/TFUw9E5KkYtiQJBXDhiSpGDYkScWwIUkqhg1JUjFaakgRMSciro2In0bEfRGxIiIGIuJ7EXF//Th3soqVesmMnXuYsXMP0xYtZNqihU2XMyVaTUifBb6Tma8CfgO4D7gIWJeZS4F19bIkjWvCJ0ZGxGzgt4D3A2TmbmB3RJwLrKw3uwpYD3y4lSKlXtT3gzsA2PmOMwA48qGtTZYzJVpJSEuAx4EvR8QdEfHFiDgKOD4ztwPUj8eN9c0RsSYiNkbExj3saqEMSd2ilYY0DTgduDwzTwOe5SUMzzJzbWYuy8xl05nZQhlSd8u+6qsXtPJrDgKDmbmhXr6WqkE9FhHzAOrHHa2VKKlXTLghZeajwMMRcUq9ahVwL3A9sLpetxq4rqUKJfWMVq/2/2Pg6oiYAWwB/pCqyX0tIs4HtgLvbvE9pJ6WfdF0CVOmpYaUmXcCy8Z4aVUrP1dSb/J+SFLhemVCG7x0RFJBTEhS4UxIktQAE5JUuH39vXOUzYQkqRg2JEnFcMgmFc5JbUlqgAlJKpwJSZIaYEKSCmdCkqQGmJCkwmV/0xVMHROSpGLYkCQVwyGbVDgntSWpASYkqXC9dE9tE5KkYpiQpMLN+9omAHb80Yrhdcf+3Y+bKqetTEiSimFCkgq39+dPADA0q/vnkkxIkophQ5JUDIdsUofohRMke+BXlNQpTEhShzAhSdIUMiFJHcKEJElTyIYkqRgO2aRO0QPxoQd+RUmdwoQkdYjs/kvZTEiSytFSQ4qICyLinoi4OyKuiYgjImJxRGyIiPsj4qsRMWOyipV6WfaNfHWrCf9qETEf+BNgWWb+GtAPnAd8HLgsM5cCTwHnT0ahkrpfq3NI04BZEbEHOBLYDpwF/H79+lXApcDlLb6P1PNOuOn54edP/pfq7pEDX+quO0dOOCFl5iPAp4CtVI3oGeA24OnMHKo3GwTmj/X9EbEmIjZGxMY97JpoGZK6SCtDtrnAucBi4ATgKOCtY2yaY31/Zq7NzGWZuWw6MydahqQu0sqQ7Wzgwcx8HCAivgmcCcyJiGl1SloAbGu9TEl9P7xj+Pmzv3UmAANNFdMmrczXbwWWR8SRERHAKuBe4EbgXfU2q4HrWitRUq9oZQ5pA3AtcDtwV/2z1gIfBj4UEZuBY4ErJqFOSaP10ZVnEbZ0lC0zLwEuOWD1FuB1rfxcSb3JS0ekDtStl5F0YeiT1KlMSFIn6tIo0aW/lqROZEOSVAyHbFIHclJbktrMhCR1oJlP109e95+rx1vuaqyWyWRCklQME5LUgV5x2Y8AeOij1X2RFt3SZDWTx4QkqRgmJKmTddnRNhOSpGLYkCQVwyGb1MG67QRJE5KkYpiQpE5mQpKk9jAhSZ3MhCRJ7WFDklQMh2xSB1u6djsAWy6pPjhy4Ud+1GQ5LTMhSSqGCUnqYENbHqoeZ81rtpBJYkKSVAwTktQNuuTwvwlJUjFsSJKKYUOSVAwbkqRiOKktdYGZT1Wz2vvecOrwur6b7myqnAkzIUkqhglJ6gLzP1ZdMrLlEyuG1y25qalqJs6EJKkYJiSpi3T6PbZNSJKKMW5DiogvRcSOiLh71LqBiPheRNxfP86t10dEfC4iNkfETyLi9HYWL6m7HE5CuhJ4ywHrLgLWZeZSYF29DPBWYGn9tQa4fHLKlNQLxm1ImfmvwJMHrD4XuKp+fhXw9lHr/z4rNwNzIqI77osgqe0mOql9fGZuB8jM7RFxXL1+PvDwqO0G63XbJ16ipMM14+lRs9rLf716vPknzRQzAZN9lG2sOf4cc8OINVTDOo7gyEkuQ1InmmhDeiwi5tXpaB6wo14/CJw4arsFwLaxfkBmrgXWAsyOgTGblqSX5sS/Hrmn9gOfrE6SPOnmpqp56SZ62P96YHX9fDVw3aj176uPti0Hntk/tJOk8YybkCLiGmAl8PKIGAQuAT4GfC0izge2Au+uN/82cA6wGXgO+MM21CzpcETnDTzGbUiZ+Z6DvLRqjG0T+ECrRUnqTZ6pLakYNiRJxbAhSSqGV/tLXWrhd/YAsO3C6mO2T/hE+R+zbUKSVAwTktSlpv/LbQA8/+YV42xZDhOSpGLYkCQVw4YkqRg2JEnFcFJb6nIL1g8BsP1D1eH/eZ8u9/C/CUlSMUxIUpeb+c+3AvDcyvIP/5uQJBXDhCT1iL5qKon+YweG1+194sDP72iWCUlSMUxIUo9YfPGPAbj/EyNzSUsu/HFT5YzJhCSpGDYkScVwyCb1moLv/W9CklQME5LUY5Z+8bHh5w9eWl1OsvDSMi4nMSFJKoYJSeoxe+/fMvJ8xvEA9M+eXS3v3NlITfuZkCQVw4Qk9bDFf1mfLPmx6mTJJRc1e6KkCUlSMWxIkorhkE0Sx922D4An/mjkOrdj/27qh28mJEnFsCFJ4uivb+Dor28g+xj+eu4dZ/DcO86Y0jpsSJKKEZnNX2k3OwbyjFjVdBmSRnn6D6r5pKh7xMv+8eYJ/6wNuY6d+WSMt50JSVIxPMomaUxz/qE6ytZ36msAePJ9VWI6Zuuu4W36198+qe9pQpJUDBuSpGI4ZJN0SPvuvBeAOXdWy9MWzB9+bdfK06t1v9gNQN+z1XAuXtg98gP2DBGPTj+s9zIhSSpGEYf9I+Jx4Fng503XMgEvp/Pqtuap0Yk1Q3vqfmVm/sp4GxXRkAAiYmNmLmu6jpeqE+u25qnRiTVDs3U7ZJNUDBuSpGKU1JDWNl3ABHVi3dY8NTqxZmiw7mLmkCSppIQkqcfZkCQVo4iGFBFviYhNEbE5Ii5qup6xRMSJEXFjRNwXEfdExAfr9QMR8b2IuL9+nNt0rQeKiP6IuCMibqiXF0fEhrrmr0bEjKZrPFBEzImIayPip/U+X1H6vo6IC+p/G3dHxDURcURp+zoivhQROyLi7lHrxtyvUflc/Xf5k4g4vd31Nd6QIqIf+DzwVuA1wHsi4jXNVjWmIeDPMvPVwHLgA3WdFwHrMnMpsK5eLs0HgftGLX8cuKyu+Sng/EaqOrTPAt/JzFcBv0FVf7H7OiLmA38CLMvMXwP6gfMob19fCbzlgHUH269vBZbWX2uAy9teXWY2+gWsAL47avli4OKm6zqMuq8D3gRsAubV6+YBm5qu7YA6F9T/yM4CbgCC6izcaWPt/xK+gNnAg9QHXUatL3ZfA/OBh4EBqmtEbwDeXOK+BhYBd4+3X4G/Bd4z1nbt+mo8ITHyH3K/wXpdsSJiEXAasAE4PjO3A9SPxzVX2Zg+A1wI7KuXjwWezsyhernE/b0EeBz4cj3U/GJEHEXB+zozHwE+BWwFtgPPALdR/r6Gg+/XKf/bLKEhjXVby2LPRYiIo4FvAH+amc1+EPo4IuJtwI7MvG306jE2LW1/TwNOBy7PzNOornMsZng2lnre5VxgMXACcBTVkOdApe3rQ5nyfyslNKRB4MRRywuAbQ3VckgRMZ2qGV2dmd+sVz8WEfPq1+cBO5qqbwyvB343Ih4CvkI1bPsMMCci9t96psT9PQgMZuaGevlaqgZV8r4+G3gwMx/PzD3AN4EzKX9fw8H365T/bZbQkG4FltZHI2ZQTQRe33BNvyQiArgCuC8zPz3qpeuB1fXz1VRzS0XIzIszc0FmLqLar9/PzPcCNwLvqjcrqmaAzHwUeDgiTqlXrQLupeB9TTVUWx4RR9b/VvbXXPS+rh1sv14PvK8+2rYceGb/0K5tmp5gqyfLzgF+BjwA/K+m6zlIjW+giqs/Ae6sv86hmpNZB9xfPw40XetB6l8J3FA/XwLcAmwGvg7MbLq+Meo9FdhY7+9vAXNL39fAR4CfAncD/wDMLG1fA9dQzXHtoUpA5x9sv1IN2T5f/13eRXUEsa31eemIpGKUMGSTJMCGJKkgNiRJxbAhSSqGDUlSMWxIkophQ5JUjP8A6SjVsL7RZeAAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def remap_plan(P): # boost contrast\n", " return np.log(.001+P)\n", "plt.figure(figsize = (5,5))\n", "plt.imshow(remap_plan(P.value));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the marginals." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFGZJREFUeJzt3X2MXNd93vHvE7Km4DaRC4oBWkkMGYhqS6lBam6kFEiTtKwN2kDNBKUqSjWiFEQJJ2ULNH2T0UZ11ASIUrRCighwmEiITKOVXAZNFw0DAo3SJg1klaTlN8oQsGbUaKUAkUSWgePSMuVf/5hLeTLc1d7dnX2ZOd8PsOCde8/dOYc789xzz9x7JlWFJKkN37LRFZAkrR9DX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktSQrRtdgVE33XRT7dq1a6OrIUkT5dy5c69X1Y6lym260N+1axdnz57d6GpI0kRJ8n/6lHN4R5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDWkV+gnOZDkxSRzSR5cYPu2JE93259Lsmto23cleTbJ+SRfSHLD+KovSVqOJUM/yRbgMeADwF7gviR7R4odAS5V1W3Ao8Aj3b5bgU8CH6mqO4AfBL4+ttpLkpalT0//LmCuqi5U1ZvAU8DBkTIHgSe75ZPA/iQB3g98vqo+B1BVb1TVW+OpuiRpufqE/s3Ay0OP57t1C5apqqvAZWA7cDtQSU4n+UySf77QEyQ5muRskrOvvfbactsgaRqdO/fNH41Nn9DPAuuqZ5mtwPcBf7f794eT7L+uYNXxqpqpqpkdO5acOkKStEJ95t6ZB24denwL8OoiZea7cfwbgYvd+v9ZVa8DJDkFvBf4zVXWW9I0WqxXf239vn3rV5cp1aenfwbYk2R3kncBh4HZkTKzwAPd8iHgmaoq4DTwXUne3R0MfgB4YTxVlyQt15I9/aq6muQYgwDfAjxRVeeTPAycrapZ4HHgRJI5Bj38w92+l5L8ewYHjgJOVdWvr1FbJElL6DW1clWdAk6NrHtoaPkKcM8i+36SwWWbkqQN5h25ktQQQ1+SGmLoS1JDNt3XJUpqzHJuvhou6+WbK2JPX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDnIZB0sZY7XffOiXDitjTl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktSQXqGf5ECSF5PMJXlwge3bkjzdbX8uya5u/a4k/y/JZ7ufj4+3+pKk5Vhy7p0kW4DHgPcB88CZJLNV9cJQsSPApaq6Lclh4BHg3m7bl6vqu8dcb0nSCvTp6d8FzFXVhap6E3gKODhS5iDwZLd8EtifJOOrpiRpHPrMsnkz8PLQ43ng7sXKVNXVJJeB7d223UmeB/4I+FdV9Turq7KkibXamTX7/F5n3HxHfUJ/oR579SzzB8DOqnojyT7g15LcUVV/9Cd2To4CRwF27tzZo0oal2vvFd8nUhv6DO/MA7cOPb4FeHWxMkm2AjcCF6vqa1X1BkBVnQO+DNw++gRVdbyqZqpqZseOHctvhSSplz49/TPAniS7gVeAw8D9I2VmgQeAZ4FDwDNVVUl2MAj/t5J8J7AHuDC22mtFFjrD9uxYasOSod+N0R8DTgNbgCeq6nySh4GzVTULPA6cSDIHXGRwYAD4fuDhJFeBt4CPVNXFtWiIJGlpvb4usapOAadG1j00tHwFuGeB/X4V+NVV1lGSNCbekStJDTH0JakhvYZ3NB36XiLtZZzS9LKnL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhriJZtTbjUz2TofjzR97OlLUkPs6UtaW2v1xSl9ns9T1OvY05ekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSFepz+lxn1ptF+sIk0He/qS1BBDX5IaYuhLUkMMfUlqiKEvSQ3pFfpJDiR5MclckgcX2L4tydPd9ueS7BrZvjPJV5L80/FUW5K0EkuGfpItwGPAB4C9wH1J9o4UOwJcqqrbgEeBR0a2Pwr8xuqrK0lajT49/buAuaq6UFVvAk8BB0fKHASe7JZPAvuTBCDJDwEXgPPjqbIkaaX6hP7NwMtDj+e7dQuWqaqrwGVge5I/DfwL4KdWX1VJ0mr1uSM3C6yrnmV+Cni0qr7SdfwXfoLkKHAUYOfOnT2qJGlTW+9vy1qMt5Jfp0/ozwO3Dj2+BXh1kTLzSbYCNwIXgbuBQ0l+DngP8I0kV6rqF4Z3rqrjwHGAmZmZ0QOKJGlM+oT+GWBPkt3AK8Bh4P6RMrPAA8CzwCHgmaoq4K9dK5DkY8BXRgNfkrR+lgz9qrqa5BhwGtgCPFFV55M8DJytqlngceBEkjkGPfzDa1lpLWw9zqj9zmlpsvWaZbOqTgGnRtY9NLR8Bbhnid/xsRXUT5I0Rt6RK0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNaTX3Dva3DZq6nKnKpcmjz19SWqIoS9JDXF4R9L4bJavSRzlF0G8zZ6+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIY4DcOE2kx3u3uHuzQ5evX0kxxI8mKSuSQPLrB9W5Knu+3PJdnVrb8ryWe7n88l+eHxVl+StBxLhn6SLcBjwAeAvcB9SfaOFDsCXKqq24BHgUe69V8EZqrqu4EDwC8m8exCkjZIn57+XcBcVV2oqjeBp4CDI2UOAk92yyeB/UlSVV+tqqvd+huAGkelJUkr0yf0bwZeHno8361bsEwX8peB7QBJ7k5yHvgC8JGhg8DbkhxNcjbJ2ddee235rZAk9dIn9LPAutEe+6Jlquq5qroD+B7go0luuK5g1fGqmqmqmR07dvSokiRpJfqE/jxw69DjW4BXFyvTjdnfCFwcLlBVXwL+GLhzpZWVJK1Onw9VzwB7kuwGXgEOA/ePlJkFHgCeBQ4Bz1RVdfu8XFVXk3wH8BeAl8ZVeUmbwGa6friPxq8xXjL0u8A+BpwGtgBPVNX5JA8DZ6tqFngcOJFkjkEP/3C3+/cBDyb5OvAN4Mer6vW1aIgkaWm9Lp+sqlPAqZF1Dw0tXwHuWWC/E8CJVdZRkjQmTsMgSQ0x9CWpIYa+JDXE0JekhjgPzgSZhCvjGr8aTtr07OlLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaojTMEhavkmYE6SPBucNsacvSQ0x9CWpIQ7vTIBJPZO+Vu9GzpqliWBPX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDWkV+gnOZDkxSRzSR5cYPu2JE93259Lsqtb/74k55J8ofv3b4y3+pKk5Vgy9JNsAR4DPgDsBe5Lsnek2BHgUlXdBjwKPNKtfx34W1X1l4EHgBPjqrgkafn69PTvAuaq6kJVvQk8BRwcKXMQeLJbPgnsT5Kqer6qXu3WnwduSLJtHBWXJC1fn9C/GXh56PF8t27BMlV1FbgMbB8p87eB56vqayurqiRptfpMw5AF1tVyyiS5g8GQz/sXfILkKHAUYOfOnT2qJElaiT49/Xng1qHHtwCvLlYmyVbgRuBi9/gW4L8AP1JVX17oCarqeFXNVNXMjh07ltcCSVJvfXr6Z4A9SXYDrwCHgftHyswy+KD2WeAQ8ExVVZL3AL8OfLSqfnd81Z5+kzrJ2kIanLJ8Ok3Ti3IhjcwQuGRPvxujPwacBr4EfKqqzid5OMmHumKPA9uTzAE/AVy7rPMYcBvwk0k+2/18+9hbIUnqpdfUylV1Cjg1su6hoeUrwD0L7PfTwE+vso6SpDHxjlxJaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWpIrztytX6c3kTSWrKnL0kNsacv6Z2N8/RzZub6dWfPju/3j8OUTwtr6EtaWwsF/WLbN9sBYAo5vCNJDbGnL2nzsNe/5gx9LWqps3Lfk1rUUi8ebRhDX9fx/apNwV7/mjD0Baws6K/t4/tRmhyGvqTxWavTRHsYY2PoN24c71HPwqXJYehvAtM+9cJCpvz+F2nTMvQb5Ae1mlieVq6aoa+x8j05JZZz+jnNvYgpPCU19Bsyze9NSf0Y+pImk6eVK+LcO5LUkF49/SQHgJ8HtgC/XFU/O7J9G/AJYB/wBnBvVb2UZDtwEvge4Feq6tg4K6+lbeSQjpdWTzHHCifWkqGfZAvwGPA+YB44k2S2ql4YKnYEuFRVtyU5DDwC3AtcAX4SuLP7kaTxc6intz7DO3cBc1V1oareBJ4CDo6UOQg82S2fBPYnSVX9cVX9Lwbhr3U0M2NnTNL1+gzv3Ay8PPR4Hrh7sTJVdTXJZWA78HqfSiQ5ChwF2LlzZ59dpkKLN2UtxO/NldZPn9DPAutqBWUWVVXHgeMAMzMzvffTZPDMW9o8+oT+PHDr0ONbgFcXKTOfZCtwI3BxLDVUbw7naE1NygvMXsY76hP6Z4A9SXYDrwCHgftHyswCDwDPAoeAZ6rKHrs0SRxvfGdTcnfukqHfjdEfA04zuGTziao6n+Rh4GxVzQKPAyeSzDHo4R++tn+Sl4BvA96V5IeA949c+SNJWie9rtOvqlPAqZF1Dw0tXwHuWWTfXauon6aM1+5LG8tpGKbApAy1SuvOXsZ1DP115rDp4qZkyHT62KuYKs69I0kNsac/oSa98+VVddLGMPQlTT97GW8z9KWWLfYh06SfSq61CZ47xNCfIL4PJa2WH+RqwzkjqLR+7OmvEy/VXJ4JPnvWZtf4+L6hPwHsBWtd+EJrgqGvTaPxDpi0Lgx9qTWONX7TansaE3gbuaG/SXmmLWktGPpryA7Vyi00T9aEdKQmi72L5njJpiQ1xJ7+JmPHS9ogjUzDbOhrUxs+CPoFnKt0bbzRnsXamJAPdQ39NeBY/trwhi2tiym/dtjQ3wTseElaL4a+Jsa1g6PDPMswfNpp72L9bOKhHkN/TFYypDMzA2fZx2Y9gZxhc45TbeL3k6bNFA71GPqjkhXu+N5l77HZX0JnWdtEXelBxQ93l8He/djsW+F/5WeWUXY9Xs/TF/orDu2VObeCsNfAOA4qyeDAYfgP8UqCNXGOb6b+cs7P39t1bj6zxp2ovqYv9NWUtw8cWU5/iqk7SuzLN4N+OJy0ebx36Mx2Iw8Ahv4K2LvffIb/Jvv6nFCv9IxwDQ8WqzlJLYN+XV07sG7eT+QW1yv0kxwAfh7YAvxyVf3syPZtwCeAfcAbwL1V9VK37aPAEeAt4B9V1emx1X4dGfST49rfqlf4L9cqhw/D6g8axfoOYWpxqx3yud7anwEsGfpJtgCPAe8D5oEzSWar6oWhYkeAS1V1W5LDwCPAvUn2AoeBO4A/D/z3JLdX1Vvjbsg4GOzTZam/55ocFJZgYE+vlR4A1lufnv5dwFxVXQBI8hRwEBgO/YPAx7rlk8AvJEm3/qmq+hrwe0nmut/37Hiqf70v8RdXvO+7+eoYa6LNbjWvFemdfJIPX7fuw3xyA2pyvT6hfzPw8tDjeeDuxcpU1dUkl4Ht3fpPj+x784prK0kTaqEDwfW+tOb16BP6C52Pjg5MLlamz74kOQoc7R5+JcmLPer1Tm4CXl/l75g0tnn6tdZeaK3Ng8+MVtrm7+hTqE/ozwO3Dj2+BXh1kTLzSbYCNwIXe+5LVR0HjvepcB9JzlZVU5cz2Obp11p7wTavhT5fonIG2JNkd5J3MfhgdnakzCzwQLd8CHimqqpbfzjJtiS7gT3A/x5P1SVJy7VkT78boz8GnGZwyeYTVXU+ycPA2aqaBR4HTnQf1F5kcGCgK/cpBh/6XgX+wWa9ckeSWtDrOv2qOgWcGln30NDyFeCeRfb9GeBnVlHHlRjbUNEEsc3Tr7X2gm0eu9SU3Y4uSVqcX4wuSQ2Z6NBPciDJi0nmkjy4wPZtSZ7utj+XZNf613J8erT3J5K8kOTzSX4zSa9LuDazpdo8VO5Qkkoy8Vd69Glzkr/T/a3PJ/mP613Hcevx2t6Z5LeSPN+9vj+4EfUclyRPJPnDJF9cZHuS/Ifu/+PzScY3XUBVTeQPgw+Vvwx8J/Au4HPA3pEyPw58vFs+DDy90fVe4/b+deDd3fKPTXJ7+7a5K/etwG8zuBFwZqPrvQ5/5z3A88Cf7R5/+0bXex3afBz4sW55L/DSRtd7lW3+fgZfwvHFRbZ/EPgNBvc6fS/w3Liee5J7+m9PD1FVbwLXpocYdhB4sls+CezvpoeYREu2t6p+q6quzSXxaQb3RUyyPn9jgH8D/BxwZT0rt0b6tPnvA49V1SWAqvrDda7juPVpcwHf1i3fyAL3+0ySqvptBlc6LuYg8Ika+DTwniR/bhzPPcmhv9D0EKNTPPyJ6SGAa9NDTKI+7R12hEFPYZIt2eYkfwW4tar+23pWbA31+TvfDtye5HeTfLqbBXeS9Wnzx4APJ5lncCXhP1yfqm2Y5b7fe5vk+fRXMz3EJOrdliQfBmaAH1jTGq29d2xzkm8BHgV+dL0qtA76/J23Mhji+UEGZ3O/k+TOqvq/a1y3tdKnzfcBv1JV/y7JX2VwX9CdVfWNta/ehliz7Jrknv5ypodgZHqISdRrSoskfxP4l8CHajC76SRbqs3fCtwJ/I8kLzEY+5yd8A9z+76u/2tVfb2qfg94kcFBYFL1afMR4FMAVfUscAODOWqmVa/3+0pMcuivZnqISbRke7uhjl9kEPiTPs4LS7S5qi5X1U1VtauqdjH4HONDVbV5JzNfWp/X9a8x+NCeJDcxGO65sK61HK8+bf59YD9Akr/EIPRfW9darq9Z4Ee6q3i+F7hcVX8wjl88scM7tYrpISZRz/b+W+DPAP+5+7z696vqQxtW6VXq2eap0rPNp4H3J3mBwTfS/bOqemPjar06Pdv8T4BfSvKPGQxz/OgEd+BI8p8YDM/d1H1O8a+BPwVQVR9n8LnFB4E54KvA3xvbc0/w/5skaZkmeXhHkrRMhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ35/6pFIh2+irLvAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "a1 = np.sum(P.value, axis=1)\n", "b1 = np.sum(P.value.T, axis=1)\n", "plt.bar(x, a1, width = 1/n, color = \"b\")\n", "plt.bar(y, b1, width = 1/m, color = \"r\")\n", "plt.bar(x, a, width = 1/n, color = \"b\", alpha=.2)\n", "plt.bar(y, b, width = 1/m, color = \"r\", alpha=.2);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Shows the impact of $\\rho$ on the solution." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAB0CAYAAADQOaYgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEFNJREFUeJzt3XtUVPeBB/DvvTO8RhhfRRFhEAWOD7QIlfBK9Wi7bm1O0pjdrM2r2T0xNTmtJ0lPbLTJsSanmhOb2Haz0abtOWlT09VmY7qbZuU0pnXVASWIIijKGwFBBcmgKDD33v2DYPGBzOPO/O6d+X7+kSEzd74nM5fv/O7vzu9KmqaBiIjIaGTRAYiIiG6HBUVERIbEgiIiIkNiQRERkSGxoIiIyJBYUEREZEgsKCIiMiQWFBERGRILioiIDMnqzZ0jrTYtWokOVJawdg1XMKD1S/5uJ1KO1qI1mx6R6DZ6cemipmnx/mwjUo7RorUYvSLRTfR4jYAvXifYAK62oztP/955VVAxEeORqxb5nopGdVjbp8t2Yix25CqLddkW3eoT7f1mf7cRrcWgMPkxuFvb9IhEN9HjNQKGXqeC+H+GcrFLj83RCJ7+veMhPiIRJL8HyxQELCexWFChhn/3TEHtviQ6AnnImpIsOkLY8q6g+KnPBPgamYF65YroCOQptwLL5EmiU4QlrwpKu3oNWmFWoLKQHmQWlFnImbNFRyAPuNvaoXR1Q86cDTkuTnScsOL1IT53jCUQOUgvqsZPeyahVtVA+kqm6BjkIbWqBlLyNNExwor3BRXLgjIybXAQqoM7kVnIZ1qA3PmiY5CHlJNnIC2cB9nGr3IEg9cFFVt1IRA5SEfKuAjREchDissFubKWh85NRKuohpSaDGnhPFji/f66Fd2BV9+DAgClrhF9998F257DgchDOrB+VgNVdAjymHrtGqRDx4Zu5M6HFmGB9eJlKKfrxAajUSnVp4f+/eK2ZU461OjIoRtVtdAGB8QECzFeFxQAtH9VQtoevaOQXtRr12CdlgD3uQ7RUchbR05AAiDNcEArzILcNwitolp0KhqDcqr2+s/ygtmA5e9TIVJdC9TeXhGxTM+ngsr4USVOv5OD9MfL9c5DOml5eCYSf8qCMit3UwukphbIkyfBPeLw3/WRFhmWWllzw215bgakKMcNv5M0Deqxk8GMZUo+FZTa14fUd4Ha32Yj/TtH9c5EOnDsbEDjSwVIfsUpOgr5QenqhnSo+/rtm+eqrOddUGobgh2LvKCcPHPL7zQActbc2z9A06AePxXYUCbhU0EBgHVfOdL3AdKi+ZAG3PwfajDucx1IfmVoBNW6vgBJW1hUoeDmEZQyyv0skyfBnXHjCgjSiEVPLY0dUDrP6x2PvKDHCErOnD10qtudJp1lAJIETZYhKcotIzwj87mghmllJ4BPp6O+PB+zni/RIxPpLGmLE41b8jGpWsP435eKjkNBoHR1QyrpHv0OczOAlKnXb0rD3TVy5W5Ng6W9i3OZBqZWeVc2st0O97IcDMZaMDhORr9dwuTqa5APVAQooX/8LigA0Ja2Qf1FErIqgGML9dgi6S11fQkuPZ6PhlfzEdcExO/gh4lwdvNhp1EvKDFzBqSF8wDcOAK7/jhJAiwS1AgLBu0R0KwSov5cpnNa0ovicsG6rxxWAMMXfGndUIC+b+dCHpCR9oyxPsDqUlAAkL72MCptNnytqhOfZHI5ECOa+E4JJgK4/GAe6rdyxEtjczc0Xf/5TldFkgBEArDY7eh6NB8DcRIuOzSMrwMm/5rvMyNL2jx0+P/CU/no/igDk+65dc5MFN0KChg6eWK4nBJL4/BIvBOvzeK35I0mdncpYnf//bZ7aQ6a7olA5OcyHJs4V0W+U1wuTHh3qJCmjPh90yv5kAclOF7m+8uo4reXANuBjmcK8J0n9mL3ln/A+J1iR1QBu9xGe14vXps1H5saeCq60Vk/LUfac6VwFF9G3bY81G3L4/I7pKsZL5UgZUkzsC9JdBQaQ8LPnCjOtGP1ix/izI5coVkCfj2ojTNz8O/Nh5BeFhXopyJ/lVYi7dlSpD1bipYVcdfLqu6NPNHJKARoS9tQdy4ey6tcoqOQB3bPScCLi/9baAZdD/GN5vsphWh4NRuHzv4UTzQ8gP7FPCvI6Bw/vvFQTN0beUNnemlA6p/6DXvWDxlb2iMV+M36f8TbDW/i5ZnZouPQGHbPSRga9S5rFfL8QSkoAJj5QgkefaEQQAcu/TkdJVm78JXPHsKU+8xzTn44S3tu9GPRTa/kX/8uhqQBMR0SprzFuQa6vaQtTry8JRu/ajmIb765Domv8b1iaMta8YO6arx9bjF6774Y1KcOWkGNNPGbtViBbEz9tBd7249heSJXcjazGS/deJaWunghzr5YAGiA9EVpxZ5VhU+4krGsdhThr2e3YqnyPKa9zpIystfT5mFz48f40af3Q1vaFrTnFVJQw7SlbViOLBS3H8Nc5yNI/qcqkXFIJ/L+CiTvv/F3WmEWOp4tGBplqYA8qEFWAKiArACTjru4KGoYeji5EGVtP8ey5u9h3Pu8QoKRbUjNxY7mXViDoqA9p9CCGrY8MQtzD3RgV9tRzC95jEUVgqRDx5BwaPT/7l6Sjb5VeUMjLlWDNLyGjwSoFmAgTobbJmEgDly2KcTcO30RDrb/Esvf55EUo1uTUoQzO3KRseZIUJ7PEAUFAL13X8QKZCMZVRj4Swr+OPs9LCl7EtNX8lN1OLD87Sh8+Xp31P4ELJl8BsWZdt0zUfAsT8zCr1oOYrUjeJ/OyTcZa46gOEhTMwE/zdwXkV9vxsPJhYi0uvFW80HRccjA+hd3oDjTzu/bhYB7jq7GuvoTomOQB+b+x9N4pyXwf5sNWVDD4u89jadTirCpoRzte0ZZmp4IwLrnnsLHbUdRvzVfdBTy0bRvncJvzxei4DivRmt0yT9x4uuffRdr6wJ7FrahC2rYxpk5uHxhHJ6qrUPbB/NExyEDivnwCFZMz8aBf9mK1v/ie8SsOvNdiLNcQ/dHGaKj0Bimr6zGrguBXWnCFAUFABlPlmF7ehqmr6xGy8YC5FSoaHiPk6p0o8cdRUh6oBqXHudIyqyKM+2GWrCURteZ78JTtXUB275pCmokxyYnyhfKsJXZMLWEk+N0qyObt4uOQH6Kd04QHYE88HZONu492RWQbZuyoIYlbHOiM98FdV+y8EUNyVhSP3wSD5ziFWPNTNUktG4oEB2DxqC4XNhW8TVc/B/9D8uauqCGycvOYtpfZXR8OAf1r3NhUwIynj4CR0QX1Lt5BU2z6iq8hNx7eVafGcx6uAJzJp+HZeJEXbcbEgUFAHG7SpHwrVOY9YNSuJfl4MxbuWhfx09f4Wxb2hwUvcnVCcysPa8X5/80W3QM8sCFgh7UbU/WdZshU1AjWfeVI+PpI0hw9qHxVU6Wh7O//bAQ6j59dxoKLlfjBHSt5n5sBqmrKnXdXkgW1DD54DGkvlCC1g0F6HmUb/BwFLm3DP1uwyyYQj5Ie6YUXdnK2HekkBPSBTUsabMTkyp70Lm2AH333yU6DgXZ+c9jRUcgP1k/t4iOQB66sEa/wUDYfLRUj5/C1ONDP8uZs9GzYAImVPVAreT1qEJdyoOcaDe79N91gWMoc9Askm7bCosR1M3UqhrY3yuFEhuF3lU864/I6JST/OKuWUT2arptK2xGULcjOY8jzglYZzjgbmoRHYeIyPQirqi6bSssR1A3cy2cBusMh+gYRESmx4LSmW3PYbibWnD5QR7uIyLyR+TeMiBvgS7bYkGNYGu/BktaqugYRESmJg3qc0oLC2oE+eAxuL48RXQMIiJTk9z6HOZjQd0k5iIvlkZE5BdFnzP5WFA3kfdXQIqKEh2DiMi0JJUjqIDpXpUtOgLpjGdpEpkPC+o21AjRCUhvruxpoiMQkZdYULcx5cAFnnIeYjS+002PZ9iGH+62t6GcroMmAVphlugopBPVqt/6YCSGZuPccLhhQY0iblcpWpfYRMcgnXAERWQ+3G3vQI0UnYD0ouq4wjIJIvE1DDcsqDtI2ejkFXlDBEdQRObD3XYMif/nFh2BdKCG9br9IYIjqLDDghpD1MdloiOQDjRekJUoeAbdkKOj/d4MC8oD/SsWiY5AfuIhPqLgUWobIM1K8Xs73G090PbIoOgI5CdN5uEhIrNhQXkgY90FNO+eLzoG+WHKW050ri0QHYP8oHEOKuxw6tgD7tY2TLbzS4Jmx8N8RObCXdZDPfsT0LiZp5ybGQvK5Pj6hR2OoDyUtMWJqP0J6BcdhHzGgiIyF+6yXuhf3AHH4XGiY5CPeKq5uWnl1ZBy5omOQR7SY86QBeUlu/Wq6AjkI46gQgBPlAgr3GW9dOiNXNERyEeJB6+i+984j0hkFiwoL43fWSo6AvlIPlCBK9P4CZzILFhQPvh2TbvoCOQjzkMRBYekw3JHLCgf2GSey2dafMebG+egTEM5XQekzfBrG9xdffCbjFSsrasRHYN8oPHvG5FpsKB8FC1xfT4ziuoBkMtlq4jMgAXlo9dmzcfmxiOiY5CXErY50XRfrOgYROHBz4ZhQfnh3e4CtG7gAqRmw8N85iU3tsMyN0N0DAoSFpQfTuW4see7W0XHIG+xoExLudgFJZYLN4cLFpSfIqCJjkDeYkERmQILyk9rUopERyAvpb5QIjoCEXmABaWDyYcmio5ARGQ46onTfj2eBaWDnzk+Eh2BiMh4NP+mQFhQOth8fjHqX88THYO8oBZliY5AvpI5iRguWFA6OJXjxvsrfw5pEb8AahaN98WIjkC+Kq0E8haITkFBwILSyQ9T78LvP9ghOgYRUchgQelIFR2APBbZI/FTOJHBsaB0tOpf12J9faXoGOSB5J84Uf/AONExiEKedXqiz49lQeko4pNyfH/HGtExyFMSv2RtVpKbxyvMQv3SeJ8fy4LSWeJWp+gI5CHH3kG0r+NaimakfVYlOgIFAQsqABJL40RHIA9EfFKOq/EcRREZFQsqAL43dZ/oCERExqBpPl8JmQUVABtSc7Gu/oToGOSBWc+XoH5rvugY5APrzBmiI5AH1MoayAtm+/RYFlSA/Pj5J5B3nFfdNQPZLToB+WJwqu+T72QOLKgAsX1wGOMtVyHlzBMdhcaQup6rmxMFlI+H+VhQAVScacdDO4tFxyAKSfKAG7LNJjoGecDXw3wsqADbOTuJl6g2gcbN+bDY7aJjkBe08mooX04XHYM8JLlVSBGRXj2GBRUERf95XHQEGkPqhhLUbuDhWLOJaOqEZQ5LygyU6tOQ5s7y6jEsqCD4y7qvImb/VNExaAxTylV0reYZfWbiPtcB9wQe5gtVLKggiNxbhs4dqXD9r3efHii4Yv94GJoM9K28S3QU8oK1uhHSQo5+zUA9ccar+7OggsT+h1LYv1EPAGjeVAD17oWCE9HtfOmXJbB9cBg9j+VDWZItOg55QHG5oFVUw5qSLDoKjUVVvDpZwhrAKDSKlI1O9H9jEfofykNUj4Loc32QTh4UHYtGmPC7EliTpqN/STaslwdg6eoFBt3AWdHJaFQDg0PzUSdFB6E7UStrIMVEe3RfSfPimvGSJF0A0OxjLrqzFE3T4v3dCF+jgPP7deJrFHDcl4zPo9fIq4IiIiIKFs5BERGRIbGgiIjIkFhQRERkSCwoIiIyJBYUEREZEguKiIgMiQVFRESGxIIiIiJDYkEREZEh/T/icYUD2rUsTgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "rho_list = np.array([.03, .1, .5, 1])\n", "for k in range(len(rho_list)):\n", " rho = rho_list[k]\n", " objective = cp.Minimize( cp.sum(cp.multiply(P,C)) + rho*q + rho*r )\n", " prob = cp.Problem(objective, constr)\n", " result = prob.solve()\n", " ax = plt.subplot(1,len(rho_list),k+1)\n", " plt.imshow(remap_plan(P.value));\n", " ax.set(xticks=[], yticks=[])\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADFlJREFUeJzt3bF620YWBlDAX6qt5doss62px8nj5XEIl0mt2q63xRYbrUR+BDgAB8CdmXMqy4IIxMHlP/cCAvtxHDsAiObL0QcAAPcIKABCElAAhCSgAAhJQAEQkoACICQBBUBIAgqAkAQUACH9tmTjl5eX8XQ6bXQoALRgGIZf4zh+fbTdooA6nU7d5XJZf1QANK/v+7eU7RYFFECVhuH66/P5mOPgioAC2nQbSlPfE1aHcZMEACHpoIB2zHVNKT+jm9qVDgqAkHRQQL3WdExLXk9HtSkdFAAhCSgAQjLiA+qSe6yXui/jvuwEVKVSalQ9AZEJqIosXTha/AGRuQYFQEg6qMLlGrfrpijantedUo5BEWWhgwIgJB1UgbZeLFoIAhHooAAISUABEJIRXyGOugZs3EdIEW6KmOOZfVnooAAISUABEJKAAiAk16CCijhidz2KQ0UsilSKZxUdFAAhCSgAQjLiC6SkCYaJBbA1HRQAIemggLhKGiukMn5IJqAOVkP9qTdgC0Z8AIQkoAAIyYgPiKOGmfcSHio7S0AdoOYadD0KyMWID4CQBBQAIRnx7aTmsd4U4z6StFgcUxTNFR0UACHpoDZkYfjBwhBYSkAB+7N6e8yqzogPgJh0UJlZGD7mdxOBFAIK2J6V23MaXdUJqAzU3nOM2oF7BBSwDSu37TSyqhNQK0WsvdfXPK9zueR5nTUaqTu2lKsQUhxZLA0QUAscGUp71lzKvvaoS2FVoL2KZM+CmDN1HHsGV8WFIqAOFqXOlopQlwSwdSDVViBdt22RVHYzhYB6IFf9Tb6hd2WfQO9eu//9Q20VXBUvEsuTO5RKDaG19gyvwgunrYDq+8lvDd33TXdde2PxMGgzvgdNvT2eux8fX4xjvh3SdcPQnZ/4fzjkPAFqNhFe5yzvIB+V8+NBvUYpn35ccCSvr6/jJUfCzwTFFrYOH8pxFWJ7W1j1e5TJ98m4nyZs4soTZB9ugyxXcPV9P4zj+PBEWhRQfd//7Lru7ZkDA6B538Zx/Ppoo0UBBQB78bBYAEISUACEJKAACElAARCSgAIgJAEFQEgCCoCQBBQAIQkoAEISUACEJKAACElAARCSgAIgJAEFQEgCCoCQBBQAIf22ZOOXl5fxdDptdCgAtGAYhl8pn6i7KKBOp1N3ueT9zHsA2tL3/VvKdkZ8AIS0qIMCqMYwpG13Pm97HEzSQQEQkg4KaEdq1zT1M7qpXemgAAhJBwXUa03HtOT1dFSb0kEBEJKAAiAkAQVASK5BVSpl9G58TpVyX3dK3ZeCyk4HBUBIOqiKLF04WvwBkQmowuWaZggrIBoBBZRvz+tOKcdglZeFa1AAhKSDKtDWi0ULQSACHRQAIemggPJEuOY0xzP7shBQhTiqHo37gKMY8QEQkoACICQjvqAijtiN+4A9CSigDBFXbams7lYx4gMgJB1UICUtEC0Iga3poAAISQcFxFXSWCGV8UMyHRQAIemgDlbDAtGCENiCDgqAkHRQQBw1jBSW8FDZWQLqADXXoHEfkIsRHwAhCSgAQjLi20nNY70pxn0kabE4piiaKzooAEISUACEZMS3IZOLDyYXXNmyOF5f07a7XLY7hhwUjYBqWWodv4tezzRk6cm79DWc7CEIqMwidk05annudZbWst9NZJVcJ/LSfQmrwwioiuxZv1P7VcvctXbldtRJnXoMe53wja7qBFQGe3dNEWp2ypqwMmrn/yKf3PdYnW1KQBWitLrtOrXbvNSVW4kn9z17nvCNrOoE1Epbd0211Oy71NptpO6o7QS/dfvfZ4W2ioA6WO11ek+EkT4bmVq5tXiif7Zld1Xxqk5ALZCra2q9VudM1XFlddcGJ/p9Zt/JBNQDz4SS+nyOsCrAbYE46ZfJPU6o7G4/AXWHUIpn6t91HPc9jmZ9Lgon+T5ydFpTb2aFBNcxAdX3h+y267pu6L5v+voa9n0NCafSufux7sUrTL+50vvefbyZDZ0QCiVxUXBOfgeaXoX/6KbDa++SOCSg/u5+P2K3Xdd13b+6/xy2b46x+nzr/533QAL4K3G7I2uU9f7s/sj6en90f978zb6dVz8uiMS+7392Xfe23eEA0IBv4zh+fbTRooACgL34PCgAQhJQAIQkoAAISUABEJKAAiAkAQVASAIKgJAEFAAhCSgAQhJQAIQkoAAISUABEJKAAiAkAQVASAIKgJAEFAAhLfrI95eXl/F0Om10KAC0YBiGXymfqLsooE6nU3e5XNYfFQDN6/v+LWW7RQEFUKVhuP76fD7mOLgioIA23YbS1PeE1WEEFNCOuVBK+RlhtSt38QEQkoACICQjPqBea0Z6S17PyG9TOigAQhJQAIRkxNcANyHRlNxjvdR9Ka7sBFSlpmpUPQGlEFAVWbpwFFZAZK5BARCSDqpwucbtuimKtud1p5RjUERZ6KAACEkHVaCtF4sWgjTh9fXjzz5GKCQBBbTjcyjN/b3ACkFAAeVJHSNMBdKSn1sTVh6JlIWAKsQzY72UGp2qQeM+mmcUeBgBVamlC0c1SDXWdk2E4y4+AELSQQW1ZqSXa+E41U0Z9xHWXl2Tmyl2JaAKZ5pBMyL8Mu5aVnerCChmuTYFMxTIpgQUUB6jgyYIqEC2/tWOZ00tFk0sgC0IKCCukq47pY77XI9KJqAKYaJB8xRBcwTUwUpaIH5m3AdszS/qAhCSDiqokqYZn491HI87DipQ6kjhVuov9Hqo7CwBdYBaavAe13/JqqSVGtkJqEDUIsAHAUVWxn1whydOrCKgdlLzWG+KcR9JPp8orY8RFM0VAXWwmutRNwU8Q0AB7Mm4L5mA2tDUWK/mrmmKyQVXjPUeUzQCin0Y98EdPgBxloAC9qdrIoGAysxY77G+v/5aR9WAFm9jzanRJ04IKGB7Vmhp3EBxRUBloGt6zueOSjcFvBNQmQkl+MftLJdl5rqpRu7wE1ArGalvQzdVOIWxjbm7/SoOKwG1gFHevoRVgRQDGQmoB4RSDMIqKGO8Y0yN/yq7209A3SGUYhNWBxNKsVR8rarpgBJE5Zt6rxRcTxJCZZp785q7bT1oeB0TUBlP/qH7nu213vntg/INT5xi5+5HvgOBKFasvM8374bDuG+Q9eOCpWbf9z+7rnvb7nAAaMC3cRy/PtpoUUABwF6+HH0AAHCPgAIgJAEFQEgCCoCQBBQAIQkoAEISUACEJKAACElAARCSgAIgJAEFQEgCCoCQBBQAIQkoAEISUACEJKAACGnRR76/vLyMp9Npo0MBoAXDMPxK+UTdRQF1Op26y+XyeEMAmND3/VvKdosCCqBKw3D99fl8zHFwRUABbboNpanvCavDCCigHXOhlPIzwmpX7uIDICQBBUBIRnxAvdaM9Ja8npHfpnRQAIQkoAAIyYgPqEvusV7qvoz7shNQlUqpUfVE015fp7/niTkhCKiKLF04WvzRnLlQmtpOWB3GNSgAQtJBFS7XuF03RdGmCiG1Y5qT2k0pouwEVKWW1qUpBhCNgCrQ1jcpWQjChNuVn5XdpgRURZ6ZZrgmTDVyjPUIQUAVTi0CtRJQhdjzdw+nuinjPsLYsyDmTBWLZ/ZlIaCA8hklVElAFWjPWvy8r3Hcb78AAiqoKBOMz9zdB+xJQAFl+LxCijjSm7sV1upuFQFViAj1aNwH7Mmz+AAISQcVSMTrTlNMLGCGJ05kIaCCijDSm2Pcx+6iFwXZCSggrpLGCqmMH5IJqIPVUH/qDdiCgAqk1AmGcR+bKbUobnka8yru4gMgJB3UAWoY600x7uMpNRfHPR4qO0tAHayWCcY74z54wLgvmYAC4qhtxcZTBBRwrNbGenPMyK8IqJ1Ef87lFvr+48/GfcBSAgo4Visrtns8EmmWgNpQi13TFN0UsJSAAvbnutNjrkcJKOAArY8UprgF/YqAysxY77HP476uM/ID7hNQwPZuVyU8NtdNNTLyE1AZ6Jqe4wYK4B4BBWxD15TP3O3oFXdTAmolXdM2dFPAOwG1gFDal7AqkK5pH1PXpyrrpgTUA0IpBmEVlEA6XsU3UwioO4RSbMLqYEIprsquVQmofwilMgmrnQilMhU+CmwioIRPG5a+hwq0TvC0JPXNL1CQ9eOCKu37/mfXdW/bHQ4ADfg2juPXRxstCigA2MuXow8AAO4RUACEJKAACElAARCSgAIgJAEFQEgCCoCQBBQAIQkoAEL6Lwk7MSiNVUOUAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for k in range(len(rho_list)):\n", " rho = rho_list[k]\n", " objective = cp.Minimize( cp.sum(cp.multiply(P,C)) + rho*q + rho*r )\n", " prob = cp.Problem(objective, constr)\n", " result = prob.solve()\n", " a1 = np.sum(P.value, axis=1)\n", " b1 = np.sum(P.value.T, axis=1)\n", " ax = plt.subplot(len(rho_list),1,k+1)\n", " plt.bar(x, a1, width = 1/n, color = \"b\")\n", " plt.bar(y, b1, width = 1/m, color = \"r\")\n", " plt.bar(x, a, width = 1/n, color = \"b\", alpha=.2)\n", " plt.bar(y, b, width = 1/m, color = \"r\", alpha=.2)\n", " ax.set(xticks=[], yticks=[])\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Partial Optimal Transport\n", "=========" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can consider other divergences, such as the total variation, which corresponds to the $\\ell^1$ norm of densities, obtained for $\\phi(s)=|s-1|$\n", "$$\n", " D_\\phi(h|a) = \\norm{a-h}_1 = \\sum_i |a_i-h_i|.\n", "$$\n", "The resulting OT problem corresponds to a penalized version of the celebrated partial transport problem. In sharp contrast to the KL problem, this partial OT either transport the mass or detroys it." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "q = cp.sum( cp.abs(cp.matmul(P,v)-a[:,None]) )\n", "r = cp.sum( cp.abs(cp.matmul(P.T,u)-b[:,None]) )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the marginals of the optimal plan. One can see the presence of small spikes, which \n", "are caused by the discretization of the problem. We have displayed up-side down the densities to \n", "highlight that these error actually almost cancel." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD8CAYAAABzTgP2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEpNJREFUeJzt3X+MZedd3/H3B7sxolD/dmJ2PWwqb9QuqVTwyAb1l1vHjhOJbARO6yDK0jpdidaVimmFkdXYOEFNKNQtIoJuE6tLJLBDJJppMVptnLhUUex6ltA0TjG7mIAHW7GTdS1FlpOafPvHPWvfZ7izc3fOmblzZ94v6WruOeeZOd9nZ2Y/85znOfemqpAk6YxvmXUBkqTtxWCQJDUMBklSw2CQJDUMBklSw2CQJDUMBklSw2CQJDUMBklS4/xZF7ARl112We3bt2/WZUjSXDlx4sRXqury9drNZTDs27eP5eXlWZchSXMlyR9P085LSZKkhsEgSWoYDJKkhsEgSWoYDJKkhsEgSWoYDJKkhsEgSWrM5Q1u2jonTrz2/JprZleHpK1jMGii8UBYvc+AkHY2LyVJkhoGgySp4aUkvWrS5aP12nlZSdp5HDFIkhoGgySpYTBIkhrOMexy084rrPf5zjVIO4cjBklSw2CQJDUMBklSw2CQJDWcfN6l+k46n+3rOREtzbdBRgxJbk7yZJJTSe6ccPyCJA92xx9Lsm/V8YUkX0vyL4eoR5K0cb2DIcl5wIeAtwEHgHcnObCq2W3AC1V1NXAf8MFVx+8DfrtvLZKk/oYYMVwLnKqqp6rqG8ADwMFVbQ4CR7vnHwduSBKAJO8EngKeGKAWSVJPQwTDHuDpse2Vbt/ENlX1CvAicGmSvwj8FPAzA9QhSRrAEMGQCftqyjY/A9xXVV9b9yTJ4STLSZaff/75DZQpSZrGEKuSVoCrxrb3As+s0WYlyfnAhcBp4DrgliQ/B1wEfDPJy1X1S6tPUlVHgCMAi4uLq4NHUxh6JdJ653F1kjSfhgiGx4H9Sd4I/ClwK/DDq9osAYeAzwK3AJ+qqgL+1pkGSe4BvjYpFCRptYxdhyj/VBxU72CoqleS3A4cA84D7q+qJ5LcCyxX1RLwEeCjSU4xGinc2ve8kqTNMcgNblX1EPDQqn3vHXv+MvCudb7GPUPUIknqx5fEkCQ1DAZJUsNgkCQ1fBG9XWCrlqmu5qoRaT45YpAkNQwGSVLDYJAkNQwGSVLDYJAkNQwGSVLDYJAkNbyPYYea1b0LazlzT4P3M0jbnyMGSVLDYJAkNQwGSVLDYJAkNQwGSVLDYJAkNVyuusNst2Wqq/lS3NL254hBktQwGCRJDYNBktQwGCRJDYNBktQwGCRJDZer7gDbfYnqWnzFVWl7csSwCywujh6SNI1BgiHJzUmeTHIqyZ0Tjl+Q5MHu+GNJ9nX7b0xyIsn/7j7+vSHqkSRtXO9gSHIe8CHgbcAB4N1JDqxqdhvwQlVdDdwHfLDb/xXgB6rqrwGHgI/2rUeS1M8QI4ZrgVNV9VRVfQN4ADi4qs1B4Gj3/OPADUlSVZ+rqme6/U8A35rkggFqkiRt0BDBsAd4emx7pds3sU1VvQK8CFy6qs0PAZ+rqq8PUJMkaYOGWJWUCftWrzM5a5sk383o8tJNa54kOQwcBlhYWDj3KiVJUxlixLACXDW2vRd4Zq02Sc4HLgROd9t7gd8EfrSq/nCtk1TVkaparKrFyy+/fICy59+JE/O7VHVc8tpD0uwNEQyPA/uTvDHJ64BbgaVVbZYYTS4D3AJ8qqoqyUXAbwE/XVWfGaAWSVJPvYOhmzO4HTgG/B/gY1X1RJJ7k7yja/YR4NIkp4A7gDNLWm8Hrgb+dZLf6x5X9K1JkrRxg9z5XFUPAQ+t2vfesecvA++a8HnvB94/RA2SpGF457MkqWEwSJIavojenNkJq5DOxveElmbPEYMkqWEwSJIaBoMkqWEwSJIaBoMkqeGqpDmx01cjTeJbf0qz4YhBktQwGCRJDS8laXCLi7OuQFIfjhgkSQ1HDNvYbpxwnsSXydB2s9MXRjhikCQ1DAZJUsNgkCQ1nGPYhpxbWNtOv7YrbQeOGCRJDYNBktTwUtI24eWjc+MSVmnzOGKQJDUcMcyYI4X+nJCWhuWIQZLUcMQwA44SNofzDtIwHDFIkhqOGLaIo4St5byDtHGDjBiS3JzkySSnktw54fgFSR7sjj+WZN/YsZ/u9j+Z5K1D1LOdnDhhKMxS8tpD0nR6jxiSnAd8CLgRWAEeT7JUVV8ca3Yb8EJVXZ3kVuCDwD9IcgC4Ffhu4DuBTyZ5U1X9Wd+6ZmlSEIy/ec3y8tbVotc4ipCmM8SlpGuBU1X1FECSB4CDwHgwHATu6Z5/HPilJOn2P1BVXwf+KMmp7ut9doC6NpWjgPm13ujB4NBuN0Qw7AGeHtteAa5bq01VvZLkReDSbv+jqz53zwA1rWk7XFLYyFtfnhll9HnbzD7n7XvuebIdfkZ0bmb1Pdvq827VHy1DBMOkf5rV5a/VZprPHX2B5DBwGGBhYeFc6lv1xef0t777T3nL/5gdCwP/kJZmbWt+C4cIhhXgqrHtvcAza7RZSXI+cCFwesrPBaCqjgBHABYXFzf+rzOr6wSzXGTvAv/dZZ6HPOcyATfLibsdPmk4RDA8DuxP8kbgTxlNJv/wqjZLwCFGcwe3AJ+qqkqyBPxakn/HaPJ5P/A/B6hJ2r3WCv95Dozd6JprZnbq3sHQzRncDhwDzgPur6onktwLLFfVEvAR4KPd5PJpRuFB1+5jjCaqXwH+2byvSFqTf6lr1sZ/Bg2J7WmGYTBukBvcquoh4KFV+9479vxl4F1rfO7PAj87RB2SpnQmJAyI2dsmYTDOO5+l3cxRxGxswzAY52slSZIajhgkjTh62HzbfKRwhiMGSVLDEYOkP8/Rw3DmZJQwzhGDJKlhMEiSGl5KknR23vNw7ubw8tE4g0HSbO3A1xqad15KkiQ1DAZJUsNLSZKm4xLWs5vzeYVxBoMknasdPi/ipSRJUsNgkCQ1DAZJUsM5Bknnzono1+ygSeczHDFIkhoGgySpYTBIkhoGgySpYTBIkhoGgySp4XJVSf3sxqWrO3CJ6jhHDJKkhsEgSWoYDJKkRq9gSHJJkuNJTnYfL16j3aGuzckkh7p935bkt5L8fpInknygTy2SpGH0HTHcCTxcVfuBh7vtRpJLgLuB64BrgbvHAuTnq+qvAN8D/I0kb+tZjySpp77BcBA42j0/CrxzQpu3Aser6nRVvQAcB26uqpeq6tMAVfUN4HeBvT3rkST11He56uur6lmAqno2yRUT2uwBnh7bXun2vSrJRcAPAP+hZz2SZunM0tWduGx1hy9RHbduMCT5JPCGCYfumvIck35CXl34nOR84NeBX6yqp85Sx2HgMMDCwsKUp5Yknat1g6Gq3rLWsSRfTnJlN1q4EnhuQrMV4Pqx7b3AI2PbR4CTVfXv16njSNeWxcXFOltbSdLG9Z1jWAIOdc8PAZ+Y0OYYcFOSi7tJ55u6fSR5P3Ah8C961iFJGkjfYPgAcGOSk8CN3TZJFpN8GKCqTgPvAx7vHvdW1ekkexldjjoA/G6S30vynp71SJJ66jX5XFVfBW6YsH8ZeM/Y9v3A/avarDB5/kGSNEO+iN5uUE7JSJqeL4khSWo4YpA0vJ3yUty76N6FcY4YJEkNg0GS1DAYJEkNg0GS1DAYJEkNg0GS1DAYJEkNg0GS1PAGN0mba95udtulN7WNc8QgSWoYDJKkhsEgSWoYDJKkhsEgSWoYDJKkhsEgSWoYDJKkhsEgSWp457OkrbOd74L2judXOWKQJDUMBklSw2CQJDUMBklSw2CQJDV6BUOSS5IcT3Ky+3jxGu0OdW1OJjk04fhSki/0qUWSNIy+I4Y7gYeraj/wcLfdSHIJcDdwHXAtcPd4gCT5QeBrPeuQJA2kbzAcBI52z48C75zQ5q3A8ao6XVUvAMeBmwGSfDtwB/D+nnVI2m2Wl197aFB9g+H1VfUsQPfxiglt9gBPj22vdPsA3gf8AvBSzzokSQNZ987nJJ8E3jDh0F1TnmPS7Y2V5K8DV1fVTyTZN0Udh4HDAAsLC1OeWtK2deYu6BMnZleDdztPtG4wVNVb1jqW5MtJrqyqZ5NcCTw3odkKcP3Y9l7gEeD7gWuSfKmr44okj1TV9UxQVUeAIwCLi4s1qY0kqb++l5KWgDOrjA4Bn5jQ5hhwU5KLu0nnm4BjVfXLVfWdVbUP+JvAH6wVCpKkrdM3GD4A3JjkJHBjt02SxSQfBqiq04zmEh7vHvd2+yRJ21CvV1etqq8CN0zYvwy8Z2z7fuD+s3ydLwFv7lOLJGkY3vksSWoYDJKkhsEgSWoYDJKkhm/tKWm2xm8y24qb3bypbV2OGCRJDYNBktQwGCRJDYNBktQwGCRJDYNBktQwGCRJDYNBktTwBjdJ28dm3ezmTW3nxBGDJKlhMEiSGgaDJKlhMEiSGgaDJKlhMEiSGi5XlbQ9DbF01WWqG+KIQZLUMBgkSQ2DQZLUMBgkSQ2DQZLUcFWSpO3vXFYouRKpt14jhiSXJDme5GT38eI12h3q2pxMcmhs/+uSHEnyB0l+P8kP9alHktRf30tJdwIPV9V+4OFuu5HkEuBu4DrgWuDusQC5C3iuqt4EHAD+e896JEk99Q2Gg8DR7vlR4J0T2rwVOF5Vp6vqBeA4cHN37B8D/wagqr5ZVV/pWY8kqae+cwyvr6pnAarq2SRXTGizB3h6bHsF2JPkom77fUmuB/4QuL2qvtyzJkk72Zk5hPG5BucVBrVuMCT5JPCGCYfumvIcmbCvunPvBT5TVXckuQP4eeAfrlHHYeAwwMLCwpSnlrRjGQabZt1gqKq3rHUsyZeTXNmNFq4EnpvQbAW4fmx7L/AI8FXgJeA3u/2/Adx2ljqOAEcAFhcXa726JUkb03eOYQk4s8roEPCJCW2OATclubibdL4JOFZVBfxXXguNG4Av9qxHktRT32D4AHBjkpPAjd02SRaTfBigqk4D7wMe7x73dvsAfgq4J8nnGV1C+sme9UiSesroD/f5sri4WMvLy7MuQ5LmSpITVbW4XjtfEkOS1DAYJEkNg0GS1DAYJEkNg0GS1JjLVUlJngf+uMeXuAzYba/LZJ93B/u8O2y0z99VVZev12gug6GvJMvTLNnaSezz7mCfd4fN7rOXkiRJDYNBktTYrcFwZNYFzIB93h3s8+6wqX3elXMMkqS17dYRgyRpDTs6GJLcnOTJJKeSTHo/6guSPNgdfyzJvq2vclhT9PmOJF9M8vkkDyf5rlnUOaT1+jzW7pYklWTuV7BM0+ckf7/7Xj+R5Ne2usYhTfFzvZDk00k+1/1sv30WdQ4pyf1JnkvyhTWOJ8kvdv8mn0/yvYOdvKp25AM4j9Hbhf5l4HXA/wIOrGrzT4Ff6Z7fCjw467q3oM9/F/i27vmP74Y+d+2+A/gd4FFgcdZ1b8H3eT/wOeDibvuKWde9yf09Avx49/wA8KVZ1z1Av/828L3AF9Y4/nbgtxm9S+b3AY8Nde6dPGK4FjhVVU9V1TeAB4CDq9ocBI52zz8O3JBk0luRzot1+1xVn66ql7rNRxm9o948m+b7DKP3BPk54OWtLG6TTNPnfwJ8qKpeAKiqSe+uOC+m6W8Bf6l7fiHwzBbWtymq6neA02dpchD41Rp5FLioeyfN3nZyMOwBnh7bXun2TWxTVa8ALwKXbkl1m2OaPo+7jdFfHPNs3T4n+R7gqqr6b1tZ2Caa5vv8JuBNST6T5NEkN29ZdcObpr/3AD+SZAV4CPjnW1PaTJ3r7/vU1n3P5zk26S//1UuwpmkzT6buT5IfARaBv7OpFW2+s/Y5ybcA9wE/tlUFbYFpvs/nM7qcdD2jUeH/SPLmqvq/m1zbZpimv+8G/nNV/UKS7wc+2vX3m5tf3sxs2v9fO3nEsAJcNba9lz8/vHy1TZLzGQ1BzzZ02+6m6TNJ3gLcBbyjqr6+RbVtlvX6/B3Am4FHknyJ0bXYpTmfgJ72Z/sTVfX/quqPgCcZBcU8mqa/twEfA6iqzwLfyuj1hHayqX7fN2InB8PjwP4kb0zyOkaTy0ur2iwBh7rntwCfqm5WZ06t2+fussp/ZBQK83zd+Yyz9rmqXqyqy6pqX1XtYzSv8o6qmuf3hp3mZ/u/MFpoQJLLGF1aempLqxzONP39E+AGgCR/lVEwPL+lVW69JeBHu9VJ3we8WFXPDvGFd+ylpKp6JcntwDFGqxrur6onktwLLFfVEvARRkPOU4xGCrfOruL+puzzvwW+HfiNbp79T6rqHTMruqcp+7yjTNnnY8BNSb4I/Bnwr6rqq7OreuOm7O9PAv8pyU8wupzyY3P+Rx5Jfp3RpcDLurmTu4G/AFBVv8JoLuXtwCngJeAfDXbuOf+3kyQNbCdfSpIkbYDBIElqGAySpIbBIElqGAySpIbBIElqGAySpIbBIElq/H8v5ic7Orp6lwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "rho =.1\n", "objective = cp.Minimize( cp.sum(cp.multiply(P,C)) + rho*q + rho*r )\n", "prob = cp.Problem(objective, constr)\n", "result = prob.solve()\n", "a1 = np.sum(P.value, axis=1)\n", "b1 = np.sum(P.value.T, axis=1)\n", "plt.bar(x, a1, width = 1/n, color = \"b\")\n", "plt.bar(y, -b1, width = 1/m, color = \"r\")\n", "plt.bar(x, a, width = 1/n, color = \"b\", alpha=.2)\n", "plt.bar(y,-b, width = 1/m, color = \"r\", alpha=.2);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the impact of $\\rho$." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACmdJREFUeJzt3b1y4kgYBVC0tY/AxKN0Y/P+TwDxxo5n3kGb7JSBskAC/dzuPieyy9hW7Wxz+7stQzcMwwEA0vy19wUAwHcEFACRBBQAkQQUAJEEFACRBBQAkQQUAJEEFACRBBQAkf6e8+Dj8Tj0fb/SpQDQgsvl8nsYhh/PHjcroPq+P5zP59evCoDmdV33OeVxswIKoEqXy+3nHx/7XAc3BBTQpvtQGvuasNqNmyQAiGSCAtrxaGqa8j2mqU2ZoACIZIIC6vXKxDTn55moVmWCAiCSgAIgkooPqMvStd7U36XuW5yAqtSUNWo9AckEVEXmbhxt/oBkzqAAiGSCKtxSdbtpiqJtee405RosokWYoACIZIIq0NqbRRtBIIEJCoBIJiiA0+n2c2/MGkFAFWKvM2B1H5ESbop4xGv2LULFB0AkAQVAJAEFQCRnUKESK3bnUewqcVFMZfG8xAQFQCQBBUAkFV+QkhoMjQWwNhMUAJFMUECukmqFqdQPkwmondWw/qw3YA0CiulOV+lzriBZgWjOoACIZIJKdbrrykwstKCGznsOLyr7kIDaQc1r0HkUsBQVHwCRBBQAkVR8G6m51huj7mOSFhfHGIvmhgkKgEgmqBXZGH6xMQTmElDA9uzenrOrU/EBkMkEtTAbw+f8bSIwhYAC1ne/Kzmdvj4+n7e9lhI1uqsTUAswNb2n674+Hob9rgOiCHEBBazkeufR6BPsahq5gUJAvcjUtA7TFPCHgJpBKG1LWBXo+h/t2nVdxbIqnqYEFPC6sUBiH5XdTCGgnjA1ZTBNBRFK5Sh8uhJQ/xNE5Rh7fhRcKxJK5SswrJoIqEnhc/0Ott69tkhTnkOF2B3B06apO/Kdg6wbZqzYrut+HQ6Hz/UuB4AG/ByG4cezB80KKADYiheLBSCSgAIgkoACIJKAAiCSgAIgkoACIJKAAiCSgAIgkoACIJKAAiCSgAIgkoACIJKAAiCSgAIgkoACIJKAAiDSrLd8Px6PQ9/3K10KAC24XC6/p7yj7qyA6vv+cD6fX78qAJrXdd3nlMep+ACINGuCAqjG5TLtcR8f614Ho0xQAEQyQQHtmDo1jX2PaWpTJigAIpmggHq9MjHN+XkmqlWZoACIJKAAiCSgAIjkDKpSU6p39TlVWvrcaervsqAWZ4ICIJIJqiJzN442f0AyAVW4pdoMYQWkEVBA+bY8d5pyDXZ5i3AGBUAkE1SB1t4s2ggCCUxQAEQyQQHl6brbz9Pe6dtr9i1CQBVirzNgdR9NOJ1uP08LvEap+ACIJKAAiKTiC5XwZx33rmv/YdjvOoA2CCigDPc3RpTEYe5LVHwARDJBBUms9cao+4C1maAAiGSCAnKVfO40xnnUZCYoACKZoHZW0rnTGOdRwBpMUABEMkEBOWo8c3rEi8o+JKB2MKnWO939j3ouowtU9wFLUfEBEElAARBJxbeRGu7Wm0vdxyStnTs94m+kbpigAIgkoACIpOJbUYu13hh1HzfUes+p+0xQAGQyQS3M1PTc/ebZRAV8R0Dx0Om09xVQBZXeexp9xQkBtQBT03ucTwHfEVDAOkxN77muL87n2681cgOFgHqRqWkdpingDwEFLGevqenRtFG7iqcpATWDqWlbpilom4B6QihlEFahnDNlqexuPwH1DaGUTVjtQBCVaezJrJDgai6gRsPn+g0CC3lzQKY9bwqxiYRQOx7twoPCqxtmrN7T6TScFziA/Lf75+2fAcC2/hn+XeTndF13GYbh6csAzAqorut+HQ6Hz3cuDIDm/RyG4cezB80KKADYilczByCSgAIgkoACIJKAAiCSgAIgkoACIJKAAiCSgAIgkoACIJKAAiCSgAIgkoACIJKAAiCSgAIgkoACIJKAAiDS33MefDweh77vV7oUAFpwuVx+T3lH3VkB1ff94Xw+v35VADSv67rPKY+bFVAAVeq628+HYZ/r4IaAAtp0H0pjXxNWuxFQQDsehdKU7xFWm3IXHwCRBBQAkVR8QL1eqfTm/DyV36pMUABEElAARFLxNcBNSDRl6Vpv6u+yuBYnoCo1tkatJ6AUAqoiczeOwgpI5gwKgEgmqMItVbebpijaludOYy6Xr48/Pva7joqYoACIZIIq0NqbRdMUkMAEBUAkExRQnoQzp0euz6MOB2dSLxJQhdhrPar7gL2o+ACIJKAAiKTiC5VYsav7gC0JKKAMibu2qfwR70tUfABEElAARFLxBSmpwXAeBaxNQAG5Stq1TeU8ajIVHwCRTFA7u39FlBKp+4A1mKAAiGSCAnLUeOb0iBeVfUhA7aCGWm+Mug9YiooPgEgCCoBIKr6N1FzrjVH3MUlr506P+BupGyYoACIJKAAiqfhW1GKtN0bdxw213nPqPgEFsIjT6evj83m/66iIig+ASCaohan1nrtvd1R+DVDpvafRV5wwQQEQyQS1AFPTe9xAAd9wpiWggJWo9dbTyB1+AupFpqZ1mKYKJ5S2V3FYCagZhNK2hBW0TUA9IZQyCKtQJqYsld3tJ6C+IZSyCaudCaVyFF7/NR1Qgqh8Y8+VgutNQqg+j57wQsOr+IBaLGROV/9A5+2S6/pOUpbzzvNrdLgJjjLcL+y9bhMfe4K5v56pT6QbB1k3zFiNXdf9OhwOn+tdDgAN+DkMw49nD5oVUACwFS91BEAkAQVAJAEFQCQBBUAkAQVAJAEFQCQBBUAkAQVAJAEFQCQBBUAkAQVAJAEFQCQBBUAkAQVAJAEFQCQBBUCkWW/5fjweh77vV7oUAFpwuVx+T3lH3VkB1ff94Xz/XvYAMEPXdZ9THjcroACq1HW3nw/DPtfBDQEFtOk+lMa+Jqx2I6CAdjwKpSnfI6w25S4+ACIJKAAiqfiAer1S6c35eSq/VZmgAIgkoACIpOID6rJ0rTf1d6n7FiegKjVljVpPQDIBVZG5G0ebPyCZMygAIpmgCrdU3W6aomhbnjtNuQaLaBEmKAAimaAKtPZm0UYQSGCCAiCSgAIgkoqvEHudAav7iJRwU8QjXrNvESYoACIJKAAiqfhCJTYY6j5gSwIKKEPirm0qu7uXqPgAiCSgAIik4gtSUoOhsQDWZoICIJIJCshVUq0w1eXy9fHHx37XUQABtbMa1p+6D1iDgALY0un09bEd3UPOoACIZILaQQ213hh1H2+peXF85/o86nBwJnXHBAVAJAEFQCQVH7Cv1mq9R9yCfkNAbaTFNeg8CniHig+ASCaoFbU4NY0xTQFzCShge3ZvzzmPUvEBkMkEtTAbw+fu/xup/IDvCChgfXZu812/Zt/5fPu1Rio/AbUAa+89bqAAviOggHXYua2nkRsoBNSLrL11mKaAPwTUDEJpW8KqQBbJ9iqepgTUE9ZbBmEVygLJUtnbdwiob1hz2YTVziyQchQ+XQmo/1lzZRJWG7FAyldgWDURUNZWG6b8OwuxOxZHm+6rwDE7B1k3zFixXdf9OhwOn+tdDgAN+DkMw49nD5oVUACwFS8WC0AkAQVAJAEFQCQBBUAkAQVAJAEFQCQBBUAkAQVAJAEFQKT/AKiVI+cmOu74AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "rho_list = np.array([.05, .1, .2, 5])\n", "for k in range(len(rho_list)):\n", " rho = rho_list[k]\n", " objective = cp.Minimize( cp.sum(cp.multiply(P,C)) + rho*q + rho*r )\n", " prob = cp.Problem(objective, constr)\n", " result = prob.solve()\n", " a1 = np.sum(P.value, axis=1)\n", " b1 = np.sum(P.value.T, axis=1)\n", " ax = plt.subplot(len(rho_list),1,k+1)\n", " plt.bar(x, a1, width = 1/n, color = \"b\")\n", " plt.bar(y, b1, width = 1/m, color = \"r\")\n", " plt.bar(x, a, width = 1/n, color = \"b\", alpha=.2)\n", " plt.bar(y, b, width = 1/m, color = \"r\", alpha=.2)\n", " ax.set(xticks=[], yticks=[])\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can compare several $\\phi$-divergence." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Entropic Regularization and Sinkhorn\n", "=========\n", "\n", "It is possible to regularized the initial problem using entropic regularization and consider\n", "$$\n", "\\umin{P \\in \\RR_+^{n \\times m}} \\dotp{P}{C}\n", " + \\rho KL( P 1_m|a )\n", " + \\rho KL( P^\\top 1_n|b )\n", " + \\varepsilon KL(P|ab^\\top).\n", "$$\n", "Here $\\varepsilon>0$ controls the strength of the regularization, increasing it results in faster algorithms but degrades the approximation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The solution $P$ can be shown to be of the form \n", "$$\n", " P_{i,j} = e^{ \\frac{f_i+g_j-C_{i,j}}{\\epsilon} } a_i b_j\n", "$$\n", "where the dual variables $f \\in \\RR^n, g \\in \\RR^m$ satisfies the following scaled Sinkhorn iterations (written here in log domain):\n", "$$\n", " f_i = -\\varepsilon \\kappa \\log \\sum_{j} \\exp\\pa{ \\frac{g_j-C_{i,j}}{\\epsilon} } b_j\n", "$$\n", "$$\n", " g_j = -\\varepsilon\\kappa \\log \\sum_{i} \\exp\\pa{ \\frac{f_i-C_{i,j}}{\\epsilon} } a_j\n", "$$\n", "where we noted\n", "$$\n", " \\kappa \\triangleq \\frac{\\rho}{\\varepsilon + \\rho} .\n", "$$\n", "Sinkhorn's algoritm simply iterates these two fixed points. \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the log-sum-exp operator (which corresponds to soft $C$-transforms)." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def mina_u(H,epsilon): return -epsilon*np.log( np.sum(a[:,None] * np.exp(-H/epsilon),0) )\n", "def minb_u(H,epsilon): return -epsilon*np.log( np.sum(b[None,:] * np.exp(-H/epsilon),1) )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " They can be stabilized using the usual [log-sum-exp trick](https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations)." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def mina(H,epsilon): return mina_u(H-np.min(H,0),epsilon) + np.min(H,0);\n", "def minb(H,epsilon): return minb_u(H-np.min(H,1)[:,None],epsilon) + np.min(H,1);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Values of $\\varepsilon, \\rho$ and $\\kappa$." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "epsilon = .001\n", "rho = .2\n", "kappa = rho/(rho+epsilon)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implement Sinkhorn's iterates." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "f = np.zeros(n)\n", "niter = 1000\n", "for it in range(niter):\n", " g = kappa*mina(C-f[:,None],epsilon)\n", " f = kappa*minb(C-g[None,:],epsilon)\n", "# generate the coupling\n", "P = a[:,None] * np.exp((f[:,None]+g[None,:]-C)/epsilon) * b[None,:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the optimal plan." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(remap_plan(P))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the marginals." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFEFJREFUeJzt3X+snFd+1/H3Z23W0UKbrRxXAieuXcUpOFFVNpekSKUthF1lV2LdioR1wqopsrC2JSBRCmQFDWloJVIEEWgjbQ2JSLOCZElFuaKuLLXZ0lJlg+1mfzkrS3fd0Ny4UpPYuEoXb9bZL3/M4+3sZK7vc++d+2PmvF/SlZ95njN3zvGd+cx5zpznTKoKSVIb3rXZFZAkbRxDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktSQ7ZtdgVHXXXdd7d27d7OrIUlT5dSpU69X1a7lym250N+7dy8nT57c7GpI0lRJ8n/6lHN4R5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDWkV+gnuTPJmSQLSR4Yc3xHkme64y8k2Tt07HuTPJ/kdJIvJrlmctWXJK3EsqGfZBvwGPBB4ABwT5IDI8UOAxeq6kbgUeCR7r7bgU8BH6uqm4EfBr4+sdpLklakT0//NmChqs5W1VvA08DBkTIHgSe77WeBO5IE+ADwhar6PEBVvVFVb0+m6pKkleoT+ruBV4ZuL3b7xpapqsvARWAncBNQSY4n+d0k/2TcAyQ5kuRkkpOvvfbaStsgaRadOvUnP5qYPqGfMfuqZ5ntwA8Af7v790eT3PGOglVHq2ququZ27Vp26QhJ0ir1WXtnEbhh6Pb1wLklyix24/jXAue7/f+zql4HSHIMeB/wG2ust6RZtFSv/sr+W2/duLrMqD49/RPA/iT7krwbOATMj5SZB+7rtu8CnquqAo4D35vkPd2bwQ8BL02m6pKklVq2p19Vl5PczyDAtwFPVNXpJA8DJ6tqHngceCrJAoMe/qHuvheS/FsGbxwFHKuqX12ntkiSltFraeWqOgYcG9n34ND2JeDuJe77KQbTNiVJm8wrciWpIYa+JDXE0Jekhmy5r0uU1JiVXHw1XNbpm6tiT1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQ1yGQdLmWOt337okw6rY05ekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUkF6hn+TOJGeSLCR5YMzxHUme6Y6/kGRvt39vkv+X5HPdzycnW31J0kosu/ZOkm3AY8D7gUXgRJL5qnppqNhh4EJV3ZjkEPAI8JHu2Feq6vsmXG9J0ir06enfBixU1dmqegt4Gjg4UuYg8GS3/SxwR5JMrpqSpEnos8rmbuCVoduLwO1Llamqy0kuAju7Y/uSvAj8EfDPq+q311ZlSVNrrStr9vm9rrh5VX1Cf1yPvXqW+QNgT1W9keRW4FeS3FxVf/Qtd06OAEcA9uzZ06NKmpQrrxVfJ1Ib+gzvLAI3DN2+Hji3VJkk24FrgfNV9bWqegOgqk4BXwFuGn2AqjpaVXNVNbdr166Vt0KS1Eufnv4JYH+SfcCrwCHg3pEy88B9wPPAXcBzVVVJdjEI/7eTfDewHzg7sdprVcadYXt2LLVh2dDvxujvB44D24Anqup0koeBk1U1DzwOPJVkATjP4I0B4AeBh5NcBt4GPlZV59ejIZKk5fX6usSqOgYcG9n34ND2JeDuMff7ZeCX11hHSdKEeEWuJDXE0JekhvQa3tFs6DtF2mmc0uyypy9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5Ia4pTNGbeWlWxdj0eaPfb0Jakh9vQlra/1+uKUPo/nKeo72NOXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0JakhztOfUZOeGu0Xq0izwZ6+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDeoV+kjuTnEmykOSBMcd3JHmmO/5Ckr0jx/ckeTPJT0+m2pKk1Vg29JNsAx4DPggcAO5JcmCk2GHgQlXdCDwKPDJy/FHg19ZeXUnSWvTp6d8GLFTV2ap6C3gaODhS5iDwZLf9LHBHkgAk+RHgLHB6MlWWJK1Wn9DfDbwydHux2ze2TFVdBi4CO5P8aeCfAj+79qpKktaqzxW5GbOvepb5WeDRqnqz6/iPf4DkCHAEYM+ePT2qJGlL2+hvy1qKl5K/Q5/QXwRuGLp9PXBuiTKLSbYD1wLngduBu5L8AvBe4BtJLlXVJ4bvXFVHgaMAc3Nzo28okqQJ6RP6J4D9SfYBrwKHgHtHyswD9wHPA3cBz1VVAX/lSoEkDwFvjga+JGnjLBv6VXU5yf3AcWAb8ERVnU7yMHCyquaBx4Gnkiww6OEfWs9Ka7yNOKP2O6el6dZrlc2qOgYcG9n34ND2JeDuZX7HQ6uonyRpgrwiV5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JakivtXe0tW3W0uUuVS5NH3v6ktQQQ1+SGuLwjqTJ2SpfkzjKL4L4Jnv6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhriMgxTaitd7e4V7tL06NXTT3JnkjNJFpI8MOb4jiTPdMdfSLK3239bks91P59P8qOTrb4kaSWWDf0k24DHgA8CB4B7khwYKXYYuFBVNwKPAo90+78EzFXV9wF3Ar+YxLMLSdokfXr6twELVXW2qt4CngYOjpQ5CDzZbT8L3JEkVfXVqrrc7b8GqElUWpK0On1CfzfwytDtxW7f2DJdyF8EdgIkuT3JaeCLwMeG3gS+KcmRJCeTnHzttddW3gpJUi99Qj9j9o322JcsU1UvVNXNwF8CPp7kmncUrDpaVXNVNbdr164eVZIkrUaf0F8Ebhi6fT1wbqky3Zj9tcD54QJV9WXgj4FbVltZSdLa9PlQ9QSwP8k+4FXgEHDvSJl54D7geeAu4Lmqqu4+r1TV5STfBXwP8PKkKi9pC9hK84f7aHyO8bKh3wX2/cBxYBvwRFWdTvIwcLKq5oHHgaeSLDDo4R/q7v4DwANJvg58A/jJqnp9PRoiSVper+mTVXUMODay78Gh7UvA3WPu9xTw1BrrKEmaEJdhkKSGeKGUpM01N/cn2ydPbl49GmHoS9ocw2E/bp9vAOvC4R1Jaog9/SkyDTPjGp8NJ215hr6kjTNuSKdPWYd6JsbhHUlqiKEvSQ0x9PUOc3MrOwuXND0c0xfg7DmpFfb0JW19nn5OjD199XLl9WaPXytmWG8phn7jfD1KbXF4R5IaYuhLUkMc3tGKOKNHwHSsCdJHg+uGGPoNchxfapehL2ny1qtn4anmmhn6U2Baz6Sv1LuRs2ZpKvhBriQ1xNDXqnmRpDR9DH1Jaohj+g2xVy7Jnr4kNcSevqTp5PTNVenV009yZ5IzSRaSPDDm+I4kz3THX0iyt9v//iSnknyx+/evTbb6krYUP93f8pbt6SfZBjwGvB9YBE4kma+ql4aKHQYuVNWNSQ4BjwAfAV4H/kZVnUtyC3Ac2D3pRmhz2eGSpkefnv5twEJVna2qt4CngYMjZQ4CT3bbzwJ3JElVvVhV57r9p4FrkuyYRMUlSSvXZ0x/N/DK0O1F4PalylTV5SQXgZ0MevpX/E3gxar62uqrq5XyTFvSsD6hnzH7aiVlktzMYMjnA2MfIDkCHAHYs2dPjypJklajz/DOInDD0O3rgXNLlUmyHbgWON/dvh74b8CPVdVXxj1AVR2tqrmqmtu1a9fKWiBJ6q1PT/8EsD/JPuBV4BBw70iZeeA+4HngLuC5qqok7wV+Ffh4Vf3O5Ko9+6Z1kbVxGlyyfDbN0pNynEZWCFy2p19Vl4H7Gcy8+TLw6ao6neThJB/uij0O7EyyAPwUcGVa5/3AjcDPJPlc9/OdE2+FJKmXXhdnVdUx4NjIvgeHti8Bd4+5388BP7fGOmqKOH1Tm8InXm9ekStpbZwiNlUM/Rnl61DSOC64JkkNMfQlqSGGviQ1xNCXpIYY+pLUEGfvaN1cmUHktGltKOfsX5WhP0OcpilpOYb+FuPyJpoK9jCmlmP6ktQQe/qSrm7WTz9HzfiysPb0Jakhhr4kNcThHa07Z9BJW4ehPwOcSCEtwYtF3sHhHUlqiD19Sf14SjkT7OlLUkMMfUlqiMM7W0Br177AzF//Im1Z9vS1oebmHBqWNpM9/SllcGpdtXj6Oc4MnpIa+pJmn1cIfpPDO5LUEHv6kq7OscSZ0qunn+TOJGeSLCR5YMzxHUme6Y6/kGRvt39nks8keTPJJyZbdUnSSi0b+km2AY8BHwQOAPckOTBS7DBwoapuBB4FHun2XwJ+BvjpidVYkrRqfYZ3bgMWquosQJKngYPAS0NlDgIPddvPAp9Ikqr6Y+B/JblxclVu1yydZQ+3pWrz6iG1pk/o7wZeGbq9CNy+VJmqupzkIrATeL1PJZIcAY4A7Nmzp89dZoKz4gb83lxp4/QJ/YzZN9o361NmSVV1FDgKMDc3Z79P0vppfPpmnw9yF4Ebhm5fD5xbqkyS7cC1wPlJVFCSNDl9evongP1J9gGvAoeAe0fKzAP3Ac8DdwHPVTlSK00Vxxuvbkauzl029Lsx+vuB48A24ImqOp3kYeBkVc0DjwNPJVlg0MM/dOX+SV4Gvh14d5IfAT5QVS+NPo6WNksf4GpK+KSbWb0uzqqqY8CxkX0PDm1fAu5e4r5711A/NeBKvnhuKK0/l2GQpIa4DMMGc9h0aTMyZKpp0uBMHnv6ktQQe/pblJ+jSVoPhr6kAXsaTTD0tWW4Hs8m8EOm1ZnitUMc05ckaOYLnA19SWqIwzsbpO9ZdAMdjV6m+OxZ2tIMfall9jKaY+hrS3JpBml9GPpSa5yxc3UruUp3Ci8jN/S3AM+wJW0UQ38d2aFauynsSElbmqGvLc0LttaBp5ZNM/Q3ka89aYubwVU4DX2pFY43rq8pGYs09NeBr631kQz+dZhnlTy1FIb+hvN1J02pGRnqMfQ1da70+MFe/7I87dwcW3iox9CfEF9b2pI8tdQIQ3+D+NqTZsiVF/QUDvMY+mtkD39z+eHuGMPjX9p8W2zJWEN/Hdm73zjNj/Mb9JtjCtfpmb3Q34An/yne16vc9J34bT1zrPxUyt6/NsXQG8Cty7z6f3eJ/RvxnO0V+knuBP4dsA34j1X1r0aO7wB+CbgVeAP4SFW93B37OHAYeBv4B1V1fGK130B9g16TdZI19Iiy1EtrDTb4nWRcH6awVz/t3rdkZ2b9zwCWDf0k24DHgPcDi8CJJPNV9dJQscPAhaq6Mckh4BHgI0kOAIeAm4E/B/x6kpuq6u1JN2QSDPbZstzf89Yl+1tXscHDKJ6sTKdTvHNsd7ne/0bp09O/DVioqrMASZ4GDgLDoX8QeKjbfhb4RJJ0+5+uqq8Bv5dkoft9z0+m+u/0Zf78qu/7Hr46wZpoq1vLc0VaqU/x0bH7P8qnNrQefUJ/N/DK0O1F4PalylTV5SQXgZ3d/s+O3Hf3qmsrSTPmW98Mvrzuj9cn9Medz46edS5Vps99SXIEONLdfDPJmR71uprrgNfX+DumjW2efa21F1pr82D4cLVt/q4+hfqE/iJww9Dt64FzS5RZTLIduBY43/O+VNVR4GifCveR5GRVNTVh0jbPvtbaC7Z5PbyrR5kTwP4k+5K8m8EHs/MjZeaB+7rtu4Dnqqq6/YeS7EiyD9gP/O/JVF2StFLL9vS7Mfr7geMMpmw+UVWnkzwMnKyqeeBx4Knug9rzDN4Y6Mp9msGHvpeBv7dVZ+5IUgt6zdOvqmPAsZF9Dw5tXwLuXuK+Pw/8/BrquBoTGyqaIrZ59rXWXrDNE5fyskVJakafMX1J0oyY6tBPcmeSM0kWkjww5viOJM90x19Isnfjazk5Pdr7U0leSvKFJL+RpNcUrq1suTYPlbsrSSWZ+pkefdqc5G91f+vTSf7zRtdx0no8t/ck+UySF7vn94c2o56TkuSJJH+Y5EtLHE+Sf9/9f3whyeSWC6iqqfxh8KHyV4DvBt4NfB44MFLmJ4FPdtuHgGc2u97r3N6/Cryn2/6JaW5v3zZ35b4N+C0GFwLObXa9N+DvvB94EfiO7vZ3bna9N6DNR4Gf6LYPAC9vdr3X2OYfBN4HfGmJ4x8Cfo3BtU7fD7wwqcee5p7+N5eHqKq3gCvLQww7CDzZbT8L3NEtDzGNlm1vVX2mqq6sJfFZBtdFTLM+f2OAfwn8AnBpIyu3Tvq0+e8Cj1XVBYCq+sMNruOk9WlzAd/ebV/LmOt9pklV/RaDmY5LOQj8Ug18Fnhvkj87icee5tAftzzE6BIP37I8BHBleYhp1Ke9ww4z6ClMs2XbnOQvAjdU1f/YyIqtoz5/55uAm5L8TpLPdqvgTrM+bX4I+GiSRQYzCf/+xlRt06z09d7bNK+nv5blIaZR77Yk+SgwB/zQutZo/V21zUneBTwK/PhGVWgD9Pk7b2cwxPPDDM7mfjvJLVX1f9e5buulT5vvAf5TVf2bJH+ZwXVBt1TVN9a/epti3bJrmnv6K1kegpHlIaZRryUtkvx14J8BH67B6qbTbLk2fxtwC/CbSV5mMPY5P+Uf5vZ9Xv/3qvp6Vf0ecIbBm8C06tPmw8CnAarqeeAaBmvUzKper/fVmObQX8vyENNo2fZ2Qx2/yCDwp32cF5Zpc1VdrKrrqmpvVe1l8DnGh6tqayxcvjp9nte/wuBDe5Jcx2C45+yG1nKy+rT594E7AJL8BQah/9qG1nJjzQM/1s3i+X7gYlX9wSR+8dQO79QaloeYRj3b+6+BPwP81+7z6t+vqg9vWqXXqGebZ0rPNh8HPpDkJQbfSPePq+qNzav12vRs8z8C/kOSf8hgmOPHp7gDR5L/wmB47rruc4p/AfwpgKr6JIPPLT4ELABfBf7OxB57iv/fJEkrNM3DO5KkFTL0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqyP8HpQMHIzl6wRMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "a1 = np.sum(P, axis=1)\n", "b1 = np.sum(P.T, axis=1)\n", "plt.bar(x, a1, width = 1/n, color = \"b\")\n", "plt.bar(y, b1, width = 1/m, color = \"r\")\n", "plt.bar(x, a, width = 1/n, color = \"b\", alpha=.2)\n", "plt.bar(y, b, width = 1/m, color = \"r\", alpha=.2);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise (easy):** Experiments with different values of $\\varepsilon$ and $\\rho$. Study the rate of convergence of the method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise (hard):** Extend Sinkhorn for other type of divergence, starting with TV." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" } }, "nbformat": 4, "nbformat_minor": 1 }