{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Advanced Topics on Sinkhorn Algorithm\n", "=====================================\n", "\n", "*Important:* Please read the [installation page](http://gpeyre.github.io/numerical-tours/installation_matlab/) for details about how to install the toolboxes.\n", "$\\newcommand{\\dotp}[2]{\\langle #1, #2 \\rangle}$\n", "$\\newcommand{\\enscond}[2]{\\lbrace #1, #2 \\rbrace}$\n", "$\\newcommand{\\pd}[2]{ \\frac{ \\partial #1}{\\partial #2} }$\n", "$\\newcommand{\\umin}[1]{\\underset{#1}{\\min}\\;}$\n", "$\\newcommand{\\umax}[1]{\\underset{#1}{\\max}\\;}$\n", "$\\newcommand{\\umin}[1]{\\underset{#1}{\\min}\\;}$\n", "$\\newcommand{\\uargmin}[1]{\\underset{#1}{argmin}\\;}$\n", "$\\newcommand{\\norm}[1]{\\|#1\\|}$\n", "$\\newcommand{\\abs}[1]{\\left|#1\\right|}$\n", "$\\newcommand{\\choice}[1]{ \\left\\{ \\begin{array}{l} #1 \\end{array} \\right. }$\n", "$\\newcommand{\\pa}[1]{\\left(#1\\right)}$\n", "$\\newcommand{\\diag}[1]{{diag}\\left( #1 \\right)}$\n", "$\\newcommand{\\qandq}{\\quad\\text{and}\\quad}$\n", "$\\newcommand{\\qwhereq}{\\quad\\text{where}\\quad}$\n", "$\\newcommand{\\qifq}{ \\quad \\text{if} \\quad }$\n", "$\\newcommand{\\qarrq}{ \\quad \\Longrightarrow \\quad }$\n", "$\\newcommand{\\ZZ}{\\mathbb{Z}}$\n", "$\\newcommand{\\CC}{\\mathbb{C}}$\n", "$\\newcommand{\\RR}{\\mathbb{R}}$\n", "$\\newcommand{\\EE}{\\mathbb{E}}$\n", "$\\newcommand{\\Zz}{\\mathcal{Z}}$\n", "$\\newcommand{\\Ww}{\\mathcal{W}}$\n", "$\\newcommand{\\Vv}{\\mathcal{V}}$\n", "$\\newcommand{\\Nn}{\\mathcal{N}}$\n", "$\\newcommand{\\NN}{\\mathcal{N}}$\n", "$\\newcommand{\\Hh}{\\mathcal{H}}$\n", "$\\newcommand{\\Bb}{\\mathcal{B}}$\n", "$\\newcommand{\\Ee}{\\mathcal{E}}$\n", "$\\newcommand{\\Cc}{\\mathcal{C}}$\n", "$\\newcommand{\\Gg}{\\mathcal{G}}$\n", "$\\newcommand{\\Ss}{\\mathcal{S}}$\n", "$\\newcommand{\\Pp}{\\mathcal{P}}$\n", "$\\newcommand{\\Ff}{\\mathcal{F}}$\n", "$\\newcommand{\\Xx}{\\mathcal{X}}$\n", "$\\newcommand{\\Mm}{\\mathcal{M}}$\n", "$\\newcommand{\\Ii}{\\mathcal{I}}$\n", "$\\newcommand{\\Dd}{\\mathcal{D}}$\n", "$\\newcommand{\\Ll}{\\mathcal{L}}$\n", "$\\newcommand{\\Tt}{\\mathcal{T}}$\n", "$\\newcommand{\\si}{\\sigma}$\n", "$\\newcommand{\\al}{\\alpha}$\n", "$\\newcommand{\\la}{\\lambda}$\n", "$\\newcommand{\\ga}{\\gamma}$\n", "$\\newcommand{\\Ga}{\\Gamma}$\n", "$\\newcommand{\\La}{\\Lambda}$\n", "$\\newcommand{\\si}{\\sigma}$\n", "$\\newcommand{\\Si}{\\Sigma}$\n", "$\\newcommand{\\be}{\\beta}$\n", "$\\newcommand{\\de}{\\delta}$\n", "$\\newcommand{\\De}{\\Delta}$\n", "$\\newcommand{\\phi}{\\varphi}$\n", "$\\newcommand{\\th}{\\theta}$\n", "$\\newcommand{\\om}{\\omega}$\n", "$\\newcommand{\\Om}{\\Omega}$\n", "$\\newcommand{\\eqdef}{\\equiv}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This numerical tour explore several extensions of the basic Sinkhorn\n", "method." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "addpath('solutions/optimaltransp_6_entropic_adv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Log-domain Sinkhorn \n", "--------------------\n", "For simplicity, we consider uniform distributions on point clouds, so\n", "that the associated histograms are $ (a,b) \\in \\RR^n \\times \\RR^m$\n", "being constant $1/n$ and $1/m$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n = 100;\n", "m = 200; \n", "d = 2; % Dimension of the clouds.\n", "a = ones(n,1)/n; \n", "b = ones(1,m)/m;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Point clouds $x$ and $y$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = rand(2,n)-.5;\n", "theta = 2*pi*rand(1,m);\n", "r = .8 + .2*rand(1,m);\n", "y = [cos(theta).*r; sin(theta).*r];" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display of the two clouds." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plotp = @(x,col)plot(x(1,:)', x(2,:)', 'o', 'MarkerSize', 9, 'MarkerEdgeColor', 'k', 'MarkerFaceColor', col, 'LineWidth', 2);\n", "clf; hold on;\n", "plotp(x, 'b');\n", "plotp(y, 'r');\n", "axis('off'); axis('equal');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cost matrix $C_{i,j} = \\norm{x_i-y_j}^2$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "distmat = @(x,y)sum(x.^2,1)' + sum(y.^2,1) - 2*x.'*y;\n", "C = distmat(x,y);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sinkhorn algorithm is originally implemented using matrix-vector\n", "multipliciation, which is unstable for small epsilon.\n", "Here we consider a log-domain implementation, which operates by\n", "iteratively updating so-called Kantorovitch dual potentials $ (f,g) \\in \\RR^n \\times \\RR^m $.\n", "\n", "\n", "The update are obtained by regularized c-transform, which reads\n", "$$ f_i \\leftarrow {\\min}_\\epsilon^b( C_{i,\\cdot} - g ) $$\n", "$$ g_j \\leftarrow {\\min}_\\epsilon^a( C_{\\cdot,j} - f ), $$\n", "where the regularized minimum operator reads\n", "$$ {\\min}_\\epsilon^a(h) \\eqdef -\\epsilon \\log \\sum_i a_i e^{-h_i/\\epsilon}. $$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mina = @(H,epsilon)-epsilon*log( sum(a .* exp(-H/epsilon),1) );\n", "minb = @(H,epsilon)-epsilon*log( sum(b .* exp(-H/epsilon),2) );" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The regularized min operator defined this way is non-stable, but it can\n", "be stabilized using the celebrated log-sum-exp trick, wich relies on the\n", "fact that for any constant $c \\in \\RR$, one has\n", "$$ {\\min}_\\epsilon^a(h+c) = {\\min}_\\epsilon^a(h) + c, $$\n", "and stabilization is achieved using $c=\\min(h)$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "minaa = @(H,epsilon)mina(H-min(H,[],1),epsilon) + min(H,[],1);\n", "minbb = @(H,epsilon)minb(H-min(H,[],2),epsilon) + min(H,[],2);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Exercise 1__\n", "\n", "Implement Sinkhorn in log domain." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "exo1()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%% Insert your code here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Exercise 2__\n", "\n", "Study the impact of $\\epsilon$ on the convergence rate of the algorithm." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "exo2()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%% Insert your code here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Wasserstein Flow for Matching \n", "------------------------------\n", "We aim at performing a \"Lagrangian\" gradient (also called Wasserstein flow) descent of Wasserstein\n", "distance, in order to perform a non-parametric fitting. This corresponds\n", "to minimizing the energy function\n", "$$ \\Ee(z) \\eqdef W_\\epsilon\\pa{ \\frac{1}{n}\\sum_i \\de_{z_i}, \\frac{1}{m}\\sum_i \\de_{y_i} }. $$\n", "\n", "\n", "Here we have denoted the Sinkhorn score as\n", "$$ W_\\epsilon(\\al,\\be) \\eqdef \\dotp{P}{C} - \\epsilon \\text{KL}(P|ab^\\top)$$\n", "where $\\al=\\frac{1}{n}\\sum_i \\de_{x_i}$ and\n", "$\\be=\\frac{1}{m}\\sum_i \\de_{y_i}$ are the measures (beware that $C$\n", "depends on the points positions)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "z = x; % initialization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The gradient of this energy reads\n", "$$ ( \\nabla \\Ee(z) )_i = \\sum_j P_{i,j}(z_i-y_j) = a_i z_i - \\sum_j P_{i,j} y_j, $$\n", "where $P$ is the optimal coupling. It is better to consider a renormalized gradient, which corresponds\n", "to using the inner product associated to the measure $a$ on the\n", "deformation field, in which case\n", "$$ ( \\bar\\nabla \\Ee(z) )_i = z_i - \\bar y_i \\qwhereq \\bar y_i \\eqdef \\frac{\\sum_j P_{i,j} y_j}{a_i}. $$\n", "Here $\\bar y_i$ is often called the \"barycentric projection\" associated\n", "to the coupling matrix $P$.\n", "\n", "\n", "First run Sinkhorn, beware you need to recompute the cost matrix at each step." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "epsilon = .01;\n", "niter = 300;\n", "C = distmat(z,y); \n", "for it=1:niter\n", " g = mina(C-f,epsilon);\n", " f = minb(C-g,epsilon);\n", "end\n", "P = a .* exp((f+g-C)/epsilon) .* b;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compute the gradient" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "G = z - ( y*P' ) ./ a';" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the gradient field." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clf; hold on;\n", "plotp(x, 'b');\n", "plotp(y, 'r');\n", "for i=1:n\n", " plot([x(1,i), x(1,i)-G(1,i)], [x(2,i), x(2,i)-G(2,i)], 'k');\n", "end\n", "axis('off'); axis('equal');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set the descent step size." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tau = .1;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Update the point cloud." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "z = z - tau * G;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Exercise 3__\n", "\n", "Implement the gradient flow." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "exo3()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%% Insert your code here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Exercise 4__\n", "\n", "Show the evolution of the fit as $\\epsilon$ increases. What do you observe.\n", "Replace the Sinkhorn score $W_\\epsilon(\\al,\\be)$ by the Sinkhorn divergence\n", "$W_\\epsilon(\\al,\\be)-W_\\epsilon(\\al,\\al)/2-W_\\epsilon(\\be,\\be)/2$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "exo4()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%% Insert your code here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generative Model Fitting\n", "------------------------\n", "The Wasserstein is a non-parametric idealization which does not\n", "corresponds to any practical application. We consider here a simple toy\n", "example of density fitting, where the goal is to find a parameter $\\theta$\n", "to fit a deformed point cloud of the form $ (g_\\theta(x_i))_i $ using a\n", "Sinkhorn cost. This is ofen called a generative model in the machine\n", "learning litterature, and corresponds to the problem of shape\n", "registration in imaging.\n", "\n", "\n", "The matching is achieved by solving\n", "$$ \\min_\\th \\Ff(\\th) \\eqdef \\Ee(G_\\th(z)) = W_\\epsilon\\pa{ \\frac{1}{n}\\sum_i \\de_{g_\\th(z_i)}, \\frac{1}{m}\\sum_i \\de_{y_i} }, $$\n", "where the function $G_\\th(z)=( g_\\th(z_i) )_i$ operates independently on\n", "each point.\n", "\n", "\n", "The gradient reads\n", "$$ \\nabla \\Ff(\\th) = \\sum_i \\partial g_\\th(z_i)^*[ \\nabla \\Ee(G_\\th(z))_i ], $$\n", "where $\\partial g_\\th(z_i)^*$ is the adjoint of the Jacobian of\n", "$g_\\th$.\n", "\n", "\n", "We consider here a simple model of affine transformation, where\n", "$\\th=(A,h) \\in \\RR^{d \\times d} \\times \\RR^d $\n", "and $g_\\th(z_i)=Az_i+h$.\n", "\n", "\n", "Denoting $ v_i = \\nabla \\Ee(G_\\th(z))_i $ the gradient of the Sinkhorn\n", "loss (which is computed as in the previous section), the gradient with\n", "respect to the parameter reads\n", "$$ \\nabla_A \\Ff(\\th) = \\sum_i v_i z_i^\\top\n", " \\qandq \\nabla_h \\Ff(\\th) = \\sum_i v_i. $$\n", "\n", "\n", "Generate the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "z = randn(2,n)*.2;\n", "y = randn(2,m)*.2; y(1,:) = y(1,:)*.05 + 1;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize the parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "A = eye(2); h = zeros(2,1);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the clouds." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clf; hold on;\n", "plotp(A*z+h, 'b'); plotp(y, 'r'); \n", "axis('off'); axis('equal');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compute the gradient with respect to parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = A*z+h;\n", "C = distmat(x,y); \n", "f = zeros(n,1);\n", "for it=1:niter\n", "\tg = mina(C-f,epsilon);\n", "\tf = minb(C-g,epsilon);\n", "end\n", "P = a .* exp((f+g-C)/epsilon) .* b;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "gradient with respect to positions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = a' .* z - ( y*P' );" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "gradient with respect to parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nabla_A = v*z';\n", "nabla_h = sum(v,2);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Exercise 5__\n", "\n", "Implement the gradient descent." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "exo5()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%% Insert your code here." ] } ], "metadata": { "kernelspec": { "display_name": "Matlab", "language": "matlab", "name": "matlab_kernel" }, "language_info": { "file_extension": ".m", "help_links": [ { "text": "MetaKernel Magics", "url": "https://github.com/calysto/metakernel/blob/master/metakernel/magics/README.md" } ], "mimetype": "text/x-matlab", "name": "matlab" } }, "nbformat": 4, "nbformat_minor": 1 }