{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "raw", "source": [ "---\n", "title: \"Closures in Numba\"\n", "date: '2022-10-24'\n", "categories:\n", " - python\n", "---" ], "metadata": { "id": "g-ETSiSNS1os" } }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Mfhpi2BuMjEJ", "outputId": "301d17fb-af39-4f88-94ee-5e3d67c44add" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (0.56.3)\n", "Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.7/dist-packages (from numba) (0.39.1)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba) (57.4.0)\n", "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from numba) (4.13.0)\n", "Requirement already satisfied: numpy<1.24,>=1.18 in /usr/local/lib/python3.7/dist-packages (from numba) (1.21.6)\n", "Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->numba) (4.1.1)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->numba) (3.9.0)\n" ] } ], "source": [ "!pip install numba" ] }, { "cell_type": "code", "source": [ "from numba.core import types\n", "from numba.typed import Dict\n", "from numba import njit" ], "metadata": { "id": "0xidhIxWMqc9" }, "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Closures in Numba\n", "\n", "Closures can be used to dinamically create different versions of a function based on some parameter, a similar functionality can be achieved with `functools.partial`.\n", "\n", "My specific use-case is creating filters for [RDataFrame](https://root.cern/doc/master/classROOT_1_1RDataFrame.html), therefore I need to create `numba` optimizable functions with no other arguments except the input data.\n", "\n", "## Parameter is a simple type\n", "\n", "In this case `numba` supports a closure without any issue, in this case have a function which defines a cut on an array, we can create different versions of this function dinamically." ], "metadata": { "id": "9yvkxKRyMzts" } }, { "cell_type": "code", "source": [ "def MB_cut_factory(limit):\n", " def cut(value):\n", " return value < limit\n", " return cut" ], "metadata": { "id": "kgaHXd-pNcn1" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "MB_cut_factory(4)(3)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CE0wsUXMNeUt", "outputId": "d2b6d019-c4d9-415f-a7cc-dbcea94faaab" }, "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "code", "source": [ "njit(MB_cut_factory(4))(3)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QwXsXue0NemR", "outputId": "3ba522c5-8d43-425d-d8c7-e970ff301757" }, "execution_count": 8, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 8 } ] }, { "cell_type": "markdown", "source": [ "## Parameter is a complex type\n", "\n", "If the parameter is a complex type, unfortunately `numba` throws a `NotImplementedError`:" ], "metadata": { "id": "0P7jc6pcO0zi" } }, { "cell_type": "code", "source": [ "dict_ranges = Dict.empty(\n", " key_type=types.int64,\n", " value_type=types.Tuple((types.float64, types.float64))\n", " )\n", "\n", "dict_ranges[3] = (1, 3)\n", "\n", "def MB_cut_factory(dict_ranges):\n", " def cut(series, value):\n", " return dict_ranges[series][0] < value < dict_ranges[series][1]\n", " return cut\n", "\n", "MB_cut_factory(dict_ranges)(3,2)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aQeIh9tZMksJ", "outputId": "eccebfcc-b093-47ec-fdc2-81812cca0663" }, "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 9 } ] }, { "cell_type": "code", "source": [ "njit(MB_cut_factory(dict_ranges))(3,2)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 378 }, "id": "mGWzGledMmWN", "outputId": "ab6de018-11d7-4bb5-c36e-8fd65eb7bed4" }, "execution_count": 10, "outputs": [ { "output_type": "error", "ename": "NumbaNotImplementedError", "evalue": "ignored", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNumbaNotImplementedError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mMB_cut_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdict_ranges\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/numba/core/dispatcher.py\u001b[0m in \u001b[0;36m_compile_for_args\u001b[0;34m(self, *args, **kws)\u001b[0m\n\u001b[1;32m 466\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpatch_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 468\u001b[0;31m \u001b[0merror_rewrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'typing'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 469\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 470\u001b[0m \u001b[0;31m# Something unsupported is present in the user code, add help info\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/numba/core/dispatcher.py\u001b[0m in \u001b[0;36merror_rewrite\u001b[0;34m(e, issue_type)\u001b[0m\n\u001b[1;32m 407\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 408\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 409\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 410\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 411\u001b[0m \u001b[0margtypes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mNumbaNotImplementedError\u001b[0m: Failed in nopython mode pipeline (step: native lowering)\n, (DictType[int64,UniTuple(float64 x 2)],)\nDuring: lowering \"$2load_deref.0 = freevar(dict_ranges: {3: (1.0, 3.0)})\" at (10)" ] } ] }, { "cell_type": "markdown", "source": [ "## The ugly workaround\n", "\n", "Using `exec` we can brutally create the function definition injecting the dictionary as a string into the function definition itself.\n", "\n", "It is ugly but works and gives back a function that can be tested in pure Python before passing it to `numba` for optimization.\n", "\n", "Notice we need to use `globals()` in the call to `exec` to have the `cut` function available in the namespace." ], "metadata": { "id": "q276pG7DRhen" } }, { "cell_type": "code", "source": [ "def MB_cut_factory(dict_ranges):\n", " exec(\"def cut(series, value):\\n dict_ranges=\" +\\\n", " dict_ranges.__str__() +\\\n", " \"\\n return dict_ranges[series][0] < value < dict_ranges[series][1]\", globals())\n", " return cut" ], "metadata": { "id": "IB3vfkSoMsV0" }, "execution_count": 21, "outputs": [] }, { "cell_type": "code", "source": [ "MB_cut_factory(dict_ranges)(3,2)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oswgfUnfQGdr", "outputId": "3e824991-defa-4719-d970-b3f3dae12069" }, "execution_count": 22, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 22 } ] }, { "cell_type": "code", "source": [ "njit(MB_cut_factory(dict_ranges))(3,2)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QDgZwGW-QG2P", "outputId": "3f913623-138b-497c-afb4-791e48d3c326" }, "execution_count": 23, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 23 } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "IT_YgrdDQvhi" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Questions on Stackoverflow\n", "\n", "Trying to find solutions I posted 2 related questions to Stackoverflow, plase contribute there if you have better suggestions:\n", "\n", "* [numba and variables defined in a closure](https://stackoverflow.com/questions/74160505/numba-and-variables-defined-in-a-closure)\n", "* [Transform a partial function into a normal function](https://stackoverflow.com/questions/74166161/transform-a-partial-function-into-a-normal-function)" ], "metadata": { "id": "ruO3ODKYR69u" } } ] }