{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Optimizations with Physical Gradients\n", "\n", "The previous section has made many comments about the advantages and disadvantages of different optimization methods. Below we'll show with a practical example how much differences these properties actually make.\n", "\n", "\n", "## Problem formulation\n", "\n", "We'll consider a very simple setup to clearly illustrate what's happening: we have a two-dimensional input space $\\mathbf{x}$, a mock \"physical model\" likewise with two dimensions $\\mathbf{z}$, and a scalar loss $L$, i.e. \n", "$\\mathbf{x} \\in \\mathbb{R}^2$, \n", "$\\mathbf{z}: \\mathbb{R}^2 \\rightarrow \\mathbb{R}^2 $, and \n", "$L: \\mathbb{R}^2 \\rightarrow \\mathbb{R} $.\n", "The components of a vector like $\\mathbf{x}$ are denoted with $x_i$, and to be in sync with python arrays the indices start at 0.\n", "\n", "Specifically, we'll use the following $\\mathbf{z}$ and $L$:\n", "\n", "$\\quad \\mathbf{z}(\\mathbf{x}) = \\mathbf{z}(x_0,x_1) = \\begin{bmatrix} x_0 \\\\ x_1^2 \\end{bmatrix}$, \n", "i.e. $\\mathbf{z}$ only squares the second component of its input, and\n", "\n", "$\\quad L(\\mathbf{z}) = |\\mathbf{z}|^2 = z_0^2 + z_1^2 \\ $ \n", "represents a simple squared $L^2$ loss.\n", "\n", "As starting point for some example optimizations we'll use \n", "$\\mathbf{x} = \\begin{bmatrix} \n", " 3 \\\\ 3\n", "\\end{bmatrix}$ as initial guess for solving the following simple minimization problem:\n", "\n", "$\\quad \\text{arg min}_{\\mathbf{x}} \\ L(\\mathbf{x}).$\n", "\n", "For us as humans it's quite obvious that $[0 \\ 0]^T$ is the right answer, but let's see how quickly the different optimization algorithms discussed in the previous section can find that solution. And while $\\mathbf{z}$ is a very simple function, it is nonlinear due to its $x_1^2$.\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 3 Spaces\n", "\n", "In order to understand the following examples, it's important to keep in mind that we're dealing with mappings between the three _spaces_ we've introduced here:\n", "$\\mathbf{x}$, $\\mathbf{z}$ and $L$. A regular forward pass maps an\n", "$\\mathbf{x}$ to $L$, while for the optimization we'll need to associate values\n", "and changes in $L$ with positions in $\\mathbf{x}$. While doing this, it will \n", "be interesting how this influences the positions in $\\mathbf{z}$ that develop while searching for\n", "the right position in $\\mathbf{x}$.\n", "\n", "```{figure} resources/placeholder.png\n", "---\n", "height: 220px\n", "name: pg-three-spaces\n", "---\n", "TODO, visual overview of 3 spaces\n", "```\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implementation\n", "\n", "For this example we'll use the [JAX framework](https://github.com/google/jax), which represents a nice alternative to pytorch and tensorflow for efficiently working with differentiable functions.\n", "JAX also has a nice numpy wrapper that implements most of numpy's functions. Below we'll use this wrapper as `np`, and the _original_ numpy as `onp`.\n", "\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as np\n", "import numpy as onp\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll start by defining the $\\mathbf{z}$ and $L$ functions, together with a single composite function `fun` which calls L and z. Having a single native python function is necessary for many of the JAX operations." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting point x = [3. 3.]\n", "\n", "Some test calls of the functions we defined so far, from top to bottom, z, Lz (manual), Lz:\n" ] }, { "data": { "text/plain": [ "(DeviceArray([3., 9.], dtype=float32),\n", " DeviceArray(90., dtype=float32),\n", " DeviceArray(90., dtype=float32))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# \"physics\" function z\n", "def fun_z(x):\n", " return np.array( [x[0], x[1]*x[1]] )\n", "\n", "# simple L2 loss\n", "def fun_L(z):\n", " #return z[0]*z[0] + z[1]*z[1] # \"manual version\"\n", " return np.sum( np.square(z) )\n", "\n", "# composite function with L & z\n", "def fun(x):\n", " return fun_L(fun_z(x))\n", "\n", "\n", "x = np.asarray([3,3], dtype=np.float32)\n", "print(\"Starting point x = \"+format(x) +\"\\n\")\n", "\n", "print(\"Some test calls of the functions we defined so far, from top to bottom, z, Lz (manual), Lz:\") \n", "fun_z(x) , fun_L( fun_z(x) ), fun(x) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can evaluate the derivatives of our function via `jax.grad`. E.g., `jax.grad(fun_L)(fun_z(x))` evaluates the Jacobian $\\partial L / \\partial \\mathbf{z}$. The cell below evaluates this and a few variants, together with a sanity check for the inverse of the Jacobian of $\\mathbf{z}$:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Jacobian L(z): [ 6. 18.]\n", "\n", "Jacobian z(x): \n", "[[1. 0.]\n", " [0. 6.]]\n", "\n", "Sanity check with inverse Jacobian of z, this should give x again: [3. 3.]\n", "\n", "Gradient for full L(x): [ 6. 108.]\n", "\n" ] } ], "source": [ "# this works:\n", "print(\"Jacobian L(z): \" + format(jax.grad(fun_L)(fun_z(x))) +\"\\n\")\n", "\n", "# the following would give an error as z (and hence fun_z) is not scalar\n", "#jax.grad(fun_z)(x) \n", "\n", "# computing the jacobian of z is a valid operation:\n", "J = jax.jacobian(fun_z)(x)\n", "print( \"Jacobian z(x): \\n\" + format(J) ) \n", "\n", "# the following also gives error, JAX grad needs a single function object\n", "#jax.grad( fun_L(fun_z) )(x) \n", "\n", "print( \"\\nSanity check with inverse Jacobian of z, this should give x again: \" + format(np.linalg.solve(J, np.matmul(J,x) )) +\"\\n\")\n", "\n", "# instead use composite 'fun' from above\n", "print(\"Gradient for full L(x): \" + format( jax.grad(fun)(x) ) +\"\\n\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last line is worth a closer look: here we print the gradient $\\partial L / \\partial \\mathbf{x}$ at our initial position. And while we know that we should just move diagonally towards the origin (with the zero vector being the minimizer), this gradient is not very diagonal - it has a strongly dominant component along $x_1$ with an entry of 108.\n", "\n", "Let's see how the different methods cope with this situation. We'll compare \n", "\n", "* the first order method _gradient descent_ (i.e., regular, non-stochastic, \"steepest gradient descent\"), \n", "\n", "* _Newton's method_ as a representative of the second order methods, \n", "\n", "* and _physical gradients_.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient descent\n", "\n", "For gradient descent, the simple gradient based update from equation {eq}`GD-update`\n", "in our setting gives the following update step in $\\mathbf{x}$:\n", "\n", "$$\\begin{aligned}\n", "\\Delta \\mathbf{x} \n", "&= \n", "- \\eta ( J_{L} J_{\\mathbf{z}} )^T \\\\\n", "&=\n", "- \\eta ( \\frac{\\partial L }{ \\partial \\mathbf{z} } \\frac{\\partial \\mathbf{z} }{ \\partial \\mathbf{x} } )^T\n", "\\end{aligned}$$\n", "\n", "where $\\eta$ denotes the step size parameter .\n", "\n", "Let's start the optimization via gradient descent at $x=[3,3]$, and update our solution ten times with\n", "$\\eta = 0.01$:\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GD iter 0: [2.94 1.9200001]\n", "GD iter 1: [2.8812 1.6368846]\n", "GD iter 2: [2.823576 1.4614503]\n", "GD iter 3: [2.7671044 1.3365935]\n", "GD iter 4: [2.7117622 1.2410815]\n", "GD iter 5: [2.657527 1.1646168]\n", "GD iter 6: [2.6043763 1.1014326]\n", "GD iter 7: [2.5522888 1.0479842]\n", "GD iter 8: [2.501243 1.0019454]\n", "GD iter 9: [2.4512184 0.96171147]\n" ] } ], "source": [ "x = np.asarray([3.,3.])\n", "eta = 0.01\n", "historyGD = [x]; updatesGD = []\n", "\n", "for i in range(10):\n", " G = jax.grad(fun)(x)\n", " x += -eta * G\n", " historyGD.append(x); updatesGD.append(G)\n", " print( \"GD iter %d: \"%i + format(x) )\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we've already printed the resulting positions in $\\mathbf{x}$, and they seem to be going down, i.e. moving in the right direction. The last point, $[2.451 \\ 0.962]$ still has a fair distance of 2.63 to the origin.\n", "\n", "Let's take a look at the progression over the course of the iterations (the evolution was stored in the `history` list above). The blue points denote the positions in $\\mathbf{x}$ from the GD iterations, with the target at the origin shown with a thin black cross." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'x1')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAFtCAYAAADrr7rKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcCElEQVR4nO3df5BdZZ3n8fc3IQQ3STfWqAlCQjJSqyjU8KvUsA6ibFZFgrIrq+U4RdadXXBMRV1X1yzo6CxMZrQWcZaMUOgo7Ew5/rOFxF8I46C1ElGj/FojKys4oCSMP+hud03MkO/+cW7r5ebe23277+1zbz/vV9Wp5jz3Ofd8c7x++vRznntOZCaSpDIsqbsASdLCMfQlqSCGviQVxNCXpIIY+pJUEENfkgpi6EtSQQx9SSrIUXUXsNAiIoBnA1N11yJJfbQK+HHO8I3b4kKfKvAfrbsISRqAE4AfdetQYuhPATzyyCOMjY3VXYskzdvk5CRr166FWYxglBj6AIyNjRn6korjhVxJKoihL0kFMfQlqSCGviQVxNCXpIIY+pI0RAb9MMNaQz8i3hIR90bEZGPZHRGvmmGbiyPiexFxICLui4jzF6peSRqEqSnYtg02bIC1a6uf27ZV7f0WdT4jNyI2A08C3wcCuAR4F3B6Zv6vNv3PBr4KbAc+C7wR+E/AGZl5/yz3OQZMTExMOE9fUu2mpmDjRti7Fw4f/k37kiVw8smwezesWtX9PSYnJxkfHwcYz8zJbn1rPdPPzF2Z+fnM/H5m/u/MvBz4BfDiDpu8DfhiZn4oM/dm5nuBbwNbF6pmSeqnyy8/MvChWt+7F664or/7G5ox/YhYGhFvAFYAuzt02wjc3tJ2a6O90/suj4ix6YXqpkSSNBR27Toy8KcdPgy33NLf/dUe+hFxakT8AjgIXAdclJnf7dB9DbC/pW1/o72T7cBE0+LN1iQNhUw4dKh7n0OH+ntxt/bQBx4ATgNeBHwUuDEint/H998BjDctJ/TxvSVpziJg2bLufZYtq/r1S+2hn5m/yswHM3NPZm4H7qEau29nH7C6pW11o73T+x/MzMnpBe+jL2mIbN5cXbRtZ8kSuPDC/u6v9tBvYwmwvMNru4HzWto20fkagCQNtauuqmbptAb/9OydK6/s7/7qnqe/IyLOiYj1jbH9HcC5wF83Xr+p0TbtI8ArI+KdEfG8iHg/cBZw7ULXLkn9sGpVNS1z61ZYvx6OP776uXXr7KZr9qruefofpzpzP47qIuu9wJ9l5m2N1+8AHs7MLU3bXAxcCaynmt//7sz8fA/7dJ6+pKGV2fsYfi/z9GsN/ToY+pIWm5H5cpYkaWEZ+pJUEENfkgpi6EtSQQx9SSqIoS9JBTH0Jakghr4kFcTQl6SCGPqSVBBDX5IKYuhLUkEMfUkqiKEvSQUx9CWpIIa+JBXE0Jekghj6klQQQ1+SCmLoS1JBDH1JKoihL0kFMfQlqSCGviQVxNCXpIIY+pJUEENfkgpi6EtSQQx9SSqIoS9JBTH0Jakghr4kFcTQl6SCGPqSVBBDX5IKUmvoR8T2iPhmRExFxOMRcXNEPHeGbbZERLYsBxaqZkkaZXWf6b8U2Am8GNgELAO+FBErZthuEjiuaTlxkEVK0mJxVJ07z8xXNq9HxBbgceBM4KvdN819AyxNkhalus/0W403fv5shn4rI+KHEfFIRHwmIl7QqWNELI+IsekFWNW3aiVpxAxN6EfEEuAa4GuZeX+Xrg8AbwZeA7yJ6t9wZ0Sc0KH/dmCiaXm0XzVL0qiJzKy7BgAi4qPAq4CXZOasgzkilgF7gU9l5nvbvL4cWN7UtAp4dGJigrGxsXlWLUn1m5ycZHx8HGA8Mye79a11TH9aRFwLXACc00vgA2TmoYj4DnBSh9cPAgeb9jWfUiVppNU9ZTMagX8R8PLMfGgO77EUOBV4rN/1SdJiU/eZ/k7gjVTj81MRsabRPpGZvwSIiJuAH2Xm9sb6+4CvAw8CxwLvopqy+bGFLV2SRk/dof+Wxs87Wtr/DfDJxn+vAw43vfZ04AZgDfBzYA9wdmZ+d2BVStIiMTQXchdKY9rmhBdyJS0WvVzIHZopm5KkwTP0Jakghr4kFcTQl6SCGPqSVBBDX5IKYuhLUkEMfUkqiKEvSQUx9CWpIIa+JBXE0Jekghj6klQQQ1+SCmLoS1JBDH1JKoihL0kFMfQlqSCGviQVxNCXpIIY+pJUEENfkgpi6EtSQQx9SSqIoS9JBTH0Jakghr4kFcTQl6SCGPqSVBBDX5IKYuhLUkEMfUkqiKEvSQUx9CWpIIa+JBWk1tCPiO0R8c2ImIqIxyPi5oh47iy2uzgivhcRByLivog4fyHqlaRRV/eZ/kuBncCLgU3AMuBLEbGi0wYRcTbwKeDjwOnAzcDNEXHKwKuVpBEXmVl3Db8WEc8EHgdemplf7dDn08CKzLygqe3rwN2Zedks9jEGTExMTDA2NtanyiWpPpOTk4yPjwOMZ+Zkt751n+m3Gm/8/FmXPhuB21vabm20HyEilkfE2PQCrJp/mZI0moYm9CNiCXAN8LXMvL9L1zXA/pa2/Y32drYDE03Lo/OrVJJG19CEPtXY/inAG/r8vjuo/oKYXk7o8/tL0sg4qu4CACLiWuAC4JzMnOlMfB+wuqVtdaP9CJl5EDjYtK95VCpJ/ZMJCx1JdU/ZjEbgXwS8PDMfmsVmu4HzWto2NdolaahNTcG2bbBhA6xdW/3ctq1qXwi1zt6JiL8A3gi8Bnig6aWJzPxlo89NwI8yc3tj/WzgK8B7gM9RDQf9Z+CMGa4FTO/T2TuSajE1BRs3wt69cPjwb9qXLIGTT4bdu2HVHKaajNLsnbdQjbPfATzWtLy+qc864Ljplcy8k+oXxb8H7gFeB7x2NoEvSXW6/PIjAx+q9b174YorBl/DUM3TXwie6Uuqy4YN8PDDnV9fvx4ems0gd4tROtOXpCJkwqFD3fscOlT1GyRDX5IWQAQsW9a9z7Jlg5/NY+hL0gLZvLm6aNvOkiVw4YWDr8HQl6QFctVV1Syd1uCfnr1z5ZWDr8HQl6QFsmpVNS1z69bqou3xx1c/t26d+3TNXjl7R5Jq0q9v5Dp7R5JGQB13hTH0Jakghr4kFcTQl6SCGPqSVBBDX5IKYuhLUkEMfUkqiKEvSQUx9CWpIIa+JBXE0Jekghj6klQQQ1+SCmLoS1JBDH1JGrBhemyJoS9JAzA1Bdu2wYYNsHZt9XPbtqq9Tj45S5L6bGoKNm6EvXvh8OHftE8/C7ffj0b0yVmSVKPLLz8y8KFa37sXrriinrrA0Jekvtu168jAn3b4MNxyy8LW08zQl6Q+yoRDh7r3OXSovou7hr4k9VEELFvWvc+yZfU8FB0MfUnqu82bq4u27SxZAhdeuLD1PGX/9e1akhanq66qZum0Bv/07J0rr6ynLjD0JanvVq2qpmVu3Qrr18Pxx1c/t27t/3TNXjlPX5IGLHOwY/jO05ekIVLXRdt2DH1JKkjfQj8iTo6IH/S4zTkRsSsifhwRGRGvnaH/uY1+rcuaeRUvSYXo55n+0cCJPW6zArgHeGuP2z0XOK5pebzH7SWpSEfNtmNEXD1Dl2f2uvPM/ALwhcb797Lp45n5RK/7k6TSzTr0gbcBdwOdrgyvnHc1s3d3RCwH7gfen5lfW8B9S1JHg56pM1+9hP6DwIcz86/avRgRpwF7+lFUF48BlwHfApYDfwDcEREvysxvd6hreaPvtBpnyEpajKamqjtr7tpV3Vdn2bLqW7lXXVXvnPx2egn9bwFnAm1DH0hgoL/fMvMB4IGmpjsj4jnAO4Df77DZduCPBlmXpHJ1unf+zp3w5S/X/2WsVr1cyH0ncE2nFzPznsysYwroN4CTury+AxhvWk5YiKIklWGY753fzqxDOjP3ZeYPI+JlnfpExKX9Kasnp1EN+7SVmQczc3J6AWp+WJmkxWSY753fzlzOzL8YER+KiF/fPDQinhERu4A/7eWNImJlRJzWuB4AsKGxvq7x+o6IuKmp/9sj4jURcVJEnBIR1wAvB3bO4d8hSfMy7PfOb2cuof8y4CLgmxHx/Ih4NdUsmjGqs+5enAV8p7EAXN347z9urB8HrGvqfzTwX4H7gK8AvwP888z8297/GZI0P8N+7/x2ermQC0Bm3tk4M78O+DbVL473Ah/MHu/elpl30OXib2ZuaVn/IPDB3iqWpMHZvLm6aNtuiKfue+e3M9cLr/+U6iz9UeAfqb4h+0/6VZQkjYphvnd+Oz2HfkS8B9gN3AacArwQOB24NyI29rc8SRpuw3zv/HZ6vp9+RDwGvLlxC4XptmXAnwDbMnN5x42HgPfTlzRIdXwjt5f76fc8pg+cmpk/aW7IzEPAuyLis3N4P0laNIbpom07PQ/vtAZ+y2tfmV85kjQahmkaZi98iIokzdLUFGzbBhs2wNq11c9t26r2UeEzciVpFjrdY2d6lk6dF219Rq4k9dmo3WOnE0NfkmZh1O6x04mhL0kzGMV77HRi6EvSDEbxHjudGPqSNAubNx95q4Vpw3iPnU4MfUmaQebo3WOnk7l8I1eSFr12z719xSvgd38XvvjF37RdeGEV+MN2j51ODH1JatFpTv4NN1Rn9ffeCytXjsYYfiuHdySpxWzm5I9i4IOhL0lHWCxz8tsx9CWpyWKak9+OoS9JTRbTnPx2DH1JapK5eObkt+PsHUnFa52euXQpHHss/PznTx3GGbU5+e0Y+pKK1u2WyU9/ejU188knR3NOfjuGvqSidZue+cQT8KY3wTXXjO4YfivH9CUVbTbTMxdL4IOhL6lgi316ZjuGvqSiLebpme0Y+pKK0vpw85/+tHPfUZ+e2Y4XciUVo9NMnXYWw/TMdgx9ScXoNFMHqiGclSthbGzxTM9sx9CXVIxuM3Uy4bd+C37wg8U1ht/KMX1JRTh8eHYzdRY7Q1/SotV80XbdOti3r3v/xTZTpx2HdyQtSr1ctIXFOVOnHc/0JS1K3S7atlqsM3XaMfQlLTqZ3S/aAhx1FBx/PKxfD1u3wu7di2+mTju1hn5EnBMRuyLixxGREfHaWWxzbkR8OyIORsSDEbFl8JVKGnatX7p65JHu/Vevhr//e3joIfjIR8oIfKh/TH8FcA/wl8D/mKlzRGwAPgdcB/wecB7wsYh4LDNvHWShkoZXr+P3UF207fSglMWs1tDPzC8AXwCI2V0yvwx4KDPf2VjfGxEvAd4BGPpSoXoZv4dyLtq2M2q/5zYCt7e03dpobysilkfE2PQCFPJHnFSOW27pLfBLuWjbzqiF/hpgf0vbfmAsIp7WYZvtwETT8ujgypO0UJrH8Gcav1+6FJ797PIu2rZT95j+QtgBXN20vgqDXxppvY7hr127+G+vMFujFvr7gNUtbauBycz8ZbsNMvMgcHB6fZbXDiQNsV7n4F94oYE/bdSGd3ZTzdhptqnRLqkQM83Bn1b6+H07dc/TXxkRp0XEaY2mDY31dY3Xd0TETU2bXAf8dkR8MCKeFxF/CPxr4MMLW7mkuszmEYdLlsCJJzp+307dwztnAX/XtD499n4jsAU4Dlg3/WJmPhQRr6YK+bdRjc3/gXP0pXJEzPyIw3Xrqi9d6Uh1z9O/A+g40paZWzpsc/rAipI09DZvhp072w/xlDwHfzZGbUxfkrjqqmqsvvUbtY7hz8zQlzRyVq2qxuq3bq3m3pd447S5isysu4YF1fhW7sTExARjY2N1lyOpDzLLnpI5OTnJ+Pg4wHhmTnbr65m+pJFXcuD3ytCXpIIY+pJUEENfkgpi6EtSQQx9SSqIoS9JBTH0Jakghr4kFcTQl6SCGPqSVBBDX5IKYuhLUkEMfUkqiKEvSQUx9CWpIIa+JBXE0Jekghj6klQQQ1+SCmLoS1JBDH1JKoihL0kFMfQlqSCGviQVxNCXpIIY+pJUEENfkgpi6EtSQQx9SSqIoS9JBTH0JakgQxH6EfHWiHg4Ig5ExF0R8cIufbdERLYsBxayXkkaVbWHfkS8Hrga+ABwBnAPcGtEPKvLZpPAcU3LiYOuU5IWg9pDH/gPwA2Z+YnM/C5wGfD/gDd32SYzc1/Tsn9BKpWkEVdr6EfE0cCZwO3TbZl5uLG+scumKyPihxHxSER8JiJe0GUfyyNibHoBVvWrfkkaNXWf6T8DWAq0nqnvB9Z02OYBqr8CXgO8ierfcGdEnNCh/3Zgoml5dJ41S9LIqjv0e5aZuzPzpsy8OzO/AvxL4B+ASztssgMYb1o6/XKQpEXvqJr3/xPgSWB1S/tqYN9s3iAzD0XEd4CTOrx+EDg4vR4Rc6tUkhaBWs/0M/NXwB7gvOm2iFjSWN89m/eIiKXAqcBjg6hRkhaTus/0oZqueWNEfAv4BvB2YAXwCYCIuAn4UWZub6y/D/g68CBwLPAuqimbH1vowiVp1NQe+pn56Yh4JvDHVBdv7wZe2TQNcx1wuGmTpwM3NPr+nOovhbMb0z0lSV1EZtZdw4JqTNucmJiYYGxsrO5yJGneJicnGR8fBxjPzMlufUdu9o4kae4MfUkqiKEvSQUx9CWpIIa+JBXE0Jekghj6klQQQ1+SCmLoS1JBDH1JKoihL0kFMfQlqSCGviQVxNCXpIIY+pJUEENfkgpi6EtSQQx9SSqIoS9JBTH0Jakghr4kFcTQl6SCGPqSVBBDX5IKYuhLUkEMfUkqiKEvSQUx9CWpIIa+JBXE0Jekghj6klQQQ1+SCmLoS1JBDH1JKoihL0kFGYrQj4i3RsTDEXEgIu6KiBfO0P/iiPheo/99EXH+IOratWsXBw4caPvagQMH2LVr1yB2K0kDU3voR8TrgauBDwBnAPcAt0bEszr0Pxv4FPBx4HTgZuDmiDil37Vt2rSJ66+//ojgP3DgANdffz2bNm3q9y4laaAiM+stIOIu4JuZubWxvgR4BPhvmfmnbfp/GliRmRc0tX0duDszL5vF/saAiYmJCcbGxmasbzrgL730Uo455pgj1iWpbpOTk4yPjwOMZ+Zkt761nulHxNHAmcDt022ZebixvrHDZhub+zfc2ql/RCyPiLHpBVjVS43HHHMMl156Kddffz1PPPGEgS9ppB1V8/6fASwF9re07wee12GbNR36r+nQfzvwR3MtEKrgv+SSSzjzzDPZs2ePgS9pZNU+pr8AdgDjTcsJvb7BgQMHuPHGG9mzZw833nhjx4u7kjTs6g79nwBPAqtb2lcD+zpss6+X/pl5MDMnpxdgqpcCm8fwjz322F8P9Rj8kkZRraGfmb8C9gDnTbc1LuSeB+zusNnu5v4Nm7r0n7N2F22bx/gNfkmjpu4zfaima/67iLgkIk4GPgqsAD4BEBE3RcSOpv4fAV4ZEe+MiOdFxPuBs4Br+13Ybbfd1vai7XTw33bbbf3epSQNVO1TNgEiYivwLqqLsXcD2zLzrsZrdwAPZ+aWpv4XA1cC64HvA+/OzM/Pcl89TdmUpGHXy5TNoQj9hWToS1psRmaeviRpYRn6klQQQ1+SCmLoS1JBDH1JKoihL0kFqfuGa7WZnOw6q0mSRkYveVbiPP3jgUfrrkOSBuCEzPxRtw4lhn4Az6bHG69R3Yf/Uaq7dPa6bd2svR7WXo9Sa18F/DhnCPXihncaB6Trb8J2qt8VAEzN9I23YWPt9bD2ehRc+6z6eyFXkgpi6EtSQQz92TsIfKDxc9RYez2svR7W3kVxF3IlqWSe6UtSQQx9SSqIoS9JBTH0Jakghn6TiHhrRDwcEQci4q6IeOEM/S+OiO81+t8XEecvVK1tapl17RGxJSKyZTmwkPU21XJOROyKiB836njtLLY5NyK+HREHI+LBiNgy+Erb1tFT7Y26W497RsSaBSp5uo7tEfHNiJiKiMcj4uaIeO4stqv98z6X2ofl8x4Rb4mIeyNisrHsjohXzbBN34+5od8QEa8HrqaaLnUGcA9wa0Q8q0P/s4FPAR8HTgduBm6OiFMWpOCn1tJT7Q2TwHFNy4mDrrODFVT1vnU2nSNiA/A54O+A04BrgI9FxCsGVF83PdXe5Lk89dg/3ue6ZvJSYCfwYmATsAz4UkSs6LTBEH3ee669YRg+748C7wHOBM4Cvgx8JiJe0K7zwI55ZrpU01bvAq5tWl9CdbuG93To/2ngsy1tXweuG4HatwBP1H3M29SVwGtn6PNnwP0tbX8DfHEEaj+30e/Yuo91S13PbNR1Tpc+Q/N5n0PtQ/l5b9T2M+DfLuQx90wfiIijqX773j7dlpmHG+sbO2y2sbl/w61d+g/EHGsHWBkRP4yIRyKi49nGEBqK4z5Pd0fEYxFxW0T8s7qLAcYbP3/Wpc+wHvfZ1A5D9nmPiKUR8QaqvxZ3d+g2kGNu6FeeASwF9re07wc6jbeu6bH/oMyl9geANwOvAd5E9Tm4MyJOGFSRfdTpuI9FxNNqqKcXjwGXAf+qsTwC3BERZ9RVUEQsoRoi+1pm3t+l67B83n+th9qH5vMeEadGxC+ovnF7HXBRZn63Q/eBHPPi7rIpyMzdNJ1dRMSdwF7gUuC9ddW12GXmA1QBNO3OiHgO8A7g9+upip3AKcBLatr/fMyq9iH7vD9AdS1qHHgdcGNEvLRL8PedZ/qVnwBPAqtb2lcD+zpss6/H/oMyl9qfIjMPAd8BTupvaQPR6bhPZuYva6hnvr5BTcc9Iq4FLgBelpkzPVhoWD7vQM+1P0Wdn/fM/FVmPpiZezJzO9VEgLd16D6QY27oU/0PAewBzptua/zpeB6dx9t2N/dv2NSl/0DMsfaniIilwKlUww/DbiiOex+dxgIf96hcC1wEvDwzH5rFZkNx3OdYe+t7DNPnfQmwvMNrgznmdV+9HpYFeD1wALgEOBm4Hvg5sLrx+k3Ajqb+ZwOHgHcCzwPeD/wKOGUEan8f8C+A36aa4vkp4JfA82uofSVV8J1GNQvjHY3/Xtd4fQdwU1P/DcD/BT7YOO5/CPwj8IoRqP3tVOPKJ1ENS1xD9VfaeQtc918AT1BNf1zTtDytqc9Qft7nWPtQfN4bn4dzgPVUv3R2AIeBTQt5zBf0/yTDvgBbgR9SXWS5C3hR02t3AJ9s6X8x1RjdQeB+4PxRqB34cFPffVTz3k+vqe5zG4HZunyy8fongTvabPOdRv3/B9gyCrUD7wYebATOT6m+a/CyGupuV3M2H8dh/bzPpfZh+bxTzbd/uFHH41QzczYt9DH31sqSVBDH9CWpIIa+JBXE0Jekghj6klQQQ1+SCmLoS1JBDH1JKoihL0kFMfSlARuWxztKYOhLAzVkj3eUvA2DNB8R8UzgPuDPM/NPGm1nU91H5VVUN/p6dWae0rTN31A9MvGVC1+xSueZvjQPmfkPVE9len9EnBURq4D/TvXM4r9leB8zqEL55CxpnjLz8xFxA/DXwLeobv28vfFy18c75mg++EUjzDN9qT/+I9VJ1MXA72XmwZrrkdoy9KX+eA7wbKr/T61val9sj3fUiHN4R5qniDga+Cvg01QPvPhYRJyamY9TPdru/JZNRvnxjhpxzt6R5ikiPgS8Dvgd4BfAV4CJzLygMWXzfmAn8JfAy4E/p5rRc2tNJatghr40DxFxLnAb1WMP/2ejbT1wD/CezPxoo8+HgecDjwL/JTM/ufDVSoa+JBXFC7mSVBBDX5IKYuhLUkEMfUkqiKEvSQUx9CWpIIa+JBXE0Jekghj6klQQQ1+SCmLoS1JBDH1JKsj/B+rWSbCM6dwHAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "axes = plt.figure(figsize=(4, 4), dpi=100).gca()\n", "historyGD = onp.asarray(historyGD)\n", "updatesGD = onp.asarray(updatesGD) # for later\n", "axes.scatter(historyGD[:,0], historyGD[:,1], lw=0.5, color='blue')\n", "axes.scatter([0], [0], lw=0.25, color='black', marker='x') # target at 0,0\n", "axes.set_xlabel('x0')\n", "axes.set_ylabel('x1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "No surprise here: the initial step mostly moves downwards along $x_1$ (in top right corner), and the updates afterwards curve towards the origin. But they don't get very far. It's still quite a distance to the solution in the bottom left corner." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Newton\n", "\n", "For Newton's method, the update step is given by\n", "\n", "$$\n", "\\begin{aligned}\n", "\\Delta \\mathbf{x} &= \n", "- \\eta \\left( \\frac{\\partial^2 L }{ \\partial \\mathbf{x}^2 } \\right)^{-1}\n", " \\frac{\\partial L }{ \\partial \\mathbf{x} }\n", "\\\\\n", "&=\n", "- \\eta \\ H_L^{-1} \\ ( J_{L} J_{\\mathbf{z}} )^T\n", "\\end{aligned}\n", "$$\n", "\n", "Hence, in addition to the same gradient as for GD, we now need to evaluate and invert the Hessian of $\\frac{\\partial^2 L }{ \\partial \\mathbf{x}^2 }$.\n", "\n", "This is quite straightforward in JAX: we can call `jax.jacobian` two times, and then use the JAX version of `linalg.inv` to invert the resulting matrix.\n", "\n", "For the optimization with Newton's method we'll use a larger step size of $\\eta =1/3$. For this example and the following one, we've chosen the step size such that the magnitude of the first update step is roughly the same as the one of GD. In this way, we can compare the trajectories of all three methods relative to each other. Note that this is by no means meant to illustrate or compare the stability of the methods here. Stability and upper limits for $\\eta$ are separate topics. Here we're focusing on convergence properties.\n", "\n", "In the next cell, we apply the Newton updates ten times starting from the same initial guess:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Newton iter 0: [2. 2.6666667]\n", "Newton iter 1: [1.3333333 2.3703704]\n", "Newton iter 2: [0.88888884 2.1069958 ]\n", "Newton iter 3: [0.59259254 1.8728852 ]\n", "Newton iter 4: [0.39506167 1.6647868 ]\n", "Newton iter 5: [0.26337445 1.4798105 ]\n", "Newton iter 6: [0.17558296 1.315387 ]\n", "Newton iter 7: [0.1170553 1.1692328]\n", "Newton iter 8: [0.07803687 1.0393181 ]\n", "Newton iter 9: [0.05202458 0.92383826]\n" ] } ], "source": [ "x = np.asarray([3.,3.])\n", "eta = 1./3.\n", "historyNt = [x]; updatesNt = []\n", "\n", "for i in range(10):\n", " G = jax.grad(fun)(x)\n", " H = jax.jacobian(jax.jacobian(fun))(x)\n", " #H = jax.jacfwd(jax.jacrev(fun_Lz))(x) # alternative\n", " Hinv = np.linalg.inv(H)\n", " \n", " x += -eta * np.matmul( Hinv , G)\n", " historyNt.append(x); updatesNt.append( np.matmul( Hinv , G) )\n", " print( \"Newton iter %d: \"%i + format(x) )\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last line already indicates: Newton's method does quite a bit better. The last point $[0.052 \\ 0.924]$ only has a distance of 0.925 to the origin (compared to 2.63 for GD)\n", "\n", "Below, we plot the Newton trajectory in orange next to the GD version in blue." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'x1')" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAFtCAYAAADrr7rKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgpklEQVR4nO3df5RkZX3n8fe3h2Fgh+7GrGbAAZwJHAwGMiicxGE9iJJJFB3A3bB6jIkTNyu69hnjumadhSSaZUJ+nKhknSirRiHZY/xnF6ejBjAEcyIj0VEmso4QDOgwMkMU6G6zztjQ3/3jVmvRXdXd1VNV91bf9+uce4q69dyq7xQ1n7l1n6eeJzITSVI9DJVdgCSpfwx9SaoRQ1+SasTQl6QaMfQlqUYMfUmqEUNfkmrE0JekGjmu7AL6LSICeDYwVXYtktRFw8C3c5Ff3NYu9CkC/+Gyi5CkHjgNOLhQgzqG/hTAgQMHGBkZKbsWSTpmk5OTnH766bCEKxh1DH0ARkZGDH1JtWNHriTViKEvSTVi6EtSjRj6klQjhr4k1YihL0kVkjO9Xc2w1NCPiDdHxD9ExGRj2xMRL1/kmKsi4usRcSQivhoRl/WrXknqhanHp7jzj7ZzYNdGDt14Ogd2beTOP9rO1OPdnzggylwjNyK2Ak8B/wgE8HrgHcDzM/P/tmh/EfC3wA7gL4HXAv8VeEFm3rvE1xwBJiYmJhynL6l0U49PcejmzfzEj+1n1aqZH+5/8qkhHnzsHE75lT0MP2N4weeYnJxkdHQUYDQzJxdqW+qZfmaOZ+anM/MfM/P+zLwG+B7wwjaHvBX4q8z8w8zcn5m/CXwZGOtXzZLUTXv/9Jp5gQ9w3KoZNv7YfvZ+9Nquvl5lrulHxKqIeA2wFtjTptlm4LNz9t3a2N/ueddExMjsRjEpkSRVwpknjM8L/FnHrZrhzON3d/X1Sg/9iDgvIr4HHAU+CLwqM7/WpvkpwOE5+w439rezA5ho2pxsTVIl5Exy3ND0gm2OG5ruaudu6aEP3AecD/ws8AHgpoh4Xhef/3pgtGk7rYvPLUnLFkPBkzOrF2zzZK4mhqJrr1l66GfmDzLzgczcm5k7gH0U1+5bOQSsm7NvXWN/u+c/mpmTsxvOoy+pQr5xZCtPPtU6ip98aohvHL28q69Xeui3MASsafPYHuDSOfu20L4PQJIq7YI37OTBx86ZF/xPPjXEg989hwt+9bquvl6pUytHxPXAZ4BvUXSwvha4BPiFxuM3Awcb3wAAbgA+FxFvBz4FvAa4EHhjfyuXpO4YfsYw/Moe/u6j13Lm8bs5bmiaJ2dW840fXM4Fv3rdosM1O1X2OP2PUJy5n0rRyfoPwO9n5u2Nx+8EHsrMbU3HXAVcB2ygGN//G5n56Q5e03H6kiorZ7Lja/idjNMvNfTLYOhLWmkG5sdZkqT+MvQlqUYMfUmqEUNfkmrE0JekGjH0JalGDH1JqhFDX5JqxNCXpBox9CWpRgx9SaoRQ1+SasTQl6QaMfQlqUYMfUmqEUNfkmrE0JekGjH0JalGDH1JqhFDX1JrNVs/uy6OK7sASRUyPQX7roGD4zAzDUOrYf1W2LQTVg+XXZ26wNCXVJiegts2w8R+YOZH++/fBYfvgJ/fY/CvAF7ekVTYd838wIfi/sR+2HdtGVWpywx9SYWD48wP/FkzcHB3P6tRjxj6kopO25nphdvMTNu5uwIY+pIgoui0XcjQ6qKdBpqhL6mwfivtI2EI1l/ez2rUI4a+pMKmnTB6DvNjYajYv+m6MqpSlxn6kgqrh4thmWePwdoNcOL64vbsMYdrriCRNeuYiYgRYGJiYoKRkZGyy5GqK9Nr+ANicnKS0dFRgNHMnFyorWf6kloz8FckQ1+SasTQl6QaMfQlqUYMfUmqkVJDPyJ2RMQXI2IqIh6NiFsi4rmLHLMtInLOdqRfNUvSICv7TP/FwC7ghcAWYDVwW0SsXeS4SeDUpu05vSxSklaKUufTz8yXNd+PiG3Ao8AFwN8ufGge6mFpkrQilX2mP9do4/axRdqdFBHfjIgDEfHJiPipdg0jYk1EjMxugD8rlFRblQn9iBgC3gd8PjPvXaDpfcAbgCuA11H8Ge6KiNPatN8BTDRtD3erZkkaNJWZhiEiPgC8HHhRZi45mCNiNbAf+Hhm/maLx9cAa5p2DQMPOw2DpJWik2kYKrFGbkS8H3glcHEngQ+QmdMR8RXgrDaPHwWONr3WsZQqSQOt7CGb0Qj8VwEvzcwHl/Ecq4DzgEe6XZ8krTRln+nvAl5LcX1+KiJOaeyfyMzvA0TEzcDBzNzRuP9bwBeAB4CTgXdQDNn8cH9Ll6TBU3bov7lxe+ec/b8KfKzx32fw9NWanwF8CDgFeBzYC1yUmV/rWZWStEJUpiO3X5xPX9JK43z6UjfV7MRIK1vZl3ekapqegn3XwMFxmJmGodXFwuGbdrpsoAaaoS/NNT0Ft22Gif08rTvp/l1w+A7Xi9VA8/KONNe+a+YHPhT3J/bDvmvLqErqCkNfmuvgOPMDf9YMHNzdz2qkrjL0pWaZxTX8hcxM27mrgWXoS80iik7bhQytLtpJA8jQl+Zav5X2fzWGYP3l/axG6ipDX5pr004YPYf5fz2Giv2briujKqkrDH1prtXDxbDMs8dg7QY4cX1xe/aYwzU18JyGQVpMptfwVWlOwyB1k4GvFcTQl6QaMfQlqUYMfUmqEUNfkmrE0JekGjH0JalGDH1JqhFDX5JqxNCXpBox9CWpRgx9SaoRQ1+SasTQV3lqNsOrVAXHlV2AamZ6CvZdUyw+PjNdLD24fmuxcInz1Es9Z+irf6an4LbNMLEfmPnR/vt3weE7XKBE6gMv76h/9l0zP/ChuD+xH/ZdW0ZVUq0Y+uqfg+PMD/xZM3Bwdz+rkWrJ0Fd/ZBbX8BcyM23nrtRjhr76I6LotF3I0GqXJpR6zNBX/6zfSvuP3BCsv7yf1Ui1ZOirfzbthNFzmP+xGyr2b7qujKqkWjH01T+rh4thmWePwdoNcOL64vbsMYdrSn0SWbOOs4gYASYmJiYYGRkpu5x6y/QavtQFk5OTjI6OAoxm5uRCbUs904+IHRHxxYiYiohHI+KWiHjuEo67KiK+HhFHIuKrEXFZP+pVlxn4Ut+VfXnnxcAu4IXAFmA1cFtErG13QERcBHwc+AjwfOAW4JaIOLfn1UrSgKvU5Z2IeBbwKPDizPzbNm0+AazNzFc27fsCcE9mvmkJr+HlHUkrysBc3mlhtHH72AJtNgOfnbPv1sb+eSJiTUSMzG6AvYWSaqsyoR8RQ8D7gM9n5r0LND0FODxn3+HG/lZ2ABNN28PHVqkkDa7KhD7Ftf1zgdd0+Xmvp/gGMbud1uXnl6SBUYmplSPi/cArgYszc7Ez8UPAujn71jX2z5OZR4GjTa91DJVKUveUMWq57CGb0Qj8VwEvzcwHl3DYHuDSOfu2NPZLUqVNTcH27bBxI5x+enG7fXuxvx9KHb0TEX8CvBa4Ariv6aGJzPx+o83NwMHM3NG4fxHwOeCdwKcoLgf9N+AFi/QFzL6mo3cklWJqCjZvhv37YaZplvGhITjnHNizB4aXMdRkkEbvvJniOvudwCNN26ub2pwBnDp7JzPvoviH4o3APuAXgSuXEviSVKZrrpkf+FDc378fru3DOkKVGqffD57pSyrLxo3w0EPtH9+wAR5cykXuOQbpTF9VUrMTAKmfMmF6kXWEpvuwjlAlRu+oRNNTxdq1B8eLlauGVhfz3m/a6ayXUhdFwOpF1hFa3Yd1hDzTr7PpKbhtM9y/C/7lIfj+weL2/l3F/uk+DSeQamLr1qLTtpWhIbi8D+sIGfp1tu8amNjP/MXKZ4r9+/rQqyTVyM6dxSiducE/O3rnuj6sI2To19nBceYH/qwZOLi7n9VIK97wcDEsc2ys6LRdv764HRtb/nDNTnlNv64yi2v4C5mZdqETqcuGh+GGG4qtdr/IVYkiik7bhQz1oVdJqrEy/noZ+nW2fivtPwJDsL4PvUqS+srQr7NNO2H0HOZ/DIaK/Zv60Kskqa8M/TpbPQw/vwfOHoO1G+DE9cXt2WPFfsfpSyuO0zDoR+y0lQaS0zBoeQx8acUz9CWpRgx9SaoRQ1+SasTQl6QaMfQlqUYMfUmqEUNfkmrE0B90NftxnaRj49TKg8glDiUtk6E/aGaXOJy74tX9u+DwHc6ZI1VQlWY48fLOoHGJQ2kgTE3B9u2wcSOcfnpxu317sb9MTrg2aD65sVi8vJ21G+CKB/tVjaQWpqZg82bYvx9mms7PZtfC7fbSiE64tlJ1ssShpNJcc838wIfi/v79cG2JX8gN/UHiEofSQBgfnx/4s2ZmYPfu/tbTzNAfNC5xKFVaJkwv8oV8usQv5Ib+oHGJQ6nSImD1Il/IV5f4hdzQHzQucShV3tatRadtK0NDcHmJX8gdvTPoqjQAWBLg6B31koEvVc7wcBHsY2OwYQOsX1/cjo11P/A75Zm+JPVYr7+Qe6YvSRVSpS/khr4k1UjXQj8izomIf+rwmIsjYjwivh0RGRFXLtL+kka7udspx1S8JNVEN8/0jwee0+Exa4F9wFs6PO65wKlN26MdHi9JtbTkqZUj4j2LNHlWpy+emZ8BPtN4/k4OfTQzn+j09QaCQzAl9VAn8+m/FbgHaNczfNIxV7N090TEGuBe4F2Z+fk+vnb3uSiKtGJU/bytk9B/AHhvZv55qwcj4nxgbzeKWsAjwJuALwFrgF8D7oyIn83ML7epa02j7axqpaiLokgDb2qqmFlzfLyYV2f16uJXuTt3ljsmv5VOQv9LwAVAy9AHEujpv2+ZeR9wX9OuuyLiTOBtwC+3OWwH8Nu9rOuYLGVRlAtvKKMySUvQ7te3u3bBHXeU/2OsuTrpyH078L52D2bmvswsYwjo3wNnLfD49cBo03ZaP4pasoPjzA/8WTNwsMQ5WCUtqspz57ey5JDOzEOZ+c2IeEm7NhFxdXfK6sj5FJd9WsrMo5k5ObsBJS9W1sRFUaSBV+W581tZzpn5X0XEH0bEDycPjYhnRsQ48HudPFFEnBQR5zf6AwA2Nu6f0Xj8+oi4uan9r0fEFRFxVkScGxHvA14K7FrGn6N8LooiDbSqz53fynJC/yXAq4AvRsTzIuIVFKNoRijOujtxIfCVxgbwnsZ//07j/qnAGU3tjwf+CPgq8DlgE/BzmfnXnf8xKsJFUaSBVfW581vpOPQz8y6KcL8X+DLwf4D3Apdk5jc7fK47MzNabNsaj2/LzEua2v9BZp6VmSdm5r/OzJdk5t90+meoFBdFkQZalefOb2W5Ha9nU5ylPww8SfEL2X/VraJqxUVRpIG2c2cxR/7c4J+dO/+6ip23dTy1ckS8E3g38D+Bd1CMnPkziss7r8vMPd0uspsqP7Vy1X/ZIWmeqalilM7u3T8ap3/55UXg92O4ZidTKy8n9B8B3tCYQmF232rgd4Htmbmm7cEVUPnQlzTQyjhv6yT0O/lx1qzzMvM7zTsycxp4R0T85TKeT5JWjKp/UV9OR+53Fnjsc8dWjiQNhioNw+yEi6hI0hJNTcH27bBxI5x+enG7fXuxf1C4Rm6/2VErDaR2c+zMjtIpc44d18itmukp+NJ2+ORGuOX04vZL24v9kgbCoM2x045n+r3Wburk2R9fORZfGggbN8JDD7V/fMMGePDBflXzdJ7pV8lSpk6WVGmDOMdOO4Z+rzl1sjTwBnGOnXYM/V5y6mRpxRi0OXbaMfR7yamTpRUhc/Dm2GnH0O81p06WBtLcMfk//dPwohfBG99YdNquX1/cjo1Vb0nEhTh6p9ccvSMNnKWMyT/ppOp8SXf0TpU4dbI0cJYyJr8qgd8pz/T7zV/kSpVX5TH5rXimX2UGvlRpK2lMfiuGviQ1WUlj8lsx9HtlUE8DpJrLXDlj8ltZziIqamd6qph24eB48aOrodXFkM1NO+2wlSpsaqrovB0fLy7drFoFJ58Mjz/+9PO3QRuT34qh3y3thmbevwsO3+FIHamiFhqe+YxnFEMzn3qq/+ve9oqXd7rFidWkgbTQ8MwnnoArr4QDB4rROjfcMNiBD4Z+9zixmjSQxsfnB/6smRnYvXtwO21bMfS7wYnVpIG00odntmLod4MTq0kDayUPz2zF0O8WJ1aTBsLcidS++932bQd9eGYrTsPQLU6sJlVeu5E6rVRhwfOlchqGMjixmlR57UbqQHEJZ3h4cKdMXirP9HvFidWkylnKRGr/9E+D91fXM/0qGLRPjbTCzcwsbaTOSmfod1PNvjVJVdfcaXvGGXDo0MLtV9pInVachuFYOd+OVEmddNrCyhyp04qhfyycb0eqrIU6bedaCROpLZWXd46F8+1IlZS58PQKAMcdt/JH6rRSauhHxMURMR4R346IjIgrl3DMJRHx5Yg4GhEPRMS23lfahvPtSJUx90dXBw4s3H7dOvjWt1bORGpLVfblnbXAPuBPgf+9WOOI2Ah8Cvgg8EvApcCHI+KRzLy1l4XO08l8Oyu9Z0gqWafX76HotG23UMpKVmroZ+ZngM8AxNKC8U3Ag5n59sb9/RHxIuBtQH9D3/l2pMro5Po91KfTtpVB+3duM/DZOftubexvKSLWRMTI7AZ070uc8+1IlbB7d2eBX5dO21YGLfRPAQ7P2XcYGImIE9scswOYaNoe7lo1m3YW8+rMexsb8+1squmnSuqD5mv4i12/X7UKnv3s+nXatlL2Nf1+uB54T9P9YboV/LPz7ey7tui0/eE4/cuLwHe4ptQTnV7DP/30wZxeoRcGLfQPAevm7FsHTGbm91sdkJlHgaOz95fYd7B0x50EF95QbHbaSn3R6Rj8yy/3r+asQQv9PcBlc/ZtaezvH3+FK5VqsTH4s+p+/b6VUkM/Ik4CzmratTEizgcey8xvRcT1wPrM/JXG4x8ExiLiDyiGeb4U+PfAK/pWtL/ClUq1lCUOh4aKSzpXXFEEfl2v37dS9pn+hcDfNN2fvfZ+E7ANOBU4Y/bBzHwwIl4BvBd4K8W1+V/r6xj9pfwK98Ib+laOVDcRiy9xeMYZxY+uNF+po3cy887MjBbbtsbj2zLzkhbHPD8z12TmmZn5sb4W7a9wpdJt3dr+h1V1HoO/FIM2ZLNcnfwKV1LP7NxZXKufG/xew1+cod8Jf4UrVcLwcDHWfmysGHtfx4nTlqvsa/qDZ/3WotO25SUef4Ur9cvwcDFR2g2Olu6IZ/qd8le4UuUY+Etn6Hdq9le4Z4/B2g1w4vri9uwxh2tKqrzImnU6NiZdm5iYmGBkZOTYn9DvlZJKNjk5yejoKMBoZk4u1NYz/WNl4EsaIIa+JNWIoS9JNWLoS1KNGPqSVCOGviTViKEvSTVi6EtSjRj6klQjhr4k1YihL0k1YuhLUo0Y+pJUI4a+JNWIoS9JNWLoS1KNGPqSVCOGviTViKEvSTVi6EtSjRj6klQjhr4k1YihL0k1YuhLUo0Y+pJUI4a+JNWIoS9JNWLoS1KNGPqSVCOVCP2IeEtEPBQRRyLi7oj4mQXabouInLMd6We9kjSoSg/9iHg18B7g3cALgH3ArRHx4wscNgmc2rQ9p9d1StJKUHroA/8Z+FBmfjQzvwa8Cfh/wBsWOCYz81DTdrgvlUrSgCs19CPieOAC4LOz+zJzpnF/8wKHnhQR34yIAxHxyYj4qQVeY01EjMxuwHC36pekQVP2mf4zgVXA3DP1w8ApbY65j+JbwBXA6yj+DHdFxGlt2u8AJpq2h4+xZkkaWGWHfscyc09m3pyZ92Tm54B/C/wzcHWbQ64HRpu2dv84SNKKd1zJr/8d4Clg3Zz964BDS3mCzJyOiK8AZ7V5/ChwdPZ+RCyvUklaAUo908/MHwB7gUtn90XEUOP+nqU8R0SsAs4DHulFjZK0kpR9pg/FcM2bIuJLwN8Dvw6sBT4KEBE3Awczc0fj/m8BXwAeAE4G3kExZPPD/S5ckgZN6aGfmZ+IiGcBv0PReXsP8LKmYZhnADNNhzwD+FCj7eMU3xQuagz3lCQtIDKz7Br6qjFsc2JiYoKRkZGyy5GkYzY5Ocno6CjAaGZOLtR24EbvSJKWz9CXpBox9CWpRgx9SaoRQ1+SasTQl6QaMfQlqUYMfUmqEUNfkmrE0JekGjH0JalGDH1JqhFDX5JqxNCXpBox9CWpRgx9SaoRQ1+SasTQl6QaMfQlqUYMfUmqEUNfkmrE0JekGjH0JalGDH1JqhFDX5JqxNCXpBox9CWpRgx9SaoRQ1+SasTQl6QaMfQlqUYMfUmqEUNfkmrE0JekGjH0JalGKhH6EfGWiHgoIo5ExN0R8TOLtL8qIr7eaP/ViLisF3WNj49z5MiRlo8dOXKE8fHxXrysJPVM6aEfEa8G3gO8G3gBsA+4NSJ+vE37i4CPAx8Bng/cAtwSEed2u7YtW7Zw4403zgv+I0eOcOONN7Jly5Zuv6Qk9VRkZrkFRNwNfDEzxxr3h4ADwP/IzN9r0f4TwNrMfGXTvi8A92Tmm5bweiPAxMTEBCMjI4vWNxvwV199NSeccMK8+5JUtsnJSUZHRwFGM3NyobalnulHxPHABcBnZ/dl5kzj/uY2h21ubt9wa7v2EbEmIkZmN2C4kxpPOOEErr76am688UaeeOIJA1/SQDuu5Nd/JrAKODxn/2HgJ9scc0qb9qe0ab8D+O3lFghF8L/+9a/nggsuYO/evQa+pIFV+jX9PrgeGG3aTuv0CY4cOcJNN93E3r17uemmm9p27kpS1ZUd+t8BngLWzdm/DjjU5phDnbTPzKOZOTm7AVOdFNh8Df/kk0/+4aUeg1/SICo19DPzB8Be4NLZfY2O3EuBPW0O29PcvmHLAu2XrVWnbfM1foNf0qAp+0wfiuGa/zEiXh8R5wAfANYCHwWIiJsj4vqm9jcAL4uIt0fET0bEu4ALgfd3u7Dbb7+9ZaftbPDffvvt3X5JSeqp0odsAkTEGPAOis7Ye4DtmXl347E7gYcyc1tT+6uA64ANwD8Cv5GZn17ia3U0ZFOSqq6TIZuVCP1+MvQlrTQDM05fktRfhr4k1YihL0k1YuhLUo0Y+pJUI4a+JNVI2ROulWZycsFRTZI0MDrJszqO018PPFx2HZLUA6dl5sGFGtQx9AN4Nh1OvEYxD//DFLN0dnps2ay9HNZejrrWPgx8OxcJ9dpd3mm8IQv+S9hK8W8FAFOL/eKtaqy9HNZejhrXvqT2duRKUo0Y+pJUI4b+0h0F3t24HTTWXg5rL4e1L6B2HbmSVGee6UtSjRj6klQjhr4k1YihL0k1Yug3iYi3RMRDEXEkIu6OiJ9ZpP1VEfH1RvuvRsRl/aq1RS1Lrj0itkVEztmO9LPeploujojxiPh2o44rl3DMJRHx5Yg4GhEPRMS23lfaso6Oam/UPfd9z4g4pU8lz9axIyK+GBFTEfFoRNwSEc9dwnGlf96XU3tVPu8R8eaI+IeImGxseyLi5Ysc0/X33NBviIhXA++hGC71AmAfcGtE/Hib9hcBHwc+AjwfuAW4JSLO7UvBT6+lo9obJoFTm7bn9LrONtZS1PuWpTSOiI3Ap4C/Ac4H3gd8OCJ+oUf1LaSj2ps8l6e/9492ua7FvBjYBbwQ2AKsBm6LiLXtDqjQ573j2huq8Hl/GHgncAFwIXAH8MmI+KlWjXv2nmemWzFs9W7g/U33hyima3hnm/afAP5yzr4vAB8cgNq3AU+U/Z63qCuBKxdp8/vAvXP2/QXwVwNQ+yWNdieX/V7PqetZjbouXqBNZT7vy6i9kp/3Rm2PAf+hn++5Z/pARBxP8a/vZ2f3ZeZM4/7mNodtbm7fcOsC7XtimbUDnBQR34yIAxHR9myjgirxvh+jeyLikYi4PSL+TdnFAKON28cWaFPV930ptUPFPu8RsSoiXkPxbXFPm2Y9ec8N/cIzgVXA4Tn7DwPtrree0mH7XllO7fcBbwCuAF5H8Tm4KyJO61WRXdTufR+JiBNLqKcTjwBvAv5dYzsA3BkRLyiroIgYorhE9vnMvHeBplX5vP9QB7VX5vMeEedFxPcofnH7QeBVmfm1Ns178p7XbpZNQWbuoensIiLuAvYDVwO/WVZdK11m3kcRQLPuiogzgbcBv1xOVewCzgVeVNLrH4sl1V6xz/t9FH1Ro8AvAjdFxIsXCP6u80y/8B3gKWDdnP3rgENtjjnUYfteWU7tT5OZ08BXgLO6W1pPtHvfJzPz+yXUc6z+npLe94h4P/BK4CWZudjCQlX5vAMd1/40ZX7eM/MHmflAZu7NzB0UAwHe2qZ5T95zQ5/ifwSwF7h0dl/jq+OltL/etqe5fcOWBdr3xDJrf5qIWAWcR3H5oeoq8b530fn0+X2PwvuBVwEvzcwHl3BYJd73ZdY+9zmq9HkfAta0eaw373nZvddV2YBXA0eA1wPnADcCjwPrGo/fDFzf1P4iYBp4O/CTwLuAHwDnDkDtvwX8PPATFEM8Pw58H3heCbWfRBF851OMwnhb47/PaDx+PXBzU/uNwL8Af9B43/8T8CTwCwNQ+69TXFc+i+KyxPsovqVd2ue6/wR4gmL44ylN24lNbSr5eV9m7ZX4vDc+DxcDGyj+0bkemAG29PM97+tfkqpvwBjwTYpOlruBn2167E7gY3PaX0Vxje4ocC9w2SDUDry3qe0hinHvzy+p7ksagTl3+1jj8Y8Bd7Y45iuN+r8BbBuE2oHfAB5oBM53KX5r8JIS6m5Vcza/j1X9vC+n9qp83inG2z/UqONRipE5W/r9nju1siTViNf0JalGDH1JqhFDX5JqxNCXpBox9CWpRgx9SaoRQ1+SasTQl6QaMfSlHqvK8o4SGPpST1VseUfJaRikYxERzwK+CvxxZv5uY99FFPOovJxioq9XZOa5Tcf8BcWSiS/rf8WqO8/0pWOQmf9MsSrTuyLiwogYBv6MYs3iv6a6ywyqplw5SzpGmfnpiPgQ8L+AL1FM/byj8fCCyzvmYC78ogHmmb7UHf+F4iTqKuCXMvNoyfVILRn6UnecCTyb4u/Uhqb9K215Rw04L+9Ixygijgf+HPgExYIXH46I8zLzUYql7S6bc8ggL++oAefoHekYRcQfAr8IbAK+B3wOmMjMVzaGbN4L7AL+FHgp8McUI3puLalk1ZihLx2DiLgEuJ1i2cO/a+zbAOwD3pmZH2i0eS/wPOBh4L9n5sf6X61k6EtSrdiRK0k1YuhLUo0Y+pJUI4a+JNWIoS9JNWLoS1KNGPqSVCOGviTViKEvSTVi6EtSjRj6klQjhr4k1cj/B2DqTZ3iQk/XAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "axes = plt.figure(figsize=(4, 4), dpi=100).gca()\n", "historyNt = onp.asarray(historyNt)\n", "updatesNt = onp.asarray(updatesNt) \n", "axes.scatter(historyGD[:,0], historyGD[:,1], lw=0.5, color='blue')\n", "axes.scatter(historyNt[:,0], historyNt[:,1], lw=0.5, color='orange')\n", "axes.scatter([0], [0], lw=0.25, color='black', marker='x') # target at 0,0\n", "axes.set_xlabel('x0')\n", "axes.set_ylabel('x1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Not completely surprising: for this simple example we can reliably evaluate the Hessian, and Newtons's method profits from the second order information. It's trajectory is much more diagonal (that would be the ideal, shortest path to the solution), and does not slow down as much as GD." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Physical Gradients\n", "\n", "Now we also use inverse physics, i.e. the inverse of z:\n", "$\\mathbf{z}^{-1}(\\mathbf{x}) = [x_0 \\ x_1^{1/2}]^T$, to compute the _physical gradient_. As a slight look-ahead to the next section, we'll use a Newton's step for $L$, and combine it with the inverse physics function to get an overall update. This gives an update step:\n", "\n", "$$\\begin{aligned}\n", "\\Delta \\mathbf{x} &= \n", "\\mathbf{z}^{-1} \\left( \\mathbf{z}(\\mathbf{x}) - \\eta\n", " \\left( \\frac{\\partial^2 L }{ \\partial \\mathbf{z}^2 } \\right)^{-1}\n", " \\frac{\\partial L }{ \\partial \\mathbf{z} }\n", "\\right) - \\mathbf{x}\n", "\\end{aligned}$$\n", "\n", "Below, we define our inverse function `fun_z_inv_analytic` (we'll come to a variant below), and then evaluate an optimization with the physical gradient for ten steps:\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PG iter 0: [2.1 2.5099802]\n", "PG iter 1: [1.4699999 2.1000001]\n", "PG iter 2: [1.0289999 1.7569861]\n", "PG iter 3: [0.72029996 1.47 ]\n", "PG iter 4: [0.50421 1.2298902]\n", "PG iter 5: [0.352947 1.029 ]\n", "PG iter 6: [0.24706289 0.86092323]\n", "PG iter 7: [0.17294402 0.7203 ]\n", "PG iter 8: [0.12106082 0.60264623]\n", "PG iter 9: [0.08474258 0.50421 ]\n" ] } ], "source": [ "x = np.asarray([3.,3.])\n", "eta = 0.3\n", "historyPG = [x]; historyPGz = []; updatesPG = []\n", "\n", "def fun_z_inv_analytic(y):\n", " return np.array( [y[0], np.power(y[1],0.5)] )\n", "\n", "for i in range(10):\n", " \n", " # Newton step for L(z)\n", " zForw = fun_z(x)\n", " GL = jax.grad(fun_L)(zForw)\n", " HL = jax.jacobian(jax.jacobian(fun_L))(zForw)\n", " HLinv = np.linalg.inv(HL)\n", " \n", " # step in z space\n", " zBack = zForw -eta * np.matmul( HLinv , GL)\n", " historyPGz.append(zBack)\n", "\n", " # \"inverse physics\" step via z-inverse\n", " x = fun_z_inv_analytic(zBack)\n", " historyPG.append(x)\n", " updatesPG.append( historyPG[-2] - historyPG[-1] )\n", " print( \"PG iter %d: \"%i + format(x) )\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we obtain $[0.084 \\ 0.504]$ as the final position, with a distance of only 0.51! This is clearly better than both Newton and GD.\n", "\n", "Let's directly visualize how the PGs (in red) fare in comparison to Newton's method (orange) and GD (blue)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'x1')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAFtCAYAAADrr7rKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAi1UlEQVR4nO3df5RcZZ3n8fe3kiawSXczjppAk5ggJxgFG4SjJstBhI0iEn44MnpcZ8m6M4JLn6jrOks2jKsjMTPOWSSztJpVR8joOPwX0qMOP4aB3TUtI9E0ohFkTSRpSFCR7tYloUl/94+nCivVdauruqvqPrfu53XOPUXdem7Vk6L6U7fufe73MXdHRETyoZB2B0REpH0U+iIiOaLQFxHJEYW+iEiOKPRFRHJEoS8ikiMKfRGRHFHoi4jkyPy0O9BuZmbAqcBE2n0REWmibuApn+GK29yFPiHwD6bdCRGRFjgNGK3VII+hPwFw4MABenp60u6LiMicjY+Ps3TpUqjjCEYeQx+Anp4ehb6I5I5O5IqI5IhCX0QkRxT6IiI5otAXEckRhb6ISI4o9EVEIuLHplr6/KmGvpl9yMweMbPx4jJsZu+YYZtrzOwnZnbEzH5oZpe1q78iIq0w8eRTHLygnxd/fz5TL+/ixd+fz8EL+pl48qmmv1bae/oHgRuB84DzgfuBu8zsddUam9ka4BvAV4BzgR3ADjM7qy29FRFpsoknn+KEN51O365HmP/sMeY9N8X8Z4/RN/wIJ7zp9KYHv8U2MbqZPQt83N2/UuWxO4GF7n552brvAnvc/fo6n78HGBsbG9PFWSKSuoMX9NO36xGsShR7AUbX9HPa/95T8znGx8fp7e0F6HX38Vpt097Tf4mZzTOz9wILgeGEZquB+yrW3V1cn/S8C8ysp7QQihKJiERhyd4fVQ18AJuCJT/+UVNfL/XQN7Ozzew3wFHgi8DV7v7jhOZLgMMV6w4X1yfZCIyVLSq2JiJR8GNT2FTtoy02NdXUk7uphz7wGHAO8CbgC8AdZvbaJj7/FqC3bDmtic8tIjJrNq+AF6xmGy8YNq95UZ166Lv7C+7+hLvvdveNwAjw4YTmh4DFFesWF9cnPf9Rdx8vLaiOvohE5NCq1+EJSewFOPTa5o5TST30qygACxIeGwYuqVi3luRzACIiUev9u2/zwisXTAt+L8ALr1xA79e/1dTXS3uc/hYzu9DMlheP7W8BLgK+Xnx8e3FdyVbgUjP7mJm9xsw+SRjqeVu7+y4i0gzdy07lhYd+xuiafl582XyOnVzgxZfNZ3RNPy889DO6l53a1NdLu57+K4HtwCmEk6yPAG9393uLjy8DXjqD4e67zOx9wM3AZ4CfAle5+6Nt7bWISBN1LzuV7uKwTD82hc0rtOzkY3Tj9FtN4/RFpNNkcpy+iIi0nkJfRCRHFPoiIjmi0BcRyRGFvohIjij0RURyRKEvIpIjCn0RkRxR6IuI5IhCX0QkRxT6IiI5otAXEckRhb6ISI4o9EVEckShLyKSIwp9EZEcUeiLiOSIQl9EJEcU+iIiOaLQF5HqcjZ/dl7MT7sDIhKRyQkY2QSjQzA1CYUu6FsH/Zuhqzvt3kkTKPRFJJicgHtWw9heYOp36x8fhMP3w9uGFfwdQId3RCQY2TQ98CHcH9sLIzel0StpMoW+iASjQ0wP/JIpGN3Zzt5Iiyj0RSSctJ2arN1malIndzuAQl9EwCyctK2l0BXaSaYp9EUk6FtHciQUoO+KdvZGWkShLyJB/2boXcX0WCiE9f03p9EraTKFvogEXd1hWObKAVi4HE7qC7crBzRcs4OY5+zEjJn1AGNjY2P09PSk3R2ReLnrGH5GjI+P09vbC9Dr7uO12mpPX0SqU+B3JIW+iEiOKPRFRHJEoS8ic5Oz84JZp9AXkcZNTMCGDbBiBSxdGm43bAjrJWqphr6ZbTSz75nZhJk9Y2Y7zOzMGbZZb2ZesRxpV59Fcm9iAlavhsFB2L8fRkfD7eBgWK/gj1rae/pvAQaBNwNrgS7gHjNbOMN248ApZcurWtlJESmzaRPs3QtTFcXZpqbC+ptUjTNmqYa+u1/q7re7+4/cfQRYDywDzpt5Uz9UthxueWdFJBgamh74JVNTsFPVOGOW9p5+pd7i7bMztFtkZj83swNmdpeZvS6poZktMLOe0gLoskKR2XKHyRmqcU6qGmfMogl9MysAtwLfcfdHazR9DPgAcCXwfsK/YZeZnZbQfiMwVrYcbFafRXLHDLpmqMbZpWqcMYsm9AnH9s8C3lurkbsPu/t2d9/j7g8C7wJ+AVyXsMkWwi+I0pL05SAi9Vi3DgoJ0VEowBWqxhmzKELfzG4DLgfe6u4N7Ym7+yTwA+CMhMePuvt4aQE0tEBkLjZvhlWrpgd/oRDW36xqnDFLe8imFQP/auBid983i+eYB5wNPN3s/olIFd3dMDwMAwOwfDn09YXbgYGwvlunzWKWapVNM/s88D7C8fnHyh4ac/fni222A6PuvrF4/xPAd4EngJOBjwNXAee5+4/reE1V2RRpJlXjTF0jVTbnt6dLiT5UvH2gYv2/B24v/vcyjp+t+feALwFLgF8Du4E19QS+iLSAAj9TVE9fRCTjVE9fpJlytmMknS3twzsicZqcgJFNMDoEU5NQ6AoTh/dv1rSBkmkKfZFKkxNwz2oY28txp5MeH4TD92u+WMk0Hd4RqTSyaXrgQ7g/thdGVFBMskuhL1JpdIjpgV8yBaMqKCbZpdAXKecejuHXMqWCYpJdCn2RcmbhpG0tBRUUk+xS6ItU6ltH8p9GAfpUUEyyS6EvUql/M/SuYvqfRyGs71dBMckuhb5Ipa7uMCxz5QAsXA4n9YXblQMarimZpzIMIjNRQTGJnMowiDSTAl86iEJfRCRHFPoiIjmi0BcRyRGFvohIjij0RURyRKEvIpIjCn0RkRxR6IuI5IhCX0QkRxT6IiI5otAXEckRhb6ISI4o9CU9OavwKhKD+Wl3QHJmcgJGNoXJx6cmw9SDfevCxCWdUqdepZglYtrTl/aZnIB7VsPjg/Db/fD8aLh9fDCsn5xIu4ezNzEBGzbAihWwdGm43bAhrBeJiCZRkfZ5eEMIeKaqPFgIM1Odv7XdvZq7iQlYvRr27oWpsn9boQCrVsHwMHR3yK8YiZImUZE4jQ5RPfAJ60d3trM3zbNp0/TAh3B/71646aZ0+iVShUJf2sM9HMOvZWoymyd3h4amB37J1BTszOiXmXQkhb60h1k4aVtLoSt7J0DdYXKGL7PJjH6ZSUdS6Ev79K0j+SNXgL4r2tmb5jCDrhm+zLoy+GUmHUuhL+3Tvxl6VzH9Y1cI6/tvTqNXc7duXThpW02hAFdk8MtMOpZG70h7TU7AyE3hpO1L4/SvCIGf1XH6Gr0jKWtk9I5CX9LTSRcxTUyEUTo7d4Zj+F1dYQ//5psV+NJymQl9M9sIvAt4DfA8sAv4L+7+2AzbXQN8GlgO/LS4zbfqfE2FvrRWJ32ZSSZkaZz+W4BB4M3AWqALuMfMFiZtYGZrgG8AXwHOBXYAO8zsrJb3VqQeCnyJWFSHd8zsFcAzwFvc/X8ltLkTWOjul5et+y6wx92vr+M1tKcvIh0lS3v6lXqLt8/WaLMauK9i3d3F9dOY2QIz6yktgA6wikhuRRP6ZlYAbgW+4+6P1mi6BDhcse5wcX01G4GxsuXg3HoqIpJd0YQ+4dj+WcB7m/y8Wwi/IErLaU1+fhGRzIiinr6Z3QZcDlzo7jPtiR8CFlesW1xcP427HwWOlr3WHHoqItI8aQz0SnVP34LbgKuBi919Xx2bDQOXVKxbW1wvIhK1tKdeSHuc/ueB9wFXAuVj88fc/flim+3AqLtvLN5fAzwI3Ah8k3A46L8Cb5jhXEDpNTV6R0RS0aqLt7M0eudDhOPsDwBPly3vKWuzDDildMfddxG+KD4IjADvBq6qJ/BFRNIUw9QLUY3Tbwft6YtIWlasgP37kx9fvhz21XOQu0KW9vQlJjnbARBpp1imXohi9I6kaHICRjaFqQxfqnq5LpRBzmrVS5EIxTL1gvb082xyAu5ZHSYr/+1+eH403D4+GNZPtmk4gUhOxDD1gkI/z0Y2wdhepk9WPhXWj2hCb5Fm2rw5jNKpDP7S6J2b2zCPkEI/z0aHmB74JVNhohMRaZru7jAsc2AgnLTt6wu3AwPtm2tHx/Tzyj0cw69lalK14UWarLsbtm4NS+6uyJUUmYWTtrUUNKG3SCul8eel0M+zvnUkfwQKYe5aEekoCv08698MvauY/jEohPX9bTirJCJtpdDPs65ueNswrByAhcvhpL5wu3IgrNc4fZGOozIM8jtZOGmbhT6KtJnKMMjsxBqmadeiFekg2tOXuLWqFq1IB9GevnSOGGrRinQQhb7EbWhoeuCXTE3BTl01LNIIhb7EK5ZatCIdRKEv8YqlFq1IB1HoS9xiqEUr0kEU+hK3GGrRinQQhX7Wdfrx7Bhq0Yp0EI3Tz6I8T3GoK3JFpmlknL7q6WdNaYrDyhmvHh+Ew/d3fs0cBb5kUEz7Kjq8kzWa4lAkE2KtHqLDO1lz14oweXmShcvhyn3t6o2IVNHu6iEqw9CpGpniUERSE3P1EIV+lmiKQ5FMiLl6iEI/azTFoUjUYq8eotDPGk1xKBK12KuHKPSzRlMcikQv5uohGr2TdTENABYRQKN3pJUU+CLRibl6iPb0RURarNU/yLWnLyISkZh+kCv0RURypGmhb2arzOxnDW5zoZkNmdlTZuZmdtUM7S8qtqtclsyp8yIiOdHMPf0TgFc1uM1CYAS4ocHtzgROKVueaXB7EZFcqru0spndMkOTVzT64u7+beDbxedvZNNn3P25Rl8vE7I0BDNLfRURoLF6+h8G9gBJZ4YXzbk39dtjZguAR4FPuvt32vjazZelSVEmJkI1qaGhcC15V1e4EmXzZs1iJUL8+0J1D9k0s8eAT7v71xIePwfY7e7zZtURMweudvcdNdqcCVwEPAwsAP4Y+CPgTe7+/YRtFhTblnQDB6MZspk0KUqprEJMV9m2+4oTkYxIe1+oVUM2HwbOq/G4Ay39fnP3x9x9m7vvdvdd7v4BYBfw0RqbbQTGypaDrexjw7I0KUrM9WJFUlLaFxochP37YXQ03A4OhvVpT5pSqZHQ/xhwa9KD7j7i7mkMAf0X4Iwaj28BesuW09rRqbqNDjE98EumYDTFGqyVYq4XK5KSrO0L1R3S7n7I3X9uZm9NamNm1zWnWw05B3g66UF3P+ru46UFiOd7N0uTosReL1YkJVnbF5rNnvk/mtlfmdlLxUPN7OVmNgT8RSNPZGaLzOyc4vkAgBXF+8uKj28xs+1l7T9iZlea2RlmdpaZ3QpcDAzO4t+RvixNihJ7vViRFGRxX2g2of9W4Grge2b2WjN7J2EUTQ9hr7sR5wM/KC4AtxT/+8+L908BlpW1PwH478APgQeBfuDfuPs/Nf7PiESWJkWJuV6sSAqyuC80q4JrZrYI+CLwbkJi/RnwWc9A9bboCq5p9I5Ipm3YEE7aVjvEUyiEyppbt7a2D+0ouLaSsJd+EHiRcIXsv5rlc+VbliZFiblerEhKNm8O+zyVP4JL+0I3RzaZXcN7+mZ2I/Ap4H8CHyeMnPlbwuGd97v7cLM72UzR7elXiv3KjnJZ6qtIC01MhFE6O3f+bpz+FVeEwI9tnP5sQv9p4APFEgqldV3AZ4AN7r4gceMIRB/6IpJpaewLNRL6jZRhKDnb3X9ZvsLdJ4GPm9k/zOL5REQ6Ruw/fhs+pl8Z+BWPPTi37oiIZEP8w1aq0yQqIiJ1mpgIo3VWrIClS8Pthg3xlVqoRXPktptOfopkUswjljVHbmwmJ+DhDXDXCtixNNw+vCGsF5FMyFqNnSTa02+1LF18JSKJVqwI1TOTLF8O+/a1qzfH055+TLJUOllEqspijZ0kCv1Wy1LpZBGpKos1dpIo9FspS6WTRaSmTqk3qNBvpSyVThaRRO7Zq7GTRKHfajGVTtYvCpG6VY7Jf/3r4YIL4IMfzHa9QY3eabW0R++kPWOzSAbVMyZ/0aJ4fqS3tOBa1qVycdbkRBilM7ozHMMvdIU9/P6bWx/4sV5NIhKxGGrkN0KhX0OursjN2idXJBIxj8mvRuP0Y9bO34NZm7FZJAKdNCa/GoV+p+r0T65Ii3TSmPxqFPqtknaYdvonV6RF3DtnTH41Cv1miq2wWid/ckWaqHJ45l13wcknT98nytqY/Gp0IrdZ0h6aWY1G74jMqNafycknh6GZx461f97bRuhEbhpiLKzW3R2CfWAg21eTiLRQrZLJzz0HV10FBw6E0Tpbt2b/z0Z7+s1y1wr47f7kxxcuhytTHuOlCVxEpsna8MxqtKffblkprKbAFzlOHge5KfSbQYXVRDIrb4PcFPrNElNhNRFJVDlS51e/Sm7biYPcdEy/WWIcvSMix0kaqVNNlga56Zh+Grq6Q7CvHAgnbU/qC7crBxT4IpFIGqkD4RBOd3fnD3LTnn6raKSMSHTqGanzs59l709Xe/oxyNqnRqTDTU3VN1Kn0yn0mylnv5pEYld+0nbZMjh0qHb7ThupU838tDuQeZMT4Wrc0aGyCVLWQf/m1h3H16EjkRk1ctIWOnOkTjXa05+L0oidxwfD1bjPj4bbxwfD+mYWWqscZ7ZiRbg/kVIxN5HI1TppW6kTCqnVSydy5+LhDSHgp9XbASiEkTvnN2FmKhVOE2mIO5x+eu2TtvPnw+LFcRdSq1dmTuSa2YVmNmRmT5mZm9lVdWxzkZl938yOmtkTZra+9T1NMDpE9cAnrB9t0sxUtSpC7d0LN6VQzE0kMpU/hg8cqN1+8WJ48snOKaRWr7QP7ywERoAb6mlsZiuAbwL/DJwD3Ap82cze3qL+JWtnvR1NeyhSU+nH8OBg2LsfHQ3lkGvp6kqebqKTpXoi192/DXwbwOo7MXk9sM/dP1a8v9fMLgA+Ctzdkk4maVe9nUYqQunkruRUI8fvIT8nbavJ2vfcauC+inV3F9dXZWYLzKyntADN+xHXjno7mvZQZEY7dzYW+Hk5aVtN1kJ/CXC4Yt1hoMfMTkrYZiMwVrYcbFpv+jeHujrT3sZivZ3+Jn2qNO2hyDTlx/BnOn4/bx6cempnl1eoVx7G6W8Bbim7302zgr9Ub2fkpnDS9qVx+leEwG/WOP3Nm+H++5NH7+R1l0Vyq9Ex+EuXZrO8QitkLfQPAYsr1i0Gxt39+WobuPtR4Gjpfp3nDuo3f1EYlnn+1tYdVy9Ne3jTTeF37ORkZ4wzE5mlRsfgX3GFAr8kmnH6ZubA1e6+o0abvwQuc/ezy9b9HfAyd7+0zteZ+zj9NK7CLaeTtpJzMxVOK8nLpSyNjNNPdU/fzBYBZ5StWmFm5wDPuvuTZrYF6HP3f1d8/IvAgJl9Fvgb4GLgD4F3tq3TSXXzHx+Ew/e3p4yyAl9yrJ4BbYVCOKRz5ZX6MVwp7cM75xPG3JeUjr3fAawHTgGWlR50931m9k7gc8CHCcfm/9jd2zdcc2RTlYlSCPfH9obj+824CldEqqpnQNuyZfFPZp6WVEfvuPsD7m5VlvXFx9e7+0VVtjnX3Re4+6vd/fa2drpdV+GKSCINaJu9rA3ZTFc7r8IVkUSbN4dj9ZXBrwFtM1PoN6LVV+Hqy0KkLqUBbQMDYex9p09x2EwK/UY1+ypclUwWmZXu7lAobd++cHFW3gqnzVY0QzbbZc5DNpNG75Suwm1k9I5KJotIE2SmtHImla7CXTkAC5fDSX3hduVA48M1VTJZRNpMe/pzNZcLpWa6wmT5co07E5EZaU+/neZy0rbekskiIk2i0E+LSiaLSAoU+mnSFSYi0mYK/TTpChMRaTOFfpp0hYmItJlG78REJZNFZBY0eierFPgi0mIKfRGRHFHot1vODqeJSFwU+u2gomoiEgmdyG01FVUTkRbTidyYqKiaiEREod9qQ0PTA79kagp2anpFEWkfhX4rqaiaiERGod9KKqomIpFR6LeaiqqJSEQU+q2momoiEhGFfqupqJqIRETj9NtNRdVEpMk0Tj8W1b5QFfgikiKFfrOp5IKIREyHd5pJJRdEJAU6vJMWlVwQkcgp9JtJJRdEJHIK/WZRyQURyQCFfrOo5IKIZIBCv5lUckFEIqfQbyaVXBCRyCn0m2nRIpVcEJGoRRH6ZnaDme03syNm9pCZvbFG2/Vm5hXLkXb29ziVF2O9/vXhZO0jj8CBA7BvH2zdqsAXkSjMT7sDZvYe4BbgeuAh4CPA3WZ2prs/k7DZOHBm2f10hsQkXYw1OAj336+9exGJTgx7+v8J+JK7f9Xdf0wI//8HfKDGNu7uh8qWw23paSVdjCUiGZNq6JvZCcB5wH2lde4+Vby/usami8zs52Z2wMzuMrPX1XiNBWbWU1qA5u1662IsEcmYtPf0Xw7MAyr31A8DSxK2eYzwK+BK4P2Ef8MuMzstof1GYKxsOTjHPge6GEtEMijt0G+Yuw+7+3Z33+PuDwLvAn4BXJewyRagt2xJ+nJojC7GEpEMSjv0fwkcAxZXrF8MHKrnCdx9EvgBcEbC40fdfby0AM2rcayLsUQkY1INfXd/AdgNXFJaZ2aF4v3hep7DzOYBZwNPt6KPNeliLBHJmLT39CEM1/wTM7vWzFYBXwAWAl8FMLPtZral1NjMPmFmbzOz083sDcDXgFcBX257zzX/rYhkTOrj9N39TjN7BfDnhJO3e4BLy4ZhLgPKh8j8HvClYttfE34prCkO92y/7u5w8dXWrZr/VkSip5mzREQyTjNniYhIVQp9EZEcUeiLiOSIQl9EJEcU+iIiOaLQFxHJEYW+iEiOKPRFRHJEoS8ikiMKfRGRHFHoi4jkiEJfRCRHFPoiIjmi0BcRyRGFvohIjij0RURyRKEvIpIjCn0RkRxR6IuI5IhCX0QkRxT6IiI5otAXEckRhb6ISI4o9EVEckShLyKSIwp9EZEcUeiLiOSIQl9EJEcU+iIiOaLQFxHJEYW+iEiOKPRFRHJEoS8ikiMKfRGRHFHoi4jkSBShb2Y3mNl+MztiZg+Z2RtnaH+Nmf2k2P6HZnZZK/o1NDTEkSNHqj525MgRhoaGWvGyIiItk3rom9l7gFuATwFvAEaAu83slQnt1wDfAL4CnAvsAHaY2VnN7tvatWvZtm3btOA/cuQI27ZtY+3atc1+SRGRljJ3T7cDZg8B33P3geL9AnAA+B/u/hdV2t8JLHT3y8vWfRfY4+7X1/F6PcDY2NgYPT09M/avFPDXXXcdJ5544rT7IiJpGx8fp7e3F6DX3cdrtU11T9/MTgDOA+4rrXP3qeL91QmbrS5vX3R3UnszW2BmPaUF6G6kjyeeeCLXXXcd27Zt47nnnlPgi0imzU/59V8OzAMOV6w/DLwmYZslCe2XJLTfCPy32XYQQvBfe+21nHfeeezevVuBLyKZlfox/TbYAvSWLac1+gRHjhzhjjvuYPfu3dxxxx2JJ3dFRGKXduj/EjgGLK5Yvxg4lLDNoUbau/tRdx8vLcBEIx0sP4Z/8sknv3SoR8EvIlmUaui7+wvAbuCS0rriidxLgOGEzYbL2xetrdF+1qqdtC0/xq/gF5GsSXtPH8JwzT8xs2vNbBXwBWAh8FUAM9tuZlvK2m8FLjWzj5nZa8zsk8D5wG3N7ti9995b9aRtKfjvvffeZr+kiEhLpT5kE8DMBoCPE07G7gE2uPtDxcceAPa7+/qy9tcANwPLgZ8Cf+ru36rztRoasikiErtGhmxGEfrtpNAXkU6TmXH6IiLSXgp9EZEcUeiLiOSIQl9EJEcU+iIiOaLQFxHJkbQLrqVmfLzmqCYRkcxoJM/yOE6/DziYdj9ERFrgNHcfrdUgj6FvwKk0WHiNUIf/IKFKZ6Pbpk19T4f6no689r0beMpnCPXcHd4pviE1vwmrCd8VAEzMdMVbbNT3dKjv6chx3+tqrxO5IiI5otAXEckRhX79jgKfKt5mjfqeDvU9Hep7Dbk7kSsikmfa0xcRyRGFvohIjij0RURyRKEvIpIjCv0yZnaDme03syNm9pCZvXGG9teY2U+K7X9oZpe1q69V+lJ3381svZl5xXKknf0t68uFZjZkZk8V+3FVHdtcZGbfN7OjZvaEma1vfU+r9qOhvhf7Xfm+u5ktaVOXS/3YaGbfM7MJM3vGzHaY2Zl1bJf65302fY/l825mHzKzR8xsvLgMm9k7Ztim6e+5Qr/IzN4D3EIYLvUGYAS428xemdB+DfAN4CvAucAOYIeZndWWDh/fl4b6XjQOnFK2vKrV/UywkNDfG+ppbGYrgG8C/wycA9wKfNnM3t6i/tXSUN/LnMnx7/0zTe7XTN4CDAJvBtYCXcA9ZrYwaYOIPu8N970ohs/7QeBG4DzgfOB+4C4ze121xi17z91dSxi2+hBwW9n9AqFcw40J7e8E/qFi3XeBL2ag7+uB59J+z6v0y4GrZmjzl8CjFev+HvjHDPT9omK7k9N+ryv69Ypivy6s0Saaz/ss+h7l573Yt2eB/9DO91x7+oCZnUD49r2vtM7dp4r3Vydstrq8fdHdNdq3xCz7DrDIzH5uZgfMLHFvI0JRvO9ztMfMnjaze83sX6fdGaC3ePtsjTaxvu/19B0i+7yb2Twzey/h1+JwQrOWvOcK/eDlwDzgcMX6w0DS8dYlDbZvldn0/THgA8CVwPsJn4NdZnZaqzrZREnve4+ZnZRCfxrxNHA98AfF5QDwgJm9Ia0OmVmBcIjsO+7+aI2msXzeX9JA36P5vJvZ2Wb2G8IVt18Ernb3Hyc0b8l7nrsqmwLuPkzZ3oWZ7QL2AtcBf5ZWvzqduz9GCKCSXWb2auCjwB+l0ysGgbOAC1J6/bmoq++Rfd4fI5yL6gXeDdxhZm+pEfxNpz394JfAMWBxxfrFwKGEbQ412L5VZtP347j7JPAD4Izmdq0lkt73cXd/PoX+zNW/kNL7bma3AZcDb3X3mSYWiuXzDjTc9+Ok+Xl39xfc/Ql33+3uGwkDAT6c0Lwl77lCn/A/AtgNXFJaV/zpeAnJx9uGy9sXra3RviVm2ffjmNk84GzC4YfYRfG+N9E5tPl9t+A24GrgYnffV8dmUbzvs+x75XPE9HkvAAsSHmvNe5722etYFuA9wBHgWmAVsA34NbC4+Ph2YEtZ+zXAJPAx4DXAJ4EXgLMy0PdPAG8DTicM8fwG8Dzw2hT6vogQfOcQRmF8tPjfy4qPbwG2l7VfAfwW+Gzxff+PwIvA2zPQ948QjiufQTgscSvhV9olbe7354HnCMMfl5QtJ5W1ifLzPsu+R/F5L34eLgSWE750tgBTwNp2vudt/SOJfQEGgJ8TTrI8BLyp7LEHgNsr2l9DOEZ3FHgUuCwLfQc+V9b2EGHc+7kp9fuiYmBWLrcXH78deKDKNj8o9v//Auuz0HfgT4EnioHzK8K1Bm9Nod/V+uzl72Osn/fZ9D2WzzthvP3+Yj+eIYzMWdvu91yllUVEckTH9EVEckShLyKSIwp9EZEcUeiLiOSIQl9EJEcU+iIiOaLQFxHJEYW+iEiOKPRFWiyW6R1FQKEv0lKRTe8oojIMInNhZq8Afgj8tbt/prhuDaGOyjsIhb7e6e5nlW3z94QpEy9tf48l77SnLzIH7v4LwqxMnzSz882sG/hbwpzF/0S80wxKTmnmLJE5cvdvmdmXgK8DDxNKP28sPlxzekfP5sQvkmHa0xdpjv9M2Im6Bvi37n405f6IVKXQF2mOVwOnEv6mlpet77TpHSXjdHhHZI7M7ATga8CdhAkvvmxmZ7v7M4Sp7S6r2CTL0ztKxmn0jsgcmdlfAe8G+oHfAA8CY+5+eXHI5qPAIPA3wMXAXxNG9NydUpclxxT6InNgZhcB9xKmPfw/xXXLgRHgRnf/QrHN54DXAgeBT7v77e3vrYhCX0QkV3QiV0QkRxT6IiI5otAXEckRhb6ISI4o9EVEckShLyKSIwp9EZEcUeiLiOSIQl9EJEcU+iIiOaLQFxHJEYW+iEiO/H8eT2QWVl9OjgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "historyPG = onp.asarray(historyPG)\n", "updatesPG = onp.asarray(updatesPG) \n", "\n", "axes = plt.figure(figsize=(4, 4), dpi=100).gca()\n", "axes.scatter(historyGD[:,0], historyGD[:,1], lw=0.5, color='blue')\n", "axes.scatter(historyNt[:,0], historyNt[:,1], lw=0.5, color='orange')\n", "axes.scatter(historyPG[:,0], historyPG[:,1], lw=0.5, color='red')\n", "axes.scatter([0], [0], lw=0.25, color='black', marker='x') # target at 0,0\n", "axes.set_xlabel('x0')\n", "axes.set_ylabel('x1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This illustrates that the PG variant does significantly better than Newton's method. It yields a trajectory that is better aligned with the _diagonal_, and its final state is closer to the origin. A key ingredient here is the inverse function for $\\mathbf{z}$, which yielded higher order terms than the second-order approximation for Newton's method. Despite the simplicity of the problem, Newton's method has problems finding the right search direction. For PG, on the other hand, the higher order information yields an improved direction for the optimization.\n", "\n", "This difference also shows in first update step for each method: below we measure how well it is aligned with the diagonal." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Diagonal lengths (larger is better): GD 1.053930, Nt 1.264911, PG 1.356443 \n" ] } ], "source": [ "def mag(x):\n", " return np.sqrt(np.sum(np.square(x)))\n", "\n", "def one_len(x):\n", " return np.dot( x/mag(x), np.array([1,1])) \n", "\n", "print(\"Diagonal lengths (larger is better): GD %f, Nt %f, PG %f \" % \n", " (one_len(updatesGD[0]) , one_len(updatesNt[0]) , one_len(updatesPG[0])) )\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The largest value of 1.356 for PG confirms what we've seen above: the PG gradient was the closest one to the diagonal direction from our starting point to the origin." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "---\n", "\n", "
\n", "\n", "## Z Space\n", "\n", "To understand the behavior and differences of the methods here, it's important to keep in mind that we're not dealing with a black box that maps between $\\mathbf{x}$ and $L$, but rather there are spaces in between that matter. In our case, we only have a single $\\mathbf{z}$ space, but for DL settings, we might have a large number of latent spaces, over which we have a certain amount of control. We will return to NNs soon, but for now let's focus on $\\mathbf{z}$. \n", "\n", "A first thing to note is that for PG, we explicitly map from $L$ to $\\mathbf{z}$, and then continue with a mapping to $\\mathbf{x}$. Thus we already obtained the trajectory in $\\mathbf{z}$ space, and not coincidentally, we already stored it in the `historyPGz` list above.\n", "\n", "Let's directly take a look what PG did in $\\mathbf{z}$ space:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'z1')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAGDCAYAAAAyH0uIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcxklEQVR4nO3de5zddX3n8deHBCY+IDOIA1IbLTha0ha7XIJotgWMCtq1raigZd0lG1svFbdiH7VS7+6yrFsfDMVgTSPetongrpU+xFWLlVxQ4gVvg3dToFCuI8wZIjCJ8bN//M6Qk2FmMjM5Z37nm3k9H4/f4+R8z/d3fp/z4/DON9/zu0RmIkkqz0F1FyBJmhsDXJIKZYBLUqEMcEkqlAEuSYUywCWpUAa4JBXKAJekQhngklQoA1ySCmWAS1KhDHBJKpQBLkmFMsBVlIg4JiJyqmUf6y6NiMsi4taIGIuIeyPiuog4qaXPpoi4OSJOjoivRMTDEXFLRLx2wnsdEhHviYibIqIRET+PiK0R8ZxJtntQRPxZRAxFxCMRcV9EfD4iVkzo98rm+z0cEfdHxFUR8eT93Wc6cIWXk1VJIuJQ4OwJzQcDg8DOzDxqmnU3AC8D1gLfB54A/A5wdWZuaPbZBDwdWAx8EvgxcG6z36sy88PNfv3Ad4FPAD8BlgKvAp4KPDMzv92y3Y8Aq4HPAV9ovvfvAl/MzLXNPm8F/ltzm5uBI4E3ADuAEzNzZBa7SQtFZrq4FL0AVwC/AJ6zj34jwNp99NkEJPCmlrZDgG8B9wAHN9sWAYdMWPdw4G7gypa25zTf728m2db4AOrXmvX/1YTXjwd2TWx3cRlfnEJR0SLiPwN/Crw5M6/fR/cR4NSIeNI++v0CWDf+JDN3Np8fBZzcbNvdbB+fIjmCamT9DeCklvd6KVWAv3viRjJz/J+/L6GazvxkRPSPL1R/GfyE6i8B6TEMcBUrIk4APgh8IjMvncEqb6Ya1d4eEV+LiHdFxFMn6XdnZv58QtuPm4/HtGz//Ij4LvAI8DPgPuA/AH0t6w003+/+aep6OhBUYX3fhOU3qP7ikB5jcd0FSHMREY8HPkUVrH88k3Uy85MRsZVqDv1M4C+Av4yIl2Tm52a5/VcCHwWuAf4auBfYDVxEFdqzcRDVKP2FzfeYaMcs308LhAGu4kTEQcAGqjnn52XmQzNdNzPvAj4AfCAijgK+CbyV6gfGcU+KiEMnjMJ/vfl4a/PxZcC/AC9pmQohIiZOlWwHzoqII6YZhW+nGoHfkpk/nqKP9BhOoahE7wTOAv4oM2+ZyQoRsSgiWqc2yMx7gTuBngndFwOvaVn3kObz+4Cbms3jI+Vo6Xcq8OwJ7/WpZp93TlLT+Lr/0Hy/d7a0PdonIp6wr8+nhckRuIoSEc8A3g5sAY5qTmU8KjP/fopVlwJ3RMT/Bb5DNS3xPOAU4M8n9L2TamrlGKopmpcDJwCvzsxdzT7XUv34+OmI+CxwLPBaqsMTD2up5/qI+N/Af42IpwOfpxo4/S5wPdVRMdsj4m3AJcAxEXEN8GDzPc8G/g543wx3kRYQjwNXUSLiDKrgm1RmxmTtzVH0f6ea+34qVYj+FFiXmX/b0m8T0A+cD7wfOJHq8MG/zswrWvoF8BaqkfnRVMH9duAc4IzMPKal7yLgQvYcJ96gOlrlbZn5zZZ+L2n2O7HZdDvwz8DlTq1oMga41GI8wDPz+LprkfbFOXBJKpQBLkmFMsAlqVDOgUtSoRyBS1KhDHBJKlTRJ/I0j8V9EtVJD5J0oFhKdRG0aee4iw5wqvC+o+4iJKkDlgH/Nl2H0gP8QYDbb7+d3t7eumuRpP02OjrKk5/8ZJjBzELpAQ5Ab2+vAS5pwfFHTEkqlAEuSYUywCWpUAa4JBXKAJekQhngktQpHb7WlAEuSZ2wezesWFE9dogBLkmdsHUrDA3BDTd0bBMHxIk8ktQVhoeh0aj+PDgIu3ZVj8uWVW19fdDf37bNOQKXpHa5+GJYvhxWroQtW6q2zZur58uXV6+3kQEuSe0yOAjr1sGiRTAyUrWNjMDixbB+ffV6GxngktROa9bAqlV7t61aBatXt31TzoFLUjtlVtMm/f0wMADbt8OmTVV7RFs35QhcktppaAh27ICNG2HbNtiwoXo+NNT2TRV9U+OI6AUajUbDy8lK6g5jY9XSmkmNBixZAj09+1x9dHSUvr4+gL7MHJ2ur1MoktROPT2PDeoqkNvOKRRJKpQBLkmFMsAlqVAGuCQVqvYAj4hfjYi/j4ifRcTDETEUESvqrkuSul2tR6FExOOBLwPXAy8E7gOeDjxQZ12SVIK6DyP8S+D2zPwvLW231FWMJJWk7imUPwC+ERH/JyLujYhvRcSfTNU5Inoiond8AZbOX6mS1F3qDvCnAq8DfgKcBfwtcHlEnD9F/4uARstyx3wUKUndqNZT6SNiJ/CNzFzZ0nY5cEpmPnuS/j1A6ylOS4E7PJVe0oGipFPp7wK+P6HtB8BLJ+ucmWPA2PjzaPOVvSSpJHVPoXwZOG5C268Dt9VQiyQVpe4AHwSeFRF/FRFPi4jzgFcDV9RclyR1vVoDPDO/DpwN/BFwM/B24I2ZuaHOuiSpBHXPgZOZ1wLX1l2HJJWm7ikUSdIcGeCSVCgDXJIKZYBLUqEMcEkqlAEuSYUywCWpUAa4JBXKAJekQhngklQoA1ySCmWAS1KhDHBJKpQBLkmFMsAlqVAGuCQVygCXpEIZ4JJUKANckgplgEtSoQxwSSqUAS5JhTLAJalQBrgkFcoAl6RCGeCSVCgDXJIKZYBLUqEMcEkqlAEuSYUywCWpUAa4JBWq1gCPiHdFRE5YflhnTZJUisV1FwB8D3hey/Nf1FWIJJWkGwL8F5l5d91FSFJpumEO/OkRcWdE/EtEbIiIp0zVMSJ6IqJ3fAGWzmOdktRV6g7wrwKrgRcArwOOBbZGxFTBfBHQaFnumIcaJakrRWbWXcOjIuJw4DbgTZl55SSv9wA9LU1LgTsajQa9vb3zU6QkddDo6Ch9fX0AfZk5Ol3fbpgDf1RmjkTEj4GnTfH6GDA2/jwi5qs0Seo6dU+h7CUiDgMGgLvqrkWSul3dx4G/LyJOj4hjImIl8GlgN/CJOuuSpBLUPYWyjCqsnwDcB9wAPCsz76u1KkkqQK0BnpmvqHP7klSyrpoDlyTNnAEuSYUywCWpUAa4JBXKAJekQhngklQoA1ySCmWAS1KhDHBJKpQBLkmFMsAlqVAGuCQVygCXpEIZ4JJUKANc0h5ddI9c7ZsBLqmyezesWFE9qggGuKTK1q0wNAQ33FB3JZqhum+pJqlOw8PQaFR/HhyEXbuqx2XLqra+Pujvr68+TcsRuLSQXXwxLF8OK1fCli1V2+bN1fPly6vX1bUMcGkhGxyEdetg0SIYGanaRkZg8WJYv756XV3LAJcWujVrYNWqvdtWrYLVq2spRzPnHLi00GVW0yb9/TAwANu3w6ZNVXtE3dVpGo7ApYVuaAh27ICNG2HbNtiwoXo+NFR3ZdqHyIIP3I+IXqDRaDTo7e2tuxypTGNj1dL6/1CjAUuWQE9PfXUtUKOjo/T19QH0ZebodH2dQpEWup6exwZ1FSDqck6hSFKhDHBJKpQBLkmFMsAlqVAGuCQVygCXpEIZ4JJUKANckgrVNQEeEW+JiIyIy+quRZJK0BUBHhGnAK8Bvlt3LZJUitoDPCIOAzYAfwI8UHM5klSM2gMcuAL4bGZ+cV8dI6InInrHF2Bp58uTpO5U68WsIuIVwEnAKTNc5SLgnZ2rSJLKUdsIPCKeDPwN8B8z85EZrnYJ0NeyLOtQeZLU9eocgZ8MHAV8M/bc9WMRcFpEXAD0ZObu1hUycwwYG38e3i1E0gJWZ4D/M/CMCW0fAX4IvHdieEuS9lZbgGfmg8DNrW0R8XPgZ5l58+RrSZLGdcNRKJKkOeiqW6pl5hl11yBJpXAELkmFMsAlqVAGuCQVygCXpEIZ4JJUKANckgplgEtSoQxwSSqUAS5JhTLAJalQBrgkFcoAl6RCGeCSVCgDXJIKZYBLUqEMcEkqlAEuSYUywCWpUAa4JBXKAJekQhngklQoA1ySCmWAS1KhDHBJKpQBLkmFamuAR8S/i4jd7XxPSdLkOjECjw68pyRpgsWz6RwR/7CPLn1Azr0cqctkQjgmUXea7Qj894ElQGOKZUdbq5PqtHs3rFhRPUpdaFYjcOAHwKcy88rJXoyIE4AX7W9RUlfYuhWGhuCGG+D00+uuRnqM2Qb4TcBJwKQBDowB/7pfFUl1Gh6GRqP68+Ag7NpVPS5bVrX19UF/f331SS0ic+ZT1hHRAyzKzIc6V9LMRUQv0Gg0GvT29tZdjg4EF14Ia9fCEUfAzp0wMgKHHw6HHAL33w8XXFAFutQho6Oj9PX1AfRl5uh0fWc1B56ZY5n5UER8KSLeOfH1iHh8RHxppu8XEa+LiO9GxGhzuTEiXjibmqS2GhyEdetg0aIqvKF6XLwY1q83vNVV5noY4RnABRFxTUQc2tJ+CDCbycI7gLcAJwMrgC8B/xgRvzXHuqT9t2YNrFq1d9uqVbB6dS3lSFPZn+PAnwccDWyLiGPm8gaZ+ZnM/H+Z+ZPM/HFmvpXqSJZn7Udd0v7JhM2bq7nuU0+tHjdtqtqlLrI/AX4X1Wh7CPh6RJyxP4VExKKIeAVwKHDjFH16IqJ3fAGW7s82pUkNDcGOHbBxI2zbBhs2VM+HhuquTNrLbI9CGZdQzYkD50XE24DPA++d7RtFxDOoAnsJ1ej77Mz8/hTdLwIeM/cutdVxx8Ftt8H4D+Nnngm33gpLltRaljTRrI5CeXSliF8CR2fmvS1tLwU+BjwuMxfN4r0OAZ5CdRbny4A/Bk6fLMSbR8H0tDQtBe7wKBRJB4rZHIUy1xH4scB9rQ2Z+amI+CHVj5Ezlpk7gZ82n94UEacAfwa8ZpK+Y1THmgMQnuIsaQGbU4Bn5m1TtH8P+N5+VVTNy/fss5ckLXBzHYG3RURcAnyO6uzNpcB5VIconlVjWZJUhFoDHDgK+DjwK1QXw/oucFZmXldrVZJUgFoDPDNfVef2Jalk3lJNkgplgEtSoQxwSSqUAS5JhTLAJalQBrgkFcoAl6RCGeCSVCgDXJIKZYBLUqEMcEkqlAEuSYUywCWpUAa4JBXKAJekQhngklQoA1ySCmWAS1KhDHBJKpQBLkmFMsAlqVAGuCQVygCXpEIZ4JJUKANckgplgKs+mXVXIBXNAFc9du+GFSuqR0lzYoCrHlu3wtAQ3HBD3ZVIxVpcdwFaQIaHodGo/jw4CLt2VY/LllVtfX3Q319ffVJhHIFr/lx8MSxfDitXwpYtVdvmzdXz5cur1yXNmAGu+TM4COvWwaJFMDJStY2MwOLFsH599bqkGTPANb/WrIFVq/ZuW7UKVq+upRypZM6Ba35lVtMm/f0wMADbt8OmTVV7RN3VSUWpdQQeERdFxNcj4sGIuDciromI4+qsSR02NAQ7dsDGjbBtG2zYUD0fGqq7Mqk4kTWeTBERnweuAr5O9a+B/wEcD/xmZv58Buv3Ao1Go0Fvb29Ha1WbjI1VS+t/r0YDliyBnp766pK6xOjoKH19fQB9mTk6Xd9ap1Ay8wWtzyNiNXAvcDKwpY6a1GE9PY8N6urLKmmWum0OfPz/5PsnezEieoDW//uXdrwiSepSXXMUSkQcBFwGfDkzb56i20VAo2W5Y36qk6Tu0zUBDlxBNf/9imn6XEI1Sh9fls1DXZLUlbpiCiUi1gIvAk7LzClH1Zk5Boy1rDcP1UlSd6o1wKNK4PcDZwNnZOYtddYjSSWpewR+BXAe8IfAgxFxdLO9kZkP11eWJHW/uufAX0c1l70JuKtleXmNNUlSEeo+DtxJbEmao7pH4JKkOTLAJalQBrgkFcoAl6RCGeCSVCgDXJIKZYBLUqEMcEkqlAEuSYUywFWp8dZ6kubGABfs3g0rVlSPkophgAu2bq3uCn/DDXVXImkW6r6crOoyPFzdDR5gcBB27aoelzVvctTXB/399dUnaZ8cgS9UF18My5fDypWwZUvVtnlz9Xz58up1SV3NAF+oBgdh3TpYtAhGRqq2kRFYvBjWr69el9TVDPCFbM0aWLVq77ZVq2D16lrKkTQ7zoEvZJnVtEl/PwwMwPbtsGlT1e4No6Wu5wh8IRsagh07YONG2LYNNmyong8N1V2ZpBmILPgEjojoBRqNRoPe3t66yynP2Fi1tO67RgOWLIGenvrqkhaw0dFR+vr6APoyc3S6vk6hLGQ9PY8N6uqLI6kATqFIUqEMcEkqlAEuSYUywCWpUAa4JBXKAJekQhngklQoA1ySCmWAS1KhDPBSFXwJBEntYYCXyHtYSsIAL5P3sJSEF7Mqh/ewlDRBrSPwiDgtIj4TEXdGREbEi+usp6t5D0tJE9Q9hXIo8B3g9TXX0f28h6WkCWoN8Mz8XGa+LTM/PZP+EdETEb3jC7C0wyV2F+9hKalF3SPw2boIaLQsd9RbzjxrvYflqadWj+P3sJS04JQW4JcAfS3LsnrLmWfew1JSi6KOQsnMMWBs/HkstDunH3cc3HbbnntYnnkm3HprdQ9LSQtOUQG+4HkPS0ktSptCkSQ11ToCj4jDgKe1NB0bEScA92fmv9ZTlSSVoe4plBXA9S3PL20+fgxYPe/VSFJBag3wzNwELLBfIltkwkL7IVZS2zgHXhevKChpPxngdfGKgpL2U91z4AuLVxSU1EaOwOeTVxSU1EYG+HzyioKS2sgAn29eUVBSmzgHPt9aryg4MADbt++5oqCHFEqaBUfg880rCkpqk8iCryXdvKlDo9Fo0Dt+hb5uMN1oemysWlrrbTSqKwpOvFCVpAVndHSUvuoidX2ZOTpdX0fg7bavE3R6evYOb6gOHzS8Jc2SAd5unqAjaZ74I2Y7eIKOpBo4Am8HT9CRVAMDvB08QUdSDQzwdvEEHUnzzDnwuZp4qKAn6EiaZ47A52KyQwU9QUfSPHMEPhethwqefnrVdtxxcNtte47xPvNMuPXW6gQdSeoAA3ym5nKoYHU2lSR1hFMoM+WhgpK6jAE+U5de6qGCkrqKAT4T4z9ann++hwpK6hrOgc/E+I+WW7d6qKCkruEIfCrDw1VAb9++50fL97yn+iHzfe+Da6/1UEFJtfJ64FO58EJYuxaOOAJ27qzmuw8/HA4+GB54AC64oAp2r+UtqY28Hng7THV9k4MP3vtHS6/lLakmBvh0vL6JpC7mj5jT8fomkrqYI/DpeH0TSV3MHzGn4w2IJc2z2fyI6RTKdHp6HhvUXt9EUpdwCkWSCtUVAR4Rr4+IWyPikYj4akQ8s9PbfMc73tHpTUhSR9Ue4BHxcuBS4N3AScB3gC9ExFGd2uZ9993HlVdeyfDwcKc2IUkdV3uAA28C1mfmRzLz+8BrgYeANZ3a4Pr167nzzjv50Ic+1KlNSFLH1RrgEXEIcDLwxfG2zPxl8/mzJ+nfExG94wuwdKbb2rhxIwMDAwwMDHD55ZcDcNlllz3atnHjxv39OJI0r+oegfcDi4B7JrTfAxw9Sf+LgEbLcsdMN3Teeedx1VVX0d/fzz33VJu75557OPLII7n66qs577zz5lK/JNWm7gCfrUuAvpZl2WxWPuWUUzj33HP3ajvnnHNYsWJF2wqUpPlSd4APA7uBJ05ofyJw98TOmTmWmaPjC/DgbDe4detWBgYGeMMb3sDAwABbt26dU+GSVLdaAzwzdwI3Ac8db4uIg5rPb2z39h566CGGh4e57rrruPzyy7nuuusYHh7m4YcfbvemJKnjaj+VvnkY4ceA1wBfA94InAssz8yJc+MT153VqfTjnzVaLkQ1WZsk1aWoU+kz8+qIOBJ4D9UPl98GXrCv8J6LyULa4JZUqtoDHCAz1wJr665DkkpS94+YkqQ5MsAlqVAGuCQVygCXpEIZ4JJUKANckgrVFYcR7q/R0WmPdZekYswmz2o/E3N/RMSvMosrEkpSQZZl5r9N16H0AA/gScz+olZLqYJ/2RzWPdC4L/ZwX+zN/bHHfO+LpcCduY+ALnoKpfnhpv0bajItp88/uK9rDRzo3Bd7uC/25v7Yo4Z9MaNt+COmJBXKAJekQi3UAB8D3t18XOjcF3u4L/bm/tijK/dF0T9iStJCtlBH4JJUPANckgplgEtSoQxwSSrUARvgEfH6iLg1Ih6JiK9GxDP30f+ciPhhs/9QRPzefNXaabPZFxGxOiJywvLIfNbbKRFxWkR8JiLubH6uF89gnTMi4psRMRYRP42I1Z2vtPNmuy+a+2Hi9yIj4uh5KrljIuKiiPh6RDwYEfdGxDURcdwM1qs9Mw7IAG/e6f5SqsN+TgK+A3whIo6aov9K4BPAlcCJwDXANRFx/LwU3EGz3RdNo8CvtCy/1uk658mhVJ//9TPpHBHHAp8FrgdOAC4DPhQRZ3Wovvk0q33R4jj2/m7c2+a66nA6cAXwLOD5wMHAP0XEoVOt0DWZkZkH3AJ8FVjb8vwgqlPu3zJF/6uBaye0bQM+WPdnqWFfrAZG6q57HvZLAi/eR5/3AjdPaLsK+Hzd9dewL85o9ju87nrnYX8c2fysp03Tpysy44AbgUfEIcDJwBfH2zLzl83nz55itWe39m/6wjT9izDHfQFwWETcFhG3R8Q/RsRvdbjUbnVAfi/207cj4q6IuC4i/n3dxXRIX/Px/mn6dMV344ALcKAfWATcM6H9HmCq+bqjZ9m/FHPZFz8C1gB/CLyS6jvylYhY1qkiu9hU34veiHhcDfXU6S7gtcBLm8vtwKaIOKnWqtosIg6imir7cmbePE3XrsiMoq9GqPbLzBuBG8efR8RXgB8ArwHeXlddqldm/ojqL/dxX4mIAeBC4D/VU1VHXAEcD/xO3YXMxIE4Ah8GdgNPnND+RODuKda5e5b9SzGXfbGXzNwFfAt4WntLK8JU34vRzHy4hnq6zdc4gL4XEbEWeBHwnMzc141iuiIzDrgAz8ydwE3Ac8fbmv8sei4tI8sJbmzt3/T8afoXYY77Yi8RsQh4BtU/oReaA/J70UYncAB8L6KyFjgbWJWZt8xgte74btT9i2+HfkV+OfAIcD7wG8A64AHgic3XPw5c0tJ/JbAL+HNgOfAuYCdwfN2fpYZ98Q7gTOCpVIcdfgJ4GPjNuj9LG/bFYVShcwLVUQYXNv/8lObrlwAfb+l/LPBz4H81vxd/CvwCOKvuz1LDvngj1e8iT6OaYriM6l93z637s7RhX3wAGKE6nPDoluVxLX26MjNq33kd/I9yAXAb1eUfvwqc2vLaJuCjE/qfQzXHNwbcDPxe3Z+hjn0BDLb0vZvqOOgT6/4MbdoPZzTDauLy0ebrHwU2TbLOt5r7Yzuwuu7PUce+AN4M/JTqL/OfUR0b/5y6P0eb9sVk+yFb/1t3a2Z4OVlJKtQBNwcuSQuFAS5JhTLAJalQBrgkFcoAl6RCGeCSVCgDXJIKZYBLUqEMcEkqlAEuzUJE/HZEbG3eB/H2iHhz3TVp4TLApRmKiF7gn6iuFXMy8BfAuyLi1bUWpgXLa6FILSLiGGCyy4luproP4sXA0VldqpeI+J9U95NcPm9FSk2OwKW93c7ed10/kerqe1uo7ne4ZTy8m74AHBcRj5/vQiUDXGqRmbsz8+7MvJvqGtEfpLpI/7uY+j6IUP79U1Ug74kpTe3DwFLg+Zn5y4ioux5pLwa4NImIeBtwFvDMzHyw2TzVfRDHX5PmlVMo0gQR8VKqW8udm5nbW166ETgtIg5uaXs+8KPMfGA+a5TAo1CkvUTE8VS3nbsUuKLlpZ1U94D8EdWhhO+lujfkh4ELM/Pv5rlUyQCXWkXEauAjk7y0OTPPiIjfpgr2U4Bh4P2Z+d55LFF6lAEuSYVyDlySCmWAS1KhDHBJKpQBLkmFMsAlqVAGuCQVygCXpEIZ4JJUKANckgplgEtSoQxwSSrU/wdQpwajiEJ+lQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "historyPGz = onp.asarray(historyPGz)\n", "\n", "axes = plt.figure(figsize=(4, 4), dpi=100).gca()\n", "axes.set_title('z space')\n", "axes.scatter(historyPGz[:,0], historyPGz[:,1], lw=0.5, color='red', marker='*')\n", "axes.scatter([0], [0], lw=0.25, color='black', marker='*') \n", "axes.set_xlabel('z0')\n", "axes.set_ylabel('z1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For PG we're making explicit steps in $\\mathbf{z}$ space, which progress in a straight diagonal line to the origin (which is likewise the solution in $\\mathbf{z}$ space).\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Interestingly, neither GD nor Newton's method give us information about progress in intermediate spaces (like the $\\mathbf{z}$ space). \n", "\n", "For **GD** we're concatenating the Jacobians, so we're moving in directions that locally should decrease the loss. However, the $\\mathbf{z}$ position is influenced by $\\mathbf{x}$, and hence we don't know where we end up in $\\mathbf{z}$ space until we have the definite point there. (For NNs in general we won't know at which latent-space points we end up after a GD update until we've actually computed all updated weights.)\n", "\n", "More specifically, we have an update $-\\eta \\frac{\\partial L}{\\partial \\mathbf{x}}$ for GD, which means we arrive at $\\mathbf{z}(\\mathbf{x} -\\eta \\frac{\\partial L}{\\partial \\mathbf{x}})$ in $\\mathbf{z}$ space. A Taylor expansion with \n", "$h = \\eta \\frac{\\partial L}{\\partial \\mathbf{x}}$ yields \n", "\n", "$\n", "\\quad\n", "\\mathbf{z}(\\mathbf{x} - h) = \n", "\\mathbf{z}(\\mathbf{x}) - h \\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{x}} + \\mathcal{O}( h^2 )\n", "= \\mathbf{z}(x) - \\eta \\frac{\\partial L}{\\partial \\mathbf{z}} (\\frac{\\partial \\mathbf{z}}{\\partial x})^2 + \\mathcal{O}( h^2 )\n", "$.\n", "\n", "And $\\frac{\\partial L}{\\partial \\mathbf{z}} (\\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{x}})^2$ clearly differs from the step $\\frac{\\partial L}{\\partial \\mathbf{z}}$ we would compute during the backpropagation pass in GD for $\\mathbf{z}$.\n", "\n", "**Newton's method** does not fare much better: we compute first-order derivatives like for GD, and the second-order derivatives for the Hessian for the full process. But since both are approximations, the actual intermediate states resulting from an update step are unknown until the full chain is evaluated. In the _Consistency in function compositions_ paragraph for Newton's method in {doc}`physgrad` the squared $\\frac{\\partial \\mathbf{z}}{\\partial \\mathbf{x}}$ term for the Hessian already indicated this dependency.\n", "\n", "With **PGs** we do not have this problem: PGs can directly map points in $\\mathbf{z}$ to $\\mathbf{x}$ via the inverse function. Hence we know eactly where we started in $\\mathbf{z}$ space, as this position is crucial for evaluating the inverse.\n", "\n", "In the simple setting of this section, we only have a single latent space, and we already stored all values in $\\mathbf{x}$ space during the optimization (in the `history` lists). Hence, now we can go back and re-evaluate `fun_z` to obtain the positions in $\\mathbf{z}$ space." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "x = np.asarray([3.,3.])\n", "eta = 0.01\n", "historyGDz = []\n", "historyNtz = []\n", "\n", "for i in range(1,10):\n", " historyGDz.append(fun_z(historyGD[i]))\n", " historyNtz.append(fun_z(historyNt[i]))\n", "\n", "historyGDz = onp.asarray(historyGDz)\n", "historyNtz = onp.asarray(historyNtz)\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'z1')" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAGDCAYAAAA26pu1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlkElEQVR4nO3df5zcVX3v8dcnm2SjkF2FhQQbKbhC0luwYH6Ae70EIpDa670Kiliu95rGtmrF22oftdLij/aay/W2dSkNahpBbZuo9NrSVisGCwmJEGjRyoISdGuAFEiyxp1lhewuu5/7x5kh311mktndmTnfM/N+Ph7fx+R75nx3PplMPnvmfM8Pc3dERCRdc2IHICIis6NELiKSOCVyEZHEKZGLiCROiVxEJHFK5CIiiVMiFxFJnBK5iEjilMhFRBKnRC4ikjglchGRxCmRi4gkTolcRCRxSuSSFDM7zcy80nGMaxea2fVmttfMRszsgJndbmavztTZbmYPmtlyM7vbzJ41sx+Z2bun/Kz5ZvaHZna/mRXM7KdmttPMLirzunPM7DfNrM/MDpvZQTO7zcxWTKn39uLPe9bMDpnZl8zs5bN9z6T5mZaxlZSY2XHAZVOK5wG9wKi7n3yUa7cAbwE2At8DTgReC3zZ3bcU62wHzgDmArcAjwBvLdZ7p7vfXKzXBTwAfBH4AbAQeCfwCmCVu/9r5nU/B6wDvg58o/iz/xPwTXffWKzz+8D/Kr7mDuAk4H3AMHCuuw9O422SVuPuOnQkfQA3As8BFx2j3iCw8Rh1tgMOfCBTNh/4DrAfmFcsawPmT7n2JcBTwE2ZsouKP+9Py7xWqSH1s8X4f2/K82cBY1PLdeiYeqhrRZJmZv8D+A3gg+5+5zGqDwLnmdnLjlHvOWBT6cTdR4vnJwPLi2XjxfJS18kJhJb2vwCvzvysNxMS+R9MfRF3L30dvpzQzXmLmXWVDsIvhR8QfhmIVKRELskys3OAzwBfdPdPVnHJBwmt3MfN7D4z+5iZvaJMvSfc/adTyh4pPp6Wef13mNkDwGHgx8BB4D8DnZnruos/79BR4joDMELSPjjl+DnCLxCRiubGDkBkJszspcBXCAn2V6u5xt1vMbOdhD72S4HfAX7XzC53969P8/XfDnweuBX4I+AAMA5cQ0je0zGH0Gp/ffFnTDU8zZ8nLUaJXJJjZnOALYQ+6Yvd/Zlqr3X3J4FPAZ8ys5OBbwO/T7gRWfIyMztuSqv8zOLj3uLjW4B/Ay7PdJFgZlO7UPqBtWZ2wlFa5f2EFvmP3P2RCnVEKlLXiqToo8Ba4Jfd/UfVXGBmbWaW7fLA3Q8ATwDtU6rPBd6VuXZ+8fwgcH+xuNRytky984DXTPlZXynW+WiZmErX/k3x5300U/Z8HTM78Vh/P2ltapFLUszsbODDwF3AycUujue5+19VuHQhsM/M/h/wXUJ3xcXASuC3p9R9gtDlchqh6+ZK4Bzg1919rFjnq4SblH9rZl8DTgfeTRjWeHwmnjvN7C+B/2lmZwC3ERpQ/wm4kzCKpt/MrgWuA04zs1uBp4s/8zLgz4E/rvItkhakceSSFDO7kJAAy3J3K1debFV/nNA3/gpCMv0hsMndP52ptx3oAt4B/BlwLmHY4R+5+42ZegZ8iNBSX0xI4B8GrgAudPfTMnXbgPdzZJx5gTC65Vp3/3am3uXFeucWix4H/gm4QV0ucjRK5CIZpUTu7mfFjkWkWuojFxFJnBK5iEjilMhFRBKnPnIRkcSpRS4ikjglchGRxCU9Iag4lvdlhMkTIiLNYiFhsbWq+r6TTuSEJL4vdhAiInWwBPj3aiqmnsifBnj88cfp6OiIHYuIyKwNDQ3x8pe/HKbR05B6Igego6NDiVxEWpZudoqIJE6JXEQkcUrkIiKJUyIXEUmcErmISOKUyEVakdZYaipK5CKtZmIcblsRHqUpKJGLtJqDO6HQBwd3xY5EaqQpJgSJyDEcHoCxQvjzw70wMQZ7euHFS0LZvE5Y0BUvPpkVJXKRVvDQBnhkI8w/ASZGQ9n+HbCtB0YPwZlXw/LeuDHKjKlrRaQVLO+FVZvA2mBsMJSNDYLNhVWblcQTFzWRm9leM/Myx40x4xJpSt3rYdGayWWL10D3uijhSO3E7lpZCbRlzs8Cbgf+Ok44Ik3MHQ7sgPYuOL4bhvth//ZQbhY7OpmFqC1ydz/o7k+VDuANQD+wI2ZcIk1psA+eG4aerbB2N/RsCeeDfbEjk1nKzebLZjYfeAL4pLv/7wp12oH2TNFCYF+hUNAytiLHMj4CEyMwL/N/ZbQAbQugrb3yddJQQ0NDdHZ2AnS6+1A11+TpZuebgJcAnz9KnWuAQubQ7kAi1Wprn5zEAeZ3Kok3gTwl8ncCX3f3J45S5zqgM3MsaURgIiJ5FvtmJwBm9rPAxcDlR6vn7iPASOa6OkcmIpJ/eWmR/wpwAPha7EBERFITPZGb2RxCIv+Cuz8XOx4RkdRET+SELpVTgZtjByIikqLofeTuvg1QZ7eIyAzloUUuIiKzoEQuIpI4JXIRkcQpkYuIJE6JXEQkcUrkIiKJUyIXEUmcErmISOKUyEVEEqdELiKSOCVyEZHEKZGLiCROiVxEJHFK5CIiiVMiFxFJnBK5iFTmHjsCqYISuYiUNz4OK1aER8k1JXIRKW/nTujrg127YkcixxB9qzcRyZGBASgUwp97e2FsLDwuWRLKOjuhqytefFKWWuQicsSGDbBsGfT0wF13hbIdO8L5smXheckdJXIROaK3FzZtgrY2GBwMZYODMHcubN4cnpfcUSIXkcnWr4c1ayaXrVkD69ZFCUeOTX3kIjKZe+hO6eqC7m7o74ft20O5WezopAy1yEVksr4+GB6GrVth927YsiWc9/XFjkwqME94wL+ZdQCFQqFAR0dH7HBEmsPISDiy/6cKBViwANrb48XVIoaGhujs7ATodPehaq5R14qITNbe/sKEHRKL5JS6VkREEhc9kZvZz5jZX5nZj83sWTPrM7MVseMSEUlF1K4VM3sp8C3gTuD1wEHgDOAnMeMSEUlJ7D7y3wUed/dfyZT9KFYwIiIpit218l+BfzGzvzazA2b2HTP7tUqVzazdzDpKB7CwcaGKiORT7ET+CuA9wA+AtcCngRvM7B0V6l8DFDLHvkYEKSKSZ1HHkZvZKPAv7t6TKbsBWOnurylTvx3IjotaCOzTOHIRaRYpjiN/EvjelLLvA28uV9ndR4CR0rlpurCISPSulW8BS6eUnQk8GiEWEZEkxU7kvcD5ZvZ7ZvZKM7sK+HXgxshxiYgkI2oid/d/Bi4Dfhl4EPgw8FvuviVmXCIiKYndR467fxX4auw4RERSFbtrRUREZkmJXEQkcUrkIiKJUyKX1pTwhioiUymRS+uZGIfbVoRHkSagRC6t5+BOKPTBwV2xIxGpiejDD0Ua4vAAjBXCnx/uhYkx2NMLL14SyuZ1woKuePGJzIISubSGhzbAIxth/gkwMRrK9u+AbT0wegjOvBqW98aNUWSG1LUirWF5L6zaBNYGY4OhbGwQbC6s2qwkLklTIpfW0b0eFq2ZXLZ4DXSvixKOSK2oa0Vahzsc2AHtXXB8Nwz3w/7toVxLIkvC1CKX1jHYB88NQ89WWLsberaE88G+2JGJzErUHYJmq7hvZ0E7BElVxkdgYgTmZT4rowVoWwBt7ZWvE2mgFHcIEmmctvYXJuz5nXFiEakhda2IiCROiVxEJHFK5CIiiVMiFxFJnBK5iEjilMhFRBKnRC4ikjglchGRxCmRi4gkTolcRCRxSuQiIolTIhcRSZwSuYhI4pTIRUQSFzWRm9nHzMynHA/HjElEJDV5WI/8IeDizPlzsQIREUlRHhL5c+7+VOwgRERSlYc+8jPM7Akz+zcz22Jmp1aqaGbtZtZROoCFDYxTRCSXYifye4F1wC8C7wFOB3aaWaUEfQ1QyBz7GhCjiEiu5WrzZTN7CfAo8AF3v6nM8+1AdtPFhcA+bb4sIs0i+c2X3X3QzB4BXlnh+RFgpHRuZo0KTUQkt2J3rUxiZscD3cCTsWMREUlF7HHkf2xmq83sNDPrAf4WGAe+GDMuEZGUxO5aWUJI2icCB4FdwPnufjBqVCIiCYmayN39bTFfX0SkGeSqj1xERKZPiVxEJHFK5CIiiVMiFxFJnBK5iEjilMhFRBKnRC4ikjglchGRxCmRi4gkTolcRCRxSuRSOzla216klSiRS21MjMNtK8Jjs9IvKskpJXKpjYM7odAHB3fFjqQ+xsdhxYrwKJIzSuQyc4cH4On+cDzcCxNjsKf3SNnhgdgR1s7OndDXB7ua9BeVJC32euSSsoc2wCMbYf4JMDEayvbvgG09MHoIzrwalvfGjXE2BgagUAh/7u2FsbHwuGRJKOvshK6uePGJFKlFLjO3vBdWbQJrg7HBUDY2CDYXVm1OO4kDbNgAy5ZBTw/cdVco27EjnC9bFp4XyQElcpmd7vWwaM3kssVroHtdlHBqqrcXNm2CtjYYHAxlg4Mwdy5s3hyeF8kBJXKZHXc4sAPau+DE88Lj/u3NM8Jj/XpYM+UX1Zo1sG5dlHBEylEil9kZ7IPnhqFnK6zdDT1bwvlgX+zIasM9dKd0dcF554XH7dub5xeVNAUlcpmdjqXwpkfhlEvC+SmXwhv3hvJm0NcHw8OwdSvs3g1btoTzvib5RSVNwTzhloWZdQCFQqFAR0dH7HCkGY2MhCP7+SoUYMECaG+PF5c0raGhITo7OwE63X2omms0/FDkaNrbX5iww38ykdxQ14qISOKUyEVEEqdELiKSOCVyEZHEKZGLiCROiVxEJHG5SeRm9iEzczO7PnYsIiIpyUUiN7OVwLuAB2LHIiKSmuiJ3MyOB7YAvwb8JHI4IiLJiZ7IgRuBr7n7N49V0czazayjdAAL6x+eiEi+RZ2ib2ZvA14NrKzykmuAj9YvIhGR9ERrkZvZy4E/Bf6bux+u8rLrgM7MsaRO4YmIJCNmi3w5cDLwbTMrlbUBF5jZ1UC7u0/astzdR4CR0nnmOhGRlhUzkf8TcPaUss8BDwOfmJrERUSkvGiJ3N2fBh7MlpnZT4Efu/uD5a8SEZGp8jBqRUREZiFXG0u4+4WxYxARSY1a5CIiddKonTSVyEVE6mB8HFasCI/1pkQuIlIHO3dCXx/s2lX/18pVH7mISMoGBqBQCH/u7YWxsfC4pDh1sbMTurpq/7pqkYuI1MiGDbBsGfT0wF13hbIdO8L5smXh+XqoaSI3s18wM03kEZGW1NsLmzZBWxsMDoaywUGYOxc2bw7P10M9WuSaNy8iLWv9elizZnLZmjWwbl39XnNafeRm9jfHqNIJNGjAjRyTO2g9GpGGcg/dKV1d0N0N/f2wfXt9/ztOt0X+X4AFQKHCMVzT6GTmJsbhthXhUUQapq8Phodh61bYvRu2bAnnfX31e83pjlr5PvAVd7+p3JNmdg7whtkGJTVwcCcU+uDgLli0OnY0Ii1j6VJ49FHo6Ajnl14Ke/fCggX1e83pJvL7CRtBlE3khCVmH5tVRDJzhwdgrDj26eFemBiDPb3w4uLYp3mdsKAOY59E5Hnt7eHI6uys72uaT2MOqZm1A23u/kz9Qqpecbu3QqFQoKP066+V3f9+eGQjzD8BJkZhbBDmvQTmzIfRQ3Dm1bC8TrfNRaQmhoaG6AyZv9Pdh6q5Zlp95O4+4u7PmNkdZvaCLdfM7KVmdsd0fqbU0PJeWLUJrC0kcQiPNhdWbVYSF2lSMx1+eCFwtZndambHZcrnA+qQjal7PSyaMvZp8RroXhclHBGpv9mMI78YWAzsNrPTahOOzJo7HNgB7V1w4nnhcf/2xi3DJiINN5tE/iSh9d0H/LOZXViLgGSWBvvguWHo2Qprd0PPlnA+WMexTyIS1UwXzXJ4fjPkq8zsWuA24BO1CkxmqGMpvOlRmFe8+XvKpfDGvdBWx7FPIhLVTBP5pPlJ7v5xM/s+8IXZhySz0tYejqz5dR77JCJRzTSRnw4czBa4+1fM7GFgxayjEhGRqs0okbv7oxXKHwIemlVEIiIyLVqPXEQkcUrkIiKJUyIXEUmcErnEp8lKIrOiRC5xjY/DihXhUURmRIlc4tq5M6y4v2tX7EhEkjXTceQiMzcwAIXiuum9vTA2Fh6XFNdN7+wM+2SJSFWitsjN7D1m9oCZDRWPe8zs9TFjkgbYsAGWLYOeHrjrrlC2Y0c4X7YsPC8iVYvdtbIP+BCwnDAj9A7g78zs56NGJfXV2wubNkFbGwwOhrLBQZg7FzZvDs+LSNWiJnJ3/wd3/0d3/4G7P+Luv0/YwPn8mHFJA6xfD2umrJu+Zg2sWxclHJGU5aaP3MzagCuA44B7KtRpB7IrQi1sQGhSD+6hO6WrC7q7ob8ftm8P5WbHvFxEjojdtYKZnW1mw4SNmz8DXObu36tQ/RqgkDn2NSZKqbm+Phgehq1bYfdu2LIlnPdp3XSR6ZrW5st1CcBsPnAq0Am8BfhVYHW5ZF6hRb5Pmy8naGQkHNl/t0IBFix44RbkIi1kJpsvR+9acfdR4IfF0/vNbCXwm8C7ytQdIbTcATB9BU9Xe/sLE3an1k0XmYnoXStlzGFyq1tERI4iaovczK4Dvg48RugmuQq4EFgbMay4dLNPRKYpdov8ZOAvgD3APwErgbXufnvUqGKZGIfbVoRHEZEqRW2Ru/s7Y75+7hzcCYU+OLgLFq2OHY2IJCL6zc6Wd3gAxorrjjzcCxNjsKcXXlxcd2ReJyzQuiMiUpkSeWwPbYBHNsL8E2BiNJTt3wHbemD0EJx5NSzXlHWRPIt9ayt2H7ks74VVm8DaYGwwlI0Ngs2FVZuVxEVyLg9L6iuR50H3elg0Zd2RxWuge12UcESkenlYUl9dK3ngDgd2QHsXHN8Nw/2wf3v872siUlbeltRXizwPBvvguWHo2Qprd0PPlnA+qHVHRPIob0vqR19rZTbMrAMoJL/WyvgITIzAvMzfYbQAbQugTZNcRfLo5pvh2mvhySePlL3sZSGJz2Y15pmstaIWeR60tU9O4gDzO5XERXIsT0vqq49cRGQG8rSkvlrkIiIzkKcl9dVHLiIyA/VaUj/J9chFRFKUpyX11bUiIpI4JXIRkcQpkYuIJE6JXEQkcUrkIiKJUyIXEUmcErmISOKUyOsp4clWIpIOJfJ6mRiH21aERxGROlIir5eDO6HQBwcjbhsyXfoGIZIkJfJaOjwAT/eH4+FemBiDPb1Hyg4PxI6wsjxsPCiSgDy2d5TIa+mhDfDVZbCtBw4Utw3ZvyOcf3VZeD6v8rDxoEjO5bW9o0ReS8t7YdUmsDYYGwxlY4Ngc2HV5vB8ngwMhEWU+/snbzxYKhvI8TcIkQjy2t5RIq+17vWwaMq2IYvXQPe6KOEcVd42HhTJoRTaO0rkteYOB3ZAexeceF543L89nx1rvb2waRO0tcHgYCgbHIS5c2Hz5vC8SItLob2jRF5rg33w3DD0bIW1u6FnSzgfjLBtSDXytPGgSA6l0N6JukOQmV0DXA4sA54F7gZ+1933VHl9/nYIGh+BiZHJmymPFqBtQT43U3aHU0+Fw4ePbDy4YAE89ljjNx4UybG3vz1s55Y9/8u/rP3rzGSHoNgt8tXAjcD5wCXAPGCbmR0XNarZaGufnMQB5nfmM4lDvjYeFMmp7EbL550XHksbLedBrvbsNLOTgAPAane/q4r6+WuRp6ZeGw+KNJEHHoDVq+GWW+CSS2DbNrjyypDcX/Wq2r5WM+zZWdrx7lC5J82sHchml4V1j6jZ5WnjQZGcWroUHn30SHvn0kth797Q3smD2F0rzzOzOcD1wLfc/cEK1a4BCpljX2OiE5FW1t4++UsrhPZOXr605iaRE/rKzwLedpQ61xFa7aVjSQPiEhHJtVx0rZjZRuANwAXuXrGV7e4jwEjmugZEJyKSb1ETuYVM/GfAZcCF7v6jmPGIiJS4pzMCN3bXyo3A24GrgKfNbHHxeFHkuESkheV1caxKYify9xD6urcDT2aOKyPGVL0cDd0UkdrJ6+JYlUTtWnH3RL64lDExDt9YBWvvgzltsaMRkVkaGAhTKGDy4lhLikMqOjvDRKA8it0iT1eKOwCJSEUpLI5ViRL5dKS8A5CIHFUKi2NVokQ+HSnvACQix5TqYqBK5NOR2g5AIjIteV8cqxIl8ulKaQcgEZmWVBcDzcXMzqRkdwA6vhuG+4/sAJTK7AERKSvvi2NVohb5dKW2A5CIHFOp6yTvi2NVkqv1yKcrynrkedoBSN8CRGZtfBxWrYL77gsjVmJLcYeg9ORlB6DU5hCL5FRqszjLUR95qrKfvtWrY0cjkpSUZ3GWo0Sekmb79IlEsmEDbNwIJ5wAo6OhrDSL89AhuPrqfE8AmkpdKylJeQ6xSI6kPIuzHCXyauTlhnCzffpEIkp1Fmc5SuTHMjEOt60Ij3nQTJ8+kQhK7bJUZ3GWo0R+LHlb5bCZPn0iDZYd7JXqLM5ydLOznMMDMFa8qZhd5fDFxZuK8zphQaSbiqVP3y23wCWXwLZtcOWVofxVr4oTk0gisoO9zj8/zVmc5WhCUDn3vx8e2QjzT4CJ0bAw1ryXwJz5MHoIzrw63gJZIyPhyP59C4Xw6cv79DORCLKDvT7wAfj7v4c3vhH+5E9CWd4Ge2lCUK3keZXDVOcQi0TSCoO9lMgr0SqHIslzb43BXkrklWRXOTzxvPBYWuVQRHIve2Oz2Qd7KZFXolUORZKWvbHZ7IO9NGqlko6l8KZHjyyQdcql8Ma9YZVDEcmlSqtYDA+H8ptvhre8pfkGeymRV9LW/sIVDed3xolFRKpSaQ2Vu++GZ5+Fb30rJPKUhxqWo66VSprlO5dIizjajc158154Y7OZBnspkZfTyGn5+oUhMmutdGOzHCXycho1LV+bQ4jURCvd2CxHibzk8AA83R+O7LT8Utnhgdq/ZjNsTSISgXu4sdnfH47sjc1//MdwY/PTn05/DZVqRb3ZaWYXAL8DLAdOAS5z91ujBPPQhsnT8gH274BtPbWdlq/NIURmpbTH5mtfC5/6VOvd2Cwndov8OOC7wHsjx9G4afmtMF9YpI5KX2Qvv7w1b2yWEzWRu/vX3f1ad//bmHE8rxHT8lthvrBIjVXqRlm9GlaunFy32W9slpPUOHIzaweyv1cX1vQFstPyj++G4f4j0/LNavc669fDHXeEzruSVvz0iVTpaHtsHjgQuk1+4RdCoi/d2Kzlf9m8i921Ml3XAIXMsa+mP71R0/Jb8ba6yCxU+iLrDi96UViatlVubJaTWiK/DujMHEtq+tNL0/JPuSScl6bldyyt6cs01dYkIg1Sbnz4xRfDU0+FPVbgyI3NpTX+L5t3udlYwsycaY5aqdvGEvWmzSFEps0dTj0VDh+G7u7QjbJgATz2WHN1o2hjiVRocwiRadMX2cpijyM/Hnhlpuh0MzsHOOTuj8WJSkTyaOnS5tljs9Zij1pZAdyZOf9k8fELwLqGR1NLrXbbXKTO2ttf+KW1UwuSApETubtvB5ov25Wmnt13X7jNLiJSR+ojrwetoSIiDRS7a6V5aA0VEYlELfJa0RoqIhKJEnmtaA0VEYlEibyWWnFrEhGJTn3ktZRdQ6U09awVV/ARkYZSi7yWNPVMRCLIzVorMxF9rZWpLW2toSIis6S1Vhqp3MbJWkNFRCJQIp8pTfoRkZzQzc7p0KQfEckhtcinQ5N+RCSHlMinQ5N+RCSHlMirVRrdo0k/IpIzSuTVyI5Q0cbJIpIzSuTVyI5Q0aQfEckZJfJKBgbCFPv+/skjVObOhTvvhHPPDfVaddtuEckNJfJKPv7x8iNULroIVq6cPEJFk35EJCIl8nLGx0N3yqc/rREqIpJ7SuTllPrEzzhDI1REJPc0s7Ok0qzNe++Fl74UTj8dHntMy9KKSO6oRV5SbtbmHXfA/v0wNAQXXKARKiKSS2qRl/T2wtlnw7XXHukTf/ppWLwYrrvuSHfK3r1hWVoRkZxQizyr3KzNiy+e3CeuESoikjNqkWdpqzYRSZBa5FmatSkiCdJWb1naqk1EIpvJVm/qWslqb39hwg5vqIhIbuWia8XM3mtme83ssJnda2ar6v2aH/nIR+r9EiIiDRE9kZvZlcAngT8AXg18F/iGmZ1cr9c8ePAgN910EwMDA/V6CRGRhomeyIEPAJvd/XPu/j3g3cAzwPp6veDmzZt54okn+OxnP1uvlxARaZioidzM5gPLgW+Wytx9onj+mjL1282so3QAC6t9ra1bt9Ld3U13dzc33HADANdff/3zZVu3bp3tX0dEJIrYLfIuoA3YP6V8P7C4TP1rgELm2FftC1111VV86Utfoquri/37w8vt37+fk046iS9/+ctcddVVM4lfRCS62Il8uq4DOjPHkulcvHLlSt761rdOKrviiitYsWJFzQIUEWm02Il8ABgHFk0pXwQ8NbWyu4+4+1DpAJ6e7gvu3LmT7u5u3ve+99Hd3c3OnTtnFLiISF5ETeTuPgrcD7yuVGZmc4rn99T69Z555hkGBga4/fbbueGGG7j99tsZGBjg2WefrfVLiYg0TPSZncXhh18A3gXcB/wW8FZgmbtP7Tufeu20ZnaW/q6WWTelXJmISCxJzux09y+b2UnAHxJucP4r8IvHSuIzUS5ZK4GLSOqiJ3IAd98IbIwdh4hIimLf7BQRkVlSIhcRSZwSuYhI4pTIRUQSp0QuIpI4JXIRkcTlYvjhbA0NVTVmXkQk92aSz6LP7JwNM/sZprECoohIQpa4+79XUzH1RG7Ay5j+4lkLCb8Alszg2jxQ/HEp/rhaIf6FwBNeZYJOumul+Jes6jdWVmZa/tPVrmWQJ4o/LsUfV4vEP62/l252iogkTolcRCRxrZrIR4A/KD6mSPHHpfjjUvxTJH2zU0REWrdFLiLSNJTIRUQSp0QuIpI4JXIRkcQ1bSI3s/ea2V4zO2xm95rZqmPUv8LMHi7W7zOzX2pUrBXiqTp+M1tnZj7lONzIeKfEc4GZ/YOZPVGM5U1VXHOhmX3bzEbM7Idmtq7+kVaMZVrxF2Of+v67mS1uUMjZWK4xs382s6fN7ICZ3WpmS6u4Lhef/5nEn6fPv5m9x8weMLOh4nGPmb3+GNfM+r1vykRuZlcCnyQM8Xk18F3gG2Z2coX6PcAXgZuAc4FbgVvN7KyGBPzCeKYVf9EQcErm+Nl6x3kUxxFifm81lc3sdOBrwJ3AOcD1wGfNbG2d4juWacWfsZTJ/wYHahxXNVYDNwLnA5cA84BtZnZcpQty9vmfdvxFefn87wM+BCwHVgB3AH9nZj9frnLN3nt3b7oDuBfYmDmfQ5jK/6EK9b8MfHVK2W7gM4nEvw4YjP2+V4jNgTcdo84ngAenlH0JuC2R+C8s1ntJ7HjLxHZSMbYLjlInV5//GcSf289/Mb5DwDvr+d43XYvczOYTfht+s1Tm7hPF89dUuOw12fpF3zhK/bqZYfwAx5vZo2b2uJlVbAHkVG7e/1n6VzN70sxuN7P/GDuYos7i46Gj1Mnz+19N/JDDz7+ZtZnZ2wjf8O6pUK0m733TJXKgC2gD9k8p3w9U6rNcPM369TST+PcA64E3Am8n/LvebWZL6hVkjVV6/zvM7EUR4pmuJ4F3A28uHo8D283s1TGDMrM5hG6qb7n7g0epmqfP//OmEX+uPv9mdraZDRNmbn4GuMzdv1ehek3e+6RXP5TA3e8h8xvfzO4Gvg+8C/hwrLhahbvvISSTkrvNrBt4P/Df40QFhL7ms4DXRoxhNqqKP4ef/z2Eez2dwFuAL5jZ6qMk81lrxhb5ADAOLJpSvgh4qsI1T02zfj3NJP5J3H0M+A7wytqGVjeV3v8hd382Qjy1cB8R338z2wi8AbjI3Y+1+UqePv/AtOOfJPbn391H3f2H7n6/u19DuHH+mxWq1+S9b7pE7u6jwP3A60plxa9or6NyP9U92fpFlxylft3MMP5JzKwNOJvwlT8FuXn/a+gcIrz/FmwELgPWuPuPqrgsN+//DOOf+jPy9vmfA7RXeK42733sO7p1ukt8JXAYeAfwc8Am4CfAouLzfwFcl6nfA4wBvw0sAz4GjAJnJRL/R4BLgVcQhit+EXgW+A+R4j+ekMjOIYw4eH/xz6cWn78O+ItM/dOBnwL/t/j+/wbwHLA2kfh/i9A/+0pCV8D1hG9Vr4sQ+6eAQcIwvsWZ40WZOrn9/M8w/tx8/oufjQuA0wi/TK4DJoBL6vneN/w/SQPf0KuBRwk3HO4Fzss8tx34/JT6VxD6tkaAB4FfSiV+oDdT9ynCmOxzI8Z+YTEBTj0+X3z+88D2Mtd8p/h36AfWpRI/8EHgh8Xk8WPCePiLIsVeLm7Pvp95/vzPJP48ff4J48H3FmM5QBiRckm933stYysikrim6yMXEWk1SuQiIolTIhcRSZwSuYhI4pTIRUQSp0QuIpI4JXIRkcQpkYuIJE6JXEQkcUrkItNgZq8ys53F/RUfN7MPxo5JRIlcpEpm1gFsI6zrsRz4HeBjZvbrUQOTlqe1VkQyzOw0oNzSqTsI+ytuABZ7WG4YM/s/hD09lzUsSJEp1CIXmexxJu/Gfi5hRcO7CPso3lVK4kXfAJaa2UsbHahIiRK5SIa7j7v7U+7+FGFd7M8QFvn/GJX3V4TI+1tKa9OenSKV3QwsJKwnPWFmseMRKUuJXKQMM7sWWAuscveni8WV9lcsPScShbpWRKYwszcTtg97q7v3Z566B7jAzOZlyi4B9rj7TxoZo0iWRq2IZJjZWYSt9T4J3Jh5apSwD+cewhDETxD257wZeL+7/3mDQxV5nhK5SIaZrQM+V+apHe5+oZm9ipDgVwIDwJ+5+ycaGKLICyiRi4gkTn3kIiKJUyIXEUmcErmISOKUyEVEEqdELiKSOCVyEZHEKZGLiCROiVxEJHFK5CIiiVMiFxFJnBK5iEji/j+amfIAeNJExgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "axes = plt.figure(figsize=(4, 4), dpi=100).gca()\n", "axes.set_title('z space')\n", "axes.scatter(historyGDz[:,0], historyGDz[:,1], lw=0.5, marker='*', color='blue')\n", "axes.scatter(historyNtz[:,0], historyNtz[:,1], lw=0.5, marker='*', color='orange')\n", "axes.scatter(historyPGz[:,0], historyPGz[:,1], lw=0.5, marker='*', color='red')\n", "axes.scatter([0], [0], lw=0.25, color='black', marker='*') \n", "axes.set_xlabel('z0')\n", "axes.set_ylabel('z1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These trajectories confirm the intuition outlined in the previous sections: GD in blue gives a very sub-optimal trajectory in $\\mathbf{z}$. Newton (in orange) does better, but is still clearly curved, in contrast to the straight, and diagonal red trajectory for the PG-based optimization.\n", "\n", "The behavior in intermediate spaces becomes especially important when they're not only abstract latent spaces as in this example, but when they have actual physical meanings." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusions \n", "\n", "That concludes our simple example. Despite its simplicity, it already showed surprisingly large differences between gradient descent, Newton's method, and the physical gradients.\n", "\n", "The main takeaways of this section are:\n", "* GD easily yields \"unbalanced\" updates\n", "* Newtons method does better, but is far from optimal\n", "* PGs outperform both if an inverse function is available\n", "* The choice of optimizer strongly affects progress in latent spaces\n", " \n", "In the next sections we can build on these observations to use PGs for training NNs via invertible physical models." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "---\n", "\n", "## Approximate inversions\n", "\n", "If an analytic inverse like the `fun_z_inv_analytic` above is not readily available, we can actually resort to optimization schemes like Newton's method or BFGS to approximate it numerically. This is a topic that is orthogonal to the comparison of different optimization methods, but it can be easily illustrated based on the PG example above.\n", "\n", "Below, we'll use the BFGS variant `fmin_l_bfgs_b` from `scipy` to compute the inverse. It's not very complicated, but we'll use numpy and scipy directly here, which makes the code a bit messier than it should be." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BFGS optimization test run, find x such that y=[2,2]:\n" ] }, { "data": { "text/plain": [ "array([2.00000003, 1.41421353])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def fun_z_inv_opt(target_y, x_ini):\n", " # a bit ugly, we switch to pure scipy here inside each iteration for BFGS\n", " import numpy as np\n", " from scipy.optimize import fmin_l_bfgs_b\n", " target_y = onp.array(target_y)\n", " x_ini = onp.array(x_ini)\n", "\n", " def fun_z_opt(x,target_y=[2,2]):\n", " y = onp.array( [x[0], x[1]*x[1]] ) # we cant use fun_z from JAX here\n", " ret = onp.sum( onp.square(y-target_y) )\n", " return ret\n", " \n", " ret = fmin_l_bfgs_b(lambda x: fun_z_opt(x,target_y), x_ini, approx_grad=True )\n", " #print( ret ) # return full BFGS details\n", " return ret[0]\n", "\n", "print(\"BFGS optimization test run, find x such that y=[2,2]:\")\n", "fun_z_inv_opt([2,2], [3,3])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nonetheless, we can now use this numerically inverted $\\mathbf{z}$ function to perform the PG optimization. Apart from calling `fun_z_inv_opt`, the rest of the code is unchanged." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PG iter 0: [2.09999967 2.50998022]\n", "PG iter 1: [1.46999859 2.10000011]\n", "PG iter 2: [1.02899871 1.75698602]\n", "PG iter 3: [0.72029824 1.4699998 ]\n", "PG iter 4: [0.50420733 1.22988982]\n", "PG iter 5: [0.35294448 1.02899957]\n", "PG iter 6: [0.24705997 0.86092355]\n", "PG iter 7: [0.17294205 0.72030026]\n", "PG iter 8: [0.12106103 0.60264817]\n", "PG iter 9: [0.08474171 0.50421247]\n" ] } ], "source": [ "x = np.asarray([3.,3.])\n", "eta = 0.3\n", "history = [x]; updates = []\n", "\n", "for i in range(10): \n", " # same as before, Newton step for L(y)\n", " y = fun_z(x)\n", " GL = jax.grad(fun_L)(y)\n", " y += -eta * np.matmul( np.linalg.inv( jax.jacobian(jax.jacobian(fun_L))(y) ) , GL)\n", "\n", " # optimize for inverse physics, assuming we dont have access to an inverse for fun_z\n", " x = fun_z_inv_opt(y,x)\n", " history.append(x)\n", " updates.append( history[-2] - history[-1] )\n", " print( \"PG iter %d: \"%i + format(x) )\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nice! It works, just like PG. Not much point plotting this, it's basically the PG version, but let's measure the difference. Below, we compute the MAE, which for this simple example turns out to be on the order of our floating point accuracy." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MAE difference between analytic PG and approximate inversion: 0.000001\n" ] } ], "source": [ "historyPGa = onp.asarray(history)\n", "updatesPGa = onp.asarray(updates) \n", "\n", "print(\"MAE difference between analytic PG and approximate inversion: %f\" % (np.average(np.abs(historyPGa-historyPG))) )\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "---\n", "\n", "## Next steps\n", "\n", "Based on this code example you can try the following modifications:\n", "\n", "- Instead of the simple L(z(x)) function above, try other, more complicated functions.\n", "\n", "- Replace the simple \"regular\" gradient descent with another optimizer, e.g., commonly used DL optimizers such as AdaGrad, RmsProp or Adam. Compare the versions above with the new trajectories." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }