{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Overloading Torch-TensorRT Converters with Custom Converters\n\nIf for some reason you want to change the conversion behavior of a specific PyTorch operation to TensorRT, you can do so by writing a custom converter and overloading Torch-TensorRT's.\nThis may be for reasons like wanting to use a custom kernel instead of TensorRT's kernels or because you want to use a different implementation of a layer in TensorRT than the one\nTorch-TensorRT would normally use.\n\nIn this tutorial, we will demonstrate how to overload Torch-TensorRT's conversion of the `torch.nn.functional.gelu` operation to TensorRT with a custom converter that uses a different implementation\nof the GeLU layer.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import logging\nimport sys\n\nimport torch\nimport torch_tensorrt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "GeLU has 2 modes in PyTorch, one using the ``erf`` function and the other using the ``tanh`` approximation.\nTensorRT natively supports both implementations as an activation layer, but suppose we want to use a custom implementation of GeLU in TensorRT only for ``tanh`` mode.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class GeLU(torch.nn.Module):\n def __init__(self, mode=\"tanh\"):\n super().__init__()\n self.mode = mode\n\n def forward(self, x):\n return torch.nn.functional.gelu(x, approximate=self.mode)\n\n\nmy_mod = GeLU(mode=\"tanh\").to(\"cuda\").eval()\nex_input = torch.randn(2, 5).to(\"cuda\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As a baseline, we can use the standard Torch-TensorRT GeLU converter (in tanh approximation mode) with our module.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "my_standard_gelu = torch_tensorrt.compile(\n my_mod, arg_inputs=(ex_input,), min_block_size=1\n)\nprint(my_standard_gelu.graph)\nprint(my_standard_gelu(ex_input))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Writing a Custom Converter\n\nConverters are functions that take a specific instance of a PyTorch operation in a PyTorch graph and convert it to an equivalent set TensorRT operations in an under-construction TensorRT graph.\nThey are registered with Torch-TensorRT using the ``@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter`` decorator.\nAt a code level, converter takes the current conversion state (``ConversionCtx``), the next operator in the graph to convert, and the arguments to that node\nand returns the placeholder outputs for that operation, while as side-effect inserting the necessary TensorRT layers into the TensorRT network.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Dict, Sequence, Tuple, Union\n\nimport tensorrt as trt\nfrom torch.fx.node import Argument, Node, Target\nfrom torch_tensorrt.dynamo import CompilationSettings\nfrom torch_tensorrt.dynamo.conversion import ConversionContext" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Converter Metadata\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(\n # The PyTorch operation to convert, when this operation is encountered, this converter will be called\n torch.ops.aten.gelu.default,\n # Validators are functions that determine that given a specific node, if it can be converted by the converter\n capability_validator=lambda node, settings: (\n \"approximate\" in node.kwargs and node.kwargs[\"approximate\"] == \"tanh\"\n ),\n # Can this converter be used in cases where the input shapes are dynamic\n supports_dynamic_shapes=True,\n # Set the priority of the converter to supersede the default one\n priority=torch_tensorrt.dynamo.conversion.ConverterPriority.HIGH,\n # Whether the converter requires a dynamic output allocator to run (e.g. data dependent ops)\n requires_output_allocator=True,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the decorator defining a converter, there is one required argument and a few optional ones.\nAll converters need a target operator they will run against, the idea being that when there is an instance of ``torch.ops.aten.gelu.default`` in the graph, this converter will be called.\n\nFollowing the target operator, you can provide additional metadata that defines the capabilities of the converter and the priority of the converter verses other possible converters for the target in question\n\nThe primary tool for defining the capabilities of a converter is the ``capability_validator`` argument,\nwhich is a lambda function that takes a specific node in the graph as well as the user compilation settings and returns a boolean indicating if the converter can be used for that node.\nThis validator function gets run prior to the graph partitioning phase against each instance of the converter target op. Nodes where there are no converters with validators that pass during this phase, will be executed in PyTorch at runtime.\nThis is useful for cases where you want to use a custom converter only in specific cases, like in our case where we only want to use our converter when ``approximate == \"tanh\"``.\n\nDistinct to the validator is the ``supports_dynamic_shapes`` argument, which is a boolean indicating if the converter can be used in cases where the input shapes are dynamic.\nIf this is set to ``False``, in cases where the inputs provided by the user are dynamic, this converter will be disabled. If there are no alternatives that support dynamic shape, the operation will be run in PyTorch.\n\nFinally there is the ``priority`` argument, which is an enum from the ``torch_tensorrt.dynamo.conversion.ConverterPriority`` class that defines the priority of the converter. The two options are ``HIGH`` and ``STANDARD``.\nConverters registered with ``STANDARD`` will be appended to the converter list for a given operation, while converters registered with ``HIGH`` will be prepended to the list.\nCandidate converters are evalated for their suitability in this priority order and the first converter that passes the validator is used.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Converter Implementation\nThe converter function itself takes the following arguments: the current conversion context, the target operator, the arguments to the target operator, the keyword arguments to the target operator, and the name of the target operator.\nArguments can either any of python primitives, ``torch.Tensor``, ``np.Arrays`` or ``ITensor`` objects.\nThe converter function should return the outputs of the target operator in terms of TensorRT ``ITensor`` primarily. These inputs and outputs should correspond to the schema\nof the target PyTorch operator which can be found here [https://pytorch.org/docs/main/torch.compiler_ir.html](https://pytorch.org/docs/main/torch.compiler_ir.html).\n\nSince Torch-TensorRT covers the core ATen opset, it has already abstracted many of the common low-level operations into helper functions that can be used to build up the TensorRT network.\nThis allows developers to avoid the boilerplate of creating the TensorRT layers directly and instead focus on the high-level logic of the conversion.\nThe helper functions are located in the ``torch_tensorrt.dynamo.conversion.impl`` module and are designed to be composable and interoperable with raw-TensorRT implementations.\nIn this case, we will use the Torch-TensorRT ``mul``, ``add`` and ``tanh`` functions from ``impl`` to implement our alternative GeLU layer.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def aten_ops_gelu(\n ctx: ConversionContext,\n target: Target,\n args: Tuple[Argument, ...],\n kwargs: Dict[str, Argument],\n name: str,\n) -> Union[trt.ITensor, Sequence[trt.ITensor]]:\n # The schema for torch.ops.aten.gelu.default is gelu(Tensor self, *, str approximate=\u2019none\u2019) -> Tensor\n\n from torch_tensorrt.dynamo import SourceIR\n from torch_tensorrt.dynamo.conversion import impl\n\n # Cheap way to allow layer names to be unqiue\n op_count = 0\n\n def get_op_count():\n nonlocal op_count\n op_count += 1\n return op_count\n\n mul = lambda x, y: impl.elementwise.mul(\n ctx,\n target,\n name=f\"mul_{get_op_count()}\",\n source_ir=SourceIR.ATEN,\n lhs_val=x,\n rhs_val=y,\n )\n add = lambda x, y: impl.elementwise.add(\n ctx,\n target,\n name=f\"add_{get_op_count()}\",\n source_ir=SourceIR.ATEN,\n lhs_val=x,\n rhs_val=y,\n )\n tanh = lambda x: impl.activation.tanh(\n ctx, target, name=f\"tanh_{get_op_count()}\", source_ir=SourceIR.ATEN, input_val=x\n )\n\n # So we know that our custom converter is being run instead of the standard one\n print(\"\\n\\n---------------------------\")\n print(\"Using custom GeLU converter\")\n print(\"---------------------------\\n\\n\")\n\n x_7 = mul(args[0], 0.5)\n x_8 = mul(args[0], 0.79788456080000003)\n x_9 = mul(args[0], 0.044714999999999998)\n x_10 = mul(x_9, args[0])\n x_11 = add(x_10, 1.0)\n x_12 = mul(x_8, x_11)\n x_13 = tanh(x_12)\n x_14 = add(x_13, 1.0)\n x_15 = mul(x_7, x_14)\n\n return x_15" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using our Custom Converter\n\nWe can now recompile and see that our custom converter is being called to convert GeLU to TensorRT.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "my_custom_gelu = torch_tensorrt.compile(\n my_mod, arg_inputs=(ex_input,), min_block_size=1\n)\nwith torch.no_grad():\n print(my_custom_gelu.graph)\n print(my_custom_gelu(ex_input))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can verify that our implementation matches the TensorRT implementation for the ``tanh`` approximation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\n f\"tanh approximations are close: {torch.allclose(my_standard_gelu(ex_input), my_custom_gelu(ex_input))}\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we want to verify that in the case that the ``approximate`` argument is not set to ``tanh``, our custom converter is not used.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "my_mod_erf = GeLU(mode=\"none\").to(\"cuda\").eval()\nmy_gelu_erf = torch_tensorrt.compile(\n my_mod_erf, arg_inputs=(ex_input,), min_block_size=1\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that we don't see the print statement from our custom converter, indicating that it was not used. However, looking at the graph, we can still see that a TensorRT engine was created to run the GeLU operation.\nIn this case, the validator for our custom converter returned ``False``, so the conversion system moved on to the next converter in the list, the standard GeLU converter and used that one to convert the operation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with torch.no_grad():\n print(my_gelu_erf.graph)\n print(my_gelu_erf(ex_input))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 0 }