{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = 'svg'\n", "import numpy\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from matplotlib import cm\n", "from matplotlib.ticker import LinearLocator, FormatStrFormatter\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from torch.autograd import Variable, grad" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SGD: 0.0922205299139\n", "Newton: 0.0\n", "Optimal: 0\n" ] } ], "source": [ "def sgd(func, x_init, lr=1e-2, n_iter=100):\n", " x = Variable(x_init, requires_grad=True)\n", " xy = []\n", " for i in range(n_iter):\n", " y = func(x)\n", " xy.append([x.data[0], x.data[1], y.data[0]])\n", " dx = grad(y, x)[0]\n", " x.data -= lr * dx.data\n", " return numpy.array(xy)\n", "\n", "\n", "def newton(func, x_init, n_iter=100):\n", " x = Variable(x_init, requires_grad=True)\n", " xy = []\n", " for i in range(n_iter):\n", " y = func(x)\n", " xy.append([x.data[0], x.data[1], y.data[0]])\n", " dx = grad(y, x, create_graph=True)[0]\n", " ddx0 = grad(dx[0], x, retain_graph=True)[0]\n", " ddx1 = grad(dx[1], x)[0]\n", " ddx = torch.stack((ddx0, ddx1))\n", " iddx = ddx.data.inverse()\n", " x.data -= iddx.mv(dx.data)\n", " return numpy.array(xy)\n", " \n", "\n", "\n", "def f(x):\n", " return (x[0] - 1)**2 + 10 * (x[0]**2 - x[1])**2\n", "\n", "\n", "x_init = torch.FloatTensor([0, 1])\n", "sgd_path = sgd(f, x_init)\n", "x_init = torch.FloatTensor([0, 1])\n", "nt_path = newton(f, x_init)\n", "print(\"SGD:\", sgd_path[-1, -1])\n", "print(\"Newton:\", nt_path[-1, -1])\n", "print(\"Optimal:\", f([1, 1]))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def annotate(ax, point, message, xytext=(-20, 20)):\n", " ax.annotate(message, xy=point, xytext=xytext, \n", " textcoords='offset points', ha='center', va='bottom',\n", " bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=1.0),\n", " arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.5')) \n", "\n", "def plot(f, point_dict):\n", " N = 500\n", " x1 = numpy.linspace(0, 1, N)\n", " x2 = numpy.linspace(0, 1, N)\n", "\n", " X1, X2 = numpy.meshgrid(x1, x2)\n", " X = numpy.c_[numpy.ravel(X1), numpy.ravel(X2)]\n", " Y_plot = f([X1, X2])\n", " Y_plot = Y_plot.reshape(X1.shape)\n", "\n", " fig = plt.figure()\n", " ax = fig.gca() # projection='3d')\n", " ax.contour(X1, X2, Y_plot, 20)\n", " for k, points in point_dict.items():\n", " ax.plot(points[:, 0], points[:, 1])\n", " ax.scatter(points[:, 0], points[:, 1], points[:, 2],\n", " label=k, marker=\"X\", linewidths=10, alpha=0.5)\n", " annotate(ax, points[-1, :2], k + \"-end\")\n", " # point optimal point\n", " annotate(ax, [0, 1], \"initial\")\n", " annotate(ax, [1, 1], \"optimal\", (20, -20))\n", " ax.legend(loc=\"lower left\")\n", " plt.savefig(\"./plot.png\")\n", "\n", "plot(f, {\"SGD\": sgd_path, \"Newton\": nt_path})" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }