{"nbformat": 4, "cells": [{"source": "# Symbol - Neural network graphs and auto-differentiation\n\nIn a [previous tutorial](http://mxnet.io/tutorials/basic/ndarray.html), we introduced `NDArray`,\nthe basic data structure for manipulating data in MXNet.\nAnd just using NDArray by itself, we can execute a wide range of mathematical operations.\nIn fact, we could define and update a full neural network just by using `NDArray`.\n`NDArray` allows you to write programs for scientific computation\nin an imperative fashion, making full use of the native control of any front-end language.\nSo you might wonder, why don't we just use `NDArray` for all computation?\n\nMXNet provides the Symbol API, an interface for symbolic programming.\nWith symbolic programming, rather than executing operations step by step,\nwe first define a *computation graph*.\nThis graph contains placeholders for inputs and designated outputs.\nWe can then compile the graph, yielding a function\nthat can be bound to `NDArray`s and run.\nMXNet's Symbol API is similar to the network configurations\nused by [Caffe](http://caffe.berkeleyvision.org/)\nand the symbolic programming in [Theano](http://deeplearning.net/software/theano/).\n\nAnother advantage conferred by symbolic approach is that\nwe can optimize our functions before using them.\nFor example, when we execute mathematical computations in imperative fashion,\nwe don't know at the time that we run each operation,\nwhich values will be needed later on.\nBut with symbolic programming, we declare the required outputs in advance.\nThis means that we can recycle memory allocated in intermediate steps,\nas by performing operations in place. Symbolic API also uses less memory for the\nsame network. Refer to [How To](http://mxnet.io/how_to/index.html) and\n[Architecture](http://mxnet.io/architecture/index.html) section to know more.\n\nIn our design notes, we present [a more thorough discussion on the comparative strengths\nof imperative and symbolic programing](http://mxnet.io/architecture/program_model.html).\nBut in this document, we'll focus on teaching you how to use MXNet's Symbol API.\nIn MXNet, we can compose Symbols from other Symbols, using operators,\nsuch as simple matrix operations (e.g. \"+\"),\nor whole neural network layers (e.g. convolution layer).\nOperator can take multiple input variables,\ncan produce multiple output symbols\nand can maintain internal state symbols.\n\nFor a visual explanation of these concepts, see\n[Symbolic Configuration and Execution in Pictures](http://mxnet.io/api/python/symbol_in_pictures.html).\n\nTo make things concrete, let's take a hands-on look at the Symbol API.\nThere are a few different ways to compose a `Symbol`.\n\n## Prerequisites\n\nTo complete this tutorial, we need:\n\n- MXNet. See the instructions for your operating system in [Setup and Installation](http://mxnet.io/get_started/install.html)\n- [Jupyter](http://jupyter.org/)\n ```\n pip install jupyter\n ```\n- GPUs - A section of this tutorial uses GPUs. If you don't have GPUs on your machine, simply\nset the variable gpu_device to mx.cpu().\n\n## Basic Symbol Composition\n\n### Basic Operators\n\nThe following example builds a simple expression: `a + b`.\nFirst, we create two placeholders with `mx.sym.Variable`,\ngiving them the names `a` and `b`.\nWe then construct the desired symbol by using the operator `+`.\nWe don't need to name our variables while creating them,\nMXNet will automatically generate a unique name for each.\nIn the example below, `c` is assigned a unique name automatically.", "cell_type": "markdown", "metadata": {}}, {"source": "import mxnet as mx\na = mx.sym.Variable('a')\nb = mx.sym.Variable('b')\nc = a + b\n(a, b, c)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Most operators supported by `NDArray` are also supported by `Symbol`, for example:", "cell_type": "markdown", "metadata": {}}, {"source": "# elemental wise multiplication\nd = a * b\n# matrix multiplication\ne = mx.sym.dot(a, b)\n# reshape\nf = mx.sym.reshape(d+e, shape=(1,4))\n# broadcast\ng = mx.sym.broadcast_to(f, shape=(2,4))\n# plot\nmx.viz.plot_network(symbol=g)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "The computations declared in the above examples can be bound to the input data\nfor evaluation by using `bind` method. We discuss this further in the\n[symbol manipulation](#Symbol Manipulation) section.\n\n### Basic Neural Networks\n\nBesides the basic operators, `Symbol` also supports a rich set of neural network layers.\nThe following example constructs a two layer fully connected neural network\nand then visualizes the structure of that network given the input data shape.", "cell_type": "markdown", "metadata": {}}, {"source": "net = mx.sym.Variable('data')\nnet = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128)\nnet = mx.sym.Activation(data=net, name='relu1', act_type=\"relu\")\nnet = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=10)\nnet = mx.sym.SoftmaxOutput(data=net, name='out')\nmx.viz.plot_network(net, shape={'data':(100,200)})", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Each symbol takes a (unique) string name. NDArray and Symbol both represent\na single tensor. *Operators* represent the computation between tensors.\nOperators take symbol (or NDArray) as inputs and might also additionally accept\nother hyperparameters such as the number of hidden neurons (*num_hidden*) or the\nactivation type (*act_type*) and produce the output.\n\nWe can view a symbol simply as a function taking several arguments.\nAnd we can retrieve those arguments with the following method call:", "cell_type": "markdown", "metadata": {}}, {"source": "net.list_arguments()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "These arguments are the parameters and inputs needed by each symbol:\n\n- *data*: Input data needed by the variable *data*.\n- *fc1_weight* and *fc1_bias*: The weight and bias for the first fully connected layer *fc1*.\n- *fc2_weight* and *fc2_bias*: The weight and bias for the second fully connected layer *fc2*.\n- *out_label*: The label needed by the loss.\n\nWe can also specify the names explicitly:", "cell_type": "markdown", "metadata": {}}, {"source": "net = mx.symbol.Variable('data')\nw = mx.symbol.Variable('myweight')\nnet = mx.symbol.FullyConnected(data=net, weight=w, name='fc1', num_hidden=128)\nnet.list_arguments()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "In the above example, `FullyConnected` layer has 3 inputs: data, weight, bias.\nWhen any input is not specified, a variable will be automatically generated for it.\n\n## More Complicated Composition\n\nMXNet provides well-optimized symbols for layers commonly used in deep learning\n(see [src/operator](https://github.com/dmlc/mxnet/tree/master/src/operator)).\nWe can also define new operators in Python. The following example first\nperforms an element-wise add between two symbols, then feeds them to the fully\nconnected operator:", "cell_type": "markdown", "metadata": {}}, {"source": "lhs = mx.symbol.Variable('data1')\nrhs = mx.symbol.Variable('data2')\nnet = mx.symbol.FullyConnected(data=lhs + rhs, name='fc1', num_hidden=128)\nnet.list_arguments()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can also construct a symbol in a more flexible way than the single forward\ncomposition depicted in the preceding example:", "cell_type": "markdown", "metadata": {}}, {"source": "data = mx.symbol.Variable('data')\nnet1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10)\nnet1.list_arguments()\nnet2 = mx.symbol.Variable('data2')\nnet2 = mx.symbol.FullyConnected(data=net2, name='fc2', num_hidden=10)\ncomposed = net2(data2=net1, name='composed')\ncomposed.list_arguments()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "In this example, *net2* is used as a function to apply to an existing symbol *net1*,\nand the resulting *composed* symbol will have all the attributes of *net1* and *net2*.\n\nOnce you start building some bigger networks, you might want to name some\nsymbols with a common prefix to outline the structure of your network.\nYou can use the\n[Prefix](https://github.com/dmlc/mxnet/blob/master/python/mxnet/name.py)\nNameManager as follows:", "cell_type": "markdown", "metadata": {}}, {"source": "data = mx.sym.Variable(\"data\")\nnet = data\nn_layer = 2\nfor i in range(n_layer):\n with mx.name.Prefix(\"layer%d_\" % (i + 1)):\n net = mx.sym.FullyConnected(data=net, name=\"fc\", num_hidden=100)\nnet.list_arguments()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "### Modularized Construction for Deep Networks\n\nConstructing a *deep* network layer by layer, (like the Google Inception network),\ncan be tedious owing to the large number of layers.\nSo, for such networks, we often modularize the construction.\n\nFor example, in Google Inception network,\nwe can first define a factory function which chains the convolution,\nbatch normalization and rectified linear unit (ReLU) activation layers together.", "cell_type": "markdown", "metadata": {}}, {"source": "def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0),name=None, suffix=''):\n conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel,\n stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))\n bn = mx.sym.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix))\n act = mx.sym.Activation(data=bn, act_type='relu', name='relu_%s%s'\n %(name, suffix))\n return act\nprev = mx.sym.Variable(name=\"Previous Output\")\nconv_comp = ConvFactory(data=prev, num_filter=64, kernel=(7,7), stride=(2, 2))\nshape = {\"Previous Output\" : (128, 3, 28, 28)}\nmx.viz.plot_network(symbol=conv_comp, shape=shape)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Then we can define a function that constructs an inception module based on\nfactory function `ConvFactory`.", "cell_type": "markdown", "metadata": {}}, {"source": "def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3,\n pool, proj, name):\n # 1x1\n c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))\n # 3x3 reduce + 3x3\n c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')\n c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))\n # double 3x3 reduce + double 3x3\n cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')\n cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name))\n cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name))\n # pool + proj\n pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))\n cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name))\n # concat\n concat = mx.sym.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name)\n return concat\nprev = mx.sym.Variable(name=\"Previous Output\")\nin3a = InceptionFactoryA(prev, 64, 64, 64, 64, 96, \"avg\", 32, name=\"in3a\")\nmx.viz.plot_network(symbol=in3a, shape=shape)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Finally, we can obtain the whole network by chaining multiple inception\nmodules. See a complete example\n[here](https://github.com/dmlc/mxnet/blob/master/example/image-classification/symbols/inception-bn.py).\n\n### Group Multiple Symbols\n\nTo construct neural networks with multiple loss layers, we can use\n`mxnet.sym.Group` to group multiple symbols together. The following example\ngroups two outputs:", "cell_type": "markdown", "metadata": {}}, {"source": "net = mx.sym.Variable('data')\nfc1 = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128)\nnet = mx.sym.Activation(data=fc1, name='relu1', act_type=\"relu\")\nout1 = mx.sym.SoftmaxOutput(data=net, name='softmax')\nout2 = mx.sym.LinearRegressionOutput(data=net, name='regression')\ngroup = mx.sym.Group([out1, out2])\ngroup.list_outputs()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Relations to NDArray\n\nAs you can see now, both `Symbol` and `NDArray` provide multi-dimensional array\noperations, such as `c = a + b` in MXNet. We briefly clarify the differences here.\n\nThe `NDArray` provides an imperative programming alike interface, in which the\ncomputations are evaluated sentence by sentence. While `Symbol` is closer to\ndeclarative programming, in which we first declare the computation and then\nevaluate with data. Examples in this category include regular expressions and\nSQL.\n\nThe pros for `NDArray`:\n\n- Straightforward.\n- Easy to work with native language features (for loop, if-else condition, ..)\n and libraries (numpy, ..).\n- Easy step-by-step code debugging.\n\nThe pros for `Symbol`:\n\n- Provides almost all functionalities of NDArray, such as `+`, `*`, `sin`,\n `reshape` etc.\n- Easy to save, load and visualize.\n- Easy for the backend to optimize the computation and memory usage.\n\n## Symbol Manipulation\n\nOne important difference of `Symbol` compared to `NDArray` is that we first\ndeclare the computation and then bind the computation with data to run.\n\nIn this section, we introduce the functions to manipulate a symbol directly. But\nnote that, most of them are wrapped by the `module` package.\n\n### Shape and Type Inference\n\nFor each symbol, we can query its arguments, auxiliary states and outputs.\nWe can also infer the output shape and type of the symbol given the known input\nshape or type of some arguments, which facilitates memory allocation.", "cell_type": "markdown", "metadata": {}}, {"source": "arg_name = c.list_arguments() # get the names of the inputs\nout_name = c.list_outputs() # get the names of the outputs\n# infers output shape given the shape of input arguments\narg_shape, out_shape, _ = c.infer_shape(a=(2,3), b=(2,3))\n# infers output type given the type of input arguments\narg_type, out_type, _ = c.infer_type(a='float32', b='float32')\n{'input' : dict(zip(arg_name, arg_shape)),\n 'output' : dict(zip(out_name, out_shape))}\n{'input' : dict(zip(arg_name, arg_type)),\n 'output' : dict(zip(out_name, out_type))}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "### Bind with Data and Evaluate\n\nThe symbol `c` constructed above declares what computation should be run. To\nevaluate it, we first need to feed the arguments, namely free variables, with data.\n\nWe can do it by using the `bind` method, which accepts device context and\na `dict` mapping free variable names to `NDArray`s as arguments and returns an\nexecutor. The executor provides `forward` method for evaluation and an attribute\n`outputs` to get all the results.", "cell_type": "markdown", "metadata": {}}, {"source": "ex = c.bind(ctx=mx.cpu(), args={'a' : mx.nd.ones([2,3]),\n 'b' : mx.nd.ones([2,3])})\nex.forward()\nprint('number of outputs = %d\\nthe first output = \\n%s' % (\n len(ex.outputs), ex.outputs[0].asnumpy()))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can evaluate the same symbol on GPU with different data.\n\n**Note** In order to execute the following section on a cpu set gpu_device to mx.cpu().", "cell_type": "markdown", "metadata": {}}, {"source": "gpu_device=mx.gpu() # Change this to mx.cpu() in absence of GPUs.\n\nex_gpu = c.bind(ctx=gpu_device, args={'a' : mx.nd.ones([3,4], gpu_device)*2,\n 'b' : mx.nd.ones([3,4], gpu_device)*3})\nex_gpu.forward()\nex_gpu.outputs[0].asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "We can also use `eval` method to evaluate the symbol. It combines calls to `bind`\nand `forward` methods.", "cell_type": "markdown", "metadata": {}}, {"source": "ex = c.eval(ctx = mx.cpu(), a = mx.nd.ones([2,3]), b = mx.nd.ones([2,3]))\nprint('number of outputs = %d\\nthe first output = \\n%s' % (\n len(ex), ex[0].asnumpy()))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "For neural nets, a more commonly used pattern is ```simple_bind```, which\ncreates all of the argument arrays for you. Then you can call ```forward```,\nand ```backward``` (if the gradient is needed) to get the gradient.\n\n### Load and Save\n\nLogically symbols correspond to ndarrays. They both represent a tensor. They both\nare inputs/outputs of operators. We can either serialize a `Symbol` object by\nusing `pickle`, or by using `save` and `load` methods directly as we discussed in\n[NDArray tutorial](http://mxnet.io/tutorials/basic/ndarray.html#serialize-from-to-distributed-filesystems).\n\nWhen serializing `NDArray`, we serialize the tensor data in it and directly dump to\ndisk in binary format.\nBut symbol uses a concept of graph. Graphs are composed by chaining operators. They are\nimplicitly represented by output symbols. So, when serializing a `Symbol`, we\nserialize the graph of which the symbol is an output. While serialization, Symbol\nuses more readable `json` format for serialization. To convert symbol to `json`\nstring, use `tojson` method.", "cell_type": "markdown", "metadata": {}}, {"source": "print(c.tojson())\nc.save('symbol-c.json')\nc2 = mx.sym.load('symbol-c.json')\nc.tojson() == c2.tojson()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Customized Symbol\n\nMost operators such as `mx.sym.Convolution` and `mx.sym.Reshape` are implemented\nin C++ for better performance. MXNet also allows users to write new operators\nusing any front-end language such as Python. It often makes the developing and\ndebugging much easier. To implement an operator in Python, refer to\n[How to create new operators](http://mxnet.io/how_to/new_op.html).\n\n## Advanced Usages\n\n### Type Cast\n\nBy default, MXNet uses 32-bit floats.\nBut for better accuracy-performance,\nwe can also use a lower precision data type.\nFor example, The Nvidia Tesla Pascal GPUs\n(e.g. P100) have improved 16-bit float performance,\nwhile GTX Pascal GPUs (e.g. GTX 1080) are fast on 8-bit integers.\n\nTo convert the data type as per the requirements,\nwe can use `mx.sym.cast` operator as follows:", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.sym.Variable('data')\nb = mx.sym.cast(data=a, dtype='float16')\narg, out, _ = b.infer_type(data='float32')\nprint({'input':arg, 'output':out})\n\nc = mx.sym.cast(data=a, dtype='uint8')\narg, out, _ = c.infer_type(data='int32')\nprint({'input':arg, 'output':out})", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "### Variable Sharing\n\nTo share the contents between several symbols,\nwe can bind these symbols with the same array as follows:", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.sym.Variable('a')\nb = mx.sym.Variable('b')\nb = a + a * a\n\ndata = mx.nd.ones((2,3))*2\nex = b.bind(ctx=mx.cpu(), args={'a':data, 'b':data})\nex.forward()\nex.outputs[0].asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2}