{ "metadata": { "name": "" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Entropic Regularization of Optimal Transport", "============================================", "", "*Important:* Please read the [installation page](http://gpeyre.github.io/numerical-tours/installation_python/) for details about how to install the toolboxes.", "$\\newcommand{\\dotp}[2]{\\langle #1, #2 \\rangle}$", "$\\newcommand{\\enscond}[2]{\\lbrace #1, #2 \\rbrace}$", "$\\newcommand{\\pd}[2]{ \\frac{ \\partial #1}{\\partial #2} }$", "$\\newcommand{\\umin}[1]{\\underset{#1}{\\min}\\;}$", "$\\newcommand{\\umax}[1]{\\underset{#1}{\\max}\\;}$", "$\\newcommand{\\umin}[1]{\\underset{#1}{\\min}\\;}$", "$\\newcommand{\\uargmin}[1]{\\underset{#1}{argmin}\\;}$", "$\\newcommand{\\norm}[1]{\\|#1\\|}$", "$\\newcommand{\\abs}[1]{\\left|#1\\right|}$", "$\\newcommand{\\choice}[1]{ \\left\\{ \\begin{array}{l} #1 \\end{array} \\right. }$", "$\\newcommand{\\pa}[1]{\\left(#1\\right)}$", "$\\newcommand{\\diag}[1]{{diag}\\left( #1 \\right)}$", "$\\newcommand{\\qandq}{\\quad\\text{and}\\quad}$", "$\\newcommand{\\qwhereq}{\\quad\\text{where}\\quad}$", "$\\newcommand{\\qifq}{ \\quad \\text{if} \\quad }$", "$\\newcommand{\\qarrq}{ \\quad \\Longrightarrow \\quad }$", "$\\newcommand{\\ZZ}{\\mathbb{Z}}$", "$\\newcommand{\\CC}{\\mathbb{C}}$", "$\\newcommand{\\RR}{\\mathbb{R}}$", "$\\newcommand{\\EE}{\\mathbb{E}}$", "$\\newcommand{\\Zz}{\\mathcal{Z}}$", "$\\newcommand{\\Ww}{\\mathcal{W}}$", "$\\newcommand{\\Vv}{\\mathcal{V}}$", "$\\newcommand{\\Nn}{\\mathcal{N}}$", "$\\newcommand{\\NN}{\\mathcal{N}}$", "$\\newcommand{\\Hh}{\\mathcal{H}}$", "$\\newcommand{\\Bb}{\\mathcal{B}}$", "$\\newcommand{\\Ee}{\\mathcal{E}}$", "$\\newcommand{\\Cc}{\\mathcal{C}}$", "$\\newcommand{\\Gg}{\\mathcal{G}}$", "$\\newcommand{\\Ss}{\\mathcal{S}}$", "$\\newcommand{\\Pp}{\\mathcal{P}}$", "$\\newcommand{\\Ff}{\\mathcal{F}}$", "$\\newcommand{\\Xx}{\\mathcal{X}}$", "$\\newcommand{\\Mm}{\\mathcal{M}}$", "$\\newcommand{\\Ii}{\\mathcal{I}}$", "$\\newcommand{\\Dd}{\\mathcal{D}}$", "$\\newcommand{\\Ll}{\\mathcal{L}}$", "$\\newcommand{\\Tt}{\\mathcal{T}}$", "$\\newcommand{\\si}{\\sigma}$", "$\\newcommand{\\al}{\\alpha}$", "$\\newcommand{\\la}{\\lambda}$", "$\\newcommand{\\ga}{\\gamma}$", "$\\newcommand{\\Ga}{\\Gamma}$", "$\\newcommand{\\La}{\\Lambda}$", "$\\newcommand{\\si}{\\sigma}$", "$\\newcommand{\\Si}{\\Sigma}$", "$\\newcommand{\\be}{\\beta}$", "$\\newcommand{\\de}{\\delta}$", "$\\newcommand{\\De}{\\Delta}$", "$\\newcommand{\\phi}{\\varphi}$", "$\\newcommand{\\th}{\\theta}$", "$\\newcommand{\\om}{\\omega}$", "$\\newcommand{\\Om}{\\Omega}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This numerical tours exposes the general methodology of regularizing the", "optimal transport (OT) linear program using entropy. This allows to", "derive fast computation algorithm based on iterative projections", "according to a Kulback-Leiber divergence.", "$$ \\DeclareMathOperator{\\KL}{KL}", "\\newcommand{\\KLdiv}[2]{\\KL\\pa{#1 | #2}}", "\\newcommand{\\KLproj}{P^{\\tiny\\KL}}", "\\def\\ones{\\mathbb{I}} $$" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from __future__ import division", "import nt_toolbox as nt", "from nt_solutions import optimaltransp_5_entropic as solutions", "%matplotlib inline", "%load_ext autoreload", "%autoreload 2" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Entropic Regularization of Optimal Transport", "--------------------------------------------", "We consider two input histograms $p,q \\in \\Si_N$, where we denote the simplex in $\\RR^N$", "$$ \\Si_{N} = \\enscond{ p \\in (\\RR^+)^N }{ \\sum_i p_i = 1 }. $$", "We consider the following discrete regularized transport", "$$ W_\\ga(p,q) = \\umin{\\pi \\in \\Pi(p,q)} \\dotp{C}{\\pi} - \\ga E(\\pi). $$", "where the polytope of coupling is defined as", "$$ \\Pi(p,q) = \\enscond{\\pi \\in (\\RR^+)^{N \\times N}}{ \\pi \\ones = p, \\pi^* \\ones = q }, $$", "and for $f \\in (\\RR^+)^{P}$ for some $P > 0$, we define its entropy as", "$$ E(f) = - \\sum_{i=1}^N f_i ( \\log(f_i) - 1). $$", "", "", "When $\\ga=0$ one recovers the classical (discrete) optimal transport.", "We refer to the monograph [Villani](#biblio) for more details about OT.", "The idea of regularizing transport to allows for faster computation is", "introduced in [Cuturi](#biblio).", "", "", "", "Here the matrix $C \\in (\\RR^+)^{N \\times N} $ defines the ground cost, i.e.", "$C_{i,j}$ is the cost of moving mass from a bin indexed by $i$ to a bin indexed by $j$.", "", "", "The regularized transportation problem can be re-written as a projection", "$$ W_\\ga(p,q) = \\ga \\umin{\\pi \\in \\Pi(p,q)} \\KLdiv{\\pi}{\\bar \\pi}", "\t\\qwhereq", "\t\\bar\\pi_{i,j} = e^{ -\\frac{C_{i,j}}{\\ga} } $$", "of $\\bar\\pi$ according to the Kullback-Leibler divergence. The Kullback-Leibler divergence between $f, \\bar f \\in (\\RR^+)^P$ is", "$$ \\KLdiv{f}{\\bar f} = \\sum_{i=1}^P f_{i} \\pa{ \\log\\pa{ \\frac{f_i}{\\bar f_i} } - 1}. $$", "With a slight abuse of notation, we extend these definitions for vectors $\\pi \\in \\RR^{N \\times N}$ (and also higher $d$-dimensional tensor arrays), so that $P=N^2$ (or more generally $P=N^d$) by replacing the sum over elements $f_i$ by $\\pi_{i,j}$ with $i,j=1,\\ldots,N$.", "", "", "Given a convex set $\\Cc \\subset \\RR^N$, the projection according to the Kullback-Leiber divergence is defined as", "$$ \\KLproj_\\Cc(\\bar f) = \\uargmin{ f \\in \\Cc } \\KLdiv{f}{\\bar f}. $$", "", "Iterative Bregman Projection Algorithm", "--------------------------------------", "Given affine constraint sets $ (\\Cc_1,\\ldots,\\Cc_K) $, we aim at computing", "$$ \\KLproj_\\Cc(\\bar \\pi) \\qwhereq \\Cc = \\Cc_1 \\cap \\ldots \\cap \\Cc_K. $$", "", "", "This can be achieved, starting by $\\pi_0=\\bar\\pi$, by iterating", "$$ \\forall \\ell \\geq 0, \\quad \\pi_{\\ell+1} = \\KLproj_{\\Cc_\\ell}(\\pi_\\ell), $$", "where the index of the constraints should be understood modulo $K$,", "i.e. we set $ \\Cc_{\\ell+K}=\\Cc_\\ell $.", "", "", "One can indeed show that $\\pi_\\ell \\rightarrow \\KLproj_\\Cc(\\bar \\pi)$.", "We refer to [BauschkeLewis](#biblio) for more details about this", "algorithm and its extension to compute the projection on the intersection of", "convex sets (Dikstra algorithm).", "", "Iterative Projection for Regularized Transport", "----------------------------------------------", "We can re-cast the regularized optimal transport problem within this", "framework by introducing", "$$ \\Cc_1 = \\enscond{\\pi \\in (\\RR^+)^{N \\times N} }{\\pi \\ones = p}", "\\qandq", " \\Cc_2 = \\enscond{\\pi \\in (\\RR^+)^{N \\times N} }{\\pi^* \\ones = q}$$", "", "", "The KL projection on $\\Cc_1$ sets are easily computed by divisive", "normalization of rows. Indeed, denoting", "$ \\pi = \\KLproj_{\\Cc_1}(\\bar \\pi) $, one has", "$$ \\forall (i,j), \\quad", " \\pi_{i,j} = \\frac{ p_i \\bar\\pi_{i,j} }{ \\sum_{s} \\bar\\pi_{i,s} } $$", "and similarely for $\\KLproj_{\\Cc_1}(\\bar \\pi) $ by replacing rows by", "colums.", "", "", "Size $N$ of the histograms." ] }, { "cell_type": "code", "collapsed": false, "input": [ "N = 200" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define $\\KLproj_{\\Cc_1}$." ] }, { "cell_type": "code", "collapsed": false, "input": [ "ProjC1 = lambda pi, p: pi .* repmat(p./ max(sum(pi, 2), 1e-10), [1 N])" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define $\\KLproj_{\\Cc_2}$." ] }, { "cell_type": "code", "collapsed": false, "input": [ "ProjC2 = lambda pi, q: pi .* repmat(q'./ max(sum(pi, 1), 1e-10), [N 1])" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use here a 1-D square Euclidean metric." ] }, { "cell_type": "code", "collapsed": false, "input": [ "x = (0: N-1)'/ N; y = x", "Y = repmat(y', [N 1])", "X = repmat(x, [1 N])", "C = abs(X-Y).^2" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the histogram $p,q$" ] }, { "cell_type": "code", "collapsed": false, "input": [ "Gaussian = lambda x0, sigma: exp(-(x-x0).^2/ (2*sigma^2))", "normalize = lambda p: p/ sum(p(: ))", "x0 = .2; y0 = .8; sigma = .07", "p = Gaussian(x0, sigma)", "q = Gaussian(y0, sigma)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add some minimal mass and normalize." ] }, { "cell_type": "code", "collapsed": false, "input": [ "vmin = .02", "p = normalize(p + max(p)*vmin)", "q = normalize(q + max(q)*vmin)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display them." ] }, { "cell_type": "code", "collapsed": false, "input": [ "subplot(2, 1, 1)", "bar(x, p, 'k'); axis tight", "subplot(2, 1, 2)", "bar(y, q, 'k'); axis tight" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Exercise 1__", "", "Perform the iterations, and display the decay of the errors", "$$ \\norm{\\pi_\\ell \\ones - p}", " \\qandq", " \\norm{\\pi_\\ell^* \\ones - q} $$", "in log scale.", "isplay error decay." ] }, { "cell_type": "code", "collapsed": false, "input": [ "solutions.exo1()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "## Insert your code here." ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the optimal $\\pi$." ] }, { "cell_type": "code", "collapsed": false, "input": [ "imageplot(pi)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For visualization purpose, to more clearly see the optimal map, do a", "normalization." ] }, { "cell_type": "code", "collapsed": false, "input": [ "normalizeMax = lambda pi: pi ./ repmat(max(pi, [], 1), [N 1])", "", "imageplot(normalizeMax(pi))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Exercise 2__", "", "Display the transport map for several values of $\\gamma$." ] }, { "cell_type": "code", "collapsed": false, "input": [ "solutions.exo2()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "## Insert your code here." ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Bibliography", "------------", "", "", "", "* [Villani] Villani, C. (2009). Optimal transport: old and new, volume 338. Springer Verlag.", "* [Cuturi] Cuturi, M. (2013). Sinkhorn distances: Lightspeed computation of optimal transport. In Burges, C. J. C., Bottou, L., Ghahramani, Z., and Weinberger, K. Q., editors, Proc. NIPS, pages 2292-2300.", "* [AguehCarlier] Agueh, M. and Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM J. on Mathematical Analysis, 43(2):904-924.", "* [CuturiDoucet] Cuturi, M. and Doucet, A. (2014). Fast computation of wasserstein barycenters. In Proc. ICML.", "* [BauschkeLewis] H. H. Bauschke and A. S. Lewis. Dykstra's algorithm with Bregman projections: a convergence proof. Optimization, 48(4):409-427, 2000." ] } ] } ] }