{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Linear OT mapping estimation\n\n

Note

Example updated in release: 0.9.1.

\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\nfrom pathlib import Path\n\nimport numpy as np\nfrom matplotlib import pyplot as plt\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 1000\nd = 2\nsigma = 0.1\n\nrng = np.random.RandomState(42)\n\n# source samples\nangles = rng.rand(n, 1) * 2 * np.pi\nxs = np.concatenate((np.sin(angles), np.cos(angles)), axis=1) + sigma * rng.randn(n, 2)\nxs[: n // 2, 1] += 2\n\n\n# target samples\nanglet = rng.rand(n, 1) * 2 * np.pi\nxt = np.concatenate((np.sin(anglet), np.cos(anglet)), axis=1) + sigma * rng.randn(n, 2)\nxt[: n // 2, 1] += 2\n\n\nA = np.array([[1.5, 0.7], [0.7, 1.5]])\nb = np.array([[4, 2]])\nxt = xt.dot(A) + b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(1, (5, 5))\nplt.plot(xs[:, 0], xs[:, 1], \"+\")\nplt.plot(xt[:, 0], xt[:, 1], \"o\")\nplt.legend((\"Source\", \"Target\"))\nplt.title(\"Source and target distributions\")\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimate linear mapping and transport\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Gaussian (linear) Monge mapping estimation\nAe, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt)\n\nxst = xs.dot(Ae) + be\n\n# Gaussian (linear) GW mapping estimation\nAgw, bgw = ot.gaussian.empirical_gaussian_gromov_wasserstein_mapping(xs, xt)\n\nxstgw = xs.dot(Agw) + bgw" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot transported samples\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(2, (10, 5))\nplt.clf()\nplt.subplot(1, 2, 1)\nplt.plot(xs[:, 0], xs[:, 1], \"+\")\nplt.plot(xt[:, 0], xt[:, 1], \"o\")\nplt.plot(xst[:, 0], xst[:, 1], \"+\")\nplt.legend((\"Source\", \"Target\", \"Transp. Monge\"), loc=0)\nplt.title(\"Transported samples with Monge\")\nplt.subplot(1, 2, 2)\nplt.plot(xs[:, 0], xs[:, 1], \"+\")\nplt.plot(xt[:, 0], xt[:, 1], \"o\")\nplt.plot(xstgw[:, 0], xstgw[:, 1], \"+\")\nplt.legend((\"Source\", \"Target\", \"Transp. GW\"), loc=0)\nplt.title(\"Transported samples with Gaussian GW\")\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load image data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def im2mat(img):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))\n\n\ndef mat2im(X, shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\n\ndef minmax(img):\n return np.clip(img, 0, 1)\n\n\n# Loading images\nthis_file = os.path.realpath(\"__file__\")\ndata_path = os.path.join(Path(this_file).parent.parent.parent, \"data\")\n\nI1 = plt.imread(os.path.join(data_path, \"ocean_day.jpg\")).astype(np.float64) / 256\nI2 = plt.imread(os.path.join(data_path, \"ocean_sunset.jpg\")).astype(np.float64) / 256\n\n\nX1 = im2mat(I1)\nX2 = im2mat(I2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimate mapping and adapt\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Monge mapping\nmapping = ot.da.LinearTransport()\nmapping.fit(Xs=X1, Xt=X2)\n\n\nxst = mapping.transform(Xs=X1)\nxts = mapping.inverse_transform(Xt=X2)\n\nI1t = minmax(mat2im(xst, I1.shape))\nI2t = minmax(mat2im(xts, I2.shape))\n\n# gaussian GW mapping\n\nmapping = ot.da.LinearGWTransport()\nmapping.fit(Xs=X1, Xt=X2)\n\n\nxstgw = mapping.transform(Xs=X1)\nxtsgw = mapping.inverse_transform(Xt=X2)\n\nI1tgw = minmax(mat2im(xstgw, I1.shape))\nI2tgw = minmax(mat2im(xtsgw, I2.shape))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot transformed images\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(3, figsize=(14, 7))\n\nplt.subplot(2, 3, 1)\nplt.imshow(I1)\nplt.axis(\"off\")\nplt.title(\"Im. 1\")\n\nplt.subplot(2, 3, 4)\nplt.imshow(I2)\nplt.axis(\"off\")\nplt.title(\"Im. 2\")\n\nplt.subplot(2, 3, 2)\nplt.imshow(I1t)\nplt.axis(\"off\")\nplt.title(\"Monge mapping Im. 1\")\n\nplt.subplot(2, 3, 5)\nplt.imshow(I2t)\nplt.axis(\"off\")\nplt.title(\"Inverse Monge mapping Im. 2\")\n\nplt.subplot(2, 3, 3)\nplt.imshow(I1tgw)\nplt.axis(\"off\")\nplt.title(\"Gaussian GW mapping Im. 1\")\n\nplt.subplot(2, 3, 6)\nplt.imshow(I2tgw)\nplt.axis(\"off\")\nplt.title(\"Inverse Gaussian GW mapping Im. 2\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }