Lowering Phase

The lowering phase is made up out of passes which are operations which map a graph from a high level representation to a lower level one. Each pass does something specific for instance inlining method calls. The idea is to significantly reduce what the conversion phase needs to be able to handle when actually mapping to TensorRT. We aim for closer to 1->1 op conversion vs looking for applicable subgraphs, limiting the number of converters and reduce the scope of each converter.

You can see the effects of each pass by setting the log level to Level::kGraph

Passes Used

EliminateCommonSubexpression

Removes common subexpressions in the graph

Eliminate Dead Code

Dead code elimination will check if a node has side effects and not delete it if it does.

Eliminate Exeception Or Pass Pattern

A common pattern in scripted modules are dimension gaurds which will throw execptions if the input dimension is not what was expected.

%1013 : bool = aten::ne(%1012, %24) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:11
    = prim::If(%1013) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:248:8
    block0():
        = prim::RaiseException(%23) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:249:12
    -> ()
    block1():
    -> ()

Since we are resolving all of this at compile time and there are no execptions in the TensorRT graph, we just remove it.

Eliminate Redundant Gaurds

Eliminate redundant guards for ops whose outputs are fully determined by their inputs i.e. if inputs to such ops are guarded we are allowed to remove a guard on ops’ outputs

Freeze Module

Freeze attributes and inline constants and modules. Propogates constants in the graph.

Fuse AddMM Branches

A common pattern in scripted modules is tensors of different dimensions use different constructions for implementing linear layers. We fuse these different varients into a single one that will get caught by the Unpack AddMM pass.

%ret : Tensor = prim::If(%622)
block0():
  %ret.1 : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)
  -> (%ret.1)
block1():
  %output.1 : Tensor = aten::matmul(%x9.1, %3677)
  %output0.1 : Tensor = aten::add_(%output.1, %self.fc.bias, %3)
  -> (%output0.1)

We fuse this set of blocks into a graph like this:

%ret : Tensor = aten::addmm(%self.fc.bias, %x9.1, %3677, %3, %3)

Fuse Linear

Match the aten::linear pattern and fuse it into a single aten::linear This pass fuse the addmm or matmul + add generated by JIT back to linear

Fuse Flatten Linear

TensorRT implicity flattens input layers into fully connected layers when they are higher than 1D. So when there is a aten::flatten -> aten::linear pattern we remove the aten::flatten .

Lower Graph

Given a graph with of a method which first argument is %self, lower it to a graph where all attributes accesses are replaced with explicit inputs of the graph (rather than results of prim::GetAttr executed on %self). Returns a tuple (graph, parameters) where the last module.parameters.size() inputs to the graph are the trainable parameters used in this method. The remaining inputs are the true inputs to the function.

Lower Tuples

  • LowerSimpleTuples :

Removes tuples where TupleConstruct and TupleUnpack are matched but leaves tuples in place across if statements, loops, and as inputs/outputs

  • LowerAllTuples :

Removes _all_ tuples and raises an error if some cannot be removed, this is used by ONNX to ensure there are not tuples before conversion, but will not work on graphs whose inputs contain tuples.

Module Fallback

Torch-TensorRT/core/lowering/passes/module_fallback.cpp <https://github.com/nvidia/Torch-TensorRT/blob/master/core/lowering/passes/module_fallback.cpp>

Module fallback consists of two lowering passes that must be run as a pair. The first pass is run before freezing to place delimiters in the graph around modules that should run in PyTorch. The second pass marks nodes between these delimiters after freezing to signify they should run in PyTorch.

  • NotateModuleForFallback

Places delimiting nodes around module calls pre freezing to signify where in the graph nodes should run in PyTorch

  • MarkNodesForFallback

Looks for delimiters then marks all nodes between the delimiters to tell partitioning to run them in PyTorch

Peephole Optimze

The intent for this optimization pass is to catch all of the small, easy to catch peephole optimizations you might be interested in doing.

Right now, it does:
  • Eliminate no-op ‘expand’ nodes

  • Simply x.t().t() to x

Remove Contiguous

Removes contiguous operators since we are doing TensorRT memory is already contiguous.

Remove Dropout

Removes dropout operators since we are doing inference.

Remove To

Removes aten::to operators that do casting, since TensorRT mangages it itself. It is important that this is one of the last passes run so that other passes have a change to move required cast operators out of the main namespace.

Unpack AddMM

Unpacks aten::addmm into aten::matmul and aten::add_ (with an additional trt::const op to freeze the bias in the TensorRT graph). This lets us reuse the aten::matmul and aten::add_ converters instead of needing a dedicated converter.

Unpack LogSoftmax

Unpacks aten::logsoftmax into aten::softmax and aten::log . This lets us reuse the aten::softmax and aten::log converters instead of needing a dedicated converter.

Unroll Loops

Unrolls the operations of compatable loops (e.g. sufficently short) so that you only have to go through the loop once.