# `@overload_classmethod` for NumPy Array subclasses

In this release, experimental support is added for specializing the allocator in NumPy `ndarray` subclasses. Two key enhancements were added to enable this:

- `@overload_classmethod` permits the specializing of `classmethod` on specific types; and,
- exposing [`Array._allocate`](https://github.com/numba/numba/blob/0.54.0/numba/np/arrayobj.py#L3531-L3537) as an overloadable `classmethod` on Numba's `Array` type.

The rest of this notebook demonstrates the use of `@overload_classmethod` to override the allocator for a custom NumPy `ndarray` subclass.

In [1]:
# All necessary imports
import builtins
import ctypes
from numbers import Number

import numpy as np

# We'll be need to write some LLVM IR
from llvmlite import ir

from numba import njit
from numba.core import types
from numba.extending import (
 overload_classmethod,
 typeof_impl,
 register_model,
 intrinsic,
)
from numba.core import cgutils, types, typing
from numba.core.datamodel import models
from numba.np import numpy_support

## Define a NumPy array subclass

Make a NumPy `ndarray` subclass called `MyArray`. It needs to override [``__array_ufunc__``](https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__) to specialize how certain ufuncs are handled.

In [2]:
class MyArray(np.ndarray):
 def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
 # This is a "magic" method in NumPy subclasses to override 
 # the behavior of NumPy’s ufuncs.
 if method == "__call__":
 N = None
 scalars = []
 for inp in inputs:
 # If scalar?
 if isinstance(inp, Number):
 scalars.append(inp)
 # If array?
 elif isinstance(inp, (type(self), np.ndarray)):
 if isinstance(inp, type(self)):
 scalars.append(np.ndarray(inp.shape, inp.dtype, inp))
 else:
 scalars.append(inp)
 # Guard shape
 if N is not None:
 if N != inp.shape:
 raise TypeError("inconsistent sizes")
 else:
 N = inp.shape
 # If unknown type?
 else:
 return NotImplemented
 print(f"NumPy: {type(self)}.__array_ufunc__ method={method} inputs={inputs}")
 ret = ufunc(*scalars, **kwargs)
 return self.__class__(ret.shape, ret.dtype, ret)
 else:
 return NotImplemented

## Register the new NumPy subclass type in Numba


Make a subclass of the Numba `Array` type to represent `MyArray` as a Numba type. Similar to the NumPy `ndarray` subclass, the Numba type also has a `__array_ufunc__` method, but the difference is that it operates in the Numba _typing domain_. Concretely, it receives ``inputs`` that are the argument types, not the argument values, and it returns the type of the returned value, not the return value itself.

In [3]:
class MyArrayType(types.Array):
 def __init__(self, dtype, ndim, layout, readonly=False, aligned=True):
 name = f"MyArray({ndim}, {dtype}, {layout})"
 super().__init__(dtype, ndim, layout, readonly=readonly,
 aligned=aligned, name=name)
 
 # Tell Numba typing how to combine MyArrayType with other ndarray types.
 def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
 """
 This is the parallel for NumPy's __array_ufunc__ but operates on Numba types instead.
 In NumPy's __array_ufunc__, this performs the calculation, but here we 
 only produce the return type.
 """
 if method == "__call__":
 for inp in inputs:
 if not isinstance(inp, (types.Array, types.Number)):
 return NotImplemented
 print(f"Numba: {self}.__array_ufunc__ method={method} inputs={inputs}")
 return MyArrayType
 else:
 return NotImplemented

We need to teach Numba that ``MyArray`` corresponds to the Numba type ``MyArrayType``. This is done by registering the implementation of `typeof` for `MyArray`.

In [4]:
@typeof_impl.register(MyArray)
def typeof_ta_ndarray(val, c):
 # Determine dtype
 try:
 dtype = numpy_support.from_dtype(val.dtype)
 except NotImplementedError:
 raise ValueError("Unsupported array dtype: %s" % (val.dtype,))
 # Determine memory layout
 layout = numpy_support.map_layout(val)
 # Determine writeability
 readonly = not val.flags.writeable
 return MyArrayType(dtype, val.ndim, layout, readonly=readonly)

We also need to teach Numba how `MyArrayType` is represented in memory. For our purpose, it is the same as the basic `Array` type. This is done by registering a `datamodel` for `MyArrayType`.

In [5]:
register_model(MyArrayType)(models.ArrayModel)

numba.core.datamodel.models.ArrayModel

## Override the allocator in the subclass

We define a new allocator to use inside Numba for `MyArray`. Numba exposes an API for external code to register a new allocator table. The C structure for the allocator table is defined below:

(From: https://github.com/numba/numba/blob/0.54.0/numba/core/runtime/nrt_external.h#L10-L19)

```C

typedef void *(*NRT_external_malloc_func)(size_t size, void *opaque_data);
typedef void *(*NRT_external_realloc_func)(void *ptr, size_t new_size, void *opaque_data);
typedef void (*NRT_external_free_func)(void *ptr, void *opaque_data);


struct ExternalMemAllocator {
 NRT_external_malloc_func malloc;
 NRT_external_realloc_func realloc;
 NRT_external_free_func free;
 void *opaque_data;
};
```

In the following, we use `ctypes` to expose Python functions as C-functions (using `ctypes.CFUNCTYPE`). These functions will be used as the allocator and deallocator. Then, we put the pointers to these functions into a `ctypes.Structure` that matches the `ExternalMemAllocator` structure shown above.

As this is not a performance focused implementation, we are using Python functions as the allocator/deallocator so that we can `print()` when they are invoked. For production use, users are expected to write the allocator/deallocator in native code.

**WARNING: DO NOT rerun** the following cells. It will cause a segfault because the deallocator (`free_func()`) can be a removed before all the Numba dynamic memory is released.

In [6]:
lib = ctypes.CDLL(None)
lib.malloc.argtypes = [ctypes.c_size_t]
lib.malloc.restype = ctypes.c_size_t
lib.free.argtypes = [ctypes.c_void_p]

@ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_size_t, ctypes.c_void_p)
def malloc_func(size, data):
 """
 The allocator. Numba takes opaque data as a void* in the second argument.
 """
 # Call underlying C malloc
 out = lib.malloc(size)
 print(f">>> Malloc size={size} data={data} -> {hex(np.uintp(out))}")
 return out


@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p)
def free_func(ptr, data):
 """
 The deallocator. Numba takes opaque data as a void* in the second argument.
 """
 if lib is None:
 # Note: in practice guard against global being removed during interpreter shutdown
 return
 print(f">>> Free ptr={hex(ptr)} data={data}")
 # Call underlying C free()
 lib.free(ptr)
 return


class ExternalMemAllocator(ctypes.Structure):
 """
 This defines a struct for the allocator table. 
 Its fields must match ExternalMemAllocator defined in `nrt_external.h`
 """
 _fields_ = [
 ("malloc_func", ctypes.c_void_p),
 ("realloc_func", ctypes.c_void_p),
 ("free_func", ctypes.c_void_p),
 ("data", ctypes.c_void_p),
 ]

# Instantiate the allocator table
allocator_table = ExternalMemAllocator(
 malloc_func=ctypes.cast(malloc_func, ctypes.c_void_p),
 realloc_func=None, # unused; skipped for demo purpose
 free_func=ctypes.cast(free_func, ctypes.c_void_p),
 data=None, # no extra data needed
)
# Inspect the address of the table
print("allocator_table:", hex(ctypes.addressof(allocator_table)))

allocator_table: 0x7faad81b2fb0


Now to override the memory allocator for this array subclass...

Note: For demonstration purpose, the allocator references the dynamic runtime address of the allocator-table. This disables several features of Numba, including caching and AOT compilation. 

In [7]:
@overload_classmethod(MyArrayType, "_allocate")
def _ol_array_allocate(cls, allocsize, align):
 """Implements a Numba-only classmethod on the array type.
 """
 def impl(cls, allocsize, align):
 # The bulk of the work in implemented in the intrinsic below.
 return allocator_MyArray(allocsize, align)

 return impl

@intrinsic
def allocator_MyArray(typingctx, allocsize, align):
 def impl(context, builder, sig, args):
 context.nrt._require_nrt()
 size, align = args

 mod = builder.module
 u32 = ir.IntType(32)
 voidptr = cgutils.voidptr_t

 # We will use our custom allocator table here.
 # The table is referenced by its dynamic runtime address.
 addr = ctypes.addressof(allocator_table)
 ext_alloc = context.add_dynamic_addr(builder, addr, info='custom_alloc_table')

 # Invoke the allocator routine that uses our custom allocator
 fnty = ir.FunctionType(voidptr, [cgutils.intp_t, u32, voidptr])
 fn = cgutils.get_or_insert_function(
 mod, fnty, name="NRT_MemInfo_alloc_safe_aligned_external"
 )
 fn.return_value.add_attribute("noalias")

 if isinstance(align, builtins.int):
 align = context.get_constant(types.uint32, align)
 else:
 assert align.type == u32, "align must be a uint32"

 call = builder.call(fn, [size, align, ext_alloc])
 return call

 mip = types.MemInfoPointer(types.voidptr) # return untyped pointer
 sig = typing.signature(mip, allocsize, align)
 return sig, impl


## Testing

To test, we define a simple functions that computes `a * 2 + a`:

In [8]:
def foo(a):
 return a * 2 + a

buf = np.arange(4)
a = MyArray(buf.shape, buf.dtype, buf)
a

MyArray([0, 1, 2, 3])

When `foo()`, is not Numba-compiled, is executed, we can see that the `MyArray.__array_ufunc__` method is used for the `*` and `+` operations.

In [9]:
foo(a)

NumPy: .__array_ufunc__ method=__call__ inputs=(MyArray([0, 1, 2, 3]), 2)
NumPy: .__array_ufunc__ method=__call__ inputs=(MyArray([0, 2, 4, 6]), MyArray([0, 1, 2, 3]))


MyArray([0, 3, 6, 9])

Below is the Numba JIT version:

In [10]:
jit_foo = njit(foo)

When `jit_foo()` is executed, `MyArrayType.__array_ufunc__` method is used to compute the types of the `*` and `+` operations. Note, type-inference is invoking the `__array_ufunc__` method multiple times due to specifics of the algorithm. We can also see a series of prints to `stdout` as part of the implementation of the allocator (`malloc_func()`) and deallocator (`free_func()`). It is showing two allocations for the result of `*` and `+`, and one deallocation for the intermediate in `*`.

In [11]:
jit_foo(a)

Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)
Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))
Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)
Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))
Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)
Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))
>>> Malloc size=144 data=None -> 0x7faa85511aa0
>>> Malloc size=144 data=None -> 0x7faa85561760
>>> Free ptr=0x7faa85511aa0 data=None


array([0, 3, 6, 9])

Lastly, we can observe the use of the `MyArray` type in the annotated IR.

In [13]:
jit_foo.inspect_types()

foo (MyArray(1, int64, C),)
--------------------------------------------------------------------------------
# File: 
# --- LINE 1 --- 

def foo(a):

 # --- LINE 2 --- 
 # label 0
 # a = arg(0, name=a) :: MyArray(1, int64, C)
 # $const4.1 = const(int, 2) :: Literal[int](2)
 # $6binary_multiply.2 = a * $const4.1 :: MyArray(1, int64, C)
 # del $const4.1
 # $10binary_add.4 = $6binary_multiply.2 + a :: MyArray(1, int64, C)
 # del a
 # del $6binary_multiply.2
 # $12return_value.5 = cast(value=$10binary_add.4) :: MyArray(1, int64, C)
 # del $10binary_add.4
 # return $12return_value.5

 return a * 2 + a


