{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Introduction to Optimal Transport with Python\n\nThis example gives an introduction on how to use Optimal Transport in Python.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary, Nicolas Courty, Aurelie Boisbunon\n#\n# License: MIT License\n# sphinx_gallery_thumbnail_number = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## POT Python Optimal Transport Toolbox\n\n### POT installation\n\n* Install with pip::\n\n pip install pot\n* Install with conda::\n\n conda install -c conda-forge pot\n\n### Import the toolbox\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np # always need it\nimport pylab as pl # do the plots\n\nimport ot # ot\n\nimport time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Getting help\n\nOnline documentation : [](https://pythonot.github.io/all.html)\n\nOr inline help:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "help(ot.dist)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First OT Problem\n\nWe will solve the Bakery/Caf\u00e9s problem of transporting croissants from a\nnumber of Bakeries to Caf\u00e9s in a City (in this case Manhattan). We did a\nquick google map search in Manhattan for bakeries and Caf\u00e9s:\n\n\"bakery-cafe-manhattan\"\n\nWe extracted from this search their positions and generated fictional\nproduction and sale number (that both sum to the same value).\n\nWe have access to the position of Bakeries ``bakery_pos`` and their\nrespective production ``bakery_prod`` which describe the source\ndistribution. The Caf\u00e9s where the croissants are sold are defined also by\ntheir position ``cafe_pos`` and ``cafe_prod``, and describe the target\ndistribution. For fun we also provide a\nmap ``Imap`` that will illustrate the position of these shops in the city.\n\n\nNow we load the data\n\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "data = np.load(\"../data/manhattan.npz\")\n\nbakery_pos = data[\"bakery_pos\"]\nbakery_prod = data[\"bakery_prod\"]\ncafe_pos = data[\"cafe_pos\"]\ncafe_prod = data[\"cafe_prod\"]\nImap = data[\"Imap\"]\n\nprint(\"Bakery production: {}\".format(bakery_prod))\nprint(\"Cafe sale: {}\".format(cafe_prod))\nprint(\"Total croissants : {}\".format(cafe_prod.sum()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting bakeries in the city\n\nNext we plot the position of the bakeries and caf\u00e9s on the map. The size of\nthe circle is proportional to their production.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1, (7, 6))\npl.clf()\npl.imshow(Imap, interpolation=\"bilinear\") # plot the map\npl.scatter(\n bakery_pos[:, 0], bakery_pos[:, 1], s=bakery_prod, c=\"r\", ec=\"k\", label=\"Bakeries\"\n)\npl.scatter(cafe_pos[:, 0], cafe_pos[:, 1], s=cafe_prod, c=\"b\", ec=\"k\", label=\"Caf\u00e9s\")\npl.legend()\npl.title(\"Manhattan Bakeries and Caf\u00e9s\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cost matrix\n\n\nWe can now compute the cost matrix between the bakeries and the caf\u00e9s, which\nwill be the transport cost matrix. This can be done using the\n[ot.dist](https://pythonot.github.io/all.html#ot.dist) function that\ndefaults to squared Euclidean distance but can return other things such as\ncityblock (or Manhattan distance).\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "C = ot.dist(bakery_pos, cafe_pos)\n\nlabels = [str(i) for i in range(len(bakery_prod))]\nf = pl.figure(2, (14, 7))\npl.clf()\npl.subplot(121)\npl.imshow(Imap, interpolation=\"bilinear\") # plot the map\nfor i in range(len(cafe_pos)):\n pl.text(\n cafe_pos[i, 0],\n cafe_pos[i, 1],\n labels[i],\n color=\"b\",\n fontsize=14,\n fontweight=\"bold\",\n ha=\"center\",\n va=\"center\",\n )\nfor i in range(len(bakery_pos)):\n pl.text(\n bakery_pos[i, 0],\n bakery_pos[i, 1],\n labels[i],\n color=\"r\",\n fontsize=14,\n fontweight=\"bold\",\n ha=\"center\",\n va=\"center\",\n )\npl.title(\"Manhattan Bakeries and Caf\u00e9s\")\n\nax = pl.subplot(122)\nim = pl.imshow(C, cmap=\"coolwarm\")\npl.title(\"Cost matrix\")\ncbar = pl.colorbar(im, ax=ax, shrink=0.5, use_gridspec=True)\ncbar.ax.set_ylabel(\"cost\", rotation=-90, va=\"bottom\")\n\npl.xlabel(\"Caf\u00e9s\")\npl.ylabel(\"Bakeries\")\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The red cells in the matrix image show the bakeries and caf\u00e9s that are\nfurther away, and thus more costly to transport from one to the other, while\nthe blue ones show those that are very close to each other, with respect to\nthe squared Euclidean distance.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solving the OT problem with [ot.emd](https://pythonot.github.io/all.html#ot.emd)\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "start = time.time()\not_emd = ot.emd(bakery_prod, cafe_prod, C)\ntime_emd = time.time() - start" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function returns the transport matrix, which we can then visualize (next section).\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Transportation plan visualization\n\nA good visualization of the OT matrix in the 2D plane is to denote the\ntransportation of mass between a Bakery and a Caf\u00e9 by a line. This can easily\nbe done with a double ``for`` loop.\n\nIn order to make it more interpretable one can also use the ``alpha``\nparameter of plot and set it to ``alpha=G[i,j]/G.max()``.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Plot the matrix and the map\nf = pl.figure(3, (14, 7))\npl.clf()\npl.subplot(121)\npl.imshow(Imap, interpolation=\"bilinear\") # plot the map\nfor i in range(len(bakery_pos)):\n for j in range(len(cafe_pos)):\n pl.plot(\n [bakery_pos[i, 0], cafe_pos[j, 0]],\n [bakery_pos[i, 1], cafe_pos[j, 1]],\n \"-k\",\n lw=3.0 * ot_emd[i, j] / ot_emd.max(),\n )\nfor i in range(len(cafe_pos)):\n pl.text(\n cafe_pos[i, 0],\n cafe_pos[i, 1],\n labels[i],\n color=\"b\",\n fontsize=14,\n fontweight=\"bold\",\n ha=\"center\",\n va=\"center\",\n )\nfor i in range(len(bakery_pos)):\n pl.text(\n bakery_pos[i, 0],\n bakery_pos[i, 1],\n labels[i],\n color=\"r\",\n fontsize=14,\n fontweight=\"bold\",\n ha=\"center\",\n va=\"center\",\n )\npl.title(\"Manhattan Bakeries and Caf\u00e9s\")\n\nax = pl.subplot(122)\nim = pl.imshow(ot_emd)\nfor i in range(len(bakery_prod)):\n for j in range(len(cafe_prod)):\n text = ax.text(\n j, i, \"{0:g}\".format(ot_emd[i, j]), ha=\"center\", va=\"center\", color=\"w\"\n )\npl.title(\"Transport matrix\")\n\npl.xlabel(\"Caf\u00e9s\")\npl.ylabel(\"Bakeries\")\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The transport matrix gives the number of croissants that can be transported\nfrom each bakery to each caf\u00e9. We can see that the bakeries only need to\ntransport croissants to one or two caf\u00e9s, the transport matrix being very\nsparse.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## OT loss and dual variables\n\nThe resulting wasserstein loss loss is of the form:\n\n\\begin{align}W=\\sum_{i,j}\\gamma_{i,j}C_{i,j}\\end{align}\n\nwhere $\\gamma$ is the optimal transport matrix.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "W = np.sum(ot_emd * C)\nprint(\"Wasserstein loss (EMD) = {0:.2f}\".format(W))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regularized OT with Sinkhorn\n\nThe Sinkhorn algorithm is very simple to code. You can implement it directly\nusing the following pseudo-code\n\n\"Sinkhorn\n\nIn this algorithm, $\\oslash$ corresponds to the element-wise division.\n\nAn alternative is to use the POT toolbox with\n[ot.sinkhorn](https://pythonot.github.io/all.html#ot.sinkhorn)\n\nBe careful of numerical problems. A good pre-processing for Sinkhorn is to\ndivide the cost matrix ``C`` by its maximum value.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Algorithm\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Compute Sinkhorn transport matrix from algorithm\nreg = 0.1\nK = np.exp(-C / C.max() / reg)\nnit = 100\nu = np.ones((len(bakery_prod),))\nfor i in range(1, nit):\n v = cafe_prod / np.dot(K.T, u)\n u = bakery_prod / (np.dot(K, v))\not_sink_algo = np.atleast_2d(u).T * (\n K * v.T\n) # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v)))\n\n# Compute Sinkhorn transport matrix with POT\not_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg, M=C / C.max())\n\n# Difference between the 2\nprint(\n \"Difference between algo and ot.sinkhorn = {0:.2g}\".format(\n np.sum(np.power(ot_sink_algo - ot_sinkhorn, 2))\n )\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot the matrix and the map\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"Min. of Sinkhorn's transport matrix = {0:.2g}\".format(np.min(ot_sinkhorn)))\n\nf = pl.figure(4, (13, 6))\npl.clf()\npl.subplot(121)\npl.imshow(Imap, interpolation=\"bilinear\") # plot the map\nfor i in range(len(bakery_pos)):\n for j in range(len(cafe_pos)):\n pl.plot(\n [bakery_pos[i, 0], cafe_pos[j, 0]],\n [bakery_pos[i, 1], cafe_pos[j, 1]],\n \"-k\",\n lw=3.0 * ot_sinkhorn[i, j] / ot_sinkhorn.max(),\n )\nfor i in range(len(cafe_pos)):\n pl.text(\n cafe_pos[i, 0],\n cafe_pos[i, 1],\n labels[i],\n color=\"b\",\n fontsize=14,\n fontweight=\"bold\",\n ha=\"center\",\n va=\"center\",\n )\nfor i in range(len(bakery_pos)):\n pl.text(\n bakery_pos[i, 0],\n bakery_pos[i, 1],\n labels[i],\n color=\"r\",\n fontsize=14,\n fontweight=\"bold\",\n ha=\"center\",\n va=\"center\",\n )\npl.title(\"Manhattan Bakeries and Caf\u00e9s\")\n\nax = pl.subplot(122)\nim = pl.imshow(ot_sinkhorn)\nfor i in range(len(bakery_prod)):\n for j in range(len(cafe_prod)):\n text = ax.text(\n j, i, np.round(ot_sinkhorn[i, j], 1), ha=\"center\", va=\"center\", color=\"w\"\n )\npl.title(\"Transport matrix\")\n\npl.xlabel(\"Caf\u00e9s\")\npl.ylabel(\"Bakeries\")\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We notice right away that the matrix is not sparse at all with Sinkhorn,\neach bakery delivering croissants to all 5 caf\u00e9s with that solution. Also,\nthis solution gives a transport with fractions, which does not make sense\nin the case of croissants. This was not the case with EMD.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Varying the regularization parameter in Sinkhorn\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reg_parameter = np.logspace(-3, 0, 20)\nW_sinkhorn_reg = np.zeros((len(reg_parameter),))\ntime_sinkhorn_reg = np.zeros((len(reg_parameter),))\n\nf = pl.figure(5, (14, 5))\npl.clf()\nmax_ot = 100 # plot matrices with the same colorbar\nfor k in range(len(reg_parameter)):\n start = time.time()\n ot_sinkhorn = ot.sinkhorn(\n bakery_prod, cafe_prod, reg=reg_parameter[k], M=C / C.max()\n )\n time_sinkhorn_reg[k] = time.time() - start\n\n if k % 4 == 0 and k > 0: # we only plot a few\n ax = pl.subplot(1, 5, k // 4)\n im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot)\n pl.title(\"reg={0:.2g}\".format(reg_parameter[k]))\n pl.xlabel(\"Caf\u00e9s\")\n pl.ylabel(\"Bakeries\")\n\n # Compute the Wasserstein loss for Sinkhorn, and compare with EMD\n W_sinkhorn_reg[k] = np.sum(ot_sinkhorn * C)\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This series of graph shows that the solution of Sinkhorn starts with something\nvery similar to EMD (although not sparse) for very small values of the\nregularization parameter, and tends to a more uniform solution as the\nregularization parameter increases.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Wasserstein loss and computational time\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Plot the matrix and the map\nf = pl.figure(6, (4, 4))\npl.clf()\npl.title(\"Comparison between Sinkhorn and EMD\")\n\npl.plot(reg_parameter, W_sinkhorn_reg, \"o\", label=\"Sinkhorn\")\nXLim = pl.xlim()\npl.plot(XLim, [W, W], \"--k\", label=\"EMD\")\npl.legend()\npl.xlabel(\"reg\")\npl.ylabel(\"Wasserstein loss\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this last graph, we show the impact of the regularization parameter on\nthe Wasserstein loss. We can see that higher\nvalues of ``reg`` leads to a much higher Wasserstein loss.\n\nThe Wasserstein loss of EMD is displayed for\ncomparison. The Wasserstein loss of Sinkhorn can be a little lower than that\nof EMD for low values of ``reg``, but it quickly gets much higher.\n\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }