{ "cells": [ { "cell_type": "markdown", "id": "e48e9e8b", "metadata": {}, "source": [ "# `@overload_classmethod` for NumPy Array subclasses\n", "\n", "In this release, experimental support is added for specializing the allocator in NumPy `ndarray` subclasses. Two key enhancements were added to enable this:\n", "\n", "- `@overload_classmethod` permits the specializing of `classmethod` on specific types; and,\n", "- exposing [`Array._allocate`](https://github.com/numba/numba/blob/0.54.0/numba/np/arrayobj.py#L3531-L3537) as an overloadable `classmethod` on Numba's `Array` type.\n", "\n", "The rest of this notebook demonstrates the use of `@overload_classmethod` to override the allocator for a custom NumPy `ndarray` subclass." ] }, { "cell_type": "code", "execution_count": 1, "id": "d94a2aa0", "metadata": {}, "outputs": [], "source": [ "# All necessary imports\n", "import builtins\n", "import ctypes\n", "from numbers import Number\n", "\n", "import numpy as np\n", "\n", "# We'll be need to write some LLVM IR\n", "from llvmlite import ir\n", "\n", "from numba import njit\n", "from numba.core import types\n", "from numba.extending import (\n", " overload_classmethod,\n", " typeof_impl,\n", " register_model,\n", " intrinsic,\n", ")\n", "from numba.core import cgutils, types, typing\n", "from numba.core.datamodel import models\n", "from numba.np import numpy_support" ] }, { "cell_type": "markdown", "id": "b753d88e", "metadata": {}, "source": [ "## Define a NumPy array subclass\n", "\n", "Make a NumPy `ndarray` subclass called `MyArray`. It needs to override [``__array_ufunc__``](https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__) to specialize how certain ufuncs are handled." ] }, { "cell_type": "code", "execution_count": 2, "id": "ae9f3635", "metadata": {}, "outputs": [], "source": [ "class MyArray(np.ndarray):\n", " def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):\n", " # This is a \"magic\" method in NumPy subclasses to override \n", " # the behavior of NumPy’s ufuncs.\n", " if method == \"__call__\":\n", " N = None\n", " scalars = []\n", " for inp in inputs:\n", " # If scalar?\n", " if isinstance(inp, Number):\n", " scalars.append(inp)\n", " # If array?\n", " elif isinstance(inp, (type(self), np.ndarray)):\n", " if isinstance(inp, type(self)):\n", " scalars.append(np.ndarray(inp.shape, inp.dtype, inp))\n", " else:\n", " scalars.append(inp)\n", " # Guard shape\n", " if N is not None:\n", " if N != inp.shape:\n", " raise TypeError(\"inconsistent sizes\")\n", " else:\n", " N = inp.shape\n", " # If unknown type?\n", " else:\n", " return NotImplemented\n", " print(f\"NumPy: {type(self)}.__array_ufunc__ method={method} inputs={inputs}\")\n", " ret = ufunc(*scalars, **kwargs)\n", " return self.__class__(ret.shape, ret.dtype, ret)\n", " else:\n", " return NotImplemented" ] }, { "cell_type": "markdown", "id": "0c5eb363", "metadata": {}, "source": [ "## Register the new NumPy subclass type in Numba\n" ] }, { "cell_type": "markdown", "id": "85bb7d59", "metadata": {}, "source": [ "Make a subclass of the Numba `Array` type to represent `MyArray` as a Numba type. Similar to the NumPy `ndarray` subclass, the Numba type also has a `__array_ufunc__` method, but the difference is that it operates in the Numba _typing domain_. Concretely, it receives ``inputs`` that are the argument types, not the argument values, and it returns the type of the returned value, not the return value itself." ] }, { "cell_type": "code", "execution_count": 3, "id": "1c704009", "metadata": {}, "outputs": [], "source": [ "class MyArrayType(types.Array):\n", " def __init__(self, dtype, ndim, layout, readonly=False, aligned=True):\n", " name = f\"MyArray({ndim}, {dtype}, {layout})\"\n", " super().__init__(dtype, ndim, layout, readonly=readonly,\n", " aligned=aligned, name=name)\n", " \n", " # Tell Numba typing how to combine MyArrayType with other ndarray types.\n", " def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):\n", " \"\"\"\n", " This is the parallel for NumPy's __array_ufunc__ but operates on Numba types instead.\n", " In NumPy's __array_ufunc__, this performs the calculation, but here we \n", " only produce the return type.\n", " \"\"\"\n", " if method == \"__call__\":\n", " for inp in inputs:\n", " if not isinstance(inp, (types.Array, types.Number)):\n", " return NotImplemented\n", " print(f\"Numba: {self}.__array_ufunc__ method={method} inputs={inputs}\")\n", " return MyArrayType\n", " else:\n", " return NotImplemented" ] }, { "cell_type": "markdown", "id": "056946fb", "metadata": {}, "source": [ "We need to teach Numba that ``MyArray`` corresponds to the Numba type ``MyArrayType``. This is done by registering the implementation of `typeof` for `MyArray`." ] }, { "cell_type": "code", "execution_count": 4, "id": "194ca90a", "metadata": {}, "outputs": [], "source": [ "@typeof_impl.register(MyArray)\n", "def typeof_ta_ndarray(val, c):\n", " # Determine dtype\n", " try:\n", " dtype = numpy_support.from_dtype(val.dtype)\n", " except NotImplementedError:\n", " raise ValueError(\"Unsupported array dtype: %s\" % (val.dtype,))\n", " # Determine memory layout\n", " layout = numpy_support.map_layout(val)\n", " # Determine writeability\n", " readonly = not val.flags.writeable\n", " return MyArrayType(dtype, val.ndim, layout, readonly=readonly)" ] }, { "cell_type": "markdown", "id": "e9f4f609", "metadata": {}, "source": [ "We also need to teach Numba how `MyArrayType` is represented in memory. For our purpose, it is the same as the basic `Array` type. This is done by registering a `datamodel` for `MyArrayType`." ] }, { "cell_type": "code", "execution_count": 5, "id": "6102793e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "numba.core.datamodel.models.ArrayModel" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "register_model(MyArrayType)(models.ArrayModel)" ] }, { "cell_type": "markdown", "id": "cba0a91c", "metadata": {}, "source": [ "## Override the allocator in the subclass" ] }, { "cell_type": "markdown", "id": "f674efce", "metadata": {}, "source": [ "We define a new allocator to use inside Numba for `MyArray`. Numba exposes an API for external code to register a new allocator table. The C structure for the allocator table is defined below:\n", "\n", "(From: https://github.com/numba/numba/blob/0.54.0/numba/core/runtime/nrt_external.h#L10-L19)\n", "\n", "```C\n", "\n", "typedef void *(*NRT_external_malloc_func)(size_t size, void *opaque_data);\n", "typedef void *(*NRT_external_realloc_func)(void *ptr, size_t new_size, void *opaque_data);\n", "typedef void (*NRT_external_free_func)(void *ptr, void *opaque_data);\n", "\n", "\n", "struct ExternalMemAllocator {\n", " NRT_external_malloc_func malloc;\n", " NRT_external_realloc_func realloc;\n", " NRT_external_free_func free;\n", " void *opaque_data;\n", "};\n", "```\n", "\n", "In the following, we use `ctypes` to expose Python functions as C-functions (using `ctypes.CFUNCTYPE`). These functions will be used as the allocator and deallocator. Then, we put the pointers to these functions into a `ctypes.Structure` that matches the `ExternalMemAllocator` structure shown above.\n", "\n", "As this is not a performance focused implementation, we are using Python functions as the allocator/deallocator so that we can `print()` when they are invoked. For production use, users are expected to write the allocator/deallocator in native code.\n", "\n", "**WARNING: DO NOT rerun** the following cells. It will cause a segfault because the deallocator (`free_func()`) can be a removed before all the Numba dynamic memory is released." ] }, { "cell_type": "code", "execution_count": 6, "id": "348f3881", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "allocator_table: 0x7faad81b2fb0\n" ] } ], "source": [ "lib = ctypes.CDLL(None)\n", "lib.malloc.argtypes = [ctypes.c_size_t]\n", "lib.malloc.restype = ctypes.c_size_t\n", "lib.free.argtypes = [ctypes.c_void_p]\n", "\n", "@ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_size_t, ctypes.c_void_p)\n", "def malloc_func(size, data):\n", " \"\"\"\n", " The allocator. Numba takes opaque data as a void* in the second argument.\n", " \"\"\"\n", " # Call underlying C malloc\n", " out = lib.malloc(size)\n", " print(f\">>> Malloc size={size} data={data} -> {hex(np.uintp(out))}\")\n", " return out\n", "\n", "\n", "@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p)\n", "def free_func(ptr, data):\n", " \"\"\"\n", " The deallocator. Numba takes opaque data as a void* in the second argument.\n", " \"\"\"\n", " if lib is None:\n", " # Note: in practice guard against global being removed during interpreter shutdown\n", " return\n", " print(f\">>> Free ptr={hex(ptr)} data={data}\")\n", " # Call underlying C free()\n", " lib.free(ptr)\n", " return\n", "\n", "\n", "class ExternalMemAllocator(ctypes.Structure):\n", " \"\"\"\n", " This defines a struct for the allocator table. \n", " Its fields must match ExternalMemAllocator defined in `nrt_external.h`\n", " \"\"\"\n", " _fields_ = [\n", " (\"malloc_func\", ctypes.c_void_p),\n", " (\"realloc_func\", ctypes.c_void_p),\n", " (\"free_func\", ctypes.c_void_p),\n", " (\"data\", ctypes.c_void_p),\n", " ]\n", "\n", "# Instantiate the allocator table\n", "allocator_table = ExternalMemAllocator(\n", " malloc_func=ctypes.cast(malloc_func, ctypes.c_void_p),\n", " realloc_func=None, # unused; skipped for demo purpose\n", " free_func=ctypes.cast(free_func, ctypes.c_void_p),\n", " data=None, # no extra data needed\n", ")\n", "# Inspect the address of the table\n", "print(\"allocator_table:\", hex(ctypes.addressof(allocator_table)))" ] }, { "cell_type": "markdown", "id": "5783ee53", "metadata": {}, "source": [ "Now to override the memory allocator for this array subclass...\n", "\n", "Note: For demonstration purpose, the allocator references the dynamic runtime address of the allocator-table. This disables several features of Numba, including caching and AOT compilation. " ] }, { "cell_type": "code", "execution_count": 7, "id": "de2cf2bd", "metadata": {}, "outputs": [], "source": [ "@overload_classmethod(MyArrayType, \"_allocate\")\n", "def _ol_array_allocate(cls, allocsize, align):\n", " \"\"\"Implements a Numba-only classmethod on the array type.\n", " \"\"\"\n", " def impl(cls, allocsize, align):\n", " # The bulk of the work in implemented in the intrinsic below.\n", " return allocator_MyArray(allocsize, align)\n", "\n", " return impl\n", "\n", "@intrinsic\n", "def allocator_MyArray(typingctx, allocsize, align):\n", " def impl(context, builder, sig, args):\n", " context.nrt._require_nrt()\n", " size, align = args\n", "\n", " mod = builder.module\n", " u32 = ir.IntType(32)\n", " voidptr = cgutils.voidptr_t\n", "\n", " # We will use our custom allocator table here.\n", " # The table is referenced by its dynamic runtime address.\n", " addr = ctypes.addressof(allocator_table)\n", " ext_alloc = context.add_dynamic_addr(builder, addr, info='custom_alloc_table')\n", "\n", " # Invoke the allocator routine that uses our custom allocator\n", " fnty = ir.FunctionType(voidptr, [cgutils.intp_t, u32, voidptr])\n", " fn = cgutils.get_or_insert_function(\n", " mod, fnty, name=\"NRT_MemInfo_alloc_safe_aligned_external\"\n", " )\n", " fn.return_value.add_attribute(\"noalias\")\n", "\n", " if isinstance(align, builtins.int):\n", " align = context.get_constant(types.uint32, align)\n", " else:\n", " assert align.type == u32, \"align must be a uint32\"\n", "\n", " call = builder.call(fn, [size, align, ext_alloc])\n", " return call\n", "\n", " mip = types.MemInfoPointer(types.voidptr) # return untyped pointer\n", " sig = typing.signature(mip, allocsize, align)\n", " return sig, impl\n" ] }, { "cell_type": "markdown", "id": "fc2f602d", "metadata": {}, "source": [ "## Testing\n", "\n", "To test, we define a simple functions that computes `a * 2 + a`:" ] }, { "cell_type": "code", "execution_count": 8, "id": "9f0d1a4b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MyArray([0, 1, 2, 3])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def foo(a):\n", " return a * 2 + a\n", "\n", "buf = np.arange(4)\n", "a = MyArray(buf.shape, buf.dtype, buf)\n", "a" ] }, { "cell_type": "markdown", "id": "741ddb83", "metadata": {}, "source": [ "When `foo()`, is not Numba-compiled, is executed, we can see that the `MyArray.__array_ufunc__` method is used for the `*` and `+` operations." ] }, { "cell_type": "code", "execution_count": 9, "id": "011af807", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NumPy: .__array_ufunc__ method=__call__ inputs=(MyArray([0, 1, 2, 3]), 2)\n", "NumPy: .__array_ufunc__ method=__call__ inputs=(MyArray([0, 2, 4, 6]), MyArray([0, 1, 2, 3]))\n" ] }, { "data": { "text/plain": [ "MyArray([0, 3, 6, 9])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "foo(a)" ] }, { "cell_type": "markdown", "id": "15243909", "metadata": {}, "source": [ "Below is the Numba JIT version:" ] }, { "cell_type": "code", "execution_count": 10, "id": "c761b309", "metadata": {}, "outputs": [], "source": [ "jit_foo = njit(foo)" ] }, { "cell_type": "markdown", "id": "63ddb6dd", "metadata": {}, "source": [ "When `jit_foo()` is executed, `MyArrayType.__array_ufunc__` method is used to compute the types of the `*` and `+` operations. Note, type-inference is invoking the `__array_ufunc__` method multiple times due to specifics of the algorithm. We can also see a series of prints to `stdout` as part of the implementation of the allocator (`malloc_func()`) and deallocator (`free_func()`). It is showing two allocations for the result of `*` and `+`, and one deallocation for the intermediate in `*`." ] }, { "cell_type": "code", "execution_count": 11, "id": "0435ef04", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)\n", "Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))\n", "Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)\n", "Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))\n", "Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), int64)\n", "Numba: MyArray(1, int64, C).__array_ufunc__ method=__call__ inputs=(MyArray(1, int64, C), MyArray(1, int64, C))\n", ">>> Malloc size=144 data=None -> 0x7faa85511aa0\n", ">>> Malloc size=144 data=None -> 0x7faa85561760\n", ">>> Free ptr=0x7faa85511aa0 data=None\n" ] }, { "data": { "text/plain": [ "array([0, 3, 6, 9])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jit_foo(a)" ] }, { "cell_type": "markdown", "id": "f8d961ea", "metadata": {}, "source": [ "Lastly, we can observe the use of the `MyArray` type in the annotated IR." ] }, { "cell_type": "code", "execution_count": 13, "id": "94ebd691", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "foo (MyArray(1, int64, C),)\n", "--------------------------------------------------------------------------------\n", "# File: \n", "# --- LINE 1 --- \n", "\n", "def foo(a):\n", "\n", " # --- LINE 2 --- \n", " # label 0\n", " # a = arg(0, name=a) :: MyArray(1, int64, C)\n", " # $const4.1 = const(int, 2) :: Literal[int](2)\n", " # $6binary_multiply.2 = a * $const4.1 :: MyArray(1, int64, C)\n", " # del $const4.1\n", " # $10binary_add.4 = $6binary_multiply.2 + a :: MyArray(1, int64, C)\n", " # del a\n", " # del $6binary_multiply.2\n", " # $12return_value.5 = cast(value=$10binary_add.4) :: MyArray(1, int64, C)\n", " # del $10binary_add.4\n", " # return $12return_value.5\n", "\n", " return a * 2 + a\n", "\n", "\n", "================================================================================\n" ] } ], "source": [ "jit_foo.inspect_types()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }