In [1]:
!pip install numba

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
from numba.core import types
from numba.typed import Dict
from numba import njit

# Closures in Numba

Closures can be used to dinamically create different versions of a function based on some parameter, a similar functionality can be achieved with `functools.partial`.

My specific use-case is creating filters for [RDataFrame](https://root.cern/doc/master/classROOT_1_1RDataFrame.html), therefore I need to create `numba` optimizable functions with no other arguments except the input data.

## Parameter is a simple type

In this case `numba` supports a closure without any issue, in this case have a function which defines a cut on an array, we can create different versions of this function dinamically.

In [5]:
def MB_cut_factory(limit):
 def cut(value):
 return value < limit
 return cut

In [6]:
MB_cut_factory(4)(3)

True

In [8]:
njit(MB_cut_factory(4))(3)

True

## Parameter is a complex type

If the parameter is a complex type, unfortunately `numba` throws a `NotImplementedError`:

In [9]:
dict_ranges = Dict.empty(
 key_type=types.int64,
 value_type=types.Tuple((types.float64, types.float64))
 )

dict_ranges[3] = (1, 3)

def MB_cut_factory(dict_ranges):
 def cut(series, value):
 return dict_ranges[series][0] < value < dict_ranges[series][1]
 return cut

MB_cut_factory(dict_ranges)(3,2)

True

In [10]:
njit(MB_cut_factory(dict_ranges))(3,2)

NumbaNotImplementedError: ignored

## The ugly workaround

Using `exec` we can brutally create the function definition injecting the dictionary as a string into the function definition itself.

It is ugly but works and gives back a function that can be tested in pure Python before passing it to `numba` for optimization.

Notice we need to use `globals()` in the call to `exec` to have the `cut` function available in the namespace.

In [21]:
def MB_cut_factory(dict_ranges):
 exec("def cut(series, value):\n dict_ranges=" +\
 dict_ranges.__str__() +\
 "\n return dict_ranges[series][0] < value < dict_ranges[series][1]", globals())
 return cut

In [22]:
MB_cut_factory(dict_ranges)(3,2)

True

In [23]:
njit(MB_cut_factory(dict_ranges))(3,2)

True

## Questions on Stackoverflow

Trying to find solutions I posted 2 related questions to Stackoverflow, plase contribute there if you have better suggestions:

* [numba and variables defined in a closure](https://stackoverflow.com/questions/74160505/numba-and-variables-defined-in-a-closure)
* [Transform a partial function into a normal function](https://stackoverflow.com/questions/74166161/transform-a-partial-function-into-a-normal-function)