{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'svg'\n",
"import numpy\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"from matplotlib import cm\n",
"from matplotlib.ticker import LinearLocator, FormatStrFormatter\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from torch.autograd import Variable, grad"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SGD: 0.0922205299139\n",
"Newton: 0.0\n",
"Optimal: 0\n"
]
}
],
"source": [
"def sgd(func, x_init, lr=1e-2, n_iter=100):\n",
" x = Variable(x_init, requires_grad=True)\n",
" xy = []\n",
" for i in range(n_iter):\n",
" y = func(x)\n",
" xy.append([x.data[0], x.data[1], y.data[0]])\n",
" dx = grad(y, x)[0]\n",
" x.data -= lr * dx.data\n",
" return numpy.array(xy)\n",
"\n",
"\n",
"def newton(func, x_init, n_iter=100):\n",
" x = Variable(x_init, requires_grad=True)\n",
" xy = []\n",
" for i in range(n_iter):\n",
" y = func(x)\n",
" xy.append([x.data[0], x.data[1], y.data[0]])\n",
" dx = grad(y, x, create_graph=True)[0]\n",
" ddx0 = grad(dx[0], x, retain_graph=True)[0]\n",
" ddx1 = grad(dx[1], x)[0]\n",
" ddx = torch.stack((ddx0, ddx1))\n",
" iddx = ddx.data.inverse()\n",
" x.data -= iddx.mv(dx.data)\n",
" return numpy.array(xy)\n",
" \n",
"\n",
"\n",
"def f(x):\n",
" return (x[0] - 1)**2 + 10 * (x[0]**2 - x[1])**2\n",
"\n",
"\n",
"x_init = torch.FloatTensor([0, 1])\n",
"sgd_path = sgd(f, x_init)\n",
"x_init = torch.FloatTensor([0, 1])\n",
"nt_path = newton(f, x_init)\n",
"print(\"SGD:\", sgd_path[-1, -1])\n",
"print(\"Newton:\", nt_path[-1, -1])\n",
"print(\"Optimal:\", f([1, 1]))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def annotate(ax, point, message, xytext=(-20, 20)):\n",
" ax.annotate(message, xy=point, xytext=xytext, \n",
" textcoords='offset points', ha='center', va='bottom',\n",
" bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=1.0),\n",
" arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.5')) \n",
"\n",
"def plot(f, point_dict):\n",
" N = 500\n",
" x1 = numpy.linspace(0, 1, N)\n",
" x2 = numpy.linspace(0, 1, N)\n",
"\n",
" X1, X2 = numpy.meshgrid(x1, x2)\n",
" X = numpy.c_[numpy.ravel(X1), numpy.ravel(X2)]\n",
" Y_plot = f([X1, X2])\n",
" Y_plot = Y_plot.reshape(X1.shape)\n",
"\n",
" fig = plt.figure()\n",
" ax = fig.gca() # projection='3d')\n",
" ax.contour(X1, X2, Y_plot, 20)\n",
" for k, points in point_dict.items():\n",
" ax.plot(points[:, 0], points[:, 1])\n",
" ax.scatter(points[:, 0], points[:, 1], points[:, 2],\n",
" label=k, marker=\"X\", linewidths=10, alpha=0.5)\n",
" annotate(ax, points[-1, :2], k + \"-end\")\n",
" # point optimal point\n",
" annotate(ax, [0, 1], \"initial\")\n",
" annotate(ax, [1, 1], \"optimal\", (20, -20))\n",
" ax.legend(loc=\"lower left\")\n",
" plt.savefig(\"./plot.png\")\n",
"\n",
"plot(f, {\"SGD\": sgd_path, \"Newton\": nt_path})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}