torch_tensorrt ¶
Functions ¶
-
torch_tensorrt.
set_device
( gpu_id ) ¶
-
torch_tensorrt.
compile
( module: Any , ir='default' , inputs=[] , enabled_precisions={<dtype.float: 0>} , **kwargs ) ¶ -
Compile a PyTorch module for NVIDIA GPUs using TensorRT
Takes a existing PyTorch module and a set of settings to configure the compiler and using the path specified in
ir
lower and compile the module to TensorRT returning a PyTorch Module backConverts specifically the forward method of a Module
- Parameters
-
module ( Union ( torch.nn.Module , torch.jit.ScriptModule ) – Source module
- Keyword Arguments
-
-
inputs ( List [ Union ( torch_tensorrt.Input , torch.Tensor ) ] ) –
Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.
input=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ]
-
enabled_precision ( Set ( Union ( torch.dtype , torch_tensorrt.dtype ) ) ) – The set of datatypes that TensorRT can use when selecting kernels
-
ir ( str ) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
-
**kwargs – Additional settings for the specific requested strategy (See submodules for more info)
-
- Returns
-
Compiled Module, when run it will execute via TensorRT
- Return type
-
torch.nn.Module
-
torch_tensorrt.
convert_method_to_trt_engine
( module: Any , method_name: str , ir='default' , inputs=[] , enabled_precisions={<dtype.float: 0>} , **kwargs ) ¶ -
Convert a TorchScript module method to a serialized TensorRT engine
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
- Parameters
-
module ( Union ( torch.nn.Module , torch.jit.ScriptModule ) – Source module
- Keyword Arguments
-
-
inputs ( List [ Union ( torch_tensorrt.Input , torch.Tensor ) ] ) –
Required List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type.
input=[ torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 torch_tensorrt.Input( min_shape=(1, 224, 224, 3), opt_shape=(1, 512, 512, 3), max_shape=(1, 1024, 1024, 3), dtype=torch.int32 format=torch.channel_last ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ]
-
enabled_precision ( Set ( Union ( torch.dtype , torch_tensorrt.dtype ) ) ) – The set of datatypes that TensorRT can use when selecting kernels
-
ir ( str ) – The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
-
**kwargs – Additional settings for the specific requested strategy (See submodules for more info)
-
- Returns
-
Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
- Return type
-
bytes
-
torch_tensorrt.
get_build_info
( ) → str ¶ -
Returns a string containing the build information of torch_tensorrt distribution
- Returns
-
String containing the build information for torch_tensorrt distribution
- Return type
-
str
-
torch_tensorrt.
dump_build_info
( ) ¶ -
Prints build information about the torch_tensorrt distribution to stdout
Classes ¶
-
class
torch_tensorrt.
Input
( * args , ** kwargs ) ¶ -
Defines an input to a module in terms of expected shape, data type and tensor format.
-
__init__
( * args , ** kwargs ) ¶ -
__init__ Method for torch_tensorrt.Input
Input accepts one of a few construction patterns
- Parameters
-
shape ( Tuple or List , optional ) – Static shape of input tensor
- Keyword Arguments
-
-
shape ( Tuple or List , optional ) – Static shape of input tensor
-
min_shape ( Tuple or List , optional ) – Min size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
-
opt_shape ( Tuple or List , optional ) – Opt size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
-
max_shape ( Tuple or List , optional ) – Max size of input tensor’s shape range Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input’s shape_mode to DYNAMIC
-
dtype ( torch.dtype or torch_tensorrt.dtype ) – Expected data type for input tensor (default: torch_tensorrt.dtype.float32)
-
format ( torch.memory_format or torch_tensorrt.TensorFormat ) – The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
-
Examples
-
Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
-
Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
-
Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW
-
dtype
= <dtype.unknown: 5> ¶ -
torch_tensorrt.dtype.float32)
- Type
-
The expected data type of the input tensor (default
-
format
= <TensorFormat.contiguous: 0> ¶ -
torch_tensorrt.TensorFormat.NCHW)
- Type
-
The expected format of the input tensor (default
-
shape
= None ¶ -
Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }
- Type
-
(Tuple or Dict)
-
shape_mode
= None ¶ -
Is input statically or dynamically shaped
- Type
-
(torch_tensorrt.Input._ShapeMode)
-
-
class
torch_tensorrt.
Device
( * args , ** kwargs ) ¶ -
Defines a device that can be used to specify target devices for engines
-
__init__
( * args , ** kwargs ) ¶ -
__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
- Parameters
-
spec ( str ) – String with device spec e.g. “dla:0” for dla, core_id 0
- Keyword Arguments
-
-
gpu_id ( int ) – ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided
-
dla_core ( int ) – ID of target DLA core. If specified, no positional arguments should be provided.
-
allow_gpu_fallback ( bool ) – Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)
-
Examples
-
Device(“gpu:1”)
-
Device(“cuda:1”)
-
Device(“dla:0”, allow_gpu_fallback=True)
-
Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)
-
Device(dla_core=0, allow_gpu_fallback=True)
-
Device(gpu_id=1)
-
allow_gpu_fallback
= False ¶ -
(bool) Whether falling back to GPU if DLA cannot support an op should be allowed
-
device_type
= None ¶ -
Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
- Type
-
dla_core
= -1 ¶ -
(int) Core ID for target DLA core
-
gpu_id
= -1 ¶ -
(int) Device ID for target GPU
-
Enums ¶
-
class
torch_tensorrt.
dtype
¶ -
Enum to specifiy operating precision for engine execution
Members:
float : 32 bit floating point number
float32 : 32 bit floating point number
half : 16 bit floating point number
float16 : 16 bit floating point number
int8 : 8 bit integer number
int32 : 32 bit integer number
bool : Boolean value
unknown : Unknown data type
-
class
torch_tensorrt.
DeviceType
¶ -
Enum to specify device kinds to build TensorRT engines for
Members:
GPU : Specify using GPU to execute TensorRT Engine
DLA : Specify using DLA to execute TensorRT Engine (Jetson Only)
-
class
torch_tensorrt.
EngineCapability
¶ -
Enum to specify engine capability settings (selections of kernels to meet safety requirements)
Members:
safe_gpu : Use safety GPU kernels only
safe_dla : Use safety DLA kernels only
default : Use default behavior
-
class
torch_tensorrt.
TensorFormat
¶ -
Enum to specifiy the memory layout of tensors
Members:
contiguous : Contiguous memory layout (NCHW / Linear)
channels_last : Channels last memory layout (NHWC)