{
"cells": [
{
"cell_type": "markdown",
"id": "e48e9e8b",
"metadata": {},
"source": [
"# `@overload_classmethod` for NumPy Array subclasses\n",
"\n",
"In this release, experimental support is added for specializing the allocator in NumPy `ndarray` subclasses. Two key enhancements were added to enable this:\n",
"\n",
"- `@overload_classmethod` permits the specializing of `classmethod` on specific types; and,\n",
"- exposing [`Array._allocate`](https://github.com/numba/numba/blob/0.54.0/numba/np/arrayobj.py#L3531-L3537) as an overloadable `classmethod` on Numba's `Array` type.\n",
"\n",
"The rest of this notebook demonstrates the use of `@overload_classmethod` to override the allocator for a custom NumPy `ndarray` subclass."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d94a2aa0",
"metadata": {},
"outputs": [],
"source": [
"# All necessary imports\n",
"import builtins\n",
"import ctypes\n",
"from numbers import Number\n",
"\n",
"import numpy as np\n",
"\n",
"# We'll be need to write some LLVM IR\n",
"from llvmlite import ir\n",
"\n",
"from numba import njit\n",
"from numba.core import types\n",
"from numba.extending import (\n",
" overload_classmethod,\n",
" typeof_impl,\n",
" register_model,\n",
" intrinsic,\n",
")\n",
"from numba.core import cgutils, types, typing\n",
"from numba.core.datamodel import models\n",
"from numba.np import numpy_support"
]
},
{
"cell_type": "markdown",
"id": "b753d88e",
"metadata": {},
"source": [
"## Define a NumPy array subclass\n",
"\n",
"Make a NumPy `ndarray` subclass called `MyArray`. It needs to override [``__array_ufunc__``](https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__) to specialize how certain ufuncs are handled."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ae9f3635",
"metadata": {},
"outputs": [],
"source": [
"class MyArray(np.ndarray):\n",
" def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):\n",
" # This is a \"magic\" method in NumPy subclasses to override \n",
" # the behavior of NumPy’s ufuncs.\n",
" if method == \"__call__\":\n",
" N = None\n",
" scalars = []\n",
" for inp in inputs:\n",
" # If scalar?\n",
" if isinstance(inp, Number):\n",
" scalars.append(inp)\n",
" # If array?\n",
" elif isinstance(inp, (type(self), np.ndarray)):\n",
" if isinstance(inp, type(self)):\n",
" scalars.append(np.ndarray(inp.shape, inp.dtype, inp))\n",
" else:\n",
" scalars.append(inp)\n",
" # Guard shape\n",
" if N is not None:\n",
" if N != inp.shape:\n",
" raise TypeError(\"inconsistent sizes\")\n",
" else:\n",
" N = inp.shape\n",
" # If unknown type?\n",
" else:\n",
" return NotImplemented\n",
" print(f\"NumPy: {type(self)}.__array_ufunc__ method={method} inputs={inputs}\")\n",
" ret = ufunc(*scalars, **kwargs)\n",
" return self.__class__(ret.shape, ret.dtype, ret)\n",
" else:\n",
" return NotImplemented"
]
},
{
"cell_type": "markdown",
"id": "0c5eb363",
"metadata": {},
"source": [
"## Register the new NumPy subclass type in Numba\n"
]
},
{
"cell_type": "markdown",
"id": "85bb7d59",
"metadata": {},
"source": [
"Make a subclass of the Numba `Array` type to represent `MyArray` as a Numba type. Similar to the NumPy `ndarray` subclass, the Numba type also has a `__array_ufunc__` method, but the difference is that it operates in the Numba _typing domain_. Concretely, it receives ``inputs`` that are the argument types, not the argument values, and it returns the type of the returned value, not the return value itself."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1c704009",
"metadata": {},
"outputs": [],
"source": [
"class MyArrayType(types.Array):\n",
" def __init__(self, dtype, ndim, layout, readonly=False, aligned=True):\n",
" name = f\"MyArray({ndim}, {dtype}, {layout})\"\n",
" super().__init__(dtype, ndim, layout, readonly=readonly,\n",
" aligned=aligned, name=name)\n",
" \n",
" # Tell Numba typing how to combine MyArrayType with other ndarray types.\n",
" def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):\n",
" \"\"\"\n",
" This is the parallel for NumPy's __array_ufunc__ but operates on Numba types instead.\n",
" In NumPy's __array_ufunc__, this performs the calculation, but here we \n",
" only produce the return type.\n",
" \"\"\"\n",
" if method == \"__call__\":\n",
" for inp in inputs:\n",
" if not isinstance(inp, (types.Array, types.Number)):\n",
" return NotImplemented\n",
" print(f\"Numba: {self}.__array_ufunc__ method={method} inputs={inputs}\")\n",
" return MyArrayType\n",
" else:\n",
" return NotImplemented"
]
},
{
"cell_type": "markdown",
"id": "056946fb",
"metadata": {},
"source": [
"We need to teach Numba that ``MyArray`` corresponds to the Numba type ``MyArrayType``. This is done by registering the implementation of `typeof` for `MyArray`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "194ca90a",
"metadata": {},
"outputs": [],
"source": [
"@typeof_impl.register(MyArray)\n",
"def typeof_ta_ndarray(val, c):\n",
" # Determine dtype\n",
" try:\n",
" dtype = numpy_support.from_dtype(val.dtype)\n",
" except NotImplementedError:\n",
" raise ValueError(\"Unsupported array dtype: %s\" % (val.dtype,))\n",
" # Determine memory layout\n",
" layout = numpy_support.map_layout(val)\n",
" # Determine writeability\n",
" readonly = not val.flags.writeable\n",
" return MyArrayType(dtype, val.ndim, layout, readonly=readonly)"
]
},
{
"cell_type": "markdown",
"id": "e9f4f609",
"metadata": {},
"source": [
"We also need to teach Numba how `MyArrayType` is represented in memory. For our purpose, it is the same as the basic `Array` type. This is done by registering a `datamodel` for `MyArrayType`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6102793e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"numba.core.datamodel.models.ArrayModel"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"register_model(MyArrayType)(models.ArrayModel)"
]
},
{
"cell_type": "markdown",
"id": "cba0a91c",
"metadata": {},
"source": [
"## Override the allocator in the subclass"
]
},
{
"cell_type": "markdown",
"id": "f674efce",
"metadata": {},
"source": [
"We define a new allocator to use inside Numba for `MyArray`. Numba exposes an API for external code to register a new allocator table. The C structure for the allocator table is defined below:\n",
"\n",
"(From: https://github.com/numba/numba/blob/0.54.0/numba/core/runtime/nrt_external.h#L10-L19)\n",
"\n",
"```C\n",
"\n",
"typedef void *(*NRT_external_malloc_func)(size_t size, void *opaque_data);\n",
"typedef void *(*NRT_external_realloc_func)(void *ptr, size_t new_size, void *opaque_data);\n",
"typedef void (*NRT_external_free_func)(void *ptr, void *opaque_data);\n",
"\n",
"\n",
"struct ExternalMemAllocator {\n",
" NRT_external_malloc_func malloc;\n",
" NRT_external_realloc_func realloc;\n",
" NRT_external_free_func free;\n",
" void *opaque_data;\n",
"};\n",
"```\n",
"\n",
"In the following, we use `ctypes` to expose Python functions as C-functions (using `ctypes.CFUNCTYPE`). These functions will be used as the allocator and deallocator. Then, we put the pointers to these functions into a `ctypes.Structure` that matches the `ExternalMemAllocator` structure shown above.\n",
"\n",
"As this is not a performance focused implementation, we are using Python functions as the allocator/deallocator so that we can `print()` when they are invoked. For production use, users are expected to write the allocator/deallocator in native code.\n",
"\n",
"**WARNING: DO NOT rerun** the following cells. It will cause a segfault because the deallocator (`free_func()`) can be a removed before all the Numba dynamic memory is released."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "348f3881",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"allocator_table: 0x7faad81b2fb0\n"
]
}
],
"source": [
"lib = ctypes.CDLL(None)\n",
"lib.malloc.argtypes = [ctypes.c_size_t]\n",
"lib.malloc.restype = ctypes.c_size_t\n",
"lib.free.argtypes = [ctypes.c_void_p]\n",
"\n",
"@ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_size_t, ctypes.c_void_p)\n",
"def malloc_func(size, data):\n",
" \"\"\"\n",
" The allocator. Numba takes opaque data as a void* in the second argument.\n",
" \"\"\"\n",
" # Call underlying C malloc\n",
" out = lib.malloc(size)\n",
" print(f\">>> Malloc size={size} data={data} -> {hex(np.uintp(out))}\")\n",
" return out\n",
"\n",
"\n",
"@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p)\n",
"def free_func(ptr, data):\n",
" \"\"\"\n",
" The deallocator. Numba takes opaque data as a void* in the second argument.\n",
" \"\"\"\n",
" if lib is None:\n",
" # Note: in practice guard against global being removed during interpreter shutdown\n",
" return\n",
" print(f\">>> Free ptr={hex(ptr)} data={data}\")\n",
" # Call underlying C free()\n",
" lib.free(ptr)\n",
" return\n",
"\n",
"\n",
"class ExternalMemAllocator(ctypes.Structure):\n",
" \"\"\"\n",
" This defines a struct for the allocator table. \n",
" Its fields must match ExternalMemAllocator defined in `nrt_external.h`\n",
" \"\"\"\n",
" _fields_ = [\n",
" (\"malloc_func\", ctypes.c_void_p),\n",
" (\"realloc_func\", ctypes.c_void_p),\n",
" (\"free_func\", ctypes.c_void_p),\n",
" (\"data\", ctypes.c_void_p),\n",
" ]\n",
"\n",
"# Instantiate the allocator table\n",
"allocator_table = ExternalMemAllocator(\n",
" malloc_func=ctypes.cast(malloc_func, ctypes.c_void_p),\n",
" realloc_func=None, # unused; skipped for demo purpose\n",
" free_func=ctypes.cast(free_func, ctypes.c_void_p),\n",
" data=None, # no extra data needed\n",
")\n",
"# Inspect the address of the table\n",
"print(\"allocator_table:\", hex(ctypes.addressof(allocator_table)))"
]
},
{
"cell_type": "markdown",
"id": "5783ee53",
"metadata": {},
"source": [
"Now to override the memory allocator for this array subclass...\n",
"\n",
"Note: For demonstration purpose, the allocator references the dynamic runtime address of the allocator-table. This disables several features of Numba, including caching and AOT compilation. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "de2cf2bd",
"metadata": {},
"outputs": [],
"source": [
"@overload_classmethod(MyArrayType, \"_allocate\")\n",
"def _ol_array_allocate(cls, allocsize, align):\n",
" \"\"\"Implements a Numba-only classmethod on the array type.\n",
" \"\"\"\n",
" def impl(cls, allocsize, align):\n",
" # The bulk of the work in implemented in the intrinsic below.\n",
" return allocator_MyArray(allocsize, align)\n",
"\n",
" return impl\n",
"\n",
"@intrinsic\n",
"def allocator_MyArray(typingctx, allocsize, align):\n",
" def impl(context, builder, sig, args):\n",
" context.nrt._require_nrt()\n",
" size, align = args\n",
"\n",
" mod = builder.module\n",
" u32 = ir.IntType(32)\n",
" voidptr = cgutils.voidptr_t\n",
"\n",
" # We will use our custom allocator table here.\n",
" # The table is referenced by its dynamic runtime address.\n",
" addr = ctypes.addressof(allocator_table)\n",
" ext_alloc = context.add_dynamic_addr(builder, addr, info='custom_alloc_table')\n",
"\n",
" # Invoke the allocator routine that uses our custom allocator\n",
" fnty = ir.FunctionType(voidptr, [cgutils.intp_t, u32, voidptr])\n",
" fn = cgutils.get_or_insert_function(\n",
" mod, fnty, name=\"NRT_MemInfo_alloc_safe_aligned_external\"\n",
" )\n",
" fn.return_value.add_attribute(\"noalias\")\n",
"\n",
" if isinstance(align, builtins.int):\n",
" align = context.get_constant(types.uint32, align)\n",
" else:\n",
" assert align.type == u32, \"align must be a uint32\"\n",
"\n",
" call = builder.call(fn, [size, align, ext_alloc])\n",
" return call\n",
"\n",
" mip = types.MemInfoPointer(types.voidptr) # return untyped pointer\n",
" sig = typing.signature(mip, allocsize, align)\n",
" return sig, impl\n"
]
},
{
"cell_type": "markdown",
"id": "fc2f602d",
"metadata": {},
"source": [
"## Testing\n",
"\n",
"To test, we define a simple functions that computes `a * 2 + a`:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9f0d1a4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MyArray([0, 1, 2, 3])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def foo(a):\n",
" return a * 2 + a\n",
"\n",
"buf = np.arange(4)\n",
"a = MyArray(buf.shape, buf.dtype, buf)\n",
"a"
]
},
{
"cell_type": "markdown",
"id": "741ddb83",
"metadata": {},
"source": [
"When `foo()`, is not Numba-compiled, is executed, we can see that the `MyArray.__array_ufunc__` method is used for the `*` and `+` operations."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "011af807",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NumPy: .__array_ufunc__ method=__call__ inputs=(MyArray([0, 1, 2, 3]), 2)\n",
"NumPy: .__array_ufunc__ method=__call__ inputs=(MyArray([0, 2, 4, 6]), MyArray([0, 1, 2, 3]))\n"
]
},
{
"data": {
"text/plain": [
"MyArray([0, 3, 6, 9])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"foo(a)"
]
},
{
"cell_type": "markdown",
"id": "15243909",
"metadata": {},
"source": [
"Below is the Numba JIT version:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c761b309",
"metadata": {},
"outputs": [],
"source": [
"jit_foo = njit(foo)"
]
},
{
"cell_type": "markdown",
"id": "63ddb6dd",
"metadata": {},
"source": [
"When `jit_foo()` is executed, `MyArrayType.__array_ufunc__` method is used to compute the types of the `*` and `+` operations. Note, type-inference is invoking the `__array_ufunc__` method multiple times due to specifics of the algorithm. We can also see a series of prints to `stdout` as part of the implementation of the allocator (`malloc_func()`) and deallocator (`free_func()`). It is showing two allocations for the result of `*` and `+`, and one deallocation for the intermediate in `*`."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0435ef04",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)\n",
"Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))\n",
"Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)\n",
"Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))\n",
"Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)\n",
"Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))\n",
">>> Malloc size=144 data=None -> 0x7faa85511aa0\n",
">>> Malloc size=144 data=None -> 0x7faa85561760\n",
">>> Free ptr=0x7faa85511aa0 data=None\n"
]
},
{
"data": {
"text/plain": [
"array([0, 3, 6, 9])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jit_foo(a)"
]
},
{
"cell_type": "markdown",
"id": "f8d961ea",
"metadata": {},
"source": [
"Lastly, we can observe the use of the `MyArray` type in the annotated IR."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "94ebd691",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"foo (MyArray(1, int64, C),)\n",
"--------------------------------------------------------------------------------\n",
"# File: \n",
"# --- LINE 1 --- \n",
"\n",
"def foo(a):\n",
"\n",
" # --- LINE 2 --- \n",
" # label 0\n",
" # a = arg(0, name=a) :: MyArray(1, int64, C)\n",
" # $const4.1 = const(int, 2) :: Literal[int](2)\n",
" # $6binary_multiply.2 = a * $const4.1 :: MyArray(1, int64, C)\n",
" # del $const4.1\n",
" # $10binary_add.4 = $6binary_multiply.2 + a :: MyArray(1, int64, C)\n",
" # del a\n",
" # del $6binary_multiply.2\n",
" # $12return_value.5 = cast(value=$10binary_add.4) :: MyArray(1, int64, C)\n",
" # del $10binary_add.4\n",
" # return $12return_value.5\n",
"\n",
" return a * 2 + a\n",
"\n",
"\n",
"================================================================================\n"
]
}
],
"source": [
"jit_foo.inspect_types()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}