{ "cells": [ { "cell_type": "markdown", "metadata": { "cell_id": "00004-493eb01f-6b39-4f1c-96ea-0059bb582301", "output_cleared": false, "tags": [] }, "source": [ "# Project 2 - TMA4215\n", "
\n", "\n", "## Introduction\n", "This project attempts to tackle Hamiltonian functions using machine learning. In particular, we implement a\n", "machine learning algorithm based on the Adam algorithm and with a Stochastic Gradient Descent (SGD).\n", "The algorithm is trained and tested on a set of given example functions, and thereafter trained with trajectories\n", "for Hamiltonian systems. Using the trained weights, the model makes an attempt solving unknown trajectories.\n", "\n", "Section 1 will cover the declarations and implementations of necessary functions for generating synthetic data.\n", "\n", "Section 2 will implement the neutral network and training algorithm with the Adam method and SGD.\n", "The section also covers testing of the trained network of the synthetic data and a comparison to the true values,\n", "as well as training and testing the trajectory data.\n", "\n", "Section 3 will first discuss how to calculate the gradient of a trained function, which is necessary for\n", "the implementation of numerical integrators. Thereafter, the two numerical integrators known as\n", "the symplectic Euler method and Størmer-Verlet method, are implemented for both known and unknown hamiltonians." ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00002-b2c81420-14e5-426f-8e09-f8d5c655c487", "output_cleared": false, "tags": [] }, "source": [ "We start off by importing necessary libraries and data. Note that the trajectories used through the entire project, is the new trajectory data." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cell_id": "00001-75128b47-cf19-481e-896c-e74ea6c4a6b2", "execution_millis": 1, "execution_start": 1604311509902, "output_cleared": false, "source_hash": "a22a0809", "tags": [] }, "outputs": [], "source": [ "# Imports\n", "import numpy as np #NumPy library for numerical computation in Python\n", "import matplotlib.pyplot as plt # Main plotting package\n", "\n", "# Necessary for surface plotting ______\n", "from matplotlib import cm # Color maps\n", "from matplotlib import colors \n", "from mpl_toolkits.mplot3d import Axes3D\n", "# _____________________________________\n", "\n", "import intervals as interval # Used for defining domains of closed intervals \n", "import sys # For printing and flushing\n", "import time # Receive current time to calculate expected remaining time\n", "%matplotlib notebook \n", "import project_2_data_acquisition as dataFile # The given .ipynb for importing trajectory data, converted to .py" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00004-c977a480-cbcc-4120-ac47-088d2d6f8516", "output_cleared": false, "tags": [] }, "source": [ "## Section 1 - Generating synthetic data \n", "The following code defines the class infrastructure used in order to generate synthetic data from the given\n", "example functions. We also define a scaler class used to scale data, a helpful tool\n", "when implementing machine learning.\n", "\n", "Furthermore, we define `f1`, `f2`, `f3` and `f4` as function objects based on $F(y)$\n", "from the table at page 6 of the project description." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cell_id": "00004-d7addd86-3bbb-4af7-af87-52dfc65a8634", "execution_millis": 0, "execution_start": 1604334723803, "output_cleared": false, "source_hash": "1a32d6db", "tags": [] }, "outputs": [], "source": [ "\n", "class domain:\n", " '''\n", " This class allows us to define multi-dimensional domains, and is used\n", " for the example data generating functions.\n", " The domain class is also used for the known hamiltonians in section 4.\n", " \n", " dim : dimention, d in the table\n", " intervals : intervals as described in table\n", " excluded : potential points excluded from domain, e.g. (0,0) for last function from table\n", " '''\n", " def __init__(self, intrvals):\n", " self.dim = len(intrvals)\n", " self.intervals = intrvals\n", " self.excluded = [interval.empty() for i in range(self.dim)]\n", " \n", " def __contains__(self, item):\n", " '''\n", " Overloads the *in* operator, i.e. checks if item is in self\n", " '''\n", " if np.isscalar(item):\n", " if self.dim != 1: return False\n", " return item in self.intervals[0]\n", " \n", " if len(item) != self.dim: return False\n", " contains = True\n", " notInExcluded = False\n", " for i in range(self.dim):\n", " contains *= (item[i] in self.intervals[i])\n", " notInExcluded += (item[i] not in self.excluded[i]) \n", " contains *= (notInExcluded > 0)\n", " return contains\n", "\n", " def __sub__(self, dom_to_remove):\n", " '''\n", " Remove subdomains from a domain.\n", " This opens for a much simpler notation when setting domains, e.g.\n", " domain(interval.closed(-1,1) - interval.closed(-0.1, 0.1)) for [-1,1]\\[-0.1,0.1]\n", " '''\n", " if dom_to_remove.dim != self.dim:\n", " raise Exception(\"Error: Domain dimension mismatch in subtract\")\n", " new_dom = domain(self.intervals)\n", " for i in range(self.dim):\n", " new_dom.excluded[i] = dom_to_remove.intervals[i]\n", "\n", " if new_dom.intervals == new_dom.excluded:\n", " return domain([intervals.empty() for i in range(new_dom.dim)])\n", " return new_dom\n", "\n", " def drawRandom(self, n):\n", " '''\n", " Draw n random samples from within the domain, with a uniform probability distribution\n", " '''\n", " samples = []\n", " while len(samples) < n:\n", " sample = [np.random.uniform(self.intervals[k].lower, self.intervals[k].upper) for k in range(self.dim)]\n", " if sample in self:\n", " samples.append(sample)\n", " \n", " return np.array(samples).T\n", "\n", " def drawGrid(self, shape):\n", " '''\n", " Draw a grid of samples from the domain\n", "\n", " shape : shape of the grid (tuple(d0))\n", "\n", " -Note- these points might be in the excluded domain\n", " '''\n", " linspaces = tuple([np.linspace(self.intervals[k].lower, self.intervals[k].upper, shape[k]) for k in range(self.dim)])\n", "\n", " return linspaces\n", " \n", "\n", "class datGenFunctions:\n", " '''\n", " Class of example functions for generating data, as well as for known hamiltonians\n", "\n", " d0 : input dimension\n", " d : hidden layer dimension, usually 2*d0\n", " F : function F(y), implemented inline with anonymous lambda functions\n", " dom : domain (domain object)\n", " '''\n", " def __init__(self, d0, d, F, dom):\n", " self.d0 = d0\n", " self.d = d\n", " self.F = F\n", " self.dom = dom\n", "\n", " def __call__(self, *y):\n", " '''\n", " Calls the function, and throws exception if y is outside of domain\n", "\n", " returns : F(y)\n", " '''\n", " if np.isscalar(y) or np.isscalar(y[0]):\n", " if not y in self.dom:\n", " raise Exception(\"Error: y = \"+str(y)+\" outside of function domain\")\n", " \n", " else: \n", " yy = np.array(y)\n", " if not np.isscalar(y[0][0]):\n", " shape = list(yy.shape)\n", " yy = np.reshape(yy, (shape[0], np.product(shape[1:])))\n", " for row in yy.T:\n", " if not row in self.dom:\n", " raise Exception(\"Error: y = \"+str(row)+\" outside of function domain\")\n", "\n", " return self.F(*y)\n", " \n", "\n", "\n", "class scaler:\n", " '''\n", " This class stores the scale parameters, so that a rescaling can be performed easily\n", "\n", " ci : data to scale\n", " [alpha, beta] : range to scale the data into\n", " '''\n", " def __init__(self, ci, alpha=0.2, beta=0.8):\n", " # Store the scale parameters\n", " self.a = np.amin(ci)\n", " self.b = np.amax(ci)\n", " self.alpha = alpha\n", " self.beta = beta\n", " self.scalefunc = (lambda c : 1/(self.b-self.a)*((self.b-c)*self.alpha+(c-self.a)*self.beta))\n", " '''\n", " scalefunc scales data using min-max method as described on page 6 of the Project description.\n", " '''\n", "\n", " def scale(self, ci):\n", " '''Scale data'''\n", " return self.scalefunc(ci)\n", "\n", " def rescale(self, ci):\n", " '''\n", " Rescale data, the inverse function of scale\n", " '''\n", " return (self.a - self.b)/(self.alpha - self.beta) * ci + (self.alpha*self.b - self.a*self.beta)/(self.alpha - self.beta)\n", "\n", " def gradientRescale(self, ci):\n", " '''\n", " Rescaler needed for rescaling the gradient, \n", " this is the same as inverse of scalefunc but without shifting of c\n", " '''\n", " return (self.a - self.b)/(self.alpha - self.beta) * ci\n", "\n", "\n", "\n", "def genData(funcObj, numInput=250):\n", " '''\n", " Generates random data within domain from given function\n", " '''\n", " y = funcObj.dom.drawRandom(numInput)\n", " c = funcObj(*tuple(y))\n", " sc_c = scaler(c)\n", " sc_y = scaler(y)\n", " return sc_c, y, sc_c.scale(c)\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00006-865e67b1-04a8-4031-853c-5fa13b83b5fe", "output_cleared": false, "tags": [] }, "source": [ "## Section 2 - Implementation of neural network\n", "### 2.1 - Foundations and infrastructure\n", "We start by defining necessary functions as given in the project description - that is\n", "$$\n", "\\eta(x),\\quad \\eta'(x),\\quad \\sigma(x), \\quad \\sigma'(x), \\quad \\tilde{F}(Z_k, \\omega, \\mu, n),\\quad \\frac{\\partial J}{\\partial \\mu},\\quad \\frac{\\partial J}{\\partial \\omega}, \n", "$$\n", "An additional `embed`-function is defined, in order to fit the data points from a function of dimension $d_0$ to have shape $d$.\n", "\n", "We chose $\\eta(x) = \\frac{1}{2} (1 + tanh(\\frac{x}{2}))$ as the hypothesis function, simply beacuse it gave us the most\n", "consistent results. Especially the unknown hamiltonian proved to be more prone to divergence when using the identity function as\n", "hypothesis function." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cell_id": "00005-1d793141-510a-462f-943e-e4d91b71a162", "execution_millis": 4, "execution_start": 1604336520135, "output_cleared": false, "source_hash": "7612eefd", "tags": [] }, "outputs": [], "source": [ "def eta(x):\n", " '''\n", " Hypothesis function\n", " '''\n", " return 0.5*(1+np.tanh(x/2)) # Alternative definition of eta(x)\n", " #return x \n", "\n", "\n", "def eta_derivative(x):\n", " '''\n", " Derivative of hypothesis function\n", " '''\n", " return (1/4)*(1-np.square(np.tanh(x/2))) #Follows from alternative definition of eta(x)\n", " #return np.ones(x.shape)\n", "\n", "\n", "def F_tilde(Z_k, omega, mu, n_samples):\n", " return eta(np.transpose(Z_k)@omega+mu*np.ones((n_samples, 1)))\n", "\n", "\n", "def P_k(omega, Z_k, mu, c, n_samples):\n", " ypsilon = eta(Z_k.T@omega + mu*np.ones((Z_k.T@omega).shape))\n", " return np.outer(omega, ((ypsilon-c)*eta_derivative(Z_k.T@omega + mu*np.ones((Z_k.T@omega).shape))).T)\n", "\n", "\n", "def dJdMu(Z_k, omega, mu, c, n_samples):\n", " z_minus_c = F_tilde(Z_k, omega, mu, n_samples)-c\n", " eta_arg = (np.transpose(Z_k)@omega + mu*np.ones((n_samples, 1))).T\n", " return eta_derivative(eta_arg)@z_minus_c\n", "\n", "\n", "def dJdOmega(Z_k, omega, mu, c, n_samples):\n", " return Z_k@((F_tilde(Z_k, omega, mu, n_samples)-c)*eta_derivative(np.transpose(Z_k)@omega + mu))\n", "\n", "\n", "def sigma(x):\n", " '''Sigmoid activation function'''\n", " return np.tanh(x)\n", "\n", "def sigmaDerivative(x):\n", " '''Derivative of activation function'''\n", " return 1-np.square(np.tanh(x))\n", "\n", "\n", "def Z_k(Z_k_minus_1, h, W_k_minus_1, b_k_minus_1):\n", " return Z_k_minus_1 + h*np.tanh(W_k_minus_1@Z_k_minus_1 + b_k_minus_1)\n", "\n", "\n", "def P_k_minus_1(P_k, h, W_k, Z_k_minus_1, b_k):\n", " return P_k + h*np.transpose(W_k)@(sigmaDerivative(W_k@Z_k_minus_1 + b_k)*P_k)\n", "\n", "\n", "def embed(y, d):\n", " '''This functions embeds the input data in R^d, by setting all extra coordinates to 0'''\n", " y = np.array(y)\n", " return np.concatenate((y, np.zeros((d-y.shape[0],y.shape[1]))))\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00008-d03d644e-4114-416a-8bd5-fcbe37c78063", "output_cleared": false, "tags": [] }, "source": [] }, { "cell_type": "markdown", "metadata": { "cell_id": "00008-82ab56f3-0ebb-4c52-ab5f-ee2044e0dad8", "output_cleared": false, "tags": [] }, "source": [ "### 2.2 - Network class and training implementation\n", "The `network` class is the main class for storing information about a neural network.\n", "The class stores all learning parameters, parameters for Adam descent and defines the member functions\n", "`train`, `showConvergence`, `getParams`, `checkTraining`, `grad`, along with some helper functions\n", "(see docstring in the code). The addition operator is also overloaded so that one can sum two networks, e.g. $T$ and $V$.\n", "\n", "\n", "The `train`-function contains an implementation of the Adam-algorithm and an option for Stochastic Gradient Descent (SGD),\n", "i.e, an optional input parameter for some $\\bar{I} \\ll I$ that chooses a different subset of $(y_i)_{i=1}^{\\bar{I}}$\n", "points for each iteration, instead of all $(y_i)$.\n", "\n", "`showConvergence` draws the residue plots, i.e. $J/I$ over each time step.\n", "`getParams` returns the trained parameters and is only used for initializing the tester class,\n", "which is explained further on.\n", "`checkTraining` simply confirms that a network has trained, and runs before testing a netowrk.\n", "`grad` is the implementation of the gradient of trained function, that is $\\nabla_y F(y))$.\n", "See *Section 3* for a detailed discussion of reaching this result." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cell_id": "00010-53bce103-6b21-4c77-a942-8ed4fa094dda", "execution_millis": 7, "execution_start": 1604311509971, "output_cleared": false, "source_hash": "c68ed114", "tags": [] }, "outputs": [], "source": [ "class network:\n", " '''\n", " Network class - main class for storing values and weights.\n", " Init Parameters:\n", " K : Number of layers (int)\n", " h : Step length (float)\n", " tau: Learning parameter (float)\n", " d : Number of dimensions for each layer (int)\n", " d0 : Number of dimensions in input layer (int)\n", " iterations : number of iterations (int)\n", " hasTrained : toggles when training is done (bool)\n", " We also define necessary parameters for Adam descent method;\n", " these values are universal for all network objects.\n", " '''\n", " def __init__(self, K, h, tau, d, d0, iterations):\n", " self.K = K\n", " self.h = h\n", " self.tau = tau\n", " self.d = d\n", " self.d0 = d0\n", " self.iterations = iterations\n", " self.hasTrained = False\n", "\n", " # Params for Adam descent\n", " self.m = [0, 0, 0, 0]\n", " self.v = [0, 0, 0, 0]\n", " self.beta_1 = 0.9\n", " self.beta_2 = 0.999\n", " self.alpha = 0.01\n", " self.epsilon = 1e-8 \n", " \n", " def train(self, y, c, I_bar=False):\n", " '''\n", " Input: y-values, c-values I_bar (optional) - last one for stochastic gradient descent\n", " This function trains the network like described in the psudeocode of the Adam-method,\n", " and in addition has an optional implementation of stochastic gradient descent.\n", " Toggles the hasTrained bool/flag to True and saves trained weights when completed.\n", " '''\n", " if not I_bar:\n", " I_bar = len(c)\n", "\n", " n_samples = len(c)\n", " \n", " b = np.random.randn(self.K, self.d, 1)\n", " W = np.random.randn(self.K, self.d, self.d)\n", " mu = np.random.rand()\n", " omega = np.random.randn(self.d,1)\n", "\n", " Z = np.zeros((self.K+1, self.d, n_samples))\n", " Z[0] = np.array(embed(y, self.d)[:, :n_samples])\n", " c = np.expand_dims(c[:n_samples], axis=1)\n", "\n", " J = []\n", " dJdW = np.zeros((self.K, self.d, self.d))\n", " dJdb = np.zeros((self.K, self.d, 1))\n", " J_grad = [dJdW, dJdb, 0, 0]\n", " U = [W, b, mu, omega]\n", "\n", " startTime = time.time() # Starting a timer for keeping track of training progress\n", "\n", " sys.stdout.write(\"Training\")\n", " sys.stdout.flush()\n", " for n in range(1, self.iterations+1):\n", " sys.stdout.write(\"\\rTraining: \" + str(round(n/self.iterations*100, 1)) + \"% Estimated time remaining: \" + str(round((time.time()-startTime)*self.iterations/n - (time.time()-startTime))) + \" seconds \")\n", " sys.stdout.flush()\n", " \n", "\n", "\n", " #Clearing the list of P's from the previous iteration\n", " P = np.zeros((self.K, self.d, n_samples))\n", "\n", " #Running the images through all layers\n", " for k in range(1, self.K+1):\n", " Z[k] = Z_k(Z[k-1], self.h, U[0][k-1], U[1][k-1])\n", " \n", " \n", " #Finding P for the last layer\n", " P[-1] = P_k(U[3], Z[-1], U[2], c, n_samples)\n", " \n", " #Stochastic gradient decent - choose random indices\n", " if I_bar:\n", " selected = np.random.default_rng().choice(n_samples, I_bar, replace=False)\n", " else:\n", " selected = np.arange(n_samples)\n", "\n", " \n", " \n", " #Computing the contributions to grad(U) by mu and w\n", " subZet = np.array([Z[self.K][:,i] for i in selected]).T #Pot. subset Z with indices according to SGD\n", " subCet = np.array([c[i] for i in selected]) #like subZet with c\n", " J_grad[2] = dJdMu(subZet, U[3], U[2], subCet, I_bar)\n", " J_grad[3] = dJdOmega(subZet, U[3], U[2], subCet, I_bar)\n", " \n", " \n", " #Backpropagating; finding P for all layers\n", " for k in range(self.K-1, 0, -1):\n", " P[k-1] = P_k_minus_1(P[k], self.h, U[0][k], Z[k-1], U[1][k])\n", "\n", " \n", " #Calculating the contributions to grad(U) by W and b\n", " for k in range(self.K):\n", " sigmaArg = U[0][k]@Z[k] + U[1][k]\n", " brackets = P[k]*sigmaDerivative(sigmaArg)\n", " J_grad[0][k] = self.h*(np.array([brackets[:, i] for i in selected]).T@np.transpose(subZet))\n", " J_grad[1][k] = self.h*np.array([brackets[:, i] for i in selected]).T@np.ones((I_bar, 1))\n", " \n", " \n", " #Adam descent:\n", " for i in range(len(J_grad)):\n", " self.m[i] = self.beta_1 * self.m[i] + (1-self.beta_1) * J_grad[i]\n", " self.v[i] = self.beta_2 * self.v[i] + (1-self.beta_2) * np.square(J_grad[i])\n", " m_hat = self.m[i] / (1-self.beta_1**(n))\n", " v_hat = self.v[i] / (1-self.beta_2**(n))\n", " U[i] = U[i] - self.alpha*(m_hat/(np.sqrt(v_hat)+self.epsilon))\n", " \n", " \n", " #Finding approximation \n", " Ypsilon = F_tilde(Z[self.K], U[3], U[2], n_samples)\n", " \n", " #Computing the average error\n", " J.append(0.5*np.linalg.norm(Ypsilon - c)**2*(1/n_samples))\n", " \n", " #--------------------------------------------------------------------------- \n", "\n", " \n", " sys.stdout.write(\"\\r\\rTraining: Done Runtime: \" + str(round((time.time() - startTime), 2)) + \" seconds \\n\")\n", " \n", " self.W = U[0]\n", " self.b = U[1]\n", " self.mu = U[2]\n", " self.omega = U[3]\n", " self.J = J\n", " self.hasTrained = True\n", "\n", " def showConvergence(self, axs, plotType):\n", " '''\n", " Draws the residue plot, i.e. J/I over each iteration.\n", " plotType : {log, lin}, option to either log-plot or lin-plot (string) \n", " '''\n", " self.checkTraining()\n", "\n", " axs.plot(self.J)\n", " axs.yaxis.set_label_coords(-0.1,1.02)\n", " axs.set_title(\"Residue plot\")\n", " axs.set_ylabel(r\"$J/I$\", rotation=\"horizontal\", size=\"large\")\n", " axs.set_xlabel(\"Training iteration\")\n", " axs.set_yscale(plotType)\n", " plt.show()\n", "\n", " def getParams(self):\n", " '''Returns trained parameters'''\n", " self.checkTraining()\n", " return self.W, self.b, self.mu, self.omega\n", "\n", "\n", " def checkTraining(self):\n", " '''Ensures that only trained networks can be tested.'''\n", " if not self.hasTrained:\n", " raise Exception(\"Error, neural network has not been trained\")\n", "\n", " def propagate(self, y):\n", " '''\n", " Propagate y through the network\n", "\n", " returns : all Z and the output of the network\n", " '''\n", " self.checkTraining()\n", " \n", " z = []\n", " z.append(embed(np.array(y), self.d))\n", " for k in range(self.K):\n", " z.append(z[-1] + self.h*sigma(self.W[k]@(z[-1]) + self.b[k]))\n", " f_tilde = eta((z[-1]).T@self.omega + self.mu)\n", " return z, f_tilde\n", "\n", " def grad(self, y):\n", " '''\n", " Compute the gradient of the neural network at y\n", "\n", " y : point to compute the gradient\n", "\n", " returns : first d0 elements of gradient at point y\n", " '''\n", " Z_list, f_tilde = self.propagate(y)\n", " \n", " A = eta_derivative((np.dot(np.squeeze(self.omega), np.squeeze(Z_list[-1])) + self.mu))*self.omega \n", " for k in range(self.K, 0, -1):\n", " A = A + np.transpose(self.W[k-1, :])@(self.h*sigmaDerivative(self.W[k-1, :]@Z_list[k-1] + self.b[k-1]) * A)\n", " return np.squeeze(A)[:self.d0]\n", "\n", " def gradSum(self, other, p, q):\n", " '''\n", " Take the gradient of a sum of neural networks, analogous to grad(T+V)\n", " '''\n", " gradT = self.grad(p)\n", " gradV = other.grad(q)\n", " return gradT + gradV\n", "\n", " def __add__(self, other):\n", " '''\n", " This allows us to effectively add to neural networks\n", " by propagating through each of them and adding the result\n", "\n", " Note: Only works with T + V, will not work with V + T\n", " '''\n", " def call(p, q):\n", " _, T_tilde = self.propagate(p)\n", " _, V_tilde = other.propagate(q)\n", " return T_tilde + V_tilde\n", " \n", " return call\n", "\n", " " ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00009-c372fb2d-48ab-4f31-8203-8bf9cfeb1454", "output_cleared": false, "tags": [] }, "source": [ "### 2.3 - Tester class and plotting\n", "The tester class takes a trained neutral network, as well as testing data\n", "(input parameters `y` and `c`) and compares these.\n", "The class has two member functions, `plot` and `trajectoryPlotter`, where `plot`\n", "compares the test data and the network approximation of said data. Similarly, `trajectoryPlotter`\n", "is adapted to plot the trajectories from the test data and the network approximation.\n", "See docstrings and inline comments in the code below for plotting details." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cell_id": "00006-dd64ae94-772f-4036-b18f-03d09add6d5c", "execution_millis": 46, "execution_start": 1604311509991, "output_cleared": false, "source_hash": "e5ce7d5e", "tags": [] }, "outputs": [], "source": [ "class tester:\n", " '''\n", " This class can run various tests on the trained neural network.\n", "\n", " y : Input data to test (numpy ndarray)\n", " c : Analytic result of y; F(y) (numpy ndarray)\n", " sc : The scaler used to scale c in the training phase (scaler object)\n", " nn : The neural network to test (network object)\n", " funcObj : The function object that the network is trained on (datGenFunction object)\n", " plotType : {'log', 'lin'} The y-scale of the convergence plot. Default: 'log'\n", " grid : The grid to plot 3D surfaces on, if applicable (tuple)\n", " '''\n", " def __init__(self, y, c, sc, nn, funcObj, plotType=\"log\", grid = None):\n", " if nn.d0 == 1:\n", " self.fig, self.axs = plt.subplots(2, 1)\n", " else:\n", " self.fig, self.axs = plt.subplots(1)\n", " self.axs = [self.axs]\n", " \n", " # Plot the convergence of J from the training\n", " nn.showConvergence(self.axs[0], plotType)\n", " print(\"Last J from training: \", nn.J[-1])\n", " \n", " # Get obtained parameters from training\n", " W, b, mu, omega = nn.getParams()\n", "\n", " self.y = np.squeeze(y)\n", "\n", " if nn.d0 < 3:\n", " # This line is only needed when we want to plot, hence disabled for d0 >= 3\n", " y = np.reshape(np.array(np.meshgrid(*y)), (nn.d//2,len(y[-1])**(nn.d//2)))\n", "\n", "\n", " # Propagate through network\n", " z, f_tilde = nn.propagate(y)\n", "\n", " # Rescale using the same scaler as the data was scaled with\n", " f_tilde = sc.rescale(f_tilde) \n", "\n", " # Write data to class attributes\n", " self.c = c\n", " self.f_tilde = f_tilde\n", " self.d = nn.d\n", " self.d0 = nn.d0\n", " self.network = nn\n", " self.grid = grid\n", " self.funcObj = funcObj\n", " self.sc = sc\n", "\n", " def plot(self):\n", " '''\n", " Main function for preforming test plots.\n", " for d0==1, simply plots c(y) for analytical and approximated values.\n", " '''\n", " if self.d0 == 1:\n", " # Plot a comparison between the analytical and the neural network approximation\n", " self.axs[1].plot(self.y, self.c, label=\"Analytical\")\n", " self.axs[1].plot(self.y, self.f_tilde, label=\"Neural network approx.\")\n", " self.fig.legend()\n", " plt.show()\n", "\n", " \n", " elif self.d0 == 2:\n", " '''\n", " In this case, the function plots three 3D surfaces, the first one being the anlytical graph,\n", " the second being the neural network approximation, and the third being the difference between them\n", " '''\n", "\n", " # Set up a grid of y-values\n", " if self.grid != None:\n", " y1 , y2 = np.linspace(min(self.y[0]), max(self.y[0]), grid[0]), np.linspace(min(self.y[1]), max(self.y[1]), grid[1])\n", " y1, y2 = np.meshgrid(y1, y2)\n", " else:\n", " y1, y2 = np.meshgrid(self.y[0], self.y[1])\n", "\n", " self.c = self.funcObj(y1, y2)\n", "\n", " self.f_tilde = np.reshape(self.f_tilde, self.grid)\n", "\n", "\n", " # Make a figure\n", " fig = plt.figure()\n", " \n", " fig.suptitle(\"Test results\")\n", " \n", " # Containers for surfaces and axes\n", " surfaces = []\n", " axs = []\n", "\n", " # ------- First plot ------\n", " # Add the first subplot to the figure\n", " axs.append(fig.add_subplot(3, 1, 1, projection=\"3d\"))\n", " \n", " # Plot the first surface\n", " surfaces.append(axs[0].plot_surface(y1, y2, self.c, cmap=cm.Spectral,\n", " linewidth=0, antialiased=False))\n", "\n", " # Set title\n", " axs[0].set_title(\"Analytic\")\n", "\n", " # Add a colorbar\n", " #fig.colorbar(surf1, shrink=0.5, aspect=5)\n", "\n", "\n", " # ------- Second plot ------\n", " # Add the first subplot to the figure\n", " axs.append(fig.add_subplot(3, 1, 2, projection=\"3d\"))\n", " \n", " # Plot a normal vector to the surface at a random point\n", " randomIndex = np.random.randint(len(y1.flatten()))\n", " x = y1.flatten()[randomIndex]\n", " y = y2.flatten()[randomIndex]\n", " z = self.f_tilde.flatten()[randomIndex]\n", " axs[1].quiver(x,y,z,*tuple(self.sc.gradientRescale(self.network.grad([[x], [y]]))),-1, \n", " label=\"Normal vector at\\n\" + str(np.around([x, y, z], 2)))\n", "\n", " # Plot the second surface\n", " surfaces.append(axs[1].plot_surface(y1, y2, self.f_tilde, cmap=cm.Spectral,\n", " linewidth=0, antialiased=False))\n", " \n", " fig.legend()\n", " \n", " # Set title\n", " axs[1].set_title(\"Neural network approx.\")\n", "\n", " # Find max and min value across the two first surfaces, for mapping colors\n", " vmin = min(surface.get_array().min() for surface in surfaces)\n", " vmax = max(surface.get_array().max() for surface in surfaces)\n", " norm = colors.Normalize(vmin=vmin, vmax=vmax)\n", " for sf in surfaces:\n", " sf.set_norm(norm)\n", " \n", " # Add colorbar, this will be shared for the first two surfaces\n", " fig.colorbar(surfaces[0], ax=axs, orientation='vertical')\n", "\n", "\n", " # ------- Third plot -------\n", " # Add the third subplot\n", " ax = fig.add_subplot(3, 1, 3, projection=\"3d\")\n", "\n", " # Plot the surface\n", " surf3 = ax.plot_surface(y1, y2, np.absolute((self.f_tilde-self.c)/self.c.max()), cmap=cm.coolwarm,\n", " linewidth=0, antialiased=False)\n", "\n", " # Set title\n", " ax.set_title(\"Relative difference\") \n", "\n", " # Add colorbar, the last surface gets its own colorbar\n", " fig.colorbar(surf3, shrink=0.5, aspect=5)\n", "\n", " plt.show()\n", " \n", " # Print J for the test data\n", " print(\"J for test data:\", 0.5*np.linalg.norm(np.squeeze(self.f_tilde)-self.c)**2/np.prod(self.f_tilde.shape))\n", "\n", " \n", " def trajectoryPlotter(self, axs, ylabel):\n", " '''\n", " Additional function for plotting trajectory data\n", " '''\n", " axs.plot(self.c, label=\"Trajectory data\")\n", " axs.plot(self.f_tilde, label=\"Neural network prediction\")\n", " axs.set_xlabel(\"Time step\")\n", " axs.set_ylabel(ylabel)\n", " axs.legend()\n", " plt.show()\n", "\n", " # Print J for the test data\n", " print(\"J for test data: \", 0.5*np.linalg.norm(np.squeeze(self.f_tilde)-self.c)**2/np.prod(self.f_tilde.shape))\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00014-e8fdfc74-bb69-4f6e-8c0b-111e5479409a", "output_cleared": false, "tags": [] }, "source": [ "### 2.4 - Optimal choices for network parameters\n", "Through systematic testing, we have found that the best results we can expect to get, is around $J = 10^-4$ for the\n", "error during training, and $J = 10^-3$ for new generated test data. Our optimization of the training parameters\n", "will therefore have as a goal to achieve these results relatively consistently without unnecessary high amount\n", "of computation time.\n", "\n", "* $K$: By testing the network for different parameter values, we have found that $K=30$ layers\n", "will be sufficient without too great error. Although the error decreases by choosing $K>30$,\n", "the impact of the additional layers will not justify the increase in computation time.\n", "* $d$: For a function $F$ with input dimension $d_0$, we get a sufficiently small $J$ by setting the dimension of each hidden layer, $d=2d_0$. This seems to apply universally.\n", "* $h$: The parameter $h$ affect how \"much impact\" one layer has. We find $h=0.1$ to be sufficient for functions whose dimensions are $1$, $d_0=1$.\n", " * For cases where $d_0>1$, the value of $0.1$ is too great and negatively affects the results. Thus, we use $h=0.01$ for functions whose $d_0>1$.\n", "* $\\tau$: The learning parameter $\\tau$ impacts the speed of convergence. Similarly to the parameter $h$, we found that $\\tau = 0.002$\n", "seems to be a good value with reasonable convergence speed, given $d_0(F(y))=1$.\n", " * However, where $d_0>1$, $\\tau$ is set to $0.001$. \n", "* Number of iterations: $600$ iterations will be enough in most cases. For more complex functions, $1,000$ iterations have had been necessary.\n", "\n", "The final parameter to choose is the number of generated data points in testing phase, $I$.\n", "Using too many points will affect the computation time greatly. However, using too few points might result in overfitting.\n", "One thing to keep in mind when choosing this value, is that training the network is essentially analogous to solving a system\n", "with $K \\cdot d^2 + K \\cdot d + 2$ unknowns and $I$ equations. Hence, as a rule of thumb, we want to use at least $I = K \\cdot d^2 + K \\cdot d + 2$.\n", "We therefore landed on $I \\approx 2 \\cdot K \\cdot d^2$." ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00010-86985456-cf93-469a-bc97-a430c0225ecb", "output_cleared": false, "tags": [] }, "source": [ "### 2.5 - Preforming tests of trained networks\n", "\n", "#### 2.5.1 - Testing synthetic functions\n", "We restate the given example functions for generating synthetic data:\n", "$$\n", "\\begin{array}{c|cccc}\n", " & d_0 & d & F(y) & \\text{Domain}\\\\[6pt]\n", "F_1(y) & 1 & 2 & \\frac12 y^2 & [-2,2] \\\\[6pt]\n", "F_2(y) & 1 & 2 & 1-\\cos y & [-\\frac{\\pi}{3},\\frac{\\pi}{3}] \\\\[6pt]\n", "F_3(y) & 2 & 4 & \\frac12 (y_1^2+y_2^2) & [-2,2]\\times[-2,2] \\\\[6pt]\n", "F_4(y) & 2 & 4 & -\\frac{1}{\\sqrt{y_1^2+y_2^2}} & \\mathbb{R}^2\\backslash \\{0,0\\} \\\\\n", "\\end{array}\n", "$$\n", "Note: for $F_4$ we set the domain to be $[-10, 10]\\backslash [-\\frac12,\\frac12] $,\n", "as it's infeasible to simulate the entire real plane.\n", "\n", "All example functions are implemented below as $F_i(y)=\\tt{fi}$.\n", "The cells below output a log-plot of the residue during the training of the network, and a plot of\n", "analytical $F_i(y)$ compared with the neural network approximation $\\tilde{F}(y; \\theta)$ against new testpoints $y$." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "cell_id": "00013-74e3cafa-4aae-4da7-98a2-9fc14fe29efe", "execution_millis": 1, "execution_start": 1604311510038, "output_cleared": false, "source_hash": "afbab6a0", "tags": [] }, "outputs": [], "source": [ "dom1 = domain([interval.closed(-2, 2)])\n", "f1 = datGenFunctions(1, 2, lambda y : 0.5*y**2, dom1)\n", "\n", "dom2 = domain([interval.closed(-np.pi/3, np.pi/3)])\n", "f2 = datGenFunctions(1, 2, lambda y : 1-np.cos(y), dom2)\n", "\n", "dom3 = domain([interval.closed(-2, 2), interval.closed(-2, 2)])\n", "f3 = datGenFunctions(2, 4, lambda y1,y2 : 0.5*(y1**2+y2**2), dom3)\n", "\n", "xy_plane = domain([interval.closed(-10, 10), interval.closed(-10, 10)]) #[-10, 10]x[-10, 10] set as the entire xy plane\n", "zero = domain([interval.closed(-0.5,0.5), interval.closed(-0.5,0.5)])\n", "dom4 = xy_plane - zero\n", "f4 = datGenFunctions(2, 4, lambda y1,y2 : -1*1/np.sqrt(y1**2+y2**2), dom4)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "cell_id": "00013-f0069de3-e6db-42ce-8d58-82dde4a60384", "execution_millis": 6143, "execution_start": 1604313849864, "output_cleared": false, "source_hash": "64f30167", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training: Done Runtime: 4.41 seconds \n" ] }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Last J from training: 5.885117266078403e-05\n", "J for test data: 5.4024776945141805e-05\n" ] } ], "source": [ "#__________ f2(y) __________\n", "sc2, y2, c2 = genData(f2, 200)\n", "\n", "nn2 = network(K = 30, \n", " tau = 0.002,\n", " h = 0.1,\n", " d = 2,\n", " d0 = 1,\n", " iterations = 600\n", " )\n", "\n", "nn2.train(y2, c2, 20)\n", "y2 = np.linspace(-np.pi/3, np.pi/3, 100)\n", "c2 = f2(y2)\n", "test2 = tester([y2], c2, sc2, nn2, f2)\n", "test2.plot()" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00019-8dfc36c2-d1dc-47c3-b3a5-1fb9809f47ad", "output_cleared": false, "tags": [] }, "source": [ "**Remarks to $F_2(y)$** The residue plot has a similar behavior to that of $F_1(y)$, except for a somewhat\n", "jagged pattern around the later iterations. Increasing $\\tau$ seem to dim the severity of these artifacts.\n", "\n", "Furthermore, the network approximates $F_2(y)$ to a great prescision, with $J$ being of order $10^{-6}$ for both training and testing.\n", "The comparison plot also highlights the similarities between analytical and approximation. " ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "cell_id": "00015-f2d2e9b8-dec7-42a6-807e-274a31906f66", "execution_millis": 21857, "output_cleared": false, "source_hash": "64b9080f", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training: Done Runtime: 16.28 seconds \n" ] } ], "source": [ "#__________ f3(y) __________\n", "sc3, y3, c3 = genData(f3, 800)\n", "\n", "\n", "nn3 = network(K = 30, \n", " tau = 0.001,\n", " h = 0.01,\n", " d = 4,\n", " d0 = 2,\n", " iterations = 1000\n", " )\n", "\n", "nn3.train(y3, c3, 20)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "allow_embed": false, "cell_id": "00009-f2d37276-fcc9-452c-bd56-361c7ba8e5c8", "execution_millis": 1061, "execution_start": 1604338858487, "output_cleared": false, "source_hash": "297e1601", "tags": [] }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "J for test data: 0.01670546936100807\n" ] } ], "source": [ "grid = (20,20)\n", "y3 = f3.dom.drawGrid(grid)\n", "c3 = f3(*y3)\n", "test3 = tester(y3, c3, sc3, nn3, f3, grid = grid)\n", "test3.plot()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "cell_id": "00010-270fe6ce-3d94-426d-9812-706ac8e2a934", "execution_millis": 23311, "execution_start": 1604338868963, "output_cleared": false, "source_hash": "d90dbf9a", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training: Done Runtime: 17.89 seconds \n" ] }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "J for test data: 0.00020822897538426766\n" ] } ], "source": [ "#__________ f4(y) __________\n", "\n", "sc4, y4, c4 = genData(f4, 800)\n", "\n", "nn4 = network(K = 30, \n", " tau = 0.001,\n", " h = 0.01,\n", " d = 4,\n", " d0 = 2,\n", " iterations = 1000\n", " )\n", "\n", "nn4.train(y4, c4, 20)\n", "\n", "grid = (10,10)\n", "y4 = f4.dom.drawGrid(grid)\n", "c4 = f4(*y4)\n", "test4 = tester(y4, c4, sc4, nn4, f4, grid = grid)\n", "test4.plot()" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00023-159d3360-c5e9-4c28-a696-144281ba1a9f", "output_cleared": false, "tags": [] }, "source": [ "**Remarks to $F_3(y)$ and $F_4(y)$:** Functions $F_3(y)$ and $F_4(y)$ are of dimensions $d_0=2$ are far more\n", "complex than one-dimensional polynomials of degree $2$. Therefore, it seemed necessary to use more\n", "precise parameters, and by systematic testing, $h$ is changed from $0.1$ to $0.01$, the number of\n", "data points is increased from $200$ to $800$, and number of iterations incresed from $600$ to $1,000$.\n", "Also note that for $d_0=2$, by the rule of thumb, the number of sample points ought to be at least $Kd^2=480$.\n", "\n", "It is worth noting that in $F_3(y)$, $J$ is of order $10^{-3}$ for testing, but of order $10^{-5}$\n", "in the training phase. Such a deviation may suggest a case of overfitting. However, by analyzing the surface plots,\n", "the network seem to approximate the analytical solution very well. The surface plot of the differences indicate\n", "that the error is mainly in the outer corners of the function domain. The difference in $J$ when comparing test value to training value is of degree $10^1$, and thus not overfitted.\n", "\n", "The resulting plots show that $\\tilde{F}(y;\\theta)$ seemingly fit the same shape in both two-dimensional cases." ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00018-e82ea611-2652-464d-a835-0ff6803c6958", "output_cleared": false, "tags": [] }, "source": [ "#### 2.5.2 - Testing and training on trajectory data\n", "We define a function `genTrajData` to import the trajectory data in the same manner as\n", "`genData` does, including necessary scaling. The function allows for two main methods to import the trajectory data;\n", "\n", "The first method imports all data from a specified `batch`, which is `0` by default.\n", "The second method stochastically chooses `I` data points to include, within batch \\[`batchMin`, `batchMax`].\n", "\n", "In the two following cells, we train a new network on random trajectory data (using concatenate), and test the network." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "cell_id": "00017-4aa943ec-5d8e-45f1-8cc0-67a6f942c951", "execution_millis": 1, "execution_start": 1604331124223, "output_cleared": false, "source_hash": "6c540cdc", "tags": [] }, "outputs": [], "source": [ "def genTrajData(hamParam, I=None, batch=0, batchMin=0, batchMax=50):\n", " '''\n", " Generates data in same manner as genData, including nececcary scaling, with trajectory data\n", " Inputs:\n", " hamParam : The hamiltonian we want to import, ie \"T\" or \"V\".\n", " I : Number of points to generate. If None, the function returns all values from batch parameter\n", " batch : Which batch to import data, given I==None\n", " batchMin : The first batch included in concatenating data\n", " batchMax : The last batch included in concatenating data\n", " '''\n", " if hamParam==\"T\":\n", " pq=\"P\"\n", " elif hamParam==\"V\":\n", " pq=\"Q\"\n", " else:\n", " raise Exception(\"Hamiltonian must be V or T.\")\n", " \n", " if I != None:\n", " datDict = dataFile.concatenate(batchMin, batchMax)\n", " else:\n", " datDict = dataFile.generate_data(batch)\n", " c = datDict[hamParam]\n", " y = datDict[pq]\n", "\n", " # Select a random subset of all the trajectory data\n", " if I != None:\n", " selected = np.random.default_rng().choice(y.shape[1], I, replace=False)\n", " c = c[selected]\n", " y = y.T[selected].T\n", "\n", " sc_c = scaler(c)\n", " sc_y = scaler(y)\n", " return sc_c, y, sc_c.scale(c)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "cell_id": "00024-201a4716-8faf-436e-9702-4eed4eb6cd3b", "execution_millis": 145242, "execution_start": 1604338672687, "output_cleared": false, "source_hash": "2f176d00", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training: Done Runtime: 45.17 seconds \n" ] } ], "source": [ "sc_t1, y_t1, c_t1 = genTrajData(\"T\", 2000, batchMin = 0, batchMax = 47)\n", "nn_t1 = network(K = 30, \n", " tau = 0.001,\n", " h = 0.1,\n", " d = 6,\n", " d0 = 3,\n", " iterations = 800\n", " )\n", "\n", "nn_t1.train(y_t1, c_t1, 50)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "cell_id": "00025-5d568f71-3d03-48ba-b46f-6ec58fb2a74d", "execution_millis": 755, "execution_start": 1604338824614, "output_cleared": false, "source_hash": "980121ba", "tags": [] }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "J for test data: 9.549373382183024e-06\n" ] } ], "source": [ "sc_t1_test, y_t1_test, c_t1_test = genTrajData(\"T\", batch=49)\n", "c_t1_test = sc_t1_test.rescale(c_t1_test) # We dont want the cs to be scaled, so we rescale them right away\n", "\n", "y_t1_test = tuple(y_t1_test)\n", "c_t1_test = c_t1_test\n", "\n", "test4 = tester(y_t1_test, c_t1_test, sc_t1, nn_t1, f1)\n", "\n", "fig, axs = plt.subplots(1)\n", "test4.trajectoryPlotter(axs, \"T\")" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00029-06f88e35-3e21-4686-860a-9782a883f81e", "output_cleared": false, "tags": [] }, "source": [ "**Remarks:** $J$ from the test data in this realization is in fact lower than $J$ for the training data.\n", "This is however a small difference, and an artifact of the stochastic nature of training neural networks.\n", "Also note that the network prediction matches the shapes of the trajectory data for $T$ very well.\n", "\n", "In *Section 3*, a more in-deph analysis of $H=V+T$ and whether $H$ is preserved, will be discussed." ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00018-a8c037a7-f7ce-498f-aace-7f693f912d81", "output_cleared": false, "tags": [] }, "source": [ "## Section 3 - Gradient of $F(y)$ and numerical integrators" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00031-c5ff71ff-182d-48d3-94c4-17343743fa6a", "output_cleared": false, "tags": [] }, "source": [ "### 3.1 - Expression for the gradient of $F(y)$\n", "In order to implement numerical integrators for trained neural networks, it is expedient to\n", "derive an expression for $\\nabla_y F(y)$, that is, the gradient of the trained function in a point $y$.\n", "
" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00028-39fffefa-1136-4e5f-965d-edc2675b3d7f", "output_cleared": false, "tags": [] }, "source": [ "For the symplectic Euler and Størmer-Verlet methods, we need an approximation to the gradient of the function that\n", "the neural network approximates. We can write the approximation as the composition of each of the transformations of the neural network. \n", "We have\n", "\n", "$$\n", "F(y) = G \\circ \\Phi_{K-1}(y_{K-1}) \\circ \\Phi_{K-2}(y_{K-2}) \\circ ... \\circ \\Phi_0 = G(\\Phi(y)),\n", "$$ \n", "\n", "where\n", "\n", "$$\n", "\\Phi_k(y) = y + h\\sigma(W_k y + b_k), \\: G(y) = \\eta(w^T y + \\mu),\n", "$$\n", "\n", "which are the maps between each layer and $G(y)$ maps from the final layer into a scalar.\n", "We have all these parameters, so now it has to be made into a computer process so that we can calculate the gradient.\n", "We define \n", "\n", "$$\n", "\\Psi_k = \\Phi_k \\circ \\Phi_{k-1} \\circ ... \\circ \\Phi_0, \\: k = 1, ..., K-1,\n", "$$\n", "\n", "which will be the combined transformations up until the current layer $k$. We now have\n", "\n", "$$\n", "F(y) = G(\\Psi_{K-1}(y)).\n", "$$\n", "\n", "From earlier we have used $Z^{(k)}$ as the input into the $k$th layer, i.e. we have that $Z^{(k)} = \\Psi_{k-1}(y), \\: k = 1, ..., K.$\n", "\n", "Now we try to find the gradient in regards to the input $y$. We use the chain rule an have\n", "\n", "$$\n", "\\nabla_{y} F(y) = (D\\Psi_{K-1}(y))^T \\nabla G(Z^{K}).\n", "$$\n", "\n", "Now we have tho parts that are unknown to us, $\\nabla G(Z^{(K)})$ and $D\\Psi_{K-1}(y))^T$.\n", "We have $G$, which can be written as \n", "\n", "$$\n", "G(y) = \\eta(\\sum_{i = 1}^{d}\\omega_i y_i + \\mu)\n", "$$\n", "\n", "For the jht element of the gradient, we get\n", "\n", "$$\n", "\\frac{\\partial G}{\\partial y_j} = \\eta '(\\sum_{i=1}^{d} \\omega_i y_i + \\mu)\\omega_j = \\eta ' (\\omega ^T y + \\mu)\\omega_j\n", "$$\n", "\n", "and so the full gradient of $G$ is\n", "\n", "$$\n", "\\nabla G = \\eta '(\\omega y + \\mu)\\omega.\n", "$$\n", "\n", "As for $(D\\Psi_{K-1}(y))^T$, we can use that the input into the final layer is $Z^{(K)} = \\Phi_{K-1} \\circ \\Psi_{K-2}(y)$.\n", "We have \n", "\n", "$$\n", "D\\Psi_{K-1}(y) = D\\Phi{K-1}(Z^{(K))}) \\cdot D\\Psi_{K-2}(y),\n", "$$\n", "\n", "and transposed\n", "\n", "$$\n", "(D\\Psi_{K-1}(y))^T = (D\\Psi_{K-2}(y))^T \\cdot \\left( D\\Phi_{K-1}(Z^{(K-1)})\\right) ^T\n", "$$\n", "\n", "\n", "We look at one of the elements of $\\Phi(y)$. We can write the $i$th element as\n", "\n", "$$\n", "[\\Phi(y)]_i = y_i + h\\sigma \\left(\\sum_{j=1}^d W_{ij}y_j + b_i \\right).\n", "$$\n", "\n", "We then differentiate it, we have\n", "\n", "$$\n", "\\frac{\\partial \\Phi_i}{\\partial y_r} = \\delta_{ir} + h \\sigma '\\left( \\sum_{j=1}^{d} W_{ij} y_j + b_i \\right) W_{ir}\n", "$$\n", "\n", "For a vector $A$, we need to calculate $(D\\Phi(y))^T A$, where the $r$th component is\n", "\n", "$$\n", "\\sum_{i=1}^{d}[D\\Phi(y)]_{ir}A_i = A_r + h \\sum_{i=1}^{d}\\sigma ' \\left(\\sum_{j=1}^{d} W_{ij} y_j + b_i \\right) W_{ir} A_i \n", "$$\n", "\n", "\n", "Considering the last term as a Hadamard product, we can then finally find\n", "\n", "$$\n", "D\\Phi(y)^T A = A + W^T (h \\sigma ' (Wy + b) \\odot A).\n", "$$\n", "\n", "
\n", "\n", "An implementation of this method is relatively straight-forward.\n", "See the definition of the member function `grad` inside the `network` class for details." ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00023-3098682c-f0c4-4bff-a493-1cd187ee823a", "output_cleared": false, "tags": [] }, "source": [ "### 3.2 - The symplectic Euler and Størmer-Verlet methods \n", "The section will cover implementation and testing of two symplectic integrators, namely the symplectic Euler\n", "and the Størmer-Verlet method. These methods can then be used on both known and unknown Hamiltonians.\n", "\n", "The symplectic Euler method is as follows\n", "$$\n", "\\begin{matrix}\n", "q_{n+1}=q_n+h \\frac{\\partial T}{\\partial p}(p_n) \\\\\n", " \\\\\n", "p_{n+1}=p_n-h \\frac{\\partial V}{\\partial q}(q_{n+1}),\n", "\\end{matrix}\n", "$$\n", "\n", "\n", "and the Størmer-Verlet method is given by\n", "$$\n", "\\begin{matrix}\n", "p_{n+\\frac12}=p_n-\\frac{\\Delta t}{2}\\frac{\\partial V}{\\partial q}(q_n) \\\\\n", " \\\\\n", "q_{n+1}=q_n+\\Delta t\\frac{\\partial T}{\\partial p}(p_{n+\\frac12}) \\\\\n", " \\\\\n", "p_{n+1}=p_{n+\\frac12}-\\frac{\\Delta t}{2}\\frac{\\partial V}{\\partial q}(q_{n+1}).\n", "\\end{matrix}\n", "$$\n", "\n", "Before testing the methods on an unkonwn Hamiltonian, we test on known separable Hamiltonians,\n", "namely for a non-linear pendulum, for the Kepler Two-Body problem and the The Henon-Heiles problem,\n", "all of which are separable. For practical purposes, the problems are listed below.\n", "\n", "* **Non-linear pendulum** $H(p,q)=\\frac12 p^2 + mgl(1-\\cos(q)), \\quad p, q \\in \\mathbb{F}$, \n", "* **Kepler two-body problem** $H(\\mathbf{p},\\mathbf{q})=\\frac12 \\mathbf{p}^T\\mathbf{p} -\\frac{1}{\\sqrt{q_1^2+q_2^2}}$\n", "* **Henon-Heiles problem** $H(\\mathbf{p},\\mathbf{q})=\\frac12 \\mathbf{p}^T\\mathbf{p}+\\frac12 \\mathbf{q}^T\\mathbf{q}+q_1^2 q_2-\\frac13 q_2^3$\n", "\n", "where $\\mathbb{F}$ is the scalar field, i.e. $\\mathbb{F}=\\mathbb{R} \\text{ or } \\mathbb{C}$.\n", "\n", "\n", "### 3.3 - Implementation of numerical integrators and known hamiltonians\n", "\n", "The following cell contains the implementation of said methods and known hamiltonians using the infrastructure from the syntethic functions in *Section 1*. In addition, we define `testNumIntKnown` and `plotTotalEnergy`. `testNumIntKnown` is a function that trains necessary networks, runs both methods on the networks, and compares the difference between the methods.\n", "\n", "`plotTotalEnergy` aims to plot the total energy, $T+V$, over time steps to determine if the Hamiltonian is preserved. \n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "cell_id": "00024-9dd1f7fd-c08f-4d2b-a637-6261fe58e6a6", "execution_millis": 1, "execution_start": 1604327237487, "output_cleared": false, "source_hash": "56309c1c", "tags": [] }, "outputs": [], "source": [ "def numIntInit(N, T, y0, nn):\n", " '''Helper function that initalizes containers used in both integrators'''\n", " p = np.zeros((N+1, nn.d0))\n", " q = np.zeros((N+1, nn.d0))\n", " p[0] = y0[0]\n", " q[0] = y0[1]\n", " h = T / N #\\Delta t\n", " return p, q, h\n", "\n", "def sympEuler(sc_V, sc_T, nn_V, nn_T, y0, N, T):\n", " '''Symplectic Euler method\n", " Input: \n", " nn_V : trained network object for H(q)\n", " nn_T : trained network object for T(p)\n", " y0=[p0, q0]\n", " N=number of steps,\n", " T, last time step,\n", " Output:\n", " p, q: Numpy-array with values for \"p\"/\"q\", with the index matching their appropriate time step.\n", " '''\n", " # Initialize containers\n", " p, q, h = numIntInit(N, T, y0, nn_V)\n", " \n", " for n in range(N): #Perform steps with symplectic euler method\n", " #The following lines are an implementation with gradient rescaling enabled. See discussion below\n", " q[n+1] = q[n] + h*np.squeeze(sc_T.gradientRescale(nn_T.grad(np.array([p[n]]).T))) #q_n+1 = q_n + h*dTdp(pn)\n", " p[n+1] = p[n] - h*np.squeeze(sc_V.gradientRescale(nn_V.grad(np.array([q[n+1]]).T))) #p_n+1 = p_n + h*dVdq(q_n+1)\n", " '''\n", " q[n+1] = q[n] + h*np.squeeze(nn_T.grad(np.array([p[n]]).T)) #q_n+1 = q_n + h*dTdp(pn)\n", " p[n+1] = p[n] - h*np.squeeze(nn_V.grad(np.array([q[n+1]]).T)) #p_n+1 = p_n + h*dVdq(q_n+1)\n", " '''\n", " return p, q\n", "\n", "def stormerVerlet(sc_V, sc_T, nn_V, nn_T, y0, N, T):\n", " '''Størmer-Verlet method\n", " Input: \n", " nn_V : trained network object for H(q)\n", " nn_T : trained network object for T(p)\n", " y0=[p0, q0]\n", " N=number of steps,\n", " T, last time step,\n", " Output:\n", " p, q: Numpy-array with values for \"p\"/\"q\", with the index matching their appropriate time step.\n", " '''\n", " p, q, h = numIntInit(N, T, y0, nn_V)\n", " for n in range(N):\n", " #The following lines are an implementation with gradient rescaling enabled. See discussion below\n", " p_n_half = p[n] - h/2 *sc_V.gradientRescale(nn_V.grad(np.array([q[n]]).T)) #p_mid \n", " q[n+1] = q[n] + h *sc_T.gradientRescale(nn_T.grad(np.array([p_n_half]).T))\n", " p[n+1] = p_n_half - h/2 *sc_V.gradientRescale(nn_V.grad(np.array([q[n+1]]).T))\n", " '''\n", " p_n_half = p[n] - h/2 *nn_V.grad(np.array([q[n]]).T) #p_mid \n", " q[n+1] = q[n] + h *nn_T.grad(np.array([p_n_half]).T)\n", " p[n+1] = p_n_half - h/2 *nn_V.grad(np.array([q[n+1]]).T)\n", " '''\n", " return p, q\n", " \n", "\n", "def nonLinP(mgl=1):\n", " '''\n", " Helper function returning two datGenFunction-objects, for a non-linear pendulum system.\n", " First variable being for T and other for V.\n", " '''\n", " d_nLPT = domain([interval.closed(-2,2)])\n", " d_nLPV = domain([interval.closed(-np.pi/3,np.pi/3)])\n", " nonLinPend_T = datGenFunctions(1, 2, lambda p : 0.5*p**2, d_nLPT)\n", " nonLinPend_V = datGenFunctions(1, 2, lambda q : mgl*(1-np.cos(q)), d_nLPV )\n", " return nonLinPend_T, nonLinPend_V\n", "def keplerTB():\n", " '''\n", " Helper function returning two datGenFunctions for a Kepler Two-body system.\n", " First variable is for T and the second is for V.\n", " '''\n", " d_kTBT = domain([interval.closed(-1,1), interval.closed(-1,1)])\n", " d_kTBV = domain([interval.closed(-4,4), interval.closed(-4,4)]) - domain([interval.closed(-0.05, 0.05), interval.closed(-0.05, 0.05)])\n", " keplerTB_T = datGenFunctions(2, 4, lambda p1, p2 : 0.5*np.array([np.inner(p,p) for p in np.array([p1, p2]).T]), d_kTBT)\n", " keplerTB_V = datGenFunctions(2, 4, lambda q1, q2 : -1/(np.sqrt(q1**2+q2**2)), d_kTBV)\n", " return keplerTB_T, keplerTB_V\n", "\n", "def henonHeiles():\n", " '''\n", " Helper function returning two datGenFunctions-objects for a Henon-Heiles system\n", " First return varaible is for T and second is for V\n", " '''\n", " d_hHT = domain([interval.closed(-1,1), interval.closed(-1, 1)]) - domain([interval.closed(-0.1, 0.1), interval.closed(-0.1, 0.1)])\n", " d_hHV = domain([interval.closed(-0.2,0.2), interval.closed(-0.2, 0.2)]) - domain([interval.closed(-0.1, 0.1), interval.closed(-0.1, 0.1)])\n", " hHeilesTB_T = datGenFunctions(2, 4, lambda p1, p2 : 0.5*np.array([np.inner(p,p) for p in np.array([p1, p2]).T]), d_hHT)\n", " hHeilesTB_V = datGenFunctions(2, 4, lambda q1, q2 : 0.5*np.array([np.inner(q,q) for q in np.array([q1, q2]).T]) + q1**2*q2 - 1/3 *q2**3, d_hHV)\n", " return hHeilesTB_T, hHeilesTB_V\n", "\n", "\n", "\n", "\n", "def testNumIntKnown(y0, system, nn_T, nn_V, I, I_bar, N, T):\n", " '''\n", " Collection of instructions in order to test the two methods for the known Hamiltonians\n", "\n", " y0: initial value\n", " nn_V: trained network object for H(q)\n", " nn_T: trained network object for T(p)\n", " I: Number of training data\n", " I_bar: Number of points for stochastic gradient descent\n", " N: number of steps\n", " T: last time step\n", "\n", "\n", " '''\n", " system_T, system_V = system()\n", " # Declare and train new networks\n", "\n", " sc_T, y_T, c_T = genData(system_T, I)\n", " if not nn_T.hasTrained:\n", " nn_T.train(y_T, c_T, I_bar)\n", "\n", " sc_V, y_V, c_V = genData(system_V, I)\n", " if not nn_V.hasTrained:\n", " nn_V.train(y_V, c_V, I_bar)\n", "\n", " symp_p, symp_q = sympEuler(sc_V, sc_T, nn_V, nn_T, y0, N, T)\n", " verlet_p, verlet_q = stormerVerlet(sc_V, sc_T, nn_V, nn_T, y0, N, T)\n", "\n", " print(\"\\nDifference in q between the two methods: \", np.linalg.norm(symp_q-verlet_q)/np.product(symp_q.shape))\n", " print(\"Difference in p between the two methods: \", np.linalg.norm(symp_p-verlet_p)/np.product(symp_p.shape))\n", "\n", " return symp_p, symp_q, verlet_p, verlet_q, nn_T, nn_V, sc_T, sc_V\n", " \n", "\n", "def plotTotalEnergy(ax, x, p, q, nn_T, nn_V, **kwargs):\n", " '''\n", " Plot the total energy of the system against time step\n", "\n", " ax : axis to plot the total energy into\n", " p : series of impulse-data, generated from symplectic Euler or Störmer-Verlet \n", " q : series of positional data, generated from symplectic Euler or Störmer-Verlet\n", " nn_T : trained neural network for approximating T(p)\n", " nn_V : trained neural network for approximating V(q)\n", " **kwargs : all further arguments are passed to ax.plot()\n", " '''\n", "\n", "\n", " H = (nn_T + nn_V)(p, q)\n", "\n", " ax.set_ylim(0, 1.1*max(max(np.squeeze(H)), ax.get_ylim()[1]))\n", " ax.plot(x, np.squeeze(H), **kwargs)\n", " print(np.std(np.squeeze(H)))\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00035-88dd1a8d-b543-42a3-b68c-f5050c322379", "output_cleared": false, "tags": [] }, "source": [ "#### 3.3.1 - Applied to nonlinear pendulum" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00035-f6bb2efe-8260-4508-bb51-b8599ec00fd9", "tags": [] }, "source": [ "Note that for the following tests, the number of iterations when training the networks, is set to $2,000$.\n", "When only using around $600$ iterations, a lot of the nature of the solutions did not translate\n", "to the approximation properly, such as initial movement from the starting point." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "cell_id": "00032-ffd98d46-bf94-4632-9130-a8bbb95d978f", "execution_millis": 86621, "execution_start": 1604334763231, "output_cleared": false, "source_hash": "87696583", "tags": [] }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Standard deviation of total energy:\n", "Symplectic Euler: 0.06237039932984198\n", "Størmer-Verlet: 0.060625022748471306\n" ] } ], "source": [ "## Step 1 : Create new networks\n", "\n", "nn_T_TB = network(\n", " K = 30, \n", " tau = 0.001,\n", " h = 0.01,\n", " d = 4,\n", " d0 = 2,\n", " iterations = 1000\n", " )\n", "\n", "nn_V_TB = network(\n", " K = 30, \n", " tau = 0.001,\n", " h = 0.01,\n", " d = 4,\n", " d0 = 2,\n", " iterations = 1000\n", " )\n", "\n", "N_TB = 1000 #Number of iterations\n", "T_TB = 30 #Time interval [0, T]\n", "\n", "## Step 2 : Test numerical integration for Kepler two-body problem\n", "symp_p_TB, symp_q_TB, verlet_p_TB, verlet_q_TB, nn_T_TB, nn_V_TB, _, __ = testNumIntKnown(np.array([[0.5,0.5],[0.3,0.3]]), keplerTB, nn_T_TB, nn_V_TB, 1500, 100, N_TB, T_TB)\n", "\n", "## Step 3 : Presentation of results\n", "fig, ax = plt.subplots(1, 3)\n", "\n", "### Phase plot\n", "x_TB = np.linspace(0, T_TB, N_TB+1)\n", "ax[0].plot(symp_p_TB[:,0], symp_q_TB[:,0])\n", "ax[0].plot(verlet_p_TB[:,0], verlet_q_TB[:,0])\n", "ax[0].set_title(\"$p_1$ versus $q_1$\", pad=22)\n", "ax[0].plot(symp_p_TB[:,0][0], symp_q_TB[:,0][0], 'm*', label=\"Starting point\", markersize=10)\n", "\n", "ax[1].plot(symp_p_TB[:,1], symp_q_TB[:,1]) #p1 vs. q1\n", "ax[1].plot(verlet_p_TB[:,1], verlet_q_TB[:,1]) #p2 vs. q2\n", "ax[1].set_title(\"$p_2$ versus $q_2$\", pad=22)\n", "ax[1].plot(verlet_p_TB[:,1][0], verlet_q_TB[:,1][0], 'm*', markersize=10)\n", "\n", "\n", "\n", "\n", "### Total energy\n", "print(\"\\nStandard deviation of total energy:\")\n", "print(\"Symplectic Euler: \",end=\"\")\n", "plotTotalEnergy(ax[2], x_TB, symp_p_TB.T, symp_q_TB.T, nn_T_TB, nn_V_TB, label=\"Symplectic Euler\")\n", "print(\"Størmer-Verlet: \", end=\"\")\n", "plotTotalEnergy(ax[2], x_TB, verlet_p_TB.T, verlet_q_TB.T, nn_T_TB, nn_V_TB, label=\"Størmer-Verlet\")\n", "fig.legend(bbox_to_anchor=(0.8, 0.87), loc=\"upper left\")\n", "ax[2].set_title(\"Total energy\", pad=22)\n", "ax[2].set_xlabel(\"Time\")\n", "ax[2].yaxis.set_label_coords(-0.1,1.02)\n", "ax[2].set_ylabel(\"E\", rotation=\"horizontal\")\n", "fig.tight_layout()\n", "plt.subplots_adjust(right=0.8)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00038-09f147a4-84c5-4a17-8798-7855083f255f", "output_cleared": false, "tags": [] }, "source": [ "In the Kepler two-body problem, $p$ and $q$ are both two-dimensional, so it is a bit harder to visualize.\n", "The phase plots between their dimensions seem to be closed, which is desirable. \n", "\n", "The total energy follows the same periodic pattern like in the case of nonlinear pendulum,\n", "but with significantly higher amplitudes. Changing the inital values would reduce this amplitude,\n", "but in return also destabilize the phase plots. Regardless, the total energy fluctuates over a constant value,\n", "namely the equilibrium of the function corresponding to $H$. Thus we may conclude that the Hamiltonian is preserved. " ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00041-1c53a523-5a9c-40e1-b327-b7f019338494", "output_cleared": false, "tags": [] }, "source": [ "#### 3.3.3 - Applied to the Henon-Heiles problem " ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "cell_id": "00034-8a8c71b8-fd46-465d-9b06-16376eb28bb4", "execution_millis": 139092, "execution_start": 1604336757111, "output_cleared": false, "source_hash": "9f39da70", "tags": [] }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Standard deviation of total energy:\n", "Symplectic Euler: 0.02077699857450407\n", "Størmer-Verlet: 0.02075028761369274\n" ] } ], "source": [ "def testNumIntUnknown(nn_T, nn_V, I, I_bar, N, T, y0=None): \n", " '''Function for generating data to train/test, and calculating p/q using the two methods.\n", " Input:\n", " nn_T : Neural network for T\n", " nn_V : Neural network for V\n", " I : Number of points to be generated\n", " I_bar: Number of points for stochastic gradient descent\n", " N : Number of steps in the numerical integrators\n", " T : Time interval for numerical integrators, [0,T]\n", " y0 : Optional, inital values, here calculated as the average of p_i / q_i\n", " '''\n", " sc_t, y_t, c_t = genTrajData(\"T\", I, batchMin=0, batchMax=47)\n", " if not nn_T.hasTrained:\n", " nn_T.train(y_t, c_t, I_bar)\n", "\n", " avg_p = np.average(y_t, axis=1)\n", "\n", " sc_v, y_v, c_v = genTrajData(\"V\", I, batchMin=0, batchMax=47)\n", " if not nn_V.hasTrained:\n", " nn_V.train(y_v, c_v, I_bar)\n", "\n", " avg_q = np.average(y_v, axis=1)\n", "\n", " if y0 == None:\n", " y0 = np.array([avg_p, avg_q])\n", "\n", " symp_p, symp_q = sympEuler(sc_v, sc_t, nn_V, nn_T, y0, N, T)\n", " verlet_p, verlet_q = stormerVerlet(sc_v, sc_t, nn_V, nn_T, y0, N, T)\n", "\n", " print(\"\\nDifference in q between the two methods: \", np.linalg.norm(symp_q-verlet_q)/np.product(symp_q.shape))\n", " print(\"\\nDifference in p between the two methods: \", np.linalg.norm(symp_p-verlet_p)/np.product(symp_p.shape))\n", "\n", " return symp_p, symp_q, verlet_p, verlet_q, nn_T, nn_V, sc_t, sc_v\n", "\n", "## Part 1 : Create and train new networks\n", "nn_T_UH = network(\n", " K = 30, \n", " tau = 0.001,\n", " h = 0.1,\n", " d = 2*3,\n", " d0 = 3,\n", " iterations = 800\n", " )\n", "nn_V_UH = network(\n", " K = 30,\n", " tau = 0.001,\n", " h = 0.1,\n", " d = 2*3,\n", " d0 = 3,\n", " iterations = 800\n", " )\n", "\n", "T_UH = 30\n", "N_UH = 1000\n", "\n", "\n", "\n", "sp, sq, vp, vq, nnT, nnV, _, __ = testNumIntUnknown(nn_T_UH, nn_V_UH, 1000, 100, N_UH, T_UH)\n", "\n", "fig, ax = plt.subplots(1)\n", "x_UH = np.linspace(0, T_UH, N_UH+1)\n", "\n", "ax.set_title(\"Total energy\", pad=22)\n", "\n", "print(\"\\nStandard deviation of total energy:\")\n", "print(\"Symplectic Euler: \",end=\"\")\n", "plotTotalEnergy(ax, x_UH, sp.T, sq.T, nnT, nnV, label=\"Symplectic Euler\")\n", "\n", "print(\"Størmer-Verlet: \", end=\"\")\n", "plotTotalEnergy(ax, x_UH, vp.T, vq.T, nnT, nnV, label=\"Størmer-Verlet\")\n", "\n", "ax.set_xlabel(\"Time\")\n", "ax.yaxis.set_label_coords(-0.1,1.02)\n", "ax.set_ylabel(\"E\", rotation=\"horizontal\")\n", "\n", "fig.legend(bbox_to_anchor=(0.8, 0.87), loc=\"upper left\")\n", "\n", "fig.tight_layout()\n", "plt.subplots_adjust(right=0.8)\n", "\n", "plt.show()\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00047-cf8fc36f-9f19-47ac-b7cf-eef70ceb847a", "output_cleared": false, "tags": [] }, "source": [ "The initial values for the unknown Hamiltonian is the average of all randomly selected values.\n", "\n", "We observe that also for the unknown Hamiltonian, the total energy is seemingly conserved - the curve has similar properties to\n", "the known problems, but with an even smaller periodic deviation." ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00047-7f8adedc-7f16-41ef-a5ea-583594ac45f8", "tags": [] }, "source": [ "## Conclusion\n", "\n", "The first section implements synthetic functions that come in handy when generating data for testing the algorithms.\n", "The second section implements the machine learning algorithm using the Adam-algorithm with a stochastic\n", "gradient descent.\n", "The third section derives the gradient of the network-trained function, as well as\n", "implementation of two numerical integrators (symplectic Euler and Størmer-Verlet) for trained networks, both\n", "trained on known Hamiltonians and unknown Hamiltonians.\n", "\n", "Section 1 and 2 provides sufficients results, indicating that the neural network implementation works as intended.\n", "Furthermore, it seems possible to approximate the Hamiltonians with the use of machine learning.\n", "The approximations naturally differ to some degree, and due to the stochastic nature og machine learning,\n", "the results unfortunately diverse with some runs. Testing indicates however that\n", "the plots converge more often than not.\n", "\n", "One noteworthy discovery is that the systems are highly sensitive to the given inital values.\n", "They have a chaotic behavior, as small changes to the initial values will produce a big difference in the plots. \n", "They are also sensitive to the value $\\bar{I}$, that is, how many points are sampled at once using SGD, considering\n", "the plots diverge when $\\bar{I}$ is too small.\n", "\n", "We may also observe that the Størmer-Verlet method overall has a lower standard deviation of total energy,\n", "suggesting that this method may be more precise than the symplectic Euler method." ] }, { "cell_type": "markdown", "metadata": { "cell_id": "00049-27881c89-3114-4789-b6f5-9043d7bf4941", "tags": [] }, "source": [] } ], "metadata": { "deepnote_execution_queue": [], "deepnote_notebook_id": "3626dd21-fa62-47ca-8a2a-31c6cd47d51c", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }