\n\nThis example is designed to show how to use SSNB [58] in POT.\nSSNB computes an l-strongly convex potential $\\varphi$ with an L-Lipschitz gradient such that\n$\\nabla \\varphi \\# \\mu \\approx \\nu$. This regularity can be enforced only on the components of a partition\nof the ambient space, which is a relaxation compared to imposing global regularity.\n\nIn this example, we consider a source measure $\\mu_s$ which is the uniform measure on the unit square in\n$\\mathbb{R}^2$, and the target measure $\\mu_t$ which is the image of $\\mu_x$ by\n$T(x_1, x_2) = (x_1 + 2\\mathrm{sign}(x_2), 2 * x_2)$. The map $T$ is non-smooth, and we wish to approximate\nit using a \"Brenier-style\" map $\\nabla \\varphi$ which is regular on the partition\n$\\lbrace x_1 <=0, x_1>0\\rbrace$, which is well adapted to this particular dataset.\n\nWe represent the gradients of the \"bounding potentials\" $\\varphi_l, \\varphi_u$ (from [59], Theorem 3.14),\nwhich bound any SSNB potential which is optimal in the sense of [58], Definition 1:\n\n\\begin{align}\\varphi \\in \\mathrm{argmin}_{\\varphi \\in \\mathcal{F}}\\ \\mathrm{W}_2(\\nabla \\varphi \\#\\mu_s, \\mu_t),\\end{align}\n\nwhere $\\mathcal{F}$ is the space functions that are on every set $E_k$ l-strongly convex\nwith an L-Lipschitz gradient, given $(E_k)_{k \\in [K]}$ a partition of the ambient source space.\n\nWe perform the optimisation on a low amount of fitting samples and with few iterations,\nsince solving the SSNB problem is quite computationally expensive.\n\nTHIS EXAMPLE REQUIRES CVXPY\n\n.. [58] Fran\u00e7ois-Pierre Paty, Alexandre d\u2019Aspremont, and Marco Cuturi. Regularity as regularization:\n Smooth and strongly convex brenier potentials in optimal transport. In International Conference\n on Artificial Intelligence and Statistics, pages 1222\u20131232. PMLR, 2020.\n\n.. [59] Adrien B Taylor. Convex interpolation and performance estimation of first-order methods for\n convex optimization. PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium,\n 2017.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Author: Eloi Tanguy \n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport ot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generating the fitting data\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"n_fitting_samples = 30\nrng = np.random.RandomState(seed=0)\nXs = rng.uniform(-1, 1, size=(n_fitting_samples, 2))\nXs_classes = (Xs[:, 0] < 0).astype(int)\nXt = np.stack([Xs[:, 0] + 2 * np.sign(Xs[:, 0]), 2 * Xs[:, 1]], axis=-1)\n\nplt.scatter(\n Xs[Xs_classes == 0, 0], Xs[Xs_classes == 0, 1], c=\"blue\", label=\"source class 0\"\n)\nplt.scatter(\n Xs[Xs_classes == 1, 0],\n Xs[Xs_classes == 1, 1],\n c=\"dodgerblue\",\n label=\"source class 1\",\n)\nplt.scatter(Xt[:, 0], Xt[:, 1], c=\"red\", label=\"target\")\nplt.axis(\"equal\")\nplt.title(\"Splitting sphere dataset\")\nplt.legend(loc=\"upper right\")\nplt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Fitting the Nearest Brenier Potential\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"L = 3 # need L > 2 to allow the 2*y term, default is 1.4\nphi, G = ot.mapping.nearest_brenier_potential_fit(\n Xs, Xt, Xs_classes, its=10, init_method=\"barycentric\", gradient_lipschitz_constant=L\n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plotting the images of the source data\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.clf()\nplt.scatter(Xs[:, 0], Xs[:, 1], c=\"dodgerblue\", label=\"source\")\nplt.scatter(Xt[:, 0], Xt[:, 1], c=\"red\", label=\"target\")\nfor i in range(n_fitting_samples):\n plt.plot([Xs[i, 0], G[i, 0]], [Xs[i, 1], G[i, 1]], color=\"black\", alpha=0.5)\nplt.title(\"Images of in-data source samples by the fitted SSNB\")\nplt.legend(loc=\"upper right\")\nplt.axis(\"equal\")\nplt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Computing the predictions (images by nabla phi) for random samples of the source distribution\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"n_predict_samples = 50\nYs = rng.uniform(-1, 1, size=(n_predict_samples, 2))\nYs_classes = (Ys[:, 0] < 0).astype(int)\nphi_lu, G_lu = ot.mapping.nearest_brenier_potential_predict_bounds(\n Xs, phi, G, Ys, Xs_classes, Ys_classes, gradient_lipschitz_constant=L\n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot predictions for the gradient of the lower-bounding potential\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.clf()\nplt.scatter(Xs[:, 0], Xs[:, 1], c=\"dodgerblue\", label=\"source\")\nplt.scatter(Xt[:, 0], Xt[:, 1], c=\"red\", label=\"target\")\nfor i in range(n_predict_samples):\n plt.plot(\n [Ys[i, 0], G_lu[0, i, 0]], [Ys[i, 1], G_lu[0, i, 1]], color=\"black\", alpha=0.5\n )\nplt.title(\"Images of new source samples by $\\\\nabla \\\\varphi_l$\")\nplt.legend(loc=\"upper right\")\nplt.axis(\"equal\")\nplt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot predictions for the gradient of the upper-bounding potential\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.clf()\nplt.scatter(Xs[:, 0], Xs[:, 1], c=\"dodgerblue\", label=\"source\")\nplt.scatter(Xt[:, 0], Xt[:, 1], c=\"red\", label=\"target\")\nfor i in range(n_predict_samples):\n plt.plot(\n [Ys[i, 0], G_lu[1, i, 0]], [Ys[i, 1], G_lu[1, i, 1]], color=\"black\", alpha=0.5\n )\nplt.title(\"Images of new source samples by $\\\\nabla \\\\varphi_u$\")\nplt.legend(loc=\"upper right\")\nplt.axis(\"equal\")\nplt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 0
}