{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Different gradient computations for regularized optimal transport\n\nThis example illustrates the differences in terms of computation time between the gradient options for the Sinkhorn solver.\n\n

Note

Example added in release: 0.9.6

\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Sonia Mazelet \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.backend import torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Time comparison of the Sinkhorn solver for different gradient options\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_trials = 10\ntimes_autodiff = torch.zeros(n_trials)\ntimes_envelope = torch.zeros(n_trials)\ntimes_last_step = torch.zeros(n_trials)\n\nn_samples_s = 300\nn_samples_t = 300\nn_features = 5\nreg = 0.03\n\n# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions\nfor i in range(n_trials):\n x = torch.rand((n_samples_s, n_features))\n y = torch.rand((n_samples_t, n_features))\n a = ot.utils.unif(n_samples_s)\n b = ot.utils.unif(n_samples_t)\n M = ot.dist(x, y)\n\n a = torch.tensor(a, requires_grad=True)\n b = torch.tensor(b, requires_grad=True)\n M = M.clone().detach().requires_grad_(True)\n\n # autodiff provides the gradient for all the outputs (plan, value, value_linear)\n ot.tic()\n res_autodiff = ot.solve(M, a, b, reg=reg, grad=\"autodiff\")\n res_autodiff.value.backward()\n times_autodiff[i] = ot.toq()\n\n a = a.clone().detach().requires_grad_(True)\n b = b.clone().detach().requires_grad_(True)\n M = M.clone().detach().requires_grad_(True)\n\n # envelope provides the gradient for value\n ot.tic()\n res_envelope = ot.solve(M, a, b, reg=reg, grad=\"envelope\")\n res_envelope.value.backward()\n times_envelope[i] = ot.toq()\n\n a = a.clone().detach().requires_grad_(True)\n b = b.clone().detach().requires_grad_(True)\n M = M.clone().detach().requires_grad_(True)\n\n # last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm\n ot.tic()\n res_last_step = ot.solve(M, a, b, reg=reg, grad=\"last_step\")\n res_last_step.value.backward()\n times_last_step[i] = ot.toq()\n\npl.figure(1, figsize=(5, 3))\npl.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0, 0))\npl.boxplot(\n ([times_autodiff, times_envelope, times_last_step]),\n tick_labels=[\"autodiff\", \"envelope\", \"last_step\"],\n showfliers=False,\n)\npl.ylabel(\"Time (s)\")\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }