{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# inspired from http://peterroelants.github.io/posts/neural_network_implementation_part05/\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "EPS = 10e-7\n", "\n", "class Layer(object):\n", " def forward(self, h):\n", " \"\"\" Perform a forward step for the given layer and returns the result\n", " # Argument\n", " h: np.array of the previous layer\n", " # Return\n", " np.array of the activation \n", " \"\"\"\n", " \n", " def backward(self, grad):\n", " \"\"\" Perform the backpropagation of the layer. This method updates weights if necessary\n", " # Argument\n", " grad: np.array of the incoming gradient\n", " # Return\n", " np.array of the calculated gradient\n", " \"\"\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class Linear(Layer):\n", " def __init__(self, input_size, nb_neurons, lr=0.0001, freeze=False):\n", " self._weights = np.random.randn(input_size + 1, nb_neurons) * 0.1\n", " self.lr = lr\n", " self.freeze = freeze\n", " \n", " def forward(self, h):\n", " self._h = np.concatenate((np.ones(1), h))\n", " return self._weights.T.dot(self._h)\n", " \n", " def backward(self, grad):\n", " if not self.freeze:\n", " dW = grad[:, np.newaxis].dot(self._h[:, np.newaxis].T).T\n", " self._weights -= self.lr * dW\n", " return grad.dot(self._weights.T)\n", " \n", "class Relu(Layer):\n", " def forward(self, h):\n", " self._h = h\n", " return np.maximum(0, h)\n", " \n", " def backward(self, grad):\n", " return np.multiply(grad[1:], (self._h >= 0).astype(dtype=np.float))\n", " \n", "class Sigmoid(Layer):\n", " def forward(self, h):\n", " self._h = h\n", " return 1 / (1 - np.exp(-h))\n", " \n", " def backward(self, grad):\n", " pass\n", " \n", "class Softmax(Layer):\n", " def forward(self, h):\n", " self._h = h\n", " s = np.sum(np.exp(h))\n", " return np.exp(h) / s\n", " \n", " def backward(self, y_true):\n", " return self._h - y_true\n", " \n", "def categorical_cross_entropy(y_true, y_pred):\n", " return - np.multiply(y_true, np.log(y_pred + EPS)).sum()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class Model(object):\n", " def __init__(self, cost):\n", " self._layers = []\n", " self._activations = []\n", " self._cost = cost\n", "\n", " def add_layer(self, layer):\n", " self._layers.append(layer)\n", " \n", " def get_cost(self, y_true , y_pred):\n", " return self._cost(y_true, y_pred)\n", " \n", " def propagate(self, x):\n", " self._activations = [x]\n", " for layer in self._layers:\n", " self._activations.append(layer.forward(self._activations[-1]))\n", " \n", " def back_prop(self, y):\n", " cur_grad = None\n", " for layer in reversed(self._layers):\n", " Y = self._activations.pop()\n", " if cur_grad is None:\n", " cur_grad = layer.backward(y)\n", " else:\n", " cur_grad = layer.backward(cur_grad)\n", " \n", " def predict(self, X):\n", " predictions = []\n", " for x in X:\n", " self.propagate(x)\n", " predictions.append(self._activations[-1])\n", " return np.array(predictions)\n", "\n", " def score(self, X_test, y_test):\n", " y_pred = self.predict(X_test)\n", " return np.sum(np.argmax(y_pred, axis=1) == np.argmax(y_test, axis=1))/len(y_pred)\n", " \n", " \n", " def train(self, epochs, X_train, y_train, X_valid=None, y_valid=None):\n", " n_train, n_valid = len(y_train), len(y_valid)\n", " loss_epochs = []\n", " val_loss_epochs = []\n", " for _ in range(epochs):\n", " Ls = []\n", " for i in range(n_train):\n", " x, y = X_train[i, :], y_train[i, :]\n", " self.propagate(x)\n", " L = self.get_cost(y, self._activations[-1])\n", " Ls.append(L)\n", " self.back_prop(y)\n", " loss_epochs.append(np.mean(Ls))\n", " Ls = []\n", " for i in range(n_valid):\n", " x, y = X_valid[i, :], y_valid[i, :]\n", " self.propagate(x)\n", " L = self.get_cost(y, self._activations[-1])\n", " Ls.append(L)\n", " val_loss_epochs.append(np.mean(Ls))\n", " plt.title(\"Loss over {} epochs\".format(str(epochs)))\n", " plt.plot(val_loss_epochs, 'b', label=\"Val Loss\")\n", " plt.plot(loss_epochs, 'r', label=\"Loss\")\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "model = Model(categorical_cross_entropy)\n", "model.add_layer(Linear(64, 20))\n", "model.add_layer(Relu())\n", "model.add_layer(Linear(20, 20))\n", "model.add_layer(Relu())\n", "model.add_layer(Linear(20, 10))\n", "model.add_layer(Softmax())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "# Test on MNIST\n", "from sklearn import datasets\n", "from keras.utils import to_categorical\n", "\n", "# The digits dataset\n", "digits = datasets.load_digits()\n", "targets = digits.target\n", "X, y = digits.data, to_categorical(targets)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3XeYFGXywPFvLWmRJbuAZBVQMsoKqGcAzAn1EEVUDBw/0BMU0xnPnO5EUMQIYkAxYQTFAHfeHYguOYmCgIAgSaKSlvr9Ub0wLBtmd2e3d2br8zz97Ex3T0/1NNT0vP12vaKqOOecSyxJYQfgnHMu9jy5O+dcAvLk7pxzCciTu3POJSBP7s45l4A8uTvnXALy5O5cKSQiKiJNwo7DFR1P7g4AEVkqIqeEHUdxEpEeIjJZRH4XkX9ls7ydiEwLlk8TkXYRy0REHhOR9cH0mIhINK91rjh4cnelgoiUyWb2BmAI8Gg265cHPgReB6oDrwAfBvMB+gLnA22BNsC5wP9F+Vrnipwnd5cnEfmLiCwSkQ0i8pGI1A3mi4g8KSJrRGSziMwRkVbBsrNEZL6IbBGRlSJycw7bThKRu0RkWbCdV0WkarDsUxH5a5b1Z4nIhcHjI0XkiyCuhSLSI2K9USLyrIiMF5FtQOes762qX6rq28Av2YR2MlAWGKKqO1T1KUCALsHy3sATqrpCVVcCTwBXRvnarJ9BVREZISKrgs/qwcwvIxG5UkT+JyLDRGSTiHwvIl0jXls3OCYbgmP0l4hlZUTkDhFZHByHaSLSIOKtTxGRH0Vko4g8k/nLQ0SaiMi/g/dbJyJvZRe3K9k8ubtciUgX4BGgB3AIsAwYEyw+DTgRaAZUDdZZHywbAfyfqlYGWgETc3iLK4OpM3AYkAIMC5a9CfSMiKUF0AgYJyKVgC+AN4BawCXA8GCdTJcCDwGVgf/mc9dbArN1//ocs4P5mctnRSyblWVZbq/NahSwG2gCHIV9rn0ilncEFgMHA38HxopIjWDZGGAFUBfoDjwcHDOAQdjndxZQBbga+D1iu+cAx2C/PHoApwfzHwA+x3511AeeziFuV4J5cnd56QWMVNXpqroDuB04VkQaA7uwxHkkIKq6QFVXBa/bBbQQkSqq+puqTs9l+4NV9SdV3Rps/xIRKQu8D7QTkUYR644N4jgHWKqqL6vqblWdAbwHXBSx7Q9V9X+qukdVt+dzv1OATVnmbQr2N7vlm4CU4Ow3r9fuJSK1seR7g6puU9U1wJPYl1WmNdivgF2q+hawEDg7OAs/HrhNVber6kzgJeCK4HV9gLtUdaGaWaq6PmK7j6rqRlX9GZgEZF4X2IV9idYNtpvfL0ZXAnhyd3mpi52tAxAk4PVAPVWdiJ1lPwOsEZEXRKRKsOqfsaS1LPiJf2w02w8elwVqq+oWYBz7El1PYHTwuBHQMWhS2CgiG7HkXydiW8sLtMdmK3a2G6kKsCWH5VWArcHZel6vjdQIKAesitiP57FfI5lWZvkVsAz73OoCG4LPKXJZveBxA+yMPyerIx7/jn0pAdyKNSN9KyLzROTqXLbhSihP7i4vv2AJCICgOaQmsBJAVZ9S1fZAC6x55pZg/neq2g1LUh8Ab0ezfaAh1kTxa/D8TaBn8OWQjJ1hgiXuf6tqtYgpRVX7R2yrMCVP5wFtInvAYM0X8yKWt41Y1jbLstxeG2k5sAM4OGI/qqhqZBNOvSzbaoh9br8ANUSkcpZlKyO2fXge+3kAVV2tqn9R1brYReLh4t0m444ndxepnIgkR0xlseR6VdC1rwLwMDBVVZeKyDEi0lFEygHbgO3AHhEpLyK9RKSqqu4CNgN7cnjPN4EbReRQEUkJtv+Wqu4Olo/Hkv/9wfzM7XwCNBORy0WkXDAdIyLNo93Z4IJjMvZLISnY53LB4n8BGcAAEakQcWE389rBq8AgEakndoH5JqztPJrX7hU0Y30OPCEiVYILzIeLyEkRq9UKtlVORC4CmgPjVXU5MBl4JIi9DXAN1ksHrInmARFpKqaNiNSM4nO5SETqB09/w74kczp+rqRSVZ98AliK/SeOnB4MlvXDft5vwJJq/WB+V+xC4VZgHdZkkgKUBz7DEsNm4DvgTzm8bxJwD3aWuZag+2CWdUYE8RyTZf4RWLPNWqypaCLQLlg2KjP+XPb5ymz2eVTE8qOAacAfwHTgqIhlAjwefCYbgscSzWuziaMq8Cx2YXQTMAO4JCLG/2HNX5uAH4DTIl5bPzgmG4Jj1C9iWRngLmAJ1iT0XcSxU6BJxLp7P69gX1YGx3Ux0Dfsf58+5X+S4GA650ogEbkS6KOqfwo7FhdfvFnGOecSkCd355xLQN4s45xzCcjP3J1zLgGVDeuNDz74YG3cuHFYb++cc3Fp2rRp61Q1Na/1QkvujRs3Jj09Pay3d865uCQiy/Jey5tlnHMuIeWZ3EWkgYhMEivfOk9EBmazTjcRmS0iM0UkXUS8T65zzoUommaZ3cBNqjo9qGExTUS+UNX5Eet8BXykqhrcAv02VinQOedcCPJM7mq1L1YFj7eIyAKs6tz8iHW2RrykEoUr2OScS0C7du1ixYoVbN+e3+rLpVNycjL169enXLlyea+cjXxdUA1qeB8FTM1m2QXYoA61gLMLFI1zLmGtWLGCypUr07hxY/YvcumyUlXWr1/PihUrOPTQQwu0jagvqAYV+97DBhXYnE0w76vqkdi4kg/ksI2+QZt8+tq1awsUsHMuPm3fvp2aNWt6Yo+CiFCzZs1C/cqJKrkHZVDfA0ar6tjc1lXVr4HDROTgbJa9oKppqpqWmppnN03nXILxxB69wn5W0fSWEazk6gJVHZzDOk0iBtc9GqjAvrE0Y2vBAvSGG2HnziLZvHPOJYJoztyPBy4HugRdHWeKjWzfT0T6Bev8GZgrIjOxIdcu1iIqWjP9vSXI0CFsfmdCUWzeOZegOnfuzIQJ++eNIUOG0L9//xxeYVJSUvI1v6TIM7mr6n9VVVS1jaq2C6bxqvqcqj4XrPOYqrYMlh2rRTigbpkzTmUdNfll8JiiegvnXALq2bMnY8bsnzfGjBlDz549Q4qoaMXdHapt08rxda3uNJjxIbpla94vcM45oHv37owbN46dQZPu0qVL+eWXXzjhhBPYunUrXbt25eijj6Z169Z8+OGHBXqPpUuX0qVLF9q0aUPXrl35+eefAXjnnXdo1aoVbdu25cQTTwRg3rx5dOjQgXbt2tGmTRt+/PHH2OxoILTaMoWR/H9XUumB55l16+u0fbZf3i9wzpUoN9wAM2fGdpvt2sGQITkvr1GjBh06dODTTz+lW7dujBkzhh49eiAiJCcn8/7771OlShXWrVtHp06dOO+88/J9UfP666+nd+/e9O7dm5EjRzJgwAA++OAD7r//fiZMmEC9evXYuHEjAM899xwDBw6kV69e7Ny5k4yMjMLs/gHi7swd4NS7OjKnQnsqjRrGngy/X8o5F53IppnIJhlV5Y477qBNmzaccsoprFy5kl9//TXf258yZQqXXnopAJdffjn//a+1UB9//PFceeWVvPjii3uT+LHHHsvDDz/MY489xrJly6hYsWIsdnGfsAZvbd++vRbG//7ysiropL9PKtR2nHPFY/78+WGHoFu2bNHU1FSdNm2aNm3adO/8l19+WXv06KE7d+5UVdVGjRrpkiVLVFW1UqVK2W4ru/k1a9bcu42dO3dqzZo19y775ptv9O6779ZGjRrpunXrVFV10aJFOnToUG3SpIl+9dVXB2wvu88MSNcocmxcnrkDdBx8Mb+VqcnOwU8T418zzrkElZKSQufOnbn66qv3u5C6adMmatWqRbly5Zg0aRLLlkVVVfcAxx133N5fBqNHj+aEE04AYPHixXTs2JH777+f1NRUli9fzk8//cRhhx3GgAED6NatG7Nnzy78DkaI2+ReJqUiv577F7pu+YCPni7YgXDOlT49e/Zk1qxZ+yX3Xr16kZ6eTuvWrXn11Vc58si86x7+/vvv1K9ff+80ePBgnn76aV5++WXatGnDa6+9xtChQwG45ZZbaN26Na1ateK4446jbdu2vP3227Rq1Yp27doxd+5crrjiipjuZ2hjqKalpWlhB+vYs2w5exofyshqg7hqzeMUsL6Oc64YLFiwgObNm4cdRlzJ7jMTkWmqmpbXa+P2zB0gqVED1hx/IRdtfJHXn98WdjjOOVdixHVyBzjk0YFUZyM/3P0aXknUOedM3Cd3Of44NjdtzxUbh/Li83vCDsc550qEuE/uiFD5roE053umPvgFu3aFHZBzzoUv/pM7IBf3YHu12ly6bihjvOSMc84lRnKnQgXKD+zPWXzK2w8sJKQOQM45V2IkRnIHkvr3I6NseU7/8Wk++yzsaJxzJVFJL9MbSwmT3KldGy65hKtkFMMf3hh2NM45F6rESe5AmYHXU0m3Uf+/bzJvXtjROOfiQUkq0xtLcVnyN0ft27O7ZVv+Mv8lnnmmP8OHhx2Qcy5bYdT8zUFJKtMbSwl15o4IZfv14WidzsyXZ7BpU9gBOedKuhJVpjeG8jxzF5EGwKtAbUCBF1R1aJZ1egG3AQJsAfqr6qzYhxuFXr3Yc9PNXLp9BK+8MowBA0KJwjmXmwKcYRe35557jqlTpzJu3Djat2/PtGnTuPTSS+nYsSPjxo3jrLPO4vnnn6dLly5hh5qtaM7cdwM3qWoLoBNwnYi0yLLOEuAkVW0NPAC8ENsw86F6dZK6/5neZV7npaf/YI/ftOqcy0VJKtMbS9EMkL1KVacHj7cAC4B6WdaZrKq/BU+/AerHOtB8ueYaKmdsovWisUyaFGokzrkSpKSX6Y2lfJX8FZHGwNdAK1XdnMM6NwNHqmqf3LYVi5K/Odqzhz1NmvK/5Q159qJJvPFG0byNcy56XvI3/4ql5K+IpADvATfkktg7A9dg7e/ZLe8rIukikr527dpo3zr/kpJI6nMNJ+z+FzPfW8SGDUX3Vs45VxJFldxFpByW2Eer6tgc1mkDvAR0U9X12a2jqi+oapqqpqWmphY05uhceSWalESvnS8zenTRvpVzzpU0eSZ3ERFgBLBAVQfnsE5DYCxwuar+ENsQC6huXeTUU7my/BuMeEm93oxzJUBYI7/Fo8J+VtGcuR8PXA50EZGZwXSWiPQTkX7BOvcANYHhwfIiakzPp169qLdzKZVmT2b69LCDca50S05OZv369Z7go6CqrF+/nuTk5AJvI89+7qr6X6z/em7r9AFyvYAaivPPRytW5Iqdoxkx4njatw87IOdKr/r167NixQqK9HpbAklOTqZ+/YJ3PIzrAbKjcsklbPngSw6tsIrlq8tRgm8oc865PJWKAbKj0qsXlXesp9PmCbz3XtjBOOdc8Uj85H766WiNGvxfymhGjAg7GOecKx6Jn9zLl0d69OCMHR+S/q8tLF4cdkDOOVf0Ej+5A/TqRbldf3CBfMjIkWEH45xzRa90JPfjjoNGjRhw8GhGjYLdu8MOyDnnilbpSO5JSXDppRy9/gt2//IrEyaEHZBzzhWt0pHcAXr1ImlPBlenvO0XVp1zCa/0JPeWLaFtW/pVHs3HH8OaNWEH5JxzRaf0JHeASy+l0aqpNNq9iNdeCzsY55wrOqUruffsCSLcVv8NRozAi4k55xJW6UruDRrAiSdy0a7RLFigfPNN2AE551zRKF3JHaBXL6r9+gN/Sp7mF1adcwmr9CX37t2hfHnuPmw0b70FW7eGHZBzzsVe6Uvu1avDWWdx8q9j+H1rBm+/HXZAzjkXe6UvuQP06kX59au5ov4kb5pxziWk0pnczz4bqlThpjqjmTwZ5swJOyDnnIut0pncK1aECy+kxffvUa3CHwwbFnZAzjkXW9EMkN1ARCaJyHwRmSciA7NZ50gRmSIiO0Tk5qIJNcZ69SJp6xYePu4TXn8dfvst7ICccy52ojlz3w3cpKotgE7AdSLSIss6G4ABwD9jHF/R6dwZDjmES/e8zu+/46WAnXMJJc/krqqrVHV68HgLsACol2WdNar6HbCrSKIsCmXKwGWXUfW/4+jWcTXPPAMZGWEH5ZxzsZGvNncRaQwcBUwtyJuJSF8RSReR9BIxAnqfPpCRwf2Hj2LJEhg/PuyAnHMuNqJO7iKSArwH3KCqmwvyZqr6gqqmqWpaampqQTYRW82awYkn0nrqS9Srqzz1VNgBOedcbESV3EWkHJbYR6vq2KINqZj16YMsXsyjZ/6bL7+EBQvCDsg55wovmt4yAowAFqjq4KIPqZh17w5Vq9J940skJ8OQIWEH5JxzhRfNmfvxwOVAFxGZGUxniUg/EekHICJ1RGQFMAi4S0RWiEiVIow7dipWhMsuI/mTd+l/yW+88gr8+mvYQTnnXOFE01vmv6oqqtpGVdsF03hVfU5VnwvWWa2q9VW1iqpWCx4XqF0+FH36wI4d/K3+6+zcCU8/HXZAzjlXOKXzDtWs2rWDtDRqffACF5yvDB/u1SKdc/HNk3um/v1h7lweOOXf/PYbXlDMORfXPLln6tkTatakxcRhnHACDB4Mu+LnliznnNuPJ/dMFSta2/sHH3DPVcv5+Wd4552wg3LOuYLx5B6pf39QpesPz9KiBTz+uA+i7ZyLT57cIzVqBN26IS+9yN9u2M6sWfDZZ2EH5Zxz+efJPau//hXWraNn0hgaNoQHHvCzd+dc/PHknlXnztCyJWWfeYrbblWmTIF//SvsoJxzLn88uWclAgMHwowZ9Gn6b+rUgQcfDDso55zLH0/u2bn8cqhVi/JP/ZNbboGJE2HKlLCDcs656Hlyz05ysrW9jxtHvxPnU7MmPPRQ2EE551z0PLnnpH9/qFiRg54bzKBBMG4cfPdd2EE551x0PLnn5OCD4cor4bXXGNBjNQcfDHfdFXZQzjkXHU/uubnxRti1i5RRw7j9dvj8c+8545yLD57cc9O0KZx/Pjz7LP2v2Ea9enDnnd7v3TlX8nlyz8stt8CGDVR87QXuuQcmT/aBtJ1zJZ9oSKehaWlpmp6eHsp751uXLvD99+xa+BPNj0omJQWmT4ck/2p0zhUzEZmmqml5refpKRp33w2rVlHu1RHcfz/MmgVvvRV2UM45l7NoBshuICKTRGS+iMwTkYHZrCMi8pSILBKR2SJydNGEG5KTT4bjj4fHHuOSC3fSti3ccQds3x52YM45l71oztx3AzepagugE3CdiLTIss6ZQNNg6gs8G9MowyZiZ+/Ll5P02is88QQsXepjrTrnSq5oBsheparTg8dbgAVAvSyrdQNeVfMNUE1EDol5tGE67TQ45hh45BG6nriLs8+2mjNr14YdmHPOHShfbe4i0hg4CpiaZVE9YHnE8xUc+AWAiPQVkXQRSV8bb1kx8+x9yRJ45RX+8Q/Ytg3uuy/swJxz7kBRJ3cRSQHeA25Q1c0FeTNVfUFV01Q1LTU1tSCbCNc550CHDnD//TQ/dDt9+8Jzz8H334cdmHPO7S+q5C4i5bDEPlpVx2azykqgQcTz+sG8xCICDz8My5fD889z771w0EFw661hB+acc/uLpreMACOABao6OIfVPgKuCHrNdAI2qeqqGMZZcnTtatNDD1HroK3ccQd8/DFMmhR2YM45t080Z+7HA5cDXURkZjCdJSL9RKRfsM544CdgEfAicG3RhFtCPPSQXUkdMoQbboCGDa0MTUZG2IE555zxO1QL6vzz7XR9yRLe+aoGPXrAM8/AtYn9teacC5nfoVrUHnwQtmyBxx+ne3cbevWuu2DdurADc845T+4F16oV9OoFTz2FrF7F00/D5s1e8905VzJ4ci+M++6DXbvgvvto2dJG5nvhBSsq5pxzYfLkXhiHHWaN7C++CPPmce+9NoDT9dd7zXfnXLg8uRfWPfdAlSpw881UqwaPPWY1319/PezAnHOlmSf3wqpZ08oSfPYZfPYZvXtDx452Y9OmTWEH55wrrTy5x8J118Hhh8PNN5O0ZzfPPANr1lhZYOecC4Mn91ioUAEefxzmzYORI2nfHgYMgGefhSlTwg7OOVca+U1MsaIKJ50ECxfCjz+yRarQogVUrw7TpkG5cmEH6JxLBH4TU3ETgSeesPaYRx+lcmUYNgzmzIEnnww7OOdcaePJPZaOOQYuuwwGD4YlS+jWDS64AO6918rAO+dccfHkHmuPPAJlysCgQQA89ZQ97d/f+74754qPJ/dYq1/fukZ+8AF89hn161sRyQkTYPTosINzzpUWfkG1KOzYAW3a2Kn6nDlklK3ASSdZZ5q5c6HeAQMQOudcdPyCapgqVLD2mB9/hCefpEwZePlly/l/+Ys3zzjnip4n96Jy+ulW8/2BB2DFCpo2tdIEn34KI0aEHZxzLtF5ci9KTz4Je/bATTcBdiNr5852rXXZspBjc84lNE/uRalxY7j9dnj7bfjyS5KSYORIa5a5+mrL+845VxSiGSB7pIisEZG5OSyvLiLvi8hsEflWRFrFPsw4duut0LQp9OsHv/9O48bWDX7iRBg+POzgnHOJKpoz91HAGbksvwOYqaptgCuAoTGIK3EkJ1u998WL7W4moE8fOPNMuOUWmD073PCcc4kpz+Suql8DG3JZpQUwMVj3e6CxiNSOTXgJ4qSTrJvME0/A9OmIWO+ZatWgRw/YujXsAJ1ziSYWbe6zgAsBRKQD0Aion92KItJXRNJFJH3t2rUxeOs48vjjUKsWXHMN7NpF7dp2U9MPP9hgTt490jkXS7FI7o8C1URkJnA9MAPIyG5FVX1BVdNUNS01NTUGbx1HqlWDZ56BmTP3VhLr0sUGcnrtNXjllZDjc84llKjuUBWRxsAnqprrxVIREWAJ0EZVN+e2bkLfoZqbCy+0zu5z5kCTJmRkwKmnwjff2NSmTdgBOudKsmK7Q1VEqolI+eBpH+DrvBJ7qTZsGJQvb30hMzIoUwbeeMPqvp93HpS21irnXNGIpivkm8AU4AgRWSEi14hIPxHpF6zSHJgrIguBM4GBRRduAqhb10oT/Oc/1icSqFPH6oytXg0XXQS7doUco3Mu7nnhsDCowp//DOPGwXff7W2LGT3aysH37+994J1z2fPCYSWZCDz/vLXFXH65VRQDevWyvu/PPgtDhoQco3MurnlyD0tqKrz0kt3FdPfde2c/8ohdc73xRnjrrRDjc87FNU/uYTrnHCtL8I9/wPjxgI3a9Prr8Kc/wRVXwKRJIcfonItLntzDNniwtblfcQWsWAFAxYrw0UfQpIlVDZ4+PeQYnXNxx5N72CpWtKqRO3bAJZfA7t2ANcd/9pnd+3TqqTBrVshxOufiiif3kuCII+CFF+B//4O77to7u0EDa5apVAlOOcWG6HPOuWh4ci8pevaEvn1tuKb33ts7+7DDrDxwhQpWrsCbaJxz0fDkXpIMHQodO0Lv3vvVAm7SxM7gDzoITj7ZL7I65/Lmyb0kSU6GsWOhShXo1g3Wrdu7qGlTa7Vp2BDOOMNWc865nHhyL2nq1rVaBKtWWbH3iFoE9erB119D+/bQvbv1ifdSwc657HhyL4k6dLALrJMmwYAB+2XwGjXgq6+sif6OO+wG1+3bQ4zVOVcilQ07AJeDK66AefNskI8GDSyTBypWtBudWraEO++ERYvsZL9OnRDjdc6VKH7mXpI98ohVErvzThuXL4KI5fuxY600/DHHwLffhhSnc67E8eRekiUlwYgRcNppNgbruHEHrHLBBXahtUwZK1kwdKi3wzvnPLmXfOXLw7vvQtu2Vuw9m36Q7dpZ//czz4QbbrBqwhs3hhCrc67E8OQeDypXtqH5Dj0Uzj4b/vWvA1apUcPa3Z94Aj7+GI4+2krFO+dKJ0/u8aJWLbtVNTPBT5x4wCoiMGiQDfKUkQHHHw9PP+3NNM6VRp7c40nt2vsS/JlnWnNNNjp1ghkz4PTTrSflBRfAr78Wc6zOuVBFM4bqSBFZIyLZlq0Skaoi8rGIzBKReSJyVezDdHvVrm13MqWl2U1Ozz2X7Wo1aljZ4CeesOqSrVrl+F3gnEtA0Zy5jwLOyGX5dcB8VW0LnAw8ISLlCx+ay1GNGvDFF3DWWTbg6n33Zdv2ktlMM306NGpk12MvvRQ2bAghZudcscozuavq10Bu6UCByiIiQEqw7u7YhOdydNBB8P77VmTs3nvhr3+1hvZstGgBU6bA/ffDO+/YzU8ff1y84Trnilcs2tyHAc2BX4A5wEBV3ZPdiiLSV0TSRSR97dq1MXjrUq5cObu56ZZbYPhwKza2eXOOq959t93olJoK551nZ/F+GJxLTLFI7qcDM4G6QDtgmIhUyW5FVX1BVdNUNS01NTUGb+0QsRIFzzxjjeudOsGPP+a4+lFHQXq6ncW/+y40bw5vvOE9apxLNLFI7lcBY9UsApYAR8Zguy4/rr3W2uHXrLHCY198keOq5cvbWfyMGVYrvlcv6125ZEkxxuucK1KxSO4/A10BRKQ2cATwUwy26/Krc2e7c6l+fSv6/uSTuZ6St2xppQuGDLG+8S1bwqOPws6dxRizc65IRNMV8k1gCnCEiKwQkWtEpJ+I9AtWeQA4TkTmAF8Bt6nqupy254rYoYfC5MnW/j5oEFx8cY7t8GA1aQYOhAULrOv87bfb3a3/+U8xxuyciznRkBpb09LSND09PZT3LhVU4Z//tGx92GHWTaZt2zxf9skn1vFm2TK4+mprzq9Zsxjidc5FRUSmqWpaXuv5HaqJSsR60UyaBFu32oXW4cPzvHJ6zjlWRv7WW+HVV+GII2DUKL/g6ly88eSe6E44AWbOtJG1r7vOsncetQgqVYLHHrObn444Aq66yprzFywonpCdc4Xnyb00qFULxo+3KmITJ0Lr1tb+kofWra3t/cUXYfZsa9W56y74449iiNk5Vyie3EsLEWtMT0+3QbjPPddKF/z+e64vS0qCPn3g++/hkkvgoYesTs2ECcUUt3OuQDy5lzYtW8LUqXDzzfD889Y1Ztq0PF9Wq5a1wU+cCGXLWk/Liy+GlSuLIWbnXL55ci+NKlSAf/wDvvxy38XWu+6CHTvyfGnnztZEc9998OGHcOSR1iln165iiNs5FzVP7qUHfaafAAARrUlEQVRZly42uvZll1l7y9FHwzff5PmyChXgnntg/nxL9rfcYkP9ZTMCoHMuJJ7cS7vq1a342KefwpYtcNxxcNNNebbFg3Wf/+gjm/74w74revb0phrnSgJP7s6ccQbMnQv9+sHgwdCmTbZjtWbn3HOtb/y991oVYm+qcS58ntzdPlWq2I1Ome0rnTtbj5pcyhdkqlgR/v73A5tqvvqqiGN2zmXLk7s70Mkn21XTQYOsR02rVlZOOApZm2pOOQVOO816YDrnio8nd5e9gw6yAVgnT4aUFKsq1rMnrFoV1cvPPdfO4p980koLH3MMdO9u/eWdc0XPk7vLXadOVofg3nth7FhrUB8+PMch/SIlJ8MNN8DixfbyCROsm/0118DPPxd55M6Vap7cXd6Sk61Bfc4cOwW/7jrrVTNjRlQvr1LFXv7TT1Ze+PXXoVkza/XxYf6cKxqe3F30mjWzEZ5Gj4alSyEtDW680bpQRiE11Tri/Pijjf40dKi10d97b1TXbJ1z+eDJ3eWPiI2s/f330LevZejmza3JJsq6wA0bwogR1vPy9NPtbtfDD7f2+e3bizh+50oJT+6uYKpXh2eftQuuBx8Mf/6zXUVdujTqTTRvboN0f/utdZscNMjO5B9/3M/knSssT+6ucDp1sn6OgwfbTU8tWlgx+HzcwXTMMdbaM3GiXXC97TY7u7/9dli9uuhCdy6RRTOG6kgRWSMic3NYfouIzAymuSKSISI1Yh+qK7HKlrW29wUL7E7Xv/0NjjoK/v3vfG2mc2dL8unp1jf+scegcWO7aXbRoqIJ3blEFc2Z+yjgjJwWquo/VLWdqrYDbgf+raobYhSfiycNGljb+0cfWbXJk0+GCy/Md2Zu3x7efhsWLoTeva30TbNmcN55lvx9yD/n8pZnclfVr4Fok3VP4M1CReTi37nn2ln8Qw9ZNm7RwoqRbdyYr800bWo3yC5bBnfeaQUrTzvNmm6GD7fvD+dc9mLW5i4iB2Fn+O/lsk5fEUkXkfS13sE5sVWsCHfcYf0ee/e2rjBNmsCwYfmuKFanDjzwACxfbgOGVKpkXe3r1bPWIG+yce5Asbygei7wv9yaZFT1BVVNU9W01NTUGL61K7Hq1LFBWGfMsEFYr7/eKk6OG5fv9pUKFeDyy613zZQpcPbZ9l3RrJk9njAB9uwpov1wLs7EMrlfgjfJuJy0bWsjP330kWXgc86xNpY5c/K9KRHrpPPGG1bG4J57bKTAM86w/vL33mt3wzpXmsUkuYtIVeAk4MNYbM8lKBFrj587125+mjbNOrhfeWW++sdHOuQQS+bLllmyb9oU7r/fkvxJJ8HIkbBpUyx3wrn4EE1XyDeBKcARIrJCRK4RkX4i0i9itQuAz1V1W1EF6hJIuXIwYIA1lt94I4wZY20rAwbAr78WaJMVKljRys8/t0T/0ENWwPKaa6B2beu08/bbsM3/hbpSQjSkfmVpaWma7kW+HcCKFXbFdMQIy9LXXgs332xZuRBUrYfNW29ZYl+1yioZn3ceXHKJNeNUqBCjfXCumIjINFVNy3M9T+6uxFi0yArNvPGGZd1+/WxIp0MOKfSmMzLgP/+xHwnvvgvr11uvm1NPtYuxZ50FdevGYB+cK2Ke3F38+uEHePhhqw1crpwVKLv1Vuv7GAO7dtnwfx9+aJ12li+3+UcfbYn+7LOtJEKSF+dwJZAndxf/Fi+2JP/qq1CmjDWg33STVReLEVW7vvvJJ5bop0yxzjy1atngU2efbZ16qlaN2Vs6Vyie3F3iWLIEHnkERo2y9pWLLrLmmvbtY/5W69dbf/lx4+DTT+G336x0TocOVk2hc2cbp+Sgg2L+1s5FxZO7Szy//AJPPWWlhjdvhi5d7MLrGWdYN8sY273bLsiOHw+TJsF339l3S7lyluxPPBGOPdb63Ps9ea64eHJ3iWvzZrvr9cknYeVKq11z441w2WU2JGAR2bLFytdPmmTVjdPT9w0le/jh0LGjTR06WPf9IgzFlWKe3F3i27kT3nkHnnjCyhukplo3ymuvtUbzIvb77zZ2+DffWFv91Kn2XQN2dt+2rY1E2L69/W3Z0uY7Vxie3F3poWq14wcPho8/3leE5sYb7ay+GK1cabVvvv3Wkv20aftGlapQwcrqtGsHrVpZsm/Vyr6HiqBVySUoT+6udFq4EIYMgVdegT/+gBNOgD59oHv3UK6C7tljnX6mTbNmnGnTYPZs2BBRXq9mzf2TffPmcMQRVnPNk77LypO7K93WrbM7XkeMsLLDVapAr16W6I8+OtTQVK3Kwty5MG+e/c18vGXLvvWqVLGqDEceack+c2ra1Coqu9LJk7tzYJn0P/+Bl16y9vnt220IwMsug4svjtmNUbGgajdUff+9/QBZuHDf4xUr9q0nAo0a7Z/wjzzSuv/Xq+ft+onOk7tzWW3caKUNRo609hERKx3Zs6c129QouUP/bttmN+5GJvzM6fff962XlGRlFBo1sqlhw/0f169vvwi8uSd+eXJ3Ljc//ABvvmnTwoV2unv66ZbozznHMmAcULWLuAsX2r1eP/9sVTEz/y5fbv31I1WsaF8AOU2HHGJ/K1cOZ59c7jy5OxcNVZg5087ox4yx9o/y5eGUU+CCC6yEZDF0qywqGRmwevW+ZP/LLwdOK1fuf/afKSUl+8RfqxYcfLBNNWva35QU/zVQXDy5O5dfe/bYXUrvvw9jx9oAIklJ8Kc/wfnn252wRx6ZcFlM1S7kZib7Vauy/xL45Re7ZJGd8uUPTPjZPa9ZE6pXt6lqVSsZ5PLHk7tzhaEKs2ZZkn//fevOAtZwfdpp1oTTtatlqVJC1S5brF1rnZHWr7e/mVN2z9evz3moXBFr/cpM9tWrQ7VqOf+NfFylivVsTbDv2ah4cnculpYssWGeJkywesGbN9tZfYcOluhPPNEep6SEHWmJkpFhXwiRyf+337KfNm60KfNxdk1FkZKSLMlXrWp/M6fKlQ98npKy7292U6VK9usjHnhyd66o7Nplt6BOmGDTd9/Z6WlSkt2CetxxNh17LBx6aOk8vYyBHTts/NvIhJ/5d/PmA6dNm6x5KXJefoZVLFt2X6KvVGn/x1mfZz6uWNGm5OTs/2Y3r7BNUTFL7iIyEjgHWKOqrXJY52RgCFAOWKeqJ+X1xp7cXcL47TerNTB5sk1Tp8LWrbasdu39k3379l5RrBjt3m2HIrdpyxb7Eoictm7N/nHm8x07Ch5T2bJw223w4IMFe30sk/uJwFbg1eySu4hUAyYDZ6jqzyJSS1XX5PXGntxdwsrIsDb6yZOtotjkyVaDAKzLZfv2luiPO84eN27sZ/dxZvduazb64w+btm/P/m9Oy044wQaDKYiYNsuISGPgkxyS+7VAXVW9Kz8BenJ3pcqaNfsS/eTJVmgms+tJ1apWQvKoo6yqWLt2VvAsXhqBXbEqzuSe2RzTEqgMDFXVV3PYTl+gL0DDhg3bL1u2LM/3di4h7dxpvXFmzLB+9jNn2vPMq4jlylm3y9atrZpYq1b2uGFDH9y1lCvO5D4MSAO6AhWBKcDZqvpDbtv0M3fnssjIsOabGTNsmjsX5syxO5AypaRY+cjIpN+qlbXtu1Ih2uReNgbvtQJYr6rbgG0i8jXQFsg1uTvnsihTxspANmtmRc0ybdoE8+dbos8sIfnBB1YMLVNqqiX5Zs2sgtihh9rfww4rVX3x3T6xSO4fAsNEpCxQHugIPBmD7TrnwNrkjz3Wpkyq1o6feXafmfTffdc6lEeqVm1fos+cMpN/w4betp+g8kzuIvImcDJwsIisAP6OtbGjqs+p6gIR+QyYDewBXlLVuUUXsnMOEWuKqV3b7pSNtHmz3XT100/7T7Nnw0cfWXt/pqQkaNDAEn3jxvvKSDZsaFODBt51M075TUzOlSYZGVYk5qef9v8CWLzYKoutWnXga2rX3pfssyb/hg2taIx35Sw2xdnm7pyLF2XK2Nl4gwZWyz6rHTusTGRk7eDMad48GD/eOmtHyqwhnFkrOKe/Vav6l0Ax8uTunNunQoV97fLZUbU2/cikn3nG/8sv1qVz/Ph9d+hGSk7OPvHXqbOvialOHbs47MNJFZond+dc9ET21e/NbSzaLVss4WdOkbWEV62yi8Cff27XB7JTs+b+CT/zca1aNqWm2lSrlhV58V8EB/Dk7pyLvcqVbWrWLPf1tm2z0cIjp9Wr93/+7bf2N7tfA2C/CCKTfWbh+NymUlAv2JO7cy48lSrl3gwUads2Kya/dq11A83p8YIF1nSU05cBWPNT1oRfo0bOXwY1athUNn5SZvxE6pwr3TLr7TZuHN36O3fChg37Rg3JbZo/3/5u2HDgoLORqlbdl+gjRxGJZirmXwue3J1zial8eWuvr1Mn+tdkjjkYmfgzvyCy/t240XoWZY4ykrUXUVZly+5L9P37w6BBhdu/PHhyd865TJlj/1WpYnfx5kfk6CJ5TcVQC8iTu3POxUKFCvt685QAXjvUOecSkCd355xLQJ7cnXMuAXlyd865BOTJ3TnnEpAnd+ecS0Ce3J1zLgF5cnfOuQQU2khMIrIWWFbAlx8MrIthOPHA97l08H0uHQqzz41UNTWvlUJL7oUhIunRDDOVSHyfSwff59KhOPbZm2Wccy4BeXJ3zrkEFK/J/YWwAwiB73Pp4PtcOhT5Psdlm7tzzrncxeuZu3POuVx4cnfOuQQUd8ldRM4QkYUiskhE/hZ2PLEiIg1EZJKIzBeReSIyMJhfQ0S+EJEfg7/Vg/kiIk8Fn8NsETk63D0oGBEpIyIzROST4PmhIjI12K+3RKR8ML9C8HxRsLxxmHEXhohUE5F3ReR7EVkgIscm8nEWkRuDf9NzReRNEUlOxOMsIiNFZI2IzI2Yl+/jKiK9g/V/FJHeBY0nrpK7iJQBngHOBFoAPUWkRbhRxcxu4CZVbQF0Aq4L9u1vwFeq2hT4KngO9hk0Daa+wLPFH3JMDAQWRDx/DHhSVZsAvwHXBPOvAX4L5j8ZrBevhgKfqeqRQFts/xPyOItIPWAAkKaqrYAywCUk5nEeBZyRZV6+jquI1AD+DnQEOgB/z/xCyDdVjZsJOBaYEPH8duD2sOMqon39EDgVWAgcEsw7BFgYPH4e6Bmx/t714mUC6gf/4LsAnwCC3bVXNuvxBiYAxwaPywbrSdj7UIB9rgosyRp7oh5noB6wHKgRHLdPgNMT9TgDjYG5BT2uQE/g+Yj5+62XnymuztzZ9w8l04pgXkIJfooeBUwFaqvqqmDRaiBzZN1E+CyGALcCe4LnNYGNqro7eB65T3v3N1i+KVg/3hwKrAVeDpqjXhKRSiTocVbVlcA/gZ+BVdhxm0biH+dM+T2uMTve8ZbcE56IpADvATeo6ubIZWpf5QnRd1VEzgHWqOq0sGMpZmWBo4FnVfUoYBv7fqoDCXecqwPdsC+1ukAlDmy6KBWK+7jGW3JfCTSIeF4/mJcQRKQclthHq+rYYPavInJIsPwQYE0wP94/i+OB80RkKTAGa5oZClQTkbLBOpH7tHd/g+VVgfXFGXCMrABWqOrU4Pm7WLJP1ON8CrBEVdeq6i5gLHbsE/04Z8rvcY3Z8Y635P4d0DS40l4euzDzUcgxxYSICDACWKCqgyMWfQRkXjHvjbXFZ86/Irjq3gnYFPHzr8RT1dtVtb6qNsaO40RV7QVMAroHq2Xd38zPoXuwftyd3arqamC5iBwRzOoKzCdBjzPWHNNJRA4K/o1n7m9CH+cI+T2uE4DTRKR68KvntGBe/oV9AaIAFyzOAn4AFgN3hh1PDPfrT9hPttnAzGA6C2tv/Ar4EfgSqBGsL1jPocXAHKw3Quj7UcB9Pxn4JHh8GPAtsAh4B6gQzE8Oni8Klh8WdtyF2N92QHpwrD8AqifycQbuA74H5gKvARUS8TgDb2LXFXZhv9CuKchxBa4O9n8RcFVB4/HyA845l4DirVnGOedcFDy5O+dcAvLk7pxzCciTu3POJSBP7s45l4A8uTvnXALy5O6ccwno/wE7Ko2N8mRCngAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def split_data(X, y, percent):\n", " split = int(len(y) * percent)\n", " return X[:split, :], y[:split, :], X[split:, :], y[split:, :]\n", "X = (X - np.mean(X)) / np.std(X)\n", "X_train, y_train, X_valid, y_valid = split_data(X, y, 0.5)\n", "X_valid, y_valid, X_test, y_test = split_data(X_valid, y_valid, 0.5)\n", "model.train(1000, X_train, y_train, X_valid, y_valid)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 0.9066666666666666\n" ] } ], "source": [ "score = model.score(X_test, y_test)\n", "print(\"Accuracy = {}\".format(str(score)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }