{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Optimizing models using the PyTorch JIT\n", "\n", "*Thomas Viehmann, MathInf GmbH*\n", "\n", "Today we look at TorchScript, the language implemented by the PyTorch JIT (\"Just in Time compiler\"), PyTorch's solution for deployment and model optimization. \n", "We can use it to export models to work beyond Python, e.g. on mobile or embedded platforms, or just to escape the infamous Python Global Interpreter Lock during computation. This is possibly the more well-known application.\n", "\n", "But the JIT also lends itself to the implementation _holistic_ optimizations that consider several operations at once. This is as opposed to just writing a better implementation of any given PyTorch operation, although the JIT works for these, too, as we will see.\n", "\n", "We will start with a high-level overview of how PyTorch and the JIT work to then dive into the how it enables compiling fused kernels to optimize models at run time.\n", "\n", "*Sidenote:* If you want to take a look at exporting models, do check out Chapter 15 of our [book](https://www.manning.com/books/deep-learning-with-pytorch), from which I also took some diagrams below. There we introduce the JIT with a view towards running the model in C++ and on mobile. The book also as a comprehensive introduction from everything PyTorch to how to represent data and a detailed account of project to build an AI detecting cancerous lung nodules.\n", "\n", "This tutorial has been prepared in the context of work I did for AMD. Thank you!\n", "\n", "**Note:** This is the Notebook version of a [blog post](https://lernapparat.de/jit-optimization-intro/) on the subject." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The first thing we want to do when considering how the JIT works is consider the structure of PyTorch.\n", "\n", "\n", "(image from Deep Learning with PyTorch)\n", "\n", "PyTorch most prominently is a PyTorch library, I call this part classic PyTorch. Some parts are implemented in Python (e.g. the `torch.nn` modules and the optimizers), but the compute functions (like `torch.matmul`) are provided as a Python C++ extension.\n", "\n", "Looking a bit closer, this Python C++ extension is a thin wrapper around PyTorch's C++ library _LibTorch_. That in turn uses the ATen tensor library which itself dispatches into various backends.\n", "\n", "The PyTorch JIT now implements a virtual machine that takes in TorchScript programs (typically created through the `torch.jit` ) and runs them by calling into LibTorch itself, circumventing the Python parts.\n", "\n", "The JIT also is extendable by defining _Custom Ops_, we'll get back to this. To run PyTorch-exported programs in Torch Mobile or Torch Serving, the typical thing is to implement a wrapper around the JIT api to load and run modules.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## TorchScript\n", "\n", "Now that we know that we want to run our model in the JIT execution, we should see how to get our model into TorchScript, the form the JIT can process.\n", "\n", "*Sidenote:* TorchScript is used simultaneously for the language - mostly a typed subset of Python - and the representation (intermediate - IR).\n", "\n", "There are two main ways of achieving this (but they can be mixed), _scripting_ and _tracing_. Let's look at them." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [ { "data": { "text/plain": [ "'Vega 20 [Radeon VII]'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sys\n", "#sys.path.insert(0, '/home/tv/pytorch/pytorch/build/lib.linux-x86_64-3.9/')\n", "import torch\n", "\n", "%matplotlib inline\n", "from matplotlib import pyplot\n", "import numpy\n", "\n", "assert torch.cuda.is_available(), \"Some examples need the GPU\"\n", "torch.cuda.get_device_name()\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Scripting\n", "\n", "Scripting compiles (mostly) a subset of Python.\n", "It takes the Python source code and transforms it. \n", "\"Here is what the function should do\", just like normal programming.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "@torch.jit.script\n", "def fn(x):\n", " return x * 2" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,\n", " graph(%x.1 : Tensor):\n", " %2 : int = prim::Constant[value=2]() # :3:15\n", " %3 : Tensor = aten::mul(%x.1, %2) # :3:11\n", " return (%3))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fn, fn.graph" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Tracing\n", "\n", "Tracing runs the code and observers the calls into PyTorch with some sample input.\n", "\"Watch me, now you know how to do the same.\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "graph(%x : Float(5, strides=[1], requires_grad=0, device=cpu)):\n", " %1 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # :2:0\n", " %2 : Float(5, strides=[1], requires_grad=0, device=cpu) = aten::mul(%x, %1) # :2:0\n", " return (%2)\n", " def fn(x: Tensor) -> Tensor:\n", " return torch.mul(x, CONSTANTS.c0)\n", "\n" ] } ], "source": [ "def fn(x):\n", " return x * 2\n", "fn = torch.jit.trace(fn, [torch.randn(5)])\n", "\n", "print(fn.graph, fn.code)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "N.B.: The specialization for the Tensor shape isn't relevant here and will be erased e.g. during saving of the model." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### What is TorchScript?\n", "\n", "Now that had a glimpse of TorchScript, what is it?\n", "\n", "One important difference between TorchScript and Python is that in TorchScript everything is typed. Important\n", "types are\n", "- `bool`, `int`, `long`, `double` for numbers (int = 32 bit integer, long = 64 bit integer)\n", "- `Tensor` for tensors (of arbitrary shape, dtype, ...)\n", "- `List[T]` a list with elements fo type T (one of the above)\n", "- Tuples are of fixed size with arbitrary but fixed element type, so e.g. `Tuple(Tensor, int)`.\n", "- `Optional[T]` for things that can be `None`\n", "\n", "`None` always is of type `Optional[T]` for some specific `T` (except in the rarest circumstances).\n", "\n", "PyTorch will mostly infer the intermediate and return types, but you need to annotate any non-Tensor inputs.\n", "\n", "(maybe move to later)\n", "Another important difference is the binding behaviour - when a given variable name is looked up to find the associated variable. Python uses late binding. If we write a function that calls `torch.matmul` the Python interpreter will look up what `torch.matmul` is when it executes the statement in which it is used.\n", "\n", "This is in contrast to many other languages, which use early binding, as - your guessed it - TorchScript does: When we compile a function to TorchScript, the JIT looks it up then and there and puts it into our function (it even inlines the commands, but that is another part).\n", "*Sidenote:* And while functions are looked up early, the *operators* being executed by the PyTorch JIT are found during runtime.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Tracing vs. Scripting\n", "\n", "Scripting will process all code but may not understand all. This means it captures all constructs (like control flow) it understands, but it will fail if it doesn't understand something.\n", "\n", "Tracing doesn't see anything not calling into PyTorch and will happily ignore that (e.g. control flow). This is also the reason why it will loudly complain if you have non-tensor inputs.\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "def fn(x):\n", " for i in range(x.dim()):\n", " x = x * x\n", " return x\n", "\n", "script_fn = torch.jit.script(fn)\n", "trace_fn = torch.jit.trace(fn, [torch.randn(5, 5)])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def fn(x: Tensor) -> Tensor:\n", " x0 = x\n", " for i in range(torch.dim(x)):\n", " x0 = torch.mul(x0, x0)\n", " return x0\n", "\n" ] } ], "source": [ "print(script_fn.code)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def fn(x: Tensor) -> Tensor:\n", " x0 = torch.mul(x, x)\n", " return torch.mul(x0, x0)\n", "\n" ] } ], "source": [ "print(trace_fn.code)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Tracing and Scripting Modules\n", "\n", "But our models often are not functions. What now?\n", "\n", "With tracing, we can work just like with functions. We get a `ScriptModule` subclass that behaves much like a\n", "`Module` with parameters, state dict etc." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.jit._trace.TopLevelTracedModule,\n", " Sequential(\n", " original_name=Sequential\n", " (0): Linear(original_name=Linear)\n", " (1): ReLU(original_name=ReLU)\n", " (2): Linear(original_name=Linear)\n", " ))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = torch.nn.Sequential(\n", " torch.nn.Linear(1, 10),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(10, 1))\n", " \n", "traced_model = torch.jit.trace(model, [torch.randn(8, 1)])\n", "type(traced_model), traced_model" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Saving is a bit different, here we include the model on purpose:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1.0100],\n", " [0.4957],\n", " [0.5004],\n", " [0.6980],\n", " [0.8027],\n", " [0.5387],\n", " [0.6841],\n", " [0.7053]], grad_fn=)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "traced_model.save('./traced_model.pt')\n", "loaded_model = torch.jit.load('./traced_model.pt')\n", "\n", "loaded_model(torch.randn(8,1))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Scripting Modules\n", "\n", "Scripting modules is ... a bit tricky. We don't script the class in its entirety but instead take an instance (in particular past `__init__`) and process its data members and methods (the latters work like script functions)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def forward(self,\n", " input: Tensor) -> Tensor:\n", " _0 = getattr(self, \"0\")\n", " _1 = getattr(self, \"1\")\n", " _2 = getattr(self, \"2\")\n", " input0 = (_0).forward(input, )\n", " input1 = (_1).forward(input0, )\n", " return (_2).forward(input1, )\n", "\n" ] } ], "source": [ "scripted_model = torch.jit.script(model)\n", "print(scripted_model.code)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "We can also look at the graph including submodules, but it gets unwieldy rather fast:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "graph(%self : __torch__.torch.nn.modules.container.___torch_mangle_13.Sequential,\n", " %input.1 : Tensor):\n", " %2 : __torch__.torch.nn.modules.linear.___torch_mangle_10.Linear = prim::GetAttr[name=\"0\"](%self)\n", " %3 : __torch__.torch.nn.modules.activation.___torch_mangle_11.ReLU = prim::GetAttr[name=\"1\"](%self)\n", " %4 : __torch__.torch.nn.modules.linear.___torch_mangle_12.Linear = prim::GetAttr[name=\"2\"](%self)\n", " %8 : int = prim::Constant[value=1]()\n", " %9 : int = prim::Constant[value=2]() # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:22\n", " %10 : Tensor = prim::GetAttr[name=\"weight\"](%2)\n", " %11 : Tensor = prim::GetAttr[name=\"bias\"](%2)\n", " %12 : int = aten::dim(%input.1) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:7\n", " %13 : bool = aten::eq(%12, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:7\n", " %input.3 : Tensor = prim::If(%13) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:4\n", " block0():\n", " %15 : Tensor = aten::t(%10) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1665:39\n", " %ret.2 : Tensor = aten::addmm(%11, %input.1, %15, %8, %8) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1665:14\n", " -> (%ret.2)\n", " block1():\n", " %17 : Tensor = aten::t(%10) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1667:30\n", " %output.2 : Tensor = aten::matmul(%input.1, %17) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1667:17\n", " %output.4 : Tensor = aten::add_(%output.2, %11, %8) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1669:12\n", " -> (%output.4)\n", " %input.5 : Tensor = aten::relu(%input.3) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1111:17\n", " %21 : int = prim::Constant[value=1]()\n", " %22 : int = prim::Constant[value=2]() # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:22\n", " %23 : Tensor = prim::GetAttr[name=\"weight\"](%4)\n", " %24 : Tensor = prim::GetAttr[name=\"bias\"](%4)\n", " %25 : int = aten::dim(%input.5) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:7\n", " %26 : bool = aten::eq(%25, %22) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:7\n", " %input.7 : Tensor = prim::If(%26) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:4\n", " block0():\n", " %28 : Tensor = aten::t(%23) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1665:39\n", " %ret.1 : Tensor = aten::addmm(%24, %input.5, %28, %21, %21) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1665:14\n", " -> (%ret.1)\n", " block1():\n", " %30 : Tensor = aten::t(%23) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1667:30\n", " %output.1 : Tensor = aten::matmul(%input.5, %30) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1667:17\n", " %output.3 : Tensor = aten::add_(%output.1, %24, %21) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1669:12\n", " -> (%output.3)\n", " return (%input.7)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scripted_model.forward.inlined_graph" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# What can you do with scripted modules?\n", "\n", "- Run them as is, bypassing Python.\n", " - not as much speedup as often is expected (maybe 5%-10% for some models I tested),\n", " - but - sometimes crucially - it avoids the dreaded Python Global Interpreter Lock (GIL), so it is useful e.g.\n", " for multithreaded things like serving PyTorch models.\n", "- Export and run in C++ / Mobile / ..., export to other frameworks like [TVM](https://tvm.ai/).\n", "- Apply holistic optimizations (this is what a submodule, the JIT fuser does)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# How the JIT works at a very high level\n", "\n", "For the in-depth discussion of fusers it will be useful to look closer at how the JIT works under the hood.\n", "The JIT has several phases to get us from a function to running our programs. For our purposes, we think of the following three stages:\n", "\n", "- The first thing is to go from tracing or source to a graph.\n", "- Then there are a number of compiler passes through the graph to go from `.graph` to an optimized graph (that can be retrieved with `.graph_for(*inputs)`. We will meet some of them in detail below.\n", "- Finally, the `.graph` is compiled to a from of bytecode that is then executed by a virtual machine. We might hope to not meet the bytecode too often, but clearly we want this part to be fast, too. This maintains the operands on a stack and then dispatches to the various operators registered by LibTorch or the _custom operators_ that extend the JIT.\n", "\n", "The unoptimized `.graph` is the \"hosehold\" format here, in particular, this is what is serialized and loading a scripted function will then have to re-do the optimizations.\n", "\n", "## Tracing or scripting to a .graph\n", "\n", " When tracing a function, the LibTorch dispatcher will call a special function (found in `torch/csrc/autograd/generated/TraceTypeEverything.cpp` after you have built PyTorch) for every call of a LibTorch function. This special function (*Sidenote*: For more on the dispatcher, see Ed Yang's excellent blog post [Let's talk about the PyTorch Dispatcher](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/).) This will record a graph node (the ones that show up in `.graph`) with source location and type information and all in a `TracingState` structure's `.graph` and in between re-dispatch to run the real LibTorch operation. This `.graph` is more or less directly what you can see as `.graph` of a traced function.\n", "\n", "- When tracing modules, the tracer will also hook into the module `__call__` method to record the current module as the scope to capture the module structure. This is done at the Python level in the `torch.nn.Module` class, see the `_slow_forward` method there.\n", "\n", "- When scripting a function from Python, the JIT grabs the Python source code (via the `inspect` module of the standard Python library) and then runs the Python parser from `ast` (for Abstract Syntax Tree) module. It then transforms the Python AST into a TorchScript (implemented in C++) one, from which an initial graph form that looks a lot like Python (i.e. before converting to a [static single assignment](https://en.wikipedia.org/wiki/Static_single_assignment_form) (SSA) form. Any name lookup is also done at this stage (so TorchScript is (mostly) [statically binding](https://en.wikipedia.org/wiki/Name_resolution_(programming_languages) rather than dynamically like Python), representing objects as _Sugared Values_ in between. Finally, the JIT transforms the graph into the SSA form that you can see with `.graph`.\n", "\n", "- There is a variant of scripting that can be called directly from C++ and does not use the Python `ast` but parses Python on its own. This is used internally by `AutoDiff` but is also a neat trick to use from C++.\n", "\n", "\n", "## Optimization passes\n", "\n", "The JIT compiler gets us from `.graph` to what we see with `.graph_for` above by running a series of optimization (and some other) passes. This is done by the JIT's GraphExecutor (actually there are two, the \"regular\" one and the profiling one) on the first run or first few runs in the case of the profiling executor. The optimized graphs are cached along with the bytecode.\n", "\n", "There are a number of passes that work and don't mess with AutoGrad like (these are not all of them also there are analysis passes for shapes and types and such)\n", "\n", "- Eliminating dead code and common subexpressions, pre-computing things that only involve constants,\n", "- Pooling redundant constants into single values, and some simple \"pattern matching\" optimizations (like eliminating `.t().t()`),\n", "- Unrolling small loops and batching matrix multiplications that result from unrolling loops.\n", "\n", "If the last one looks awefully special, it is, but it is quite commonly used in recurrent networks such as LSTMs with the input weights.\n", "\n", "As you might have guessed with the introduction, there are also some passes that can mess up AutoGrad and we can only do them if we do not require gradients or have taken of AutoGrad before. \n", "\n", "## Bytecode and execution\n", "\n", "Finally, the optmized graph is lowered to bytecode and run by the virtual machine. The virtual machine can also do function calls, this is used e.g. by the fallback mechanisms of the fusers. We will not deal much with this part.\n", "\n", "So this gives you a very high-level overview of what goes on in the JIT. As usual, things get complicated really soon and also the JIT is actively being worked on, making this a bit of a moving target in the details. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Excursion: GPU, efficiency, measurement\n", "\n", "Before we discuss optimization through the JIT we have to discuss measurement. In fact, one of my many informal mottos is _It's not optimization until you measure_. I'll only discuss the most basic measurement here, PyTorch offers a capable profiling facility, too.\n", "\n", "When you think about code being slow, it's important to figure out what is slow and why.\n", "To my mind, a lot of measurement can be done with very basic tools, e.g. IPython's `%timeit` magic.\n", "\n", "As GPU computation is and shold be ansynchronous, avoid unneeded synchronization points. Synchronization happens when the CPU waits for the GPU (to get the results).\n", "- Synchronizations can happen because the program needs to know something (e.g. sizes of tensors depending on\n", " the input). Often, these are unavoidable.\n", "- Typical sources of spurious synchronizations are too frequent \n", " `.to(device=\"cpu\")`, `.item()`, `.to_list()`, `print`.\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "If we want to time GPU kernels, we want to be sure to synchronize before taking the start and end times.\n", "Typically, we also want to have some \"warm-up\", i.e. run the measured function before timing.\n", "\n", "Let's take the uniformity loss from [Wang and Isola: Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere](https://arxiv.org/abs/2005.10242) (a great paper!).\n", "\n", "The Uniformity loss is defined as a function of the pairwise distances over a largish set of vectors.\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/plain": [ "tensor(-3.9374, device='cuda:0', grad_fn=)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def lunif(x, t=2): # copied from the paper\n", " sq_pdist = torch.pdist(x, p=2).pow(2)\n", " return sq_pdist.mul(-t).exp().mean().log()\n", "\n", "x = torch.randn(1024, 128, device=\"cuda\")\n", "x /= x.norm(p=2, dim=1, keepdim=True).requires_grad_()\n", "\n", "lunif(x)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "One would think that the specialised `pdist` function is the right tool for the job.\n", "But is it? Let's time it." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "18.6 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "def totime(fn):\n", " l = fn(x)\n", " g, = torch.autograd.grad(l, x)\n", " torch.cuda.synchronize()\n", "\n", "totime(lunif) # warmup\n", "%timeit totime(lunif)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Let's use $|x-y|^2 = |x|^2 + |y|^2 - 2 $ and compare." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0\n", "2.19 ms ± 9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "def lunif2(x, t=2):\n", " t=2\n", " xnorm = torch.norm(x, p=2, dim=1).pow(2)\n", " sq_pdist = xnorm[None] + xnorm[:, None] - 2 * torch.mm(x, x.t())\n", " exp = sq_pdist.mul(-t).exp().tril(diagonal=-1)\n", " N = x.size(0)\n", " res = exp.sum().mul(2/(N*N-N)).log()\n", " return res\n", "\n", "print((lunif2(x.to(torch.double)) - lunif(x.to(torch.double))).item())\n", "\n", "totime(lunif2)\n", "%timeit totime(lunif2)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "Even though we have stark inefficiencies (like taking tril and taking a copy to do so), this is almost an order of magnitude faster!\n", "\n", "Largely due to backward of `pdist` implementation." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Optimization\n", "\n", "### But Python is slow...\n", "\n", "Uniformity loss: \"what formulas you use\" is the real bottleneck (unless you optimize pdist).\n", "\n", "The \"what do we compute\" typically should be the first optimization target.\n", "\n", "But when we fix the task (\"what\"), how can we optimize?\n", "\n", "Conventional wisdom: **Python is slow**\n", "\n", "- certainly, Python isn't fast (`for` loop vs C++ `for` loop)\n", "- but, if the GPU is saturated $\\Rightarrow$ Python isn't the bottleneck\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### How PyTorch programs spend their time\n", "\n", "At a very high level, you can divide time spent into these parts:\n", "- Python program flow,\n", "- Data \"administrative overhead\" (creating `Tensor` data structures, autograd `Node`s etc.),\n", "- Data aquisition (I/O),\n", "- Computation roughly as\n", " - fixed overhead (kernel launches etc.),\n", " - reading / writing memory,\n", " - \"real computation\".\n", "\n", "**Thomas' rule of thumb**: As long as your operands are reasonably large (say 100s of elements, not single elements), Python and data \"administrative overhead\" probably isn't your main problem.\n", "\n", "So while the JIT takes away some Python overhead, this is not spectacular optimization.\n", "With this out of the way, let us get back to how the JIT helps us optimize things." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "## An ad-hoc graph plotter (skip this)\n", "\n", "It will be handy to draw some graphs, so here is a function that plots our graphs. It's not complete by any means, but it helps us here." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "def make_graph(gr):\n", " import graphviz\n", " dot = graphviz.Digraph(format='svg', graph_attr={'labelloc': 't'})\n", "\n", " nodes = {}\n", " for i in gr.inputs():\n", " nname = i.debugName()\n", " label = nname.split('.')[0]\n", " nodes[nname] = (nname, dot)\n", " dot.node(nname, label, color='blue')\n", "\n", " unseen_ops = {'prim::ListConstruct', 'aten::index', \n", " 'aten::size', 'aten::slice', 'aten::unsqueeze', 'aten::squeeze',\n", " 'aten::to', 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous',\n", " 'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', 'aten::unbind',\n", " 'aten::select', 'aten::detach', 'aten::stack', 'aten::reshape', 'aten::split_with_sizes',\n", " 'aten::cat', 'aten::expand', 'aten::expand_as', 'aten::_shape_as_tensor',\n", " 'aten::_size_if_not_equal', 'prim::BroadcastSizes',\n", " 'prim::Constant',\n", " }\n", "\n", " def process_block(nodeit, dot):\n", " firstnode = None\n", " lastnode = None\n", " for n in nodeit:\n", " k = n.kind()\n", " outs = list(n.outputs())\n", " inps = list(n.inputs())\n", " type_outs = [o.type().kind() for o in outs]\n", " type_inps = [o.type().kind() for o in inps]\n", " if k == 'prim::If':\n", " label = 'If'\n", " nname = outs[0].debugName()\n", " for i in inps:\n", " src, srcdot = nodes.get(i.debugName(), (None, None))\n", " if src is not None:\n", " srcdot.edge(src, nname + '_in')\n", " dot.node(nname + '_in', 'If', shape='diamond')\n", " dot.node(nname, '', width='0.1', height='0.1')\n", " dot.edge(nname + '_in', nname, style='invis')\n", " nodes[nname] = (nname, dot)\n", " bl = list(n.blocks())\n", " for i, b in enumerate(bl):\n", " with dot.subgraph(name=f\"cluster_{nname}_{i}\", graph_attr={'label':''}) as sub_dot:\n", " firstnode, lastnode = process_block(b.nodes(), sub_dot)\n", " dot.edge(nname + '_in', firstnode, label=\"yn\"[i])\n", " dot.edge(lastnode, nname)\n", " if firstnode is None:\n", " firstnode = nname + '_in'\n", " lastnode = nname\n", " elif k == 'prim::DifferentiableGraph':\n", " label = 'DifferentiableGraph'\n", " nname = outs[0].debugName()\n", " nodes[nname] = (nname, dot)\n", " sg = n.g('Subgraph')\n", " nis = list(n.inputs())\n", " sgis = list(sg.inputs())\n", " assert len(nis) == len(sgis)\n", " for ni, sgi in zip(nis, sgis):\n", " if ni.debugName() in nodes:\n", " nodes[sgi.debugName()] = nodes[ni.debugName()]\n", " with dot.subgraph(name=f\"cluster_{nname}\", graph_attr={\n", " 'label': 'DifferentiableGraph', 'labelloc':'b', 'labeljust':'r'}) as sub_dot:\n", " firstnode, lastnode = process_block(sg.nodes(), sub_dot)\n", " nos = list(n.outputs())\n", " sgos = list(sg.outputs())\n", " assert len(nos) <= len(sgos)\n", " for no, sgo in zip(nos, sgos):\n", " if sgo.debugName() in nodes:\n", " nodes[no.debugName()] = (nodes[sgo.debugName()][0], dot)\n", " elif k not in unseen_ops:\n", " if k == 'prim::CallFunction':\n", " label = 'call ' + next(n.inputs()).node().s(\"name\")\n", " else:\n", " label = k.replace('aten::', '').replace('prim::', '')\n", " nname = outs[0].debugName()\n", " dot.node(nname, label, shape='box', style='rounded')\n", " for o in outs:\n", " nodes[o.debugName()] = (nname, dot)\n", " for i in inps:\n", " src, srcdot = nodes.get(i.debugName(), (None, None))\n", " if src is not None:\n", " srcdot.edge(src, nname)\n", " if firstnode is None:\n", " firstnode = nname\n", " lastnode = nname\n", " return firstnode, lastnode\n", "\n", " process_block(gr.nodes(), dot)\n", " dot.node('.outputs', 'outputs', color='blue')\n", " for i, o in enumerate(gr.outputs()):\n", " src, srcdot = nodes.get(o.debugName(), (None, None))\n", " if src is not None:\n", " dot.edge(src, '.outputs')\n", "\n", " return dot\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Holistic Optimizations - JIT fusers\n", "\n", "So currently the fuser is a hotspot of development, and PyTorch has no fewer than three fusers:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on function fuser in module torch.jit._fuser:\n", "\n", "fuser(name)\n", " A context manager that facilitates switching between\n", " backend fusers.\n", " \n", " Valid names:\n", " * ``fuser0`` - enables only legacy fuser\n", " * ``fuser1`` - enables only NNC\n", " * ``fuser2`` - enables only nvFuser\n", "\n" ] } ], "source": [ "help(torch.jit.fuser)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# How the JIT optimizes pointwise operations\n", "\n", "\n", "\n", "To get a taste of how the JIT fuser works, let us look at the intersection over union ratio for detection models.\n", "We have a two lists of rectangles given by the top left (as x and y coordinates) and width and height.\n", "To measure the pairwise agreement of the $i$th rectangle in the first and in the second list.\n", "We do this by the intersection over union ratio which computes the areas of the intersection and the union of the two rectangles. The quotient of the two is between 0 (no agreement at all) and 1 (perfect agreement).\n", "\n", "*Sidenote*: Another prominent example of pointwise operations is in LSTMs: They can be though of as two matrix multiplications followed by a series of pointwise operations for the gates. The case of LSTMs has been a show case for the JIT\n", "[show case](https://lernapparat.de/fast-lstm-pytorch/) [for JIT](https://lernapparat.de/more-jit-optimizations/) [optimizations](https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/)." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):\n", " xi = torch.max(x1, x2) # Intersection left\n", " yi = torch.max(y1, y2) # Intersection top\n", " wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0.) # Intersection width\n", " hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0.) # Intersection height\n", " area_i = wi * hi # Area Intersection\n", " area_u = w1 * h1 + w2 * h2 - wi * hi # Area Union\n", " return area_i / torch.clamp(area_u, min=1e-5) # Intersection over Union\n", "\n", "# we make a scripted function\n", "ratio_iou_scripted = torch.jit.script(ratio_iou)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a simple enough function with elementwise computation. Let us look at the function graph." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "\n", "x1.1\n", "\n", "x1\n", "\n", "\n", "\n", "xi.1\n", "\n", "max\n", "\n", "\n", "\n", "x1.1->xi.1\n", "\n", "\n", "\n", "\n", "\n", "17\n", "\n", "add\n", "\n", "\n", "\n", "x1.1->17\n", "\n", "\n", "\n", "\n", "\n", "y1.1\n", "\n", "y1\n", "\n", "\n", "\n", "yi.1\n", "\n", "max\n", "\n", "\n", "\n", "y1.1->yi.1\n", "\n", "\n", "\n", "\n", "\n", "32\n", "\n", "add\n", "\n", "\n", "\n", "y1.1->32\n", "\n", "\n", "\n", "\n", "\n", "w1.1\n", "\n", "w1\n", "\n", "\n", "\n", "w1.1->17\n", "\n", "\n", "\n", "\n", "\n", "48\n", "\n", "mul\n", "\n", "\n", "\n", "w1.1->48\n", "\n", "\n", "\n", "\n", "\n", "h1.1\n", "\n", "h1\n", "\n", "\n", "\n", "h1.1->32\n", "\n", "\n", "\n", "\n", "\n", "h1.1->48\n", "\n", "\n", "\n", "\n", "\n", "x2.1\n", "\n", "x2\n", "\n", "\n", "\n", "x2.1->xi.1\n", "\n", "\n", "\n", "\n", "\n", "21\n", "\n", "add\n", "\n", "\n", "\n", "x2.1->21\n", "\n", "\n", "\n", "\n", "\n", "y2.1\n", "\n", "y2\n", "\n", "\n", "\n", "y2.1->yi.1\n", "\n", "\n", "\n", "\n", "\n", "36\n", "\n", "add\n", "\n", "\n", "\n", "y2.1->36\n", "\n", "\n", "\n", "\n", "\n", "w2.1\n", "\n", "w2\n", "\n", "\n", "\n", "w2.1->21\n", "\n", "\n", "\n", "\n", "\n", "51\n", "\n", "mul\n", "\n", "\n", "\n", "w2.1->51\n", "\n", "\n", "\n", "\n", "\n", "h2.1\n", "\n", "h2\n", "\n", "\n", "\n", "h2.1->36\n", "\n", "\n", "\n", "\n", "\n", "h2.1->51\n", "\n", "\n", "\n", "\n", "\n", "25\n", "\n", "sub\n", "\n", "\n", "\n", "xi.1->25\n", "\n", "\n", "\n", "\n", "\n", "40\n", "\n", "sub\n", "\n", "\n", "\n", "yi.1->40\n", "\n", "\n", "\n", "\n", "\n", "22\n", "\n", "min\n", "\n", "\n", "\n", "17->22\n", "\n", "\n", "\n", "\n", "\n", "21->22\n", "\n", "\n", "\n", "\n", "\n", "22->25\n", "\n", "\n", "\n", "\n", "\n", "wi.1\n", "\n", "clamp\n", "\n", "\n", "\n", "25->wi.1\n", "\n", "\n", "\n", "\n", "\n", "area_i.1\n", "\n", "mul\n", "\n", "\n", "\n", "wi.1->area_i.1\n", "\n", "\n", "\n", "\n", "\n", "56\n", "\n", "mul\n", "\n", "\n", "\n", "wi.1->56\n", "\n", "\n", "\n", "\n", "\n", "37\n", "\n", "min\n", "\n", "\n", "\n", "32->37\n", "\n", "\n", "\n", "\n", "\n", "36->37\n", "\n", "\n", "\n", "\n", "\n", "37->40\n", "\n", "\n", "\n", "\n", "\n", "hi.1\n", "\n", "clamp\n", "\n", "\n", "\n", "40->hi.1\n", "\n", "\n", "\n", "\n", "\n", "hi.1->area_i.1\n", "\n", "\n", "\n", "\n", "\n", "hi.1->56\n", "\n", "\n", "\n", "\n", "\n", "64\n", "\n", "div\n", "\n", "\n", "\n", "area_i.1->64\n", "\n", "\n", "\n", "\n", "\n", "53\n", "\n", "add\n", "\n", "\n", "\n", "48->53\n", "\n", "\n", "\n", "\n", "\n", "51->53\n", "\n", "\n", "\n", "\n", "\n", "area_u.1\n", "\n", "sub\n", "\n", "\n", "\n", "53->area_u.1\n", "\n", "\n", "\n", "\n", "\n", "56->area_u.1\n", "\n", "\n", "\n", "\n", "\n", "63\n", "\n", "clamp\n", "\n", "\n", "\n", "area_u.1->63\n", "\n", "\n", "\n", "\n", "\n", "63->64\n", "\n", "\n", "\n", "\n", "\n", ".outputs\n", "\n", "outputs\n", "\n", "\n", "\n", "64->.outputs\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "make_graph(ratio_iou_scripted.graph)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "It is not complex as code, but it has quite a few operations. Now, in terms of execution, every of these ops launches a kernel (a function run on the GPU) that does three things:\n", "\n", "- Load the inputs (from the incoming edges) from memory,\n", "- compute the output,\n", "- store the result.\n", "\n", "These are 37 times loading inputs and 20 times storing outputs with only trivial computation.\n", "Clearly this is heavily limited by the memory transfers, even if we can get helped by caching.\n", "\n", "What if we could make it all into one large kernel and have 8 loads and 1 store?\n", "\n", "This is exactly what a fuser does and it does give us a good speedup:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "161 µs ± 938 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", "38.2 µs ± 485 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" ] } ], "source": [ "x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda').exp()\n", "\n", "def take_time(fn):\n", " _ = fn(x1, y1, w1, h1, x2, y2, w2, h2)\n", " torch.cuda.synchronize()\n", "\n", "take_time(ratio_iou) # warmup\n", "%timeit take_time(ratio_iou)\n", "\n", "for i in range(2):\n", " take_time(ratio_iou_scripted)\n", "%timeit take_time(ratio_iou_scripted)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "We can see in the graph specialised for the inputs which operations are fused:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "scrolled": false, "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "cluster_121_0\n", "\n", "\n", "\n", "cluster_121_1\n", "\n", "\n", "\n", "\n", "x1.1\n", "\n", "x1\n", "\n", "\n", "\n", "112\n", "\n", "TypeCheck\n", "\n", "\n", "\n", "x1.1->112\n", "\n", "\n", "\n", "\n", "\n", "147\n", "\n", "call fallback_function\n", "\n", "\n", "\n", "x1.1->147\n", "\n", "\n", "\n", "\n", "\n", "y1.1\n", "\n", "y1\n", "\n", "\n", "\n", "y1.1->112\n", "\n", "\n", "\n", "\n", "\n", "y1.1->147\n", "\n", "\n", "\n", "\n", "\n", "w1.1\n", "\n", "w1\n", "\n", "\n", "\n", "w1.1->112\n", "\n", "\n", "\n", "\n", "\n", "w1.1->147\n", "\n", "\n", "\n", "\n", "\n", "h1.1\n", "\n", "h1\n", "\n", "\n", "\n", "h1.1->112\n", "\n", "\n", "\n", "\n", "\n", "h1.1->147\n", "\n", "\n", "\n", "\n", "\n", "x2.1\n", "\n", "x2\n", "\n", "\n", "\n", "x2.1->112\n", "\n", "\n", "\n", "\n", "\n", "x2.1->147\n", "\n", "\n", "\n", "\n", "\n", "y2.1\n", "\n", "y2\n", "\n", "\n", "\n", "y2.1->112\n", "\n", "\n", "\n", "\n", "\n", "y2.1->147\n", "\n", "\n", "\n", "\n", "\n", "w2.1\n", "\n", "w2\n", "\n", "\n", "\n", "w2.1->112\n", "\n", "\n", "\n", "\n", "\n", "w2.1->147\n", "\n", "\n", "\n", "\n", "\n", "h2.1\n", "\n", "h2\n", "\n", "\n", "\n", "h2.1->112\n", "\n", "\n", "\n", "\n", "\n", "h2.1->147\n", "\n", "\n", "\n", "\n", "\n", "121_in\n", "\n", "If\n", "\n", "\n", "\n", "112->121_in\n", "\n", "\n", "\n", "\n", "\n", "68\n", "\n", "TensorExprGroup\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "121\n", "\n", "\n", "\n", "\n", "\n", "121_in->68\n", "\n", "\n", "y\n", "\n", "\n", "\n", "121_in->147\n", "\n", "\n", "n\n", "\n", "\n", "\n", ".outputs\n", "\n", "outputs\n", "\n", "\n", "\n", "121->.outputs\n", "\n", "\n", "\n", "\n", "\n", "68->121\n", "\n", "\n", "\n", "\n", "\n", "147->121\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "make_graph(ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Aha, so we do some type check and if that returns OK, we run a `TensorExprGroup`, which will be executed as one kernel. We keep a fallback just in case.\n", "In the text representation, we can actually see the `TensorExprGroup` and we can see which operations are fused:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "graph(%x1.1 : Tensor,\n", " %y1.1 : Tensor,\n", " %w1.1 : Tensor,\n", " %h1.1 : Tensor,\n", " %x2.1 : Tensor,\n", " %y2.1 : Tensor,\n", " %w2.1 : Tensor,\n", " %h2.1 : Tensor):\n", " %112 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %113 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %114 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %115 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %116 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %117 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %118 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %119 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %120 : bool = prim::TypeCheck(%w2.1, %h2.1, %w1.1, %h1.1, %y2.1, %y1.1, %x2.1, %x1.1)\n", " %121 : Tensor = prim::If(%120)\n", " block0():\n", " %68 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%112, %113, %114, %115, %116, %117, %118, %119)\n", " -> (%68)\n", " block1():\n", " %146 : Function = prim::Constant[name=\"fallback_function\", fallback=1]()\n", " %147 : (Tensor) = prim::CallFunction(%146, %w2.1, %h2.1, %w1.1, %h1.1, %y2.1, %y1.1, %x2.1, %x1.1)\n", " %148 : Tensor = prim::TupleUnpack(%147)\n", " -> (%148)\n", " return (%121)\n", "with prim::TensorExprGroup_0 = graph(%14 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %15 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %17 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %18 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %34 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %37 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %51 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %54 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0)):\n", " %4 : float = prim::Constant[value=1.0000000000000001e-05]()\n", " %42 : None = prim::Constant()\n", " %41 : float = prim::Constant[value=0.]()\n", " %55 : int = prim::Constant[value=1]()\n", " %xi.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::max(%54, %51) # :2:9\n", " %yi.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::max(%37, %34) # :3:9\n", " %56 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%54, %17, %55) # :4:31\n", " %53 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%51, %14, %55) # :4:38\n", " %50 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::min(%56, %53) # :4:21\n", " %47 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%50, %xi.2, %55) # :4:21\n", " %wi.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%47, %41, %42) # :4:9\n", " %39 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%37, %18, %55) # :5:31\n", " %36 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%34, %15, %55) # :5:38\n", " %33 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::min(%39, %36) # :5:21\n", " %30 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%33, %yi.2, %55) # :5:21\n", " %hi.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%30, %41, %42) # :5:9\n", " %area_i.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%wi.2, %hi.2) # :6:13\n", " %19 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%17, %18) # :7:13\n", " %16 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%14, %15) # :7:23\n", " %13 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%19, %16, %55) # :7:13\n", " %area_u.2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%13, %area_i.2, %55) # :7:13\n", " %6 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%area_u.2, %4, %42) # :8:20\n", " %2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::div(%area_i.2, %6) # :8:11\n", " return (%2)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)" ] }, { "cell_type": "markdown", "metadata": { "scrolled": false }, "source": [ "We will look in some detail how these things work, but the core idea is that operations of the `TensorExprGroup` here will be compiled into a single kernel that then computes the result from the inputs in one go.\n", "\n", "## How the Fusers Work at a High Level\n", "\n", "At a high level, PyTorch's fusers work in three parts:\n", "\n", "- In a fusion JIT compiler pass, the operations that can be fused are arranged in a fusion group. By looking at which operations can be fused, we get a good glimpse of what the fusers (think) they can achieve. The classic (or legacy) PyTorch fuser only considers pointwise operations (like the IOU above, see `isSimpleMap` in `torch/csrc/jit/passes/graph_fuser.cpp`). The cuda fuser (or fuser2/nvFuser above), which is conceptually somewhat close but much more elaborate than the classic fuser also handles `sum` (see `IRParser`'s `registerJitOperator` in `torch/csrc/jit/codegen/cuda/parser.cpp`). The TensorExpr fuser (fuser1, the default) fuses pointwise and `softmax` and `log_softmax` in addition to `sum` if reduction support is enabled (see `isSupported` in `torch/csrc/jit/passes/tensorexpr_graph_fuser.cpp`). It generates a fusion group node of some sort, but, in the case of the newer two fusers also inserts a check (`TypeCheck` or ...) and an explicit fallback. Interestingly, the fusers also support `rand_like`, which is very interesting and useful functionality for things like dropout.\n", "\n", "- At some point (typically the first invocation of the fusion group), it compiles a kernel for the computation. Typically this is specific to (some aspects of) the type and shape of the inputs. For the GPU, the fusers emit HIP/CUDA C code and compile using the GPU RTC (run time compile) library. For the CPU the classic fuser would also use C but the TensorExpr fuser uses an LLVM backend (but note that the CPU is much less of a target and the main use case is the GPU). These kernels are cached.\n", "\n", "- When running a fusion group (the fuser registers an operator with the JIT that is then called), the fuser needs to launch the kernel. For the newer fusers, checking whether the inputs matches expectations is done outside this node, but the classic fuser would do the fallback itself if needed.\n", "\n", "One thing to know about the fallback is that it itself will be optimized by the PyTorch JIT. So when we run a function that has been optimized with fusions with incompatible parameters (e.g. change whether we want gradients), the faling type check would cause the JIT to call the fallback and that would then get the optmizations for these parameters (and another level of check and fallback).\n", "\n", "\n", "### Code generation from TorchScritpt IR to GPU kernel\n", "\n", "In addition to the operator support, the code generation is where each fuser has a different approach.\n", "\n", "The CUDA fuser first transforms the TorchScript IR in the CudaFusionGroup to a Fusion IR.\n", "This is then further lowered to the Kernel IR and finally translated to C++-code from which the\n", "runtime compiler generates the kernel. The approach is conceptually relatively straightforward: there are optimizaitons how the data access is layed out and then pointwise operators are just loading, computing and storing. For reductions, there is a heuristic how to deal with the reduction axes (this is somewhat similar to TensorIterators in ATen, and, indeed the use-case is quite similar but with the compile-time vs. run-time distinction). But, as these things go, to get good results, there are quite a few things to take care of.\n", "\n", "The TensorExpr fuser (which is inspired by the lower levels of the [Apache TVM](https://tvm.apache.org/)) translates the TorchScript IR into a sequence of [Loop-Nest](https://en.wikipedia.org/wiki/Loop_nest_optimization) statements (this is done in `torch/csrc/jit/tensorexpr/kernel.cpp`, which implements the operator processing the `TensorExprGroup` Torchscript IR node). This is the TensorExpr IR (the quickest overview over the IR node types can maybe be had by looking at `torch/csrc/jit/tensorexpr/ir_visitor.h`). They are then optimized and lowered before they are passed to the code generators (CUDA source code for the GPU or LLVM for the CPU) that write kernel functions and then compile and run them (again, with caching).\n", "\n", "\n", "## Automatic Differentiation in TorchScript\n", "\n", "Things are a bit more complicated if we need gradients. The default mode of the JIT is to execute the LibTorch operations and they will build an autograd graph just like in classic PyTorch. But when we want to fuse operators, things get a bit more complicated. The problem here is AutoGrad needs intermediate results to compute the backward. This is OK, but our express purpose here is to skip storing and loading the intermediate results. This is mitigated by the PyTorch JIT's own automatic differentiation (AD) mechanism, AutoDiff (as opposed to AutoGrad in PyTorch). \n", "\n", "We can see it in action when we re-define our function and run it with gradient-requiring inputs: we get a `DifferentiableGraph` in there and the `TensorExprGroup` is inside that (usually this would be created as part of the fallback function but to start fresh and see this better we have to re-define the function here, just re-scripting isn't enough to clear the script):" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "graph(%x1.1 : Tensor,\n", " %y1.1 : Tensor,\n", " %w1.1 : Tensor,\n", " %h1.1 : Tensor,\n", " %x2.1 : Tensor,\n", " %y2.1 : Tensor,\n", " %w2.1 : Tensor,\n", " %h2.1 : Tensor):\n", " %68 : Tensor = prim::DifferentiableGraph_0(%h2.1, %h1.1, %w2.1, %w1.1, %y2.1, %y1.1, %x2.1, %x1.1)\n", " return (%68)\n", "with prim::DifferentiableGraph_0 = graph(%65 : Tensor,\n", " %70 : Tensor,\n", " %96 : Tensor,\n", " %101 : Tensor,\n", " %104 : Tensor,\n", " %106 : Tensor,\n", " %109 : Tensor,\n", " %111 : Tensor):\n", " %617 : int[] = aten::size(%111) # :3:44\n", " %620 : int[] = aten::size(%109) # :3:93\n", " %624 : int[] = aten::size(%106) # :3:44\n", " %627 : int[] = aten::size(%104) # :3:93\n", " %634 : int[] = aten::size(%101) # :3:93\n", " %641 : int[] = aten::size(%96) # :3:93\n", " %655 : int[] = aten::size(%70) # :3:93\n", " %662 : int[] = aten::size(%65) # :3:93\n", " %903 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %904 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %905 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %906 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %907 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %908 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %909 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %910 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %911 : bool = prim::TypeCheck(%96, %65, %101, %70, %104, %106, %109, %111)\n", " %912 : Tensor, %913 : Tensor, %914 : Tensor, %915 : Tensor, %916 : Tensor, %917 : Tensor, %918 : Tensor, %919 : Tensor, %920 : Tensor, %921 : Tensor, %922 : Tensor, %923 : Tensor = prim::If(%911)\n", " block0():\n", " %830 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %832 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %area_u.4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %area_i.4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %hi.4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %846 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %850 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %852 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %wi.4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %856 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %860 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0), %862 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%903, %904, %905, %906, %907, %908, %909, %910)\n", " -> (%830, %832, %area_u.4, %area_i.4, %hi.4, %846, %850, %852, %wi.4, %856, %860, %862)\n", " block1():\n", " %959 : Function = prim::Constant[name=\"fallback_function\", fallback=1]()\n", " %960 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::CallFunction(%959, %96, %65, %101, %70, %104, %106, %109, %111)\n", " %961 : Tensor, %962 : Tensor, %963 : Tensor, %964 : Tensor, %965 : Tensor, %966 : Tensor, %967 : Tensor, %968 : Tensor, %969 : Tensor, %970 : Tensor, %971 : Tensor, %972 : Tensor = prim::TupleUnpack(%960)\n", " -> (%961, %962, %963, %964, %965, %966, %967, %968, %969, %970, %971, %972)\n", " %875 : int[] = aten::size(%912)\n", " %876 : int[] = aten::size(%913)\n", " %877 : int[] = aten::size(%914)\n", " %878 : int[] = aten::size(%915)\n", " %879 : int[] = aten::size(%916)\n", " %880 : int[] = aten::size(%917)\n", " %881 : int[] = aten::size(%918)\n", " %882 : int[] = aten::size(%919)\n", " %883 : int[] = aten::size(%920)\n", " %884 : int[] = aten::size(%921)\n", " %885 : int[] = aten::size(%922)\n", " %886 : int[] = aten::size(%923)\n", " %887 : int[] = prim::BroadcastSizes(%617, %620)\n", " %888 : int[] = prim::BroadcastSizes(%624, %627)\n", " %891 : int[] = prim::BroadcastSizes(%886, %885)\n", " %895 : int[] = prim::BroadcastSizes(%882, %881)\n", " %898 : int[] = prim::BroadcastSizes(%634, %655)\n", " %899 : int[] = prim::BroadcastSizes(%641, %662)\n", " %900 : int[] = prim::BroadcastSizes(%898, %899)\n", " %619 : int[]? = aten::_size_if_not_equal(%617, %887) # :3:19\n", " %622 : int[]? = aten::_size_if_not_equal(%620, %887) # :3:68\n", " %626 : int[]? = aten::_size_if_not_equal(%624, %888) # :3:19\n", " %629 : int[]? = aten::_size_if_not_equal(%627, %888) # :3:68\n", " %633 : int[]? = aten::_size_if_not_equal(%617, %886) # :3:19\n", " %636 : int[]? = aten::_size_if_not_equal(%634, %886) # :3:68\n", " %640 : int[]? = aten::_size_if_not_equal(%620, %885) # :3:19\n", " %643 : int[]? = aten::_size_if_not_equal(%641, %885) # :3:68\n", " %647 : int[]? = aten::_size_if_not_equal(%891, %884) # :3:19\n", " %650 : int[]? = aten::_size_if_not_equal(%887, %884) # :3:68\n", " %654 : int[]? = aten::_size_if_not_equal(%624, %882) # :3:19\n", " %657 : int[]? = aten::_size_if_not_equal(%655, %882) # :3:68\n", " %661 : int[]? = aten::_size_if_not_equal(%627, %881) # :3:19\n", " %664 : int[]? = aten::_size_if_not_equal(%662, %881) # :3:68\n", " %668 : int[]? = aten::_size_if_not_equal(%895, %880) # :3:19\n", " %671 : int[]? = aten::_size_if_not_equal(%888, %880) # :3:68\n", " %675 : int[]? = aten::_size_if_not_equal(%883, %878) # :3:19\n", " %678 : int[]? = aten::_size_if_not_equal(%879, %878) # :3:68\n", " %682 : int[]? = aten::_size_if_not_equal(%634, %898) # :3:19\n", " %685 : int[]? = aten::_size_if_not_equal(%655, %898) # :3:68\n", " %689 : int[]? = aten::_size_if_not_equal(%641, %899) # :3:19\n", " %692 : int[]? = aten::_size_if_not_equal(%662, %899) # :3:68\n", " %696 : int[]? = aten::_size_if_not_equal(%898, %900) # :3:19\n", " %699 : int[]? = aten::_size_if_not_equal(%899, %900) # :3:68\n", " %703 : int[]? = aten::_size_if_not_equal(%900, %877) # :3:19\n", " %706 : int[]? = aten::_size_if_not_equal(%878, %877) # :3:68\n", " %710 : int[]? = aten::_size_if_not_equal(%878, %875) # :3:19\n", " %713 : int[]? = aten::_size_if_not_equal(%876, %875) # :3:68\n", " return (%912, %111, %109, %619, %622, %106, %104, %626, %629, %101, %633, %636, %96, %640, %643, %923, %922, %647, %650, %921, %70, %654, %657, %65, %661, %664, %919, %918, %668, %671, %917, %920, %916, %675, %678, %682, %685, %689, %692, %696, %699, %915, %703, %706, %914, %913, %710, %713)\n", "with prim::TensorExprGroup_0 = graph(%14 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %15 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %17 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %18 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %34 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %37 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %51 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %54 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0)):\n", " %4 : float = prim::Constant[value=1.0000000000000001e-05]()\n", " %42 : None = prim::Constant()\n", " %41 : float = prim::Constant[value=0.]()\n", " %55 : int = prim::Constant[value=1]()\n", " %xi.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::max(%54, %51) # :2:9\n", " %yi.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::max(%37, %34) # :3:9\n", " %56 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%54, %17, %55) # :4:31\n", " %53 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%51, %14, %55) # :4:38\n", " %50 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::min(%56, %53) # :4:21\n", " %47 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%50, %xi.3, %55) # :4:21\n", " %wi.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%47, %41, %42) # :4:9\n", " %39 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%37, %18, %55) # :5:31\n", " %36 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%34, %15, %55) # :5:38\n", " %33 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::min(%39, %36) # :5:21\n", " %30 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%33, %yi.3, %55) # :5:21\n", " %hi.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%30, %41, %42) # :5:9\n", " %area_i.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%wi.3, %hi.3) # :6:13\n", " %19 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%17, %18) # :7:13\n", " %16 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::mul(%14, %15) # :7:23\n", " %13 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::add(%19, %16, %55) # :7:13\n", " %area_u.3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::sub(%13, %area_i.3, %55) # :7:13\n", " %6 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::clamp(%area_u.3, %4, %42) # :8:20\n", " %2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0) = aten::div(%area_i.3, %6) # :8:11\n", " return (%2, %6, %area_u.3, %area_i.3, %hi.3, %30, %36, %39, %wi.3, %47, %53, %56)\n", "\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "cluster_68\n", "\n", "DifferentiableGraph\n", "\n", "\n", "cluster_912_0\n", "\n", "\n", "\n", "cluster_912_1\n", "\n", "\n", "\n", "\n", "x1.1\n", "\n", "x1\n", "\n", "\n", "\n", "903\n", "\n", "TypeCheck\n", "\n", "\n", "\n", "x1.1->903\n", "\n", "\n", "\n", "\n", "\n", "960\n", "\n", "call fallback_function\n", "\n", "\n", "\n", "x1.1->960\n", "\n", "\n", "\n", "\n", "\n", "y1.1\n", "\n", "y1\n", "\n", "\n", "\n", "y1.1->903\n", "\n", "\n", "\n", "\n", "\n", "y1.1->960\n", "\n", "\n", "\n", "\n", "\n", "w1.1\n", "\n", "w1\n", "\n", "\n", "\n", "w1.1->903\n", "\n", "\n", "\n", "\n", "\n", "w1.1->960\n", "\n", "\n", "\n", "\n", "\n", "h1.1\n", "\n", "h1\n", "\n", "\n", "\n", "h1.1->903\n", "\n", "\n", "\n", "\n", "\n", "h1.1->960\n", "\n", "\n", "\n", "\n", "\n", "x2.1\n", "\n", "x2\n", "\n", "\n", "\n", "x2.1->903\n", "\n", "\n", "\n", "\n", "\n", "x2.1->960\n", "\n", "\n", "\n", "\n", "\n", "y2.1\n", "\n", "y2\n", "\n", "\n", "\n", "y2.1->903\n", "\n", "\n", "\n", "\n", "\n", "y2.1->960\n", "\n", "\n", "\n", "\n", "\n", "w2.1\n", "\n", "w2\n", "\n", "\n", "\n", "w2.1->903\n", "\n", "\n", "\n", "\n", "\n", "w2.1->960\n", "\n", "\n", "\n", "\n", "\n", "h2.1\n", "\n", "h2\n", "\n", "\n", "\n", "h2.1->903\n", "\n", "\n", "\n", "\n", "\n", "h2.1->960\n", "\n", "\n", "\n", "\n", "\n", "912_in\n", "\n", "If\n", "\n", "\n", "\n", "903->912_in\n", "\n", "\n", "\n", "\n", "\n", "830\n", "\n", "TensorExprGroup\n", "\n", "\n", "\n", "903->830\n", "\n", "\n", "\n", "\n", "\n", "903->830\n", "\n", "\n", "\n", "\n", "\n", "903->830\n", "\n", "\n", "\n", "\n", "\n", "903->830\n", "\n", "\n", "\n", "\n", "\n", "903->830\n", "\n", "\n", "\n", "\n", "\n", "903->830\n", "\n", "\n", "\n", "\n", "\n", "903->830\n", "\n", "\n", "\n", "\n", "\n", "903->830\n", "\n", "\n", "\n", "\n", "\n", "912\n", "\n", "\n", "\n", "\n", "960->912\n", "\n", "\n", "\n", "\n", "\n", "912_in->960\n", "\n", "\n", "n\n", "\n", "\n", "\n", "\n", "912_in->830\n", "\n", "\n", "y\n", "\n", "\n", "\n", ".outputs\n", "\n", "outputs\n", "\n", "\n", "\n", "912->.outputs\n", "\n", "\n", "\n", "\n", "\n", "830->912\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):\n", " xi = torch.max(x1, x2) # Intersection left\n", " yi = torch.max(y1, y2) # Intersection top\n", " wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0.) # Intersection width\n", " hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0.) # Intersection height\n", " area_i = wi * hi # Area Intersection\n", " area_u = w1 * h1 + w2 * h2 - wi * hi # Area Union\n", " return area_i / torch.clamp(area_u, min=1e-5) # Intersection over Union\n", "\n", "ratio_iou_scripted = torch.jit.script(ratio_iou)\n", "\n", "x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=True).exp()\n", "\n", "for i in range(10):\n", " ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)\n", "print(ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2))\n", "make_graph(ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand why this is, we need to look at how AutoDiff works. It has roughly three stages:\n", "\n", "- The first part of AutoDiff is a pass that creates these differentiable graphs (in the optimizations, notably before the fusing). AutoDiff has a catalogue of operations for which it can compute backwards (*Sidenote*: with their own derivative definition which could potentially differ from the AutoGrad one) it will move those into the `DifferentiableGraph`.\n", "\n", "- Then, when we run a graph containing `DifferentiableGraph` nodes (i.e. during the forward pass), the second part of AutoDiff will compute the gradient by going through the nodes of the forward graph. This is a form of source-to-source differentiation (but in contrast to classic symbolic differentiation, it is specialized to autograd-style jacobian-vector-products). This can amend the forward to output intermediates that are then captured for the backward, similar to the `save_for_backward` mechanism in an `autograd.Function` (you can see that the `TensorExprGroup` now returns a lot more values and the `DifferentiableGraph` itself adds all these sizes.\n", "\n", "- Finally, the PyTorch AutoGrad(!) mechanism is used by making a `DifferentiableGraphBackward` node that holds on to the intermediate values and, when backward is called, runs the backward graph constructed in the previous step (including letting the JIT optimize it, potentially fusing operations etc.).\n", "\n", "What is it with these sizes then? The convenient broadcasting semantics cause PyTorch to implicitly expand operands to (mostly) binary operations. But these expansions have a gradient operation associated with them - a summation of any broadcast dimensions. These size operations check whether broadcasting has happened (i.e. the output shape is large than the input for a binary operation) and if so record the target size for the summation (and `None` if no summation is needed thanks to the `aten::_size_if_not_equal` operation).\n", "\n", "There is another thing to note here: The JIT currently does not have a terribly smart logic to decide which things to capture and which things might be as well re-computed (e.g. done manually, one might well choose to recompute all the intermediates of our little function instead of capturing the values), but will mimic what AutoGrad does (defined by the AutoDiff backward specifications)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "221 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", "91.5 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" ] } ], "source": [ "x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=True).exp()\n", "\n", "def take_time(fn):\n", " _ = fn(x1, y1, w1, h1, x2, y2, w2, h2)\n", " torch.cuda.synchronize()\n", "\n", "take_time(ratio_iou) # warmup\n", "%timeit take_time(ratio_iou)\n", "\n", "for i in range(2):\n", " take_time(ratio_iou_scripted)\n", "%timeit take_time(ratio_iou_scripted)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Profiling Executor\n", "\n", "We mentioned that the JIT fusers will specialze on detailed tensor type informaition. How does it get this information? It is through the Profiling Executor that is in charge of running the JITed graphs.\n", "\n", "The profiling executor will record tensor type information (dtype, strides, sizes, requires gradient) in its profiling phase (the first few invocations). This is done by inserting special `prim::profile` nodes into the graph which then run an operator collecting and aggregating this information. Currently, it runs one profiling run, but this is configurable. Then it uses this information to implement optimizations. (*Sidenote*: Lest you should be thinking of taking time measurements when hearing profiling - I sure did - this does not seem to be done here, currently.)\n", "\n", "Traditionally, the same thing (get tensor type information attached to every value) has been done by propagating the types from the inputs through the graph. While this works great in general, it soon hits limitations, e.g. for convolutions the output shape (and thus precise type information) depends on the value (not even just the type) of e.g. the padding input. This means that unless we detect that the output-shaping inputs are constants *and* have some way of accessing type propagation, we do not know the output shape. (*Sidenote*: The same topic is also addressed by people interested in type-checking tensor programs who coordinate on the Python [typing sig mailing list](https://mail.python.org/archives/list/typing-sig@python.org/).) My best guess on the design choice here is that this is the reason we instead observe shapes during runtime (my impression is that PyTorch operations would ideally provide type propagation information, but that is could be me).\n", "\n", "So when the JIT fuser passes mentioned above go to work, they find these typing annotations on all tensor values and can adjust.\n", "\n", "One interesting aspects about the type expectations encoded by `TypeCheck` for the TensorExpr fuser and `CudaFuserGuard` for the CUDA fuser. (*Sidenote* Iterestingly, `TypeCheck` is wired into the JIT interpreter and JIT type system, while the `CudaFuserGuard` is implemented as a regular operator and implemented \"manually\" in a function `complyWith` in `torch/csrc/jit/codegen/cuda/interface.cpp`.) While they both nail the tensor shape and layout, the CUDA fuser will use the same kernel on tensors of different sizes as long as the contiguity pattern (i.e. that there are no gaps in the storage between the values of the tensor, e.g. from slicing) is the same. \n", "\n", "\n", "## Looking at fallback graphs\n", "\n", "We mentioned the importance of fallbacks and how fallbacks are again optimized. But we have yet to see it.\n", "Sadly, the JIT's Python interface is lacking or, hopefully, lagging a bit.\n", "\n", "But we can hack around this by building our own little PyTorch extension that provides the missing functionality.\n", "Again, I recommend to skip this bit on first reading and revisit if you really want to know about types in the JIT (that would be another tutorial)." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using /home/tv/.cache/torch_extensions as PyTorch extensions root...\n", "Emitting ninja build file /home/tv/.cache/torch_extensions/functiontype_ext/build.ninja...\n", "Building extension module functiontype_ext...\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", "Loading extension module functiontype_ext...\n" ] } ], "source": [ "csrc = \"\"\"\n", "#include \n", "\n", "using ::c10::Type;\n", "using ::torch::jit::FunctionType;\n", "\n", "PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n", " py::class_>(m, \"FunctionType\")\n", " .def(\"name\", [](const std::shared_ptr& self) {\n", " return self->function()->name();\n", " })\n", " .def(\n", " \"get_debug_state\",\n", " [](const std::shared_ptr& self) {\n", " return self->function()->get_executor().getDebugState();\n", " })\n", " .def(\"optimized_graph\", [](const std::shared_ptr& self) {\n", " return self->function()->optimized_graph();\n", " });\n", "}\n", "\"\"\"\n", "import torch.utils.cpp_extension\n", "ext = torch.utils.cpp_extension.load_inline(\"functiontype_ext\",[csrc], verbose=True)\n", "\n", "\n", "def find_function_types(graph_or_block, function_types=None):\n", " if function_types is None:\n", " function_types = []\n", " for n in graph_or_block.nodes():\n", " if n.kind() == 'prim::Constant':\n", " t = n.output().type()\n", " if t.kind() == 'FunctionType':\n", " function_types.append(t)\n", " else:\n", " for b in n.blocks():\n", " find_function_types(b, function_types=function_types)\n", " if n.hasAttribute('Subgraph'):\n", " find_function_types(n.g('Subgraph'), function_types=function_types)\n", " return function_types\n", "\n", "def get_function_graphs(gr):\n", " return {t.name(): list(t.get_debug_state().execution_plans.values())[0].graph for t in find_function_types(gr)}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With this, we can now extract the fallback. Let us run our function a few times, first without needing gradients and then with needing gradients.\n", "\n", "The original graph is the part that doesn't need gradients, as could be expected." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "cluster_121_0\n", "\n", "\n", "\n", "cluster_121_1\n", "\n", "\n", "\n", "\n", "x1.1\n", "\n", "x1\n", "\n", "\n", "\n", "112\n", "\n", "TypeCheck\n", "\n", "\n", "\n", "x1.1->112\n", "\n", "\n", "\n", "\n", "\n", "147\n", "\n", "call fallback_function\n", "\n", "\n", "\n", "x1.1->147\n", "\n", "\n", "\n", "\n", "\n", "y1.1\n", "\n", "y1\n", "\n", "\n", "\n", "y1.1->112\n", "\n", "\n", "\n", "\n", "\n", "y1.1->147\n", "\n", "\n", "\n", "\n", "\n", "w1.1\n", "\n", "w1\n", "\n", "\n", "\n", "w1.1->112\n", "\n", "\n", "\n", "\n", "\n", "w1.1->147\n", "\n", "\n", "\n", "\n", "\n", "h1.1\n", "\n", "h1\n", "\n", "\n", "\n", "h1.1->112\n", "\n", "\n", "\n", "\n", "\n", "h1.1->147\n", "\n", "\n", "\n", "\n", "\n", "x2.1\n", "\n", "x2\n", "\n", "\n", "\n", "x2.1->112\n", "\n", "\n", "\n", "\n", "\n", "x2.1->147\n", "\n", "\n", "\n", "\n", "\n", "y2.1\n", "\n", "y2\n", "\n", "\n", "\n", "y2.1->112\n", "\n", "\n", "\n", "\n", "\n", "y2.1->147\n", "\n", "\n", "\n", "\n", "\n", "w2.1\n", "\n", "w2\n", "\n", "\n", "\n", "w2.1->112\n", "\n", "\n", "\n", "\n", "\n", "w2.1->147\n", "\n", "\n", "\n", "\n", "\n", "h2.1\n", "\n", "h2\n", "\n", "\n", "\n", "h2.1->112\n", "\n", "\n", "\n", "\n", "\n", "h2.1->147\n", "\n", "\n", "\n", "\n", "\n", "121_in\n", "\n", "If\n", "\n", "\n", "\n", "112->121_in\n", "\n", "\n", "\n", "\n", "\n", "68\n", "\n", "TensorExprGroup\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "112->68\n", "\n", "\n", "\n", "\n", "\n", "121\n", "\n", "\n", "\n", "\n", "\n", "121_in->68\n", "\n", "\n", "y\n", "\n", "\n", "\n", "121_in->147\n", "\n", "\n", "n\n", "\n", "\n", "\n", ".outputs\n", "\n", "outputs\n", "\n", "\n", "\n", "121->.outputs\n", "\n", "\n", "\n", "\n", "\n", "68->121\n", "\n", "\n", "\n", "\n", "\n", "147->121\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):\n", " xi = torch.max(x1, x2) # Intersection left\n", " yi = torch.max(y1, y2) # Intersection top\n", " wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0.) # Intersection width\n", " hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0.) # Intersection height\n", " area_i = wi * hi # Area Intersection\n", " area_u = w1 * h1 + w2 * h2 - wi * hi # Area Union\n", " return area_i / torch.clamp(area_u, min=1e-5) # Intersection over Union\n", "\n", "ratio_iou_scripted = torch.jit.script(ratio_iou)\n", "\n", "x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda').exp()\n", "\n", "for i in range(10):\n", " ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)\n", "\n", "x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=True).exp()\n", "for i in range(10):\n", " ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)\n", "\n", "gr = torch.jit.last_executed_optimized_graph()\n", "\n", "make_graph(gr)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "cluster_69\n", "\n", "DifferentiableGraph\n", "\n", "\n", "cluster_928_0\n", "\n", "\n", "\n", "cluster_928_1\n", "\n", "\n", "\n", "\n", "w2.1\n", "\n", "w2\n", "\n", "\n", "\n", "919\n", "\n", "TypeCheck\n", "\n", "\n", "\n", "w2.1->919\n", "\n", "\n", "\n", "\n", "\n", "976\n", "\n", "call fallback_function\n", "\n", "\n", "\n", "w2.1->976\n", "\n", "\n", "\n", "\n", "\n", "h2.1\n", "\n", "h2\n", "\n", "\n", "\n", "h2.1->919\n", "\n", "\n", "\n", "\n", "\n", "h2.1->976\n", "\n", "\n", "\n", "\n", "\n", "w1.1\n", "\n", "w1\n", "\n", "\n", "\n", "w1.1->919\n", "\n", "\n", "\n", "\n", "\n", "w1.1->976\n", "\n", "\n", "\n", "\n", "\n", "h1.1\n", "\n", "h1\n", "\n", "\n", "\n", "h1.1->919\n", "\n", "\n", "\n", "\n", "\n", "h1.1->976\n", "\n", "\n", "\n", "\n", "\n", "y2.1\n", "\n", "y2\n", "\n", "\n", "\n", "y2.1->919\n", "\n", "\n", "\n", "\n", "\n", "y2.1->976\n", "\n", "\n", "\n", "\n", "\n", "y1.1\n", "\n", "y1\n", "\n", "\n", "\n", "y1.1->919\n", "\n", "\n", "\n", "\n", "\n", "y1.1->976\n", "\n", "\n", "\n", "\n", "\n", "x2.1\n", "\n", "x2\n", "\n", "\n", "\n", "x2.1->919\n", "\n", "\n", "\n", "\n", "\n", "x2.1->976\n", "\n", "\n", "\n", "\n", "\n", "x1.1\n", "\n", "x1\n", "\n", "\n", "\n", "x1.1->919\n", "\n", "\n", "\n", "\n", "\n", "x1.1->976\n", "\n", "\n", "\n", "\n", "\n", "928_in\n", "\n", "If\n", "\n", "\n", "\n", "919->928_in\n", "\n", "\n", "\n", "\n", "\n", "846\n", "\n", "TensorExprGroup\n", "\n", "\n", "\n", "919->846\n", "\n", "\n", "\n", "\n", "\n", "919->846\n", "\n", "\n", "\n", "\n", "\n", "919->846\n", "\n", "\n", "\n", "\n", "\n", "919->846\n", "\n", "\n", "\n", "\n", "\n", "919->846\n", "\n", "\n", "\n", "\n", "\n", "919->846\n", "\n", "\n", "\n", "\n", "\n", "919->846\n", "\n", "\n", "\n", "\n", "\n", "919->846\n", "\n", "\n", "\n", "\n", "\n", "928\n", "\n", "\n", "\n", "\n", "976->928\n", "\n", "\n", "\n", "\n", "\n", "928_in->976\n", "\n", "\n", "n\n", "\n", "\n", "\n", "\n", "928_in->846\n", "\n", "\n", "y\n", "\n", "\n", "\n", "67\n", "\n", "TupleConstruct\n", "\n", "\n", "\n", "928->67\n", "\n", "\n", "\n", "\n", "\n", "846->928\n", "\n", "\n", "\n", "\n", "\n", ".outputs\n", "\n", "outputs\n", "\n", "\n", "\n", "67->.outputs\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gr_fb1 = get_function_graphs(gr)['fallback_function']\n", "make_graph(gr_fb1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can take this to several levels, but when we get an \"internal assert failed\" error regarding a missing optimized plan, it means that we have reached the end of the *optimized* fallback passes." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "tags": [ "raises-exception" ] }, "outputs": [ { "ename": "RuntimeError", "evalue": "optimized_plan_ INTERNAL ASSERT FAILED at \"../torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp\":551, please report a bug to PyTorch. ", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgr_fb2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_function_graphs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgr_fb1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'fallback_function'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mget_function_graphs\u001b[0;34m(gr)\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_function_graphs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 42\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_debug_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecution_plans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfind_function_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_function_graphs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 42\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_debug_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecution_plans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfind_function_types\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m: optimized_plan_ INTERNAL ASSERT FAILED at \"../torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp\":551, please report a bug to PyTorch. " ] } ], "source": [ "gr_fb2 = get_function_graphs(gr_fb1)['fallback_function']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can inspect the unoptimized fallback, even if it may seem counterintuitive to the uninitiated like us that the unoptimized graph should be accessed via `optimized_graph` (Also note that the type annotations in the fallback branch are bogus. Oh well.):" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "graph(%0 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %1 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %2 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %3 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %4 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %5 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %6 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0),\n", " %7 : Float(100, 1000, strides=[1000, 1], requires_grad=0, device=cuda:0)):\n", " %11 : int = prim::Constant[value=1]()\n", " %10 : float = prim::Constant[value=0.]()\n", " %9 : None = prim::Constant()\n", " %8 : float = prim::Constant[value=1.0000000000000001e-05]()\n", " %xi.4 : Tensor = aten::max(%7, %6) # :2:9\n", " %yi.4 : Tensor = aten::max(%5, %4) # :3:9\n", " %14 : Tensor = aten::add(%7, %2, %11) # :4:31\n", " %15 : Tensor = aten::add(%6, %0, %11) # :4:38\n", " %16 : Tensor = aten::min(%14, %15) # :4:21\n", " %17 : Tensor = aten::sub(%16, %xi.4, %11) # :4:21\n", " %wi.4 : Tensor = aten::clamp(%17, %10, %9) # :4:9\n", " %19 : Tensor = aten::add(%5, %3, %11) # :5:31\n", " %20 : Tensor = aten::add(%4, %1, %11) # :5:38\n", " %21 : Tensor = aten::min(%19, %20) # :5:21\n", " %22 : Tensor = aten::sub(%21, %yi.4, %11) # :5:21\n", " %hi.4 : Tensor = aten::clamp(%22, %10, %9) # :5:9\n", " %area_i.4 : Tensor = aten::mul(%wi.4, %hi.4) # :6:13\n", " %25 : Tensor = aten::mul(%2, %3) # :7:13\n", " %26 : Tensor = aten::mul(%0, %1) # :7:23\n", " %27 : Tensor = aten::add(%25, %26, %11) # :7:13\n", " %area_u.4 : Tensor = aten::sub(%27, %area_i.4, %11) # :7:13\n", " %29 : Tensor = aten::clamp(%area_u.4, %8, %9) # :8:20\n", " %30 : Tensor = aten::div(%area_i.4, %29) # :8:11\n", " %31 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::TupleConstruct(%30, %29, %area_u.4, %area_i.4, %hi.4, %22, %20, %19, %wi.4, %17, %15, %14)\n", " return (%31)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "find_function_types(gr_fb1)[0].optimized_graph()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How we could go at benchmarking\n", "\n", "We can now pitch the various fusers against each other if we want. We abuse the context manager in a non-contextmanagery way. Note that we do not time the backwards here, but it would be straightforward to do, too." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fuser: None, requires gradient: False\n", "159 µs ± 457 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", "fuser: fuser1, requires gradient: False\n", "37.1 µs ± 180 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", "fuser: fuser2, requires gradient: False\n", "47 µs ± 166 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", "fuser: None, requires gradient: True\n", "221 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", "fuser: fuser1, requires gradient: True\n", "92.7 µs ± 197 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", "fuser: fuser2, requires gradient: True\n", "106 µs ± 242 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" ] } ], "source": [ "for rq in [False, True]:\n", " for fuser in [None, \"fuser1\", \"fuser2\"]:\n", " if fuser is not None:\n", " c = torch.jit.fuser(fuser) \n", " c.__enter__()\n", " \n", " def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):\n", " xi = torch.max(x1, x2) # Intersection left\n", " yi = torch.max(y1, y2) # Intersection top\n", " wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0.) # Intersection width\n", " hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0.) # Intersection height\n", " area_i = wi * hi # Area Intersection\n", " area_u = w1 * h1 + w2 * h2 - wi * hi # Area Union\n", " return area_i / torch.clamp(area_u, min=1e-5) # Intersection over Union\n", "\n", " ratio_iou_scripted = torch.jit.script(ratio_iou) if fuser is not None else ratio_iou\n", " \n", "\n", " x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=rq).exp()\n", " \n", " print(f\"fuser: {fuser}, requires gradient: {rq}\")\n", " for i in range(10):\n", " take_time(ratio_iou_scripted)\n", "\n", " %timeit take_time(ratio_iou_scripted)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Doing funny things to kick the tires a bit\n", "\n", "If you followed along, you will have noticed that the order of kernels to try depends on how we have called our scripted function before. This can lead to somewhat funny effects.\n", "\n", "One thing is that whether we end up running a `DifferentiableGraph` (and computing the intermediates) depends on what we did during the profiling and the fallback mechanisms for the fusion groups.\n", "In fact, there are bugs to be found (reported as [#49299](https://github.com/pytorch/pytorch/issues/49299)) where whether we get gradient requiring outputs does not match what we feed into the scripted function:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fuser: fuser1 input requires grad: True output requires grad: True\n", "fuser: fuser1 input requires grad: False output requires grad: True\n", "fuser: fuser2 input requires grad: True output requires grad: True\n", "fuser: fuser2 input requires grad: False output requires grad: True\n" ] } ], "source": [ "for fuser in [\"fuser1\", \"fuser2\"]:\n", " for rq in [True, False]:\n", " c = torch.jit.fuser(fuser)\n", " c.__enter__()\n", "\n", " def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):\n", " xi = torch.max(x1, x2) # Intersection left\n", " yi = torch.max(y1, y2) # Intersection top\n", " wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0.) # Intersection width\n", " hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0.) # Intersection height\n", " area_i = wi * hi # Area Intersection\n", " area_u = w1 * h1 + w2 * h2 - wi * hi # Area Union\n", " return area_i / torch.clamp(area_u, min=1e-5) # Intersection over Union\n", "\n", " ratio_iou_scripted = torch.jit.script(ratio_iou)\n", "\n", " x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=not rq).exp()\n", " for i in range(10):\n", " ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)\n", " x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda', requires_grad=rq).exp()\n", " print(\"fuser:\", fuser, \"input requires grad:\", x1.requires_grad, \"output requires grad:\", ratio_iou_scripted(x1, y1, w1, h1, x2, y2, w2, h2).requires_grad)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another fun thing to try is what happens when the profiling runs see different tensor sizes (this is a real thing, e.g. for Neural Machine Translation or other NLP applications).\n", "\n", "Do change the fuser between `fuser1` and `fuser2` here. We see that the CUDA fuser can handle both sizes with the same kernel while the TensorExpr fuser decides to not optimize this path.\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "\n", "x1.1\n", "\n", "x1\n", "\n", "\n", "\n", "xi.1\n", "\n", "max\n", "\n", "\n", "\n", "x1.1->xi.1\n", "\n", "\n", "\n", "\n", "\n", "20\n", "\n", "add\n", "\n", "\n", "\n", "x1.1->20\n", "\n", "\n", "\n", "\n", "\n", "y1.1\n", "\n", "y1\n", "\n", "\n", "\n", "yi.1\n", "\n", "max\n", "\n", "\n", "\n", "y1.1->yi.1\n", "\n", "\n", "\n", "\n", "\n", "34\n", "\n", "add\n", "\n", "\n", "\n", "y1.1->34\n", "\n", "\n", "\n", "\n", "\n", "w1.1\n", "\n", "w1\n", "\n", "\n", "\n", "w1.1->20\n", "\n", "\n", "\n", "\n", "\n", "51\n", "\n", "mul\n", "\n", "\n", "\n", "w1.1->51\n", "\n", "\n", "\n", "\n", "\n", "h1.1\n", "\n", "h1\n", "\n", "\n", "\n", "h1.1->34\n", "\n", "\n", "\n", "\n", "\n", "h1.1->51\n", "\n", "\n", "\n", "\n", "\n", "x2.1\n", "\n", "x2\n", "\n", "\n", "\n", "x2.1->xi.1\n", "\n", "\n", "\n", "\n", "\n", "23\n", "\n", "add\n", "\n", "\n", "\n", "x2.1->23\n", "\n", "\n", "\n", "\n", "\n", "y2.1\n", "\n", "y2\n", "\n", "\n", "\n", "y2.1->yi.1\n", "\n", "\n", "\n", "\n", "\n", "37\n", "\n", "add\n", "\n", "\n", "\n", "y2.1->37\n", "\n", "\n", "\n", "\n", "\n", "w2.1\n", "\n", "w2\n", "\n", "\n", "\n", "w2.1->23\n", "\n", "\n", "\n", "\n", "\n", "54\n", "\n", "mul\n", "\n", "\n", "\n", "w2.1->54\n", "\n", "\n", "\n", "\n", "\n", "h2.1\n", "\n", "h2\n", "\n", "\n", "\n", "h2.1->37\n", "\n", "\n", "\n", "\n", "\n", "h2.1->54\n", "\n", "\n", "\n", "\n", "\n", "29\n", "\n", "sub\n", "\n", "\n", "\n", "xi.1->29\n", "\n", "\n", "\n", "\n", "\n", "43\n", "\n", "sub\n", "\n", "\n", "\n", "yi.1->43\n", "\n", "\n", "\n", "\n", "\n", "26\n", "\n", "min\n", "\n", "\n", "\n", "20->26\n", "\n", "\n", "\n", "\n", "\n", "23->26\n", "\n", "\n", "\n", "\n", "\n", "26->29\n", "\n", "\n", "\n", "\n", "\n", "wi.1\n", "\n", "clamp\n", "\n", "\n", "\n", "29->wi.1\n", "\n", "\n", "\n", "\n", "\n", "area_i.1\n", "\n", "mul\n", "\n", "\n", "\n", "wi.1->area_i.1\n", "\n", "\n", "\n", "\n", "\n", "40\n", "\n", "min\n", "\n", "\n", "\n", "34->40\n", "\n", "\n", "\n", "\n", "\n", "37->40\n", "\n", "\n", "\n", "\n", "\n", "40->43\n", "\n", "\n", "\n", "\n", "\n", "hi.1\n", "\n", "clamp\n", "\n", "\n", "\n", "43->hi.1\n", "\n", "\n", "\n", "\n", "\n", "hi.1->area_i.1\n", "\n", "\n", "\n", "\n", "\n", "area_u.1\n", "\n", "sub\n", "\n", "\n", "\n", "area_i.1->area_u.1\n", "\n", "\n", "\n", "\n", "\n", "65\n", "\n", "div\n", "\n", "\n", "\n", "area_i.1->65\n", "\n", "\n", "\n", "\n", "\n", "57\n", "\n", "add\n", "\n", "\n", "\n", "51->57\n", "\n", "\n", "\n", "\n", "\n", "54->57\n", "\n", "\n", "\n", "\n", "\n", "57->area_u.1\n", "\n", "\n", "\n", "\n", "\n", "62\n", "\n", "clamp\n", "\n", "\n", "\n", "area_u.1->62\n", "\n", "\n", "\n", "\n", "\n", "62->65\n", "\n", "\n", "\n", "\n", "\n", ".outputs\n", "\n", "outputs\n", "\n", "\n", "\n", "65->.outputs\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = torch.jit.fuser(\"fuser1\")\n", "c.__enter__()\n", "torch._C._jit_set_num_profiled_runs(2)\n", "\n", "def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2):\n", " xi = torch.max(x1, x2) # Intersection left\n", " yi = torch.max(y1, y2) # Intersection top\n", " wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0.) # Intersection width\n", " hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0.) # Intersection height\n", " area_i = wi * hi # Area Intersection\n", " area_u = w1 * h1 + w2 * h2 - wi * hi # Area Union\n", " return area_i / torch.clamp(area_u, min=1e-5) # Intersection over Union\n", "\n", "ratio_iou_scripted = torch.jit.script(ratio_iou)\n", "\n", "inputs1 = torch.randn(8, 100, 1000, device='cuda').exp()\n", "inputs2 = torch.randn(8, 101, 1000, device='cuda').exp()\n", "\n", "for i in range(10):\n", " ratio_iou_scripted.graph_for(*inputs1)\n", " ratio_iou_scripted.graph_for(*inputs2)\n", " \n", "make_graph(ratio_iou_scripted.graph_for(*inputs1))\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Getting more debug output\n", "\n", "When we run the JIT on the command line, we can make use of its [debug logging facility](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#jit-logging) to watch its parts in action more closely. \n", "\n", "The fusers also have various debugging facilities. The TensorExpr one uses the debug logging facility (grep for `GRAPH_` in `torch/csrc/jit/tensorexpr/`) and the CUDA one uses environment variables starting with `PYTORCH_CUDA_FUSER` (grep for that in `torch/csrc/jit/codegen/cuda/`).\n", "\n", "## Conclusion\n", "\n", "In this piece, we saw a bit how the JIT works, with a focus on the parts that make fusion optimizations possible and took a dive from a very high level to experimentation that try to show how some internals work.\n", "I hope you enjoyed this tour. As always your feedback is appreciated: .\n", "\n", "*Sidenote*: There also is a more general technical overview in the file [`torch/csrc/jit/OVERVIEW.md`](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md) in the JIT directory of the PyTorch source code) and various bits of documentation in `.md` files throughout the source as well as in comments in the source.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" }, "rise": { "header": "", "theme": "white", "transition": "off" } }, "nbformat": 4, "nbformat_minor": 4 }