{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Wasserstein Discriminant Analysis\n\n

Note

Example added in release: 0.3.0.

\n\nThis example illustrate the use of WDA as proposed in [11].\n\n\n[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).\nWasserstein Discriminant Analysis.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\n\nfrom ot.dr import wda, fda" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 1000 # nb samples in source and target datasets\nnz = 0.2\n\nnp.random.seed(1)\n\n# generate circle dataset\nt = np.random.rand(n) * 2 * np.pi\nys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxs = np.concatenate((np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nt = np.random.rand(n) * 2 * np.pi\nyt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxt = np.concatenate((np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nnbnoise = 8\n\nxs = np.hstack((xs, np.random.randn(n, nbnoise)))\nxt = np.hstack((xt, np.random.randn(n, nbnoise)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1, figsize=(6.4, 3.5))\n\npl.subplot(1, 2, 1)\npl.scatter(xt[:, 0], xt[:, 1], c=ys, marker=\"+\", label=\"Source samples\")\npl.legend(loc=0)\npl.title(\"Discriminant dimensions\")\n\npl.subplot(1, 2, 2)\npl.scatter(xt[:, 2], xt[:, 3], c=ys, marker=\"+\", label=\"Source samples\")\npl.legend(loc=0)\npl.title(\"Other dimensions\")\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute Fisher Discriminant Analysis\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "p = 2\n\nPfda, projfda = fda(xs, ys, p)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute Wasserstein Discriminant Analysis\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "p = 2\nreg = 1e0\nk = 10\nmaxiter = 100\n\nP0 = np.random.randn(xs.shape[1], p)\n\nP0 /= np.sqrt(np.sum(P0**2, 0, keepdims=True))\n\nPwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter, P0=P0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot 2D projections\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "xsp = projfda(xs)\nxtp = projfda(xt)\n\nxspw = projwda(xs)\nxtpw = projwda(xt)\n\npl.figure(2)\n\npl.subplot(2, 2, 1)\npl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker=\"+\", label=\"Projected samples\")\npl.legend(loc=0)\npl.title(\"Projected training samples FDA\")\n\npl.subplot(2, 2, 2)\npl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker=\"+\", label=\"Projected samples\")\npl.legend(loc=0)\npl.title(\"Projected test samples FDA\")\n\npl.subplot(2, 2, 3)\npl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker=\"+\", label=\"Projected samples\")\npl.legend(loc=0)\npl.title(\"Projected training samples WDA\")\n\npl.subplot(2, 2, 4)\npl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker=\"+\", label=\"Projected samples\")\npl.legend(loc=0)\npl.title(\"Projected test samples WDA\")\npl.tight_layout()\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }