[1]:
# Copyright 2019 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Torch-TensorRT Getting Started - LeNet ¶
Overview ¶
In the practice of developing machine learning models, there are few tools as approachable as PyTorch for developing and experimenting in designing machine learning models. The power of PyTorch comes from its deep integration into Python, its flexibility and its approach to automatic differentiation and execution (eager execution). However, when moving from research into production, the requirements change and we may no longer want that deep Python integration and we want optimization to get the best performance we can on our deployment platform. In PyTorch 1.0, TorchScript was introduced as a method to separate your PyTorch model from Python, make it portable and optimizable. TorchScript uses PyTorch’s JIT compiler to transform your normal PyTorch code which gets interpreted by the Python interpreter to an intermediate representation (IR) which can have optimizations run on it and at runtime can get interpreted by the PyTorch JIT interpreter. For PyTorch this has opened up a whole new world of possibilities, including deployment in other languages like C++. It also introduces a structured graph based format that we can use to do down to the kernel level optimization of models for inference.
When deploying on NVIDIA GPUs TensorRT, NVIDIA’s Deep Learning Optimization SDK and Runtime is able to take models from any major framework and specifically tune them to perform better on specific target hardware in the NVIDIA family be it an A100, TITAN V, Jetson Xavier or NVIDIA’s Deep Learning Accelerator. TensorRT performs a couple sets of optimizations to achieve this. TensorRT fuses layers and tensors in the model graph, it then uses a large kernel library to select implementations that perform best on the target GPU. TensorRT also has strong support for reduced operating precision execution which allows users to leverage the Tensor Cores on Volta and newer GPUs as well as reducing memory and computation footprints on device.
Torch-TensorRT is a compiler that uses TensorRT to optimize TorchScript code, compiling standard TorchScript modules into ones that internally run with TensorRT optimizations. This enables you to continue to remain in the PyTorch ecosystem, using all the great features PyTorch has such as module composability, its flexible tensor implementation, data loaders and more. Torch-TensorRT is available to use with both PyTorch and LibTorch.
Learning objectives ¶
This notebook demonstrates the steps for compiling a TorchScript module with Torch-TensorRT on a simple LeNet network.
Content ¶
## 1. Requirements
Follow the steps in
notebooks/README
to prepare a Docker container, within which you can run this notebook.
## 2. Creating TorchScript modules
Here we create two submodules for a feature extractor and a classifier and stitch them together in a single LeNet module. In this case this is overkill but modules give us granular control over our program including where we decide to optimize and where we don’t. It is also the unit that the TorchScript compiler operates on. So you can decide to only convert/optimize the feature extractor and leave the classifier in standard PyTorch or you can convert the whole thing. When compiling your module to TorchScript, there are two paths: Tracing and Scripting.
[2]:
import torch
from torch import nn
import torch.nn.functional as F
class LeNetFeatExtractor(nn.Module):
def __init__(self):
super(LeNetFeatExtractor, self).__init__()
self.conv1 = nn.Conv2d(1, 128, 3)
self.conv2 = nn.Conv2d(128, 16, 3)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
return x
class LeNetClassifier(nn.Module):
def __init__(self):
super(LeNetClassifier, self).__init__()
self.fc1 = nn.Linear(16 * 6 * 6, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = torch.flatten(x,1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.feat = LeNetFeatExtractor()
self.classifer = LeNetClassifier()
def forward(self, x):
x = self.feat(x)
x = self.classifer(x)
return x
Let us define a helper function to benchmark a model.
[3]:
import time
import numpy as np
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
def benchmark(model, input_shape=(1024, 1, 32, 32), dtype='fp32', nwarmup=50, nruns=1000):
input_data = torch.randn(input_shape)
input_data = input_data.to("cuda")
if dtype=='fp16':
input_data = input_data.half()
print("Warm up ...")
with torch.no_grad():
for _ in range(nwarmup):
features = model(input_data)
torch.cuda.synchronize()
print("Start timing ...")
timings = []
with torch.no_grad():
for i in range(1, nruns+1):
start_time = time.time()
features = model(input_data)
torch.cuda.synchronize()
end_time = time.time()
timings.append(end_time - start_time)
if i%100==0:
print('Iteration %d/%d, ave batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))
print("Input shape:", input_data.size())
print("Output features size:", features.size())
print('Average batch time: %.2f ms'%(np.mean(timings)*1000))
PyTorch model ¶
[4]:
model = LeNet()
model.to("cuda").eval()
[4]:
LeNet(
(feat): LeNetFeatExtractor(
(conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1))
)
(classifer): LeNetClassifier(
(fc1): Linear(in_features=576, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
)
[5]:
benchmark(model)
Warm up ...
Start timing ...
Iteration 100/1000, ave batch time 6.49 ms
Iteration 200/1000, ave batch time 6.49 ms
Iteration 300/1000, ave batch time 6.49 ms
Iteration 400/1000, ave batch time 6.49 ms
Iteration 500/1000, ave batch time 6.49 ms
Iteration 600/1000, ave batch time 6.49 ms
Iteration 700/1000, ave batch time 6.49 ms
Iteration 800/1000, ave batch time 6.49 ms
Iteration 900/1000, ave batch time 6.49 ms
Iteration 1000/1000, ave batch time 6.49 ms
Input shape: torch.Size([1024, 1, 32, 32])
Output features size: torch.Size([1024, 10])
Average batch time: 6.49 ms
When compiling your module to TorchScript, there are two paths: Tracing and Scripting.
Tracing ¶
Tracing follows the path of execution when the module is called and records what happens. This recording is what the TorchScript IR will describe. To trace an instance of our LeNet module, we can call torch.jit.trace with an example input.
[6]:
traced_model = torch.jit.trace(model, torch.empty([1,1,32,32]).to("cuda"))
traced_model
[6]:
LeNet(
original_name=LeNet
(feat): LeNetFeatExtractor(
original_name=LeNetFeatExtractor
(conv1): Conv2d(original_name=Conv2d)
(conv2): Conv2d(original_name=Conv2d)
)
(classifer): LeNetClassifier(
original_name=LeNetClassifier
(fc1): Linear(original_name=Linear)
(fc2): Linear(original_name=Linear)
(fc3): Linear(original_name=Linear)
)
)
[7]:
benchmark(traced_model)
Warm up ...
Start timing ...
Iteration 100/1000, ave batch time 6.49 ms
Iteration 200/1000, ave batch time 6.49 ms
Iteration 300/1000, ave batch time 6.49 ms
Iteration 400/1000, ave batch time 6.49 ms
Iteration 500/1000, ave batch time 6.49 ms
Iteration 600/1000, ave batch time 6.49 ms
Iteration 700/1000, ave batch time 6.49 ms
Iteration 800/1000, ave batch time 6.49 ms
Iteration 900/1000, ave batch time 6.49 ms
Iteration 1000/1000, ave batch time 6.49 ms
Input shape: torch.Size([1024, 1, 32, 32])
Output features size: torch.Size([1024, 10])
Average batch time: 6.49 ms
Scripting ¶
Scripting actually inspects your code with a compiler and generates an equivalent TorchScript program. The difference is that since tracing simply follows the execution of your module, it cannot pick up control flow for instance, it will only follow the code path that a particular input triggers. By working from the Python code, the compiler can include these components. We can run the script compiler on our LeNet module by calling torch.jit.script.
[8]:
model = LeNet().to("cuda").eval()
script_model = torch.jit.script(model)
[9]:
script_model
[9]:
RecursiveScriptModule(
original_name=LeNet
(feat): RecursiveScriptModule(
original_name=LeNetFeatExtractor
(conv1): RecursiveScriptModule(original_name=Conv2d)
(conv2): RecursiveScriptModule(original_name=Conv2d)
)
(classifer): RecursiveScriptModule(
original_name=LeNetClassifier
(fc1): RecursiveScriptModule(original_name=Linear)
(fc2): RecursiveScriptModule(original_name=Linear)
(fc3): RecursiveScriptModule(original_name=Linear)
)
)
[10]:
benchmark(script_model)
Warm up ...
Start timing ...
Iteration 100/1000, ave batch time 6.48 ms
Iteration 200/1000, ave batch time 6.48 ms
Iteration 300/1000, ave batch time 6.49 ms
Iteration 400/1000, ave batch time 6.49 ms
Iteration 500/1000, ave batch time 6.49 ms
Iteration 600/1000, ave batch time 6.48 ms
Iteration 700/1000, ave batch time 6.48 ms
Iteration 800/1000, ave batch time 6.48 ms
Iteration 900/1000, ave batch time 6.48 ms
Iteration 1000/1000, ave batch time 6.48 ms
Input shape: torch.Size([1024, 1, 32, 32])
Output features size: torch.Size([1024, 10])
Average batch time: 6.48 ms
## 3. Compiling with Torch-TensorRT
TorchScript traced model ¶
First, we compile the TorchScript traced model with Torch-TensorRT. Notice the performance impact.
[11]:
import torch_tensorrt
# We use a batch-size of 1024, and half precision
compile_settings = {
"inputs": [torch_tensorrt.Input(
min_shape=[1024, 1, 32, 32],
opt_shape=[1024, 1, 33, 33],
max_shape=[1024, 1, 34, 34],
dtype=torch.half
)],
"enabled_precisions": {torch.half} # Run with FP16
}
trt_ts_module = torch_tensorrt.compile(traced_model, **compile_settings)
input_data = torch.randn((1024, 1, 32, 32))
input_data = input_data.half().to("cuda")
input_data = input_data.half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
WARNING: [Torch-TensorRT] - For input x.1, found user specified input dtype as Float16, however when inspecting the graph, the input type expected was inferred to be Float
The compiler is going to use the user setting Float16
This conflict may cause an error at runtime due to partial compilation being enabled and therefore
compatibility with PyTorch's data type convention is required.
If you do indeed see errors at runtime either:
- Remove the dtype spec for x.1
- Disable partial compilation by setting require_full_compilation to True
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Detected invalid timing cache, setup a local cache instead
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Max value of this profile is not valid
[12]:
benchmark(trt_ts_module, input_shape=(1024, 1, 32, 32), dtype="fp16")
Warm up ...
Start timing ...
Iteration 100/1000, ave batch time 1.92 ms
Iteration 200/1000, ave batch time 1.98 ms
Iteration 300/1000, ave batch time 2.05 ms
Iteration 400/1000, ave batch time 2.07 ms
Iteration 500/1000, ave batch time 2.08 ms
Iteration 600/1000, ave batch time 2.09 ms
Iteration 700/1000, ave batch time 2.09 ms
Iteration 800/1000, ave batch time 2.09 ms
Iteration 900/1000, ave batch time 2.10 ms
Iteration 1000/1000, ave batch time 2.09 ms
Input shape: torch.Size([1024, 1, 32, 32])
Output features size: torch.Size([1024, 10])
Average batch time: 2.09 ms
TorchScript script model ¶
Next, we compile the TorchScript script model with Torch-TensorRT. Notice the performance impact.
[13]:
import torch_tensorrt
# We use a batch-size of 1024, and half precision
compile_settings = {
"inputs": [torch_tensorrt.Input(
min_shape=[1024, 1, 32, 32],
opt_shape=[1024, 1, 33, 33],
max_shape=[1024, 1, 34, 34],
dtype=torch.half
)],
"enabled_precisions": {torch.half} # Run with FP16
}
trt_script_module = torch_tensorrt.compile(script_model, **compile_settings)
input_data = torch.randn((1024, 1, 32, 32))
input_data = input_data.half().to("cuda")
input_data = input_data.half()
result = trt_script_module(input_data)
torch.jit.save(trt_script_module, "trt_script_module.ts")
WARNING: [Torch-TensorRT] - For input x.1, found user specified input dtype as Float16, however when inspecting the graph, the input type expected was inferred to be Float
The compiler is going to use the user setting Float16
This conflict may cause an error at runtime due to partial compilation being enabled and therefore
compatibility with PyTorch's data type convention is required.
If you do indeed see errors at runtime either:
- Remove the dtype spec for x.1
- Disable partial compilation by setting require_full_compilation to True
WARNING: [Torch-TensorRT TorchScript Conversion Context] - The logger passed into createInferBuilder differs from one already provided for an existing builder, runtime, or refitter. TensorRT maintains only a single logger pointer at any given time, so the existing value, which can be retrieved with getLogger(), will be used instead. In order to use a new logger, first destroy all existing builder, runner or refitter objects.
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Detected invalid timing cache, setup a local cache instead
WARNING: [Torch-TensorRT] - Max value of this profile is not valid
[14]:
benchmark(trt_script_module, input_shape=(1024, 1, 32, 32), dtype="fp16")
Warm up ...
Start timing ...
Iteration 100/1000, ave batch time 1.82 ms
Iteration 200/1000, ave batch time 1.90 ms
Iteration 300/1000, ave batch time 2.00 ms
Iteration 400/1000, ave batch time 2.05 ms
Iteration 500/1000, ave batch time 2.06 ms
Iteration 600/1000, ave batch time 2.08 ms
Iteration 700/1000, ave batch time 2.13 ms
Iteration 800/1000, ave batch time 2.15 ms
Iteration 900/1000, ave batch time 2.15 ms
Iteration 1000/1000, ave batch time 2.17 ms
Input shape: torch.Size([1024, 1, 32, 32])
Output features size: torch.Size([1024, 10])
Average batch time: 2.17 ms
Conclusion ¶
In this notebook, we have walked through the complete process of compiling TorchScript models with Torch-TensorRT and test the performance impact of the optimization.
What’s next ¶
Now it’s time to try Torch-TensorRT on your own model. Fill out issues at https://github.com/NVIDIA/Torch-TensorRT . Your involvement will help future development of Torch-TensorRT.