{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Entropic Wasserstein Component Analysis\n\n

Note

Example added in release: 0.9.1.

\n\nThis example illustrates the use of EWCA as proposed in [52].\n\n\n[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023).\nEntropic Wasserstein Component Analysis.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Antoine Collas \n#\n# License: MIT License\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom ot.dr import ewca\nfrom sklearn.datasets import make_blobs\nfrom matplotlib import ticker as mticker\nimport matplotlib.patches as patches\nimport matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_samples = 20\nesp = 0.8\ncenters = np.array([[esp, esp], [-esp, -esp]])\ncluster_std = 0.4\n\nrng = np.random.RandomState(42)\nX, y = make_blobs(\n n_samples=n_samples,\n n_features=2,\n centers=centers,\n cluster_std=cluster_std,\n shuffle=False,\n random_state=rng,\n)\nX = X - X.mean(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = pl.figure(figsize=(4, 4))\ncmap = matplotlib.colormaps.get_cmap(\"tab10\")\npl.scatter(\n X[: n_samples // 2, 0],\n X[: n_samples // 2, 1],\n color=[cmap(y[i] + 1) for i in range(n_samples // 2)],\n alpha=0.4,\n label=\"Class 1\",\n zorder=30,\n s=50,\n)\npl.scatter(\n X[n_samples // 2 :, 0],\n X[n_samples // 2 :, 1],\n color=[cmap(y[i] + 1) for i in range(n_samples // 2, n_samples)],\n alpha=0.4,\n label=\"Class 2\",\n zorder=30,\n s=50,\n)\nx_y_lim = 2.5\nfs = 15\npl.xlim(-x_y_lim, x_y_lim)\npl.xticks([])\npl.ylim(-x_y_lim, x_y_lim)\npl.yticks([])\npl.legend(fontsize=fs)\npl.title(\"Data\", fontsize=fs)\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute EWCA\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pi, U = ewca(X, k=2, reg=0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data, first component, and projected data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = pl.figure(figsize=(4, 4))\n\nscale = 3\nu = U[:, 0]\npl.plot(\n [scale * u[0], -scale * u[0]],\n [scale * u[1], -scale * u[1]],\n color=\"grey\",\n linestyle=\"--\",\n lw=3,\n alpha=0.3,\n label=r\"$\\mathbf{U}$\",\n)\nX1 = X @ u[:, None] @ u[:, None].T\n\nfor i in range(n_samples):\n for j in range(n_samples):\n v = pi[i, j] / pi.max()\n if v >= 0.15 or (i, j) == (n_samples - 1, n_samples - 1):\n pl.plot(\n [X[i, 0], X1[j, 0]],\n [X[i, 1], X1[j, 1]],\n alpha=v,\n linestyle=\"-\",\n c=\"C0\",\n label=r\"$\\pi_{ij}$\"\n if (i, j) == (n_samples - 1, n_samples - 1)\n else None,\n )\npl.scatter(\n X[:, 0],\n X[:, 1],\n color=[cmap(y[i] + 1) for i in range(n_samples)],\n alpha=0.4,\n label=r\"$\\mathbf{x}_i$\",\n zorder=30,\n s=50,\n)\npl.scatter(\n X1[:, 0],\n X1[:, 1],\n color=[cmap(y[i] + 1) for i in range(n_samples)],\n alpha=0.9,\n s=50,\n marker=\"+\",\n label=r\"$\\mathbf{U}\\mathbf{U}^{\\top}\\mathbf{x}_i$\",\n zorder=30,\n)\npl.title(\"Data and projections\", fontsize=fs)\npl.xlim(-x_y_lim, x_y_lim)\npl.xticks([])\npl.ylim(-x_y_lim, x_y_lim)\npl.yticks([])\npl.legend(fontsize=fs, loc=\"upper left\")\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot transport plan\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = pl.figure(figsize=(5, 5))\n\nnorm = matplotlib.colors.PowerNorm(0.5, vmin=0, vmax=100)\nim = pl.imshow(n_samples * pi * 100, cmap=pl.cm.Blues, norm=norm, aspect=\"auto\")\ncb = fig.colorbar(im, orientation=\"vertical\", shrink=0.8)\nticks_loc = cb.ax.get_yticks().tolist()\ncb.ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))\ncb.ax.set_yticklabels([f\"{int(i)}%\" for i in cb.get_ticks()])\ncb.ax.tick_params(labelsize=fs)\nfor i, class_ in enumerate(np.sort(np.unique(y))):\n indices = y == class_\n idx_min = np.min(np.arange(len(y))[indices])\n idx_max = np.max(np.arange(len(y))[indices])\n width = idx_max - idx_min + 1\n rect = patches.Rectangle(\n (idx_min - 0.5, idx_min - 0.5),\n width,\n width,\n linewidth=1,\n edgecolor=\"r\",\n facecolor=\"none\",\n )\n pl.gca().add_patch(rect)\n\npl.title(\"OT plan\", fontsize=fs)\npl.ylabel(r\"($\\mathbf{x}_1, \\cdots, \\mathbf{x}_n$)\")\nx_label = r\"($\\mathbf{U}\\mathbf{U}^{\\top}\\mathbf{x}_1, \\cdots,\"\nx_label += r\"\\mathbf{U}\\mathbf{U}^{\\top}\\mathbf{x}_n$)\"\npl.xlabel(x_label)\npl.tight_layout()\npl.axis(\"scaled\")\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }