{ "cells": [ { "cell_type": "markdown", "id": "787060ce-57f5-4951-b1ef-c2b27e9c4d72", "metadata": {}, "source": [ "# Notebook 4: Activation [](https://colab.research.google.com/github/mattsankner/micrograd/blob/main/mg4_activation.ipynb) [](https://nbviewer.jupyter.org/github/mattsankner/micrograd/blob/main/mg4_activation.ipynb)" ] }, { "cell_type": "markdown", "id": "36145db7-df04-4169-af61-2e1636139018", "metadata": {}, "source": [ "# Welcome to the fourth lecture! Here we will get more granular and controlled with our ```tanh()``` function." ] }, { "cell_type": "markdown", "id": "9c79f45c-e6d9-4d43-8c42-2a8353cb2065", "metadata": {}, "source": [ "## Now that we have automated our forward and backward pass for a single neuron, did you find the error in the previous code? Run the code below until the next text block to find out what it was and how to fix it." ] }, { "cell_type": "code", "execution_count": 1, "id": "1ec7ba82-fc87-45b7-ab2a-36feaf1ee2f1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: graphviz in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (0.20.3)\n", "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install graphviz\n", "!export PATH=\"/usr/local/opt/graphviz/bin:$PATH\"" ] }, { "cell_type": "code", "execution_count": 13, "id": "984291a4-0274-4f3f-a5c9-b4b529f2f1c7", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import math" ] }, { "cell_type": "code", "execution_count": 10, "id": "0d369653-eee5-48c2-8f36-c432fe364ab3", "metadata": {}, "outputs": [], "source": [ "from graphviz import Digraph #graphviz is an opensource vizualization software. We are building out this graph in graphviz API. \n", "\n", "def trace(root): #helper function that enumerates the ndoes and edges in the graph\n", " # builds a set of all nodes and edges in a graph\n", " nodes, edges = set(), set()\n", " def build(v):\n", " if v not in nodes:\n", " nodes.add(v)\n", " for child in v._prev:\n", " edges.add((child, v))\n", " build(child)\n", " build(root)\n", " return nodes, edges\n", "\n", "def draw_dot(root): #creating op nodes (not Value objects)\n", " dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right\n", " \n", " nodes, edges = trace(root) #call trace\n", " for n in nodes:\n", " uid = str(id(n))\n", " # for any value in the graph, create a rectangular ('record') node for it\n", " dot.node(name = uid, label = \"{ %s | data %.4f | grad % .4f}\" % (n.label, n.data, n.grad), shape='record')\n", " if n._op:\n", " # if this value is a result of some operation, create an op node for it. Not a value object\n", " dot.node(name = uid + n._op, label = n._op)\n", " # and connect this node to it\n", " dot.edge(uid + n._op, uid)\n", "\n", " for n1, n2 in edges:\n", " # connect n1 to the op node of n2\n", " dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n", "\n", " return dot" ] }, { "cell_type": "code", "execution_count": 23, "id": "f51b0db0-05fb-400a-93c8-cf53ac3bb150", "metadata": {}, "outputs": [], "source": [ "class Value:\n", " \n", " def __init__(self, data, _children=(), _op='', label=''): \n", " self.data = data\n", " self.grad = 0.0 \n", " self._prev = set(_children)\n", " self._backward = lambda : None\n", " self._op = _op\n", " self.label=label\n", "\n", " def __repr__(self):\n", " return f\"Value(data={self.data})\"\n", "\n", " def __add__(self, other):\n", " out = Value(self.data + other.data, (self, other), '+') \n", " \n", " def _backward(): \n", " self.grad = 1.0 * out.grad \n", " other.grad = 1.0 * out.grad \n", " out._backward = _backward\n", " return out\n", " \n", " def __mul__(self, other):\n", " out = Value(self.data * other.data, (self, other), '*')\n", "\n", " def _backward():\n", " self.grad = other.data * out.grad \n", " other.grad = self.data * out.grad\n", " out._backward = _backward\n", " return out\n", " \n", " def tanh(self):\n", " x = self.data\n", " t = (math.exp(2*x)-1)/(math.exp(2*x)+1)\n", " \n", " out = Value(t, (self, ), 'tanh')\n", "\n", " def _backward():\n", " self.grad = (1-t**2) * out.grad\n", " out._backward = _backward\n", " return out\n", "\n", " #define backward() without the underscore. \n", " def backward(self): #Build the topological graph starting at self\n", " topo = [] #topological list where we populate the order\n", " visited = set() #maintain a set of visited nodes\n", " def build_topo(v): #start at the root node\n", " if v not in visited: \n", " visited.add(v)\n", " for child in v._prev:\n", " build_topo(child)\n", " topo.append(v)\n", " build_topo(self)\n", "\n", " self.grad = 1.0 #initalize root.grad\n", " for node in reversed(topo):\n", " node._backward() #call _backward() and do backpropogation on all of the children. " ] }, { "cell_type": "code", "execution_count": 15, "id": "8996f017-5597-4eba-aac1-7306d1589483", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 11.0.0 (20240428.1522)\n", " -->\n", "<!-- Pages: 1 -->\n", "<svg width=\"1567pt\" height=\"210pt\"\n", " viewBox=\"0.00 0.00 1566.50 210.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 206)\">\n", "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-206 1562.5,-206 1562.5,4 -4,4\"/>\n", "<!-- 4458697808 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>4458697808</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-165.5 0,-201.5 198,-201.5 198,-165.5 0,-165.5\"/>\n", "<text text-anchor=\"middle\" x=\"16.25\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"32.5,-166 32.5,-201.5\"/>\n", "<text text-anchor=\"middle\" x=\"74.62\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"116.75,-166 116.75,-201.5\"/>\n", "<text text-anchor=\"middle\" x=\"157.38\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4461558416* -->\n", "<g id=\"node6\" class=\"node\">\n", "<title>4461558416*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"261\" cy=\"-128.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"261\" y=\"-123.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4458697808->4461558416* -->\n", "<g id=\"edge13\" class=\"edge\">\n", "<title>4458697808->4461558416*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M172.12,-165.01C180.9,-162.35 189.7,-159.5 198,-156.5 208.09,-152.86 218.82,-148.27 228.47,-143.88\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"229.86,-147.09 237.45,-139.69 226.9,-140.75 229.86,-147.09\"/>\n", "</g>\n", "<!-- 4461003856 -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>4461003856</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1.5,-110.5 1.5,-146.5 196.5,-146.5 196.5,-110.5 1.5,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"16.25\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"31,-111 31,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"70.88\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 2.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"110.75,-111 110.75,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"153.62\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -1.5000</text>\n", "</g>\n", "<!-- 4461003856->4461558416* -->\n", "<g id=\"edge14\" class=\"edge\">\n", "<title>4461003856->4461558416*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M196.76,-128.5C205.77,-128.5 214.47,-128.5 222.4,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"222.25,-132 232.25,-128.5 222.25,-125 222.25,-132\"/>\n", "</g>\n", "<!-- 4461551760 -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>4461551760</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1374.75,-54.5 1374.75,-90.5 1558.5,-90.5 1558.5,-54.5 1374.75,-54.5\"/>\n", "<text text-anchor=\"middle\" x=\"1386.12\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">o</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1397.5,-55 1397.5,-90.5\"/>\n", "<text text-anchor=\"middle\" x=\"1437.38\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.7071</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1477.25,-55 1477.25,-90.5\"/>\n", "<text text-anchor=\"middle\" x=\"1517.88\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4461551760tanh -->\n", "<g id=\"node4\" class=\"node\">\n", "<title>4461551760tanh</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1311.75\" cy=\"-72.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1311.75\" y=\"-67.45\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n", "</g>\n", "<!-- 4461551760tanh->4461551760 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>4461551760tanh->4461551760</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1339.03,-72.5C1346.17,-72.5 1354.36,-72.5 1363.07,-72.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1362.81,-76 1372.81,-72.5 1362.81,-69 1362.81,-76\"/>\n", "</g>\n", "<!-- 4461558416 -->\n", "<g id=\"node5\" class=\"node\">\n", "<title>4461558416</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"324,-110.5 324,-146.5 542.25,-146.5 542.25,-110.5 324,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"350.38\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"376.75,-111 376.75,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"418.88\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"461,-111 461,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"501.62\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4461559504+ -->\n", "<g id=\"node10\" class=\"node\">\n", "<title>4461559504+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"605.25\" cy=\"-100.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"605.25\" y=\"-95.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4461558416->4461559504+ -->\n", "<g id=\"edge8\" class=\"edge\">\n", "<title>4461558416->4461559504+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M542.35,-110.69C551.12,-109.24 559.54,-107.86 567.19,-106.6\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"567.72,-110.06 577.02,-104.98 566.58,-103.15 567.72,-110.06\"/>\n", "</g>\n", "<!-- 4461558416*->4461558416 -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>4461558416*->4461558416</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M288.21,-128.5C295.29,-128.5 303.43,-128.5 312.17,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"312.01,-132 322.01,-128.5 312.01,-125 312.01,-132\"/>\n", "</g>\n", "<!-- 4461384336 -->\n", "<g id=\"node7\" class=\"node\">\n", "<title>4461384336</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"711.75,-27.5 711.75,-63.5 895.5,-63.5 895.5,-27.5 711.75,-27.5\"/>\n", "<text text-anchor=\"middle\" x=\"723.12\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"734.5,-28 734.5,-63.5\"/>\n", "<text text-anchor=\"middle\" x=\"774.38\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"814.25,-28 814.25,-63.5\"/>\n", "<text text-anchor=\"middle\" x=\"854.88\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4461544656+ -->\n", "<g id=\"node12\" class=\"node\">\n", "<title>4461544656+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1002\" cy=\"-72.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1002\" y=\"-67.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4461384336->4461544656+ -->\n", "<g id=\"edge10\" class=\"edge\">\n", "<title>4461384336->4461544656+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M895.91,-58.05C919.97,-61.36 944.58,-64.74 963.93,-67.4\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"963.36,-70.86 973.75,-68.75 964.32,-63.92 963.36,-70.86\"/>\n", "</g>\n", "<!-- 4461353616 -->\n", "<g id=\"node8\" class=\"node\">\n", "<title>4461353616</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"3.75,-55.5 3.75,-91.5 194.25,-91.5 194.25,-55.5 3.75,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"18.5\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"33.25,-56 33.25,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"73.12\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"113,-56 113,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"153.62\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4461555600* -->\n", "<g id=\"node14\" class=\"node\">\n", "<title>4461555600*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"261\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"261\" y=\"-68.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4461353616->4461555600* -->\n", "<g id=\"edge7\" class=\"edge\">\n", "<title>4461353616->4461555600*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M194.46,-73.5C204.21,-73.5 213.66,-73.5 222.21,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"222.12,-77 232.12,-73.5 222.12,-70 222.12,-77\"/>\n", "</g>\n", "<!-- 4461559504 -->\n", "<g id=\"node9\" class=\"node\">\n", "<title>4461559504</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"668.25,-82.5 668.25,-118.5 939,-118.5 939,-82.5 668.25,-82.5\"/>\n", "<text text-anchor=\"middle\" x=\"720.88\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1 + x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"773.5,-83 773.5,-118.5\"/>\n", "<text text-anchor=\"middle\" x=\"815.62\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"857.75,-83 857.75,-118.5\"/>\n", "<text text-anchor=\"middle\" x=\"898.38\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4461559504->4461544656+ -->\n", "<g id=\"edge9\" class=\"edge\">\n", "<title>4461559504->4461544656+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M934.24,-82.02C944.99,-80.49 955.18,-79.03 964.21,-77.75\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"964.46,-81.25 973.86,-76.37 963.47,-74.32 964.46,-81.25\"/>\n", "</g>\n", "<!-- 4461559504+->4461559504 -->\n", "<g id=\"edge3\" class=\"edge\">\n", "<title>4461559504+->4461559504</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M632.73,-100.5C639.73,-100.5 647.79,-100.5 656.52,-100.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"656.39,-104 666.39,-100.5 656.39,-97 656.39,-104\"/>\n", "</g>\n", "<!-- 4461544656 -->\n", "<g id=\"node11\" class=\"node\">\n", "<title>4461544656</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1065,-54.5 1065,-90.5 1248.75,-90.5 1248.75,-54.5 1065,-54.5\"/>\n", "<text text-anchor=\"middle\" x=\"1076.38\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">n</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1087.75,-55 1087.75,-90.5\"/>\n", "<text text-anchor=\"middle\" x=\"1127.62\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1167.5,-55 1167.5,-90.5\"/>\n", "<text text-anchor=\"middle\" x=\"1208.12\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4461544656->4461551760tanh -->\n", "<g id=\"edge12\" class=\"edge\">\n", "<title>4461544656->4461551760tanh</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1249.01,-72.5C1257.39,-72.5 1265.52,-72.5 1272.98,-72.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1272.95,-76 1282.95,-72.5 1272.95,-69 1272.95,-76\"/>\n", "</g>\n", "<!-- 4461544656+->4461544656 -->\n", "<g id=\"edge4\" class=\"edge\">\n", "<title>4461544656+->4461544656</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1029.28,-72.5C1036.42,-72.5 1044.61,-72.5 1053.32,-72.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1053.06,-76 1063.06,-72.5 1053.06,-69 1053.06,-76\"/>\n", "</g>\n", "<!-- 4461555600 -->\n", "<g id=\"node13\" class=\"node\">\n", "<title>4461555600</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"326.25,-55.5 326.25,-91.5 540,-91.5 540,-55.5 326.25,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"352.62\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"379,-56 379,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"418.88\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"458.75,-56 458.75,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"499.38\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4461555600->4461559504+ -->\n", "<g id=\"edge6\" class=\"edge\">\n", "<title>4461555600->4461559504+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M540.42,-90.37C549.89,-91.87 559,-93.32 567.21,-94.62\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"566.63,-98.07 577.05,-96.18 567.72,-91.16 566.63,-98.07\"/>\n", "</g>\n", "<!-- 4461555600*->4461555600 -->\n", "<g id=\"edge5\" class=\"edge\">\n", "<title>4461555600*->4461555600</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M288.21,-73.5C296,-73.5 305.08,-73.5 314.82,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"314.55,-77 324.55,-73.5 314.55,-70 314.55,-77\"/>\n", "</g>\n", "<!-- 4461383120 -->\n", "<g id=\"node15\" class=\"node\">\n", "<title>4461383120</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2.25,-0.5 2.25,-36.5 195.75,-36.5 195.75,-0.5 2.25,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"18.5\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"34.75,-1 34.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"74.62\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"114.5,-1 114.5,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"155.12\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 4461383120->4461555600* -->\n", "<g id=\"edge11\" class=\"edge\">\n", "<title>4461383120->4461555600*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M172.12,-36.99C180.9,-39.65 189.7,-42.5 198,-45.5 208.09,-49.14 218.82,-53.73 228.47,-58.12\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"226.9,-61.25 237.45,-62.31 229.86,-54.91 226.9,-61.25\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x109e97f90>" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#FINISH IT!\n", "o.backward()\n", "draw_dot(o)" ] }, { "cell_type": "markdown", "id": "c29db3b1-6661-485c-a889-086aeafae7a2", "metadata": {}, "source": [ "### Did you find the bug?\n", "\n", "Try running the following code:" ] }, { "cell_type": "code", "execution_count": 22, "id": "22b64644-af1d-4d5b-8e47-b56cf316d3e8", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 11.0.0 (20240428.1522)\n", " -->\n", "<!-- Pages: 1 -->\n", "<svg width=\"501pt\" height=\"45pt\"\n", " viewBox=\"0.00 0.00 500.75 45.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 41)\">\n", "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-41 496.75,-41 496.75,4 -4,4\"/>\n", "<!-- 4641931088 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>4641931088</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 183,-36.5 183,-0.5 0,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"11\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">a</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"22,-1 22,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"61.88\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"101.75,-1 101.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"142.38\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 2.0000</text>\n", "</g>\n", "<!-- 4463965136+ -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>4463965136+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"246\" cy=\"-18.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"246\" y=\"-13.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4641931088->4463965136+ -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>4641931088->4463965136+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M183.41,-18.5C191.77,-18.5 199.88,-18.5 207.32,-18.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"207.26,-22 217.26,-18.5 207.26,-15 207.26,-22\"/>\n", "</g>\n", "<!-- 4463965136 -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>4463965136</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"309,-0.5 309,-36.5 492.75,-36.5 492.75,-0.5 309,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"320.38\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"331.75,-1 331.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"371.62\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"411.5,-1 411.5,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"452.12\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4463965136+->4463965136 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>4463965136+->4463965136</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M273.28,-18.5C280.42,-18.5 288.61,-18.5 297.32,-18.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"297.06,-22 307.06,-18.5 297.06,-15 297.06,-22\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x10a1206d0>" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = Value(3.0, label='a')\n", "b = a + a; b.label = 'b'\n", "b.backward()\n", "draw_dot(b)" ] }, { "cell_type": "markdown", "id": "abe321a3-0a0e-409a-a0e9-c128f6420bb0", "metadata": {}, "source": [ "```a``` funnels two arrows on top of each other in its forward pass to ```b```, but what do you notice about the gradient?\n", "\n", "\n", "The derivative of ```b``` with respect to ```a``` should be $2$, or $1 + 1$. That's the chain rule. Instead, ```a's``` gradient is $1.0$. \n", "\n", "We are going to run into this issue every time we use a variable more than once in expression. This is because in the ```Value``` class, we override ```self``` and ```other``` accidentally when resetting their values. Since ```self``` and ```other``` both point to the same object ```a```, ```other``` overrides ```self```.\n", "\n", "``` python\n", "def __add__(self, other):\n", " out = Value(self.data + other.data, (self, other), '+') \n", " \n", " def _backward(): \n", " self.grad = 1.0 * out.grad #a.grad = 1.0 * 1.0\n", " other.grad = 1.0 * out.grad #a.grad = 1.0 * 1.0 (resetting a.grad) \n", " out._backward = _backward\n", " return out\n", "```" ] }, { "cell_type": "markdown", "id": "b0b44f12-dc5f-4040-a635-57a998ffe158", "metadata": {}, "source": [ "We can see this in a more complicated example. In the below example, ```d.backward()``` overrides ```e.backward()```, as ```d's``` gradients are correct but ```e```'s aren't. For you, it might be the other way around, but either way it won't work. " ] }, { "cell_type": "code", "execution_count": 25, "id": "b526392f-642e-4816-a32b-d7bd94f06f7f", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 11.0.0 (20240428.1522)\n", " -->\n", "<!-- Pages: 1 -->\n", "<svg width=\"826pt\" height=\"100pt\"\n", " viewBox=\"0.00 0.00 826.25 100.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 96)\">\n", "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-96 822.25,-96 822.25,4 -4,4\"/>\n", "<!-- 4464208976 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>4464208976</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"318.38,-55.5 318.38,-91.5 505.88,-91.5 505.88,-55.5 318.38,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"329.38\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">e</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"340.38,-56 340.38,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"380.25\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"420.12,-56 420.12,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"463\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -6.0000</text>\n", "</g>\n", "<!-- 4464209616* -->\n", "<g id=\"node5\" class=\"node\">\n", "<title>4464209616*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"569.25\" cy=\"-45.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"569.25\" y=\"-40.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4464208976->4464209616* -->\n", "<g id=\"edge5\" class=\"edge\">\n", "<title>4464208976->4464209616*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M506.05,-56.73C514.96,-55.12 523.59,-53.56 531.44,-52.14\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"531.99,-55.6 541.21,-50.38 530.75,-48.71 531.99,-55.6\"/>\n", "</g>\n", "<!-- 4464208976+ -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>4464208976+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"255\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"255\" y=\"-68.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4464208976+->4464208976 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>4464208976+->4464208976</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M282.31,-73.5C289.55,-73.5 297.86,-73.5 306.7,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"306.62,-77 316.62,-73.5 306.62,-70 306.62,-77\"/>\n", "</g>\n", "<!-- 4641917584 -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>4641917584</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-55.5 0,-91.5 192,-91.5 192,-55.5 0,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"11\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">a</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"22,-56 22,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"64.12\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -2.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"106.25,-56 106.25,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"149.12\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -6.0000</text>\n", "</g>\n", "<!-- 4641917584->4464208976+ -->\n", "<g id=\"edge6\" class=\"edge\">\n", "<title>4641917584->4464208976+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M192.4,-73.5C200.73,-73.5 208.79,-73.5 216.18,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"216.05,-77 226.05,-73.5 216.05,-70 216.05,-77\"/>\n", "</g>\n", "<!-- 4464211792* -->\n", "<g id=\"node7\" class=\"node\">\n", "<title>4464211792*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"255\" cy=\"-18.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"255\" y=\"-13.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4641917584->4464211792* -->\n", "<g id=\"edge7\" class=\"edge\">\n", "<title>4641917584->4464211792*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M166.39,-55.06C175.08,-52.38 183.79,-49.5 192,-46.5 201.99,-42.85 212.63,-38.3 222.22,-33.93\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"223.55,-37.17 231.15,-29.78 220.6,-30.83 223.55,-37.17\"/>\n", "</g>\n", "<!-- 4464209616 -->\n", "<g id=\"node4\" class=\"node\">\n", "<title>4464209616</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"632.25,-27.5 632.25,-63.5 818.25,-63.5 818.25,-27.5 632.25,-27.5\"/>\n", "<text text-anchor=\"middle\" x=\"642.5\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">f</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"652.75,-28 652.75,-63.5\"/>\n", "<text text-anchor=\"middle\" x=\"694.88\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"737,-28 737,-63.5\"/>\n", "<text text-anchor=\"middle\" x=\"777.62\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4464209616*->4464209616 -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>4464209616*->4464209616</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M596.72,-45.5C603.86,-45.5 612.03,-45.5 620.71,-45.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"620.44,-49 630.44,-45.5 620.44,-42 620.44,-49\"/>\n", "</g>\n", "<!-- 4464211792 -->\n", "<g id=\"node6\" class=\"node\">\n", "<title>4464211792</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"318,-0.5 318,-36.5 506.25,-36.5 506.25,-0.5 318,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"329.38\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">d</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"340.75,-1 340.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"382.88\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"425,-1 425,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"465.62\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4464211792->4464209616* -->\n", "<g id=\"edge8\" class=\"edge\">\n", "<title>4464211792->4464209616*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M506.49,-34.75C515.27,-36.28 523.77,-37.76 531.51,-39.1\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"530.66,-42.51 541.11,-40.78 531.86,-35.61 530.66,-42.51\"/>\n", "</g>\n", "<!-- 4464211792*->4464211792 -->\n", "<g id=\"edge3\" class=\"edge\">\n", "<title>4464211792*->4464211792</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M282.31,-18.5C289.49,-18.5 297.72,-18.5 306.47,-18.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"306.29,-22 316.29,-18.5 306.29,-15 306.29,-22\"/>\n", "</g>\n", "<!-- 4458511248 -->\n", "<g id=\"node8\" class=\"node\">\n", "<title>4458511248</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1.88,-0.5 1.88,-36.5 190.12,-36.5 190.12,-0.5 1.88,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"13.25\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"24.62,-1 24.62,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"64.5\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"104.38,-1 104.38,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"147.25\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -6.0000</text>\n", "</g>\n", "<!-- 4458511248->4464208976+ -->\n", "<g id=\"edge9\" class=\"edge\">\n", "<title>4458511248->4464208976+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M163.53,-36.96C173.14,-39.96 182.87,-43.18 192,-46.5 201.86,-50.08 212.39,-54.46 221.92,-58.63\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"220.23,-61.71 230.79,-62.59 223.09,-55.32 220.23,-61.71\"/>\n", "</g>\n", "<!-- 4458511248->4464211792* -->\n", "<g id=\"edge4\" class=\"edge\">\n", "<title>4458511248->4464211792*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M190.59,-18.5C199.58,-18.5 208.29,-18.5 216.23,-18.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"216.11,-22 226.11,-18.5 216.11,-15 216.11,-22\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x10a1669d0>" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = Value(-2.0, label='a')\n", "b = Value(3.0, label= 'b')\n", "d = a * b; d.label = 'd'\n", "e = a + b; e.label = 'e'\n", "f = d * e; f.label = 'f'\n", "\n", "f.backward()\n", "\n", "draw_dot(f)" ] }, { "cell_type": "markdown", "id": "20f0da93-7dce-49f5-89a9-2f2ee73f7a22", "metadata": {}, "source": [ "If we look at the multivariable case of the chain rule on wikipedia, the solution is simple: we accumulate the gradients. Instead of setting gradients, we simply do ```+= self.grad```. We instead deposit the gradients from each branch, even it's the same variable. This way they add on top of one another. \n", "\n", "We reinitialize the ```Value``` class with a different ```_backward()``` inside of our ```__add__```:" ] }, { "cell_type": "code", "execution_count": 26, "id": "3761db50-8791-4db6-a172-fbde3a339416", "metadata": {}, "outputs": [], "source": [ "class Value:\n", " \n", " def __init__(self, data, _children=(), _op='', label=''): \n", " self.data = data\n", " self.grad = 0.0 \n", " self._prev = set(_children)\n", " self._backward = lambda : None\n", " self._op = _op\n", " self.label=label\n", "\n", " def __repr__(self):\n", " return f\"Value(data={self.data})\"\n", "\n", " def __add__(self, other):\n", " out = Value(self.data + other.data, (self, other), '+') \n", " \n", " def _backward(): \n", " self.grad += 1.0 * out.grad #doesn't reset anymore\n", " other.grad += 1.0 * out.grad #doesn't reset anymore\n", " out._backward = _backward\n", " return out\n", " \n", " def __mul__(self, other):\n", " out = Value(self.data * other.data, (self, other), '*')\n", "\n", " def _backward():\n", " self.grad = other.data * out.grad \n", " other.grad = self.data * out.grad\n", " out._backward = _backward\n", " return out\n", " \n", " def tanh(self):\n", " x = self.data\n", " t = (math.exp(2*x)-1)/(math.exp(2*x)+1)\n", " \n", " out = Value(t, (self, ), 'tanh')\n", "\n", " def _backward():\n", " self.grad = (1-t**2) * out.grad\n", " out._backward = _backward\n", " return out\n", "\n", " #define backward() without the underscore. \n", " def backward(self): #Build the topological graph starting at self\n", " topo = [] #topological list where we populate the order\n", " visited = set() #maintain a set of visited nodes\n", " def build_topo(v): #start at the root node\n", " if v not in visited: \n", " visited.add(v)\n", " for child in v._prev:\n", " build_topo(child)\n", " topo.append(v)\n", " build_topo(self)\n", "\n", " self.grad = 1.0 #initalize root.grad\n", " for node in reversed(topo):\n", " node._backward() #call _backward() and do backpropogation on all of the children. " ] }, { "cell_type": "code", "execution_count": 28, "id": "f99f2a72-0bfd-4f08-b385-44fb662d817a", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 11.0.0 (20240428.1522)\n", " -->\n", "<!-- Pages: 1 -->\n", "<svg width=\"501pt\" height=\"45pt\"\n", " viewBox=\"0.00 0.00 500.75 45.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 41)\">\n", "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-41 496.75,-41 496.75,4 -4,4\"/>\n", "<!-- 4464200720 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>4464200720</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"309,-0.5 309,-36.5 492.75,-36.5 492.75,-0.5 309,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"320.38\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"331.75,-1 331.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"371.62\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"411.5,-1 411.5,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"452.12\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4464200720+ -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>4464200720+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"246\" cy=\"-18.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"246\" y=\"-13.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4464200720+->4464200720 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>4464200720+->4464200720</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M273.28,-18.5C280.42,-18.5 288.61,-18.5 297.32,-18.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"297.06,-22 307.06,-18.5 297.06,-15 297.06,-22\"/>\n", "</g>\n", "<!-- 4464201040 -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>4464201040</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 183,-36.5 183,-0.5 0,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"11\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">a</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"22,-1 22,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"61.88\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"101.75,-1 101.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"142.38\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 2.0000</text>\n", "</g>\n", "<!-- 4464201040->4464200720+ -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>4464201040->4464200720+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M183.41,-18.5C191.77,-18.5 199.88,-18.5 207.32,-18.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"207.26,-22 217.26,-18.5 207.26,-15 207.26,-22\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x10a1660d0>" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#try again\n", "a = Value(3.0, label='a')\n", "b = a + a; b.label = 'b'\n", "b.backward()\n", "draw_dot(b) #a.grad should be 2" ] }, { "cell_type": "code", "execution_count": 31, "id": "f1bb25ad-c6f3-493b-977c-e5081bd7822c", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 11.0.0 (20240428.1522)\n", " -->\n", "<!-- Pages: 1 -->\n", "<svg width=\"826pt\" height=\"100pt\"\n", " viewBox=\"0.00 0.00 826.25 100.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 96)\">\n", "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-96 822.25,-96 822.25,4 -4,4\"/>\n", "<!-- 4464229136 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>4464229136</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-55.5 0,-91.5 192,-91.5 192,-55.5 0,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"11\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">a</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"22,-56 22,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"64.12\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -2.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"106.25,-56 106.25,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"149.12\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -3.0000</text>\n", "</g>\n", "<!-- 4464228752+ -->\n", "<g id=\"node6\" class=\"node\">\n", "<title>4464228752+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"255\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"255\" y=\"-68.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4464229136->4464228752+ -->\n", "<g id=\"edge4\" class=\"edge\">\n", "<title>4464229136->4464228752+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M192.4,-73.5C200.73,-73.5 208.79,-73.5 216.18,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"216.05,-77 226.05,-73.5 216.05,-70 216.05,-77\"/>\n", "</g>\n", "<!-- 4464227280* -->\n", "<g id=\"node8\" class=\"node\">\n", "<title>4464227280*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"255\" cy=\"-18.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"255\" y=\"-13.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4464229136->4464227280* -->\n", "<g id=\"edge8\" class=\"edge\">\n", "<title>4464229136->4464227280*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M166.39,-55.06C175.08,-52.38 183.79,-49.5 192,-46.5 201.99,-42.85 212.63,-38.3 222.22,-33.93\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"223.55,-37.17 231.15,-29.78 220.6,-30.83 223.55,-37.17\"/>\n", "</g>\n", "<!-- 4464217040 -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>4464217040</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"632.25,-27.5 632.25,-63.5 818.25,-63.5 818.25,-27.5 632.25,-27.5\"/>\n", "<text text-anchor=\"middle\" x=\"642.5\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">f</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"652.75,-28 652.75,-63.5\"/>\n", "<text text-anchor=\"middle\" x=\"694.88\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"737,-28 737,-63.5\"/>\n", "<text text-anchor=\"middle\" x=\"777.62\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4464217040* -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>4464217040*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"569.25\" cy=\"-45.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"569.25\" y=\"-40.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4464217040*->4464217040 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>4464217040*->4464217040</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M596.72,-45.5C603.86,-45.5 612.03,-45.5 620.71,-45.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"620.44,-49 630.44,-45.5 620.44,-42 620.44,-49\"/>\n", "</g>\n", "<!-- 4464228240 -->\n", "<g id=\"node4\" class=\"node\">\n", "<title>4464228240</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1.88,-0.5 1.88,-36.5 190.12,-36.5 190.12,-0.5 1.88,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"13.25\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"24.62,-1 24.62,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"64.5\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"104.38,-1 104.38,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"147.25\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -8.0000</text>\n", "</g>\n", "<!-- 4464228240->4464228752+ -->\n", "<g id=\"edge7\" class=\"edge\">\n", "<title>4464228240->4464228752+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M163.53,-36.96C173.14,-39.96 182.87,-43.18 192,-46.5 201.86,-50.08 212.39,-54.46 221.92,-58.63\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"220.23,-61.71 230.79,-62.59 223.09,-55.32 220.23,-61.71\"/>\n", "</g>\n", "<!-- 4464228240->4464227280* -->\n", "<g id=\"edge5\" class=\"edge\">\n", "<title>4464228240->4464227280*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M190.59,-18.5C199.58,-18.5 208.29,-18.5 216.23,-18.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"216.11,-22 226.11,-18.5 216.11,-15 216.11,-22\"/>\n", "</g>\n", "<!-- 4464228752 -->\n", "<g id=\"node5\" class=\"node\">\n", "<title>4464228752</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"318.38,-55.5 318.38,-91.5 505.88,-91.5 505.88,-55.5 318.38,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"329.38\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">e</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"340.38,-56 340.38,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"380.25\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"420.12,-56 420.12,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"463\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -6.0000</text>\n", "</g>\n", "<!-- 4464228752->4464217040* -->\n", "<g id=\"edge6\" class=\"edge\">\n", "<title>4464228752->4464217040*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M506.05,-56.73C514.96,-55.12 523.59,-53.56 531.44,-52.14\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"531.99,-55.6 541.21,-50.38 530.75,-48.71 531.99,-55.6\"/>\n", "</g>\n", "<!-- 4464228752+->4464228752 -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>4464228752+->4464228752</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M282.31,-73.5C289.55,-73.5 297.86,-73.5 306.7,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"306.62,-77 316.62,-73.5 306.62,-70 306.62,-77\"/>\n", "</g>\n", "<!-- 4464227280 -->\n", "<g id=\"node7\" class=\"node\">\n", "<title>4464227280</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"318,-0.5 318,-36.5 506.25,-36.5 506.25,-0.5 318,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"329.38\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">d</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"340.75,-1 340.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"382.88\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"425,-1 425,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"465.62\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4464227280->4464217040* -->\n", "<g id=\"edge9\" class=\"edge\">\n", "<title>4464227280->4464217040*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M506.49,-34.75C515.27,-36.28 523.77,-37.76 531.51,-39.1\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"530.66,-42.51 541.11,-40.78 531.86,-35.61 530.66,-42.51\"/>\n", "</g>\n", "<!-- 4464227280*->4464227280 -->\n", "<g id=\"edge3\" class=\"edge\">\n", "<title>4464227280*->4464227280</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M282.31,-18.5C289.49,-18.5 297.72,-18.5 306.47,-18.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"306.29,-22 316.29,-18.5 306.29,-15 306.29,-22\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x10a168a90>" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#try again\n", "a = Value(-2.0, label='a')\n", "b = Value(3.0, label= 'b')\n", "d = a * b; d.label = 'd'\n", "e = a + b; e.label = 'e'\n", "f = d * e; f.label = 'f'\n", "\n", "f.backward()\n", "\n", "draw_dot(f) #a.grad should be 3 + -6 = -3; b.grad should be -2 + -6 = -8\n", "#this way each separate branch can deposit their gradients on top of each other" ] }, { "cell_type": "markdown", "id": "079b1585-acf7-4896-95c3-4a008f1373dc", "metadata": {}, "source": [ "## Now, we'll go through a few more exercises. \n", "\n", "Remember how we said earlier that we could have implemented ```tanh()``` as a series of explicit expressions with the ```exp()``` function? We instead built it as a single function because we knew its derivative and could backpropogate through it, but here we will implement it as a result of exponents, subtraction, etc. Remember the original ```tanh()``` function:" ] }, { "cell_type": "markdown", "id": "77aed594-fb31-42b3-a7b2-b7f7e1930356", "metadata": {}, "source": [ "$$tanh(x)={\\frac {\\sinh x}{\\cosh x}}={\\frac {e^{x}-e^{-x}}{e^{x}+e^{-x}}}={\\frac {e^{2x}-1}{e^{2x}+1}}$$\n", "\n", "\n", "\n", "Let's test to see where we can improve functionality. What happens when we add a number to a ```Value``` object?" ] }, { "cell_type": "code", "execution_count": null, "id": "9a3c581d-4726-44ab-afc4-08d8797dec60", "metadata": {}, "outputs": [], "source": [ "a = Value(2.0)\n", "#a + 1 and a * 1 causes ValueError, since we are adding a int (non-Value) to a Value object" ] }, { "cell_type": "markdown", "id": "093097f6-649a-4167-90bf-20d25aa77630", "metadata": {}, "source": [ "We solve this by reinitializing the Value class with the following code that checks if ```other``` is an instance of ```Value```, and if not, we assume ```other``` is an ```int``` and create a ```Value``` object with it. We do this for both ```__add__``` and ```__mul__```\n", "\n", "``` python\n", "other = other if isinstance(other, Value) else Value(other)\n", "```\n", "\n", "Now, after updating this, we still have a problem. ```a + 2``` and ```a * 2``` would work, but would ```2 + a``` or ```2 * a```?\n", "\n", "It would not. In python, it knows how to do ```a.__mul___(2)```, but does not know how to do ```2.__mul__(a)```. It knows how to do ```self * other```, but not ```other * self```.\n", "\n", "Fortunately, we can define an ```__rmul__``` , which python checks for if it can't do the given operation. You can do the same with ```__radd__```. They will swap the order of the operands so python can compute the operation. \n", "\n", "``` python\n", "def __radd__(self, other):\n", " return self + other\n", " \n", "def __rmul__(self, other):\n", " return self * other\n", "```" ] }, { "cell_type": "markdown", "id": "a0a0ef78-7eea-4ca1-b869-40909a4b662a", "metadata": {}, "source": [ "Now, we update the ```Value``` class and test:" ] }, { "cell_type": "code", "execution_count": 56, "id": "f81e32f4-9bce-4c1f-9277-22c69b328331", "metadata": {}, "outputs": [], "source": [ "class Value:\n", " \n", " def __init__(self, data, _children=(), _op='', label=''): \n", " self.data = data\n", " self.grad = 0.0 \n", " self._prev = set(_children)\n", " self._backward = lambda : None\n", " self._op = _op\n", " self.label=label\n", "\n", " def __repr__(self):\n", " return f\"Value(data={self.data})\"\n", "\n", " def __add__(self, other):\n", " other = other if isinstance(other, Value) else Value(other) # new for a + 1\n", " out = Value(self.data + other.data, (self, other), '+') \n", " \n", " def _backward(): \n", " self.grad += 1.0 * out.grad \n", " other.grad += 1.0 * out.grad \n", " out._backward = _backward\n", " return out\n", " \n", " def __mul__(self, other):\n", " other = other if isinstance(other, Value) else Value(other) #new for a * 1\n", " out = Value(self.data * other.data, (self, other), '*')\n", "\n", " def _backward():\n", " self.grad = other.data * out.grad \n", " other.grad = self.data * out.grad\n", " out._backward = _backward\n", " return out\n", " \n", " def tanh(self):\n", " x = self.data\n", " t = (math.exp(2*x)-1)/(math.exp(2*x)+1)\n", " \n", " out = Value(t, (self, ), 'tanh')\n", "\n", " def _backward():\n", " self.grad = (1-t**2) * out.grad\n", " out._backward = _backward\n", " return out\n", "\n", " def __radd__(self, other): #new for 1 + a\n", " return self + other\n", " \n", " def __rmul__(self, other): #new for 1 * a\n", " return self * other\n", "\n", " \n", " def backward(self): \n", " topo = [] \n", " visited = set() \n", " def build_topo(v): \n", " visited.add(v)\n", " for child in v._prev:\n", " build_topo(child)\n", " topo.append(v)\n", " build_topo(self)\n", "\n", " self.grad = 1.0 \n", " for node in reversed(topo):\n", " node._backward() " ] }, { "cell_type": "code", "execution_count": 57, "id": "599f2158-d415-4039-a16e-bb93ad93c79c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=4.0)" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#test should work:\n", "a = Value(2.0)\n", "a + 1\n", "a * 3\n", "3 + a\n", "2 * a" ] }, { "cell_type": "markdown", "id": "fe7f99aa-d791-4af1-aec8-0967049fb4b7", "metadata": {}, "source": [ "Now that we can add and divide both ways, let's add the ```exp()``` method into the ```Value``` class. It will mirror ```tanh()``` in that it inputs, transforms, and outputs a single scalar value.\n", "\n", "In the ```exp()```, we are trying to get ${e^{x}}$, or ${e^{self.data}}$. \n", "\n", "The issue, after we calculate ${e^{x}}$ in the code, is how we're going to backpropogate through it. \n", "\n", "We need to know the local derivative of ${e^{x}}$. \n", "\n", "$D/dx$ of ${e^{x}}$ is ${e^{x}}$. Thus, we can multiply it by ```out.grad``` with the chain rule to backpropogate.\n", "\n", "It will return ${e^{x}}$.\n", "\n", "```e``` is the base of the natural system of logarithms (approximately 2.718282) and ```x``` is the number passed to it, or ```self.data```. " ] }, { "cell_type": "code", "execution_count": 59, "id": "b3539bad-44e6-48af-b661-47a25cf14cef", "metadata": {}, "outputs": [], "source": [ "class Value:\n", " \n", " def __init__(self, data, _children=(), _op='', label=''): \n", " self.data = data\n", " self.grad = 0.0 \n", " self._prev = set(_children)\n", " self._backward = lambda : None\n", " self._op = _op\n", " self.label=label\n", "\n", " def __repr__(self):\n", " return f\"Value(data={self.data})\"\n", "\n", " def __add__(self, other):\n", " other = other if isinstance(other, Value) else Value(other) \n", " out = Value(self.data + other.data, (self, other), '+') \n", " \n", " def _backward(): \n", " self.grad += 1.0 * out.grad \n", " other.grad += 1.0 * out.grad \n", " out._backward = _backward\n", " return out\n", " \n", " def __mul__(self, other):\n", " other = other if isinstance(other, Value) else Value(other) \n", " out = Value(self.data * other.data, (self, other), '*')\n", "\n", " def _backward():\n", " self.grad = other.data * out.grad \n", " other.grad = self.data * out.grad\n", " out._backward = _backward\n", " return out\n", " \n", " def tanh(self):\n", " x = self.data\n", " t = (math.exp(2*x)-1)/(math.exp(2*x)+1)\n", " \n", " out = Value(t, (self, ), 'tanh')\n", "\n", " def _backward():\n", " self.grad = (1-t**2) * out.grad\n", " out._backward = _backward\n", " return out\n", "\n", " def __radd__(self, other): \n", " return self + other\n", " \n", " def __rmul__(self, other): \n", " return self * other\n", "\n", " def exp(self): #trying to get e^self.data\n", " x = self.data #hold the value\n", " out = Value(math.exp(x), (self, ), 'exp') #set out to e^x\n", " def _backward(): #since the derivative of e^x is e^x, we can jsut use out, since out.data is e^x\n", " self.grad += out.data * out.grad #then we multiply out.data by out.grad (chain rule)\n", " out._backward = _backward\n", " return out\n", " \n", " def backward(self): \n", " topo = [] \n", " visited = set() \n", " def build_topo(v): \n", " visited.add(v)\n", " for child in v._prev:\n", " build_topo(child)\n", " topo.append(v)\n", " build_topo(self)\n", "\n", " self.grad = 1.0 \n", " for node in reversed(topo):\n", " node._backward() " ] }, { "cell_type": "code", "execution_count": 60, "id": "c1c7b767-6d2f-41a7-a9fb-d81ca23a662c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=54.598150033144236)" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#let's test it:\n", "a = Value(4.0)\n", "a.exp()" ] }, { "cell_type": "markdown", "id": "acfa72d4-c1d0-41de-a7fd-51b2c2d84dde", "metadata": {}, "source": [ "Now, we want to divide. To do this, we can redefine division as an option for a power function. \n", "This is because:\n", "\n", "```a/b = a*1/b = a*b**1```\n", "\n", "We make a power function:\n", "\n", "```python\n", "def __pow__(self, other): #self to the pow of other\n", " #self.data is an int or a float, not a Value obj, just accessing .data prop\n", " assert isinstance(other, (int, float)), \"only supporting int/float powers for now\"\n", " out = Value(self.data**other, (self,),f'**{other}')\n", " \n", " #what's the chain rule for backprop thru the power function, where power is to the power of some kind of constant??\n", " def _backward(): \n", " \n", " #local derivative * out.gard\n", " self.grad += other * self.data ** (other - 1) * out.grad\n", " out._backward = _backward\n", " return out\n", "```\n", "\n", "Why is this the local derivative? If you look up derivative rules, you may come across the power rule:\n", "$${\\frac{d}{dx}({x}^n)={{n}{x}^{n-1}}}$$\n", "\n", "\n", "To use this, we'll add the following, self explanatory function:\n", "```python\n", "def __truediv__(self, other): #self/other; calls self.__mul__(other.__pow__(-1))\n", " return self*other**-1\n", "```" ] }, { "cell_type": "markdown", "id": "3d49cecf-4844-408c-a282-554db2b6f705", "metadata": {}, "source": [ "We add it to the Value class and test:" ] }, { "cell_type": "code", "execution_count": 72, "id": "c241a358-01ee-4ef4-b3e0-5f9944464910", "metadata": {}, "outputs": [], "source": [ "class Value:\n", " \n", " def __init__(self, data, _children=(), _op='', label=''): \n", " self.data = data\n", " self.grad = 0.0 \n", " self._prev = set(_children)\n", " self._backward = lambda : None\n", " self._op = _op\n", " self.label=label\n", "\n", " def __repr__(self):\n", " return f\"Value(data={self.data})\"\n", "\n", " def __add__(self, other):\n", " other = other if isinstance(other, Value) else Value(other) \n", " out = Value(self.data + other.data, (self, other), '+') \n", " \n", " def _backward(): \n", " self.grad += 1.0 * out.grad \n", " other.grad += 1.0 * out.grad \n", " out._backward = _backward\n", " return out\n", " \n", " def __mul__(self, other):\n", " other = other if isinstance(other, Value) else Value(other) \n", " out = Value(self.data * other.data, (self, other), '*')\n", "\n", " def _backward():\n", " self.grad = other.data * out.grad \n", " other.grad = self.data * out.grad\n", " out._backward = _backward\n", " return out\n", " \n", " def tanh(self):\n", " x = self.data\n", " t = (math.exp(2*x)-1)/(math.exp(2*x)+1)\n", " \n", " out = Value(t, (self, ), 'tanh')\n", "\n", " def _backward():\n", " self.grad = (1-t**2) * out.grad\n", " out._backward = _backward\n", " return out\n", "\n", " def __radd__(self, other): \n", " return self + other\n", " \n", " def __rmul__(self, other): \n", " return self * other\n", "\n", " def exp(self): #trying to get e^self.data\n", " x = self.data #hold the value\n", " out = Value(math.exp(x), (self, ), 'exp') #set out to e^x\n", " def _backward(): #since the derivative of e^x is e^x, we can jsut use out, since out.data is e^x\n", " self.grad += out.data * out.grad #then we multiply out.data by out.grad (chain rule)\n", " out._backward = _backward\n", " return out\n", " \n", "\n", " def __pow__(self, other): #self to the pow of other\n", " assert isinstance(other, (int, float)), \"only supporting int/float powers for now\"\n", " out = Value(self.data**other, (self,),f'**{other}')\n", "\n", " def _backward(): \n", " self.grad += other * self.data ** (other -1) * out.grad\n", " out._backward = _backward\n", " return out\n", "\n", " def __truediv__(self, other): #self/other\n", " return self*other**-1\n", " \n", " def backward(self): \n", " topo = [] \n", " visited = set() \n", " def build_topo(v): \n", " visited.add(v)\n", " for child in v._prev:\n", " build_topo(child)\n", " topo.append(v)\n", " build_topo(self)\n", "\n", " self.grad = 1.0 \n", " for node in reversed(topo):\n", " node._backward() " ] }, { "cell_type": "code", "execution_count": 84, "id": "8cbfc5b5-0eea-4738-ad26-d2ba9e34f0b2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=-0.5)" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "e = Value(-2.0)\n", "f = Value(4.0)\n", "e / f" ] }, { "cell_type": "markdown", "id": "5d281bfc-5b9e-4cbe-ac4a-5041b8820a3a", "metadata": {}, "source": [ "Now, we'll give it the subtract functionality. I'm sure you get the idea at this point. Instead of ```self - other```, we implement ```self + -other``` (addition of negation):" ] }, { "cell_type": "code", "execution_count": 136, "id": "cb03a033-b67a-4734-aaeb-12124269708e", "metadata": {}, "outputs": [], "source": [ "class Value:\n", " def __init__(self, data, _children=(), _op='', label=''):\n", " self.data = data\n", " self.grad = 0.0\n", " self._backward = lambda: None\n", " self._prev = set(_children)\n", " self._op = _op\n", " self.label = label\n", "\n", " def __repr__(self):\n", " return f\"Value(data={self.data})\"\n", " \n", " def __add__(self, other):\n", " other = other if isinstance(other, Value) else Value(other)\n", " out = Value(self.data + other.data, (self, other), '+')\n", " \n", " def _backward():\n", " self.grad += 1.0 * out.grad\n", " other.grad += 1.0 * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", "\n", " def __mul__(self, other):\n", " other = other if isinstance(other, Value) else Value(other)\n", " out = Value(self.data * other.data, (self, other), '*')\n", " \n", " def _backward():\n", " self.grad += other.data * out.grad\n", " other.grad += self.data * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", "\n", " def __neg__(self): #new : negate self\n", " return self * -1\n", "\n", " def __sub__(self, other): #self-other\n", " return self + (-other)\n", " \n", " def __pow__(self, other): \n", " assert isinstance(other, (int, float)), \"only supporting int/float powers for now\"\n", " out = Value(self.data**other, (self,),f'**{other}')\n", "\n", " def _backward(): \n", " self.grad += other * self.data ** (other -1) * out.grad\n", " out._backward = _backward\n", " return out\n", "\n", " def __rmul__(self,other):\n", " return self * other\n", "\n", " def __truediv__(self, other): \n", " return self*other**-1\n", " \n", " def tanh(self):\n", " x = self.data\n", " t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)\n", " out = Value(t, (self, ), 'tanh')\n", " \n", " def _backward():\n", " self.grad += (1 - t**2) * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", "\n", " def exp(self):\n", " x = self.data\n", " out = Value(math.exp(x), (self, ), 'exp')\n", " def _backward():\n", " self.grad += out.data * out.grad\n", " out._backward = _backward\n", " return out\n", " \n", " def backward(self):\n", " topo = []\n", " visited = set()\n", " def build_topo(v):\n", " if v not in visited:\n", " visited.add(v)\n", " for child in v._prev:\n", " build_topo(child)\n", " topo.append(v)\n", " build_topo(self)\n", " \n", " self.grad = 1.0\n", " for node in reversed(topo):\n", " node._backward()" ] }, { "cell_type": "code", "execution_count": 137, "id": "6cb01076-ba71-4792-8254-48dcc5431944", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=7.4)" ] }, "execution_count": 137, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#test it\n", "g = Value(-3.4)\n", "z = Value(4)\n", "z - g" ] }, { "cell_type": "markdown", "id": "48809e1f-98a1-4ad4-9c7a-44df1b2979aa", "metadata": {}, "source": [ "Now, let's revisit the variables from before, so we can break up the ```tan(h)``` into what we've built above and create this:\n", "$$tanh(x)={\\frac {e^{2x}-1}{e^{2x}+1}}$$" ] }, { "cell_type": "code", "execution_count": 138, "id": "629ce317-7b6f-480d-87a4-f1ff987f84c8", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 11.0.0 (20240428.1522)\n", " -->\n", "<!-- Pages: 1 -->\n", "<svg width=\"1567pt\" height=\"210pt\"\n", " viewBox=\"0.00 0.00 1566.50 210.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 206)\">\n", "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-206 1562.5,-206 1562.5,4 -4,4\"/>\n", "<!-- 4635406416 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>4635406416</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"711.75,-137.5 711.75,-173.5 895.5,-173.5 895.5,-137.5 711.75,-137.5\"/>\n", "<text text-anchor=\"middle\" x=\"723.12\" y=\"-150.7\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"734.5,-138 734.5,-173.5\"/>\n", "<text text-anchor=\"middle\" x=\"774.38\" y=\"-150.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"814.25,-138 814.25,-173.5\"/>\n", "<text text-anchor=\"middle\" x=\"854.88\" y=\"-150.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635406736+ -->\n", "<g id=\"node12\" class=\"node\">\n", "<title>4635406736+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1002\" cy=\"-127.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1002\" y=\"-122.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4635406416->4635406736+ -->\n", "<g id=\"edge9\" class=\"edge\">\n", "<title>4635406416->4635406736+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M895.91,-142.48C919.97,-139.05 944.58,-135.55 963.93,-132.79\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"964.34,-136.26 973.75,-131.39 963.35,-129.33 964.34,-136.26\"/>\n", "</g>\n", "<!-- 4634054736 -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>4634054736</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"3.75,-165.5 3.75,-201.5 194.25,-201.5 194.25,-165.5 3.75,-165.5\"/>\n", "<text text-anchor=\"middle\" x=\"18.5\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"33.25,-166 33.25,-201.5\"/>\n", "<text text-anchor=\"middle\" x=\"73.12\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"113,-166 113,-201.5\"/>\n", "<text text-anchor=\"middle\" x=\"153.62\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635402128* -->\n", "<g id=\"node8\" class=\"node\">\n", "<title>4635402128*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"261\" cy=\"-128.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"261\" y=\"-123.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4634054736->4635402128* -->\n", "<g id=\"edge10\" class=\"edge\">\n", "<title>4634054736->4635402128*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M172.12,-165.01C180.9,-162.35 189.7,-159.5 198,-156.5 208.09,-152.86 218.82,-148.27 228.47,-143.88\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"229.86,-147.09 237.45,-139.69 226.9,-140.75 229.86,-147.09\"/>\n", "</g>\n", "<!-- 4634058384 -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>4634058384</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-55.5 0,-91.5 198,-91.5 198,-55.5 0,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"16.25\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"32.5,-56 32.5,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"74.62\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"116.75,-56 116.75,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"157.38\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4635403664* -->\n", "<g id=\"node10\" class=\"node\">\n", "<title>4635403664*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"261\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"261\" y=\"-68.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4634058384->4635403664* -->\n", "<g id=\"edge12\" class=\"edge\">\n", "<title>4634058384->4635403664*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M198.14,-73.5C206.61,-73.5 214.8,-73.5 222.29,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"222.29,-77 232.29,-73.5 222.29,-70 222.29,-77\"/>\n", "</g>\n", "<!-- 4634056848 -->\n", "<g id=\"node4\" class=\"node\">\n", "<title>4634056848</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2.25,-110.5 2.25,-146.5 195.75,-146.5 195.75,-110.5 2.25,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"18.5\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"34.75,-111 34.75,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"74.62\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"114.5,-111 114.5,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"155.12\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 4634056848->4635402128* -->\n", "<g id=\"edge11\" class=\"edge\">\n", "<title>4634056848->4635402128*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M195.84,-128.5C205.14,-128.5 214.15,-128.5 222.32,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"222.13,-132 232.13,-128.5 222.13,-125 222.13,-132\"/>\n", "</g>\n", "<!-- 4635397328 -->\n", "<g id=\"node5\" class=\"node\">\n", "<title>4635397328</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"668.25,-82.5 668.25,-118.5 939,-118.5 939,-82.5 668.25,-82.5\"/>\n", "<text text-anchor=\"middle\" x=\"720.88\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1 + x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"773.5,-83 773.5,-118.5\"/>\n", "<text text-anchor=\"middle\" x=\"815.62\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"857.75,-83 857.75,-118.5\"/>\n", "<text text-anchor=\"middle\" x=\"898.38\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635397328->4635406736+ -->\n", "<g id=\"edge13\" class=\"edge\">\n", "<title>4635397328->4635406736+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M939.14,-118.99C947.95,-120.21 956.31,-121.36 963.87,-122.4\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"963.16,-125.83 973.55,-123.73 964.12,-118.9 963.16,-125.83\"/>\n", "</g>\n", "<!-- 4635397328+ -->\n", "<g id=\"node6\" class=\"node\">\n", "<title>4635397328+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"605.25\" cy=\"-100.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"605.25\" y=\"-95.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4635397328+->4635397328 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>4635397328+->4635397328</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M632.73,-100.5C639.73,-100.5 647.79,-100.5 656.52,-100.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"656.39,-104 666.39,-100.5 656.39,-97 656.39,-104\"/>\n", "</g>\n", "<!-- 4635402128 -->\n", "<g id=\"node7\" class=\"node\">\n", "<title>4635402128</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"326.25,-110.5 326.25,-146.5 540,-146.5 540,-110.5 326.25,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"352.62\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"379,-111 379,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"418.88\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"458.75,-111 458.75,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"499.38\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635402128->4635397328+ -->\n", "<g id=\"edge8\" class=\"edge\">\n", "<title>4635402128->4635397328+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M540.42,-111.01C549.89,-109.45 559,-107.95 567.21,-106.6\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"567.75,-110.05 577.05,-104.98 566.62,-103.15 567.75,-110.05\"/>\n", "</g>\n", "<!-- 4635402128*->4635402128 -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>4635402128*->4635402128</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M288.21,-128.5C296,-128.5 305.08,-128.5 314.82,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"314.55,-132 324.55,-128.5 314.55,-125 314.55,-132\"/>\n", "</g>\n", "<!-- 4635403664 -->\n", "<g id=\"node9\" class=\"node\">\n", "<title>4635403664</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"324,-55.5 324,-91.5 542.25,-91.5 542.25,-55.5 324,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"350.38\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"376.75,-56 376.75,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"418.88\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"461,-56 461,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"501.62\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635403664->4635397328+ -->\n", "<g id=\"edge14\" class=\"edge\">\n", "<title>4635403664->4635397328+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M542.35,-90.68C551.12,-92.07 559.54,-93.4 567.19,-94.62\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"566.59,-98.07 577.02,-96.18 567.69,-91.15 566.59,-98.07\"/>\n", "</g>\n", "<!-- 4635403664*->4635403664 -->\n", "<g id=\"edge3\" class=\"edge\">\n", "<title>4635403664*->4635403664</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M288.21,-73.5C295.29,-73.5 303.43,-73.5 312.17,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"312.01,-77 322.01,-73.5 312.01,-70 312.01,-77\"/>\n", "</g>\n", "<!-- 4635406736 -->\n", "<g id=\"node11\" class=\"node\">\n", "<title>4635406736</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1065,-109.5 1065,-145.5 1248.75,-145.5 1248.75,-109.5 1065,-109.5\"/>\n", "<text text-anchor=\"middle\" x=\"1076.38\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\">n</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1087.75,-110 1087.75,-145.5\"/>\n", "<text text-anchor=\"middle\" x=\"1127.62\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1167.5,-110 1167.5,-145.5\"/>\n", "<text text-anchor=\"middle\" x=\"1208.12\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635143568tanh -->\n", "<g id=\"node14\" class=\"node\">\n", "<title>4635143568tanh</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1311.75\" cy=\"-127.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1311.75\" y=\"-122.45\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n", "</g>\n", "<!-- 4635406736->4635143568tanh -->\n", "<g id=\"edge7\" class=\"edge\">\n", "<title>4635406736->4635143568tanh</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1249.01,-127.5C1257.39,-127.5 1265.52,-127.5 1272.98,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1272.95,-131 1282.95,-127.5 1272.95,-124 1272.95,-131\"/>\n", "</g>\n", "<!-- 4635406736+->4635406736 -->\n", "<g id=\"edge4\" class=\"edge\">\n", "<title>4635406736+->4635406736</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1029.28,-127.5C1036.42,-127.5 1044.61,-127.5 1053.32,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1053.06,-131 1063.06,-127.5 1053.06,-124 1053.06,-131\"/>\n", "</g>\n", "<!-- 4635143568 -->\n", "<g id=\"node13\" class=\"node\">\n", "<title>4635143568</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1374.75,-109.5 1374.75,-145.5 1558.5,-145.5 1558.5,-109.5 1374.75,-109.5\"/>\n", "<text text-anchor=\"middle\" x=\"1386.12\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\">o</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1397.5,-110 1397.5,-145.5\"/>\n", "<text text-anchor=\"middle\" x=\"1437.38\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.7071</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1477.25,-110 1477.25,-145.5\"/>\n", "<text text-anchor=\"middle\" x=\"1517.88\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4635143568tanh->4635143568 -->\n", "<g id=\"edge5\" class=\"edge\">\n", "<title>4635143568tanh->4635143568</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1339.03,-127.5C1346.17,-127.5 1354.36,-127.5 1363.07,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1362.81,-131 1372.81,-127.5 1362.81,-124 1362.81,-131\"/>\n", "</g>\n", "<!-- 4634059216 -->\n", "<g id=\"node15\" class=\"node\">\n", "<title>4634059216</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1.5,-0.5 1.5,-36.5 196.5,-36.5 196.5,-0.5 1.5,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"16.25\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"31,-1 31,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"70.88\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 2.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"110.75,-1 110.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"153.62\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -1.5000</text>\n", "</g>\n", "<!-- 4634059216->4635403664* -->\n", "<g id=\"edge6\" class=\"edge\">\n", "<title>4634059216->4635403664*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M168.91,-36.94C178.74,-39.93 188.67,-43.15 198,-46.5 207.96,-50.07 218.58,-54.47 228.18,-58.68\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"226.54,-61.78 237.1,-62.67 229.4,-55.39 226.54,-61.78\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x1144abf90>" ] }, "execution_count": 138, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#redefine variables and run, noting the original gradients after a backward pass:\n", "x1 = Value(2.0, label='x1')\n", "x2 = Value(0.0, label='x2')\n", "w1 = Value(-3.0, label='w1')\n", "w2 = Value(1.0, label='w2')\n", "b = Value(6.8813735870195432, label='b')\n", "x1w1 = x1*w1; x1w1.label = 'x1*w1'\n", "x2w2 = x2*w2; x2w2.label = 'x2*w2'\n", "x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n", "n = x1w1x2w2 + b; n.label = 'n'\n", "o = n.tanh(); o.label = 'o'\n", "o.backward()\n", "draw_dot(o)" ] }, { "cell_type": "code", "execution_count": 140, "id": "e9f261dd-e7d0-4378-bfa8-369ccc489d31", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 11.0.0 (20240428.1522)\n", " -->\n", "<!-- Pages: 1 -->\n", "<svg width=\"2929pt\" height=\"210pt\"\n", " viewBox=\"0.00 0.00 2929.25 210.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 206)\">\n", "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-206 2925.25,-206 2925.25,4 -4,4\"/>\n", "<!-- 4635287056 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>4635287056</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1065,-54.5 1065,-90.5 1248.75,-90.5 1248.75,-54.5 1065,-54.5\"/>\n", "<text text-anchor=\"middle\" x=\"1076.38\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">n</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1087.75,-55 1087.75,-90.5\"/>\n", "<text text-anchor=\"middle\" x=\"1127.62\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1167.5,-55 1167.5,-90.5\"/>\n", "<text text-anchor=\"middle\" x=\"1208.12\" y=\"-67.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635132688* -->\n", "<g id=\"node14\" class=\"node\">\n", "<title>4635132688*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1311.75\" cy=\"-99.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1311.75\" y=\"-94.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4635287056->4635132688* -->\n", "<g id=\"edge17\" class=\"edge\">\n", "<title>4635287056->4635132688*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1249.01,-88.6C1257.84,-90.16 1266.4,-91.67 1274.19,-93.04\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1273.42,-96.46 1283.88,-94.75 1274.64,-89.57 1273.42,-96.46\"/>\n", "</g>\n", "<!-- 4635287056+ -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>4635287056+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1002\" cy=\"-72.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1002\" y=\"-67.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4635287056+->4635287056 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>4635287056+->4635287056</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1029.28,-72.5C1036.42,-72.5 1044.61,-72.5 1053.32,-72.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1053.06,-76 1063.06,-72.5 1053.06,-69 1053.06,-76\"/>\n", "</g>\n", "<!-- 4635293200 -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>4635293200</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"326.25,-110.5 326.25,-146.5 540,-146.5 540,-110.5 326.25,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"352.62\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"379,-111 379,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"418.88\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"458.75,-111 458.75,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"499.38\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635295056+ -->\n", "<g id=\"node23\" class=\"node\">\n", "<title>4635295056+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"605.25\" cy=\"-100.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"605.25\" y=\"-95.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4635293200->4635295056+ -->\n", "<g id=\"edge27\" class=\"edge\">\n", "<title>4635293200->4635295056+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M540.42,-111.01C549.89,-109.45 559,-107.95 567.21,-106.6\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"567.75,-110.05 577.05,-104.98 566.62,-103.15 567.75,-110.05\"/>\n", "</g>\n", "<!-- 4635293200* -->\n", "<g id=\"node4\" class=\"node\">\n", "<title>4635293200*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"261\" cy=\"-128.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"261\" y=\"-123.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4635293200*->4635293200 -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>4635293200*->4635293200</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M288.21,-128.5C296,-128.5 305.08,-128.5 314.82,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"314.55,-132 324.55,-128.5 314.55,-125 314.55,-132\"/>\n", "</g>\n", "<!-- 4634145872 -->\n", "<g id=\"node5\" class=\"node\">\n", "<title>4634145872</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2214,-53.5 2214,-89.5 2394.75,-89.5 2394.75,-53.5 2214,-53.5\"/>\n", "<text text-anchor=\"middle\" x=\"2223.88\" y=\"-66.7\" font-family=\"Times,serif\" font-size=\"14.00\"> </text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"2233.75,-54 2233.75,-89.5\"/>\n", "<text text-anchor=\"middle\" x=\"2273.62\" y=\"-66.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 4.8284</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"2313.5,-54 2313.5,-89.5\"/>\n", "<text text-anchor=\"middle\" x=\"2354.12\" y=\"-66.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.1464</text>\n", "</g>\n", "<!-- 4634277072* -->\n", "<g id=\"node11\" class=\"node\">\n", "<title>4634277072*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"2674.5\" cy=\"-98.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"2674.5\" y=\"-93.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4634145872->4634277072* -->\n", "<g id=\"edge24\" class=\"edge\">\n", "<title>4634145872->4634277072*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M2394.96,-78.07C2472.25,-83.74 2580.03,-91.64 2636.17,-95.76\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"2635.59,-99.23 2645.82,-96.47 2636.1,-92.25 2635.59,-99.23\"/>\n", "</g>\n", "<!-- 4634145872+ -->\n", "<g id=\"node6\" class=\"node\">\n", "<title>4634145872+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1929.75\" cy=\"-71.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1929.75\" y=\"-66.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4634145872+->4634145872 -->\n", "<g id=\"edge3\" class=\"edge\">\n", "<title>4634145872+->4634145872</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1957.25,-71.5C2007.51,-71.5 2119.07,-71.5 2202.03,-71.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"2202.02,-75 2212.02,-71.5 2202.02,-68 2202.02,-75\"/>\n", "</g>\n", "<!-- 4635113104 -->\n", "<g id=\"node7\" class=\"node\">\n", "<title>4635113104</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1.5,-55.5 1.5,-91.5 196.5,-91.5 196.5,-55.5 1.5,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"16.25\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"31,-56 31,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"70.88\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 2.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"110.75,-56 110.75,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"153.62\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -1.5000</text>\n", "</g>\n", "<!-- 4635285264* -->\n", "<g id=\"node16\" class=\"node\">\n", "<title>4635285264*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"261\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"261\" y=\"-68.45\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 4635113104->4635285264* -->\n", "<g id=\"edge21\" class=\"edge\">\n", "<title>4635113104->4635285264*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M196.76,-73.5C205.77,-73.5 214.47,-73.5 222.4,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"222.25,-77 232.25,-73.5 222.25,-70 222.25,-77\"/>\n", "</g>\n", "<!-- 4635135120 -->\n", "<g id=\"node8\" class=\"node\">\n", "<title>4635135120</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1681.5,-26.5 1681.5,-62.5 1866.75,-62.5 1866.75,-26.5 1681.5,-26.5\"/>\n", "<text text-anchor=\"middle\" x=\"1691.38\" y=\"-39.7\" font-family=\"Times,serif\" font-size=\"14.00\"> </text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1701.25,-27 1701.25,-62.5\"/>\n", "<text text-anchor=\"middle\" x=\"1743.38\" y=\"-39.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1785.5,-27 1785.5,-62.5\"/>\n", "<text text-anchor=\"middle\" x=\"1826.12\" y=\"-39.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.1464</text>\n", "</g>\n", "<!-- 4635135120->4634145872+ -->\n", "<g id=\"edge13\" class=\"edge\">\n", "<title>4635135120->4634145872+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1867.15,-60.67C1875.89,-62.21 1884.36,-63.7 1892.08,-65.05\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1891.2,-68.45 1901.66,-66.74 1892.41,-61.56 1891.2,-68.45\"/>\n", "</g>\n", "<!-- 4634144976 -->\n", "<g id=\"node9\" class=\"node\">\n", "<title>4634144976</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1681.5,-136.5 1681.5,-172.5 1866.75,-172.5 1866.75,-136.5 1681.5,-136.5\"/>\n", "<text text-anchor=\"middle\" x=\"1691.38\" y=\"-149.7\" font-family=\"Times,serif\" font-size=\"14.00\"> </text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1701.25,-137 1701.25,-172.5\"/>\n", "<text text-anchor=\"middle\" x=\"1741.12\" y=\"-149.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1781,-137 1781,-172.5\"/>\n", "<text text-anchor=\"middle\" x=\"1823.88\" y=\"-149.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -0.1036</text>\n", "</g>\n", "<!-- 4634142608+ -->\n", "<g id=\"node26\" class=\"node\">\n", "<title>4634142608+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1929.75\" cy=\"-126.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1929.75\" y=\"-121.45\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 4634144976->4634142608+ -->\n", "<g id=\"edge11\" class=\"edge\">\n", "<title>4634144976->4634142608+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1867.15,-137.73C1875.89,-136.13 1884.36,-134.59 1892.08,-133.18\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1892.45,-136.67 1901.66,-131.44 1891.19,-129.79 1892.45,-136.67\"/>\n", "</g>\n", "<!-- 4634277072 -->\n", "<g id=\"node10\" class=\"node\">\n", "<title>4634277072</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2737.5,-80.5 2737.5,-116.5 2921.25,-116.5 2921.25,-80.5 2737.5,-80.5\"/>\n", "<text text-anchor=\"middle\" x=\"2748.88\" y=\"-93.7\" font-family=\"Times,serif\" font-size=\"14.00\">o</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"2760.25,-81 2760.25,-116.5\"/>\n", "<text text-anchor=\"middle\" x=\"2800.12\" y=\"-93.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.7071</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"2840,-81 2840,-116.5\"/>\n", "<text text-anchor=\"middle\" x=\"2880.62\" y=\"-93.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4634277072*->4634277072 -->\n", "<g id=\"edge4\" class=\"edge\">\n", "<title>4634277072*->4634277072</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M2701.78,-98.5C2708.92,-98.5 2717.11,-98.5 2725.82,-98.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"2725.56,-102 2735.56,-98.5 2725.56,-95 2725.56,-102\"/>\n", "</g>\n", "<!-- 4635110096 -->\n", "<g id=\"node12\" class=\"node\">\n", "<title>4635110096</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2.25,-165.5 2.25,-201.5 195.75,-201.5 195.75,-165.5 2.25,-165.5\"/>\n", "<text text-anchor=\"middle\" x=\"18.5\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"34.75,-166 34.75,-201.5\"/>\n", "<text text-anchor=\"middle\" x=\"74.62\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"114.5,-166 114.5,-201.5\"/>\n", "<text text-anchor=\"middle\" x=\"155.12\" y=\"-178.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 4635110096->4635293200* -->\n", "<g id=\"edge12\" class=\"edge\">\n", "<title>4635110096->4635293200*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M172.12,-165.01C180.9,-162.35 189.7,-159.5 198,-156.5 208.09,-152.86 218.82,-148.27 228.47,-143.88\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"229.86,-147.09 237.45,-139.69 226.9,-140.75 229.86,-147.09\"/>\n", "</g>\n", "<!-- 4635132688 -->\n", "<g id=\"node13\" class=\"node\">\n", "<title>4635132688</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1374.75,-81.5 1374.75,-117.5 1555.5,-117.5 1555.5,-81.5 1374.75,-81.5\"/>\n", "<text text-anchor=\"middle\" x=\"1384.62\" y=\"-94.7\" font-family=\"Times,serif\" font-size=\"14.00\"> </text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1394.5,-82 1394.5,-117.5\"/>\n", "<text text-anchor=\"middle\" x=\"1434.38\" y=\"-94.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.7627</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1474.25,-82 1474.25,-117.5\"/>\n", "<text text-anchor=\"middle\" x=\"1514.88\" y=\"-94.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.2500</text>\n", "</g>\n", "<!-- 4635142992exp -->\n", "<g id=\"node21\" class=\"node\">\n", "<title>4635142992exp</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1618.5\" cy=\"-99.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1618.5\" y=\"-94.45\" font-family=\"Times,serif\" font-size=\"14.00\">exp</text>\n", "</g>\n", "<!-- 4635132688->4635142992exp -->\n", "<g id=\"edge22\" class=\"edge\">\n", "<title>4635132688->4635142992exp</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1555.93,-99.5C1564.27,-99.5 1572.37,-99.5 1579.81,-99.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1579.75,-103 1589.75,-99.5 1579.75,-96 1579.75,-103\"/>\n", "</g>\n", "<!-- 4635132688*->4635132688 -->\n", "<g id=\"edge5\" class=\"edge\">\n", "<title>4635132688*->4635132688</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1339.12,-99.5C1346.26,-99.5 1354.45,-99.5 1363.14,-99.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1362.86,-103 1372.86,-99.5 1362.86,-96 1362.86,-103\"/>\n", "</g>\n", "<!-- 4635285264 -->\n", "<g id=\"node15\" class=\"node\">\n", "<title>4635285264</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"324,-55.5 324,-91.5 542.25,-91.5 542.25,-55.5 324,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"350.38\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"376.75,-56 376.75,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"418.88\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"461,-56 461,-91.5\"/>\n", "<text text-anchor=\"middle\" x=\"501.62\" y=\"-68.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635285264->4635295056+ -->\n", "<g id=\"edge16\" class=\"edge\">\n", "<title>4635285264->4635295056+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M542.35,-90.68C551.12,-92.07 559.54,-93.4 567.19,-94.62\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"566.59,-98.07 577.02,-96.18 567.69,-91.15 566.59,-98.07\"/>\n", "</g>\n", "<!-- 4635285264*->4635285264 -->\n", "<g id=\"edge6\" class=\"edge\">\n", "<title>4635285264*->4635285264</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M288.21,-73.5C295.29,-73.5 303.43,-73.5 312.17,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"312.01,-77 322.01,-73.5 312.01,-70 312.01,-77\"/>\n", "</g>\n", "<!-- 4635114256 -->\n", "<g id=\"node17\" class=\"node\">\n", "<title>4635114256</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 198,-36.5 198,-0.5 0,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"16.25\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"32.5,-1 32.5,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"74.62\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"116.75,-1 116.75,-36.5\"/>\n", "<text text-anchor=\"middle\" x=\"157.38\" y=\"-13.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 4635114256->4635285264* -->\n", "<g id=\"edge15\" class=\"edge\">\n", "<title>4635114256->4635285264*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M172.12,-36.99C180.9,-39.65 189.7,-42.5 198,-45.5 208.09,-49.14 218.82,-53.73 228.47,-58.12\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"226.9,-61.25 237.45,-62.31 229.86,-54.91 226.9,-61.25\"/>\n", "</g>\n", "<!-- 4634133840 -->\n", "<g id=\"node18\" class=\"node\">\n", "<title>4634133840</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2430.75,-106.5 2430.75,-142.5 2611.5,-142.5 2611.5,-106.5 2430.75,-106.5\"/>\n", "<text text-anchor=\"middle\" x=\"2440.62\" y=\"-119.7\" font-family=\"Times,serif\" font-size=\"14.00\"> </text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"2450.5,-107 2450.5,-142.5\"/>\n", "<text text-anchor=\"middle\" x=\"2490.38\" y=\"-119.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.1464</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"2530.25,-107 2530.25,-142.5\"/>\n", "<text text-anchor=\"middle\" x=\"2570.88\" y=\"-119.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 4.8284</text>\n", "</g>\n", "<!-- 4634133840->4634277072* -->\n", "<g id=\"edge19\" class=\"edge\">\n", "<title>4634133840->4634277072*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M2611.93,-109.08C2620.54,-107.6 2628.89,-106.16 2636.53,-104.85\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"2637.11,-108.3 2646.37,-103.16 2635.92,-101.4 2637.11,-108.3\"/>\n", "</g>\n", "<!-- 4634133840**-1 -->\n", "<g id=\"node19\" class=\"node\">\n", "<title>4634133840**-1</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"2304.38\" cy=\"-126.5\" rx=\"27.81\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"2304.38\" y=\"-121.45\" font-family=\"Times,serif\" font-size=\"14.00\">**-1</text>\n", "</g>\n", "<!-- 4634133840**-1->4634133840 -->\n", "<g id=\"edge7\" class=\"edge\">\n", "<title>4634133840**-1->4634133840</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M2332.57,-126.25C2354.68,-126.04 2387.33,-125.74 2419.22,-125.44\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"2418.96,-128.94 2428.93,-125.35 2418.9,-121.94 2418.96,-128.94\"/>\n", "</g>\n", "<!-- 4635142992 -->\n", "<g id=\"node20\" class=\"node\">\n", "<title>4635142992</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1683.75,-81.5 1683.75,-117.5 1864.5,-117.5 1864.5,-81.5 1683.75,-81.5\"/>\n", "<text text-anchor=\"middle\" x=\"1693.62\" y=\"-94.7\" font-family=\"Times,serif\" font-size=\"14.00\"> </text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1703.5,-82 1703.5,-117.5\"/>\n", "<text text-anchor=\"middle\" x=\"1743.38\" y=\"-94.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 5.8284</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1783.25,-82 1783.25,-117.5\"/>\n", "<text text-anchor=\"middle\" x=\"1823.88\" y=\"-94.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0429</text>\n", "</g>\n", "<!-- 4635142992->4634145872+ -->\n", "<g id=\"edge18\" class=\"edge\">\n", "<title>4635142992->4634145872+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1864.93,-83.13C1874.48,-81.39 1883.76,-79.7 1892.14,-78.17\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1892.64,-81.64 1901.85,-76.4 1891.39,-74.75 1892.64,-81.64\"/>\n", "</g>\n", "<!-- 4635142992->4634142608+ -->\n", "<g id=\"edge28\" class=\"edge\">\n", "<title>4635142992->4634142608+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1864.93,-115.28C1874.48,-116.96 1883.76,-118.59 1892.14,-120.07\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1891.39,-123.49 1901.85,-121.77 1892.61,-116.59 1891.39,-123.49\"/>\n", "</g>\n", "<!-- 4635142992exp->4635142992 -->\n", "<g id=\"edge8\" class=\"edge\">\n", "<title>4635142992exp->4635142992</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1645.91,-99.5C1653.62,-99.5 1662.54,-99.5 1672.01,-99.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1671.82,-103 1681.82,-99.5 1671.82,-96 1671.82,-103\"/>\n", "</g>\n", "<!-- 4635295056 -->\n", "<g id=\"node22\" class=\"node\">\n", "<title>4635295056</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"668.25,-82.5 668.25,-118.5 939,-118.5 939,-82.5 668.25,-82.5\"/>\n", "<text text-anchor=\"middle\" x=\"720.88\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1 + x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"773.5,-83 773.5,-118.5\"/>\n", "<text text-anchor=\"middle\" x=\"815.62\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"857.75,-83 857.75,-118.5\"/>\n", "<text text-anchor=\"middle\" x=\"898.38\" y=\"-95.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635295056->4635287056+ -->\n", "<g id=\"edge20\" class=\"edge\">\n", "<title>4635295056->4635287056+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M934.24,-82.02C944.99,-80.49 955.18,-79.03 964.21,-77.75\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"964.46,-81.25 973.86,-76.37 963.47,-74.32 964.46,-81.25\"/>\n", "</g>\n", "<!-- 4635295056+->4635295056 -->\n", "<g id=\"edge9\" class=\"edge\">\n", "<title>4635295056+->4635295056</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M632.73,-100.5C639.73,-100.5 647.79,-100.5 656.52,-100.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"656.39,-104 666.39,-100.5 656.39,-97 656.39,-104\"/>\n", "</g>\n", "<!-- 4635407184 -->\n", "<g id=\"node24\" class=\"node\">\n", "<title>4635407184</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"3.75,-110.5 3.75,-146.5 194.25,-146.5 194.25,-110.5 3.75,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"18.5\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"33.25,-111 33.25,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"73.12\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"113,-111 113,-146.5\"/>\n", "<text text-anchor=\"middle\" x=\"153.62\" y=\"-123.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635407184->4635293200* -->\n", "<g id=\"edge14\" class=\"edge\">\n", "<title>4635407184->4635293200*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M194.46,-128.5C204.21,-128.5 213.66,-128.5 222.21,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"222.12,-132 232.12,-128.5 222.12,-125 222.12,-132\"/>\n", "</g>\n", "<!-- 4634142608 -->\n", "<g id=\"node25\" class=\"node\">\n", "<title>4634142608</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1992.75,-108.5 1992.75,-144.5 2178,-144.5 2178,-108.5 1992.75,-108.5\"/>\n", "<text text-anchor=\"middle\" x=\"2002.62\" y=\"-121.7\" font-family=\"Times,serif\" font-size=\"14.00\"> </text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"2012.5,-109 2012.5,-144.5\"/>\n", "<text text-anchor=\"middle\" x=\"2052.38\" y=\"-121.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.8284</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"2092.25,-109 2092.25,-144.5\"/>\n", "<text text-anchor=\"middle\" x=\"2135.12\" y=\"-121.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad -0.1036</text>\n", "</g>\n", "<!-- 4634142608->4634133840**-1 -->\n", "<g id=\"edge23\" class=\"edge\">\n", "<title>4634142608->4634133840**-1</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M2178.4,-126.5C2208.64,-126.5 2240.63,-126.5 2264.67,-126.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"2264.62,-130 2274.62,-126.5 2264.62,-123 2264.62,-130\"/>\n", "</g>\n", "<!-- 4634142608+->4634142608 -->\n", "<g id=\"edge10\" class=\"edge\">\n", "<title>4634142608+->4634142608</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1957.16,-126.5C1964.2,-126.5 1972.25,-126.5 1980.8,-126.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1980.79,-130 1990.79,-126.5 1980.79,-123 1980.79,-130\"/>\n", "</g>\n", "<!-- 4635147664 -->\n", "<g id=\"node27\" class=\"node\">\n", "<title>4635147664</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1066.5,-109.5 1066.5,-145.5 1247.25,-145.5 1247.25,-109.5 1066.5,-109.5\"/>\n", "<text text-anchor=\"middle\" x=\"1076.38\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\"> </text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1086.25,-110 1086.25,-145.5\"/>\n", "<text text-anchor=\"middle\" x=\"1126.12\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 2.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1166,-110 1166,-145.5\"/>\n", "<text text-anchor=\"middle\" x=\"1206.62\" y=\"-122.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.2203</text>\n", "</g>\n", "<!-- 4635147664->4635132688* -->\n", "<g id=\"edge25\" class=\"edge\">\n", "<title>4635147664->4635132688*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1247.69,-111.05C1256.94,-109.35 1265.93,-107.71 1274.08,-106.22\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1274.67,-109.67 1283.87,-104.42 1273.41,-102.78 1274.67,-109.67\"/>\n", "</g>\n", "<!-- 4635295184 -->\n", "<g id=\"node28\" class=\"node\">\n", "<title>4635295184</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"711.75,-27.5 711.75,-63.5 895.5,-63.5 895.5,-27.5 711.75,-27.5\"/>\n", "<text text-anchor=\"middle\" x=\"723.12\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"734.5,-28 734.5,-63.5\"/>\n", "<text text-anchor=\"middle\" x=\"774.38\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"814.25,-28 814.25,-63.5\"/>\n", "<text text-anchor=\"middle\" x=\"854.88\" y=\"-40.7\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 4635295184->4635287056+ -->\n", "<g id=\"edge26\" class=\"edge\">\n", "<title>4635295184->4635287056+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M895.91,-58.05C919.97,-61.36 944.58,-64.74 963.93,-67.4\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"963.36,-70.86 973.75,-68.75 964.32,-63.92 963.36,-70.86\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x1143753d0>" ] }, "execution_count": 140, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1 = Value(2.0, label='x1')\n", "x2 = Value(0.0, label='x2')\n", "w1 = Value(-3.0, label='w1')\n", "w2 = Value(1.0, label='w2')\n", "b = Value(6.8813735870195432, label='b')\n", "x1w1 = x1*w1; x1w1.label = 'x1*w1'\n", "x2w2 = x2*w2; x2w2.label = 'x2*w2'\n", "x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n", "n = x1w1x2w2 + b; n.label = 'n'\n", "\n", "#trying to get hyperbolic tangent : e^(2x) - 1 / e^(2x) + 1\n", "e = (2*n).exp() #e^(2x\n", "o = (e - 1) / (e + 1)\n", "\n", "o.label = 'o'\n", "o.backward()\n", "\n", "#test backprop to see if the gradients are the same as before\n", "#expect to see a much longer graph because we broke up tanh()\n", "draw_dot(o)" ] }, { "cell_type": "markdown", "id": "35e6868d-7483-4b31-b21d-16fca502718f", "metadata": {}, "source": [ "Since the end and beginning gradients and data end up being the same for the backward and forward pass, we can assume we made the right conversion from our original ```tanh()``` to the broken down version.\n", "\n", "Andrej had us do these exercises because:\n", "- practice writing more operations for the backward pass\n", "- illustrate that the level of abstraction is up to the coder. You can write atomic or more composite functions like the ```tanh()```, and achieve the same result. The only thing you lose is the granularity of seeing the backward pass. As long as you can do the forward and backward pass, it doesn't matter what the operation is/how composite it is. If you can find the local gradient and you can chain it with the chain rule, you can do backpropogation with any type of operation." ] }, { "cell_type": "markdown", "id": "5686caf7-faf0-4e7f-9923-2c144030905d", "metadata": {}, "source": [ "## Great job! You now should have a firm grasp on backpropogation through a neuron looks and is done. In the next lecture, we will abstract this process further with PyTorch, and continue building out our deep neural network, micrograd. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }