{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Symbol Tutorial\n", "\n", "Besides the tensor computation interface [NDArray](./ndarray.ipynb), another main object in MXNet is the `Symbol` provided by `mxnet.symbol`, or `mxnet.sym` for short. A symbol represents a multi-output symbolic expression. They are composited by operators, such as simple matrix operations (e.g. “+”), or a neural network layer (e.g. convolution layer). An operator can take several input variables, produce more than one output variables, and have internal state variables. A variable can be either free, which we can bind with value later, or an output of another symbol. \n", "\n", "## Symbol Composition \n", "### Basic Operators\n", "The following example composites a simple expression `a+b`. We first create the placeholders `a` and `b` with names using `mx.sym.Variable`, and then construct the desired symbol by using the operator `+`. When the string name is not given during creating, MXNet will automatically generate a unique name for the symbol, which is the case for `c`. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "(, , )" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import mxnet as mx\n", "a = mx.sym.Variable('a')\n", "b = mx.sym.Variable('b')\n", "c = a + b\n", "(a, b, c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Most `NDArray` operators can be applied to `Symbol`, for example: " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "plot\n", "\n", "\n", "a\n", "\n", "a\n", "\n", "\n", "b\n", "\n", "b\n", "\n", "\n", "_mul0\n", "\n", "_mul\n", "\n", "\n", "_mul0->a\n", "\n", "\n", "\n", "\n", "_mul0->b\n", "\n", "\n", "\n", "\n", "dot0\n", "\n", "dot\n", "\n", "\n", "dot0->a\n", "\n", "\n", "\n", "\n", "dot0->b\n", "\n", "\n", "\n", "\n", "_plus1\n", "\n", "elemwise_add\n", "\n", "\n", "_plus1->_mul0\n", "\n", "\n", "\n", "\n", "_plus1->dot0\n", "\n", "\n", "\n", "\n", "reshape0\n", "\n", "Reshape\n", "\n", "\n", "reshape0->_plus1\n", "\n", "\n", "\n", "\n", "broadcast_to0\n", "\n", "broadcast_to\n", "\n", "\n", "broadcast_to0->reshape0\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# elemental wise times\n", "d = a * b \n", "# matrix multiplication\n", "e = mx.sym.dot(a, b) \n", "# reshape\n", "f = mx.sym.Reshape(d+e, shape=(1,4)) \n", "# broadcast\n", "g = mx.sym.broadcast_to(f, shape=(2,4)) \n", "mx.viz.plot_network(symbol=g)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Basic Neural Networks\n", "\n", "Besides the basic operators, `Symbol` has a rich set of neural network layers. The following codes construct a two layer fully connected neural work and then visualize the structure by given the input data shape. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "plot\n", "\n", "\n", "data\n", "\n", "data\n", "\n", "\n", "fc1\n", "\n", "FullyConnected\n", "128\n", "\n", "\n", "fc1->data\n", "\n", "\n", "200\n", "\n", "\n", "relu1\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu1->fc1\n", "\n", "\n", "128\n", "\n", "\n", "fc2\n", "\n", "FullyConnected\n", "10\n", "\n", "\n", "fc2->relu1\n", "\n", "\n", "128\n", "\n", "\n", "out_label\n", "\n", "out_label\n", "\n", "\n", "out\n", "\n", "SoftmaxOutput\n", "\n", "\n", "out->fc2\n", "\n", "\n", "10\n", "\n", "\n", "out->out_label\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Output may vary\n", "net = mx.sym.Variable('data')\n", "net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128)\n", "net = mx.sym.Activation(data=net, name='relu1', act_type=\"relu\")\n", "net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=10)\n", "net = mx.sym.SoftmaxOutput(data=net, name='out')\n", "mx.viz.plot_network(net, shape={'data':(100,200)})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Modulelized Construction for Deep Networks\n", "For deep networks, such as the Google Inception, constructing layer by layer is painful given the large number of layers. For these networks, we often modularize the construction. Take the Google Inception as an example, we can first define a factory function to chain the convolution layer, batch normalization layer, and Relu activation layer together:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "plot\n", "\n", "\n", "Previos Output\n", "\n", "Previos Output\n", "\n", "\n", "conv_None\n", "\n", "Convolution\n", "7x7/2, 64\n", "\n", "\n", "conv_None->Previos Output\n", "\n", "\n", "3x28x28\n", "\n", "\n", "bn_None_gamma\n", "\n", "bn_None_gamma\n", "\n", "\n", "bn_None_beta\n", "\n", "bn_None_beta\n", "\n", "\n", "bn_None_moving_mean\n", "\n", "bn_None_moving_mean\n", "\n", "\n", "bn_None_moving_var\n", "\n", "bn_None_moving_var\n", "\n", "\n", "bn_None\n", "\n", "BatchNorm\n", "\n", "\n", "bn_None->conv_None\n", "\n", "\n", "64x11x11\n", "\n", "\n", "bn_None->bn_None_gamma\n", "\n", "\n", "\n", "\n", "bn_None->bn_None_beta\n", "\n", "\n", "\n", "\n", "bn_None->bn_None_moving_mean\n", "\n", "\n", "\n", "\n", "bn_None->bn_None_moving_var\n", "\n", "\n", "\n", "\n", "relu_None\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu_None->bn_None\n", "\n", "\n", "64x11x11\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Output may vary\n", "def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):\n", " conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))\n", " bn = mx.symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix))\n", " act = mx.symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix))\n", " return act\n", "prev = mx.symbol.Variable(name=\"Previos Output\")\n", "conv_comp = ConvFactory(data=prev, num_filter=64, kernel=(7,7), stride=(2, 2))\n", "shape = {\"Previos Output\" : (128, 3, 28, 28)}\n", "mx.viz.plot_network(symbol=conv_comp, shape=shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we define a function that constructs an Inception module based on `ConvFactory`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "plot\n", "\n", "\n", "Previos Output\n", "\n", "Previos Output\n", "\n", "\n", "conv_in3a_1x1\n", "\n", "Convolution\n", "1x1/1, 64\n", "\n", "\n", "conv_in3a_1x1->Previos Output\n", "\n", "\n", "3x28x28\n", "\n", "\n", "bn_in3a_1x1_gamma\n", "\n", "bn_in3a_1x1_gamma\n", "\n", "\n", "bn_in3a_1x1_beta\n", "\n", "bn_in3a_1x1_beta\n", "\n", "\n", "bn_in3a_1x1_moving_mean\n", "\n", "bn_in3a_1x1_moving_mean\n", "\n", "\n", "bn_in3a_1x1_moving_var\n", "\n", "bn_in3a_1x1_moving_var\n", "\n", "\n", "bn_in3a_1x1\n", "\n", "BatchNorm\n", "\n", "\n", "bn_in3a_1x1->conv_in3a_1x1\n", "\n", "\n", "64x28x28\n", "\n", "\n", "bn_in3a_1x1->bn_in3a_1x1_gamma\n", "\n", "\n", "\n", "\n", "bn_in3a_1x1->bn_in3a_1x1_beta\n", "\n", "\n", "\n", "\n", "bn_in3a_1x1->bn_in3a_1x1_moving_mean\n", "\n", "\n", "\n", "\n", "bn_in3a_1x1->bn_in3a_1x1_moving_var\n", "\n", "\n", "\n", "\n", "relu_in3a_1x1\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu_in3a_1x1->bn_in3a_1x1\n", "\n", "\n", "64x28x28\n", "\n", "\n", "conv_in3a_3x3_reduce\n", "\n", "Convolution\n", "1x1/1, 64\n", "\n", "\n", "conv_in3a_3x3_reduce->Previos Output\n", "\n", "\n", "3x28x28\n", "\n", "\n", "bn_in3a_3x3_reduce_gamma\n", "\n", "bn_in3a_3x3_reduce_gamma\n", "\n", "\n", "bn_in3a_3x3_reduce_beta\n", "\n", "bn_in3a_3x3_reduce_beta\n", "\n", "\n", "bn_in3a_3x3_reduce_moving_mean\n", "\n", "bn_in3a_3x3_reduce_moving_mean\n", "\n", "\n", "bn_in3a_3x3_reduce_moving_var\n", "\n", "bn_in3a_3x3_reduce_moving_var\n", "\n", "\n", "bn_in3a_3x3_reduce\n", "\n", "BatchNorm\n", "\n", "\n", "bn_in3a_3x3_reduce->conv_in3a_3x3_reduce\n", "\n", "\n", "64x28x28\n", "\n", "\n", "bn_in3a_3x3_reduce->bn_in3a_3x3_reduce_gamma\n", "\n", "\n", "\n", "\n", "bn_in3a_3x3_reduce->bn_in3a_3x3_reduce_beta\n", "\n", "\n", "\n", "\n", "bn_in3a_3x3_reduce->bn_in3a_3x3_reduce_moving_mean\n", "\n", "\n", "\n", "\n", "bn_in3a_3x3_reduce->bn_in3a_3x3_reduce_moving_var\n", "\n", "\n", "\n", "\n", "relu_in3a_3x3_reduce\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu_in3a_3x3_reduce->bn_in3a_3x3_reduce\n", "\n", "\n", "64x28x28\n", "\n", "\n", "conv_in3a_3x3\n", "\n", "Convolution\n", "3x3/1, 64\n", "\n", "\n", "conv_in3a_3x3->relu_in3a_3x3_reduce\n", "\n", "\n", "64x28x28\n", "\n", "\n", "bn_in3a_3x3_gamma\n", "\n", "bn_in3a_3x3_gamma\n", "\n", "\n", "bn_in3a_3x3_beta\n", "\n", "bn_in3a_3x3_beta\n", "\n", "\n", "bn_in3a_3x3_moving_mean\n", "\n", "bn_in3a_3x3_moving_mean\n", "\n", "\n", "bn_in3a_3x3_moving_var\n", "\n", "bn_in3a_3x3_moving_var\n", "\n", "\n", "bn_in3a_3x3\n", "\n", "BatchNorm\n", "\n", "\n", "bn_in3a_3x3->conv_in3a_3x3\n", "\n", "\n", "64x28x28\n", "\n", "\n", "bn_in3a_3x3->bn_in3a_3x3_gamma\n", "\n", "\n", "\n", "\n", "bn_in3a_3x3->bn_in3a_3x3_beta\n", "\n", "\n", "\n", "\n", "bn_in3a_3x3->bn_in3a_3x3_moving_mean\n", "\n", "\n", "\n", "\n", "bn_in3a_3x3->bn_in3a_3x3_moving_var\n", "\n", "\n", "\n", "\n", "relu_in3a_3x3\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu_in3a_3x3->bn_in3a_3x3\n", "\n", "\n", "64x28x28\n", "\n", "\n", "conv_in3a_double_3x3_reduce\n", "\n", "Convolution\n", "1x1/1, 64\n", "\n", "\n", "conv_in3a_double_3x3_reduce->Previos Output\n", "\n", "\n", "3x28x28\n", "\n", "\n", "bn_in3a_double_3x3_reduce_gamma\n", "\n", "bn_in3a_double_3x3_reduce_gamma\n", "\n", "\n", "bn_in3a_double_3x3_reduce_beta\n", "\n", "bn_in3a_double_3x3_reduce_beta\n", "\n", "\n", "bn_in3a_double_3x3_reduce_moving_mean\n", "\n", "bn_in3a_double_3x3_reduce_moving_mean\n", "\n", "\n", "bn_in3a_double_3x3_reduce_moving_var\n", "\n", "bn_in3a_double_3x3_reduce_moving_var\n", "\n", "\n", "bn_in3a_double_3x3_reduce\n", "\n", "BatchNorm\n", "\n", "\n", "bn_in3a_double_3x3_reduce->conv_in3a_double_3x3_reduce\n", "\n", "\n", "64x28x28\n", "\n", "\n", "bn_in3a_double_3x3_reduce->bn_in3a_double_3x3_reduce_gamma\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_reduce->bn_in3a_double_3x3_reduce_beta\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_reduce->bn_in3a_double_3x3_reduce_moving_mean\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_reduce->bn_in3a_double_3x3_reduce_moving_var\n", "\n", "\n", "\n", "\n", "relu_in3a_double_3x3_reduce\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu_in3a_double_3x3_reduce->bn_in3a_double_3x3_reduce\n", "\n", "\n", "64x28x28\n", "\n", "\n", "conv_in3a_double_3x3_0\n", "\n", "Convolution\n", "3x3/1, 96\n", "\n", "\n", "conv_in3a_double_3x3_0->relu_in3a_double_3x3_reduce\n", "\n", "\n", "64x28x28\n", "\n", "\n", "bn_in3a_double_3x3_0_gamma\n", "\n", "bn_in3a_double_3x3_0_gamma\n", "\n", "\n", "bn_in3a_double_3x3_0_beta\n", "\n", "bn_in3a_double_3x3_0_beta\n", "\n", "\n", "bn_in3a_double_3x3_0_moving_mean\n", "\n", "bn_in3a_double_3x3_0_moving_mean\n", "\n", "\n", "bn_in3a_double_3x3_0_moving_var\n", "\n", "bn_in3a_double_3x3_0_moving_var\n", "\n", "\n", "bn_in3a_double_3x3_0\n", "\n", "BatchNorm\n", "\n", "\n", "bn_in3a_double_3x3_0->conv_in3a_double_3x3_0\n", "\n", "\n", "96x28x28\n", "\n", "\n", "bn_in3a_double_3x3_0->bn_in3a_double_3x3_0_gamma\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_0->bn_in3a_double_3x3_0_beta\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_0->bn_in3a_double_3x3_0_moving_mean\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_0->bn_in3a_double_3x3_0_moving_var\n", "\n", "\n", "\n", "\n", "relu_in3a_double_3x3_0\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu_in3a_double_3x3_0->bn_in3a_double_3x3_0\n", "\n", "\n", "96x28x28\n", "\n", "\n", "conv_in3a_double_3x3_1\n", "\n", "Convolution\n", "3x3/1, 96\n", "\n", "\n", "conv_in3a_double_3x3_1->relu_in3a_double_3x3_0\n", "\n", "\n", "96x28x28\n", "\n", "\n", "bn_in3a_double_3x3_1_gamma\n", "\n", "bn_in3a_double_3x3_1_gamma\n", "\n", "\n", "bn_in3a_double_3x3_1_beta\n", "\n", "bn_in3a_double_3x3_1_beta\n", "\n", "\n", "bn_in3a_double_3x3_1_moving_mean\n", "\n", "bn_in3a_double_3x3_1_moving_mean\n", "\n", "\n", "bn_in3a_double_3x3_1_moving_var\n", "\n", "bn_in3a_double_3x3_1_moving_var\n", "\n", "\n", "bn_in3a_double_3x3_1\n", "\n", "BatchNorm\n", "\n", "\n", "bn_in3a_double_3x3_1->conv_in3a_double_3x3_1\n", "\n", "\n", "96x28x28\n", "\n", "\n", "bn_in3a_double_3x3_1->bn_in3a_double_3x3_1_gamma\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_1->bn_in3a_double_3x3_1_beta\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_1->bn_in3a_double_3x3_1_moving_mean\n", "\n", "\n", "\n", "\n", "bn_in3a_double_3x3_1->bn_in3a_double_3x3_1_moving_var\n", "\n", "\n", "\n", "\n", "relu_in3a_double_3x3_1\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu_in3a_double_3x3_1->bn_in3a_double_3x3_1\n", "\n", "\n", "96x28x28\n", "\n", "\n", "avg_pool_in3a_pool\n", "\n", "Pooling\n", "avg, 3x3/1\n", "\n", "\n", "avg_pool_in3a_pool->Previos Output\n", "\n", "\n", "3x28x28\n", "\n", "\n", "conv_in3a_proj\n", "\n", "Convolution\n", "1x1/1, 32\n", "\n", "\n", "conv_in3a_proj->avg_pool_in3a_pool\n", "\n", "\n", "3x28x28\n", "\n", "\n", "bn_in3a_proj_gamma\n", "\n", "bn_in3a_proj_gamma\n", "\n", "\n", "bn_in3a_proj_beta\n", "\n", "bn_in3a_proj_beta\n", "\n", "\n", "bn_in3a_proj_moving_mean\n", "\n", "bn_in3a_proj_moving_mean\n", "\n", "\n", "bn_in3a_proj_moving_var\n", "\n", "bn_in3a_proj_moving_var\n", "\n", "\n", "bn_in3a_proj\n", "\n", "BatchNorm\n", "\n", "\n", "bn_in3a_proj->conv_in3a_proj\n", "\n", "\n", "32x28x28\n", "\n", "\n", "bn_in3a_proj->bn_in3a_proj_gamma\n", "\n", "\n", "\n", "\n", "bn_in3a_proj->bn_in3a_proj_beta\n", "\n", "\n", "\n", "\n", "bn_in3a_proj->bn_in3a_proj_moving_mean\n", "\n", "\n", "\n", "\n", "bn_in3a_proj->bn_in3a_proj_moving_var\n", "\n", "\n", "\n", "\n", "relu_in3a_proj\n", "\n", "Activation\n", "relu\n", "\n", "\n", "relu_in3a_proj->bn_in3a_proj\n", "\n", "\n", "32x28x28\n", "\n", "\n", "ch_concat_in3a_chconcat\n", "\n", "Concat\n", "\n", "\n", "ch_concat_in3a_chconcat->relu_in3a_1x1\n", "\n", "\n", "64x28x28\n", "\n", "\n", "ch_concat_in3a_chconcat->relu_in3a_3x3\n", "\n", "\n", "64x28x28\n", "\n", "\n", "ch_concat_in3a_chconcat->relu_in3a_double_3x3_1\n", "\n", "\n", "96x28x28\n", "\n", "\n", "ch_concat_in3a_chconcat->relu_in3a_proj\n", "\n", "\n", "32x28x28\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @@@ AUTOTEST_OUTPUT_IGNORED_CELL\n", "def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name):\n", " # 1x1\n", " c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))\n", " # 3x3 reduce + 3x3\n", " c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')\n", " c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))\n", " # double 3x3 reduce + double 3x3\n", " cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')\n", " cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name))\n", " cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name))\n", " # pool + proj\n", " pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))\n", " cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name))\n", " # concat\n", " concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name)\n", " return concat\n", "prev = mx.symbol.Variable(name=\"Previos Output\")\n", "in3a = InceptionFactoryA(prev, 64, 64, 64, 64, 96, \"avg\", 32, name=\"in3a\")\n", "mx.viz.plot_network(symbol=in3a, shape=shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally we can obtain the whole network by chaining multiple inception modulas. A complete example is available at [mxnet/example/image-classification/symbol_inception-bn.py](https://github.com/dmlc/mxnet/blob/master/example/image-classification/symbol_inception-bn.py)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Group Multiple Symbols\n", "\n", "To construct neural networks with multiple loss layers, we can use `mxnet.sym.Group` to group multiple symbols together. The following example group two outputs:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "['softmax_output', 'regression_output']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net = mx.sym.Variable('data')\n", "fc1 = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128)\n", "net = mx.sym.Activation(data=fc1, name='relu1', act_type=\"relu\")\n", "out1 = mx.sym.SoftmaxOutput(data=net, name='softmax')\n", "out2 = mx.sym.LinearRegressionOutput(data=net, name='regression')\n", "group = mx.sym.Group([out1, out2])\n", "group.list_outputs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Relations to NDArray\n", "\n", "As can be seen now, both Symbol and NDArray provide multi-dimensional array operations, such as `c=a+b` in MXNet. Sometimes users are confused which way to use. We briefly clarify the difference here, more detailed explanation are available [here](http://mxnet.readthedocs.io/en/latest/system/program_model.html). \n", "\n", "The `NDArray` provides an imperative programming alike interface, in which the computations are evaluated sentence by sentence. While `Symbol` is closer to declarative programming, in which we first declare the computation, and then evaluate with data. Examples in this category include regular expression and SQL.\n", "\n", "The pros for `NDArray`:\n", "- straightforward\n", "- easy to work with other language features (for loop, if-else condition, ..) and libraries (numpy, ..)\n", "- easy to step-by-step debug\n", "\n", "The pros for `Symbol`:\n", "- provides almost all functionalities of NDArray, such as +, \\*, sin, and reshape \n", "- provides a large number of neural network related operators such as Convolution, Activation, and BatchNorm\n", "- provides automatic differentiation \n", "- easy to construct and manipulate complex computations such as deep neural networks\n", "- easy to save, load, and visualization\n", "- easy for the backend to optimize the computation and memory usage\n", "\n", "We will show on the [mixed programming tutorial](./mixed.ipynb) how these two interfaces can be used together to develop a complete training program. This tutorial will focus on the usage of Symbol. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Symbol Manipulation *\n", "\n", "One important difference of `Symbol` comparing to `NDArray` is that, we first declare the computation, and then bind with data to run. \n", "\n", "In this section we introduce the functions to manipulate a symbol directly. But note that, most of them are wrapped nicely by the [`mx.module`](./module.ipynb). One can skip this section safely. \n", "\n", "### Shape Inference\n", "For each symbol, we can query its inputs (or arguments) and outputs. We can also inference the output shape by given the input shape, which facilitates memory allocation. " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'input': {'a': (2L, 3L), 'b': (2L, 3L)},\n", " 'output': {'_plus0_output': (2L, 3L)}}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arg_name = c.list_arguments() # get the names of the inputs\n", "out_name = c.list_outputs() # get the names of the outputs\n", "arg_shape, out_shape, _ = c.infer_shape(a=(2,3), b=(2,3)) \n", "{'input' : dict(zip(arg_name, arg_shape)), \n", " 'output' : dict(zip(out_name, out_shape))}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Bind with Data and Evaluate\n", "The symbol `c` we constructed declares what computation should be run. To evaluate it, we need to feed arguments, namely free variables, with data first. We can do it by using the `bind` method, which accepts device context and a `dict` mapping free variable names to `NDArray`s as arguments and returns an executor. The executor provides method `forward` for evaluation and attribute `outputs` to get all results. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of outputs = 1\n", "the first output = \n", "[[ 2. 2. 2.]\n", " [ 2. 2. 2.]]\n" ] } ], "source": [ "ex = c.bind(ctx=mx.cpu(), args={'a' : mx.nd.ones([2,3]), \n", " 'b' : mx.nd.ones([2,3])})\n", "ex.forward()\n", "print 'number of outputs = %d\\nthe first output = \\n%s' % (\n", " len(ex.outputs), ex.outputs[0].asnumpy())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can evaluate the same symbol on GPU with different data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 5., 5., 5., 5.],\n", " [ 5., 5., 5., 5.],\n", " [ 5., 5., 5., 5.]], dtype=float32)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ex_gpu = c.bind(ctx=mx.gpu(), args={'a' : mx.nd.ones([3,4], mx.gpu())*2,\n", " 'b' : mx.nd.ones([3,4], mx.gpu())*3})\n", "ex_gpu.forward()\n", "ex_gpu.outputs[0].asnumpy()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Load and Save\n", "\n", "Similar to NDArray, we can either serialize a `Symbol` object by using `pickle`, or use `save` and `load` directly. Different to the binary format chosen by `NDArray`, `Symbol` uses the more readable json format for serialization. The `tojson` method returns the json string." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"nodes\": [\n", " {\n", " \"op\": \"null\", \n", " \"name\": \"a\", \n", " \"inputs\": []\n", " }, \n", " {\n", " \"op\": \"null\", \n", " \"name\": \"b\", \n", " \"inputs\": []\n", " }, \n", " {\n", " \"op\": \"elemwise_add\", \n", " \"name\": \"_plus0\", \n", " \"inputs\": [[0, 0, 0], [1, 0, 0]]\n", " }\n", " ], \n", " \"arg_nodes\": [0, 1], \n", " \"node_row_ptr\": [0, 1, 2, 3], \n", " \"heads\": [[2, 0, 0]], \n", " \"attrs\": {\"mxnet_version\": [\"int\", 901]}\n", "}\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(c.tojson())\n", "c.save('symbol-c.json')\n", "c2 = mx.symbol.load('symbol-c.json')\n", "c.tojson() == c2.tojson()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Customized Symbol *\n", "\n", "Most operators such as `mx.sym.Convolution` and `mx.sym.Reshape` are implemented in C++ for better performance. MXNet also allows users to write new operators using any frontend language such as Python. It often makes the developing and debugging much easier. \n", "\n", "To implement an operator in Python, we just need to define the two computation methods `forward` and `backward` with several methods for querying the properties, such as `list_arguments` and `infer_shape`. \n", "\n", "`NDArray` is the default type of arguments in both `forward` and `backward`. Therefore we often also implement the computation with `NDArray` operations. To show the flexibility of MXNet, however, we will demonstrate an implementation of the `softmax` layer using NumPy. Though a NumPy based operator can be only run on CPU and also lose some optimizations which can be applied on NDArray, it enjoys the rich functionalities provided by NumPy.\n", "\n", "We first create a subclass of `mx.operator.CustomOp` and then define `forward` and `backward`.\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Softmax(mx.operator.CustomOp):\n", " def forward(self, is_train, req, in_data, out_data, aux):\n", " x = in_data[0].asnumpy()\n", " y = np.exp(x - x.max(axis=1).reshape((x.shape[0], 1)))\n", " y /= y.sum(axis=1).reshape((x.shape[0], 1))\n", " self.assign(out_data[0], req[0], mx.nd.array(y))\n", "\n", " def backward(self, req, out_grad, in_data, out_data, in_grad, aux):\n", " l = in_data[1].asnumpy().ravel().astype(np.int)\n", " y = out_data[0].asnumpy()\n", " y[np.arange(l.shape[0]), l] -= 1.0\n", " self.assign(in_grad[0], req[0], mx.nd.array(y))" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Here we use `asnumpy` to convert the `NDArray` inputs into `numpy.ndarray`. Then using `CustomOp.assign` to assign the results back to `mxnet.NDArray` based on the value of req, which could be \"over write\" or \"add to\". \n", "\n", "Next we create a subclass of `mx.operator.CustomOpProp` for querying the properties." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# register this operator into MXNet by name \"softmax\"\n", "@mx.operator.register(\"softmax\")\n", "class SoftmaxProp(mx.operator.CustomOpProp):\n", " def __init__(self):\n", " # softmax is a loss layer so we don’t need gradient input\n", " # from layers above. \n", " super(SoftmaxProp, self).__init__(need_top_grad=False)\n", " \n", " def list_arguments(self):\n", " return ['data', 'label']\n", "\n", " def list_outputs(self):\n", " return ['output']\n", "\n", " def infer_shape(self, in_shape):\n", " data_shape = in_shape[0]\n", " label_shape = (in_shape[0][0],)\n", " output_shape = in_shape[0]\n", " return [data_shape, label_shape], [output_shape], []\n", "\n", " def create_operator(self, ctx, shapes, dtypes):\n", " return Softmax()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we can use `mx.sym.Custom` with the register name to use this operator\n", "```python\n", "net = mx.symbol.Custom(data=prev_input, op_type='softmax')\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Advanced Usages *\n", "\n", "### Type Cast\n", "MXNet uses 32-bit float in default. Sometimes we want to use a lower precision data type for better accuracy-performance trade-off. For example, The Nvidia Tesla Pascal GPUs (e.g. P100) have improved 16-bit float performance, while GTX Pascal GPUs (e.g. GTX 1080) are fast on 8-bit integers. \n", "\n", "We can use the `mx.sym.Cast` operator to convert the data type." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'input': [], 'output': []}\n", "{'input': [], 'output': []}\n" ] } ], "source": [ "a = mx.sym.Variable('data')\n", "b = mx.sym.Cast(data=a, dtype='float16')\n", "arg, out, _ = b.infer_type(data='float32')\n", "print({'input':arg, 'output':out})\n", "\n", "c = mx.sym.Cast(data=a, dtype='uint8')\n", "arg, out, _ = c.infer_type(data='int32')\n", "print({'input':arg, 'output':out})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Variable Sharing\n", "\n", "Sometimes we want to share the contents between several symbols. This can be simply done by bind these symbols with the same array. \n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 6., 6., 6.],\n", " [ 6., 6., 6.]], dtype=float32)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = mx.sym.Variable('a')\n", "b = mx.sym.Variable('b')\n", "c = mx.sym.Variable('c')\n", "d = a + b * c\n", "\n", "data = mx.nd.ones((2,3))*2\n", "ex = d.bind(ctx=mx.cpu(), args={'a':data, 'b':data, 'c':data})\n", "ex.forward()\n", "ex.outputs[0].asnumpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Further Readings\n", "\n", "- [Use torch operators in MXNet](http://mxnet.dmlc.ml/en/latest/how_to/torch.html)\n", "- [Use Caffe operators in MXNet](http://dmlc.ml/mxnet/2016/07/29/use-caffe-operator-in-mxnet.html)" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 1 }