{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Neural Networks from scratch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (3.1.3)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (1.1.0)\n", "Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (1.18.1)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (2.8.1)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (0.10.0)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (2.4.6)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib) (45.2.0)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.1->matplotlib) (1.14.0)\n" ] } ], "source": [ "!pip3 install matplotlib" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from math import e\n", "from typing import List, Callable\n", "from random import random\n", "from matplotlib import pyplot as plt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Function to calculate the dot product\n", "# See: https://en.wikipedia.org/wiki/Dot_product\n", "def dot(a: List[float], b: List[float]) -> float:\n", " return sum(a_i * b_i for a_i, b_i in zip(a, b))\n", "\n", "assert dot([1, 2, 3, 4], [5, 6, 7, 8]) == 70" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# The Step activation function\n", "# See: https://en.wikipedia.org/wiki/Step_function\n", "def step(x: float) -> int:\n", " return 0.0 if x < 0.0 else 1.0\n", "\n", "assert step(-0.1) == 0.0\n", "assert step(0.0) == 1.0\n", "assert step(0.1) == 1.0" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAASrElEQVR4nO3dfYxl9V3H8fcHttRoH1B3WxGWLtWt6WrU4gSbtGqTYgWirM+BxPjUSEzEaOpDaDDYYGJSGzVpROs2NtVGi9THja5BrfUhRipDSykPYrdYZRHLWpvWpAri/frHPQM/DjM7d2bv3Du/8f1KJnvvOWfnfnPu4cNvv+ec30lVIUnq3znLLkCSNB8GuiTtEQa6JO0RBrok7REGuiTtEfuW9cH79++vQ4cOLevjJalLd911179X1YH11i0t0A8dOsTq6uqyPl6SupTknzdaZ8tFkvYIA12S9ggDXZL2CANdkvYIA12S9ohNAz3JO5I8luTeDdYnyVuTnExyT5JL51+mJGkzs4zQ3wlccYb1VwKHh5/rgF85+7IkSVu16XXoVfXXSQ6dYZOjwG/UdB7eO5Kcn+SCqnp0TjVKC3Xqk5/hPauncGpp7ZTXvvzFfMXB8+f+e+dxY9GFwMPN+1PDsmcFepLrmI7iufjii+fw0dL83bZ6ire+9yMky65Ee9WLXvBZuzbQZ1ZVx4BjACsrKw5/tCs9+b8T9p0TTv7sVcsuRdqSeVzl8ghwsHl/0bBM6tKk4ByH5+rQPAL9OPDdw9UurwQ+Zf9cPasq2y3q0qYtlyTvBl4D7E9yCvhp4DkAVfU24ARwFXAS+AzwfTtVrLQIkypH6OrSLFe5XLvJ+gJ+aG4VSUs2bbksuwpp67xTVBpxhK5eGejSSBX20NUlA10amVRxjj0XdchAl0ZsuahXBro04klR9cpAl0am16Gb6OqPgS6NTCaO0NUnA10asYeuXhno0ohzuahXBro04lwu6pWBLo3YclGvDHRpxMsW1SsDXRpxhK5eGejSiHO5qFcGujTiCF29MtClEQNdvTLQpZGJLRd1ykCXRsoRujploEsjk4Jz/C9DHfKwlUbsoatXBro0Mu2hG+jqj4EujUx76MuuQto6A10aKWdbVKcMdGlk4ghdnTLQpZGJj6BTpwx0acTZFtUrA10a8cYi9cpAl0Z8BJ16ZaBLIxMfQadOGejSiCN09cpAl0a8sUi9minQk1yR5MEkJ5PcsM76i5O8L8kHk9yT5Kr5lyothnO5qFebBnqSc4FbgCuBI8C1SY6MNvsp4LaqegVwDfDL8y5UWpTJxLlc1KdZRuiXASer6qGqegK4FTg62qaAFwyvXwj86/xKlBbLO0XVq1kC/ULg4eb9qWFZ603AdyU5BZwAfni9X5TkuiSrSVZPnz69jXKlnedcLurVvE6KXgu8s6ouAq4C3pXkWb+7qo5V1UpVrRw4cGBOHy3N16TKB1yoS7Mcto8AB5v3Fw3LWq8HbgOoqr8DPgvYP48CpUVzLhf1apZAvxM4nOSSJOcxPel5fLTNvwCvBUjycqaBbk9FXbLlol5tGuhV9SRwPXA78ADTq1nuS3JzkquHzX4M+IEkHwLeDXxvVdVOFS3tJE+Kqlf7Ztmoqk4wPdnZLrupeX0/8Kr5liYth3eKqlee+pFGnMtFvTLQpRF76OqVgS6N2ENXrwx0acS5XNQrA10amZRzuahPBro04vS56pWBLo1MR+jLrkLaOgNdGrGHrl4Z6NLIZGKgq08GujRStlzUKQNdGrHlol4Z6NLIdC6XZVchbZ2BLo04QlevDHRppLyxSJ0y0KUR53JRrwx0acSWi3ploEsjnhRVrwx0qbH25ER76OqRgS41JsOTcG25qEcGutSYDCN0Wy7qkYEuNZ4KdBNdHTLQpcaQ587loi4Z6FLj6ZaLia7+GOhS4+mTosutQ9oOA11qOEJXzwx0qVGT6Z9eh64eGehSw8sW1TMDXWrYclHPDHSp4UlR9cxAlxrO5aKezRToSa5I8mCSk0lu2GCb70xyf5L7kvzWfMuUFsO5XNSzfZttkORc4Bbg64FTwJ1JjlfV/c02h4E3Aq+qqk8medFOFSztJE+KqmezjNAvA05W1UNV9QRwK3B0tM0PALdU1ScBquqx+ZYpLcYwQHeEri7NEugXAg83708Ny1ovA16W5G+T3JHkivV+UZLrkqwmWT19+vT2KpZ20GSy1kNfciHSNszrpOg+4DDwGuBa4O1Jzh9vVFXHqmqlqlYOHDgwp4+W5qfsoatjswT6I8DB5v1Fw7LWKeB4Vf1PVf0T8I9MA17qytPT5y65EGkbZjls7wQOJ7kkyXnANcDx0TZ/wHR0TpL9TFswD82xTmkhvLFIPds00KvqSeB64HbgAeC2qrovyc1Jrh42ux34RJL7gfcBP1FVn9ipoqWdMnlqPnQDXf3Z9LJFgKo6AZwYLbupeV3AG4YfqVvlZYvqmJ1CqeGNReqZgS41vLFIPTPQpcbEuVzUMQNdangdunpmoEsNWy7qmYEuNTwpqp4Z6FLj6R76kguRtsFAlxrlnaLqmIEuNWy5qGcGutRYmz7Xk6LqkYEuNZzLRT0z0KWGc7moZwa61Hiqh26iq0MGutTwxiL1zECXGs7lop4Z6FLDuVzUMwNdathyUc8MdKnhjUXqmYEuNdZG6FKPDHSp4Vwu6pmBLjWevg59uXVI2+FhKzUmjtDVMQNdajx9UnS5dUjbYaBLjfLGInXMQJcatlzUMwNdakwm0z9tuahHBrrUcISunhnoUqOeesDFcuuQtsNAlxqO0NUzA11qOJeLemagSw1nW1TPZgr0JFckeTDJySQ3nGG7b0tSSVbmV6K0OF6Hrp5tGuhJzgVuAa4EjgDXJjmyznbPB34EeP+8i5QWxTtF1bNZRuiXASer6qGqegK4FTi6znY/A7wZ+O851ictlCdF1bNZAv1C4OHm/alh2VOSXAocrKo/PtMvSnJdktUkq6dPn95ysdJO86SoenbWJ0WTnAP8AvBjm21bVceqaqWqVg4cOHC2Hy3N3VM9dC8XUIdmOWwfAQ427y8alq15PvBlwF8m+RjwSuC4J0bVI1su6tksgX4ncDjJJUnOA64Bjq+trKpPVdX+qjpUVYeAO4Crq2p1RyqWdpAnRdWzTQO9qp4ErgduBx4Abquq+5LcnOTqnS5QWiRH6OrZvlk2qqoTwInRsps22PY1Z1+WtBzO5aKeeepHakwmjtDVLwNdanjZonpmoEsN53JRzwx0qeFcLuqZgS41Ckfn6peBLjUmVfbP1S0DXWpMyhOi6peBLjUmVV6Drm4Z6FKjHKGrYwa61JhMypOi6paBLjXsoatnBrrUsIeunhnoUqOqOMeeizploEsNWy7qmYEuNaY3Fi27Cml7DHSpMSnncVG/DHSpUY7Q1TEDXWo4l4t6ZqBLDU+KqmcGutTwOnT1zECXGs7lop4Z6FLDyxbVMwNdathDV88MdKlhD109M9ClRnnZojpmoEuNycSWi/ploEsNWy7qmYEuNTwpqp4Z6FJjOh/6squQtsdDV2pMqgiO0NWnmQI9yRVJHkxyMskN66x/Q5L7k9yT5L1JXjL/UqWdN225LLsKaXs2DfQk5wK3AFcCR4BrkxwZbfZBYKWqvhz4HeDn5l2otAjTk6Imuvo0ywj9MuBkVT1UVU8AtwJH2w2q6n1V9Znh7R3ARfMtU1qMcoSujs0S6BcCDzfvTw3LNvJ64E/WW5HkuiSrSVZPnz49e5XSgjgfuno215OiSb4LWAHest76qjpWVStVtXLgwIF5frQ0Fwa6erZvhm0eAQ427y8alj1DksuBG4Gvq6rH51OetFjTZ4ouuwppe2YZod8JHE5ySZLzgGuA4+0GSV4B/CpwdVU9Nv8ypcVwLhf1bNNAr6ongeuB24EHgNuq6r4kNye5etjsLcDzgPckuTvJ8Q1+nbSrTQpvLFK3Zmm5UFUngBOjZTc1ry+fc13SUthDV88ci0iNaQ/dQFefDHSpUT6CTh0z0KWGLRf1zECXGtMHXCy7Cml7DHSp4Vwu6pmBLjWcy0U9M9Clhj109cxAlxoGunpmoEuNci4XdcxAlxqO0NUzA11q+Ag69cxAlxqO0NUzA11qlHO5qGMGutSYOJeLOmagSw1bLuqZgS41fMCFeuahKzXKuVzUMQNdajiXi3pmoEsNe+jqmYEuNaY3Fhno6pOBLjWm86Evuwppewx0qVGO0NUxA11qeGORemagSw1PiqpnBrrUmDiXizpmoEuNsuWijhnoUsPLFtUzA11qeFJUPTPQpUFVOR+6umagS4Oq6Z+2XNQrA10aTIZEt+WiXs0U6EmuSPJgkpNJblhn/XOT/Paw/v1JDs27UGmnTdZG6Ca6OrVpoCc5F7gFuBI4Alyb5Mhos9cDn6yqLwZ+EXjzvAuVdtraCN2Oi3q1b4ZtLgNOVtVDAEluBY4C9zfbHAXeNLz+HeCXkqRqrSs5P7fd+TBv/5uH5v1rpacDHRNdfZol0C8EHm7enwK+eqNtqurJJJ8CPh/493ajJNcB1wFcfPHF2yr4/M9+Dodf/Lxt/V1pM0e+8IVc/vIXLbsMaVtmCfS5qapjwDGAlZWVbY3eX/elX8DrvvQL5lqXJO0Fs5wUfQQ42Ly/aFi27jZJ9gEvBD4xjwIlSbOZJdDvBA4nuSTJecA1wPHRNseB7xlefzvwFzvRP5ckbWzTlsvQE78euB04F3hHVd2X5GZgtaqOA78GvCvJSeA/mIa+JGmBZuqhV9UJ4MRo2U3N6/8GvmO+pUmStsI7RSVpjzDQJWmPMNAlaY8w0CVpj8iyri5Mchr4523+9f2M7kLdRXZrbda1Nda1dbu1tr1W10uq6sB6K5YW6GcjyWpVrSy7jvXs1tqsa2usa+t2a23/n+qy5SJJe4SBLkl7RK+BfmzZBZzBbq3NurbGurZut9b2/6auLnvokqRn63WELkkaMdAlaY/Y9YGe5DuS3JdkkmRltO6Nw4OpH0zyDc3yMz7Uegdq/O0kdw8/H0ty97D8UJL/ata9badrGdX1piSPNJ9/VbNu3X23wNrekuQfktyT5PeTnD8sX+o+G2pY6PFzhjoOJnlfkvuH/wZ+ZFi+4fe6wNo+luTDw+evDss+L8mfJfnI8OfnLrimL2n2yd1JPp3kR5e1v5K8I8ljSe5tlq27jzL11uGYuyfJpdv60Kra1T/Ay4EvAf4SWGmWHwE+BDwXuAT4KNPpfc8dXr8UOG/Y5sgC6/154Kbh9SHg3iXuuzcBP77O8nX33YJrex2wb3j9ZuDNu2SfLfX4GdVyAXDp8Pr5wD8O39263+uCa/sYsH+07OeAG4bXN6x9p0v8Hv8NeMmy9hfwtcCl7fG80T4CrgL+BAjwSuD92/nMXT9Cr6oHqurBdVYdBW6tqser6p+Ak0wfaP3UQ62r6glg7aHWOy5JgO8E3r2IzzsLG+27hamqP62qJ4e3dzB9EtZusLTjZ6yqHq2qDwyv/xN4gOnze3ero8CvD69/HfjmJdbyWuCjVbXdu9HPWlX9NdPnQ7Q22kdHgd+oqTuA85NcsNXP3PWBfgbrPbz6wjMsX4SvAT5eVR9pll2S5INJ/irJ1yyojtb1wz/h3tH8E3iZ+2g93890dLJmmftst+0bYNqKAl4BvH9YtN73ukgF/GmSuzJ9+DvAi6vq0eH1vwEvXkJda67hmQOrZe+vNRvto7kcd7si0JP8eZJ71/lZyshoPTPWeC3PPIgeBS6uqlcAbwB+K8kLFljXrwBfBHzlUMvPz/Ozz7K2tW1uBJ4EfnNYtOP7rDdJngf8LvCjVfVplvy9Dl5dVZcCVwI/lORr25U17SMs5ZroTB+VeTXwnmHRbthfz7IT+2imJxbttKq6fBt/7UwPr97sodZbtlmNmT4c+1uBr2r+zuPA48Pru5J8FHgZsHq29cxaV1Pf24E/Gt7O8uDvszbDPvte4BuB1w4H90L22SYWsm9mleQ5TMP8N6vq9wCq6uPN+vZ7XZiqemT487Ekv8+0VfXxJBdU1aNDu+CxRdc1uBL4wNp+2g37q7HRPprLcbcrRujbdBy4Jslzk1wCHAb+ntkear0TLgf+oapOrS1IciDJucPrlw41PrSAWtY+v+3BfQuwdrZ9o323MEmuAH4SuLqqPtMsX+o+Y3nHz7MM52R+DXigqn6hWb7R97qouj4nyfPXXjM9wX0vz3xY/PcAf7jIuhrP+JfysvfXyEb76Djw3cPVLq8EPtW0Zma36DO/2zhT/C1M+0mPAx8Hbm/W3cj0ioQHgSub5VcxvSLgo8CNC6rzncAPjpZ9G3AfcDfwAeCbFrzv3gV8GLhnOGAu2GzfLbC2k0x7hncPP2/bDftsWcfPBnW8muk/ye9p9tNVZ/peF1TXS5le/fOh4bu6cVj++cB7gY8Afw583hL22ecAnwBe2Cxbyv5i+j+VR4H/GTLs9RvtI6ZXt9wyHHMfprmibys/3vovSXtEzy0XSVLDQJekPcJAl6Q9wkCXpD3CQJekPcJAl6Q9wkCXpD3i/wAmSRRoAvz5hwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot([x for x in range(-100, 100)], [step(x) for x in range(-100, 100)]);" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# The `Neuron` class represents a single Neuron / Perceptron\n", "class Neuron:\n", " def __init__(self, activation: Callable) -> None:\n", " self._weights: List[float] = [random() for i in range(2)]\n", " self._bias: float = random()\n", " self._activation: Callable = activation\n", " \n", " def forward(self, x: List[float]) -> float:\n", " return self._activation(dot(x, self._weights) + self._bias)\n", "\n", " @property\n", " def weights(self) -> List[float]:\n", " return self._weights\n", "\n", " @property\n", " def bias(self) -> float:\n", " return self._bias\n", "\n", " @weights.setter\n", " def weights(self, weights: List[float]) -> None:\n", " self._weights: List[float] = weights\n", " \n", " @bias.setter\n", " def bias(self, bias: float) -> None:\n", " self._bias: float = bias" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Using a Perceptron to \"learn\" AND\n", "# See: https://en.wikipedia.org/wiki/Logical_conjunction\n", "x_1: List[float] = [0, 0]\n", "x_2: List[float] = [0, 1]\n", "x_3: List[float] = [1, 0]\n", "x_4: List[float] = [1, 1]\n", "\n", "y_1: float = 0.0\n", "y_2: float = 0.0\n", "y_3: float = 0.0\n", "y_4: float = 1.0\n", "\n", "# Creating a single `Neuron` (with a `step` activation function)\n", "# and setting the weights and bias\n", "perceptron_and: Neuron = Neuron(step)\n", "perceptron_and.weights = [1.5, 1.5]\n", "perceptron_and.bias = -2\n", "\n", "assert perceptron_and.forward(x_1) == y_1\n", "assert perceptron_and.forward(x_2) == y_2\n", "assert perceptron_and.forward(x_3) == y_3\n", "assert perceptron_and.forward(x_4) == y_4" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# The Sigmoid activation function\n", "# See: https://en.wikipedia.org/wiki/Sigmoid_function\n", "def sigmoid(x: float) -> float:\n", " return 1 / (1 + (e ** -x))\n", "\n", "assert sigmoid(0) == 0.5" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot([x for x in range(-10, 10)], [sigmoid(x) for x in range(-10, 10)]);" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# The derivative of the Sigmoid activation function\n", "def d_sigmoid(x: float) -> float:\n", " return sigmoid(x) * (1 - sigmoid(x))\n", "\n", "assert d_sigmoid(0) == 0.25" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot([x for x in range(-10, 10)], [d_sigmoid(x) for x in range(-10, 10)]);" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Sum of squared errors function\n", "# See: https://en.wikipedia.org/wiki/Residual_sum_of_squares\n", "def sum_squared_error(ys: List[float], ys_pred: List[float]) -> float:\n", " return sum((y - y_pred) ** 2 for y, y_pred in zip(ys, ys_pred))\n", "\n", "assert sum_squared_error([0, 1], [0, 1]) == 0\n", "assert sum_squared_error([0, 0], [1, 1]) == 2" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Derivative of the squared error function\n", "# NOTE: The squared error function is expressed as `(y - y_pred) ** 2`\n", "def d_squared_error(y: float, y_pred: float) -> float:\n", " return -2 * (y - y_pred)\n", "\n", "assert d_squared_error(1, 4) == 6\n", "assert d_squared_error(1, 1) == 0" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- Before training ---\n", "Prediction for [0, 0] (should be 0): 0.59\n", "Prediction for [0, 1] (should be 1): 0.67\n", "Prediction for [1, 0] (should be 1): 0.76\n", "Prediction for [1, 1] (should be 1): 0.81\n", "\n", "Training Perceptron for 5000 epochs with Learning Rate 0.1\n", "\n", "--- After training ---\n", "Prediction for [0, 0] (should be 0): 0.05\n", "Prediction for [0, 1] (should be 1): 0.97\n", "Prediction for [1, 0] (should be 1): 0.97\n", "Prediction for [1, 1] (should be 1): 1.00\n" ] } ], "source": [ "# Using a Perceptron to \"learn\" OR\n", "# See: https://en.wikipedia.org/wiki/Logical_disjunction\n", "xs: List[List[float]] = [[0, 0], [0, 1], [1, 0], [1, 1]]\n", "ys: List[float] = [0, 1, 1, 1]\n", "\n", "# Creating a single `Neuron` (with a `sigmoid` activation function)\n", "perceptron_or: Neuron = Neuron(sigmoid)\n", "\n", "# Print out what the randomized weights and bias initializations cause the Perceptron to predict\n", "print('--- Before training ---')\n", "print('Prediction for [0, 0] (should be 0): {num:.{digits}f}'.format(num=perceptron_or.forward([0, 0]), digits=2))\n", "print('Prediction for [0, 1] (should be 1): {num:.{digits}f}'.format(num=perceptron_or.forward([0, 1]), digits=2))\n", "print('Prediction for [1, 0] (should be 1): {num:.{digits}f}'.format(num=perceptron_or.forward([1, 0]), digits=2))\n", "print('Prediction for [1, 1] (should be 1): {num:.{digits}f}'.format(num=perceptron_or.forward([1, 1]), digits=2))\n", "\n", "# Train the Perceptron\n", "epochs: int = 5000\n", "learning_rate: int = 0.1\n", "\n", "print(f'\\nTraining Perceptron for {epochs} epochs with Learning Rate {learning_rate}\\n')\n", " \n", "for epoch in range(epochs):\n", " for x, y in zip(xs, ys):\n", " y_pred: float = perceptron_or.forward(x)\n", "\n", " d_error_d_y_pred: float = d_squared_error(y, y_pred)\n", "\n", " w1: float = perceptron_or.weights[0]\n", " w2: float = perceptron_or.weights[1]\n", " b: float = perceptron_or.bias\n", "\n", " sum_p: float = dot([x[0], x[1]], [w1, w2]) + b\n", " d_y_pred_d_w1: float = x[0] * d_sigmoid(sum_p)\n", " d_y_pred_d_w2: float = x[1] * d_sigmoid(sum_p)\n", " d_y_pred_d_b: float = d_sigmoid(sum_p)\n", "\n", " perceptron_or.weights[0] += -learning_rate * d_error_d_y_pred * d_y_pred_d_w1\n", " perceptron_or.weights[1] += -learning_rate * d_error_d_y_pred * d_y_pred_d_w2\n", " perceptron_or.bias += -learning_rate * d_error_d_y_pred * d_y_pred_d_b\n", "\n", "# Print out what the Perceptron predicts after training (updating the weights and biases)\n", "print('--- After training ---')\n", "print('Prediction for [0, 0] (should be 0): {num:.{digits}f}'.format(num=perceptron_or.forward([0, 0]), digits=2))\n", "print('Prediction for [0, 1] (should be 1): {num:.{digits}f}'.format(num=perceptron_or.forward([0, 1]), digits=2))\n", "print('Prediction for [1, 0] (should be 1): {num:.{digits}f}'.format(num=perceptron_or.forward([1, 0]), digits=2))\n", "print('Prediction for [1, 1] (should be 1): {num:.{digits}f}'.format(num=perceptron_or.forward([1, 1]), digits=2))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Our 3 Layer (input, hidden, output) Neural Network\n", "class NeuralNetwork:\n", " def __init__(self, activation: Callable):\n", " self._activation: Callable = activation\n", " # `Neuron`s of the hidden layer\n", " self._n1: Neuron = Neuron(self._activation)\n", " self._n2: Neuron = Neuron(self._activation)\n", " # `Neuron` of the output layer\n", " self._n3: Neuron = Neuron(self._activation)\n", " # The most recent x values we've seen\n", " self._x1: float = 0\n", " self._x2: float = 0\n", " # The results from the most recent forward-pass\n", " self._res_n1: float = 0\n", " self._res_n2: float = 0\n", " self._res_n3: float = 0\n", "\n", " # Forward-pass the data through the network\n", " def forward(self, x: List[float]) -> float:\n", " # Store a copy of the x values\n", " self._x1: float = x[0]\n", " self._x2: float = x[1]\n", " # Construct the data flow by combining the NN layers\n", " res_n1: float = self._n1.forward(x)\n", " res_n2: float = self._n2.forward(x)\n", " res_n3: float = self._n3.forward([res_n1, res_n2])\n", " # Store the most recent result of every single `Neuron`\n", " self._res_n1: float = res_n1\n", " self._res_n2: float = res_n2\n", " self._res_n3: float = res_n3\n", " # Return the overall result (the prediction)\n", " return res_n3\n", "\n", " # Back-propagating weight and bias updates across the network\n", " def backward(self, y: float, lr: float) -> None:\n", " # The most recent x values we've used\n", " x1: float = self._x1\n", " x2: float = self._x2\n", " \n", " # The individual `Neuron` results\n", " n1: float = self._res_n1\n", " n2: float = self._res_n2\n", " n3: float = self._res_n3\n", " y_pred: float = n3\n", " \n", " # The individual `Neuron`s weights and biases\n", " # `Neuron` 1\n", " b1: float = self._n1.bias\n", " w1: float = self._n1.weights[0]\n", " w2: float = self._n1.weights[1]\n", " # `Neuron` 2\n", " b2: float = self._n2.bias\n", " w3: float = self._n2.weights[0]\n", " w4: float = self._n2.weights[1]\n", " # `Neuron` 3 (output)\n", " b3: float = self._n3.bias\n", " w5: float = self._n3.weights[0]\n", " w6: float = self._n3.weights[1]\n", "\n", " # The partial derivative for the error function is used in every computation\n", " d_error_d_y_pred: float = d_squared_error(y, y_pred)\n", " \n", " # Calculate the partial derivatives for the individual weights and biases\n", " # `Neuron` 1\n", " sum_n1: float = dot([x1, x2], [w1, w2]) + b1\n", " d_n1_d_w1: float = x1 * d_sigmoid(sum_n1)\n", " d_n1_d_w2: float = x2 * d_sigmoid(sum_n1)\n", " d_n1_d_b1: float = d_sigmoid(sum_n1)\n", " # `Neuron` 2\n", " sum_n2: float = dot([x1, x2], [w3, w4]) + b2\n", " d_n2_d_w3: float = x1 * d_sigmoid(sum_n2)\n", " d_n2_d_w4: float = x2 * d_sigmoid(sum_n2)\n", " d_n2_d_b2: float = d_sigmoid(sum_n2)\n", " # `Neuron` 3 (output)\n", " sum_n3: float = dot([n1, n2], [w5, w6]) + b3\n", " d_y_pred_d_n1: float = w5 * d_sigmoid(sum_n3)\n", " d_y_pred_d_n2: float = w6 * d_sigmoid(sum_n3)\n", " d_y_pred_d_w5: float = n1 * d_sigmoid(sum_n3)\n", " d_y_pred_d_w6: float = n2 * d_sigmoid(sum_n3)\n", " d_y_pred_d_b3: float = d_sigmoid(sum_n3)\n", "\n", " # Update the weights and biases\n", " # `Neuron` 1\n", " self._n1.weights[0] += -lr * d_error_d_y_pred * d_y_pred_d_n1 * d_n1_d_w1\n", " self._n1.weights[1] += -lr * d_error_d_y_pred * d_y_pred_d_n1 * d_n1_d_w2\n", " self._n1.bias += -lr * d_error_d_y_pred * d_y_pred_d_n1 * d_n1_d_b1\n", " # `Neuron` 2\n", " self._n2.weights[0] += -lr * d_error_d_y_pred * d_y_pred_d_n2 * d_n2_d_w3\n", " self._n2.weights[1] += -lr * d_error_d_y_pred * d_y_pred_d_n2 * d_n2_d_w4\n", " self._n2.bias += -lr * d_error_d_y_pred * d_y_pred_d_n2 * d_n2_d_b2\n", " # `Neuron` 3 (output)\n", " self._n3.weights[0] += -lr * d_error_d_y_pred * d_y_pred_d_w5\n", " self._n3.weights[1] += -lr * d_error_d_y_pred * d_y_pred_d_w6\n", " self._n3.bias += -lr * d_error_d_y_pred * d_y_pred_d_b3" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Create a new Neural Network with a `sigmoid` activation function\n", "nn: NeuralNetwork = NeuralNetwork(sigmoid)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- Before training ---\n", "Prediction for [0, 0] (should be 0): 0.79\n", "Prediction for [0, 1] (should be 1): 0.83\n", "Prediction for [1, 0] (should be 1): 0.83\n", "Prediction for [1, 1] (should be 0): 0.85\n" ] } ], "source": [ "# Print out what the randomized weights and bias initializations cause the NN to predict\n", "print('--- Before training ---')\n", "print('Prediction for [0, 0] (should be 0): {num:.{digits}f}'.format(num=nn.forward([0, 0]), digits=2))\n", "print('Prediction for [0, 1] (should be 1): {num:.{digits}f}'.format(num=nn.forward([0, 1]), digits=2))\n", "print('Prediction for [1, 0] (should be 1): {num:.{digits}f}'.format(num=nn.forward([1, 0]), digits=2))\n", "print('Prediction for [1, 1] (should be 0): {num:.{digits}f}'.format(num=nn.forward([1, 1]), digits=2))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 --> loss: 1.4036893181184023\n", "Epoch 5000 --> loss: 0.014574431740830018\n", "Epoch 10000 --> loss: 0.004078498610061902\n", "Epoch 15000 --> loss: 0.0023138620933954145\n", "Epoch 20000 --> loss: 0.0016032268987062385\n", "Epoch 25000 --> loss: 0.001222352177475733\n" ] } ], "source": [ "# Train the Neural Network to learn the XOR function\n", "# See: https://en.wikipedia.org/wiki/Exclusive_or\n", "\n", "# XOR function data\n", "xs: List[List[float]] = [[0, 0], [0, 1], [1, 0], [1, 1]]\n", "ys: List[float] = [0, 1, 1, 0]\n", "\n", "epochs: int = 30000\n", "learning_rate: float = 0.1\n", "\n", "# This list is used to record the loss at every epoch to plot it later on\n", "losses: List[float] = []\n", "\n", "# This list is used to record our current y predictions\n", "ys_pred: List[float] = []\n", "\n", "for epoch in range(epochs):\n", " # Show the NN the whole data set once\n", " for x, y in zip(xs, ys):\n", " # Let it make a prediction\n", " y_pred: float = nn.forward(x)\n", " # Record the prediction\n", " ys_pred.append(y_pred)\n", " # Do a backward-pass to update the weights and biases\n", " nn.backward(y, lr=learning_rate)\n", "\n", " # Calculate and record the loss for every epoch\n", " # so that we can plot it later on\n", " loss: float = sum_squared_error(ys, ys_pred)\n", " losses.append(loss)\n", "\n", " # Print the loss every once in a while\n", " if epoch % 5000 == 0:\n", " print(f'Epoch {epoch} --> loss: {loss}')\n", "\n", " # Clear the list of predictions after the NN has seen the whole data set\n", " ys_pred.clear()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAd60lEQVR4nO3deXgc9Z3n8fdHp+VLtmXZOD6wzZHEBBJAIeRYQm4gWUgWdoBnspsQMs5Fkp3M5AnzZB8my+w+T45nswkzTBgnSwjMBELOx5N1TsLADFcsEnBsM4AwdpAxWDbGt2VL+u4fXbJbcuuw6eruUn1ez6Onq35VXfUttaSPqn7Vv1ZEYGZm+VVX7QLMzKy6HARmZjnnIDAzyzkHgZlZzjkIzMxyrqHaBRyr2bNnx+LFi6tdhplZpjz88MPbIqK91LLMBcHixYvp7OysdhlmZpkiadNIy3xpyMws5xwEZmY55yAwM8s5B4GZWc45CMzMcs5BYGaWc6kFgaSbJW2VtHaM9V4rqU/SZWnVYmZmI0vzjOAW4ILRVpBUD3wJ+GWKdQDw+HO7+eovH2fbnt60d2VmlimpBUFE3Au8MMZqnwR+CGxNq45BT27dzQ2/6eKFvQfT3pWZWaZUrY9A0nzgfcA3xrHuckmdkjp7enqOb38IAH8Oj5nZUNXsLP4a8LmIGBhrxYhYEREdEdHR3l5yqIwxScm2cBKYmRWr5lhDHcAdKvyFng1cJKkvIn6Sxs7qkiAYGDN2zMzypWpBEBFLBqcl3QL8NK0QSPZS2K/PCMzMhkgtCCTdDpwPzJbUDfw10AgQETeltd+R6yk8uo/AzGyo1IIgIq48hnU/mFYdg5T2DszMMio37yxO+iJ8RmBmNkx+giB5dB+BmdlQ+QkC9xGYmZWUmyCoS5JgwElgZjZEboKAw28oMzOzYrkJgsN9BE4CM7Mh8hMEOtJdbGZmR+QnCJJHnxGYmQ2VnyBwH4GZWUn5CQIPQ21mVlJugqDu8PsInARmZsVyEwSDnQQDzgEzsyFyEwTyMNRmZiXlJwh896iZWUn5CYLk0TlgZjZUfoLAw1CbmZWUoyAoPLqPwMxsqNwEQZ2HoTYzKyk3QTDYS+BhqM3MhkotCCTdLGmrpLUjLP9TSWsk/UHS/ZJenVYthf0VHh0DZmZDpXlGcAtwwSjLnwbeHBGnA38DrEixliMfXu8kMDMboiGtDUfEvZIWj7L8/qLZB4EFadUCRXcNOQnMzIaolT6Cq4GfjbRQ0nJJnZI6e3p6jmsHHobazKy0qgeBpLdQCILPjbRORKyIiI6I6Ghvbz/O/Qxu67iebmY2YaV2aWg8JJ0BfAu4MCK2p7mvusOXhszMrFjVzggkLQJ+BPyXiHiiUvv17aNmZkOldkYg6XbgfGC2pG7gr4FGgIi4CbgOaAP+PunI7YuIjvTqKTw6B8zMhkrzrqErx1j+YeDDae1/OHnYOTOzkqreWVwpPiMwMystf0FQ3TLMzGpOfoLAH15vZlZSboKgzsNQm5mVlJsgkD+83syspNwEAYcvDTkJzMyK5SYIDn94vZmZDZGfIEgefUJgZjZUfoLAw1CbmZWUnyBIHn1GYGY2VG6C4PDoow4CM7MhchMER24fdRKYmRXLTRAMcgyYmQ2VmyCQBx81MyspR0Hgu4bMzErJTxAkj+4iMDMbKj9B4GGozcxKyk0QDN4+6ruGzMyGyk0QtLY0AnDr/Zu4+9+3cuBQf5UrMjOrDakFgaSbJW2VtHaE5ZJ0g6QuSWsknZVWLQCTGuu5vGMhG7fv5apbVvOGL/6Gr//6SQeCmeVemmcEtwAXjLL8QuCU5Gs58I0UawHgS5edwaN//U5uueq1nLVoJv/n10/w7hv+lT9u35f2rs3MalZqQRAR9wIvjLLKJcCtUfAgMEPSvLTqGTSpsZ7zXz6Hb32gg1s/dA7b9x7kspvuZ8vO/Wnv2sysJlWzj2A+8EzRfHfSdhRJyyV1Surs6ekpWwHnndrO95a/nn0H+/nobQ/T1z9Qtm2bmWVFJjqLI2JFRHREREd7e3tZt/3yE6bxxUtP59HunXzngU1l3baZWRZUMwg2AwuL5hckbRX37tPncd6p7Xz910+wp7evGiWYmVVNNYNgJfBfk7uHzgV2RsSWahQiic+841R2Hejjuw/5rMDM8iXN20dvBx4AXi6pW9LVkj4q6aPJKquADUAX8E3g42nVMh6vWTiDc5fO4tYHNjEw4DedmVl+NKS14Yi4cozlAXwirf0fj8tfu5A//96jrN74Aq9b2lbtcszMKiITncWV8q7TTmBKUz0/+l1VuirMzKrCQVBkclMDb33lXO769+d9ecjMcsNBMMzbXjGHbXsOsmbzzmqXYmZWEQ6CYd58ajt1gt889ny1SzEzqwgHwTAzpzRx5qKZ3PNE+d7BbGZWyxwEJbzhpDbWPruL3QcOVbsUM7PUOQhKeN2SNvoHgs5NO6pdiplZ6hwEJZx14gwa68WDG7ZXuxQzs9Q5CEqY3NTAqxfM4MENo42ibWY2MTgIRnDu0jbWbt7pfgIzm/AcBCM4d6n7CcwsHxwEI3A/gZnlhYNgBO4nMLO8cBCM4vUnuZ/AzCY+B8Eo3E9gZnngIBjFWYtmFvoJnnI/gZlNXA6CUbQ01fOahTPcYWxmE5qDYAyvX9rGH9xPYGYTmINgDOcubWMgYPVG3z1kZhNTqkEg6QJJj0vqknRtieWLJN0t6feS1ki6KM16jsdZJ85kUmMd9zzuYanNbGJKLQgk1QM3AhcCy4ArJS0bttp/B+6MiDOBK4C/T6ue4zWpsZ43ndzOr9Y/T4Q/vtLMJp40zwjOAboiYkNEHATuAC4Ztk4A05PpVuDZFOs5bu88bS7P7jzAumd3VbsUM7OySzMI5gPPFM13J23FvgC8X1I3sAr4ZKkNSVouqVNSZ09P5S/RvO0Vc6gT/HLdcxXft5lZ2qrdWXwlcEtELAAuAm6TdFRNEbEiIjoioqO9vb3iRbZNbea1i2fxz2u2+PKQmU04aQbBZmBh0fyCpK3Y1cCdABHxADAJmJ1iTcftsrMX8PS2vTzsdxmb2QSTZhCsBk6RtERSE4XO4JXD1vkj8DYASa+kEAQ1eXvORafPY3JTPXd2PjP2ymZmGZJaEEREH3AN8AvgMQp3B62TdL2ki5PV/gL4M0mPArcDH4wavfYypbmB95wxj5+u2cLOfX5zmZlNHA1pbjwiVlHoBC5uu65oej3wxjRrKKer3riEOzu7+ceHNvGJt5xc7XLMzMpiXGcEkk6S1JxMny/pU5JmpFta7XnlvOmcd2o7375vIwcO9Ve7HDOzshjvpaEfAv2STgZWUOgE/m5qVdWwj5y3lG17evnR74b3e5uZZdN4g2Agueb/PuBvI+KzwLz0yqpdbzipjdPnt7Li3qfoH6jJ7gwzs2My3iA4JOlK4APAT5O2xnRKqm2S+Pj5J7Fx+z5+vtZvMDOz7BtvEFwFvB74XxHxtKQlwG3plVXb3nnaCSydPYVv3NPlN5iZWeaNKwgiYn1EfCoibpc0E5gWEV9KubaaVV8nPvLmpazdvIt/69pW7XLMzF6S8d419C+SpkuaBfwO+Kakr6ZbWm1775nzmTu9mX+4Z0O1SzEze0nGe2moNSJ2Af8JuDUiXge8Pb2yal9zQz1XnrOI+57axvO7DlS7HDOz4zbeIGiQNA/4E450Fufee86YRwSs+sOWapdiZnbcxhsE11MYKuKpiFgtaSnwZHplZcPJc6Zxypyp3PXY1mqXYmZ23MbbWfz9iDgjIj6WzG+IiEvTLS0b3njybB7etIODfQPVLsXM7LiMt7N4gaQfS9qafP1Q0oK0i8uCc5fOYv+hftZ0v1jtUszMjst4Lw19m8IQ0i9Lvv45acu9s0+cBcCj3TurXImZ2fEZbxC0R8S3I6Iv+boFqPxHhdWg9mnNtE9rZr0/z9jMMmq8QbBd0vsl1Sdf7we2p1lYliybN53HtjgIzCybxhsEH6Jw6+hzwBbgMuCDKdWUOa+cN50nt+7mUL87jM0se8Z719CmiLg4ItojYk5EvBfwXUOJk+dM5VB/0L1jf7VLMTM7Zi/loyo/U7YqMm5x22QANm3fW+VKzMyO3UsJApWtioxbdDgI9lW5EjOzY/dSgmDM8ZclXSDpcUldkq4dYZ0/kbRe0jpJmfzUs/apzUxuqmejzwjMLING/fB6Sbsp/QdfQMsYz60HbgTeAXQDqyWtTD6wfnCdU4C/At4YETskzTnG+muCJE5sm+IzAjPLpFGDICKmvYRtnwN0RcQGAEl3AJcA64vW+TPgxojYkewvs4P2zJ/RQvcOB4GZZc9LuTQ0lvnAM0Xz3UlbsVOBUyXdJ+lBSRekWE+qTmht5jkPR21mGTTqGUGF9n8KcD6wALhX0ukRMWTgHknLgeUAixYtqnSN43LC9Em8uO8QBw71M6mxvtrlmJmNW5pnBJuBhUXzC5K2Yt3Ayog4FBFPA09QCIYhImJFRHREREd7e22ObDF3+iQAf0iNmWVOmkGwGjhF0hJJTcAVFAauK/YTCmcDSJpN4VJRJj/78YTWQhA8t9NBYGbZkloQREQfcA2FD7R5DLgzItZJul7Sxclqv6AwjtF64G7gsxGRyTGMTkjOCNxPYGZZk2ofQUSsAlYNa7uuaDoovEM58+9SntvqS0Nmlk1pXhrKlWnNDTQ11LF9z8Fql2JmdkwcBGUiibYpTbyw10FgZtniICijmZMdBGaWPQ6CMmqb2sR2B4GZZYyDoIxm+dKQmWWQg6CMZk1pYoeDwMwyxkFQRrMmN7G7t4/evv5ql2JmNm4OgjKaNbUJgB17D1W5EjOz8XMQlFHblEIQuJ/AzLLEQVBGMyc7CMwsexwEZdQ6uRGAnft9acjMssNBUEatLQ4CM8seB0EZOQjMLIscBGXU0lhPU32dg8DMMsVBUEaSmN7S6CAws0xxEJRZa0sDuxwEZpYhDoIya21p5MX9vn3UzLLDQVBmrb40ZGYZ4yAoMweBmWWNg6DMWlsa2bnPQWBm2ZFqEEi6QNLjkrokXTvKepdKCkkdadZTCa0tjezu7WNgIKpdipnZuKQWBJLqgRuBC4FlwJWSlpVYbxrwaeChtGqppOktjUTA7gN91S7FzGxc0jwjOAfoiogNEXEQuAO4pMR6fwN8CTiQYi0V43cXm1nWpBkE84Fniua7k7bDJJ0FLIyI/zfahiQtl9QpqbOnp6f8lZbRjGQEUgeBmWVF1TqLJdUBXwX+Yqx1I2JFRHREREd7e3v6xb0EPiMws6xJMwg2AwuL5hckbYOmAa8C/kXSRuBcYGXWO4wHg8BvKjOzrEgzCFYDp0haIqkJuAJYObgwInZGxOyIWBwRi4EHgYsjojPFmlLnMwIzy5rUgiAi+oBrgF8AjwF3RsQ6SddLujit/Vabg8DMsqYhzY1HxCpg1bC260ZY9/w0a6mUSY11HorazDLF7ywus8GhqD0CqZllhYMgBTMmN/Kih5kws4xwEKRgpoPAzDLEQZCCGZOb2LHPt4+aWTY4CFIwy0FgZhniIEjBjCmN7Nh7iAiPQGpmtc9BkIKZk5s42D/AvoP91S7FzGxMDoIUzEoGnvPlITPLAgdBCmZMLry7eMde3zlkZrXPQZCCWVN8RmBm2eEgSMEMXxoyswxxEKTg8BnBXgeBmdU+B0EKWlsakWCH311sZhngIEhBfZ1obWn0pSEzywQHQUpmTm7yGYGZZYKDICVtU5rYtru32mWYmY3JQZCSudMnsXX3gWqXYWY2JgdBSuZMb+b5XT4jMLPa5yBIydzpk9jT28ee3r5ql2JmNqpUg0DSBZIel9Ql6doSyz8jab2kNZLuknRimvVU0tzpzQBs3eXLQ2ZW21ILAkn1wI3AhcAy4EpJy4at9nugIyLOAH4AfDmteipt7vRJADznIDCzGpfmGcE5QFdEbIiIg8AdwCXFK0TE3RGxL5l9EFiQYj0VNRgEW91PYGY1Ls0gmA88UzTfnbSN5GrgZ6UWSFouqVNSZ09PTxlLTM9gEDzvMwIzq3E10Vks6f1AB/CVUssjYkVEdERER3t7e2WLO05TmxuYNqmBzS/ur3YpZmajakhx25uBhUXzC5K2ISS9Hfg88OaImFDXURa3TWHT9n1jr2hmVkVpnhGsBk6RtERSE3AFsLJ4BUlnAv8AXBwRW1OspSpObJvMxu17q12GmdmoUguCiOgDrgF+ATwG3BkR6yRdL+niZLWvAFOB70t6RNLKETaXSYvbptC9Yz+H+geqXYqZ2YjSvDRERKwCVg1ru65o+u1p7r/aTmybTP9AsHnHfhbPnlLtcszMSqqJzuKJavCP/9PbfHnIzGqXgyBFp86dBsBjz+2qciVmZiNzEKSotaWRhbNaWLfZQWBmtctBkLJXvayVtc/urHYZZmYjchCk7FXzW9m0fR87/WllZlajHAQp6zhxJgAPPr29ypWYmZXmIEjZmYtmMrmpnvu6tlW7FDOzkhwEKWtqqON1S2Zx7xM9RES1yzEzO4qDoALesewENm7fx1rfPWRmNchBUAHvPn0eTfV1/Oj33dUuxczsKA6CCmid3Mg7T5vLDzq72bnfdw+ZWW1xEFTIx84/id29fdz8b09XuxQzsyEcBBVy2stauej0E7jpnqc89pCZ1RQHQQV94T+eRlNDHdd893fs7e2rdjlmZoCDoKLmTJ/EDVecyWNbdvGR2x5m9wH3F5hZ9TkIKuwtr5jDVy57NQ9s2M6l37iftZs9DpGZVZeDoAouPXsB37nqHF7cd4hLbryPv/z+o2zo2VPtsswsp5S1d7t2dHREZ2dntcsoi537DvG1u57guw/9kd6+Ac4+cSbvOWMebzp5NifPmYqkapdoZhOEpIcjoqPkMgdB9W3dfYAfPryZH/++myeeL5wZtE9r5vT5rbxy3jReccJ0FrdNYcHMFmZMbnRAmNkxq1oQSLoA+DpQD3wrIr44bHkzcCtwNrAduDwiNo62zYkYBMWeeWEf93Vt46GnX2D9s7vo6tlD/8CR12hKUz3zZ7Ywe2ozs6Y0HX6cNaWJ1pZGpjY3MKW5gSnN9YenpzY30NxQ5wAxy7GqBIGkeuAJ4B1AN7AauDIi1het83HgjIj4qKQrgPdFxOWjbXeiB8FwvX39PLV1L8/s2Ef3jv1079jH5h372b73INv39LJ970F2Hxj7VtT6OjGpoY6mhjqaG+qTxzqaG+toqh/a1tRQaKuvEw31KjzWJfN1KnqsK1qePNbXHZ6ul6irgzoJSdSpMF0nkvkjbTq8rHg51NUVHqH4+UXr1x29TQESiMJ6g5Sso+L5ZJ3Dq5Vo07BtHl4vWWcwYEfdb1Gbhm2Tom2YpWW0IGhIcb/nAF0RsSEp4g7gEmB90TqXAF9Ipn8A/J0kRdauV6WouaGeZS+bzrKXTR9xnYN9A+zYd5Bd+w+xp7ePvb397Ok9xJ7efvb29iVtffT2DXCwb4Devv7kcWDI44v7DtKbzPcNDNDfH/QNBP0DxY8D9A8Eh/r9EqVhXAEEDM+NI/F2ZDtDlxcv04jLSjUMX178/NH2M559HZ1/x7Lt4ctH/h6MVddRVQzf9pBtjf97XWpfGnFm7Ode8dqFfPg/LD2q3pcqzSCYDzxTNN8NvG6kdSKiT9JOoA0YMni/pOXAcoBFixalVW9mNTXUMXf6JOZOn1TR/Q4kAdE3MFAIiv6jA2MgYCCCCIg4Mj/YNhDF6yTTyfNKrR8EAwNHnjfSNgvrFkREYTp5fqGtsDxKtDG4/mBbMj/470kk22RIWwzZZnEbw9Yf934ZfhzJc4Zl8PBIHv5/VPHs0esO39bIzx3uqP28xG3HKMuGb/2o546yr9H2M546GfX7d6zfg/E/9+gygtlTm4evVRZpBkHZRMQKYAUULg1VuRxL1NWJpjrR5LuQzTItzd/gzcDCovkFSVvJdSQ1AK0UOo3NzKxC0gyC1cApkpZIagKuAFYOW2cl8IFk+jLgN+4fMDOrrNQuDSXX/K8BfkHh9tGbI2KdpOuBzohYCfxf4DZJXcALFMLCzMwqKNU+gohYBawa1nZd0fQB4D+nWYOZmY3OvXxmZjnnIDAzyzkHgZlZzjkIzMxyLnOjj0rqATYd59NnM+xdyxnmY6lNE+VYJspxgI9l0IkR0V5qQeaC4KWQ1DnSoEtZ42OpTRPlWCbKcYCPZTx8acjMLOccBGZmOZe3IFhR7QLKyMdSmybKsUyU4wAfy5hy1UdgZmZHy9sZgZmZDeMgMDPLudwEgaQLJD0uqUvStdWupxRJGyX9QdIjkjqTtlmSfiXpyeRxZtIuSTckx7NG0llF2/lAsv6Tkj4w0v7KXPvNkrZKWlvUVrbaJZ2dfG+6kuem9iG/IxzLFyRtTl6bRyRdVLTsr5K6Hpf0rqL2kj9zydDsDyXt30uGaU/jOBZKulvSeknrJH06ac/c6zLKsWTxdZkk6beSHk2O5X+Mtn9Jzcl8V7J88fEe44gi+YjAifxFYRjsp4ClQBPwKLCs2nWVqHMjMHtY25eBa5Ppa4EvJdMXAT+j8DGn5wIPJe2zgA3J48xkemYFaj8POAtYm0btwG+TdZU898IKH8sXgL8sse6y5OepGViS/JzVj/YzB9wJXJFM3wR8LKXjmAeclUxPA55I6s3c6zLKsWTxdREwNZluBB5Kvocl9w98HLgpmb4C+N7xHuNIX3k5IzgH6IqIDRFxELgDuKTKNY3XJcB3kunvAO8tar81Ch4EZkiaB7wL+FVEvBARO4BfARekXWRE3EvhMyXKXnuybHpEPBiF34Bbi7ZVqWMZySXAHRHRGxFPA10Uft5K/swl/zG/FfhB8vzi70tZRcSWiPhdMr0beIzC54Rn7nUZ5VhGUsuvS0TEnmS2MfmKUfZf/Hr9AHhbUu8xHeNoNeUlCOYDzxTNdzP6D1G1BPBLSQ9LWp60zY2ILcn0c8DcZHqkY6qlYy1X7fOT6eHtlXZNcsnk5sHLKRz7sbQBL0ZE37D2VCWXE86k8N9npl+XYccCGXxdJNVLegTYSiFYnxpl/4drTpbvTOot29+AvARBVrwpIs4CLgQ+Iem84oXJf12ZvN83y7UnvgGcBLwG2AL87+qWM36SpgI/BP5bROwqXpa116XEsWTydYmI/oh4DYXPcj8HeEU168lLEGwGFhbNL0jaakpEbE4etwI/pvAD8nxyCk7yuDVZfaRjqqVjLVftm5Pp4e0VExHPJ7+8A8A3Kbw2cOzHsp3CJZeGYe2pkNRI4Q/nP0XEj5LmTL4upY4lq6/LoIh4EbgbeP0o+z9cc7K8Nam3fH8D0ugMqbUvCh/JuYFCh8pg58lp1a5rWI1TgGlF0/dTuLb/FYZ27H05mX43Qzv2fpu0zwKeptCpNzOZnlWhY1jM0A7WstXO0Z2SF1X4WOYVTf85hWuzAKcxtMNuA4XOuhF/5oDvM7RT8OMpHYMoXLf/2rD2zL0uoxxLFl+XdmBGMt0C/CvwnpH2D3yCoZ3Fdx7vMY5YU5q/TLX0ReGOiCcoXIv7fLXrKVHf0uQFexRYN1gjhWuBdwFPAr8u+gUUcGNyPH8AOoq29SEKHUddwFUVqv92Cqfmhyhck7y6nLUDHcDa5Dl/R/Ku+Aoey21JrWuAlcP+AH0+qetxiu6aGelnLnmtf5sc4/eB5pSO400ULvusAR5Jvi7K4usyyrFk8XU5A/h9UvNa4LrR9g9MSua7kuVLj/cYR/ryEBNmZjmXlz4CMzMbgYPAzCznHARmZjnnIDAzyzkHgZlZzjkIzBKS+otGsXxkXKM2jn/bi1U0mqlZLWkYexWz3Ngfhbf9m+WKzwjMxqDC50R8ORl3/7eSTk7aF0v6TTLg2V2SFiXtcyX9OBlv/lFJb0g2VS/pm8kY9L+U1JKs/6lknP01ku6o0mFajjkIzI5oGXZp6PKiZTsj4nQK7579WtL2t8B3IuIM4J+AG5L2G4B7IuLVFD7XYF3SfgpwY0ScBrwIXJq0XwucmWzno2kdnNlI/M5is4SkPRExtUT7RuCtEbEhGfjsuYhok7SNwpAGh5L2LRExW1IPsCAieou2sZjCmP6nJPOfAxoj4n9K+jmwB/gJ8JM4Mla9WUX4jMBsfGKE6WPRWzTdz5E+undTGOPnLGB10QiUZhXhIDAbn8uLHh9Ipu+nMBokwJ9SGEUSCgO6fQwOfwBJ60gblVQHLIyIu4HPURhi+KizErM0+T8PsyNakk+NGvTziBi8hXSmpDUU/qu/Mmn7JPBtSZ8FeoCrkvZPAyskXU3hP/+PURjNtJR64B+TsBBwQxTGqDerGPcRmI0h6SPoiIht1a7FLA2+NGRmlnM+IzAzyzmfEZiZ5ZyDwMws5xwEZmY55yAwM8s5B4GZWc79f1IRuS3sir74AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot([i for i in range(epochs)], losses)\n", "plt.xlabel('Epochs')\n", "plt.ylabel('Loss');" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- After training ---\n", "Prediction for [0, 0] (should be 0): 0.02\n", "Prediction for [0, 1] (should be 1): 0.99\n", "Prediction for [1, 0] (should be 1): 0.99\n", "Prediction for [1, 1] (should be 0): 0.02\n" ] } ], "source": [ "# Print out what the NN predicts after training (updating the weights and biases)\n", "print('--- After training ---')\n", "print('Prediction for [0, 0] (should be 0): {num:.{digits}f}'.format(num=nn.forward([0, 0]), digits=2))\n", "print('Prediction for [0, 1] (should be 1): {num:.{digits}f}'.format(num=nn.forward([0, 1]), digits=2))\n", "print('Prediction for [1, 0] (should be 1): {num:.{digits}f}'.format(num=nn.forward([1, 0]), digits=2))\n", "print('Prediction for [1, 1] (should be 0): {num:.{digits}f}'.format(num=nn.forward([1, 1]), digits=2))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }