{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Mixture Density Networks in TensorFlow\n", "An implementation of MDNs adapted from Andrej Karpathy's [example](https://github.com/karpathy/randomfun). My version is powered by [TensorFlow](https://www.tensorflow.org/versions/master/get_started/index.html)\n", "\n", "Mixture Density Networks are just like vanilla neural networks but they return estimates of uncertainty on their predictions. Karpathy writes:\n", "\n", "\"The core idea is to have a Neural Net that predicts an entire (and possibly complex)\n", "distribution. In this example we're predicting a mixture of gaussians distributions via\n", "its sufficient statistics (the means and diagonal covariances), which are on the last\n", "layer of the neural network. This means that the network knows what it doesn't know:\n", "it will predict diffuse distributions in situations where the target variable is very\n", "noisy, and it will predict a much more peaky distribution in nearly deterministic parts.\"\n", "\n", "Karpathy took his version from Bishop's Machine Learning textbook (1994)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Exception AssertionError: AssertionError() in > ignored\n" ] } ], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import math\n", "%matplotlib inline\n", "\n", "sess = tf.InteractiveSession()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#the function we'll be trying to predict\n", "N = 200\n", "X = np.linspace(0,1,N)\n", "Y = X + 0.3 * np.sin(2*3.1415926*X) + np.random.uniform(-0.1, 0.1, N)\n", "X,Y = Y,X\n", "# plt.scatter(X,Y,color='g')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "#hyperparameters\n", "input_size = 1 #size of input (x)\n", "output_size = 1\n", "hidden_size = 30 # size of hidden layer\n", "M = 3 #number of mixture components\n", "batch_size = 200" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize tensors for weights and biases\n", "I store them all in a dictionary so that they are easier to pass to functions" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def weight_variable(shape):\n", " initial = tf.truncated_normal(shape, stddev=0.25)\n", " return tf.Variable(initial)\n", "\n", "def bias_variable(shape):\n", " initial = tf.truncated_normal(shape, stddev=0.01)\n", " return tf.Variable(initial)\n", "\n", "x = tf.placeholder(\"float\", shape=[None, input_size]) #[nb x 1]\n", "y = tf.placeholder(\"float\", shape=[None, output_size]) #[nb x 1]\n", "\n", "w = {}\n", "w['Wxh'] = weight_variable([input_size, hidden_size]) # [input x hidden]\n", "w['Whu'] = weight_variable([hidden_size, M]) #[hidden x M]\n", "w['Whs'] = weight_variable([hidden_size, M]) #[hidden x M]\n", "w['Whp'] = weight_variable([hidden_size, M]) #[hidden x M]\n", "\n", "w['bxh'] = bias_variable([1, hidden_size]) # [1 x hidden]\n", "w['bhu'] = bias_variable([1, M]) # [1 x M]\n", "w['bhs'] = bias_variable([1, M]) # [1 x M]\n", "w['bhp'] = bias_variable([1, M]) # [1 x M]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build the computation graph\n", "All the variables are tensorflow types. (ps = probabilities, lp = log probabilities, loss = log loss)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "h = tf.nn.tanh(tf.matmul(x, w['Wxh']) + w['bxh']) #[nb x hidden]\n", "mu = tf.matmul(h, w['Whu']) + w['bhu'] #[nb x M]\n", "sigma = tf.exp( tf.matmul(h, w['Whs']) + w['bhs'] ) #[nb x M]\n", "pi = tf.nn.softmax( tf.matmul(h, w['Whp']) + w['bhp'] ) #[nb x M]\n", "\n", "ps = tf.exp(-(tf.pow(y-mu,2))/(2*tf.pow(sigma,2)))/(sigma*np.sqrt(2*math.pi)) #[nb x M]\n", "pin = ps * pi\n", "lp = -tf.log(tf.reduce_sum(pin, 1, keep_dims=True)) #[nb x 1] (sum across dimension 1)\n", "loss = tf.reduce_sum(lp) / batch_size # scalar" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#select training data (1 batch, all data)\n", "nb = N # full batch\n", "xbatch = np.reshape(X[:nb], (nb,1))\n", "ybatch = np.reshape(Y[:nb], (nb,1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimize\n", "Use TensorFlow's built-in optimization methods. Training loss should hit about -1 after 20,000 steps" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step 0, training loss 0.97247\n", "step 1000, training loss -0.709157\n", "step 2000, training loss -0.802351\n", "step 3000, training loss -0.855826\n", "step 4000, training loss -0.896237\n", "step 5000, training loss -0.923585\n", "step 6000, training loss -0.941037\n", "step 7000, training loss -0.952521\n", "step 8000, training loss -0.960583\n", "step 9000, training loss -0.966559\n", "step 10000, training loss -0.971154\n", "step 11000, training loss -0.974751\n", "step 12000, training loss -0.977597\n", "step 13000, training loss -0.979862\n", "step 14000, training loss -0.981674\n", "step 15000, training loss -0.983125\n", "step 16000, training loss -0.984289\n", "step 17000, training loss -0.985222\n", "step 18000, training loss -0.985971\n", "step 19000, training loss -0.986574\n", "training loss -0.98706\n" ] } ], "source": [ "train_step = tf.train.AdagradOptimizer(1e-2).minimize(loss)\n", "sess.run(tf.initialize_all_variables())\n", "\n", "for i in range(20000):\n", " if i%1000 == 0:\n", " train_loss = loss.eval(feed_dict={x:xbatch, y: ybatch})\n", " print \"step %d, training loss %g\"%(i, train_loss)\n", " train_step.run(feed_dict={x: xbatch, y: ybatch})\n", "\n", "print \"training loss %g\"%loss.eval(feed_dict={x:xbatch, y: ybatch})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sample\n", "Sample from the mixture density network (in attempt to reproduce original function)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "\n", "def sample(mus, sigmas, pis):\n", " best = pis.argmax(axis = 1) #[nb x 1]\n", " print best.shape\n", " \n", " #select the best\n", " indices = np.zeros_like(mus) #[nb x M]\n", " indices[range(mus.shape[0]), best] = 1\n", " \n", " best_mus = np.sum( np.multiply(indices, mus), axis=1)\n", " best_sigmas = np.sum( np.multiply(indices, sigmas), axis=1)\n", " \n", " Y_ = np.random.normal(best_mus, best_sigmas)\n", " return Y_" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(200,)\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAEACAYAAABVtcpZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XtwXeV57/Hvsy12kIvxQVZBQIJaW0mdccHInFB7EpBc\nKslkckxtZg7GPQ7JTIrnJGkzg84EfERBNNE4MCOmZUJPIEMziSeOYQKeKDnQLZUigVOTUlsY6tgB\nYY64GJMKOWDHIsLWe/7YF/Zl7Yv2XtoXrd9nRjP78u61Xmm017Pe93kv5pxDRESCKVTpCoiISOUo\nCIiIBJiCgIhIgCkIiIgEmIKAiEiAKQiIiASYL0HAzB4ys7fN7IUs7282swOxnz1mdqkf5xURkdL4\n1RL4HtCV4/0jwNXOuZXAN4Hv+nReEREpQZ0fB3HO7TGz5hzvP5v09FngYj/OKyIipalETuBLwBMV\nOK+IiKTxpSVQKDNbC3wR+Ew5zysiIt7KFgTM7DLgQWCdc+54jnJazEhEZJacc1bM5/zsDrLYT+Yb\nZpcAjwJbnHOv5DuQc64mf+68886K10H1r3w9VP/a/Knl+pfCl5aAme0E2oElZvYacCcQBpxz7kHg\nb4AG4B/MzIAPnHNX+nFuEREpnl+jgzbnef8vgb/041wiIuIfzRj2UXt7e6WrUBLVv7JU/8qq9foX\ny0rtT/Kbmblqq5OISDUzM1wVJIZFRKTGKAiIiASYgoCISIApCIiIBJiCgIhIgCkIiIgEmIKAiEiA\nKQiIiASYgoCISIApCIiIBJiCgIhIgCkIiIgEmIKAiEiAKQiIiASYgoCISIApCIiIBJiCgIhIgCkI\niIgEmIKAiEiAKQiIiASYgoCIzFuRsQidOzrp3NFJZCxS6epUJXPOlX4Qs4eAzwFvO+cuy1LmPuBa\n4LfAF5xzz2cp5/yok4gEW2QswoaHNzB1egqA+rp6dt+wm66WrgrXzH9mhnPOivmsXy2B7wFZ/7Jm\ndi2wzDn3cWAr8B2fzisi4ql/b38iAABMnZ6if29/BWtUnXwJAs65PcDxHEWuA34QK/sLYLGZXeDH\nuUVECrXvrX3qGkpTrpzAxcDrSc/fjL0mIjInutd0U19Xn/La5NQkQ0eG2PDwBgWCmLpKV8BLb29v\n4nF7ezvt7e0Vq4uI1Kauli5237Cb/r397HtrH5NTk4n34l1DyfmByFgk0V3UvabbM3dQSJlyGB4e\nZnh42Jdj+ZIYBjCzZuCnXolhM/sO8JRz7uHY88NAm3PubY+ySgyLiK86d3QydGQo5bWOpR0MbhkE\nCksiV3OiuRoSwwAW+/EyAHwewMxWA7/xCgAiMn9FxiKsemAVS+5ZwqoHVpW1Oya9a6i+rp7uNd2J\n54UkkedrotmXIGBmO4F/BT5hZq+Z2RfNbKuZ3QzgnHsceNXMxoAHgC/7cV4RqQ2RsQjrd61n9Ngo\nk1OTjB4bZf2P1mcNBH6P7493DbU2tdJQ38DyxuUlH3O+8CUn4JzbXECZr/pxLhGpPf17+5k+M53y\n2vTMdEa/PGR2u4yMj7Di91fQuLCx5H74wxOHmTo9xeTUJBse3pDozule082e1/akdPUktxSAgsrU\nIs0YFqlxft41V8MM2/Rul+kz04weGy15VE+u7px4S6FjaQcdSzs8+/oLKVOLqnJ0kIgUJv2uec9r\ne4q+OM3mWLMdJdO9ppuR8ZGU1kA4FJ71nbTXqB6/dLV05T1uIWVqjVoCIjXMz2RloceKB4uhI0MM\nHRli/Y/Ws+qBVTlbD10tXQxsGkj0ybc2tTJw44DnBbWtuY2QZb80TZyaKKq1ki85HFRqCYhIXsl3\n/hOnJlK7a2ai3TWQu/VQyF10ZCxC3zN9zLgZAAxjgS3gtDsNRFsPB//zYKJFMZuWT/K8AajsOP9q\n4ts8Ab9onoBI4UoZu57epQN4Hiv99ZCFEhdpL8nj72fLazx/a1MrjQsbgWgAigccP843X5QyT0At\nAZEaVuzdbbb+//ixJk5NACQeJ9/55woAc6FxYWPiIt+5o7Os5w4CBQGRGldMsjJb/3/8Ypty558l\ndbjsvGWc+5FzU7pnSu1nzzcMc74O06wkJYZFqlSlhmumB4gZvO/8j79/nP1b93PH1XfQUN9AQ30D\nPVf1lNTPnm8Y5nwdpllJygmIVKG5XKcmMhZh25PbOPD2gUTXTn1dPT1X9TAyPpKx2Bp45wEa6hvY\nuXFn1a6nEyTVsnaQiPikkOGakbEILfe1cNY3zuLc7efS93Rf3uPGg8vosVFm3AwhQiw7bxkXLbqI\nO4bvYOjIUEYAqK+rZ8ulWzKOdcvqW+btejpBopyASA2KjEX43M7PJYZOnpg+we1P3Q5Az9U9WT/n\n1dXz6m9e9Uz2NtQ3cMWFVySSzR9f8nHuffZeIBoAeq7uYWTHiJ+/llSAWgIiVaituS3na/17+xMB\nINndP7971ufKNtrniguvYHDLYKJrp+fqHt75+ju88/V3EoFGE7Bqn4KASBUaGc+8w3700KNAtBWw\n7619np87MX0i5+zd9It2tpE/hV7MlaitfUoMi1SZyFiEzY9tzuibB7jpspt45JePpHTpZJMtSZs8\nSaytuY2+Z/pShoOubFrJ9mu262JeQ0pJDCsIiFSR9FFBpSpkNm21bJkoxdOMYZF5Ij1xWw7zcWVM\nKZxyAiLzlGFK0kpeCgIiVSQ9cevlwnMuzJrQFZkt/SeJVJHk0TYN9Q2eZf74/D/m8b94nI6lHbQ2\ntRIOhT3LOZwmbkleCgIiVaarpYvBLYPs3LiT8ILUC3x8N654mf1b93NH2x05N2ERyUX/OSJVKr4b\n17LzllEXquOc8Dnc0XZHRhJ3ZHzEc8KXJm5JIRQERKrc0RNHOT1zmpPTJ+l7pq+gFUUb6hs0cUsK\noiAgUsUKWaDNa+mGnRt3KgBIQXwJAma2zswOm9lLZnarx/tLzOwJM3vezF40sy/4cV6R+S6+w1eu\n17R0g5Si5BnDZhYCXgKuAY4CzwGbnHOHk8rcCZztnNtmZo3Ar4ALnMtcAUszhkU+tOqBVRl76i4K\nL2L1R1drdq8kVHo/gSuBl51z4865D4BdwHVpZY4Bi2KPFwHveAUAEUkV32A92YnpEwwdGWLDwxvK\nuuOYzE9+BIGLgdeTnr8Rey3Zd4EVZnYUOAB8zYfzisx7uSaPaQMX8UO51g7aBhxwzq01s2XAkJld\n5pw76VW4t7c38bi9vZ329vayVFKk2sT7+/v39ntu+yjBNDw8zPDwsC/H8iMIvAlckvT8o7HXkn0a\n6ANwzr1iZq8Cy4F/9zpgchAQkajmxc2cnD7J9JlpQPMAgiz95viuu+4q+lh+BIHngBYzawbeAjYB\nN6aVOQT8GfBzM7sA+ARwxIdzi8xr6UtLh0NhWptaaVzYqMSw+KLkIOCcO2NmXwUGieYYHnLOHTKz\nrdG33YPAduB7ZnYAMODrzjm1a0XySJ8nMD0zTePCxrx7BIgUypecgHPun4A/SnvtgaTHE8B/8+Nc\nIiLiH80YFqli2shd5pq2lxSpctr+UfLRHsMiIgFW6RnDIiJSoxQEREQCTEFAal5kLELnjk46d3Rq\nLR2RWVJOQGpa+mSq+rp6LaUsgaOcgASW16Yrmx/bnNIqKKWloFaGzHdqCUhN69zRydCRIc/36uvq\n6bmqh75n+opqKaiVIbVCLQEJrHxLLd/77L15t2fMppCtHUVqnYKA1LTkrRUb6hsqXR2RmqMgIDWv\nq6WLwS2D7Ny4M2OJhVtW3zKrZReScwBtzW1askHmPeUEZF7xWmKh0GUXvHIAPVf1MDI+kvezIpWk\nZSNECpAvGHglmTuWdmjZZql6pQSBcm0vKVJR6Xf5e17bo5E+IignIAFRyEgfLdssQaSWgEhM8qbu\noByABINyAhIImvgl85kSwyIFKGZzFm3oIrVAQUDEZ5GxCNue3MaBYweYYQZQ60Gql5aNECF64V71\nwCqW3LOEVQ+sKnrBt3jX0eix0UQAAC0bIfOTgoDMC5GxCOt3rWf02CiTU5OMHhtl/Y/WFxUI0kcS\nicxnCgIyL/Tv7Wf6zHTKa9Mz077euYdDYQ0ZlXnHlyBgZuvM7LCZvWRmt2Yp025mo2b2H2b2lB/n\nFYFoK2DfW/s83zty/EiiTKH7AmRdmdQ+PJ/2GJD5ouTEsJmFgJeAa4CjwHPAJufc4aQyi4F/BTqd\nc2+aWaNzbiLL8ZQYloKlD/30ctNlN/HILx9JlAkRYmXTSrZfsz2R5E0fBQSw+bHNTE5NphyrtamV\nwxOHNdRUqkpFRweZ2WrgTufctbHntwHOOXd3Upn/CVzonLujgOMpCEjBcm0qE7fAFnDGncl4PX4B\nBzznEPTv7c849qLwIk5Mn0h5raG+gZ0bdyoQSMVUenTQxcDrSc/fiL2W7BNAg5k9ZWbPmdkWH84r\nUhCvAAAfjvbZ9uQ2zy0q25rbCC8Ip3zm5PTJjONMTk2y4eEN6hqSmlSuZSPqgFXAnwK/B+w1s73O\nuTGvwr29vYnH7e3ttLe3l6GKUmsiYxEmTk0QshAzLjqUsy5Ux+mZ0wUfY+LUBAeOHch4fXJqkr5n\n+vjYuR/jleOvJF53eLdS4wFFrQEph+HhYYaHh305lh9B4E3gkqTnH429luwNYMI59z7wvpk9DawE\n8gYBES/puYDkfv6vPP6VlAt3NvG7/OS5AMmmTk8x/u64f5UW8Un6zfFdd91V9LH86A56Dmgxs2Yz\nCwObgIG0Mj8BPmNmC8xsIfAnwCEfzi0BlT6Wf4YZGhc20tXSxf2fvT/rvsNGUrepg/d+917O86S3\nKkLm/ZXRiqNSq0oOAs65M8BXgUHgILDLOXfIzLaa2c2xMoeBCPAC8CzwoHPul6WeW8RLfDXQ1qZW\nQkn/4iFCKd050zPTHD1xtKBjNtQ30LG0g79t/9uUABMiRGtTq0YISc3S2kFSkwpdFTR56OfEqQlG\nj40Wdb74DmPxNYXG3x2neXFzyjBTkUrRAnISSLNd4TMjj5CUUM4lZCFWXrCS6z95PX3P9GUNPPnq\nowAic0VBQKRAs2kZLAov4rfTv00kjkOEMpLIyS2EXC2TyFiE9T9az/TMh0tbhBeEGdg0oEAgJav0\nPAGRiup7uo8l9yxhyT1L6Hu6L+P95GUeAAa3DDK4ZZDt12zPedyzFpyVctH3GkW07619icCSa/vK\n/r39KQEAYPqMv2sbiRRD20tKTet7uo/bn7o98Tz+uOfqHiD3BvNdLV0sO2+Z53DS+rp6zjv7vIxl\nI9LFJ4otb1zu168kUlZqCUhNu/fZe3O+lu8O/f7P3k849OGsYMNobWql56oeXn8veSJ8dsldQHHp\nQ0a713SnnAei3UEaViqVpiAgNSsyFsk7zj/f5/v39rPi/BW0NrXSsbSDJ/7iCfZv3c/I+EjG0tS5\nNC5sZPcNu+lY2kHH0o7EmkTJ3VADNw7Q2tRKQ30DrU2tygdIVVBiWGpSrtVDv7n2m1m7g/ItGhe/\nKLfc11LQrGOILlVxdt3ZhBeEuWX1LfRc3aON7aWsNDpIAsdr9dC6UB29bb2JABAXv+OfOPXh6uVj\nk2NZVwMFuPaH12ZdJyg+ZBSiM47Tg8VNl93ET1/+aUY+IT6SSMRvpQQBJYZl3lj7B2szAgCQuPvO\nt+9AcpI3WwAAcM4lxvgvuWdJxvvff+H7RdRepDKUE5CalL77VzgUZuLURNbdvgrdN7iQReMcjs2P\nbaZzR2fBeQMlgaVaKQhITYqvD9SxtIPWplYwGD02ytCRoZLX9m9e3Jx1Abq4yalJho4M8f7p9ws7\naKxhoa0ppdooJyA1zys/kDyTN54POPjrgxkTttIlJ463PbmtoLWGvGYSe9HWlDJXlBMQ8ZA+Qie8\nIBxtNWTRuLAx0WXTv7efxoWNWSeTJUsPAOFQmI/UfSQj8Tz+7rjnnAUFAakkBQGped1rutnz2p6U\nO+zuNd0ZeYB4/33jwsbE57wWeUsPHOFQONGCCIfCYGTNBbQ2tXL9J6/n0UOPcuDYgUSAqK+rp3lx\nc94ZyCLlpu4gmRe8VvD06iZKXjnUa5G3zY9tzrhQLwov4qwFZ3He2edx7kfOBbyHmEJml098OGl8\nnSLNHZC5oHkCIh7S7+oNyxj6mW0V0HzqrI7TLnMv44b6hpzzA2a7/LVIIZQTEPEQH0EUTww/f+z5\nrGULHUIa5xUAQoQ8F51LnqQWX7gOPhwpBKTkIuLPFSCkHBQEZF6LX3Q7d3R6TgCbODXh21DNlU0r\nPdcy8notveUx8v9GUnINyaudiswlzROQQBs9NsqGhzfQ1tyWd25AsnAoTF3ow3uo8IIw26/ZzvH3\nj2eU9XotI2k9M52SbE5f7VRkrigISCCkzzBONnV6ipHxEXbfsJuG+oaM9+ObzN902U2cEz6HulAd\nSxYuIaVhEXvcvLg54/Ner4lUCwUBCYTkGcZeF/p4mZ0bd2bsC7Bz406613Tzwxd/yMnpk5yeOc1b\nJ99KyQtMz0R3Cdt+zXbCCz7cNyDeQkiXKyjFP6dlJqQcNDpIAid9NzLIXH46PUG76oFVeWcPtza1\n0riwMZEIjk8+62rpyljJtHFhI23NbYyMj7DvrX0ZyeTWplb2b93vy+8r859GB4nMwsj4iOdrPUSD\nQPIInrh8i8qFQ2EO/ufBRL9+8hyAbMNP48nf/r39GfMZ4hPaROaaL91BZrbOzA6b2UtmdmuOcp8y\nsw/MbKMf5xUpF69+/fhWlB1LO1hx/oqsid1tT27zHH4aL5PeNZS+NaXIXCo5CJhZCPg20AWsAG40\ns4xdt2PlvgVo6USpqGIuuul9/YbxjbXfYP/W/QxuGcx65x4Zi3Dg7QM5j52cr4hvTamhoVIuJecE\nzGw1cKdz7trY89sA55y7O63c14Bp4FPAz5xzj2U5nnICMueKmbmb6zPZtpP06uqJC1mIxzc/rgu+\nlKyiy0aY2fVAl3Pu5tjz/wFc6Zz766QyFwE/dM6tNbPvAT9VEJD5ptD1iyA6u3jLZVs4evJoSnmR\nYtRCYvjvgORcQc7K9vb2Jh63t7fT3t4+J5USKUXyRT8+0gdSL+jda7oZGR9J5AsM4/Kmy7n+k9fT\n90xfouWgGcIyG8PDwwwPD/tyLL+6g3qdc+tizzO6g8zsSPwh0Aj8FrjZOTfgcTy1BKTq5VpwLn1k\n0PofrU9ZinrgxgHPbiJtRC/FKqUl4MfooOeAFjNrNrMwsAlIubg755bGfv4Q+DHwZa8AIFIrci04\nlzwyqH9vf8puZvFJZbloC0opp5K7g5xzZ8zsq8Ag0aDykHPukJltjb7tHkz/SKnnFKl12TbCSW9h\nqJtI5ppmDIsUITIWYf2u9Z47jIUIsbLJeyOZ8IIwK35/RcqMYcidSFY3keRTC4lhkfkn6V6lLlRH\n8+Jmjhw/wgwzjB4bZf2u9QxsGkjZ0+Dgrw8mlp/QXb5UA7UERIrgdce+KLwoY8vJ5DWAvD7TUN+Q\nmI0cbx0kjxrSFpRSCLUERKqAV6I435pDk1OTKYvH7XltDz1X9XgONxWZCwoCIkXwSuxetOgiXjn+\nSkq5884+L+tnvMT3NlAOQMpF+wmIFMFrvZ/7P3s/dZZ6X/X6e68nhnkWsqeBSLkpJyDiI699B7xG\n92SbbKYcgBSj0pPFRCTGazXRfW/ty5j4ldwqaG1qTSxJrQAg5aaWgIiPCl1OQsRPagmIVIlc/f7J\ny0mIVAsFARGfdbV0MbhlkCsuvKLSVRHJS0FAZI60NbcV9JpIJSkIiMyRbBvai1QTBQERkQBTEBCZ\nI8VsaC9SbhoiKjKHitnQXmS2KrrRvN8UBEREZkfzBEREpCgKAiIiAaYgICISYAoCIiIBpiAgIhJg\nCgIiIgGmICAiEmC+BAEzW2dmh83sJTO71eP9zWZ2IPazx8wu9eO8IiJSmpIni5lZCHgJuAY4CjwH\nbHLOHU4qsxo45Jx718zWAb3OudVZjqfJYiIis1DpyWJXAi8758adcx8Au4Drkgs45551zr0be/os\ncLEP5xURkRLV+XCMi4HXk56/QTQwZPMl4AkfzisBEF97Z+LUBO/97j2Ov3+c5sXNbL9mu9bhEfGB\nH0GgYGa2Fvgi8Jlc5Xp7exOP29vbaW9vn9N6SXXKtl/v5NQk63etZ2DTgAKBBNLw8DDDw8O+HMuP\nnMBqon3862LPbwOcc+7utHKXAY8C65xzr+Q4nnICAkDnjk6Gjgxlfb9jaQeDWwbLWCOR6lTpnMBz\nQIuZNZtZGNgEDKRV8BKiAWBLrgAgIiLlVXIQcM6dAb4KDAIHgV3OuUNmttXMbo4V+xugAfgHMxs1\ns38r9bwyv0TGInTu6KRzRyeRsQiQuSlLsvCCsDZoEfGB9hOQikvv+6+vq2f3DbvpaulSYlikAKV0\nB5U1MSzipX9vf0ryd+r0FNue3JZ4D+D6T17PyPgIE6cmUl5XIBApjVoCUlGRsQibH9vM5NRkyuuG\ncdaCs5g+M531s8ktBpEgq3RiWKQo8W6g9AAA4HA5AwBEWwzxFoGIFEdBQComvRtIRMpPQUAqJt6/\nn42Ru3VbX1evEUIiJVIQkIqIjEUYmxzLWcaRPTfU2tSqfICIDzQ6SMqq7+k+7v753ZyYPlHU50MW\nYsulWzh68qhGCIn4QKODpCTxcfwQndyV64Lc93Qftz91e9HnigeAR375iOecApGgKmV0kIKAFMTr\nYl/IJK/k8kvuWeI5EgigLlTH6ZnTeevRUN+QcQytISRBp8liMqfSL/Z7XtvD7ht2e07yil/4vcrn\n0ry4mVeOpy4rtcAWcMad8fNXEZE0SgxLXrku9rMpf8vqW3KeJ3mdoHAojFnqjU04FOaW1beklNMI\nIZHSKAhI0brXdBMOhRPPw6Hci7r1XN3DN9d+k7pQZgP01eOv0nNVDx1LO+hY2sGK81dkdA+tOH8F\nPVf3sPuG3YlyygeIlEbdQZJXW3MbTx55khlmgLS77+Sb9djj7jXd7HltT0quIF6+5+oegIwE8Qwz\njIyPJPr2O3d0ZtSjcWEjEB0NpAu/iD8UBOaZ2YzWyff5tuY2Hj30KAfePpAIACEL0XNVD10tXXTu\n6ExZ2mH6zDT9e/sZ3DKYyBmk1yMyFqHvmb689cgVSETEPwoC80i2BG6hgSD98167es246B17Dz05\nj5Xtbj3bUhHpF/mulq6sgURE/KMgMI9kS8gWevGc7Vo+ft2tN9Q3sHPjzox6qttHZO4pCNQYP7t7\n4hfs+PN8a/kAhAglPlfM3bpX4PAKACJSHposVkNyTc7K9n7PVT2MjI8A0T7+vmf6Eu+HF4TBwfRM\ntF8/HAqDkXMJ59amVvZv3V/y76FuHhH/aMZwQHTu6Mzop0+fLZue2E2+6IcsxIybyXmO1qbWxCic\ni865iB0v7kh8Rks0iFQnzRiWhOR+9M4dnSl9/PkCAESHYSYHlRsvvVF37SLzmIJADWlrbstoCbQ1\ntyUee/X355LeHeSV2FVyVmR+UxCoIfG+/fTXeujxHB7ac1VPyiSvuIb6Bq648IqMxLDu9EWCR0Gg\nwvxKknoNDx0ZH2Fl00pGj42mlL3iwitSunx04RcJLl/WDjKzdWZ22MxeMrNbs5S5z8xeNrPnzexy\nP85b6+J370NHhhg6MsSGhzcQGYtkLd+9pjvr4mlewzsnTk2w/ZrtWnBNRLIqOQiYWQj4NtAFrABu\nNLPlaWWuBZY55z4ObAW+U+p5a1lkLELnjk42P7Z5Vqtzxsflz2bxtGI+IyLB4Ud30JXAy865cQAz\n2wVcBxxOKnMd8AMA59wvzGyxmV3gnHvbh/PXlPS++3T73tpHZCyS9UKdLVEbH9bp9ZqSuyKSjR/d\nQRcDryc9fyP2Wq4yb3qUmRfid/mdOzo9u3byLc0wOTWZt1vIS66uIhGRbKoyMdzb25t43N7eTnt7\ne8XqMhvFLuCWvrXibNf8AS24JhIkw8PDDA8P+3KskmcMm9lqoNc5ty72/DbAOefuTirzHeAp59zD\nseeHgTav7qBanjFc6Ize9KUdljcuzxjBo31zRaRQpcwY9qM76DmgxcyazSwMbAIG0soMAJ+HRND4\nTRDzAeCdqNUIHhGpFF/WDjKzdcDfEw0qDznnvmVmW4m2CB6Mlfk2sA74LfBF55znKmS13BLIt8Bb\nvs+qK0dEiqEF5KqILuYiUm4KAiIiAVbpnICIiNQoBQERkQBTEBARCTAFARGRAFMQEBEJMAUBEZEA\nUxAQEQkwBQERkQBTEBARCTAFARGRAFMQEBEJMAUBEZEAUxAQEQkwBQERkQBTEBARCTAFARGRAFMQ\nEBEJMAUBEZEAUxAQEQkwBQERkQBTEBARCbCSgoCZnWdmg2b2KzOLmNlijzIfNbN/MbODZvaimf11\nKecUERH/lNoSuA34Z+fcHwH/AmzzKHMauMU5twJYA3zFzJaXeN6qNDw8XOkqlET1ryzVv7Jqvf7F\nKjUIXAd8P/b4+8Cfpxdwzh1zzj0fe3wSOARcXOJ5q1Kt/xOp/pWl+ldWrde/WKUGgfOdc29D9GIP\nnJ+rsJn9AXA58IsSzysiIj6oy1fAzIaAC5JfAhxwu0dxl+M45wA/Br4WaxGIiEiFmXNZr9v5P2x2\nCGh3zr1tZk3AU865T3qUqwN+BjzhnPv7PMcsvkIiIgHlnLNiPpe3JZDHAPAF4G7gJuAnWcr9I/DL\nfAEAiv9FRERk9kptCTQAjwAfA8aB/+6c+42ZXQh81zn3OTP7NPA08CLR7iIH/G/n3D+VXHsRESlJ\nSUFARERqW0VnDNfqZDMzW2dmh83sJTO7NUuZ+8zsZTN73swuL3cdc8lXfzPbbGYHYj97zOzSStQz\nm0L+/rFynzKzD8xsYznrl0+B/z/tZjZqZv9hZk+Vu47ZFPC/s8TMnoj9379oZl+oQDWzMrOHzOxt\nM3shR5lq/u7mrH9R313nXMV+iOYSvh57fCvwLY8yTcDlscfnAL8CllewziFgDGgGzgKeT68PcC3w\nf2OP/wR4tpJ/5yLqvxpYHHu8rtbqn1TuSaIDEjZWut6z/PsvBg4CF8eeN1a63rOo+53A9ni9gXeA\nukrXPanPRg0cAAAC70lEQVR+nyE6TP2FLO9X7Xe3wPrP+rtb6bWDanGy2ZXAy865cefcB8Auor9H\nsuuAHwA4534BLDazC6gOeevvnHvWOfdu7OmzVNfkvkL+/gB/RXRI8q/LWbkCFFL/zcCjzrk3AZxz\nE2WuYzaF1P0YsCj2eBHwjnPudBnrmJNzbg9wPEeRav7u5q1/Md/dSgeBWpxsdjHwetLzN8j8Q6eX\nedOjTKUUUv9kXwKemNMazU7e+pvZRcCfO+f+D9F5LdWkkL//J4AGM3vKzJ4zsy1lq11uhdT9u8AK\nMzsKHAC+Vqa6+aWav7uzVdB3t9QhonlpslntMrO1wBeJNkFryd8R7V6Mq7ZAkE8dsAr4U+D3gL1m\nttc5N1bZahVkG3DAObfWzJYBQ2Z2mb6z5TWb7+6cBwHnXEe292IJjgvch5PNPJvusclmPwZ2OOey\nzUUolzeBS5KefzT2WnqZj+UpUymF1B8zuwx4EFjnnMvVfC63Qur/X4FdZmZE+6WvNbMPnHMDZapj\nLoXU/w1gwjn3PvC+mT0NrCTaH19JhdT900AfgHPuFTN7FVgO/HtZali6av7uFmS2391KdwfFJ5uB\nT5PNyuA5oMXMms0sDGwi+nskGwA+D2Bmq4HfxLu9qkDe+pvZJcCjwBbn3CsVqGMueevvnFsa+/lD\nojcPX66SAACF/f/8BPiMmS0ws4VEE5SHylxPL4XU/RDwZwCxvvRPAEfKWsv8jOytw2r+7sZlrX9R\n390KZ7obgH8mOuJnEPgvsdcvBH4We/xp4AzRkQijwH6iEa6S9V4Xq/PLwG2x17YCNyeV+TbRO7cD\nwKpK1ne29Sfar/tO7G89Cvxbpes8279/Utl/pIpGB83i/+d/ER0h9ALwV5Wu8yz+dxqBn8b+718A\nbqx0ndPqvxM4CvwOeI1ol0ktfXdz1r+Y764mi4mIBFilu4NERKSCFARERAJMQUBEJMAUBEREAkxB\nQEQkwBQEREQCTEFARCTAFARERALs/wOqSgyLwl/38AAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot results\n", "X_ = xbatch\n", "mus = mu.eval(feed_dict={x:xbatch}) #[nb x M]\n", "sigmas = sigma.eval(feed_dict={x:xbatch}) #[nb x M]\n", "pis = pi.eval(feed_dict={x:xbatch}) #[nb x M]\n", "Y_ = sample(mus,sigmas,pis)\n", "plt.scatter(X,Y_,color='g')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Results are similar to those achieved by Karpathy" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.10" } }, "nbformat": 4, "nbformat_minor": 0 }