{ "cells": [ { "cell_type": "markdown", "metadata": { "cell_id": "9F60A887F43F4A678C6B6447F3327B83" }, "source": [ "# Implicit functions in pytorch\n", "\n", "*Thomas Viehmann, tv@lernapparat.de*\n", "\n", "Sometimes, we do not know the mapping of functions we wish to apply, but only an equation that describes the mapping. In mathematical terms, we wish to apply $f : \\mathbb{R}^n \\rightarrow \\mathbb{R}^m$ of which we only know that $F(x,f(x))=0$ for some function $F : \\mathbb{R}^n \\times \\mathbb{R}^m \\rightarrow \\mathbb{R}$.\n", "\n", "This is the realm of the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem). \n", "Under reasonable conditions (smoothness, the right sort of nondegenerate derivatives), you have the following:\n", "Given $x,y$ such that $F(x,y) = 0$ there is a neighborhood $U\\ni x$ and a function $f$ such that $f(x)=y$ and $F(x,f(x))=0$. And if $F$ is nice enough, we can also compute the derivative of $f$ at $x$, namely $\\frac{df}{dx}(x) = - (\\frac{dF}{dy}(x,y))^{-1} \\frac{dF}{dx}(x,y)$.\n", "\n", "There is an example below.\n", "\n", "Nice. As computing the entire matrix Jacobian with respect to $y$ is not that straightforward in pytorch, though, we will stick with just the scalar case.\n", "\n", "Let us not be lazy about it and get our hands dirty. First import everything." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cell_id": "5802BF8F7F724585BEFEA2528B60BBD2", "collapsed": true }, "outputs": [], "source": [ "import torch\n", "import numpy\n", "from matplotlib import pyplot\n", "from mpl_toolkits.mplot3d import Axes3D\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "F3ABBD4BB63E42C18685C2BDDDA7421E" }, "source": [ "Now we can turn to the implicit function. In evaluating $f(x)=y$, we need to actually search for the solution $y$ to $F(x,y)=0$. We do this by searching for a minimum of $(F(x,y))^2$, but we also need to provide a starting point $y_0$ near which to look. (And indeed, in the circle example below, we will have two solutions and it would typically be a problem when these are close.)\n", "\n", "We use the pytorch LBFGS optimizer for a limited number of steps. Note that LBFGS also has other stopping criteria, so this really is a bound, in fact, achieving the other critera can be considered success, hitting the maximal number of iterations might be considered failure to find the minimum. The scipy optimization documentation has more details. In order for LBFGS to work, you need to provide a function that re-evaluates $F^2$ and the gradient as the `closure` argument to the `step` call.\n", "\n", "~~As we need to call `backward` on `F` for the gradient computation and I ran into locking problems when doing so in `Implicit`'s backward method, we compute the required $\\frac{dF}{dx}$ and $\\frac{dF}{dy}$ in the forward. In fact we precompute $-\\frac{df}{dx}$ and save it to the context `ctx`. In the `backward`, we wrap the saved result in a variable and multiply by the `output_grad` to do our step in the backpropagation.~~ PyTorch is ever improving - for the 0.4 update, I moved the derivative calculation in the backward.\n", "\n", "We compute the required $\\frac{dF}{dx}$ and $\\frac{dF}{dy}$ by calling backward in the backward. I don't know if all `detach` are needed or if some of the tensors are pre-detached, but I'm going to play it safe here." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cell_id": "3DB49B74E00149CE847EF4F7F5E625AA", "collapsed": true }, "outputs": [], "source": [ "class Implicit(torch.autograd.Function):\n", " @staticmethod\n", " def forward(ctx, x, y0, F, max_iter=100):\n", " with torch.enable_grad():\n", " y = y0.clone().detach().requires_grad_()\n", " xv = x.detach()\n", " opt = torch.optim.LBFGS([y], max_iter=max_iter)\n", " def reevaluate():\n", " opt.zero_grad()\n", " z = F(xv,y)**2\n", " z.backward()\n", " return z\n", " opt.step(reevaluate)\n", " ctx._the_function = F\n", " ctx.save_for_backward(x, y)\n", " return y\n", " @staticmethod\n", " def backward(ctx, output_grad):\n", " x, y = ctx.saved_tensors\n", " F = ctx._the_function\n", " with torch.enable_grad():\n", " xv = x.detach().requires_grad_()\n", " y = y.detach().requires_grad_()\n", " z = F(xv,y)\n", " z.backward()\n", " return -xv.grad/y.grad*output_grad, None, None, None\n" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "331983CBCD354721B4F41718553388C3" }, "source": [ "So now that we have a cool new `autograd` function, let us apply it to an example (and I must admit, I just took it from the Wikipedia page, the application I have in mind needs more context).\n", "\n", "The unit circle in the plane can be described by the equation $x^2+y^2 = 1$ or, equivalently $F(x,y) := x^2+y^2-1 = 0$. So let us define $F$." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cell_id": "1A05EB10B00F41B6876792D5E32F3608", "collapsed": true }, "outputs": [], "source": [ "def circle(x,y):\n", " return x**2+y**2-1" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "1F30165C0DD34AF28F2DFC0732845D1F" }, "source": [ "Everybody loves pictures, let us plot $F$ (the blue grid). In black is the circle $F(x,y)=0$." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cell_id": "9CF19E5055C841DDA0E7A7A7955C72A5" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x = numpy.linspace(-1.2,1.2,30)\n", "y = numpy.linspace(-1.2,1.2,30)\n", "t = numpy.linspace(0, 2*numpy.pi)\n", "xx = numpy.repeat(x, len(y)).reshape(x.shape[0],y.shape[0])\n", "yy = numpy.tile(y, (len(x),)).reshape(x.shape[0],y.shape[0])\n", "\n", "\n", "fig = pyplot.figure()\n", "ax = fig.gca(projection='3d')\n", "ax.plot_wireframe(xx,yy,circle(xx,yy),cmap=pyplot.cm.coolwarm_r,\n", " linewidth=0.01, antialiased=False)\n", "ax.plot(numpy.cos(t),numpy.sin(t), 0, linewidth=3, c=\"black\")\n", "ax.view_init(elev=40., azim=30)\n" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "678E9C7603434BC68617DD76EE238D5E" }, "source": [ "Now we can pick a point $x$, say $1/2$ and seek the matching $y=f(x)$ on the circle, starting from $\\frac{1}{2}$. We know that actually $y = f(x) = \\sqrt{1-x^2} = \\frac{\\sqrt{3}}{2}$. Of course, we would run into trouble for $x$ close to $1$ or $-1$." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cell_id": "1D879F5FE3D14C618F9AA2ED54FC2C1E" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8660251031723489 0.8660254037844386\n" ] } ], "source": [ "x = torch.tensor([0.5], dtype=torch.double, requires_grad=True)\n", "y0 = torch.tensor([0.5], dtype=torch.double)\n", "y= Implicit.apply(x, y0, circle)\n", "print (y.item(), (1-0.5**2)**0.5)" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "B75F8DDBE73D4A0BB15F41914462984C" }, "source": [ "Works! Let us compute the derivative. We can do that by hand, using the implicit function theorem, we have $\\frac{d}{dx} F(x,y) = 2x$ and $\\frac{d}{dy} F(x,y) = 2y$, so $\\frac{df}{dx}(x) = - x/y$. That is about the technical complication I can handle.\n", "\n", "If we didn't like using the implicit function theorem, we would have to do this by $\\frac{d}{dx} f(x) = \\frac{d}{dx} f(x) = \\frac{1}{2} \\frac{-2x}{\\sqrt{1-x^2}}$ and plugging back in $x$ and $y$ we see that $\\frac{d}{dx} f(x) = -\\frac{x}{y}$.\n", "\n", "But of course, we can also let the autograd do its thing:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "cell_id": "E9A5A59E194646BE9DAE7B9C44721367" }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1]), torch.Size([1]))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.backward()\n", "x.grad.shape, y.shape" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "4AACABC131B4409880261154E7AA8A17" }, "source": [ "Awesome. In fact, we can also use `pytorch`'s automated checker (and indeed this is why I used `DoubleTensor`s, as the gradient checker can be too strict to use single precision floats). It seems that gradcheck does not like the 1d $x$." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "cell_id": "74596AC8D005455F851D5B8DE48DC51E" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.autograd.gradcheck(lambda x: Implicit.apply(x.unsqueeze(0),y0,circle), x)" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "54903BD23B58404A889C1A2C31C311B1" }, "source": [ "So this toy example works great. But is the implicit function really useful?\n", "One thing is that the limitation to scalar $x$ and $y$ is quite severe.\n", "\n", "Quite likely, this is only the case for very specific use cases. In general, more common alternatives are\n", "- Manually solving for the implicit function.\n", "- Adding a free parameter $y$ and introduce a penalty $\\lambda F(x,y)^2$ into the loss function.\n", "\n", "However, when there is just this last transformation you want to add and a free $y$ isn't a good option for you, this might be nice to have.\n", "\n", "I hope you enjoyed this little feature, your feedback is very welcome. I read and appreciate every email you send to .\n", "\n", "Thomas" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "0F9A83A4955547DF85BDC676EBA4CBD0" }, "source": [ "P.S.: And here is a **thing to try** based on a recent [PyTorch issue](https://github.com/pytorch/pytorch/issues/7698) (thank you for the idea!):\n", "\n", "HIPS autograd has a [fixed point with parameter function](https://github.com/HIPS/autograd/blob/master/autograd/misc/fixed_points.py).\n", "Write a PyTorch equivalent. Note:\n", "- Careful with the naming: $x$ in the HIPS autograd fixpoint function is $y$ in the implicit function above and $a$ is $x$.\n", "- You replace the iteration in the forward by the simpler fixed point iteration, but you can write the fixed point equation $f(x,y)=y$ as an implicit function problem to compute the derivative.\n", "- In particular you get to store the result and I don't think you need to iterate again." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5rc1" } }, "nbformat": 4, "nbformat_minor": 2 }