{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# OT for multi-source target shift\n\n

Note

Example added in release: 0.7.0.

\n\nThis example introduces a target shift problem with two 2D source and 1 target domain.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: Remi Flamary \n# Ievgen Redko \n#\n# License: MIT License\n\nimport pylab as pl\nimport numpy as np\nimport ot\nfrom ot.datasets import make_data_classif" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 50\nsigma = 0.3\nnp.random.seed(1985)\n\np1 = 0.2\ndec1 = [0, 2]\n\np2 = 0.9\ndec2 = [0, -2]\n\npt = 0.4\ndect = [4, 0]\n\nxs1, ys1 = make_data_classif(\"2gauss_prop\", n, nz=sigma, p=p1, bias=dec1)\nxs2, ys2 = make_data_classif(\"2gauss_prop\", n + 1, nz=sigma, p=p2, bias=dec2)\nxt, yt = make_data_classif(\"2gauss_prop\", n, nz=sigma, p=pt, bias=dect)\n\nall_Xr = [xs1, xs2]\nall_Yr = [ys1, ys2]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "da = 1.5\n\n\ndef plot_ax(dec, name):\n pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], \"k\", alpha=0.5)\n pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], \"k\", alpha=0.5)\n pl.text(dec[0] - 0.5, dec[1] + 2, name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fig 1 : plots source and target samples\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1)\npl.clf()\nplot_ax(dec1, \"Source 1\")\nplot_ax(dec2, \"Source 2\")\nplot_ax(dect, \"Target\")\npl.scatter(\n xs1[:, 0],\n xs1[:, 1],\n c=ys1,\n s=35,\n marker=\"x\",\n cmap=\"Set1\",\n vmax=9,\n label=\"Source 1 ({:1.2f}, {:1.2f})\".format(1 - p1, p1),\n)\npl.scatter(\n xs2[:, 0],\n xs2[:, 1],\n c=ys2,\n s=35,\n marker=\"+\",\n cmap=\"Set1\",\n vmax=9,\n label=\"Source 2 ({:1.2f}, {:1.2f})\".format(1 - p2, p2),\n)\npl.scatter(\n xt[:, 0],\n xt[:, 1],\n c=yt,\n s=35,\n marker=\"o\",\n cmap=\"Set1\",\n vmax=9,\n label=\"Target ({:1.2f}, {:1.2f})\".format(1 - pt, pt),\n)\npl.title(\"Data\")\n\npl.legend()\npl.axis(\"equal\")\npl.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Instantiate Sinkhorn transport algorithm and fit them for all source domains\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric=\"sqeuclidean\")\n\n\ndef print_G(G, xs, ys, xt):\n for i in range(G.shape[0]):\n for j in range(G.shape[1]):\n if G[i, j] > 5e-4:\n if ys[i]:\n c = \"b\"\n else:\n c = \"r\"\n pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fig 2 : plot optimal couplings and transported samples\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2)\npl.clf()\nplot_ax(dec1, \"Source 1\")\nplot_ax(dec2, \"Source 2\")\nplot_ax(dect, \"Target\")\nprint_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)\nprint_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker=\"x\", cmap=\"Set1\", vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker=\"+\", cmap=\"Set1\", vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker=\"o\", cmap=\"Set1\", vmax=9)\n\npl.plot([], [], \"r\", alpha=0.2, label=\"Mass from Class 1\")\npl.plot([], [], \"b\", alpha=0.2, label=\"Mass from Class 2\")\n\npl.title(\"Independent OT\")\n\npl.legend()\npl.axis(\"equal\")\npl.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Instantiate JCPOT adaptation algorithm and fit it\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "otda = ot.da.JCPOTTransport(\n reg_e=1, max_iter=1000, metric=\"sqeuclidean\", tol=1e-9, verbose=True, log=True\n)\notda.fit(all_Xr, all_Yr, xt)\n\nws1 = otda.proportions_.dot(otda.log_[\"D2\"][0])\nws2 = otda.proportions_.dot(otda.log_[\"D2\"][1])\n\npl.figure(3)\npl.clf()\nplot_ax(dec1, \"Source 1\")\nplot_ax(dec2, \"Source 2\")\nplot_ax(dect, \"Target\")\nprint_G(ot.bregman.sinkhorn(ws1, [], otda.log_[\"M\"][0], reg=1e-1), xs1, ys1, xt)\nprint_G(ot.bregman.sinkhorn(ws2, [], otda.log_[\"M\"][1], reg=1e-1), xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker=\"x\", cmap=\"Set1\", vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker=\"+\", cmap=\"Set1\", vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker=\"o\", cmap=\"Set1\", vmax=9)\n\npl.plot([], [], \"r\", alpha=0.2, label=\"Mass from Class 1\")\npl.plot([], [], \"b\", alpha=0.2, label=\"Mass from Class 2\")\n\npl.title(\n \"OT with prop estimation ({:1.3f},{:1.3f})\".format(\n otda.proportions_[0], otda.proportions_[1]\n )\n)\n\npl.legend()\npl.axis(\"equal\")\npl.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run oracle transport algorithm with known proportions\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "h_res = np.array([1 - pt, pt])\n\nws1 = h_res.dot(otda.log_[\"D2\"][0])\nws2 = h_res.dot(otda.log_[\"D2\"][1])\n\npl.figure(4)\npl.clf()\nplot_ax(dec1, \"Source 1\")\nplot_ax(dec2, \"Source 2\")\nplot_ax(dect, \"Target\")\nprint_G(ot.bregman.sinkhorn(ws1, [], otda.log_[\"M\"][0], reg=1e-1), xs1, ys1, xt)\nprint_G(ot.bregman.sinkhorn(ws2, [], otda.log_[\"M\"][1], reg=1e-1), xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker=\"x\", cmap=\"Set1\", vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker=\"+\", cmap=\"Set1\", vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker=\"o\", cmap=\"Set1\", vmax=9)\n\npl.plot([], [], \"r\", alpha=0.2, label=\"Mass from Class 1\")\npl.plot([], [], \"b\", alpha=0.2, label=\"Mass from Class 2\")\n\npl.title(\"OT with known proportion ({:1.1f},{:1.1f})\".format(h_res[0], h_res[1]))\n\npl.legend()\npl.axis(\"equal\")\npl.axis(\"off\")\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }