{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Partial Wasserstein in 1D\n\nThis script demonstrates how to compute and visualize the Partial Wasserstein distance between two 1D discrete distributions using `ot.partial.partial_wasserstein_1d`.\n\nWe illustrate the intermediate transport plans for all `k = 1...n`, where `n = min(len(x_a), len(x_b))`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# sphinx_gallery_thumbnail_number = 5\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom ot.partial import partial_wasserstein_1d\n\n\ndef plot_partial_transport(\n ax, x_a, x_b, indices_a=None, indices_b=None, marginal_costs=None\n):\n y_a = np.ones_like(x_a)\n y_b = -np.ones_like(x_b)\n min_min = min(x_a.min(), x_b.min())\n max_max = max(x_a.max(), x_b.max())\n\n ax.plot([min_min - 1, max_max + 1], [1, 1], \"k-\", lw=0.5, alpha=0.5)\n ax.plot([min_min - 1, max_max + 1], [-1, -1], \"k-\", lw=0.5, alpha=0.5)\n\n # Plot transport lines\n if indices_a is not None and indices_b is not None:\n subset_a = np.sort(x_a[indices_a])\n subset_b = np.sort(x_b[indices_b])\n\n for x_a_i, x_b_j in zip(subset_a, subset_b):\n ax.plot([x_a_i, x_b_j], [1, -1], \"k--\", alpha=0.7)\n\n # Plot all points\n ax.plot(x_a, y_a, \"o\", color=\"C0\", label=\"x_a\", markersize=8)\n ax.plot(x_b, y_b, \"o\", color=\"C1\", label=\"x_b\", markersize=8)\n\n if marginal_costs is not None:\n k = len(marginal_costs)\n ax.set_title(\n f\"Partial Transport - k = {k}, Cumulative Cost = {sum(marginal_costs):.2f}\",\n fontsize=16,\n )\n else:\n ax.set_title(\"Original 1D Discrete Distributions\", fontsize=16)\n ax.legend(loc=\"upper right\", fontsize=14)\n ax.set_yticks([])\n ax.set_xticks([])\n ax.set_ylim(-2, 2)\n ax.set_xlim(min(x_a.min(), x_b.min()) - 1, max(x_a.max(), x_b.max()) + 1)\n ax.axis(\"off\")\n\n\n# Simulate two 1D discrete distributions\nnp.random.seed(0)\nn = 6\nx_a = np.sort(np.random.uniform(0, 10, size=n))\nx_b = np.sort(np.random.uniform(0, 10, size=n))\n\n# Plot original distributions\nplt.figure(figsize=(6, 2))\nplot_partial_transport(plt.gca(), x_a, x_b)\nplt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "indices_a, indices_b, marginal_costs = partial_wasserstein_1d(x_a, x_b)\n\n# Compute cumulative cost\ncumulative_costs = np.cumsum(marginal_costs)\n\n# Visualize all partial transport plans\nfor k in range(n):\n plt.figure(figsize=(6, 2))\n plot_partial_transport(\n plt.gca(),\n x_a,\n x_b,\n indices_a[: k + 1],\n indices_b[: k + 1],\n marginal_costs[: k + 1],\n )\n plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }