{ "cells": [ { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "# Train in Tensorflow, Export to ONNX\n", "In this tutorial, we will demonstrate the complete process of training a MNIST model in Tensorflow and exporting the trained model to ONNX.\n", "\n", "### Training\n", "\n", "Firstly, we can initiate the [training script](./assets/tf-train-mnist.py) by issuing the command `python tf-train-mnist.py` on your terminal. Shortly, we should obtain a trained MNIST model. The training process needs no special instrumentation. However, to successfully convert the trained model, onnx-tensorflow requires three pieces of information, all of which can be obtained after training is complete:\n", "\n", " - *Graph definition*: You need to obtain information about the graph definition in the form of GraphProto. The easiest way to achieve this is to use the following snippet of code as shown in the example training script:\n", "```\n", " with open(\"graph.proto\", \"wb\") as file:\n", " graph = tf.get_default_graph().as_graph_def(add_shapes=True)\n", " file.write(graph.SerializeToString())\n", "```\n", " - *Shape information*: By default, `as_graph_def` does not serialize any information about the shapes of the intermediate tensor and such information is required by onnx-tensorflow. Thus we request Tensorflow to serialize the shape information by adding the keyword argument `add_shapes=True` as demonstrated above.\n", " - *Checkpoint*: Tensorflow checkpoint files contain information about the obtained weight; thus they are needed to convert the trained model to ONNX format.\n", "\n", "### Graph Freezing\n", "\n", "Secondly, we freeze the graph. Here, we include quotes from Tensorflow documentation about what graph freezing is:\n", "> One confusing part about this is that the weights usually aren't stored inside the file format during training. Instead, they're held in separate checkpoint files, and there are Variable ops in the graph that load the latest values when they're initialized. It's often not very convenient to have separate files when you're deploying to production, so there's the freeze_graph.py script that takes a graph definition and a set of checkpoints and freezes them together into a single file.\n", "\n", "Thus here we build the freeze_graph tool in the Tensorflow source folder and execute it with the information about where the GraphProto is, where the checkpoint file is and where to put the frozen graph. One caveat is that you need to supply the name of the output node to this utility. If you are having trouble finding the name of the output node, please refer to [this article](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md#inspecting-graphs) for help.\n", "```\n", "bazel build tensorflow/python/tools:freeze_graph\n", "bazel-bin/tensorflow/python/tools/freeze_graph \\\n", " --input_graph=/home/mnist-tf/graph.proto \\\n", " --input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \\\n", " --output_graph=/tmp/frozen_graph.pb \\\n", " --output_node_names=fc2/add \\\n", " --input_binary=True\n", "```\n", "\n", "Note that now we have obtained the `frozen_graph.pb` with graph definition as well as weight information in one file.\n", "\n", "### Model Conversion\n", "\n", "Thirdly, we convert the model to ONNX format using onnx-tensorflow. Using `tensorflow_graph_to_onnx_model` from onnx-tensorflow API (documentation available at https://github.com/onnx/onnx-tensorflow/blob/master/doc/API.md)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "import tensorflow as tf\n", "from onnx_tf.frontend import tensorflow_graph_to_onnx_model\n", "\n", "with tf.gfile.GFile(\"frozen_graph.pb\", \"rb\") as f:\n", " graph_def = tf.GraphDef()\n", " graph_def.ParseFromString(f.read())\n", " onnx_model = tensorflow_graph_to_onnx_model(graph_def,\n", " \"fc2/add\",\n", " opset=6)\n", "\n", " file = open(\"mnist.onnx\", \"wb\")\n", " file.write(onnx_model.SerializeToString())\n", " file.close()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Performing a simple sanity check to ensure that we have obtained the correct model, we print out the first node of the ONNX model graph converted, which corresponds to the reshape operation performed to convert the 1D serial input to a 2D image tensor:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input: \"Placeholder\"\n", "input: \"reshape/Reshape/shape\"\n", "output: \"reshape/Reshape\"\n", "op_type: \"Reshape\"\n", "\n" ] } ], "source": [ "print(onnx_model.graph.node[0])" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Inference using Backend\n", "\n", "In this tutorial, we continue our demonstration by performing inference using this obtained ONNX model. Here, we exported an image representing a handwritten 7 and stored the numpy array as image.npz. Using our backend, we will classify this image using the converted ONNX model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The digit is classified as 7\n" ] } ], "source": [ "import onnx\n", "import numpy as np\n", "from onnx_tf.backend import prepare\n", "\n", "model = onnx.load('mnist.onnx')\n", "tf_rep = prepare(model)\n", "\n", "img = np.load(\"./assets/image.npz\")\n", "output = tf_rep.run(img.reshape([1, 784]))\n", "print \"The digit is classified as \", np.argmax(output)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }