{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Particle Systems", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "1e24c9aa62954c4099b28c89d46d55ca": { "model_module": "@jupyter-widgets/controls", "model_name": "VBoxModel", "state": { "_view_name": "VBoxView", "_dom_classes": [ "widget-interact" ], "_model_name": "VBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_ebd017bc9dca41099c4fa560e9f68922", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_d34c46ec2e8f4170af80954555f8a2b7", "IPY_MODEL_1dbff5ccf90247b69e19cb78cbe86e25" ] } }, "ebd017bc9dca41099c4fa560e9f68922": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "d34c46ec2e8f4170af80954555f8a2b7": { "model_module": "@jupyter-widgets/controls", "model_name": "IntSliderModel", "state": { "_view_name": "IntSliderView", "style": "IPY_MODEL_2d5a3a03f7cb421cbb0cae446cbdfdb1", "_dom_classes": [], "description": "t", "step": 1, "_model_name": "IntSliderModel", "orientation": "horizontal", "max": 19, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 13, "_view_count": null, "disabled": false, "_view_module_version": "1.5.0", "min": 0, "continuous_update": true, "readout_format": "d", "description_tooltip": null, "readout": true, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_f17ea6f4ce5245009f234130c9273294" } }, "1dbff5ccf90247b69e19cb78cbe86e25": { "model_module": "@jupyter-widgets/output", "model_name": "OutputModel", "state": { "_view_name": "OutputView", "msg_id": "", "_dom_classes": [], "_model_name": "OutputModel", "outputs": [ { "output_type": "display_data", "metadata": { "tags": [], "needs_background": "light" }, "image/png": "\n", "text/plain": "
" } ], "_view_module": "@jupyter-widgets/output", "_model_module_version": "1.0.0", "_view_count": null, "_view_module_version": "1.0.0", "layout": "IPY_MODEL_b974db90bbbf44c482c1dd12f61013a3", "_model_module": "@jupyter-widgets/output" } }, "2d5a3a03f7cb421cbb0cae446cbdfdb1": { "model_module": "@jupyter-widgets/controls", "model_name": "SliderStyleModel", "state": { "_view_name": "StyleView", "handle_color": null, "_model_name": "SliderStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "f17ea6f4ce5245009f234130c9273294": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "b974db90bbbf44c482c1dd12f61013a3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "43db062d32ac4b39b5f554d12c335620": { "model_module": "@jupyter-widgets/controls", "model_name": "VBoxModel", "state": { "_view_name": "VBoxView", "_dom_classes": [ "widget-interact" ], "_model_name": "VBoxModel", "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.5.0", "box_style": "", "layout": "IPY_MODEL_a4dc0e31bfea45bcb9f064488a7c40ed", "_model_module": "@jupyter-widgets/controls", "children": [ "IPY_MODEL_7b0b70baaaaa4da68b7e71469ee49731", "IPY_MODEL_a6fb1ecdfa004558849b54cd293b6645" ] } }, "a4dc0e31bfea45bcb9f064488a7c40ed": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "7b0b70baaaaa4da68b7e71469ee49731": { "model_module": "@jupyter-widgets/controls", "model_name": "IntSliderModel", "state": { "_view_name": "IntSliderView", "style": "IPY_MODEL_197250af0ee648579e97f82dd8006c14", "_dom_classes": [], "description": "t", "step": 1, "_model_name": "IntSliderModel", "orientation": "horizontal", "max": 19, "_view_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "value": 10, "_view_count": null, "disabled": false, "_view_module_version": "1.5.0", "min": 0, "continuous_update": true, "readout_format": "d", "description_tooltip": null, "readout": true, "_model_module": "@jupyter-widgets/controls", "layout": "IPY_MODEL_0071e8e86e8949f088f9d82f1e5fa3e6" } }, "a6fb1ecdfa004558849b54cd293b6645": { "model_module": "@jupyter-widgets/output", "model_name": "OutputModel", "state": { "_view_name": "OutputView", "msg_id": "", "_dom_classes": [], "_model_name": "OutputModel", "outputs": [ { "output_type": "display_data", "metadata": { "tags": [], "needs_background": "light" }, "image/png": "\n", "text/plain": "
" } ], "_view_module": "@jupyter-widgets/output", "_model_module_version": "1.0.0", "_view_count": null, "_view_module_version": "1.0.0", "layout": "IPY_MODEL_242932679844407388684f5af4adcdda", "_model_module": "@jupyter-widgets/output" } }, "197250af0ee648579e97f82dd8006c14": { "model_module": "@jupyter-widgets/controls", "model_name": "SliderStyleModel", "state": { "_view_name": "StyleView", "handle_color": null, "_model_name": "SliderStyleModel", "description_width": "", "_view_module": "@jupyter-widgets/base", "_model_module_version": "1.5.0", "_view_count": null, "_view_module_version": "1.2.0", "_model_module": "@jupyter-widgets/controls" } }, "0071e8e86e8949f088f9d82f1e5fa3e6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } }, "242932679844407388684f5af4adcdda": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_view_name": "LayoutView", "grid_template_rows": null, "right": null, "justify_content": null, "_view_module": "@jupyter-widgets/base", "overflow": null, "_model_module_version": "1.2.0", "_view_count": null, "flex_flow": null, "width": null, "min_width": null, "border": null, "align_items": null, "bottom": null, "_model_module": "@jupyter-widgets/base", "top": null, "grid_column": null, "overflow_y": null, "overflow_x": null, "grid_auto_flow": null, "grid_area": null, "grid_template_columns": null, "flex": null, "_model_name": "LayoutModel", "justify_items": null, "grid_row": null, "max_height": null, "align_content": null, "visibility": null, "align_self": null, "height": null, "min_height": null, "padding": null, "grid_auto_rows": null, "grid_gap": null, "max_width": null, "order": null, "_view_module_version": "1.2.0", "grid_template_areas": null, "object_position": null, "object_fit": null, "grid_auto_columns": null, "margin": null, "display": null, "left": null } } } }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "C9UDWHVz7Cnf" }, "source": [ "# Particle system simulation and fitting\n", "\n", "This tours shows how to use [PyTorch](https://pytorch.org/) and [Keops](https://www.kernel-operations.io/keops/index.html) to compute the evolution of a particles system with a simple interaction energy. " ] }, { "cell_type": "code", "metadata": { "id": "aJNoYic02Cg_" }, "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from time import time\n", "import progressbar" ], "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "u1ofi5ng2nPB" }, "source": [ "Check if cuda is available (be sure to go to the parameter of the notebook to activate GPU)" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HLmCgwm-2kkL", "outputId": "4a6dfd9a-177d-4436-d542-0466bbe466f5" }, "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "text": [ "cuda\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "Oh1WY5ttkR9g" }, "source": [ "Draw random particles." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CGKSujUO2w8X", "outputId": "d33c7c97-742d-42b7-bcb3-6cc9f35c7f1c" }, "source": [ "n = 10000 # number of points\n", "m = 10100 # number of points on the second cloud\n", "d = 2 # dimension\n", "X = torch.rand(n,d)\n", "Y = torch.rand(m,d)\n", "print( X.is_cuda )\n", "X = X.to(device); # put it on gpu\n", "Y = Y.to(device); # put it on gpu\n", "print( X.is_cuda )" ], "execution_count": 3, "outputs": [ { "output_type": "stream", "text": [ "False\n", "True\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "e5f6IgJp04gn" }, "source": [ "Handling boundary conditions." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 299 }, "id": "gBG00zg003qq", "outputId": "baf4071f-5cbf-4be1-a14e-a89bf565bd4f" }, "source": [ "boundary = 'no' # no boundary condition\n", "boundary = 'per' # periodic \n", "if boundary=='no': # change this for usual BC\n", " print('No boundary.')\n", " def bc_pos(X): return X\n", " def bc_diff(D): return D\n", "else:\n", " print('Periodic boundary.')\n", " def bc_pos(X): return torch.remainder(X,1.0)\n", " def bc_diff(D): return torch.remainder(D-.5,1.0)-.5\n", "\n", "t = torch.tensor(np.linspace(-1.5,1.5,1000))\n", "plt.plot( t, bc_pos( t ) )\n", "plt.plot( t, bc_diff( t ), '--' )" ], "execution_count": 4, "outputs": [ { "output_type": "stream", "text": [ "Periodic boundary.\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "[]" ] }, "metadata": { "tags": [] }, "execution_count": 4 }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "mgiD8_ci3mzG" }, "source": [ "Example of how to compute a pairwise distance matrix\n", "$$\n", " D_{i,j} = \\|x_i-y_j\\|^2\n", "$$\n", "efficiently using the `None` keyword." ] }, { "cell_type": "code", "metadata": { "id": "_abJlmdr29Ql" }, "source": [ "def distmat_square(X,Y):\n", " return torch.sum( bc_diff(X[:,None,:] - Y[None,:,:])**2, axis=2 )" ], "execution_count": 5, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 286 }, "id": "wWJaya0J2vbv", "outputId": "fed0fb35-825d-41c6-ffd1-9f0beccae609" }, "source": [ "plt.imshow( distmat_square(t[:,None],t[:,None]) )" ], "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 6 }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "2G4vkyh8z7_0" }, "source": [ "A more memory efficient way (specially in high dimension $d$) way to compute the distance by expanding the squares." ] }, { "cell_type": "code", "metadata": { "id": "3F7QiKf0zeo2" }, "source": [ "def distmat_square2(X, Y):\n", " X_sq = (X ** 2).sum(axis=-1)\n", " Y_sq = (Y ** 2).sum(axis=-1)\n", " cross_term = X.matmul(Y.T)\n", " return X_sq[:, None] + Y_sq[None, :] - 2 * cross_term" ], "execution_count": 7, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LiWlq3PtP1e9", "outputId": "8a30b20a-d157-4586-be8e-2823228e9437" }, "source": [ "t0 = time()\n", "distmat_square(X, Y)\n", "print(time() - t0)\n", "\n", "t0 = time()\n", "distmat_square2(X, Y)\n", "print(time() - t0)" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "0.027776718139648438\n", "0.04250526428222656\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "dBoeQ8xT3VrZ" }, "source": [ "# Maximum Mean Discrepencies" ] }, { "cell_type": "markdown", "metadata": { "id": "h-m3t5it57Ds" }, "source": [ "When training ML model with point clouds, it is important to be able to compare to pair of clouds. A simple loss function is to use an MMD norm\n", "$$\n", " \\text{MMD}(X,Y) \\triangleq\n", " \\frac{1}{n^2} \\sum_{i,i'} k(x_i,x_{i'}) \n", " +\n", " \\frac{1}{m^2} \\sum_{j,j'} k(y_j,y_{j'})\n", " -2\n", " \\frac{1}{nm} \\sum_{i,j} k(x_i,y_{j}).\n", "$$\n", "We use here the energy distance kernel, so that MMD$(X,Y)$ is a dual Sobolev norm between the discrete measure $\\frac{1}{n}\\sum_i \\delta_{x_i}$ and $\\frac{1}{m}\\sum_i \\delta_{y_j}$\n", "$$\n", " k(x,y) = -\\|x-y\\|.\n", "$$" ] }, { "cell_type": "code", "metadata": { "id": "_oQaky5t55zK" }, "source": [ "def kernel(X,Y):\n", " return -torch.sqrt( distmat_square(X,Y) )\n", "\n", "def MMD(X,Y):\n", " n = X.shape[0]\n", " m = Y.shape[0]\n", " a = torch.sum( kernel(X,X) )/n**2 + \\\n", " torch.sum( kernel(Y,Y) )/m**2 - \\\n", " 2*torch.sum( kernel(X,Y) )/(n*m)\n", " return a.item()" ], "execution_count": 9, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qxnErW9v6oMr", "outputId": "f49df341-a7a7-4b79-a811-caad4073a63e" }, "source": [ "print( MMD(X,X) ) # should be 0\n", "print( MMD(X,Y) ) # should be >0" ], "execution_count": 10, "outputs": [ { "output_type": "stream", "text": [ "0.0\n", "6.395578384399414e-05\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "CihPkBXe3blG" }, "source": [ "# Gradient flow with Pytorch" ] }, { "cell_type": "markdown", "metadata": { "id": "m16xCozj4I1s" }, "source": [ "We compute a speed vector field to advances the particlesusing an interaction kernel of the form\n", "$$\n", " v(X)_i = \\frac{1}{n} \\sum_j \\psi(\\|x_i-x_j\\|^2) (x_j-x_i)\n", "$$\n", "Note that the evolution\n", "$$\n", " \\dot X = -v(X)\n", "$$\n", "is the Wasserstein gradient flow of the energy\n", "$$\n", " \\sum_{i,j} \\phi(\\|x_i-y_j\\|^2 ).\n", "$$\n", "when defining $\\psi(r)=4\\phi'(r)$.\n", "\n", "If $\\phi$ is decreasing (resp. increasing), the flow is repulsive (resp. attractive)." ] }, { "cell_type": "code", "metadata": { "id": "P37QnQCe4H-e" }, "source": [ "sigma = .1;\n", "def psi(r):\n", " return torch.exp( -r/(2*sigma**2) )" ], "execution_count": 11, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "neZBQfE53EqM" }, "source": [ "def Speed(X):\n", " return 2/X.shape[0] * 1/sigma**2 * torch.sum( psi(distmat_square(X,X))[:,:,None] * bc_diff( X[:,None,:] - X[None,:,:] ), axis=1 )" ], "execution_count": 12, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "nsipu8UblUt5" }, "source": [ "Discretize the evolution $\\dot X = -v(X)$." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "v69IX-Ov0L9B", "outputId": "0e29574d-7e9e-4899-fb6d-ad1e75c78a1e" }, "source": [ "if boundary=='no':\n", " tau = 1/500 # time step\n", "else:\n", " tau = 1/200\n", "niter = 200\n", "save_per = 10 # periodicity of saving\n", "Zsvg = torch.zeros((n,2,niter//save_per)) # to store all the intermediate time\n", "Z = X\n", "for it in progressbar.progressbar(range(niter)):\n", " if np.mod(it,save_per)==0:\n", " Zsvg[:,:,it//save_per] = Z.clone().detach() # for later display\n", " Z = bc_pos( Z - tau*Speed(Z) )" ], "execution_count": 13, "outputs": [ { "output_type": "stream", "text": [ "100% (200 of 200) |######################| Elapsed Time: 0:00:15 Time: 0:00:15\n" ], "name": "stderr" } ] }, { "cell_type": "markdown", "metadata": { "id": "S8vh1YvTlaB0" }, "source": [ "Display the evolution." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 301, "referenced_widgets": [ "1e24c9aa62954c4099b28c89d46d55ca", "ebd017bc9dca41099c4fa560e9f68922", "d34c46ec2e8f4170af80954555f8a2b7", "1dbff5ccf90247b69e19cb78cbe86e25", "2d5a3a03f7cb421cbb0cae446cbdfdb1", "f17ea6f4ce5245009f234130c9273294", "b974db90bbbf44c482c1dd12f61013a3" ] }, "id": "Lwvb9cOCgeaA", "outputId": "15abf4e5-35de-4afa-f49b-cc187cdc4c63" }, "source": [ "import ipywidgets as widgets\n", "@widgets.interact(t=(0,niter//save_per-1))\n", "def display_frame(t=0):\n", " s = t/(niter//save_per-1)\n", " plt.scatter(Zsvg[:,0,t], Zsvg[:,1,t], color=[s,0,1-s])\n", " plt.axis('equal')\n", " plt.axis([0,1,0,1])" ], "execution_count": 14, "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1e24c9aa62954c4099b28c89d46d55ca", "version_minor": 0, "version_major": 2 }, "text/plain": [ "interactive(children=(IntSlider(value=0, description='t', max=19), Output()), _dom_classes=('widget-interact',…" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "1-lalpcK3ind" }, "source": [ "# Computing Gradient with Auto-diff\n", "\n", "Instead of computing \"by hand\" the gradient of the interaction energy, one can directly rely on Pytorch auto-diff functionality. This simplifies coding and it reduces bugs in code." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "akQKlOnh3xpu", "outputId": "1de5abb3-5ee7-4775-8e04-e0bec140ef32" }, "source": [ "X.requires_grad = True\n", "L = -1/X.shape[0] * torch.sum( psi(distmat_square(X,X)), axis=(0,1) )\n", "[g] = torch.autograd.grad(L, [X])\n", "# compare with the \"by hand\" computation\n", "print( 'Difference \"hand\" vs. pytorch\" : ' + str( torch.norm( g-Speed(X) ).item() / torch.norm( g ).item() ) )" ], "execution_count": 15, "outputs": [ { "output_type": "stream", "text": [ "Difference \"hand\" vs. pytorch\" : 1.7333064950325788e-07\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "2DVdojtPpSyd" }, "source": [ "# Computations using Keops\n", "\n", "In order to speed up (by a factor 10 to 100) the evaluation of the kernel, it is possible to use [Keops](https://www.kernel-operations.io/keops/index.html), which allows one to define \"lazy tensor\". This both reduces the memory footprint (by evaluating the kernel on the fly) and accelerate the computation by a careful mapping of the tensor entries evaluations on the tiles of the GPU." ] }, { "cell_type": "code", "metadata": { "id": "D5LIYwHWpVSc" }, "source": [ "!pip install pykeops[colab] > install.log" ], "execution_count": 16, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "C8iPf7bTpdd9", "outputId": "24ab29e7-60c0-48e3-8bf2-00b55f23be64" }, "source": [ "import pykeops\n", "import pykeops.torch as keops\n", "pykeops.clean_pykeops() # just in case old build files are still present\n", "pykeops.test_torch_bindings() # perform the compilation" ], "execution_count": 17, "outputs": [ { "output_type": "stream", "text": [ "Cleaning /root/.cache/pykeops-1.5-cpython-37/...\n", "[pyKeOps] Initializing build folder for dtype=float32 and lang=torch in /root/.cache/pykeops-1.5-cpython-37 ... done.\n", "[pyKeOps] Compiling libKeOpstorch180bebcc11 in /root/.cache/pykeops-1.5-cpython-37:\n", " formula: Sum_Reduction(SqNorm2(x - y),1)\n", " aliases: x = Vi(0,3); y = Vj(1,3); \n", " dtype : float32\n", "... \n", "[pyKeOps] Compiling pybind11 template libKeOps_template_574e4b20be in /root/.cache/pykeops-1.5-cpython-37 ... done.\n", "Done.\n", "\n", "pyKeOps with torch bindings is working!\n", "\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "YXtRegSD-A8O" }, "source": [ "Compute the gradient of a kernel interaction." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jN_iunlV4mRZ", "outputId": "e1426001-4794-4600-a990-b9131364eaaf" }, "source": [ "X.requires_grad = True\n", "D = keops.Vi(X) - keops.Vj(X)\n", "if boundary=='per':\n", " D1 = (D-.5).mod(1.0)-.5 # for periodic BC\n", "else: \n", " D1 = D\n", "D2 = ( D1 ** 2 ).sum( dim=2 )\n", "K = ( -D2 / (2*sigma**2) ).exp()\n", "L = -1/X.shape[0] * (K.sum(dim=1)**1).sum()\n", "[g] = torch.autograd.grad(L, [X])\n", "# compare with the \"by hand\" computation\n", "print( 'Difference \"hand\" vs. \"keops+pytorch\" : ' + str( torch.norm( g-Speed(X) ).item() / torch.norm( g ).item() ) )" ], "execution_count": 18, "outputs": [ { "output_type": "stream", "text": [ "[pyKeOps] Compiling libKeOpstorchea57037698 in /root/.cache/pykeops-1.5-cpython-37:\n", " formula: Sum_Reduction(Exp((Minus(Sum(Square((Mod(((Var(0,2,0) - Var(1,2,1)) - Var(2,1,2)), Var(3,1,2), IntCst(0)) - Var(4,1,2))))) / Var(5,1,2))),0)\n", " aliases: Var(0,2,0); Var(1,2,1); Var(2,1,2); Var(3,1,2); Var(4,1,2); Var(5,1,2); \n", " dtype : float32\n", "... \n", "Done.\n", "[pyKeOps] Compiling libKeOpstorch17a5f3fca7 in /root/.cache/pykeops-1.5-cpython-37:\n", " formula: Grad_WithSavedForward(Sum_Reduction(Exp((Minus(Sum(Square((Mod(((Var(0,2,0) - Var(1,2,1)) - Var(2,1,2)), Var(3,1,2), IntCst(0)) - Var(4,1,2))))) / Var(5,1,2))),0), Var(0,2,0), Var(6,1,0), Var(7,1,0))\n", " aliases: Var(0,2,0); Var(1,2,1); Var(2,1,2); Var(3,1,2); Var(4,1,2); Var(5,1,2); Var(6,1,0); Var(7,1,0); \n", " dtype : float32\n", "... \n", "Done.\n", "[pyKeOps] Compiling libKeOpstorchfdfa2e7b49 in /root/.cache/pykeops-1.5-cpython-37:\n", " formula: Grad_WithSavedForward(Sum_Reduction(Exp((Minus(Sum(Square((Mod(((Var(0,2,0) - Var(1,2,1)) - Var(2,1,2)), Var(3,1,2), IntCst(0)) - Var(4,1,2))))) / Var(5,1,2))),0), Var(1,2,1), Var(6,1,0), Var(7,1,0))\n", " aliases: Var(0,2,0); Var(1,2,1); Var(2,1,2); Var(3,1,2); Var(4,1,2); Var(5,1,2); Var(6,1,0); Var(7,1,0); \n", " dtype : float32\n", "... \n", "Done.\n", "Difference \"hand\" vs. \"keops+pytorch\" : 2.946057011364191e-07\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "7x3ZP-176ONd", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "c3b14b63-f2bc-4d3a-c9a1-bf6da9f680b5" }, "source": [ "Zsvg = torch.zeros((n,2,niter//save_per)) # to store all the intermediate time\n", "Z = X\n", "Z.requires_grad = True\n", "for it in progressbar.progressbar(range(niter)):\n", " if np.mod(it,save_per)==0:\n", " Zsvg[:,:,it//save_per] = Z.clone().detach() # for later display\n", " D = keops.Vi(Z) - keops.Vj(Z)\n", " if boundary=='per':\n", " D1 = (D-.5).mod(1.0)-.5 # for periodic BC\n", " else: \n", " D1 = D\n", " D2 = ( D1 ** 2 ).sum( dim=2 )\n", " K = ( -D2 / (2*sigma**2) ).exp()\n", " L = 1/X.shape[0] * (K.sum(dim=1)**1).sum() # There is a bug, I needed to add **1 here !!\n", " [g] = torch.autograd.grad(L, [Z])\n", " Z = bc_pos( Z + tau*g )" ], "execution_count": 19, "outputs": [ { "output_type": "stream", "text": [ "100% (200 of 200) |######################| Elapsed Time: 0:00:01 Time: 0:00:01\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 301, "referenced_widgets": [ "43db062d32ac4b39b5f554d12c335620", "a4dc0e31bfea45bcb9f064488a7c40ed", "7b0b70baaaaa4da68b7e71469ee49731", "a6fb1ecdfa004558849b54cd293b6645", "197250af0ee648579e97f82dd8006c14", "0071e8e86e8949f088f9d82f1e5fa3e6", "242932679844407388684f5af4adcdda" ] }, "id": "2PtiN6706_GG", "outputId": "33f30928-e7ae-4c60-ad4a-be1868bc6091" }, "source": [ "import ipywidgets as widgets\n", "@widgets.interact(t=(0,niter//save_per-1))\n", "def display_frame(t=0):\n", " s = t/(niter//save_per-1)\n", " plt.scatter(Zsvg[:,0,t], Zsvg[:,1,t], color=[s,0,1-s])\n", " plt.axis('equal')\n", " plt.axis([0,1,0,1])" ], "execution_count": 20, "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43db062d32ac4b39b5f554d12c335620", "version_minor": 0, "version_major": 2 }, "text/plain": [ "interactive(children=(IntSlider(value=0, description='t', max=19), Output()), _dom_classes=('widget-interact',…" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "XrEJTQa08tqg" }, "source": [ "" ], "execution_count": 20, "outputs": [] } ] }