{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#Reparameterization Trick" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we will understand the reparameterization trick used by Kingma and Welling (2014) to train their variational autoencoder. \n", "\n", "Assume we have a normal distribution $q$ that is parameterized by $\\theta$, specifically $q_{\\theta}(x) = N(\\theta,1)$. We want to solve the below problem\n", "$$\n", "\\text{min}_{\\theta} \\quad E_q[x^2]\n", "$$\n", "This is of course a rather silly problem and the optimal $\\theta$ is obvious. We want to understand how the reparameterization trick helps in calculating the gradient of this objective $E_q[x^2]$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One way to calculate $\\nabla_{\\theta} E_q[x^2]$ is as follows\n", "$$\n", "\\nabla_{\\theta} E_q[x^2] = \\nabla_{\\theta} \\int q_{\\theta}(x) x^2 dx = \\int x^2 \\nabla_{\\theta} q_{\\theta}(x) \\frac{q_{\\theta}(x)}{q_{\\theta}(x)} dx = \\int q_{\\theta}(x) \\nabla_{\\theta} \\log q_{\\theta}(x) x^2 dx = E_q[x^2 \\nabla_{\\theta} \\log q_{\\theta}(x)]\n", "$$\n", "\n", "For our example where $q_{\\theta}(x) = N(\\theta,1)$, this method gives\n", "$$\n", "\\nabla_{\\theta} E_q[x^2] = E_q[x^2 (x-\\theta)]\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reparameterization trick is a way to rewrite the expectation so that the distribution with respect to which we take the expectation is independent of parameter $\\theta$. To achieve this, we need to make the stochastic element in $q$ independent of $\\theta$. Hence, we write $x$ as\n", "$$\n", "x = \\theta + \\epsilon, \\quad \\epsilon \\sim N(0,1)\n", "$$\n", "Then, we can write \n", "$$\n", "E_q[x^2] = E_p[(\\theta+\\epsilon)^2]\n", "$$ \n", "where $p$ is the distribution of $\\epsilon$, i.e., $N(0,1)$. Now we can write the derivative of $E_q[x^2]$ as follows\n", "$$\n", "\\nabla_{\\theta} E_q[x^2] = \\nabla_{\\theta} E_p[(\\theta+\\epsilon)^2] = E_p[2(\\theta+\\epsilon)]\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let us compare the variances of the two methods; we are hoping to see that the first method has high variance while reparameterization trick decreases the variance substantially." ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.86872102149\n", "4.03506045463\n" ] } ], "source": [ "import numpy as np\n", "N = 1000\n", "theta = 2.0\n", "eps = np.random.randn(N)\n", "x = theta + eps\n", "\n", "grad1 = lambda x: np.sum(np.square(x)*(x-theta)) / x.size\n", "grad2 = lambda eps: np.sum(2*(theta + eps)) / x.size\n", "\n", "print grad1(x)\n", "print grad2(eps)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us plot the variance for different sample sizes." ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 4.10377908 4.07894165 3.97133622 4.00847457 3.99620013]\n", "[ 3.95374031 4.0025519 3.99285189 4.00065614 4.00154934]\n", "\n", "[ 8.63411090e+00 8.90650401e-01 8.94014392e-02 8.95798809e-03\n", " 1.09726802e-03]\n", "[ 3.70336929e-01 4.60841910e-02 3.59508788e-03 3.94404543e-04\n", " 3.97245142e-05]\n" ] } ], "source": [ "Ns = [10, 100, 1000, 10000, 100000]\n", "reps = 100\n", "\n", "means1 = np.zeros(len(Ns))\n", "vars1 = np.zeros(len(Ns))\n", "means2 = np.zeros(len(Ns))\n", "vars2 = np.zeros(len(Ns))\n", "\n", "est1 = np.zeros(reps)\n", "est2 = np.zeros(reps)\n", "for i, N in enumerate(Ns):\n", " for r in range(reps):\n", " x = np.random.randn(N) + theta\n", " est1[r] = grad1(x)\n", " eps = np.random.randn(N)\n", " est2[r] = grad2(eps)\n", " means1[i] = np.mean(est1)\n", " means2[i] = np.mean(est2)\n", " vars1[i] = np.var(est1)\n", " vars2[i] = np.var(est2)\n", " \n", "print means1\n", "print means2\n", "print\n", "print vars1\n", "print vars2" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python2.7/dist-packages/matplotlib/__init__.py:872: UserWarning: axes.color_cycle is deprecated and replaced with axes.prop_cycle; please use the latter.\n", " warnings.warn(self.msg_depr % (key, alt_key))\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAECCAYAAADjBlzIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3X28HFWd5/HPCQ8BQQwoGFEBZTcLzvIkFIjAku0ruiOs\nM4KeNeygo/CCCSBEd5p1zTqKI44zpU7CU5SXOqgD0SO44sOMD3QNw/NSCQIBBHQQYYEYnomBBEhq\n/+i6pLncvl19u6rrnNvf9+vVr6T61q365iT5dd1fVZ0yWZYhIiLhmlV3ABERGYwKuYhI4FTIRUQC\np0IuIhI4FXIRkcCpkIuIBG7LIitZa88B3gNsAP7aOXdFpalERKQw0+s6cmvtgcCFwNuAnYBfAns7\n59ZVH09ERHop0lrZE7jFOZc55x4DHgSiamOJiEhRRVordwKftNZuA+wC7A3MrTSViIgU1vOI3Dl3\nO3AxcANwAZAA66uNJSIiRfXskU9krb0BOM05d/PEr7VaLU3cIiIyDWNjY2a631v0qpWdnHOPW2uP\nBOZMVsTLCCMiMooGPQgueh35P1hr7wCWAH82yA59YIyZX3eGIpSzXCHkDCEjKKdvCh2RO+f+pOog\nIiIyPSN5Z2eWZVfVnaEI5SxXCDlDyAjK6ZuRLOQiIjNJodbKTGOMmR/CJ7VyliuEnCFkbLVas1at\nWnXePvvs8yjg9ZVq99xzzx7z5s27r+4cgAGeApaMjY1tKnvjI1nIRWQgi2666aabFy1a9PW6g/Ty\njne8w5sPxlartR+wCPhy2dseydaKL3+xvShnuULIGUJG4FWXXnqp90Uc/BrPsbGxW4FXVbHt0gt5\nFCe7l71NEfGK1+0Uz1UydlUckZ9WwTZLFcq1pcpZrhByhpARlNM3VRTyj0Rxsl0F2xURkUlUUciv\nBj5YwXZL41PfbCrKWa4QcoaQEZRzImPMp40x3xrGviZTRSFfApwZxclInkgVERm2KortNcCzwDsr\n2HYpQumbKWe5QsgZQkbwN6cx5rfGmM8YY35tjHnUGPPFjq/tYIy5xBjzSP71E/vY7u7GmE3GmPcZ\nY+4zxjxljGkaYw43xqwF/hfw34wxa40xTxtjXlPFn6+b0gt52mxktI/KF5W9bRGRAg4G9gFOAv7C\nGDNe584BtgfeALwXiI0x+/W57fcDBwCvBX6eZdm1WZa9Evgb4LtZlr0yy7Idsix7tIw/SFFV3RD0\nHeBvozjZO202flXRPqZN/b1yKWd5QsgIU+eM4qSUS+zSZmO6U2J/I8uy9caYnwCvoP1Es4eA/wr8\neZZlG4DbjTE/Av4EuLWPbZ+dZdkT+e/7+b5KVVLI02ZjQxQnXwHOABZWsQ8R8dMABbgsjwNkWfa8\nMQZgm/z91wKrO9ZbTX+PrcyA35QRsGxVnpD8CvCBKE52qnAf0+Jrf28i5SxXCDlDyAjh5JxgDS8t\n3HOB3/e5jRe6vF/6/Cn9qKyQp83GauCHtPtUIiJ1+yHwMWPMtsaYfWi3Wn7Ux/dP9ZPGamAvY8wW\ngwScrkKF3FrbtNaustbebq39qz62vxQ4PYoTrybnmgl9SJ8oZ3lCyAhe55zYn+9c/t/AOuAB4Arg\nE1mWdX1sZYFtd3LA08CDxpj7jTGv7mO7A+tZyK21rwdOBvajfbb2Q9baQvOppM3GzcB9tM8Qi4hU\nKsuyN2dZlnQsb5Fl2b3575/Ksuz4LMtek693UR/b/V2+rUlbKPm235Fl2dwsy3bLsuyxwf80xRVt\nrWwBbEv7pMEG2vPqFuXdpYih9PeUs1wh5AwhIyinb3oWcufcg8C5wP3560vOuSf72McVwOujOImm\nF1FERKZSpLUyB/hjYHfg3wFNa+1ru63f+QlojJm/4qyxI4DzgDONMfMnfr2O5fH+ni95ui2Pv+dL\nHo1n9csTs9adp9tylmVX+ZSn2zIdfMhzzz337NHt64MwWTb1tfvW2uOAdznnTs6XlwPfdM79dOK6\nrVYrGxsbe9mZ3ShO5gD3Av8xbTYeKiO4iNSj1Wp9Zmxs7DN15whRt7HrVjuLKtIjXw1E1tqtrLXb\nAm8FftvPTtJm40ngUjy5OaisT8GqKWe5QsgZQkZQTt8U6ZFfB/wMuA1IgYucc3dPY1/nAidHcbJN\nzzVFRKSwQtd3O+c+AXxikB2lzcY9UZysAI4HvjHItgbl8TWwL6Gc5QohZwgZQTl9M+w5w5cAi6I4\nqXsuBhGRGWPYhfxK2tek/+ch7/clQumbKWe5QsgZQkZQTt8MtZDnc5UvBc4c5n5FRIwxRxpjHqg7\nRxXqeBzbPwJvj+Jkzxr2DYTTN1POcoWQM4SMEGxOw9TzpQRr6IU8bTaeAb4GfHTY+xaRmc+0H/d2\nmjFmhWk/eu0K034c20+AXc3mx7HNmDmg6npA8gXACVGc7FDHzkPpmylnuULIGUJGCCLnycCfAe8B\nzskfx/Zu4KGOx7H9n1oTlqiW6WXTZuP/RXHyC+DDtHvmIjJDLF64vJT2xTnLFgxyddtXsyy7yxgz\nN8uym8rI47M65wlfAlwSxcn5abOxcZg7DrS/5y3lLE8IGWHqnAMW4LL8BsIZz0HV1VohbTZuBB4F\njqkrg4jMWJM9kq3Wx7FVqbZCnltCDZciBtDfA5SzbCHkDCEjBJtzNbCzMWaXmuJUpu5CfhkwL4qT\nfWvOISIzx6Q9+izLfk37irlbTftxbDOmG1BrIU+bjeeBCxnyUXkofTPlLFcIOUPICH7n7Hzc28Sc\nWZZ9NMuy1+WPY/txLQErUPcROcBFwLFRnOxcdxARkRDVXsjTZuNR2i2WU4a1z0D7e95SzvKEkBGU\n0ze1F/LcUuDUKE62rjuIiEhovCjkabNxO3AnYIexP5/7e52Us1wh5AwhIyinb3reEGStfSfwt7TP\nBBvgLUDknLut5CxLgE9HcXJJPkuiiPjJhxt+QlXJ2BV51NvPnXMHOOfeCvwxcF8FRRzgn4A5wNsr\n2PZLhNI3U85yhZAzhIzAU8cff/yJdYcowqfxbLVa+wFPVbHtfm/R/wDtE5OlS5uNTVGcnEv7UsTr\nqtiHiJRiycEHH3xeq9XaDc+nhb3wwgv3aLVa8+vOQftI/CnanYfS9VvI/zvwkSqC5C4GPhPFye5p\ns/G7qnYSSt9MOcsVQs4QMo6NjW0aGxs7re4cRYyNjdUdYSgKn+y01s4DtnXOrZpqvc4fZYwx8/tZ\nXnHW2IHPrr6vBZw2ne/Xspa1rOVQlwdhsqzYT0bW2s8AG51zf91tnVarlY2NjQ3UzI/i5E1ACuye\nNhvrBtlWN8aY+SEc+ShnuULIGUJGUM6yDVo7+7n88HjgO9PdUVFps/Fb4Brgg1XvS0RkJihUyK21\nBwNrnXO/rjjPuCXAGVGcVHKdewif0KCcZQshZwgZQTl9U+hkp3PuJuDAirN0uhpYD7wT+OkQ9ysi\nEhwv7uycKL8haCkVzYpY1gmGqilnuULIGUJGUE7feFnIc98BDojiZO+6g4iI+MzbQp42G+uBrwJn\nlL3tUPpmylmuEHKGkBGU0zfeFvLcMuADUZzsVHcQERFfeV3I02ZjNfBD4KQytxtK30w5yxVCzhAy\ngnL6xutCnlsKnB7FSb/TCYiIjATvC3nabNwM3Ae8t6xthtI3U85yhZAzhIygnL7xvpDnKrsUUUQk\ndKEU8iuAN0RxEpWxsVD6ZspZrhByhpARlNM3QRTytNl4ATgfHZWLiLxMEIU893Xg3VGc7DrohkLp\nmylnuULIGUJGUE7fBFPI02bjCeBSYGHdWUREfBJMIc+dB5wcxck2g2wklL6ZcpYrhJwhZATl9E1Q\nhTxtNu4GVtCeG11ERAiskOeWAouiOJn20zRC6ZspZ7lCyBlCRlBO34RYyH9Bex71+TXnEBHxQtEn\nBB1irb3VWnuHtbbyx71NpWOu8kXT3UYofTPlLFcIOUPICMrpm56F3FprgG8Bf+Gc+yPyJ9zX7NvA\n26M42bPuICIidStyRH4gsMY5dwOAc+6xaiP1ljYbzwBfAz46ne8PpW+mnOUKIWcIGUE5fVOkkO8G\nPG2t/Wdr7UprrS/XcV8InBDFyQ51BxERqVORQr4N8Hbac4LPBxZZa/fotnJnT8oYM7+q5bTZeGDD\nk2tuXXvvref0+/3j71WZr6TlRZ7l0XhWv7xowO8fyvLEv/u680yxHMR4DspkWTblCtbaMeCzzrnD\n8uVLgW855172dPtWq5WNjY1N+7LAfkVxcijwj8C8tNnYWPT7jDHzQ/iRSznLFULOEDKCcpZt0NpZ\n5GENKbCbtXYO8AywD3DvdHdYshuBR4GjaT9JqJAQ/mJBOcsWQs4QMoJy+qZna8U59zTtS/3+BVgJ\nXOKcu6fqYEWUcSmiiEjoCj0+zTl3OXB5xVmm6zIgjuJk37TZuK3IN4Ty45ZyliuEnCFkBOX0TYh3\ndr5E2mw8B1yA5ioXkREVfCHPXQQcG8XJzkVWDuUTWjnLFULOEDKCcvpmRhTytNl4lHaL5ZS6s4iI\nDNuMKOS5c4FTozjZuteKZV27WTXlLFcIOUPICMrpmxlTyNNmYxVwJ/D+urOIiAzTjCnkuSUUmKs8\nlL6ZcpYrhJwhZATl9M1MK+T/BMwBDq07iIjIsMyoQp42G5to98qnvEEolL6ZcpYrhJwhZATl9M2M\nKuS5i4GxKE52qzuIiMgwzLhCnjYba4FvMsUDMELpmylnuULIGUJGUE7fzLhCnjsfODGKk+3qDiIi\nUrUZWcjTZuNe4BrghMm+HkrfTDnLFULOEDKCcvpmRhby3BLgzChOZvKfUURkRhfyq4ENwFETvxBK\n30w5yxVCzhAygnL6ZsYW8nyu8iVornIRmeEKFXJr7UZr7c35a0nVoUr0HeCAKE726nwzlL6ZcpYr\nhJwhZATl9E2hB0sA65xzb600SQXSZmN9FCdfBc4ATq07j4hIFYq2Vob2QOUKLAMWRHGy4/gbofTN\nlLNcIeQMISMop2+KFvLZ1toV1tprrLVHVJqoZGmzsRr4EXBS3VlERKpQtJC/wTl3EPAx4FJr7ewK\nM1VhKfDRKE62hHD6ZspZrhByhpARlNM3hQq5c25N/usK4CFgj27rdg6cMWa+D8tps7ES+N3jt161\n2Ic8RZeB/X3KE/oyAYwnsL9PeUJfJqDxHITJsmzKFay1OwLPOufWW2v3oH3H5Dzn3LMT1221WtnY\n2JiX/fQoTo4DPpY2G4fXnUVEpNOgtbPIEflewC3W2luAy4ETJyviAbgCeGMUJwfVHUREpEw9Lz90\nzt1Au5gHLW02Xoji5DzgTGPM10M4m22Mma+c5QkhZwgZQTl9M2Pv7Ozi68DRr9r70J3qDiIiUpaR\nKuRps/EEsPzff/hz+/dc2QOhHEkoZ3lCyAjK6ZuRKuS5c4FTojjZpu4gIiJlGLlCnjYbdz+/9vHf\nAsfXnaWXsi5NqppylieEjKCcvhm5Qg7w5B3XX0Z7rnIvL5UUEenHSBbynd92zJeArYD5NUeZUij9\nPeUsTwgZQTl9M5KFPJ+rfClwZt1ZREQGNZKFPO+bfRs4LIqTPWuO01Uo/T3lLE8IGUE5fTOShRwg\nbTaeoX1d+el1ZxERGcRIFvKOvtkFwAejONmhxjhdhdLfU87yhJARlNM3I1nIx6XNxgPAlcCH684i\nIjJdI1nIJ/TNltCeq3yLmuJ0FUp/TznLE0JGUE7fjGQhn+BG4HHg6LqDiIhMx0gW8s6+WX4p4hI8\nvBQxlP6ecpYnhIygnL4ZyUI+icuAvaI42bfuICIi/RrJQj6xb5Y2G88BF+LZUXko/T3lLE8IGUE5\nfVO4kFtrt7fWPmit/XiVgWp0EXBsFCc71x1ERKQf/RyRLwZWVBVkmCbrm6XNxiO0H2V3ytADdRFK\nf085yxNCRlBO3xQq5NbaecDOwMpq49RuKbAwipOt6w4iIlJU0SPyLwBnAzNi2tdufbO02VgF3AW8\nf6iBugilv6ec5QkhIyinb3oWcmvtMcDdzrkHhpDHB0uARZqrXERCUeSI/BDgOGvtnbQnmDrLWrug\n28qdn4DGmPk+Lo/3zSb7+spP/pc/ADsCh9add/y9user1/JU4+nT8vh7vuSZbHli1rrzdFvOsuwq\nn/J0W6aDD3l6/fucLpNlWeGVrbWfBtY657482ddbrVY2NjYW/JFsFCdnAIenzYatO4uIzHyD1k5d\nRz65fwDGojjZbQhxuirr07pqylmeEDKCcvpmy35Wds6dXVUQn6TNxtooTr4FnAb8z7rziIhMZSSP\nyAteW3oe8JEoTrarOE5XoVwDq5zlCSEjKKdvRrKQF5E2G/cC1wIn1J1FRGQqI1nI++ibLQXOjOKk\nlnEKpb+nnOUJISMop29GspD34V+BDcBRdQcREelmJAt50b5ZPlf5UmBRpYG6CKW/p5zlCSEjKKdv\nRrKQ92k5cEAUJ3vVHUREZDIjWcj76ZulzcZ62lPcnlFZoC5C6e8pZ3lCyAjK6ZuRLOTTsAz4QBQn\nO9YdRERkopEs5P32zdJm42Hgx8BJlQTqIpT+nnKWJ4SMoJy+GclCPk1LgdOjOOnrblgRkaqNZCGf\nTt8sbTZWAvcDf1p6oC5C6e8pZ3lCyAjK6ZuRLOQDqO1SRBGRbkaykA/QN/sB8MYoTg4qMU5XofT3\nlLM8IWQE5fTNSBby6UqbjReA84Ez684iIjJuJAv5gH2zrwFHR3HyupLidBVKf085yxNCRlBO34xk\nIR9E2mw8Qftuz4V1ZxERgQIPlrDW7gT8LF83A852zl1RdbAqldA3Oxe4OoqTz+d3flYilP6ecpYn\nhIygnL4pckT+FHCkc+4A2rMALqs2kv/SZuNuYCXQ9SHUIiLD0rOQO+c2OueeyRfnAFtba4O+Kaak\nvtlSYFEUJ5U9bDqU/p5ylieEjKCcvilUkK212wPXA28CTnTOvVBpqjD8HPh74EjgqnqjiMgoK3Sy\n0zn3B+fcvsBBwOnW2i26rdv5CWiMme/j8njfbJDtpc1G9vSvV/70uace/VxVecffq3u8ei2XMZ7D\nWB5/z5c8ky1PzFp3nm7LWZZd5VOebst08CFPr3+f02WyLOvrG6y1LaDpnLt54tdarVY2NjZWWavB\nN/mDmX8HHJI2G/9Wdx4RCdOgtbPnEbm1dtf8yhWstXOBvYEHp7tDH5T1KZg2G+toX1d+ehnbm6is\nnFVTzvKEkBGU0zdFWiu7AVdZa28FrgTOcs79vtpYQbkA+FAUJzvUHURERlPfrZWpjFprZVwUJ98F\nrkubjXPrziIi4am8tSKFLAXOiOKk60lgEZGqjGQhr6BvdgPwOPDuMjcaSn9POcsTQkZQTt+MZCEv\nW9psZMASNFe5iNRgJAt5RfMvXAbsFcXJvmVtMJR5IpSzPCFkBOX0zUgW8iqkzcZzwIXAGXVnEZHR\nMpKFvMK+2UXAcVGc7FzGxkLp7ylneULICMrpm5Es5FVJm41HgMuBk+vOIiKjYyQLecV9s6XAqVGc\nbD3ohkLp7ylneULICMrpm5Es5FVKm41VwF3A++rOIiKjYSQL+RD6ZkuAjw06V3ko/T3lLE8IGUE5\nfTOShXwIfgLsCBxadxARmflGspBX3TdLm41NtJ/reeYg2wmlv6ec5QkhIyinb0aykA/JxcBRUZy8\nse4gIjKzjWQhH0bfLG02nga+CZw23W2E0t9TzvKEkBGU0zcjWciH6DzgpPxJQiIilRjJQj6svlna\nbNwLXAucMJ3vD6W/p5zlCSEjKKdvtuy1grV2V+C7wBxgA/AJ59yVVQebQZYAy6I4uSg/CSoiUqoi\nR+QvAKc65/YBjqV9Ei9oQ+6b/SvwHHBUv98YSn9POcsTQkZQTt/0LOTOuTXOuVX57+8HtrLWblV5\nshmiY67ygS5FFBHppq8eubX2XcDNzrnnK8ozFDX0zZYDB0Zxslc/3xRKf085yxNCRlBO3xQu5Nba\nucAXgVOnWq/zRxljzHwtm/lps7Ee+Or6NQ/8jQ95tKxlLfu3PAiTZVnPlay1s4Ergc86537Rbb1B\nnwQ9LMaY+cP+pI7i5HXAHcCeabPxRJHvqSPndChneULICMpZtkFrZ9Ej8ouBS6Yq4jK1tNl4mPYc\nLCfVnUVEZpaeR+TW2sOAhPbRpAEy4N3OudUT1w3liLwuUZwcCHyf9lH5C3XnERE/DFo7e15H7py7\nDpg93R3IZmmzsTKKkweAP6X9sGYRkYGN5J2dZZ1gmKbClyLWnLMw5SxPCBlBOX0zkoW8Zj8Adovi\n5KC6g4jIzDCShbzOs9h5b/x8ChyVh3C2HZSzTCFkBOX0zUgWcg98DTg6vyRRRGQgI1nI6+6b5deR\nfwdYONV6decsSjnLE0JGUE7fjGQh98S5wClRnGxTdxARCdtIFnIf+mZps3EXcDOwoNs6PuQsQjnL\nE0JGUE7fjGQh98gS4MwoTnQTlYhM20gWco/6Zj+nfbPVkZN90aOcU1LO8oSQEZTTNyNZyH2Rz1W+\nFFhUdxYRCddIFnLP+mbfBg6L4uTNE7/gWc6ulLM8IWQE5fTNSBZyn6TNxjrgG8BH684iImEayULu\nYd/sAuCDUZzs0PmmhzknpZzlCSEjKKdvRrKQ+yZtNu4HWsCf1xxFRAI0koXc077ZEuCMKE5e/Dvx\nNOfLKGd5QsgIyumbkSzknroBeAI4uu4gIhKWnoXcWhtba1dba28bRqBh8LFvll+KuISOSxF9zDkZ\n5SxPCBlBOX1T5Ij8cuDdVQcRAL4H7BXFyT51BxGRcPQs5M65G4HHhpBlaHztm6XNxnPAMvK5yn3N\nOZFylieEjKCcvlGP3D9fBY6L4mTnuoOISBh6Pny5X8aY+eOfguP9Kd+Wx9/zJc/E5YP+rvV94GRj\nzDrglrrzhD6eHcuLAhjP/bMsW+JRnkmXJ/7d150n9PFkQCbLsp4rWWt3B37knNt3qvVarVY2Njbm\n/Ux+nR82PoriZF/gn2/57HEfen7tE1fWnacX38dzXAg5Q8gIylm2QWtn0SNyk79mBN//YtNm47Yo\nTu7a/68u/3gUJ3sD1wCr0mZjY93ZJuP7eI4LIWcIGUE5fdOzkFtrzweOBV5trb0fONU59+PKk4ml\nfU35EcBpwNwoTm6gXdSvAdK02VhfYz4R8USh1kpRaq2UqzNnFCe7AIfnryOAtwC/ZHNhvz5tNp6s\nO6fPQsgZQkZQzrINq7UiNUubjTXA9/MXUZy8Engb7cL+l8DBUZz8G5sL+zVps/FQTXFFZIhG8oh8\nJoriZGvgrWw+Yj8ceIqOwg7ck99BKiIe0RG5AC/eTHRj/vpiPvnW3rSL+nzgU8C2UZxcy+bCfkva\nbLxQT2IRKctIFvJQ+maD5EybjU3AHfnrKwBRnOxGu7AfAZwI7BbFyY3AeHH/v2mz8cwwcw5TCDlD\nyAjK6ZuRLOSjKp/3/JL8RRQnrwYOo13YPw/sG8XJbbSL+rXAtWmz8XhNcUWkIPXI5UVRnLwCOITN\nR+2HAPezuRVzbf5hICIlUo9cSpO3Vf4lfxHFyZbA/rSL+vuAJVGcPEtHYQd+lbdxRKQmI1nIQ+mb\n1Z0zPxG6In/9fRQnBpjH5iP2s4BX7fep79219Q6v/gHtwr4ybTaeryvzVOoezyJCyAjK6ZuRLOQy\nPfmli3fnr68BRHHy+nX33XHy1vv+pz2AE4A3R3GSsvmo/ca02fhDTZFFRoJ65FKqKE7mAG9n81H7\nAcCdvLTP/kh9CUX8M2jtLL2QJ5etuQ54GFid/9r5+9XAI+csW+Dl5E9SvihOtgEiNhf2Q2n/O+i8\nUek+3agko8zHQn4kMBd4Xcevnb/fEXiElxb3SQv/OcsWPFtauA6h9M1mYs4oTrYA9mFzYT8C2Mjm\na9mvAW6v4gRqCOMZQkZQzrJ5d9XKOcsWXD3V1xcvXL4VsAsvLe5zaU8CNdb53uKFyzcw+VH9xPce\nP2fZAh3RBSCfiveW/HVefgJ1TzZPLXAmsHMUJ9exubivSJuNDTVFFvGetz3yxQuXG2AO3Y/sO997\nBfB7Ji/4nYX/9+csW/BcGfmkOlGczGXzjUpHAP8BWMnmSx6vT5uNp+tLKFIu71ordZzsXLxw+Ta0\nC/tULZ25tH8SeJru/fvO957WUb4fojjZgXZvffyo/SDgHjYX9mvSZmN1fQlFBjOUQm6ttcDngE3A\nX3Z7sITvV60sXrh8FvCa5MblxzTetuBBpj7S34JiBX9NVSdvQ+nvDTtnFCezgQPZXNgPAx6jo7AD\nv5l4AjWE8QwhIyhn2SrvkVtrtwK+ABwMbEv7rr8gnxB0zrIFm4A1xhx/7w2//OFVU627eOHy7Zn8\nCP/wCe/ttHjh8seY+kqdh2mfvO17Qip5ubxffn3++rt8psc/ol3U3wGcDWzZMdPjtcCtNcUVqVzP\nI3Jr7eHAWc659+TLCbDIOXfbxHV9PyKvQn7ydmeK9fKfo0DBp33yVre9T1N+AnV3Ns/LfgSwG+35\n2Z8r8NpQ8nq91t2oyy9HW+WtFWvt+4CjaJ9segJ4L/BN59zPyg4zk+Unb19FsYK/Pe2Tt3+gfWne\nRuCFjt9Ptlxknel8jxf7HvSDLe+zbw9s3fGaPWG526vIeoNsaxbVfYiU9WGjD5wKDe3yQ+fcRQDW\n2mOnuzNf1NE3y0+cPpm/fjXVuosXLp8NzP3Fdd+ef9RhJ9xMu1+/Zf5r52vie72We60zm/YVQH1t\nZ+26J3Z55XY7PlVlvsULl8MAHxDvhI3rN6zbfpvZ263tGOqJRWmyItVrnamWM2B9/ur6PVl72WQw\na/2GdXNmb7Pd08CsDGOAWVn7v/fWWfvvZ1ZmjBlfP/96vmxmvfi+efHr+bJ5cX3A5N8za8L3zMoY\nf59ZYEz+3iza75tTPn75JmATWQbGZON/li4VqPd4dl3KJnlvfC+THn1mE9cC2JRt2mKWmTXZw1O6\nfSAV+TfR7f1pj8Vxx+zYZTfFFCnkDwO7dizPzd+bVKvV8v4T+8orr6TVatUdo6vG+3bJf/0fNScp\napdh7Wi8uEt9xj8MxCNFWitbAXfRnpt6W6DlnJs3hGwiIlJAz09W59zzwCdoXyFwJbCo6lAiIlJc\nqTcEiYi+v4ZiAAADgklEQVTI8KnXJSISOBVyEZHA9T37YdHb9ftdt2x95tzI5jv/rnbODeU8gLU2\npv1UnTXOuX17rFvnWPaTs66x3BX4Lu2J1jYAn3DOXTnF+rWM5zRy1jWeOwE/o10jMuBs59wVU6xf\n13j2m7OW8cz3vT3tp2t9yTn35SnW63ss+yrk/dyuX+et/dPY9zrn3FuHkW2Cy4HlwMVTreTBNAmF\ncubqGssXgFOdc6ustbvRPjn/hslWrHk8C+fM1TWeTwFHOueesda+GlgFTFogax7PwjlzdY0nwGLa\nz7/tarpj2W9r5RDgdufco865B4D7rbXdjtD6Wbds/e67lrtRnXM30p7sqZc6x7KfnFDfWK5xzq3K\nf38/sFX+n2IytY1nnzmhvvHc6JwbnxtoDrC1tbbbgV+d49lPTqhpPK2182hP5bGyx6rTGst+Wytz\ngYettSfTvl1/Ne1byl8270qf65at333PttauAJ4FPumcu2YIGftR51j2q/axtNa+C7g5v3R2Ml6M\nZ4GcUON45q2A64E3ASc65ya7QxJqHs8+ckJ94/kF2g9N+UiP9aY1ltM62emcu8g59z0KfLr1s27Z\n+tj3G5xzBwEfAy611s6uPl3/6hzLPtQ6ltbaucAXgVN7rVvnePaRs7bxdM79IT8nchBwurV2yrtq\n6xrPPnMOfTyttccAd+dH2IX0O5b9FvJ+btfv69b+kvW1b+fcmvzXFcBDwB5VhpuGOseyL3WOZf6f\n8nvAx51zv51i1VrHs4+cXvzbdM7dDTwP7NdlFS/+fRbIWdd4HgIcZ629EzgdOMtau6DLutMay35b\nKzcBb7HWvoZ2I/7149PZWms/D+Cc+2SvdYegcE5r7Y7As8659dbaPWgP4v1Dygn5ZESdb3g2luN6\n5vRgLC8GLnHO/WKqnNQ/noVy1jme+dU1651zj+c/PewNPDhZTmocz35y1jWezrlPAZ/KM3waWOuc\nWz5ZRqY5ln0dkfe4XX98KtYi61aqn5zAXsAt1tpbaF+dcaJz7tlh5LTWnp9nnGetvT//EexlGesc\ny35yUu9YHgYcC5xsrf2ltfbm/D/2y3LWOZ795KTG8aQ9f/tV1tpbaY/RWc6530+Ws+Z/n4VzUu94\ndlPKWOoWfRGRwOnOThGRwKmQi4gEToVcRCRwKuQiIoFTIRcRCZwKuYhI4FTIRUQCp0IuIhK4/w8/\nUASvCYfeawAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "plt.plot(vars1)\n", "plt.plot(vars2)\n", "plt.legend(['no rt', 'rt'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Variance of the estimates using reparameterization trick is one order of magnitude smaller than the estimates from the first method!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12" } }, "nbformat": 4, "nbformat_minor": 0 }