{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# <img src=\"https://img.icons8.com/clouds/100/000000/robot.png\" style=\"height:50px;display:inline\"> Technion - Reinforcement Learning Resarch Labs - $\\text{RL}^2$\n", "---\n", "\n", "#### <a href=\"https://taldatech.github.io\">Tal Daniel</a>\n", "\n", "## Tutorial - Maximizing CPU and GPU Utilization in PyTorch\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### <img src=\"https://img.icons8.com/bubbles/50/000000/training.png\" style=\"height:50px;display:inline\"> Introduction\n", "---\n", "* In this tutorial we will provide tips and tricks for efficient code execution on CPU and GPU with PyTorch.\n", "* The main goal: an optimized code that runs faster.\n", "* We wish to maximize the utility of the GPU, and allow efficient transfer of data between CPU-GPU.\n", "* Here we provide general directions, and you can refer to the link in each section for more detailed examples." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### <img src=\"https://img.icons8.com/bubbles/50/000000/checklist.png\" style=\"height:50px;display:inline\"> Agenda\n", "---\n", "\n", "* [Monitoring Utilization](#-Monitoring-Utilization)\n", "* [General Tips](#-General-Tips)\n", "* [Data Loading Tips to Maximize GPU Utility](#-Data-Loading-Tips-to-Maximize-GPU-Utility)\n", "* [Training Tips](#-Training-Tips)\n", "* [Reinforcement Learning](#-Reinforcement-Learning)\n", "* [Recommended Videos](#-Recommended-Videos)\n", "* [Credits](#-Credits)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# imports\n", "import numpy as np\n", "import torch\n", "from torch.utils.data import DataLoader\n", "import kornia # `pip install kornia`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <img src=\"https://img.icons8.com/external-flaticons-lineal-color-flat-icons/64/000000/external-hud-game-development-flaticons-lineal-color-flat-icons.png\" style=\"height:50px;display:inline\"> Monitoring Utilization\n", "---\n", "* How can we monitor the CPU and GPU utilization of our code?\n", "* The simpleset way to monitor GPU utilization is to run `nvidia-smi` in the Terminal/CMD/PowerShell (can also use `!nvidia-smi` inside a Jupyter Notebook).\n", "* For CPU monitoring running `htop` in Ubuntu's Terminal usually gets the job done.\n", "* Recommedation: a great tool for detailed usage: <a href=\"https://github.com/XuehaiPan/nvitop\">NVITOP</a> (`pip install nvitop`) and then run `nvitop -m`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`nvitop -m`\n", "---\n", "\n", "<center><img src=\"https://user-images.githubusercontent.com/16078332/171005261-1aad126e-dc27-4ed3-a89b-7f9c1c998bf7.png\" style=\"height:500px\"></center>" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tue Oct 4 10:28:20 2022 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 517.48 Driver Version: 517.48 CUDA Version: 11.7 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 NVIDIA GeForce ... WDDM | 00000000:3B:00.0 Off | N/A |\n", "| N/A 42C P8 N/A / N/A | 0MiB / 4096MiB | 2% Default |\n", "| | | N/A |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| 0 N/A N/A 14804 C ...nda\\envs\\torch\\python.exe N/A |\n", "+-----------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nvcc: NVIDIA (R) Cuda compiler driver\n", "Copyright (c) 2005-2019 NVIDIA Corporation\n", "Built on Wed_Oct_23_19:32:27_Pacific_Daylight_Time_2019\n", "Cuda compilation tools, release 10.2, V10.2.89\n" ] } ], "source": [ "# check the current CUDA toolkit version installed in this `conda` environment\n", "!nvcc -V" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Understanding the Output of `nvidia-smi`\n", "---\n", "Adapted from: <a href=\"https://medium.com/analytics-vidhya/explained-output-of-nvidia-smi-utility-fc4fbee3b124\">Explained Output of Nvidia-smi Utility - Shachi Kaul</a>\n", "\n", "* Top row:\n", " * **Driver Version**: the current version of the NVIDIA drivers installed (not the CUDA drivers, general drivers).\n", " * **CUDA Version**: the **maximal** CUDA driver version the current GPU and NVIDIA drivers (the ones from above) support. THIS IS NOT THE CURRENT CUDA DRIVERS VERSION INSTALLED. To check the current CUDA drivers installed, run `nvcc -V` (see above).\n", "* First table:\n", " * **Temp**: Core GPU temperature is in degrees Celsius. Usual operation temperature should be around 80-85C, if it reaches 90C, there might be a problem with the cooling or some other hardware issue.\n", " * **Perf**: Denotes GPU’s current performance state. It ranges from P0 to P12 referring to the maximum and minimum performance respectively. \n", " * **Persistence-M**: The value of Persistence Mode flag where “On” means that the NVIDIA driver will remain loaded (persist) even when no active client such as `nvidia-smi` is running. This reduces the driver load latency with dependent apps such as CUDA programs. Usually set to \"Off\".\n", " * **Pwr: Usage/Cap**: It refers to the GPU’s current power usage out of total power capacity. It samples in Watts.\n", " * **Disp.A**: Display Active is a flag that decides if you want to allocate memory on GPU device for display i.e. to initialize the display on GPU. “Off” indicates that there isn’t any display using a GPU device.\n", " * **Memory-Usage**: Denotes the memory allocation on GPU out of total memory. This should help you balance model size and batch size and other hyper-parameters that greatly affect the memory required from the GPU.\n", " * **Volatile Uncorr. ECC**: ECC stands for Error Correction Code which verifies data transmission by locating and correcting transmission errors. NVIDIA GPUs provide an error count of ECC errors. Here, Volatile error counter detects error count since the last driver loaded.\n", " * **GPU-Util**: Indicates the percent of GPU utilization i.e. percent of the time when kernels were using GPU over the sample period. Here, the period could be between 1 to 1/6th second (depending on the product being queried). In the case of low percentage, the GPU is under-utilized (e.g., code spends time on reading data from disk).\n", " * **Compute M.**: Compute Mode of specific GPU refers to the shared access mode where compute mode sets to default after each reboot. “Default” value allows multiple clients to access the CPU at the same time.\n", "* Second table:\n", " * **GPU**: Indicates the GPU index, beneficial for multi-GPU setup. This determines which process is utilizing which GPU. This index represents the NVML Index of the device.\n", " * **PID**: Refers to the process by its ID using GPU. If you need to kill a ghost process that uses the GPU (usually happnes in distributed training / multiprocessing), use that PID for the `kill` command.\n", " * **Type:**: Refers to the type of processes such as “C” (Compute), “G” (Graphics), and “C+G” (Compute and Graphics context).\n", " * **GPU Memory Usage**: Memory of specific GPU utilized by each process.\n", " \n", "For more options, run `nvidia-smi --help`.\n", "\n", "* **Advanced**: manually control the cooling of the GPU (by setting temperature ranges for fan speed): `coolgpus`, <a href=\"https://github.com/andyljones/coolgpus\">GitHub link</a>." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It also possible to access the <a href=\"https://pytorch.org/docs/stable/generated/torch.cuda.utilization.html\">GPU utilization in code as follows (PyTorch 1.12)</a>:\n", "\n", "`torch.cuda.utilization(device=None)`: Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by `nvidia-smi`. \n", "\n", "`device`: Selected device. Returns statistic for the current device, given by `current_device()`, if device is `None` (default)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Cleaning CUDA Cache\n", "---\n", "If you want to manually free up memory in the GPU, it sometimes helps to clean the cache stored in the GPU by calling `torch.cuda.empty_cache()`. This is useful especially if you are working in a Jupyter Notebook, but can also be useful after an epoch has ended (use carefully). " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <img src=\"https://img.icons8.com/external-sbts2018-flat-sbts2018/58/000000/external-tip-basic-ui-elements-2.4-sbts2018-flat-sbts2018.png\" style=\"height:50px;display:inline\"> General Tips\n", "---\n", "More tricks and hacks can be found <a href=\"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html\">HERE</a>." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.cuda.is_available(): True\n", "device: cuda:0\n" ] } ], "source": [ "# reminder - define device at the top of your code, and send models and tensors to it\n", "# check if there is a CUDA device available\n", "print(f'torch.cuda.is_available(): {torch.cuda.is_available()}')\n", "# define device\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(f'device: {device}')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model(\n", " (nn): Sequential(\n", " (0): Linear(in_features=10, out_features=128, bias=True)\n", " (1): ReLU(inplace=True)\n", " (2): Linear(in_features=128, out_features=256, bias=True)\n", " (3): ReLU(inplace=True)\n", " (4): Linear(in_features=256, out_features=10, bias=True)\n", " )\n", ")\n", "model device: cuda:0\n" ] } ], "source": [ "# a simple neural network\n", "class Model(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.nn = torch.nn.Sequential(torch.nn.Linear(10, 128),\n", " torch.nn.ReLU(inplace=True),\n", " torch.nn.Linear(128, 256),\n", " torch.nn.ReLU(inplace=True),\n", " torch.nn.Linear(256, 10))\n", " def forward(self, x):\n", " return self.nn(x)\n", " \n", "# `inplace=True`: performs the operation in-place -- not creating a copy of the tensor, can save memory.\n", "model = Model()\n", "# send model to device\n", "model = model.to(device)\n", "print(model)\n", "print(f'model device: {next(model.parameters()).device}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Avoid unnecessary CPU-GPU synchronization\n", "---\n", "Avoid unnecessary synchronizations, to let the CPU run ahead of the accelerator as much as possible to make sure that the accelerator work queue contains many operations.\n", "\n", "When possible, avoid operations which require synchronizations, for example:\n", "* `print(cuda_tensor)`\n", "\n", "* `cuda_tensor.item()`\n", "\n", "* Memory copies: `tensor.cuda(), cuda_tensor.cpu()` and equivalent `tensor.to(device)` calls.\n", "\n", "* `cuda_tensor.nonzero()`\n", "\n", "* `cuda_tensor.data.cpu().numpy()`\n", "\n", "* Python control flow which depends on results of operations performed on cuda tensors e.g. `if (cuda_tensor != 0).all()`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Create tensors directly on the target device\n", "---\n", "Instead of calling `torch.rand(size).cuda()` to generate a random tensor, produce the output directly on the target device: `torch.rand(size, device=torch.device('cuda'))`.\n", "\n", "This is applicable to all functions which create new tensors and accept device argument:` torch.rand(), torch.zeros(), torch.ones(), torch.full()` and similar." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# example\n", "a = torch.randn(32, 10).to(device) # BAD\n", "b = torch.randn(32, 10, device=device) # GOOD" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### cuDNN auto-tuner\n", "---\n", "NVIDIA cuDNN supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size.\n", "\n", "For convolutional networks (other types currently not supported), enable cuDNN autotuner before launching the training loop by setting:\n", "\n", "`torch.backends.cudnn.benchmark = True`\n", "\n", "Notes:\n", "\n", "* The auto-tuner decisions may be **non-deterministic**; different algorithm may be selected for different runs.\n", "\n", "* In some rare cases, such as with highly variable input sizes, it’s better to run convolutional networks with autotuner *disabled* to avoid the overhead associated with algorithm selection for each input size.\n", "\n", "* If you care about **reproducibility**, it is better to use:\n", "\n", "`torch.backends.cudnn.benchmark = False`\n", "\n", "\n", "`torch.backends.cudnn.deterministic = True`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# example\n", "# imports\n", "import torch\n", "# put the line here\n", "torch.backends.cudnn.benchmark = True\n", "# rest of the training function goes below" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Disable gradient calculation for validation or inference\n", "---\n", "* PyTorch saves intermediate buffers from all operations which involve tensors that require gradients. \n", "* Typically gradients aren’t needed for validation or inference. \n", "* `torch.no_grad()` context manager can be applied to disable gradient calculation within a specified block of code, this accelerates execution and reduces the amount of required memory. `torch.no_grad()` can also be used as a function decorator." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# validation loop\n", "valid_dataloader = DataLoader()\n", "model.eval()\n", "with torch.no_grad():\n", " for batch in valid_dataloader:\n", " x = batch[0]\n", " output = model(x) # no gradient cache\n", " # metrics caclulation on output\n", "# make sure to put the model back in training mode after validation ends\n", "model.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Disable bias for convolutions directly followed by a Batch Normalization\n", "---\n", "* `torch.nn.Conv2d()` has a bias parameter which defaults to `True` (the same is true for `Conv1d` and `Conv3d`).\n", "\n", "* If a `nn.Conv2d` layer is directly followed by a `nn.BatchNorm2d` layer, then the bias in the convolution is not needed, instead use `nn.Conv2d(..., bias=False, ....)`. Bias is not needed because in the first step BatchNorm subtracts the mean, which effectively cancels out the effect of bias.\n", "\n", "* This is also applicable to 1D and 3D convolutions as long as BatchNorm (or other normalization layer) normalizes on the same dimension as convolution’s bias." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "conv_layer = torch.nn.Sequential(torch.nn.Conv2d(3, 64, stride=1, kernel_size=3, padding=1, bias=False),\n", " torch.nn.BatchNorm2d(64),\n", " torch.nn.ReLU(inplace=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Use `parameter.grad = None` instead of `optimizer.zero_grad()`\n", "---\n", "* Note: this is a minor optimization, do not expect large gains from it.\n", "\n", "Instead of calling: `optimizer.zero_grad()` to zero out gradients, use the following method instead:\n", "\n", "<code>for param in model.parameters():\n", " param.grad = None</code>\n", " \n", "From PyTorch 1.7, can also use: `optimizer.zero_grad(set_to_none=True)`\n", "\n", "The second code snippet does not zero the memory of each individual parameter, also the subsequent backward pass uses assignment instead of addition to store gradients, this reduces the number of memory operations.\n", "\n", "Setting gradient to `None` has a slightly **different numerical behavior** than setting it to zero, so be careful when using it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# code doesn't change much\n", "output = model(x)\n", "loss = loss_fn(x, output)\n", "optimizer.zero_grad(set_to_none=True)\n", "loss.backward()\n", "optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`set_to_none (bool)` – instead of setting to zero, set the grads to `None`. This will in general have lower memory footprint, and can modestly improve performance. \n", "\n", "However, it changes certain behaviors. \n", "\n", "For example: \n", "1. When the user tries to access a gradient and perform manual ops on it, a `None` attribute or a `Tensor` full of 0s will behave differently. \n", "2. If the user requests `zero_grad(set_to_none=True)` followed by a backward pass, `.grad`s are guaranteed to be `None` for params that did not receive a gradient. \n", "3. `torch.optim optimizers` have a different behavior if the gradient is 0 or `None` (in one case it does the step with a gradient of 0 and in the other it skips the step altogether)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Use `model.requires_grad_(False)` when using pre-trained models\n", "---\n", "* When using pre-trained models in another model, we want to preserve the gradients of the input of the pre-trained network, but we don't want to calculate the gradients of the pre-trained model.\n", "* Examples:\n", " * Perceptual loss: use the features of a pre-trained VGG network to calculate reconstruction loss.\n", " * Use CLIP score as guidance for a generative network.\n", "* Why is `torch.no_grad()` not good enough? Because wrapping the forward pass of the pre-trained model with `torch.no_grad()` will not calculate gradients for the input, which is the **output** of a neural network that is being trained.\n", "* Solution: after loading the pre-trained model, call `model.requires_grad_(False)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# example: vgg for perceptual loss\n", "from torchvision import models\n", "\n", "vggnet = models.vgg16(weigths='DEFAULT')\n", "vggnet.eval() # for dropout and batch-norm\n", "vggnet.requires_grad_(False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Avoid `for` loops (duh...)\n", "---\n", "* It goes without saying that avoiding loops in the code is crucial for fast training.\n", "* Always try to batch operations, though you might eventually find out there is no avoiding a loop, it is worth looking for ways to batch operations.\n", "* Trivial cases are when we need to work with tensors of shape `[batch_size, dim_a, dim_b]` or `[batch_size, dim_a, dim_b, dim_c]`, it might be tempting to loop over `dim_a`, but if the GPU memory allows, we can just stack everything on the batch dimension.\n", " * Another popular case is applying convolutional layers to patches: consider an image tensor of size `[batch_size, ch, h, w]` which is patchified to a tensor of size `[batch_size, num_patches, ch, h_p, w_p]`, then we don't need to loop over `num_patches`, but just stack the `num_patches` dimension onto the the `batch_size` dimension." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a: torch.Size([32, 10, 16])\n", "a: torch.Size([320, 16])\n", "a_f: torch.Size([320, 32])\n", "a_f: torch.Size([32, 10, 32])\n", "a_f_1: torch.Size([32, 10, 32])\n" ] } ], "source": [ "batch_size = 32\n", "dim_a = 10\n", "dim_b = 16\n", "\n", "func = torch.nn.Linear(16, 32)\n", "\n", "a = torch.rand(batch_size, dim_a, dim_b)\n", "print(f'a: {a.shape}') # [batch_size, dim_a, dim_b]\n", "# want to apply some function on dim_b -> batch dim_a in the batch dimension\n", "a = a.view(-1, a.shape[-1]) # [batch_size * dim_a, dim_b]\n", "print(f'a: {a.shape}')\n", "# apply the function and then reshape to the original dimension\n", "a_f = func(a) # [batch_size * dim_a, 32]\n", "print(f'a_f: {a_f.shape}')\n", "a_f = a_f.view(batch_size, dim_a, a_f.shape[-1])\n", "print(f'a_f: {a_f.shape}')\n", "\n", "# note: torch.nn.Linear actually does this automatically\n", "a = torch.rand(batch_size, dim_a, dim_b) # [batch_size, dim_a, dim_b]\n", "a_f_1 = func(a) # [batch_size, dim_a, 32]\n", "print(f'a_f_1: {a_f_1.shape}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Model Compiling: `torch.compile`\n", "---\n", "* From PyTorch 2.0 it is possible to compile models for speed-up gains on certain GPU architectures (e.g., H100, A100, or V100).\n", "* `torch.compile` makes PyTorch code run faster by **Just-in-Time**(JIT)-compiling PyTorch code into optimized kernels.\n", "* Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size.\n", " * For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.\n", "* The `mode` parameter specifies what the compiler should be optimizing while compiling.\n", " * The `default` mode is a preset that tries to compile efficiently without taking too long to compile or using extra memory.\n", " * Other modes such as `reduce-ovehead` reduce the framework overhead by a lot more, but cost a small amount of extra memory.\n", " * `max-autotune` mode compiles for a long time, trying to give you the fastest code it can generate.\n", " * <a href=\"https://pytorch.org/get-started/pytorch-2.0/#user-experience\">More information on modes</a>.\n", "* Currently, `torch.compile` only works on **Linux** machines (no support for Windows yet, an error will be thrown if using `torch.compile`).\n", "* More details on `torch.compile` can be found on the <a href=\"https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html\">official PyTorch website</a>." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# example\n", "import torch\n", "class MyModule(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = torch.nn.Linear(100, 10)\n", "\n", " def forward(self, x):\n", " return torch.nn.functional.relu(self.lin(x))\n", "\n", "mod = MyModule()\n", "opt_mod = torch.compile(mod, mode='default')\n", "print(opt_mod(torch.randn(10, 100)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# another example\n", "import torch\n", "import torchvision.models as models\n", "\n", "device = torch.device('cuda:0') # note: must be GPU\n", "model = models.resnet18().to(device)\n", "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", "compiled_model = torch.compile(model)\n", "\n", "x = torch.randn(16, 3, 224, 224).cuda()\n", "optimizer.zero_grad()\n", "out = compiled_model(x)\n", "out.sum().backward()\n", "optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Advanced: Fuse pointwise operations with Just-in-Time (JIT)\n", "---\n", "* Pointwise operations (elementwise addition, multiplication, math functions - `sin()`, `cos()`, `sigmoid()` etc.) can be fused into a single kernel to amortize memory access time and kernel launch time.\n", "\n", "* <a href=\"https://pytorch.org/docs/stable/jit.html\">PyTorch JIT</a> can fuse kernels automatically, although there could be additional fusion opportunities not yet implemented in the compiler, and not all device types are supported equally.\n", "\n", "* Pointwise operations are **memory-bound**, for each operation PyTorch launches a separate kernel. Each kernel loads data from the memory, performs computation (this step is usually inexpensive) and stores results back into the memory.\n", "\n", "* Fused operator launches only one kernel for multiple fused pointwise ops and loads/stores data **only once** to the memory. This makes JIT very useful for **activation functions, optimizers, custom RNN cells** etc.\n", "\n", "* In the simplest case fusion can be enabled by applying `torch.jit.script` decorator to the function definition." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@torch.jit.script\n", "def fused_gelu(x):\n", " return x * 0.5 * (1.0 + torch.erf(x / 1.41421))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <img src=\"https://img.icons8.com/external-flaticons-lineal-color-flat-icons/64/000000/external-gpu-computer-science-flaticons-lineal-color-flat-icons.png\" style=\"height:50px;display:inline\"> Data Loading Tips to Maximize GPU Utility\n", "---\n", "* One of the greatest bottlenecks in training NNs is the data loading into memory and transferring it to the GPU for computation.\n", "* If the data loading process is lengthy (e.g., the data itself is complex or a lot of pre-processing is required), this can result in a lot of CPU time while the GPU sits idle (i.e., low GPU utilization in `nvidia-smi`).\n", "* We'd like to fetch the data as fast as possible and quickly transfer it to the GPU. We'll review a couple of approaches to accelerate this process." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Use pinned memory data loading\n", "---\n", "* Host to GPU copies are much faster when they originate from pinned (page-locked) memory.\n", "* In CUDA, non-paged CPU (RAM) memory is referred to as pinned memory. Pinning a block of memory can be done via a CUDA API call, which issues an OS call that reserves the memory block and sets the constraint that it cannot be spilled to disk.\n", "* Pinned memory is used to speed up a CPU to GPU memory copy operation (as executed by e.g. `tensor.cuda()` in PyTorch) by ensuring that none of the memory that is to be copied is on disk. \n", "* Memory cached to disk has to be read into RAM before it can be transferred to the GPU—e.g. *it has to be copied twice*. You can naively expect this to be twice as slow (the true slowness depends on the size and business of the relevant memory buses).\n", "* The `pin_memory` field (`pin_memory=True`) on `DataLoader` invokes this memory management model. \n", " * Note: this technique requires that the OS is willing to give the PyTorch process as much main memory as it needs to complete its load and transform operations—e.g. the batch must fit into RAM in its entirety." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<center><img src=\"https://images.prismic.io/spell/7f788bc9-f60c-4e7c-bcad-0106d09585fd_foo.png\" style=\"height:300px\"></center>" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# example\n", "train_dataset = Dataset()\n", "train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Use multiprocessing for data loading\n", "---\n", "* Applying the following tip should speed-up your training code significantly.\n", "* PyTorch allows loading data on multiple processes simultaneously.\n", "* To set the number of workers: `DataLoader(num_workers=4)`.\n", "* How to set the number of workers? A general rule is `num_workers = 4 * num_gpus`;however, this is highly machine-depndent, and should be treated as a hyper-parameter.\n", " * Usually 4 or 8 is a good number to start with.\n", " * A nice discussion on PyTorch forums: <a href=\"https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/69\">Guidelines for assigning num_workers to DataLoader</a>\n", " * A script to search for the optimal `num_workers`: <a href=\"https://github.com/developer0hye/Num-Workers-Search\">Num-Workers-Search</a>.\n", "* Important notes:\n", " * Having more workers will **increase the memory usage**.\n", " * `num_workers=0` means that it’s the main process that will do the data loading when needed\n", " * `num_workers=1` is the same as any $n$, but you’ll only have a single worker, which is probably slower." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True, num_workers=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<center><img src=\"https://images.prismic.io/spell/57bc0609-3f85-45ac-820d-41ecd83a66af_foo.png\" style=\"height:400px\"></center>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Perform image augmentations directly on the GPU with Kornia\n", "---\n", "* Instead of performing augmentations with `torchvision`, we can perform the augmentations on the GPU, which can speed-up the data loading process.\n", "* <a href=\"https://kornia.github.io/\">Kornia</a> is a differentiable library that allows classical computer vision to be integrated into deep learning models.\n", " * That menas augementations and filters can utilize the GPU but also be differentiable for the backpropagation process!\n", " * `pip install kornia`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from kornia import augmentation as K\n", "from kornia.augmentation import AugmentationSequential\n", "\n", "aug_list = AugmentationSequential(\n", " K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),\n", " K.RandomAffine(360, [0.1, 0.1], [0.7, 1.2], [30., 50.], p=1.0),\n", " K.RandomPerspective(0.5, p=1.0),\n", " return_transform=False,\n", " same_on_batch=False,\n", ")\n", "\n", "img_aug = aug_list(img_tensor) # [batch_size, num_ch, h, w]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<table class=\"docutils align-default\" id=\"id1\">\n", "<caption><span class=\"caption-text\">Here is a benchmark performed on <a class=\"reference external\" href=\"https://colab.research.google.com/drive/1b-HpK4EsZR8uolztgH4roNBLaDwcMULx?usp=sharing\">Google Colab</a>\n", "K80 GPU with different libraries and batch sizes. This benchmark shows\n", "strong GPU augmentation speed acceleration brought by Kornia data augmentations. The image size is fixed to 224x224 and the\n", "unit is milliseconds (ms).</span></caption>\n", "<colgroup>\n", "<col style=\"width: 27%\">\n", "<col style=\"width: 15%\">\n", "<col style=\"width: 15%\">\n", "<col style=\"width: 15%\">\n", "<col style=\"width: 15%\">\n", "<col style=\"width: 15%\">\n", "</colgroup>\n", "<thead>\n", "<tr class=\"row-odd\"><th class=\"head\"><p>Libraries</p></th>\n", "<th class=\"head\"><p>TorchVision</p></th>\n", "<th class=\"head\"><p>Albumentations</p></th>\n", "<th class=\"head\" colspan=\"3\"><p>Kornia (GPU)</p></th>\n", "</tr>\n", "<tr class=\"row-even\"><th class=\"head\"><p>Batch Size</p></th>\n", "<th class=\"head\"><p>1</p></th>\n", "<th class=\"head\"><p>1</p></th>\n", "<th class=\"head\"><p>1</p></th>\n", "<th class=\"head\"><p>32</p></th>\n", "<th class=\"head\"><p>128</p></th>\n", "</tr>\n", "</thead>\n", "<tbody>\n", "<tr class=\"row-odd\"><td><p>RandomPerspective</p></td>\n", "<td><p>4.88±1.82</p></td>\n", "<td><p>4.68±3.60</p></td>\n", "<td><p>4.74±2.84</p></td>\n", "<td><p>0.37±2.67</p></td>\n", "<td><p>0.20±27.00</p></td>\n", "</tr>\n", "<tr class=\"row-even\"><td><p>ColorJiggle</p></td>\n", "<td><p>4.40±2.88</p></td>\n", "<td><p>3.58±3.66</p></td>\n", "<td><p>4.14±3.85</p></td>\n", "<td><p>0.90±24.68</p></td>\n", "<td><p>0.83±12.96</p></td>\n", "</tr>\n", "<tr class=\"row-odd\"><td><p>RandomAffine</p></td>\n", "<td><p>3.12±5.80</p></td>\n", "<td><p>2.43±7.11</p></td>\n", "<td><p>3.01±7.80</p></td>\n", "<td><p>0.30±4.39</p></td>\n", "<td><p>0.18±6.30</p></td>\n", "</tr>\n", "<tr class=\"row-even\"><td><p>RandomVerticalFlip</p></td>\n", "<td><p>0.32±0.08</p></td>\n", "<td><p>0.34±0.16</p></td>\n", "<td><p>0.35±0.82</p></td>\n", "<td><p>0.02±0.13</p></td>\n", "<td><p>0.01±0.35</p></td>\n", "</tr>\n", "<tr class=\"row-odd\"><td><p>RandomHorizontalFlip</p></td>\n", "<td><p>0.32±0.08</p></td>\n", "<td><p>0.34±0.18</p></td>\n", "<td><p>0.31±0.59</p></td>\n", "<td><p>0.01±0.26</p></td>\n", "<td><p>0.01±0.37</p></td>\n", "</tr>\n", "<tr class=\"row-even\"><td><p>RandomRotate</p></td>\n", "<td><p>1.82±4.70</p></td>\n", "<td><p>1.59±4.33</p></td>\n", "<td><p>1.58±4.44</p></td>\n", "<td><p>0.25±2.09</p></td>\n", "<td><p>0.17±5.69</p></td>\n", "</tr>\n", "<tr class=\"row-odd\"><td><p>RandomCrop</p></td>\n", "<td><p>4.09±3.41</p></td>\n", "<td><p>4.03±4.94</p></td>\n", "<td><p>3.84±3.07</p></td>\n", "<td><p>0.16±1.17</p></td>\n", "<td><p>0.08±9.42</p></td>\n", "</tr>\n", "<tr class=\"row-even\"><td><p>RandomErasing</p></td>\n", "<td><p>2.31±1.47</p></td>\n", "<td><p>1.89±1.08</p></td>\n", "<td><p>2.32±3.31</p></td>\n", "<td><p>0.44±2.82</p></td>\n", "<td><p>0.57±9.74</p></td>\n", "</tr>\n", "<tr class=\"row-odd\"><td><p>RandomGrayscale</p></td>\n", "<td><p>0.41±0.18</p></td>\n", "<td><p>0.43±0.60</p></td>\n", "<td><p>0.45±1.20</p></td>\n", "<td><p>0.03±0.11</p></td>\n", "<td><p>0.03±7.10</p></td>\n", "</tr>\n", "<tr class=\"row-even\"><td><p>RandomResizedCrop</p></td>\n", "<td><p>4.23±2.86</p></td>\n", "<td><p>3.80±3.61</p></td>\n", "<td><p>4.07±2.67</p></td>\n", "<td><p>0.23±5.27</p></td>\n", "<td><p>0.13±8.04</p></td>\n", "</tr>\n", "<tr class=\"row-odd\"><td><p>RandomCenterCrop</p></td>\n", "<td><p>2.93±1.29</p></td>\n", "<td><p>2.81±1.38</p></td>\n", "<td><p>2.88±2.34</p></td>\n", "<td><p>0.13±2.20</p></td>\n", "<td><p>0.07±9.41</p></td>\n", "</tr>\n", "</tbody>\n", "</table>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <img src=\"https://img.icons8.com/external-flat-land-kalash/64/000000/external-training-education-and-e-learning-flat-land-kalash.png\" style=\"height:50px;display:inline\"> Training Tips\n", "---\n", "There are 2 main approaches (that can be combined) to achieve a speed-up during train time:\n", "1. Automatic Mixed Precision (AMP) - instead of working with tensors at full-precision (`float32`), we can (sometimes) work at half-precision (`float16`). Current advances in this field also experiment with even lower precision.\n", "2. Distributed training - utilizing multiple GPUs to train a single model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Use mixed precision and AMP\n", "---\n", "* Mixed precision leverages Tensor Cores and offers up to 3x overall speedup on Volta and newer GPU architectures. \n", "* To use Tensor Cores AMP should be enabled and matrix/tensor dimensions should satisfy requirements for calling kernels that use Tensor Cores.\n", "* Deep Neural Network training has traditionally relied on FP32 (32-bit Floating Point, IEEE single-precision format).\n", "* The (automatic) mixed precision technique - training with FP16 (16-bit Floating Point, half-precision) while maintaining the network accuracy achieved with FP32.\n", "* Enabling mixed precision involves two steps: \n", " * Porting the model to use the half-precision data type **where appropriate**.\n", " * Using loss scaling to preserve small gradient values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Performance of mixed precision training on NVIDIA 8xV100 vs. FP32 training on 8xV100 GPU**\n", "\n", "<center><img src=\"https://raw.githubusercontent.com/taldatech/ee046211-deep-learning/main/assets/tut_compress_amp_chart.png\" style=\"height:300px\"></center>\n", "\n", "* Bars represent the speedup factor of V100 AMP over V100 FP32. The higher the better.\n", "\n", "<a href=\"https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/\">Image Source</a>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* <a href=\"https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html\">AMP Recipe from PyTorch</a>\n", "* <a href=\"https://nbviewer.org/github/taldatech/ee046211-deep-learning/blob/main/ee046211_tutorial_10_compression_pruning_amp.ipynb\">ECE 046211 AMP Tutorial</a>" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# general recipe\n", "\n", "use_amp = True\n", "\n", "net = make_model(in_size, out_size, num_layers)\n", "opt = torch.optim.SGD(net.parameters(), lr=0.001)\n", "scaler = torch.cuda.amp.GradScaler(enabled=use_amp) # notice the `enabled` parameter\n", "# Gradient scaling helps prevent gradients with small magnitudes from flushing\n", "# to zero (“underflowing”) when training with mixed precision\n", "\n", "start_timer()\n", "for epoch in range(epochs):\n", " for inputs, target in zip(data, targets):\n", " # notice the `enabled` parameter\n", " with torch.cuda.amp.autocast(enabled=use_amp):\n", " output = net(inputs)\n", " loss = loss_fn(output, target)\n", " \n", " # set_to_none=True here can modestly improve performance, replace 0 (float) with None (save mem)\n", " opt.zero_grad(set_to_none=True) \n", " scaler.scale(loss).backward()\n", " scaler.step(opt)\n", " scaler.update()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Distributed multi-GPU training with HuggingFace's Accelerate\n", "---\n", "* In the past, a lot of boilerplate code was needed to make our PyTorch code run on multiple GPUs with <a href=\"https://pytorch.org/docs/stable/distributed.html\">`torch.distributed`</a>.\n", "* Today, things are a lot more easier with wrappers like <a href=\"https://huggingface.co/docs/accelerate/index\">HuggingFace Accelerate</a>.\n", " * `pip instal accelerate`\n", "* Accelerate is a library that enables the same PyTorch code to be run across any distributed configuration by adding just four lines of code.\n", "* Accelerate supports the following configurations: CPU only, multi-CPU on one node (machine), multi-CPU on several nodes (machines), single GPU, multi-GPU on one node (machine), multi-GPU on several nodes (machines), TPU, FP16 with native AMP (apex on the roadmap), DeepSpeed support (Experimental), PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental).\n", "* <a href=\"https://github.com/huggingface/accelerate\">Examples can be found here</a>" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "from datasets import load_dataset\n", "from accelerate import Accelerator\n", "\n", "accelerator = Accelerator() # NEW\n", "\n", "model = torch.nn.Transformer()\n", "optimizer = torch.optim.Adam(model.parameters())\n", "\n", "dataset = load_dataset('my_dataset')\n", "data = torch.utils.data.DataLoader(dataset, shuffle=True)\n", "\n", "model, optimizer, data = accelerator.prepare(model, optimizer, data) # NEW\n", "\n", "model.train()\n", "for epoch in range(10):\n", " for source, targets in data:\n", " output = model(source)\n", " loss = F.cross_entropy(output, targets)\n", " optimizer.zero_grad()\n", " accelerator.backward(loss) # NEW\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* To launch the new distributed code we create a config file with: `accelerate config` which will ask you a few questions (e.g., how many GPUs to use).\n", "* Finally, run `accelerate launch my_script.py --args_to_my_script`.\n", "* If you work with <a href=\"https://www.pytorchlightning.ai/\">PyTorch Lightning</a>, then it already installs Accelerate and wraps it in its API." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <img src=\"https://img.icons8.com/office/80/system-task.png\" style=\"height:50px;display:inline\"> PyTorch Profiler\n", "---\n", "* PyTorch Profiler is useful to determine the most expensive operators in the model and analyze execution times.\n", "* <a href=\"https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html\">PyTorch Profiler tutorial</a>" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# usage example\n", "import torch\n", "import torchvision.models as models\n", "from torch.profiler import profile, record_function, ProfilerActivity\n", "\n", "model = models.resnet18()\n", "inputs = torch.randn(5, 3, 224, 224)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# analyze execution time\n", "with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:\n", " with record_function(\"model_inference\"):\n", " model(inputs)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", " Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls \n", "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", " model_inference 4.44% 10.480ms 100.00% 236.255ms 236.255ms 0 b -106.30 Mb 1 \n", " aten::conv2d 2.04% 4.815ms 61.43% 145.141ms 7.257ms 47.37 Mb 0 b 20 \n", " aten::convolution 1.75% 4.137ms 59.40% 140.326ms 7.016ms 47.37 Mb 0 b 20 \n", " aten::_convolution 0.63% 1.482ms 57.64% 136.189ms 6.809ms 47.37 Mb 0 b 20 \n", " aten::mkldnn_convolution 56.15% 132.665ms 57.02% 134.707ms 6.735ms 47.37 Mb 0 b 20 \n", " aten::max_pool2d 0.30% 704.000us 14.15% 33.425ms 33.425ms 11.48 Mb 0 b 1 \n", " aten::max_pool2d_with_indices 13.85% 32.721ms 13.85% 32.721ms 32.721ms 11.48 Mb 11.48 Mb 1 \n", " aten::batch_norm 0.30% 717.000us 10.10% 23.860ms 1.193ms 47.41 Mb 0 b 20 \n", " aten::_batch_norm_impl_index 0.45% 1.066ms 9.80% 23.143ms 1.157ms 47.41 Mb 0 b 20 \n", " aten::native_batch_norm 9.09% 21.469ms 9.33% 22.044ms 1.102ms 47.41 Mb -64.00 Kb 20 \n", "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", "Self CPU time total: 236.255ms\n", "\n" ] } ], "source": [ "# stats for the execution\n", "print(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=10))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ -------------------------------------------------------------------------------- \n", " Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls Input Shapes \n", "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ -------------------------------------------------------------------------------- \n", " model_inference 4.44% 10.480ms 100.00% 236.255ms 236.255ms 0 b -106.30 Mb 1 [] \n", " aten::conv2d 1.99% 4.701ms 16.21% 38.296ms 38.296ms 15.31 Mb 0 b 1 [[5, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], []] \n", " aten::convolution 1.57% 3.721ms 14.22% 33.595ms 33.595ms 15.31 Mb 0 b 1 [[5, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], [], []] \n", " aten::max_pool2d 0.30% 704.000us 14.15% 33.425ms 33.425ms 11.48 Mb 0 b 1 [[5, 64, 112, 112], [], [], [], [], []] \n", " aten::max_pool2d_with_indices 13.85% 32.721ms 13.85% 32.721ms 32.721ms 11.48 Mb 11.48 Mb 1 [[5, 64, 112, 112], [], [], [], [], []] \n", " aten::_convolution 0.52% 1.239ms 12.64% 29.874ms 29.874ms 15.31 Mb 0 b 1 [[5, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []] \n", " aten::mkldnn_convolution 11.39% 26.910ms 12.12% 28.635ms 28.635ms 15.31 Mb 0 b 1 [[5, 3, 224, 224], [64, 3, 7, 7], [], [], [], [], []] \n", " aten::conv2d 0.01% 23.000us 11.55% 27.276ms 6.819ms 15.31 Mb 0 b 4 [[5, 64, 56, 56], [64, 64, 3, 3], [], [], [], [], []] \n", " aten::convolution 0.03% 78.000us 11.54% 27.253ms 6.813ms 15.31 Mb 0 b 4 [[5, 64, 56, 56], [64, 64, 3, 3], [], [], [], [], [], [], []] \n", " aten::_convolution 0.02% 49.000us 11.50% 27.175ms 6.794ms 15.31 Mb 0 b 4 [[5, 64, 56, 56], [64, 64, 3, 3], [], [], [], [], [], [], [], [], [], [], []] \n", "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ -------------------------------------------------------------------------------- \n", "Self CPU time total: 236.255ms\n", "\n" ] } ], "source": [ "# include operator input shapes and sort by the self cpu time\n", "print(prof.key_averages(group_by_input_shape=True).table(sort_by=\"cpu_time_total\", row_limit=10))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", " Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls \n", "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", " aten::empty 0.25% 588.000us 0.25% 588.000us 2.940us 94.85 Mb 94.85 Mb 200 \n", " aten::batch_norm 0.30% 717.000us 10.10% 23.860ms 1.193ms 47.41 Mb 0 b 20 \n", " aten::_batch_norm_impl_index 0.45% 1.066ms 9.80% 23.143ms 1.157ms 47.41 Mb 0 b 20 \n", " aten::native_batch_norm 9.09% 21.469ms 9.33% 22.044ms 1.102ms 47.41 Mb -64.00 Kb 20 \n", " aten::conv2d 2.04% 4.815ms 61.43% 145.141ms 7.257ms 47.37 Mb 0 b 20 \n", " aten::convolution 1.75% 4.137ms 59.40% 140.326ms 7.016ms 47.37 Mb 0 b 20 \n", " aten::_convolution 0.63% 1.482ms 57.64% 136.189ms 6.809ms 47.37 Mb 0 b 20 \n", " aten::mkldnn_convolution 56.15% 132.665ms 57.02% 134.707ms 6.735ms 47.37 Mb 0 b 20 \n", " aten::empty_like 0.12% 277.000us 0.15% 343.000us 17.150us 47.37 Mb 0 b 20 \n", " aten::max_pool2d 0.30% 704.000us 14.15% 33.425ms 33.425ms 11.48 Mb 0 b 1 \n", "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", "Self CPU time total: 236.255ms\n", "\n" ] } ], "source": [ "# sort by memory usage\n", "print(prof.key_averages().table(sort_by=\"cpu_memory_usage\", row_limit=10))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", " Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls \n", "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", " model_inference 0.38% 3.567ms 100.00% 947.626ms 947.626ms 195.000us 0.03% 679.493ms 679.493ms 1 \n", " aten::conv2d 0.02% 154.000us 81.83% 775.462ms 38.773ms 74.000us 0.01% 515.147ms 25.757ms 20 \n", " aten::convolution 0.04% 338.000us 81.82% 775.308ms 38.765ms 68.000us 0.01% 515.073ms 25.754ms 20 \n", " aten::_convolution 0.04% 340.000us 81.78% 774.970ms 38.748ms 84.000us 0.01% 515.005ms 25.750ms 20 \n", " aten::cudnn_convolution 81.74% 774.630ms 81.74% 774.630ms 38.731ms 514.921ms 75.78% 514.921ms 25.746ms 20 \n", " aten::add_ 3.32% 31.494ms 3.32% 31.494ms 1.125ms 100.832ms 14.84% 100.832ms 3.601ms 28 \n", " aten::batch_norm 0.01% 109.000us 5.53% 52.364ms 2.618ms 75.000us 0.01% 38.227ms 1.911ms 20 \n", " aten::_batch_norm_impl_index 0.11% 1.011ms 5.51% 52.255ms 2.613ms 70.000us 0.01% 38.152ms 1.908ms 20 \n", " aten::cudnn_batch_norm 5.17% 49.028ms 5.41% 51.244ms 2.562ms 37.790ms 5.56% 38.082ms 1.904ms 20 \n", " aten::adaptive_avg_pool2d 0.00% 17.000us 3.35% 31.774ms 31.774ms 3.000us 0.00% 10.017ms 10.017ms 1 \n", "--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", "Self CPU time total: 947.626ms\n", "Self CUDA time total: 679.493ms\n", "\n" ] } ], "source": [ "# analyze performance of models executed on GPUs\n", "device = torch.device(\"cuda:0\")\n", "model = models.resnet18().to(device)\n", "inputs = torch.randn(5, 3, 224, 224).to(device)\n", "\n", "with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:\n", " with record_function(\"model_inference\"):\n", " model(inputs)\n", "\n", "print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## <img src=\"https://img.icons8.com/office/80/000000/controller.png\" style=\"height:50px;display:inline\"> Reinforcement Learning\n", "---\n", "* In (online) RL, the main bottleneck is the interaction with the environment in order to collect data, which usually occurs in a simulated environment that runs on the CPU.\n", "* This setting results in \"spikes\" in the GPU utility, where the GPU is idle during the interaction in the environment.\n", "* There are several approaches to speed-up the RL process. For example, one can utilize simulators that natively run on GPUs where inputs and ouputs are already tensors, reducing the CPU-GPU latency and utilizes the tensor cores to speed-up the simulation itself (faster data collection). Another persepctive is CPU multiprocessing -- utilizing several environments in parallel, which will hopefully reduce CPU time." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### GPU-based Simulators\n", "---\n", "* <a href=\"https://developer.nvidia.com/isaac-gym\">NVIDIA's Isaacs Gym</a> - NVIDIA’s physics simulation environment for reinforcement learning research.\n", " * GPU accelerated tensor API for evaluating environment state and applying actions.\n", " * Physics simulation in Isaac Gym runs on the GPU, storing results in PyTorch GPU tensors.\n", " * Using the Isaac Gym tensor-based APIs, observations and rewards can be calculated on the GPU in PyTorch, enabling thousands of environments to run in parallel on a single workstation.\n", " * <a href=\"https://www.youtube.com/watch?v=nleDq-oJjGk&list=PLq2Xfjf6QzkrgDkQdtEzlnXeUAbTPEXNH\">Isaac Gym Video Tutorial</a>\n", " * Code examples: <a href=\"https://github.com/NVIDIA-Omniverse/IsaacGymEnvs\">Isaacs Gym Environments</a>, <a href=\"https://github.com/AI4Finance-Foundation/ElegantRL/tree/master\">ElegantRL</a>, <a href=\"https://github.com/wangcongrobot/awesome-isaac-gym\">Collection of Isaacs Resources</a>, <a href=\"https://skrl.readthedocs.io/en/latest/intro/examples.html#learning-in-an-isaac-gym-environment\">PPO Example</a>, <a href=\"https://github.com/nv-tlabs/ASE\">Adversarial Skill Embedding</a>" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<center>\n", "<video width=\"320\" height=\"240\" controls>\n", " <source src=\"https://developer.download.nvidia.com/video/Nut_Bolt_Screw_IK_OSC.mp4\" type=\"video/mp4\">\n", "</video>\n", "</center>\n" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%HTML\n", "<center>\n", "<video width=\"320\" height=\"240\" controls>\n", " <source src=\"https://developer.download.nvidia.com/video/Nut_Bolt_Screw_IK_OSC.mp4\" type=\"video/mp4\">\n", "</video>\n", "</center>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* <a href=\"https://github.com/google/brax\">Google's BRAX</a> - a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. \n", " * Brax is written in JAX and is designed for use on acceleration hardware. \n", " * It is both efficient for single-device simulation, and scalable to massively parallel simulation on multiple devices, without the need for pesky datacenters.\n", " * Brax also includes a suite of learning algorithms that train agents in seconds to minutes.\n", " * Code examples: <a href=\"https://colab.research.google.com/github/google/brax/blob/main/notebooks/training_torch.ipynb\">Training in Brax with PyTorch on GPUs</a>, <a href=\"https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py\">PPO Example</a>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<center><img src=\"https://github.com/google/brax/raw/main/docs/img/fetch.gif\" style=\"height:150px\"></center>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### RL Frameworks with Parallel/Batch Environments\n", "---\n", "* <a href=\"https://github.com/sail-sg/envpool\">EnvPool</a> - a batched environment pool with pybind11 and thread pool. \n", " * It has high performance (~1M raw FPS with Atari games, ~3M raw FPS with Mujoco simulator on DGX-A100) and compatible APIs (supports both gym and dm_env, both sync and async, both single and multi player environment).\n", " * Compatible with OpenAI `gym` APIs and DeepMind `dm_env` APIs.\n", " * 1 Million Atari frames / 3 Million Mujoco steps per second simulation with 256 CPU cores, ~20x throughput of Python subprocess-based vector env.\n", " * ~3x throughput of Python subprocess-based vector env on low resource setup like 12 CPU cores.\n", " * Comparing with the existing GPU-based solution (Brax / Isaac-gym), EnvPool is a general solution for various kinds of speeding-up RL environment parallelization.\n", " Code examples: <a href=\"https://colab.research.google.com/drive/1U_NxL6gSs0yRVhfl0cKl9ttRmcmMCiBS?usp=sharing\">15-min Atari Breakout</a>, <a href=\"https://colab.research.google.com/drive/1bser52bpItzmlME00IA0bbmPdp1Xm0fy?usp=sharing\">5-min MuJoCo HalfCheetah</a>\n", " * **CleanRL**: EnvPool is compatible with <a href=\"https://docs.cleanrl.dev/\">CleanRL</a> - a <a href=\"https://github.com/vwxyzjn/cleanrl\">Deep Reinforcement Learning library</a> that provides high-quality single-file implementation with research-friendly features. For example, `ppo_atari.py` only has 340 lines of code but contains all implementation details on how PPO works with Atari games, so it is a great reference implementation to read for folks who do not wish to read an entire modular library." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<center><img src=\"https://envpool.readthedocs.io/en/latest/_images/throughput.png\" style=\"height:400px\"></center>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* <a href=\"https://stable-baselines3.readthedocs.io/en/master/\">Stable-Baselines-3</a> - Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch.\n", " * <a href=\"https://github.com/DLR-RM/stable-baselines3\">GitHub</a>\n", " * <a href=\"https://github.com/DLR-RM/rl-baselines3-zoo\">RL Baselines3 Zoo</a> provides a collection of pre-trained agents, scripts for training, evaluating agents, tuning hyperparameters, plotting results and recording videos.\n", " * <a href=\"https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html\">Tips and Tricks with SB3</a> - <a href=\"https://www.youtube.com/watch?v=Ikngt0_DXJg\">YouTube Video</a>\n", " * <a href=\"https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html\">Vectorized Environments</a> with Multiprocessing support - Vectorized Environments are a method for stacking multiple independent environments into a single environment.\n", " * <a href=\"https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#sb3-with-envpool-or-isaac-gym\">SB3 with EnvPool or Isaac Gym</a>\n", " * <a href=\"https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb\">Multiprocessing Example with SB3</a>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* <a href=\"https://elegantrl.readthedocs.io/en/latest/\">ElegantRL</a> - an open-source massively parallel framework for deep reinforcement learning (DRL) algorithms implemented in PyTorch.\n", " * <a href=\"https://github.com/AI4Finance-Foundation/ElegantRL\">GitHub</a>\n", " * Scalability: ElegantRL fully exploits the parallelism of DRL algorithms at multiple levels, making it easily scale out to hundreds or thousands of computing nodes on a cloud platform.\n", " * Elastic: allows to elastically and automatically allocate computing resources on the cloud.\n", " * Efficient: in many testing cases (single GPU/multi-GPU/GPU cloud). \"We find it more efficient than Ray RLlib\".\n", " * Stable: \"much much much more stable than Stable Baselines 3 by utilizing various ensemble methods\".\n", " * <a href=\"https://elegantrl.readthedocs.io/en/latest/tutorial/Creating_VecEnv.html\">VecEnvs</a> - ElegantRL supports massively parallel simulation through GPU-accelerated VecEnv.\n", " * <a href=\"https://elegantrl.readthedocs.io/en/latest/tutorial/isaacgym.html\">Supports Isaac Gym</a>\n", " * <a href=\"https://towardsdatascience.com/elegantrl-mastering-the-ppo-algorithm-part-i-9f36bc47b791\">PPO Example</a>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<center><img src=\"https://github.com/AI4Finance-Foundation/ElegantRL/raw/master/figs/SB3_vs_ElegantRL.png\" style=\"height:300px\"></center> " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* <a href=\"https://pytorch.org/rl/\">TorchRL</a> - TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.\n", " * <a href=\"https://github.com/pytorch/rl\">GitHub</a>\n", " * TorchRL provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested.\n", " * On the low-level end, torchrl comes with a set of highly re-usable functionals for cost functions, returns and data processing.\n", " * <a href=\"https://pytorch.org/rl/#tutorials\">Tutorials</a>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<center><img src=\"./assets/torch_rl.PNG\" style=\"height:300px\"></center> " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### <img src=\"https://img.icons8.com/bubbles/50/000000/video-playlist.png\" style=\"height:50px;display:inline\"> Recommended Videos\n", "---\n", "* <a href=\"https://www.youtube.com/watch?v=jF4-_ZK_tyc\">AMP - Training Neural Networks with Tensor Cores - Dusan Stosic, NVIDIA</a>\n", "* Autimatic Mixed Precision (AMP) - <a href=\"https://www.youtube.com/watch?v=b5dAmcBKxHg\"> NVIDIA - Automatic Mixed Precision Training in PyTorch </a>\n", "* <a href=\"https://www.youtube.com/watch?v=9mS1fIYj1So\">PyTorch Performance Tuning Guide - Szymon Migacz, NVIDIA</a>\n", "* <a href=\"https://www.youtube.com/watch?v=Ikngt0_DXJg\">RL Tips and Tricks with Stable-Baselines-3</a>\n", "* <a href=\"https://www.youtube.com/watch?v=nleDq-oJjGk&list=PLq2Xfjf6QzkrgDkQdtEzlnXeUAbTPEXNH\">Isaac Gym Tutorial</a>" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "## <img src=\"https://img.icons8.com/dusk/64/000000/prize.png\" style=\"height:50px;display:inline\"> Credits\n", "---\n", "* Icons from <a href=\"https://icons8.com/\">Icons8.com</a> - https://icons8.com\n", "* <a href=\"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html\">Performance Tuning Guide - Szymon Migacz</a>\n", "* <a href=\"https://medium.com/analytics-vidhya/explained-output-of-nvidia-smi-utility-fc4fbee3b124\">Explained Output of Nvidia-smi Utility - Shachi Kaul</a>\n", "* <a href=\"https://towardsdatascience.com/7-tips-for-squeezing-maximum-performance-from-pytorch-ca4a40951259\">7 Tips To Maximize PyTorch Performance - William Falcon</a>\n", "* <a href=\"https://spell.ml/blog/pytorch-training-tricks-YAnJqBEAACkARhgD\">Tricks for training PyTorch models to convergence more quickly - Aleksey Bilogur</a>" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 4 }