{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualizing JIT Modules\n",
"\n",
"Copyright 2020 by Thomas Viehmann\n",
"\n",
"This is the code for my blog post [Visualize PyTorch models](https://lernapparat.de/visualize-pytorch-models/).\n",
"The code has been made available thanks to my [single github sponsor](https://github.com/sponsors/t-vi/) at the time of writing. Thank you!\n",
"\n",
"I license this code with the CC-BY-SA 4.0 license. Please link to my blog post or the original github source (linked from the blog post) with the attribution notice.\n",
"\n",
"\n",
"## Introduction\n",
"\n",
"Did you ever wish to get a concise picture of your PyTorch model's structure and found that too hard to get?\n",
"\n",
"\n",
"Recently, I did some work that involved looking at model structure in some detail. For my write-up, I wanted to get a diagram of some model structures. Even though it is a relatively common model, searching for a diagram didn't turn up something in the shape what I was looking for.\n",
"\n",
"So how do can we get model structure for PyTorch models? The first stop probably is the neat string representation that PyTorch provides for `nn.Modules` - even without doing anything, it'll also cover our custom models pretty well. It is, however not without shortcomings.\n",
"\n",
"Let's look at TorchVision's ResNet18 basic block as an example."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
")"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torchvision\n",
"m = torchvision.models.resnet18()\n",
"m.layer1[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we have two convs and two batch norms. But how are things connected? Is there one ReLU?\n",
"\n",
"Looking at the forward method (you can get this using Python's `inspect` module or `??` in IPython), we see some important details not in the summary:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" def forward(self, x: Tensor) -> Tensor:\n",
" identity = x\n",
"\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
"\n",
" if self.downsample is not None:\n",
" identity = self.downsample(x)\n",
"\n",
" out += identity\n",
" out = self.relu(out)\n",
"\n",
" return out\n",
"\n"
]
}
],
"source": [
"import inspect\n",
"print(inspect.getsource(m.layer1[0].forward))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we missed the entire residual bit. Also, there are two ReLUs. Arguably, it is wrong to re-use stateless modules like this. It'll haunt you when you do things like quantization (because it becomes stateful then due to the quantization parameters) and it's mixing things too much. If you want stateless, use the functional interface.\n",
"\n",
"But so we can build a visualization based on JITed modules.\n",
"\n",
"We recurse into calls to make subgraphs and we have to take some care that the edges connecting the subgraph to the outer graph need to be part of the outer graph, but other than that, it is very straightforward, even though the details are messy."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import graphviz\n",
"\n",
"def make_graph(mod, classes_to_visit=None, classes_found=None, dot=None, prefix=\"\",\n",
" input_preds=None, \n",
" parent_dot=None):\n",
" preds = {}\n",
" \n",
" def find_name(i, self_input, suffix=None):\n",
" if i == self_input:\n",
" return suffix\n",
" cur = i.node().s(\"name\")\n",
" if suffix is not None:\n",
" cur = cur + '.' + suffix\n",
" of = next(i.node().inputs())\n",
" return find_name(of, self_input, suffix=cur)\n",
"\n",
" gr = mod.graph\n",
" toshow = []\n",
" # list(traced_model.graph.nodes())[0]\n",
" self_input = next(gr.inputs())\n",
" self_type = self_input.type().str().split('.')[-1]\n",
" preds[self_input] = (set(), set()) # inps, ops\n",
" \n",
" if dot is None:\n",
" dot = graphviz.Digraph(format='svg', graph_attr={'label': self_type, 'labelloc': 't'})\n",
" #dot.attr('node', shape='box')\n",
"\n",
" seen_inpnames = set()\n",
" seen_edges = set()\n",
" \n",
" def add_edge(dot, n1, n2):\n",
" if (n1, n2) not in seen_edges:\n",
" seen_edges.add((n1, n2))\n",
" dot.edge(n1, n2)\n",
"\n",
" def make_edges(pr, inpname, name, op, edge_dot=dot):\n",
" if op:\n",
" if inpname not in seen_inpnames:\n",
" seen_inpnames.add(inpname)\n",
" label_lines = [[]]\n",
" line_len = 0\n",
" for w in op:\n",
" if line_len >= 20:\n",
" label_lines.append([])\n",
" line_len = 0\n",
" label_lines[-1].append(w)\n",
" line_len += len(w) + 1\n",
" edge_dot.node(inpname, label='\\n'.join([' '.join(w) for w in label_lines]), shape='box', style='rounded')\n",
" for p in pr:\n",
" add_edge(edge_dot, p, inpname)\n",
" add_edge(edge_dot, inpname, name)\n",
" else:\n",
" for p in pr:\n",
" add_edge(edge_dot, p, name)\n",
"\n",
" for nr, i in enumerate(list(gr.inputs())[1:]):\n",
" name = prefix+'inp_'+i.debugName()\n",
" preds[i] = {name}, set()\n",
" dot.node(name, shape='ellipse')\n",
" if input_preds is not None:\n",
" pr, op = input_preds[nr]\n",
" make_edges(pr, 'inp_'+name, name, op, edge_dot=parent_dot)\n",
" \n",
" def is_relevant_type(t):\n",
" kind = t.kind()\n",
" if kind == 'TensorType':\n",
" return True\n",
" if kind in ('ListType', 'OptionalType'):\n",
" return is_relevant_type(t.getElementType())\n",
" if kind == 'TupleType':\n",
" return any([is_relevant_type(tt) for tt in t.elements()])\n",
" return False\n",
"\n",
" for n in gr.nodes():\n",
" only_first_ops = {'aten::expand_as'}\n",
" rel_inp_end = 1 if n.kind() in only_first_ops else None\n",
" \n",
" relevant_inputs = [i for i in list(n.inputs())[:rel_inp_end] if is_relevant_type(i.type())]\n",
" relevant_outputs = [o for o in n.outputs() if is_relevant_type(o.type())]\n",
" if n.kind() == 'prim::CallMethod':\n",
" fq_submodule_name = '.'.join([nc for nc in list(n.inputs())[0].type().str().split('.') if not nc.startswith('__')])\n",
" submodule_type = list(n.inputs())[0].type().str().split('.')[-1]\n",
" submodule_name = find_name(list(n.inputs())[0], self_input)\n",
" name = prefix+'.'+n.output().debugName()\n",
" label = prefix+submodule_name+' (' + submodule_type + ')'\n",
" if classes_found is not None:\n",
" classes_found.add(fq_submodule_name)\n",
" if ((classes_to_visit is None and\n",
" (not fq_submodule_name.startswith('torch.nn') or \n",
" fq_submodule_name.startswith('torch.nn.modules.container')))\n",
" or (classes_to_visit is not None and \n",
" (submodule_type in classes_to_visit\n",
" or fq_submodule_name in classes_to_visit))):\n",
" # go into subgraph\n",
" sub_prefix = prefix+submodule_name+'.'\n",
" with dot.subgraph(name=\"cluster_\"+name) as sub_dot:\n",
" sub_dot.attr(label=label)\n",
" submod = mod\n",
" for k in submodule_name.split('.'):\n",
" submod = getattr(submod, k)\n",
" make_graph(submod, dot=sub_dot, prefix=sub_prefix,\n",
" input_preds = [preds[i] for i in list(n.inputs())[1:]],\n",
" parent_dot=dot, classes_to_visit=classes_to_visit,\n",
" classes_found=classes_found)\n",
" for i, o in enumerate(n.outputs()):\n",
" preds[o] = {sub_prefix+f'out_{i}'}, set()\n",
" else:\n",
" dot.node(name, label=label, shape='box')\n",
" for i in relevant_inputs:\n",
" pr, op = preds[i]\n",
" make_edges(pr, prefix+i.debugName(), name, op)\n",
" for o in n.outputs():\n",
" preds[o] = {name}, set()\n",
" elif n.kind() == 'prim::CallFunction':\n",
" funcname = list(n.inputs())[0].type().__repr__().split('.')[-1]\n",
" name = prefix+'.'+n.output().debugName()\n",
" label = funcname\n",
" dot.node(name, label=label, shape='box')\n",
" for i in relevant_inputs:\n",
" pr, op = preds[i]\n",
" make_edges(pr, prefix+i.debugName(), name, op)\n",
" for o in n.outputs():\n",
" preds[o] = {name}, set()\n",
" else:\n",
" unseen_ops = {'prim::ListConstruct', 'prim::TupleConstruct', 'aten::index', \n",
" 'aten::size', 'aten::slice', 'aten::unsqueeze', 'aten::squeeze',\n",
" 'aten::to', 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous',\n",
" 'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', 'aten::unbind',\n",
" 'aten::select', 'aten::detach', 'aten::stack', 'aten::reshape', 'aten::split_with_sizes',\n",
" 'aten::cat', 'aten::expand', 'aten::expand_as', 'aten::_shape_as_tensor',\n",
" }\n",
" \n",
" absorbing_ops = ('aten::size', 'aten::_shape_as_tensor') # probably also partially absorbing ops. :/\n",
" if False:\n",
" print(n.kind())\n",
" #DEBUG['kinds'].add(n.kind())\n",
" #DEBUG[n.kind()] = n\n",
" label = n.kind().split('::')[-1].rstrip('_')\n",
" name = prefix+'.'+relevant_outputs[0].debugName()\n",
" dot.node(name, label=label, shape='box', style='rounded')\n",
" for i in relevant_inputs:\n",
" pr, op = preds[i]\n",
" make_edges(pr, prefix+i.debugName(), name, op)\n",
" for o in n.outputs():\n",
" preds[o] = {name}, set()\n",
" if True:\n",
" label = n.kind().split('::')[-1].rstrip('_')\n",
" pr, op = set(), set()\n",
" for i in relevant_inputs:\n",
" apr, aop = preds[i]\n",
" pr |= apr\n",
" op |= aop\n",
" if pr and n.kind() not in unseen_ops:\n",
" print(n.kind(), n)\n",
" if n.kind() in absorbing_ops:\n",
" pr, op = set(), set()\n",
" elif len(relevant_inputs) > 0 and len(relevant_outputs) > 0 and n.kind() not in unseen_ops:\n",
" op.add(label)\n",
" for o in n.outputs():\n",
" preds[o] = pr, op\n",
"\n",
" for i, o in enumerate(gr.outputs()):\n",
" name = prefix+f'out_{i}'\n",
" dot.node(name, shape='ellipse')\n",
" pr, op = preds[o]\n",
" make_edges(pr, 'inp_'+name, name, op)\n",
" return dot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Applications\n",
"\n",
"\n",
"Let's apply it! These are the pictures from my blog post along with the code that generated them.\n",
"\n",
"The following code is from the [transformers library](https://github.com/huggingface/transformers/) (Copyright 2018- The Hugging Face team. Apache Licensed.)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.11.0'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"import transformers\n",
"\n",
"from transformers import BertModel, BertTokenizer, BertConfig\n",
"import numpy\n",
"\n",
"import torch\n",
"\n",
"enc = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"\n",
"# Tokenizing input text\n",
"text = \"[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]\"\n",
"tokenized_text = enc.tokenize(text)\n",
"\n",
"# Masking one of the input tokens\n",
"masked_index = 8\n",
"tokenized_text[masked_index] = '[MASK]'\n",
"indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)\n",
"segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]\n",
"\n",
"# Creating a dummy input\n",
"tokens_tensor = torch.tensor([indexed_tokens])\n",
"segments_tensors = torch.tensor([segments_ids])\n",
"dummy_input = [tokens_tensor, segments_tensors]\n",
"\n",
"# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag\n",
"model = BertModel.from_pretrained(\"bert-base-uncased\", torchscript=True)\n",
"\n",
"model.eval()\n",
"for p in model.parameters():\n",
" p.requires_grad_(False)\n",
"\n",
"transformers.__version__"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Creating the trace\n",
"traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])\n",
"traced_model.eval()\n",
"for p in traced_model.parameters():\n",
" p.requires_grad_(False)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"if 0:\n",
" # resolvign functions?\n",
" t = fn.type()\n",
" def lookup(fn):\n",
" n = str(fn.type()).split('.')[1:]\n",
" res = globals()[n[0]]\n",
" for nc in n[1:]:\n",
" res = getattr(res, nc)\n",
" return res\n",
" lookup(fn).graph"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"aten::rsub %668 : Float(1:14, 1:14, 1:14, 14:1) = aten::rsub(%665, %666, %667) # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:395:0\n",
"\n",
"aten::mul %attention_mask : Float(1:14, 1:14, 1:14, 14:1) = aten::mul(%668, %669) # /usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:228:0\n",
"\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d = make_graph(traced_model, classes_to_visit={'BertEncoder'})\n",
"d.render('bert_model')\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"aten::matmul %attention_scores.1 : Float(1:2352, 12:196, 14:14, 14:1) = aten::matmul(%query_layer.1, %75), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:236:0\n",
"\n",
"aten::div %attention_scores.2 : Float(1:2352, 12:196, 14:14, 14:1) = aten::div(%attention_scores.1, %77), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:237:0\n",
"\n",
"aten::add %input.6 : Float(1:2352, 12:196, 14:14, 14:1) = aten::add(%attention_scores.2, %attention_mask, %79), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:240:0\n",
"\n",
"aten::softmax %input.7 : Float(1:2352, 12:196, 14:14, 14:1) = aten::softmax(%input.6, %81, %82), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1498:0\n",
"\n",
"aten::matmul %context_layer.1 : Float(1:10752, 12:896, 14:64, 64:1) = aten::matmul(%114, %value_layer.1), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:253:0\n",
"\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mod = getattr(traced_model.encoder.layer, \"0\") # traced_model.encoder.layer[0]\n",
"d = make_graph(getattr(traced_model.encoder.layer, \"0\"), classes_to_visit={'BertAttention', 'BertSelfAttention'})\n",
"d.render('bert_layer')\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import torchvision"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"m = torchvision.models.resnet18()\n",
"tm = torch.jit.trace(m, [torch.randn(1, 3, 224, 224)])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BasicBlock(\n",
" (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (relu): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
")"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = torchvision.models.resnet18()\n",
"m.layer1[0]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" def forward(self, x):\n",
" identity = x\n",
"\n",
" out = self.conv1(x)\n",
" out = self.bn1(out)\n",
" out = self.relu(out)\n",
"\n",
" out = self.conv2(out)\n",
" out = self.bn2(out)\n",
"\n",
" if self.downsample is not None:\n",
" identity = self.downsample(x)\n",
"\n",
" out += identity\n",
" out = self.relu(out)\n",
"\n",
" return out\n",
"\n"
]
}
],
"source": [
"print(inspect.getsource(m.layer1[0].forward))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"aten::add_ %input.10 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n",
"aten::add_ %input.16 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n",
"aten::add_ %input.23 : Float(1:100352, 128:784, 28:28, 28:1) = aten::add_(%21, %22, %14), scope: __module.layer2/__module.layer2.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n",
"aten::add_ %input.29 : Float(1:100352, 128:784, 28:28, 28:1) = aten::add_(%19, %1, %12), scope: __module.layer2/__module.layer2.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n",
"aten::add_ %input.36 : Float(1:50176, 256:196, 14:14, 14:1) = aten::add_(%21, %22, %14), scope: __module.layer3/__module.layer3.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n",
"aten::add_ %input.42 : Float(1:50176, 256:196, 14:14, 14:1) = aten::add_(%19, %1, %12), scope: __module.layer3/__module.layer3.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n",
"aten::add_ %input.49 : Float(1:25088, 512:49, 7:7, 7:1) = aten::add_(%21, %22, %14), scope: __module.layer4/__module.layer4.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n",
"aten::add_ %input.55 : Float(1:25088, 512:49, 7:7, 7:1) = aten::add_(%19, %1, %12), scope: __module.layer4/__module.layer4.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n",
"aten::flatten %input : Float(1:512, 512:1) = aten::flatten(%1536, %1182, %1183) # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:214:0\n",
"\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d = make_graph(tm)\n",
"d.render(\"resnet18_full\")\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"aten::flatten %input : Float(1:512, 512:1) = aten::flatten(%1536, %1182, %1183) # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:214:0\n",
"\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d = make_graph(tm, classes_to_visit={'Sequential'})\n",
"d.render(\"resnet18_highlevel\")\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"aten::add_ %input.10 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0\n",
"\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d = make_graph(getattr(tm.layer1, \"0\"))\n",
"d.render(\"resnet18_basicblock\")\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"aten::upsample_bilinear2d %3096 : Float(1:1053696, 21:50176, 224:224, 224:1) = aten::upsample_bilinear2d(%3955, %3092, %3093, %3094, %3095) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3163:0\n",
"\n",
"prim::DictConstruct %3098 : Dict(str, Tensor) = prim::DictConstruct(%3097, %3096)\n",
"\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = torchvision.models.segmentation.fcn_resnet50()\n",
"tm = torch.jit.trace(m, [torch.randn(1, 3, 224, 224)], strict=False)\n",
"d = make_graph(tm, classes_to_visit={'IntermediateLayerGetter', 'FCNHead'})\n",
"d.render(\"segmentation_fcn_high_level\")\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
":6: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
" assert inp.shape[0] == 1\n",
"/usr/local/lib/python3.8/dist-packages/torch/tensor.py:457: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).\n",
" warnings.warn('Iterating over a tensor might cause the trace to be incorrect. '\n",
"/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3000: UserWarning: The default behavior for interpolate/upsample with float scale_factor will change in 1.6.0 to align with other frameworks/libraries, and use scale_factor directly, instead of relying on the computed output size. If you wish to keep the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. \n",
" warnings.warn(\"The default behavior for interpolate/upsample with float scale_factor will change \"\n",
"/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3009: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" return [(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i],\n",
"/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:163: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),\n",
"/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:164: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]\n",
"/usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))\n",
"/usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))\n",
"/usr/local/lib/python3.8/dist-packages/torchvision/ops/poolers.py:216: UserWarning: This overload of nonzero is deprecated:\n",
"\tnonzero(Tensor input, *, Tensor out)\n",
"Consider using one of the following signatures instead:\n",
"\tnonzero(Tensor input, *, bool as_tuple) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:761.)\n",
" idx_in_level = torch.nonzero(levels == level).squeeze(1)\n",
"/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" torch.tensor(s, dtype=torch.float32, device=boxes.device) /\n",
"/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:270: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)\n"
]
}
],
"source": [
"class Detection(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.m = torchvision.models.detection.fasterrcnn_resnet50_fpn().eval()\n",
" def forward(self, inp):\n",
" assert inp.shape[0] == 1\n",
" res, = self.m(inp)\n",
" return res['boxes'], res['labels'], res['scores']\n",
"\n",
"tm = torch.jit.trace(Detection(), [torch.randn(1, 3, 224, 224)], check_trace=False)\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"aten::div %ratio_height : Float() = aten::div(%52, %60), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0\n",
"\n",
"aten::div %ratio_width : Float() = aten::div(%69, %77), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0\n",
"\n",
"aten::mul %xmin : Float(0:1) = aten::mul(%xmin.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:276:0\n",
"\n",
"aten::mul %xmax : Float(0:1) = aten::mul(%xmax.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:277:0\n",
"\n",
"aten::mul %ymin : Float(0:1) = aten::mul(%ymin.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:278:0\n",
"\n",
"aten::mul %ymax : Float(0:1) = aten::mul(%ymax.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:279:0\n",
"\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d = make_graph(tm.m, classes_to_visit={})\n",
"d.render('fasterrcnn.highlevel')\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"aten::flatten %objectness.1 : Float(159882:1, 1:1) = aten::flatten(%434, %435, %436), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:248:0\n",
"\n",
"aten::sub %widths.1 : Float(159882:1) = aten::sub(%472, %480, %481), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:192:0\n",
"\n",
"aten::sub %heights.1 : Float(159882:1) = aten::sub(%490, %498, %499), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:193:0\n",
"\n",
"aten::mul %510 : Float(159882:1) = aten::mul(%widths.1, %509), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0\n",
"\n",
"aten::add %ctr_x.1 : Float(159882:1) = aten::add(%508, %510, %511), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0\n",
"\n",
"aten::mul %522 : Float(159882:1) = aten::mul(%heights.1, %521), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0\n",
"\n",
"aten::add %ctr_y.1 : Float(159882:1) = aten::add(%520, %522, %523), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0\n",
"\n",
"aten::div %dx.1 : Float(159882:1, 1:1) = aten::div(%534, %535), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:198:0\n",
"\n",
"aten::div %dy.1 : Float(159882:1, 1:1) = aten::div(%546, %547), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:199:0\n",
"\n",
"aten::div %dw.1 : Float(159882:1, 1:1) = aten::div(%558, %559), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:200:0\n",
"\n",
"aten::div %dh.1 : Float(159882:1, 1:1) = aten::div(%570, %571), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:201:0\n",
"\n",
"aten::clamp %dw.2 : Float(159882:1, 1:1) = aten::clamp(%dw.1, %573, %574), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:204:0\n",
"\n",
"aten::clamp %dh.2 : Float(159882:1, 1:1) = aten::clamp(%dh.1, %576, %577), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:205:0\n",
"\n",
"aten::mul %586 : Float(159882:1, 1:1) = aten::mul(%dx.1, %585), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0\n",
"\n",
"aten::add %pred_ctr_x.1 : Float(159882:1, 1:1) = aten::add(%586, %593, %594), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0\n",
"\n",
"aten::mul %603 : Float(159882:1, 1:1) = aten::mul(%dy.1, %602), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0\n",
"\n",
"aten::add %pred_ctr_y.1 : Float(159882:1, 1:1) = aten::add(%603, %610, %611), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0\n",
"\n",
"aten::exp %613 : Float(159882:1, 1:1) = aten::exp(%dw.2), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0\n",
"\n",
"aten::mul %pred_w.1 : Float(159882:1, 1:1) = aten::mul(%613, %620), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0\n",
"\n",
"aten::exp %622 : Float(159882:1, 1:1) = aten::exp(%dh.2), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0\n",
"\n",
"aten::mul %pred_h.1 : Float(159882:1, 1:1) = aten::mul(%622, %629), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0\n",
"\n",
"aten::mul %639 : Float(159882:1, 1:1) = aten::mul(%638, %pred_w.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0\n",
"\n",
"aten::sub %pred_boxes1.1 : Float(159882:1, 1:1) = aten::sub(%pred_ctr_x.1, %639, %640), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0\n",
"\n",
"aten::mul %650 : Float(159882:1, 1:1) = aten::mul(%649, %pred_h.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0\n",
"\n",
"aten::sub %pred_boxes2.1 : Float(159882:1, 1:1) = aten::sub(%pred_ctr_y.1, %650, %651), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0\n",
"\n",
"aten::mul %661 : Float(159882:1, 1:1) = aten::mul(%660, %pred_w.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0\n",
"\n",
"aten::add %pred_boxes3.1 : Float(159882:1, 1:1) = aten::add(%pred_ctr_x.1, %661, %662), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0\n",
"\n",
"aten::mul %672 : Float(159882:1, 1:1) = aten::mul(%671, %pred_h.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0\n",
"\n",
"aten::add %pred_boxes4.1 : Float(159882:1, 1:1) = aten::add(%pred_ctr_y.1, %672, %673), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0\n",
"\n",
"aten::flatten %pred_boxes.1 : Float(159882:4, 4:1) = aten::flatten(%677, %678, %679), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:216:0\n",
"\n",
"aten::topk %778 : Float(1:1000, 1000:1), %top_n_idx.1 : Long(1:1000, 1000:1) = aten::topk(%x.54, %774, %775, %776, %777), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0\n",
"\n",
"aten::add %782 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.1, %780, %781), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0\n",
"\n",
"aten::topk %808 : Float(1:1000, 1000:1), %top_n_idx.2 : Long(1:1000, 1000:1) = aten::topk(%x.55, %804, %805, %806, %807), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0\n",
"\n",
"aten::add %811 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.2, %offset.1, %810), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0\n",
"\n",
"aten::topk %836 : Float(1:1000, 1000:1), %top_n_idx.3 : Long(1:1000, 1000:1) = aten::topk(%x.56, %832, %833, %834, %835), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0\n",
"\n",
"aten::add %839 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.3, %offset.2, %838), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0\n",
"\n",
"aten::topk %864 : Float(1:1000, 1000:1), %top_n_idx.4 : Long(1:1000, 1000:1) = aten::topk(%x.57, %860, %861, %862, %863), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0\n",
"\n",
"aten::add %867 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.4, %offset.3, %866), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0\n",
"\n",
"aten::topk %892 : Float(1:507, 507:1), %top_n_idx.5 : Long(1:507, 507:1) = aten::topk(%x.58, %888, %889, %890, %891), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0\n",
"\n",
"aten::add %895 : Long(1:507, 507:1) = aten::add(%top_n_idx.5, %offset, %894), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0\n",
"\n",
"aten::max %boxes_x.2 : Float(4507:2, 2:1) = aten::max(%boxes_x.1, %1002), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:124:0\n",
"\n",
"aten::min %boxes_x.3 : Float(4507:2, 2:1) = aten::min(%boxes_x.2, %1011), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125:0\n",
"\n",
"aten::max %boxes_y.2 : Float(4507:2, 2:1) = aten::max(%boxes_y.1, %1020), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:126:0\n",
"\n",
"aten::min %boxes_y.3 : Float(4507:2, 2:1) = aten::min(%boxes_y.2, %1029), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127:0\n",
"\n",
"aten::sub %1061 : Float(4507:1) = aten::sub(%1051, %1059, %1060), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0\n",
"\n",
"aten::sub %1079 : Float(4507:1) = aten::sub(%1069, %1077, %1078), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0\n",
"\n",
"aten::ge %1081 : Bool(4507:1) = aten::ge(%1061, %1080), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0\n",
"\n",
"aten::ge %1083 : Bool(4507:1) = aten::ge(%1079, %1082), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0\n",
"\n",
"aten::__and__ %keep.1 : Bool(4507:1) = aten::__and__(%1081, %1083), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:101:0\n",
"\n",
"aten::nonzero %1085 : Long(4507:1, 1:1) = aten::nonzero(%keep.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:102:0\n",
"\n",
"aten::sub %widths : Float(1000:1) = aten::sub(%61, %69, %70), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:192:0\n",
"\n",
"aten::sub %heights : Float(1000:1) = aten::sub(%79, %87, %88), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:193:0\n",
"\n",
"aten::mul %99 : Float(1000:1) = aten::mul(%widths, %98), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0\n",
"\n",
"aten::add %ctr_x : Float(1000:1) = aten::add(%97, %99, %100), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0\n",
"\n",
"aten::mul %111 : Float(1000:1) = aten::mul(%heights, %110), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0\n",
"\n",
"aten::add %ctr_y : Float(1000:1) = aten::add(%109, %111, %112), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0\n",
"\n",
"aten::div %dx : Float(1000:91, 91:1) = aten::div(%123, %124), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:198:0\n",
"\n",
"aten::div %dy : Float(1000:91, 91:1) = aten::div(%135, %136), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:199:0\n",
"\n",
"aten::div %dw.3 : Float(1000:91, 91:1) = aten::div(%147, %148), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:200:0\n",
"\n",
"aten::div %dh.3 : Float(1000:91, 91:1) = aten::div(%159, %160), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:201:0\n",
"\n",
"aten::clamp %dw : Float(1000:91, 91:1) = aten::clamp(%dw.3, %162, %163), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:204:0\n",
"\n",
"aten::clamp %dh : Float(1000:91, 91:1) = aten::clamp(%dh.3, %165, %166), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:205:0\n",
"\n",
"aten::mul %175 : Float(1000:91, 91:1) = aten::mul(%dx, %174), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0\n",
"\n",
"aten::add %pred_ctr_x : Float(1000:91, 91:1) = aten::add(%175, %182, %183), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0\n",
"\n",
"aten::mul %192 : Float(1000:91, 91:1) = aten::mul(%dy, %191), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0\n",
"\n",
"aten::add %pred_ctr_y : Float(1000:91, 91:1) = aten::add(%192, %199, %200), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0\n",
"\n",
"aten::exp %202 : Float(1000:91, 91:1) = aten::exp(%dw), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0\n",
"\n",
"aten::mul %pred_w : Float(1000:91, 91:1) = aten::mul(%202, %209), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0\n",
"\n",
"aten::exp %211 : Float(1000:91, 91:1) = aten::exp(%dh), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0\n",
"\n",
"aten::mul %pred_h : Float(1000:91, 91:1) = aten::mul(%211, %218), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0\n",
"\n",
"aten::mul %228 : Float(1000:91, 91:1) = aten::mul(%227, %pred_w), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0\n",
"\n",
"aten::sub %pred_boxes1 : Float(1000:91, 91:1) = aten::sub(%pred_ctr_x, %228, %229), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0\n",
"\n",
"aten::mul %239 : Float(1000:91, 91:1) = aten::mul(%238, %pred_h), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0\n",
"\n",
"aten::sub %pred_boxes2 : Float(1000:91, 91:1) = aten::sub(%pred_ctr_y, %239, %240), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0\n",
"\n",
"aten::mul %250 : Float(1000:91, 91:1) = aten::mul(%249, %pred_w), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0\n",
"\n",
"aten::add %pred_boxes3 : Float(1000:91, 91:1) = aten::add(%pred_ctr_x, %250, %251), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0\n",
"\n",
"aten::mul %261 : Float(1000:91, 91:1) = aten::mul(%260, %pred_h), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0\n",
"\n",
"aten::add %pred_boxes4 : Float(1000:91, 91:1) = aten::add(%pred_ctr_y, %261, %262), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0\n",
"\n",
"aten::flatten %pred_boxes : Float(1000:364, 364:1) = aten::flatten(%266, %267, %268), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:216:0\n",
"\n",
"aten::softmax %276 : Float(1000:91, 91:1) = aten::softmax(%18, %274, %275), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1498:0\n",
"\n",
"aten::max %boxes_x.5 : Float(1000:182, 91:2, 2:1) = aten::max(%boxes_x.4, %302), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:124:0\n",
"\n",
"aten::min %boxes_x : Float(1000:182, 91:2, 2:1) = aten::min(%boxes_x.5, %311), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125:0\n",
"\n",
"aten::max %boxes_y.5 : Float(1000:182, 91:2, 2:1) = aten::max(%boxes_y.4, %320), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:126:0\n",
"\n",
"aten::min %boxes_y : Float(1000:182, 91:2, 2:1) = aten::min(%boxes_y.5, %329), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127:0\n",
"\n",
"aten::gt %399 : Bool(90000:1) = aten::gt(%scores.5, %398), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0\n",
"\n",
"aten::nonzero %400 : Long(0:1, 1:1) = aten::nonzero(%399), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py:699:0\n",
"\n",
"aten::sub %450 : Float(0:1) = aten::sub(%440, %448, %449), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0\n",
"\n",
"aten::sub %468 : Float(0:1) = aten::sub(%458, %466, %467), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0\n",
"\n",
"aten::ge %470 : Bool(0:1) = aten::ge(%450, %469), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0\n",
"\n",
"aten::ge %472 : Bool(0:1) = aten::ge(%468, %471), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0\n",
"\n",
"aten::__and__ %keep.10 : Bool(0:1) = aten::__and__(%470, %472), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:101:0\n",
"\n",
"aten::nonzero %474 : Long(0:1, 1:1) = aten::nonzero(%keep.10), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:102:0\n",
"\n",
"aten::div %ratio_height : Float() = aten::div(%52, %60), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0\n",
"\n",
"aten::div %ratio_width : Float() = aten::div(%69, %77), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0\n",
"\n",
"aten::mul %xmin : Float(0:1) = aten::mul(%xmin.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:276:0\n",
"\n",
"aten::mul %xmax : Float(0:1) = aten::mul(%xmax.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:277:0\n",
"\n",
"aten::mul %ymin : Float(0:1) = aten::mul(%ymin.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:278:0\n",
"\n",
"aten::mul %ymax : Float(0:1) = aten::mul(%ymax.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:279:0\n",
"\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d = make_graph(tm.m, classes_to_visit={'RegionProposalNetwork', 'RoIHeads'})\n",
"d.render(\"fasterrcnn.detail\")\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}