{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Stochastic examples\n\nThis example is designed to show how to use the stochastic optimization\nalgorithms for discrete and semi-continuous measures from the POT library.\n\n[18] Genevay, A., Cuturi, M., Peyr\u00e9, G. & Bach, F.\nStochastic Optimization for Large-scale Optimal Transport.\nAdvances in Neural Information Processing Systems (2016).\n\n[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. &\nBlondel, M. Large-scale Optimal Transport and Mapping Estimation.\nInternational Conference on Learning Representation (2018)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Kilian Fatras \n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport numpy as np\nimport ot\nimport ot.plot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute the Transportation Matrix for the Semi-Dual Problem\n\n### Discrete case\n\nSample two discrete measures for the discrete case and compute their cost\nmatrix c.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 1000\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Call the \"SAG\" method to find the transportation matrix in the discrete case\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "method = \"SAG\"\nsag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)\nprint(sag_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Semi-Continuous Case\n\nSample one general measure a, one discrete measures b for the semicontinuous\ncase, the points where source and target measures are defined and compute the\ncost matrix.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 1000\nlog = True\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Call the \"ASGD\" method to find the transportation matrix in the semicontinuous\ncase.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "method = \"ASGD\"\nasgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(\n a, b, M, reg, method, numItermax, log=log\n)\nprint(log_asgd[\"alpha\"], log_asgd[\"beta\"])\nprint(asgd_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compare the results with the Sinkhorn algorithm\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sinkhorn_pi = ot.sinkhorn(a, b, M, reg)\nprint(sinkhorn_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot Transportation Matrices\n\nFor SAG\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sag_pi, \"semi-dual : OT matrix SAG\")\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For ASGD\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, asgd_pi, \"semi-dual : OT matrix ASGD\")\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For Sinkhorn\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sinkhorn_pi, \"OT matrix Sinkhorn\")\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute the Transportation Matrix for the Dual Problem\n\n### Semi-continuous case\n\nSample one general measure a, one discrete measures b for the semi-continuous\ncase and compute the cost matrix c.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 100000\nlr = 0.1\nbatch_size = 3\nlog = True\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Call the \"SGD\" dual method to find the transportation matrix in the\nsemi-continuous case\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(\n a, b, M, reg, batch_size, numItermax, lr, log=log\n)\nprint(log_sgd[\"alpha\"], log_sgd[\"beta\"])\nprint(sgd_dual_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare the results with the Sinkhorn algorithm\n\nCall the Sinkhorn algorithm from POT\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sinkhorn_pi = ot.sinkhorn(a, b, M, reg)\nprint(sinkhorn_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot Transportation Matrices\n\nFor SGD\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sgd_dual_pi, \"dual : OT matrix SGD\")\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For Sinkhorn\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sinkhorn_pi, \"OT matrix Sinkhorn\")\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }