{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Exporting model from PyTorch to ONNX\n", "\n", "In this tutorial, we describe how to use ONNX to convert a model defined\n", "in PyTorch into the ONNX format.\n", "\n", "ONNX exporter is part of the [PyTorch repository](http://pytorch.org/docs/master/onnx.html).\n", "\n", "For working with this tutorial, you will need to install [onnx](https://github.com/onnx/onnx). You can get binary builds of onnx with\n", "``conda install -c conda-forge onnx``.\n", "\n", "``NOTE``: ONNX is under active development so for the best support consider building PyTorch master branch which can be installed by following\n", "[the instructions here](https://github.com/pytorch/pytorch#from-source)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Invoking exporter\n", "\n", "Pretty much it's a matter of replacing `my_model(input)` with `torch.onnx.export(my_model, input, \"my_model.onnx\")` in your script.\n", "\n", "### Limitations\n", "\n", "The ONNX exporter is a trace-based exporter, which means that it operates by executing your model once, and exporting the operators which were actually run during this run. This means that if your model is dynamic, e.g., changes behavior depending on input data, the export won’t be accurate.\n", "\n", "Similarly, a trace is might be valid only for a specific input size (which is one reason why we require explicit inputs on tracing). Most of the operators export size-agnostic versions and should work on different batch sizes or input sizes. We recommend examining the model trace and making sure the traced operators look reasonable." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on function export in module torch.onnx:\n", "\n", "export(model, args, f, export_params=True, verbose=False, training=False)\n", " Export a model into ONNX format. This exporter runs your model\n", " once in order to get a trace of its execution to be exported; at the\n", " moment, it does not support dynamic models (e.g., RNNs.)\n", " \n", " See also: :ref:`onnx-export`\n", " \n", " Arguments:\n", " model (torch.nn.Module): the model to be exported.\n", " args (tuple of arguments): the inputs to\n", " the model, e.g., such that ``model(*args)`` is a valid\n", " invocation of the model. Any non-Variable arguments will\n", " be hard-coded into the exported model; any Variable arguments\n", " will become inputs of the exported model, in the order they\n", " occur in args. If args is a Variable, this is equivalent\n", " to having called it with a 1-ary tuple of that Variable.\n", " (Note: passing keyword arguments to the model is not currently\n", " supported. Give us a shout if you need it.)\n", " f: a file-like object (has to implement fileno that returns a file descriptor)\n", " or a string containing a file name. A binary Protobuf will be written\n", " to this file.\n", " export_params (bool, default True): if specified, all parameters will\n", " be exported. Set this to False if you want to export an untrained model.\n", " In this case, the exported model will first take all of its parameters\n", " as arguments, the ordering as specified by ``model.state_dict().values()``\n", " verbose (bool, default False): if specified, we will print out a debug\n", " description of the trace being exported.\n", " training (bool, default False): export the model in training mode. At\n", " the moment, ONNX is oriented towards exporting models for inference\n", " only, so you will generally not need to set this to True.\n", "\n" ] } ], "source": [ "import torch.onnx\n", "help(torch.onnx.export)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Trying it out on AlexNet\n", "\n", "If you already have your model built, it's just a few lines:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch.onnx\n", "import torchvision\n", "\n", "# Standard ImageNet input - 3 channels, 224x224,\n", "# values don't matter as we care about network structure.\n", "# But they can also be real inputs.\n", "dummy_input = torch.randn(1, 3, 224, 224)\n", "# Obtain your model, it can be also constructed in your script explicitly\n", "model = torchvision.models.alexnet(pretrained=True)\n", "# Invoke export\n", "torch.onnx.export(model, dummy_input, \"alexnet.onnx\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**That's it!**\n", "\n", "## Inspecting model\n", "\n", "You can also use ONNX tooling to check the validity of the resulting model or inspect the details" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "graph torch-jit-export (\n", " %0[FLOAT, 1x3x224x224]\n", ") initializers (\n", " %1[FLOAT, 64x3x11x11]\n", " %2[FLOAT, 64]\n", " %3[FLOAT, 192x64x5x5]\n", " %4[FLOAT, 192]\n", " %5[FLOAT, 384x192x3x3]\n", " %6[FLOAT, 384]\n", " %7[FLOAT, 256x384x3x3]\n", " %8[FLOAT, 256]\n", " %9[FLOAT, 256x256x3x3]\n", " %10[FLOAT, 256]\n", " %11[FLOAT, 4096x9216]\n", " %12[FLOAT, 4096]\n", " %13[FLOAT, 4096x4096]\n", " %14[FLOAT, 4096]\n", " %15[FLOAT, 1000x4096]\n", " %16[FLOAT, 1000]\n", ") {\n", " %17 = Conv[dilations = [1, 1], group = 1, kernel_shape = [11, 11], pads = [2, 2, 2, 2], strides = [4, 4]](%0, %1)\n", " %18 = Add[axis = 1, broadcast = 1](%17, %2)\n", " %19 = Relu(%18)\n", " %20 = MaxPool[kernel_shape = [3, 3], pads = [0, 0], strides = [2, 2]](%19)\n", " %21 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [2, 2, 2, 2], strides = [1, 1]](%20, %3)\n", " %22 = Add[axis = 1, broadcast = 1](%21, %4)\n", " %23 = Relu(%22)\n", " %24 = MaxPool[kernel_shape = [3, 3], pads = [0, 0], strides = [2, 2]](%23)\n", " %25 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%24, %5)\n", " %26 = Add[axis = 1, broadcast = 1](%25, %6)\n", " %27 = Relu(%26)\n", " %28 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%27, %7)\n", " %29 = Add[axis = 1, broadcast = 1](%28, %8)\n", " %30 = Relu(%29)\n", " %31 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%30, %9)\n", " %32 = Add[axis = 1, broadcast = 1](%31, %10)\n", " %33 = Relu(%32)\n", " %34 = MaxPool[kernel_shape = [3, 3], pads = [0, 0], strides = [2, 2]](%33)\n", " %35 = Reshape[shape = [1, 9216]](%34)\n", " %36, %37 = Dropout[is_test = 1, ratio = 0.5](%35)\n", " %38 = Transpose[perm = [1, 0]](%11)\n", " %40 = Gemm[alpha = 1, beta = 1, broadcast = 1](%36, %38, %12)\n", " %41 = Relu(%40)\n", " %42, %43 = Dropout[is_test = 1, ratio = 0.5](%41)\n", " %44 = Transpose[perm = [1, 0]](%13)\n", " %46 = Gemm[alpha = 1, beta = 1, broadcast = 1](%42, %44, %14)\n", " %47 = Relu(%46)\n", " %48 = Transpose[perm = [1, 0]](%15)\n", " %50 = Gemm[alpha = 1, beta = 1, broadcast = 1](%47, %48, %16)\n", " return %50\n", "}\n" ] } ], "source": [ "import onnx\n", "\n", "# Load the ONNX model\n", "model = onnx.load(\"alexnet.onnx\")\n", "\n", "# Check that the IR is well formed\n", "onnx.checker.check_model(model)\n", "\n", "# Print a human readable representation of the graph\n", "print(onnx.helper.printable_graph(model.graph))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that all parameters are listed as graph's inputs but they also have stored values initialized in `model.graph.initializers`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "## What's next\n", "\n", "Check [PyTorch documentation on onnx file](http://pytorch.org/docs/master/onnx.html)\n", "Take a look at [other tutorials, including importing of ONNX models to other frameworks](https://github.com/onnx/tutorials/tree/master/tutorials)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" } }, "nbformat": 4, "nbformat_minor": 2 }