{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introducing Automatic Optimization: Let's build a Deep Learning Framework\n", "\n", "In this chapter, we will:\n", "- Learn what is a Deep Learning Framework.\n", "- Introduce `Tensors`.\n", "- Introduce the `Autograd` system.\n", "- Learn how does addition backpropagation work.\n", "- Explain how to Learn a framework.\n", "- Implement nonlinearity layers.\n", "- Implement the embedding layer.\n", "- Implement the cross entropy layer.\n", "- Implement the recurrent layer.\n", "\n", "> [Arthur C. Clarke] \"Whether we are based on carbon or silicon makes no fundamental difference; we should each be treated with appropriate respect\"." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What is a Deep Learning Framework?\n", "### Good Tools reduce Errors, Speed Development, & Increase runtime performance.\n", "\n", "Now we're going to transition into using a framework, because the network you'll be training next, Long Short-term memory neural networks, are very complex, and `numpy` code describing their implementation is difficult to read, use, or debug because gradients will be flying everywhere.\n", "\n", "It's exactly this code complexity that deep learning frameworks were created to mitigate. Especially if we want to train our model on a GPU (10-100x faster training). A through understanding of a deep learning framework will be essential on our journey toward becoming a user or researcher in deep learning. But we won't jump into any deep learning framework we've heard of, that would stifle our ability to learn about what complex models (such as LSTMs) are doing under the hood. Instead, we'll build a light deep learning framework according to the latest trends in framework development. This way, we'll have no doubt about what DL frameworks do when using them for complex architectures.\n", "\n", "Building a deep learning framework will provide a smooth transition into using actual deep learning frameworks, because we'll already be familiar with the API and the functionality underneath it. The most beneficial thing about deep learning frameworks are their ability to do **automatic backpropagation & optimization**. These features let us specify only the forward propagation logic, and it handles the rest." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction to Tensors\n", "### Tensors are an abstract form of Scalars, Vectors, & Matrices\n", "\n", "We should recall that a matrix is a list of vectors, and that a vector is a list of scalars. Based on this, a tensor is the abstract version of scalars, vectors, matrices, and any type of **array**. So:\n", "- A Vector is a 1-dimensional Tensor.\n", "- A Matrix is a 2-dimensional Tensor.\n", "- Higher dimensions are referred to as n-dimensional tensors.\n", "\n", "The beginning of a new deep learning framework is the definition of this new type: the `Tensor`. Let's implement it:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np \n", "\n", "class Tensor(object):\n", " def __init__(self, data):\n", " # Storing Tensor Information in `self.data` as a NumPy Array\n", " self.data = np.array(data)\n", " \n", " def __add__(self, other):\n", " return Tensor(self.data + other.data)\n", " \n", " def __sub__(self, other):\n", " return Tensor(self.data - other.data)\n", " \n", " def __mul__(self, other):\n", " return Tensor(np.matmul(self.data, other.data))\n", "\n", " def __repr__(self):\n", " return str(self.data.__repr__())\n", "\n", " def __str__(self):\n", " return str(self.data.__str__())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Btw, what's the difference between `__repr__()` and `__str__()` in Python?\n", "- `__repr__()`'s goal is to be **unambiguous**. It is invoked when simply inspecting the object on the console.\n", "- `__str__()` goal is to be **readable**. It is invoked when print(object)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "x = Tensor([1,2,3,4,5])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 2, 3, 4, 5])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x # invoking __repr__()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 2 3 4 5]\n" ] } ], "source": [ "print(x) # invoking __str__()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "y = x + x" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 2 4 6 8 10]\n" ] } ], "source": [ "print(y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the first version of this basic data structure. We should note that it stores all the numerical information in a `NumPy` Array (`self.data`) and supports element-wise operations. Adding more operations is relatively simple: we create more functions on the tensor class with the appropriate functionality." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction to automatic gradient computation\n", "### Previosuly, we've performed backpropagation by hand. Let's make it automatic!\n", "\n", "Previously, we've computed derivatives by hand for each network we trained. Recall that this is done by moving backwards through the neural network:\n", "1. Compute the gradient of the output of the network.\n", "2. Use the result to compute the gradients at the next-to-last component.\n", ".. and so on until all weights in the architecture have corrent gradients. This logic for computing gradients can also be added to the tensor object:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "class Tensor(object):\n", " def __init__(self, data, creators=None, creation_op=None):\n", " self.data = np.array(data)\n", " self.creation_op = creation_op\n", " self.creators = creators\n", " self.grad = None\n", " \n", " def __add__(self, other):\n", " return Tensor(self.data + other.data, \n", " creators=[self,other], \n", " creation_op=\"+\")\n", " \n", " def __sub__(self, other):\n", " return Tensor(self.data - other.data,\n", " creators=[self, other],\n", " creation_op=\"-\")\n", " \n", " def __mul__(self, other):\n", " return Tensor(self.data*other.data,\n", " creators=[self, other],\n", " creation_op=\"*\")\n", "\n", " def __repr__(self):\n", " return str(self.data.__repr__())\n", "\n", " def __str__(self):\n", " return str(self.data.__str__())\n", " \n", " def backward(self, grad):\n", " self.grad = grad\n", " if (self.creation_op == \"+\"):\n", " self.creators[0].backward(grad)\n", " self.creators[1].backward(grad)\n", " elif (self.creation_op == \"-\"):\n", " self.creators[0].backward(grad)\n", " self.creators[1].backward(Tensor([-1]) * grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, all Tensors have Gradients. Let's experiment with the new functionalities:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "a = Tensor([1,2,3])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None None None\n" ] } ], "source": [ "print(a.creators, a.creation_op, a.grad)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "b = Tensor([4,5,6])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "c = a + b" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[array([1, 2, 3]), array([4, 5, 6])] | + | None\n" ] } ], "source": [ "print(c.creators, \" | \", c.creation_op, \" | \", c.grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we `.backward()` on a tensor resulting from an addition of 2 Tensors, we should assign the same given grad to its parents. This is because the backward operation doesn't actually calculate the gradient at the other node, but calculates the gradient **with respect** to the node at the back." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "x = Tensor([1,2,3,4,5])\n", "y = Tensor([2,2,2,2,2])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "z = x + y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we calculate $\\nabla x$ and $\\nabla y$ given that we have $\\nabla z$:\n", "\n", "We note that since we are dealing with the `+` operator, we have $\\frac{\\partial z}{\\partial x}=1$ and $\\frac{\\partial z}{\\partial y}=1$." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "z.backward([1,1,1,1,1])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([1, 1, 1, 1, 1], [1, 1, 1, 1, 1])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.grad, y.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each Tensor Gets 3 new Attributes:\n", "- **Creation operation**: the operation that created the current Tensor (`+`, `-`, `*`, ..)\n", "- **Creators**: A list of tensors that contributed to the creation of the current tensor.\n", "- **Grad**: the generalization of the derivative of the current tensor.\n", "\n", "Performing `z = x + y` creates a computation graph, with 3 nodes (`x`, `y`, & `z`) and 2 edges (`z -> `x, & `z -> y`). Each edge is labeled by the `creation_op` `add`. This graph allows us to recursively backpropagate gradients.\n", " \n", "
\n", "\n", "The first new concept of this implementation is the **automatic creation of graphs whenever we perform operations**. If we took `z` and performed further operations, the graph will continue to be constructed. The second new concept introduced is the automatic recursive gradient calculations that will allow us to calculate the derivative of the original tensor with respect to any of its corresponding connected nodes.\n", "\n", "Perhaps the most elegent part of this form of `autograd` is that it works recursively as well, because each node calls `.backward()` on all of its `self.creators`." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Tensor Definition\n", "a = Tensor([1,2,3,4,5])\n", "b = Tensor([2,2,2,2,2])\n", "c = Tensor([5,4,3,2,1])\n", "d = Tensor([-1,-2,-3,-4,-5])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Compute Graph Creation & Forward Propagation\n", "e = a + b\n", "f = c + d\n", "g = e + f" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Back Propagation\n", "g.backward(Tensor([1,1,1,1,1]))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 1, 1, 1, 1])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A Quick Checkpoint\n", "### Everything in Tensor is another form of Lessons already Learned.\n", "\n", "The previous implementation is nothing new compared with what we've already been working with. Previously, we've hard-coded the forward and backpropagation steps, now it's time to automate and generalize these processes.\n", "\n", "The Notion of a graph that gets built during the forward propagation task is called a *Dynamic computation graph* because it's built on the fly during the forward prop step. This is the type of **autograd** present in newer deep learning frameworks such as **DyNet** and **PyTorch**. Older frameworks such as **Theano** and **TensorFlow** have what's called a **Static Computation Graph**, which is specified before forward propagation even begins.\n", "\n", "In general, dynamic computation graphs are easier to architect and experiment with, and static computation graphs are faster at runtime because of some fancy logic under the hood. We should note that dynamic and static based frameworks have been moving towards the middle. In this Book, We'll Stick with Dynamic graphs.\n", "\n", "Debugging these frameworks can be extremely difficult at times, because most bugs don't raise an error, the model seems like it's training, but It's not." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensors That are Used Multiple Times\n", "### The Basic Autograd has a rather pesky bug, Let's Squish it!\n", "\n", "The current version of `Tensor` supports backpropagating into a variable only once, but sometimes during forward propagation, we'll use the same tensor multiple times and thus multiple parts of the graph will back propagate gradients into the same Tensor.\n", "\n", "Here is an example:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "a = Tensor([1,2,3,4,5])\n", "b = Tensor([2,2,2,2,2])\n", "c = Tensor([5,4,3,2,1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "d = a + b\n", "e = b + e\n", "f = d + e" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "f.backward(Tensor([1,1,1,1,1]))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[False False False False False]\n" ] } ], "source": [ "print(b.grad.data == [2,2,2,2,2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The variable `b` is used twice in the creation of `Tensor` `f`. Its gradient should be the sum of the two derivatives: `[2,2,2,2,2]`.\n", "\n", "We need to fix the current implementation of our `Tensor` to not merely overwrite gradients of the previous nodes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's remember how gradients flow through a simple example:\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Upgrading Autograd to Support Multiple Tensors\n", "### Add one new function, and update 3 old ones" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "%load_ext line_profiler" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "class Tensor(object):\n", " def __init__(self, data, autograd=False, parents=None, creation_op=None, id=None):\n", " self.data = np.array(data)\n", " self.autograd = autograd\n", " self.parents = parents\n", " self.children = {}\n", " self.creation_op = creation_op\n", " self.grad = None\n", " if (id is None): id = np.random.randint(100)\n", " self.id = id\n", " \n", " # Updates Parents' Children\n", " # For your library, don't do this until you need it.\n", " # A simpler solution is to just keep a counter for the number of grads that must be passed while doing back propagation\n", " if (parents is not None):\n", " for parent in parents:\n", " if (self.id not in parent.children):\n", " # 1 is the number of grads to be passed from self to parent\n", " parent.children[self] = 1\n", " else:\n", " parent.children[self] += 1\n", " \n", " def all_grads_propagated(self):\n", " for _, grads_count in self.children.items():\n", " if (grads_count != 0): return False\n", " return True\n", " \n", " def __add__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data + other.data, \n", " autograd=True,\n", " parents=[self, other], \n", " creation_op=\"+\")\n", " return Tensor(self.data + other.data)\n", "\n", " def __repr__(self):\n", " return str('Tensor(' + self.id.__repr__() + ')')\n", "\n", " def __str__(self):\n", " return str(self.data.__str__())\n", " \n", " def backward(self, grad=None, grad_origin=None):\n", " if (self.autograd):\n", " if (grad_origin is not None):\n", " # checks to make sure you can backpropagate or whether you're waiting for a gradient, in which case, decrement the counter\n", " if (self.children[grad_origin] == 0):\n", " raise Exception(\"cannot backprop more than once\")\n", " else:\n", " self.children[grad_origin] -= 1\n", " if (self.grad is None):\n", " self.grad = grad\n", " else:\n", " self.grad += grad\n", " if ((self.parents is not None) and (self.all_grads_propagated() or grad_origin is None)):\n", " if (self.creation_op == \"+\"):\n", " # begins actual back propagation\n", " self.parents[0].backward(self.grad, grad_origin=self)\n", " self.parents[1].backward(self.grad, grad_origin=self)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "a = Tensor([1,2,3,4,5], autograd=True)\n", "b = Tensor([2,2,2,2,2], autograd=True)\n", "c = Tensor([5,4,3,2,1], autograd=True)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "d = a + b\n", "e = b + c\n", "f = d + e" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "f.backward(Tensor([1,1,1,1,1]))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 1, 1, 1, 1])" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f.grad.data" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([1, 1, 1, 1, 1]), array([1, 1, 1, 1, 1]))" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "e.grad.data, d.grad.data" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([1, 1, 1, 1, 1]), array([2, 2, 2, 2, 2]), array([1, 1, 1, 1, 1]))" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a.grad.data, b.grad.data, c.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's implement a Tensor without using the book's quirky multiple \"connections/grads\" logic, we'll just use a simple counter to make sure all gradients are backpropagated before passing the local grad to the parent nodes: " ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "class Tensor(object):\n", " def __init__(self, data, autograd=False, parents=list(), creation_op=None, id=None):\n", " self.data = np.array(data)\n", " # make all tensors 2-D+ for matrice multiplication\n", " if (len(self.data.shape)==0): self.data.resize((1, 1))\n", " elif (len(self.data.shape)==1): self.data.resize((1, self.data.shape[0]))\n", " self.autograd = autograd\n", " self.parents = parents\n", " self.children = list()\n", " self.creation_op = creation_op\n", " self.grad = None\n", " if (self.autograd): self.required_grads = 0\n", " if id is None: id = np.random.randint(100)\n", " self.id = id\n", " \n", " # when this object is created, assign him as child to his parents & increment required_grads\n", " if (parents != []):\n", " for parent in parents:\n", " parent.children.append(self)\n", " if (self.autograd and parent.autograd):\n", " parent.required_grads += 1\n", " \n", " def __add__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data+other.data, \n", " autograd=True, \n", " parents=[self, other], \n", " creation_op='+')\n", " return Tensor(self.data+other.data)\n", " \n", " def __repr__(self):\n", " return str('Tensor(' + self.id.__repr__() + ')')\n", " \n", " def __str__(self):\n", " return str(self.data.__str__())\n", " \n", " def backward(self, grad=None, grad_origin=None):\n", " if (self.autograd):\n", " if (self.grad is None):\n", " self.grad = grad\n", " else:\n", " self.grad += grad\n", " self.required_grads -= 1\n", " if ((self.parents != []) and (self.required_grads==0) or (grad_origin==None)):\n", " if (self.creation_op == '+'):\n", " self.parents[0].backward(self.grad, grad_origin=self)\n", " self.parents[1].backward(self.grad, grad_origin=self)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "a = Tensor([1,2,3,4,5], autograd=True)\n", "b = Tensor([2,2,2,2,2], autograd=True)\n", "c = Tensor([5,4,3,2,1], autograd=True)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Tensor(48), Tensor(95), Tensor(94))" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a, b, c" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "d = a + b\n", "e = b + c" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Tensor(86), [Tensor(48), Tensor(95)], 0)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d, d.parents, d.required_grads" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Tensor(48), [Tensor(86)], 1)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a, a.children, a.required_grads" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Tensor(95), [Tensor(86), Tensor(60)], 2)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b, b.children, b.required_grads" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "f = d + e" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "f.backward(Tensor([1,1,1,1,1]))" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[2, 2, 2, 2, 2]])" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's go back to the book's Tensor Implementation:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create a `self.children` counter that counts the number of gradients received from each child during back propagation. This way, we also prevent a variable from accidentally backpropagating from the same child twice (which throws an exception).\n", "\n", "Previously, whenever we called `.backward()`, the object calls `.backward()` on its parents. But in this case, we want the child to first receive all of its gradients before backpropagating them to its parents. None of these concepts are new from a deep learning theory perspective. These are engineering challenges that deep learning frameworks seek to face. We'll face them when debugging deep learning neural networks in a standard framework." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How does addition backpropagation work?\n", "### Let's study the abstraction to learn how to add support for more functions\n", "\n", "We can now add support for arbitrary operations by ..\n", "- Adding the function to the Tensor Class\n", "- Adding its derivative to the `.backward()` method.\n", "\n", "Backpropagation is skipped if the variable has `autograd` turned off." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adding Support for Negation\n", "### Let's Modify the support of addition to also support Negation" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "class Tensor(object):\n", " def __init__(self, data, autograd=False, parents=None, creation_op=None, id=None):\n", " self.data = np.array(data)\n", " self.autograd = autograd\n", " self.parents = parents\n", " self.children = {}\n", " self.creation_op = creation_op\n", " self.grad = None\n", " if (id is None): id = np.random.randint(100)\n", " self.id = id\n", " \n", " # Updates Parents' Children\n", " # For your library, don't do this until you need it.\n", " # A simpler solution is to just keep a counter for the number of grads that must be passed while doing back propagation\n", " if (parents is not None):\n", " for parent in parents:\n", " if (self.id not in parent.children):\n", " # 1 is the number of grads to be passed from self to parent\n", " parent.children[self] = 1\n", " else:\n", " parent.children[self] += 1\n", " \n", " def all_grads_propagated(self):\n", " for _, grads_count in self.children.items():\n", " if (grads_count != 0): return False\n", " return True\n", " \n", " def __add__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data + other.data, \n", " autograd=True,\n", " parents=[self, other], \n", " creation_op=\"+\")\n", " return Tensor(self.data + other.data)\n", "\n", " def __neg__(self):\n", " if (self.autograd):\n", " return Tensor(self.data * -1,\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"neg\")\n", " return Tensor(self.data * -1)\n", "\n", " def __repr__(self):\n", " return str('Tensor(' + self.id.__repr__() + ')')\n", "\n", " def __str__(self):\n", " return str(self.data.__str__())\n", " \n", " def backward(self, grad=None, grad_origin=None):\n", " if (self.autograd):\n", " if (grad_origin is not None):\n", " # checks to make sure you can backpropagate or whether you're waiting for a gradient, in which case, decrement the counter\n", " if (self.children[grad_origin] == 0):\n", " raise Exception(\"cannot backprop more than once\")\n", " else:\n", " self.children[grad_origin] -= 1\n", " if (self.grad is None):\n", " self.grad = grad\n", " else:\n", " self.grad += grad\n", " if ((self.parents is not None) and (self.all_grads_propagated() or grad_origin is None)):\n", " if (self.creation_op == \"+\"):\n", " self.parents[0].backward(self.grad, grad_origin=self)\n", " self.parents[1].backward(self.grad, grad_origin=self)\n", " if (self.creation_op == \"neg\"):\n", " self.parents[0].backward(self.grad.__neg__())" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "a = Tensor([1,2,3,4,5], autograd=True)\n", "b = Tensor([2,2,2,2,2], autograd=True)\n", "c = Tensor([5,4,3,2,1], autograd=True)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "d = a + (-b)\n", "e = (-b) + c\n", "f = d + e" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "f.backward(Tensor([1,1,1,1,1]))" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-2, -2, -2, -2, -2])" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's add some more:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adding Support for Additional functions\n", "### Subtraction, Multiplication, Sum, Expand, Transpose, and Matrix Multiplication" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "class Tensor(object):\n", " def __init__(self, data, autograd=False, parents=None, creation_op=None, id=None):\n", " self.data = np.array(data)\n", " self.autograd = autograd\n", " self.parents = parents\n", " self.children = {}\n", " self.creation_op = creation_op\n", " self.grad = None\n", " if (id is None): id = np.random.randint(100)\n", " self.id = id\n", " \n", " if (parents is not None):\n", " for parent in parents:\n", " if (self.id not in parent.children):\n", " parent.children[self] = 1\n", " else:\n", " parent.children[self] += 1\n", " \n", " def all_grads_propagated(self):\n", " for _, grads_count in self.children.items():\n", " if (grads_count != 0): return False\n", " return True\n", " \n", " def __add__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data + other.data, \n", " autograd=True,\n", " parents=[self, other], \n", " creation_op=\"+\")\n", " return Tensor(self.data + other.data)\n", " \n", " def __sub__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data-other.data,\n", " autograd=True,\n", " parents=[self, other],\n", " creation_op=\"-\")\n", " return Tensor(self.data-other.data)\n", " \n", " def __mul__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data * other.data,\n", " autograd=True,\n", " parents=[self, other],\n", " creation_op=\"*\")\n", " return Tensor(self.data * other.data)\n", " \n", " def sum(self, dim):\n", " if (self.autograd):\n", " return Tensor(self.data.sum(dim),\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"sum_\" + str(dim))\n", " return Tensor(self.data.sum(dim))\n", "\n", " def __neg__(self):\n", " if (self.autograd):\n", " return Tensor(self.data * -1,\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"neg\")\n", " return Tensor(self.data * -1)\n", "\n", " def __repr__(self):\n", " return str('Tensor(' + self.id.__repr__() + ')')\n", "\n", " def __str__(self):\n", " return str(self.data.__str__())\n", " \n", " def expand(self, dim, copies):\n", " trans_cmd = list(range(0, len(self.data.shape)))\n", " trans_cmd.insert(dim, len(self.data.shape))\n", " new_shape = list(self.data.shape) + [copies]\n", " new_data = self.data.repeat(copies).reshape(new_shape)\n", " new_data = new_data.transpose(trans_cmd)\n", " \n", " if (self.autograd):\n", " return Tensor(new_data, \n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"expand_\"+str(dim))\n", " return Tensor(new_data)\n", " \n", " def transpose(self):\n", " if (self.autograd):\n", " return Tensor(self.data.transpose(),\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"T\")\n", " return Tensor(self.data.transpose())\n", " \n", " def mm(self, x):\n", " if (self.autograd):\n", " return Tensor(self.data.dot(x.data),\n", " autograd=True,\n", " parents=[self, x],\n", " creation_op=\"mm\")\n", " return Tensor(self.data.dot(x.data))\n", " \n", " def backward(self, grad=None, grad_origin=None):\n", " if (self.autograd):\n", " if (grad == None):\n", " grad = Tensor(np.ones_like(self.data))\n", " if (grad_origin is not None):\n", " if (self.children[grad_origin] == 0):\n", " raise Exception(\"cannot backprop more than once\")\n", " else:\n", " self.children[grad_origin] -= 1\n", " if (self.grad is None):\n", " self.grad = grad\n", " else:\n", " self.grad += grad\n", " if ((self.parents is not None) and (self.all_grads_propagated() or grad_origin is None)):\n", " if (self.creation_op == \"+\"):\n", " self.parents[0].backward(self.grad, grad_origin=self)\n", " self.parents[1].backward(self.grad, grad_origin=self)\n", " if (self.creation_op == \"neg\"):\n", " self.parents[0].backward(self.grad.__neg__())\n", " if (self.creation_op == '-'):\n", " self.parents[0].backward(self.grad, grad_origin=self)\n", " self.parents[1].backward(self.grad.__neg__(), grad_origin=self)\n", " if (self.creation_op == '*'):\n", " self.parents[0].backward(self.grad*self.parents[1], grad_origin=self)\n", " self.parents[1].backward(self.grad*self.parents[0], grad_origin=self)\n", " if (self.creation_op == 'mm'):\n", " activation = self.parents[0] # usually an activation function\n", " weights = self.parents[1] # usually a weights matrix\n", " activation.backward(self.grad.mm(weights.transpose()))\n", " weights.backward(self.grad.transpose().mm(activation).transpose())\n", " if (self.creation_op == 'T'):\n", " self.parents[0].backward(self.grad.transpose())\n", " if (\"sum\" in self.creation_op):\n", " dim = int(self.creation_op.split(\"_\")[1])\n", " ds = self.parents[0].data.shape[dim]\n", " self.parents[0].backward(self.grad.expand(dim, ds))\n", " if (\"expand\" in self.creation_op):\n", " dim = int(self.creation_op.split(\"_\")[1])\n", " self.parents[0].backward(self.grad.sum(dim))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We should remember that `sum()` removes a dimension & `expand()` adds a dimension.\n", "\n", "If we expand to the last dimension, it will copy single values along that last dimension. Each entry of the original Tensor becomes a list of entries instead. Thus, when we perform `.sum(dim=1)` on a tensor with four entries in that dimension, we need to perform `.expand(dim=1, copies=4)` **to the gradient when we backpropagate it**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We should understand how to take derivatives of Matrix Multiplication. The starting Point:\n", "\n", "
\n", "\n", "The gradients start at the end of the network.\n", "The following figure explains how back propagation works for FC layers:\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using autograd to train a Neural Network\n", "### We no longer have to write backpropagation logic!\n", "\n", "We have to forward propagate in such a way that `layer_1` and `layer_2` and `diff` exist as variables, because we would need them later. We then have to backpropagate each gradient to its appropriate weight matrix and perform the weight update appropriately." ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(0)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "x = Tensor(np.array([[0,0], [0,1], [1,0], [1,1]]), autograd=True)\n", "y = Tensor(np.array([[0], [1], [0], [1]]), autograd=True)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "weights = list()\n", "weights.append(Tensor(np.random.rand(2,3), autograd=True))\n", "weights.append(Tensor(np.random.rand(3,1), autograd=True))" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.58128304]\n", "[0.48988149]\n", "[0.41375111]\n", "[0.34489412]\n", "[0.28210124]\n", "[0.2254484]\n", "[0.17538853]\n", "[0.1324231]\n", "[0.09682769]\n", "[0.06849361]\n" ] } ], "source": [ "for i in range(10): # epochs\n", " y_hat = x.mm(weights[0]).mm(weights[1]) # predict\n", " loss = ((y_hat - y)*(y_hat - y)).sum(0) # compare\n", " loss.backward(Tensor(np.ones_like(loss.data))) # learn, feeding an initial gradient of 1 to the loss\n", " \n", " for weight in weights:\n", " weight.data -= weight.grad.data * 0.1\n", " weight.grad.data *= 0\n", " \n", " print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We forward propagate over the loss computation graph, we then back propagate feeding an initial gradient of 1. With the fancy new autograd system, the code is much simpler.\n", "\n", "When we have an autograd system, stochastic gradient descent becomes trivial to implement.\n", "\n", "Let's try making it its own class as well:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adding Automatic Optimization\n", "### Let's make a Stochastic Gradient Descent Optimizer" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "class SGD(object):\n", " def __init__(self, parameters, alpha):\n", " self.parameters = parameters\n", " self.alpha = alpha\n", " \n", " def zero(self):\n", " for p in self.parameters:\n", " p.grad.data *= 0\n", " \n", " def step(self, zero=True):\n", " for p in self.parameters:\n", " p.data -= p.grad.data * self.alpha\n", " if (zero):\n", " p.grad.data *= 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The previous neural network is further simplified as follows, with exactly the same results as before:" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(0)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "x = Tensor(np.array([[0,0], [0,1], [1,0], [1,1]]), autograd=True)\n", "y = Tensor(np.array([[0], [1], [0], [1]]), autograd=True)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "weights = list()\n", "weights.append(Tensor(np.random.rand(2,3), autograd=True))\n", "weights.append(Tensor(np.random.rand(3,1), autograd=True))" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "optimizer = SGD(weights, 0.1)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.58128304]\n", "[0.48988149]\n", "[0.41375111]\n", "[0.34489412]\n", "[0.28210124]\n", "[0.2254484]\n", "[0.17538853]\n", "[0.1324231]\n", "[0.09682769]\n", "[0.06849361]\n" ] } ], "source": [ "for i in range(10): # epochs\n", " y_hat = x.mm(weights[0]).mm(weights[1]) # forward propagation\n", " loss = ((y_hat - y)*(y_hat - y)).sum(0) # compare\n", " loss.backward(Tensor(np.ones_like(loss.data))) # back propagation, feeding an initial gradient of 1 to the loss\n", " optimizer.step() # learn\n", " print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adding Support for Layer Types\n", "### Layer types are also present in Keras & PyTorch\n", "\n", "Probably the most common abstraction among all deep learning framework abstraction is **the layer abstraction**. It's a collection of commonly used forward propagation techniques packaged into a simple API with some kind of forward method to call them.\n", "\n", "Here is an example of a simple Linear Layer:" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "class Layer(object):\n", " def __init__(self):\n", " self.parameters = list()\n", " \n", " def get_parameters(self):\n", " return self.parameters" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "class Linear(Layer):\n", " def __init__(self, n_inputs, n_outputs):\n", " super().__init__()\n", " W = np.random.randn(n_inputs, n_outputs)*np.sqrt(2.0/n_inputs)\n", " self.weight = Tensor(W, autograd=True)\n", " self.bias = Tensor(np.zeros(n_outputs), autograd=True)\n", " self.parameters.append(self.weight)\n", " self.parameters.append(self.bias)\n", " \n", " def forward(self, input):\n", " # expand for broadcasting\n", " return input.mm(self.weight)+self.bias.expand(0,len(input.data))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The weights are organized into a class, and we need to add a bias matrix because this is a true Linear layer. We can initialize the layers all together, such that the weights and biases are initialized in the correct sizes & the correct forward propagation logic is always employed.\n", "\n", "We created an abstract `Layer` class which will allow for more complicated layers (example: layers the contain other layers)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Layers that Contain Layers\n", "### Layers can also contain other Layers\n", "\n", "The most popular layer is a sequential layer that forward propagates a list of layers, where each layer feeds its output to the input of the next layer:" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "class Sequential(Layer):\n", " def __init__(self, layers=list()):\n", " super().__init__()\n", " self.layers = layers\n", " \n", " def add(self, layer):\n", " self.layers.append(layer)\n", " \n", " def forward(self, input):\n", " for layer in self.layers:\n", " input = layer.forward(input)\n", " return input\n", "\n", " def get_parameters(self):\n", " params = list()\n", " for layer in self.layers:\n", " params += layer.get_parameters()\n", " return params" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "x = Tensor(np.array([[0, 0], [0,1], [1,0], [1,1]]), autograd=True)\n", "y = Tensor(np.array([[1], [0], [1], [0]]), autograd=True)" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "model = Sequential([Linear(2,3), Linear(3,1)])" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "optimizer = SGD(model.get_parameters(), alpha=0.05)" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.39435371]\n", "[1.03442471]\n", "[0.80333761]\n", "[0.60197476]\n", "[0.46415449]\n", "[0.34221874]\n", "[0.25943595]\n", "[0.1908049]\n", "[0.14431529]\n", "[0.10740191]\n" ] } ], "source": [ "for i in range(10): # epochs\n", " y_hat = model.forward(x) # forward propagation\n", " loss = ((y_hat - y)*(y_hat - y)).sum(0) # loss\n", " loss.backward(Tensor(np.ones_like(loss.data))) # back propagation\n", " optimizer.step()\n", " print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is very Similar to PyTorch, Amazing!\n", "\n", "Let's also implement loss functions as layers:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loss-function layers\n", "### Some Layers have no Weights\n", "\n", "We can also create layers that are functions on the input. The most popular version of this kind of layer is *mean squared error*:" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "class MSELoss(Layer):\n", " def __init__(self):\n", " super().__init__()\n", " \n", " def forward(self, y_hat, y):\n", " return ((y_hat - y) * (y_hat - y)).sum(0)" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(0)" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "x = Tensor(np.array([[0, 0], [0,1], [1,0], [1,1]]), autograd=True)\n", "y = Tensor(np.array([[1], [0], [1], [0]]), autograd=True)" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "model = Sequential([Linear(2,3), Linear(3,1)])\n", "loss = MSELoss()\n", "optimizer = SGD(parameters=model.get_parameters(), alpha=0.05)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.6813686]\n", "[0.95192748]\n", "[0.72454581]\n", "[0.57489823]\n", "[0.45840608]\n", "[0.36465316]\n", "[0.28883237]\n", "[0.22760439]\n", "[0.17835522]\n", "[0.13895393]\n" ] } ], "source": [ "for i in range(10):\n", " y_hat = model.forward(x)\n", " l = loss.forward(y_hat, y)\n", " l.backward(Tensor(np.ones_like(l.data)))\n", " optimizer.step()\n", " print(l)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Autograd is doing all of the back propagation and the forward propagation steps are organized in classes to ensure smooth propagation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to learn a framework \n", "### Oversimplied, Frameworks are autograd + prebuilt layers and optimizers\n", "\n", "Autograd will ensure that we can piece together different types of layers without losing sights of the underlying relationships and flowing gradients. This is the main feature of modern frameworks. They eliminate the need to handwrite every mathematical operation for forward/backward propagation.\n", "\n", "Viewing a framework as merely an autograd system + a list of layers, loss functions, optimizers will help us learn them. We should take a moment to read through the list of layers and optimizers for the different frameworks we have:\n", "- [PyTorch](https://pytorch.org/docs/stable/nn.html)\n", "- [Keras](https://keras.io/layers/about-keras-layers)\n", "- [Tensorflow](https://www.tensorflow.org/api_docs/python/tf/layers)\n", "\n", "We've added a quick hack to be able to call `.backward()` on loss without passing a Tensor of `1`s everytime." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Nonlinearity Layers\n", "### Let's add some nonlinear functions to Tensor then create some layer types\n", "\n", "For the next chapter, we'll need `.sigmoid()` & `.tanh()`, Let's add them to the tensor class:" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "class Tensor(object):\n", " def __init__(self, data, autograd=False, parents=None, creation_op=None, id=None):\n", " self.data = np.array(data)\n", " self.autograd = autograd\n", " self.parents = parents\n", " self.children = {}\n", " self.creation_op = creation_op\n", " self.grad = None\n", " if (id is None): id = np.random.randint(100)\n", " self.id = id\n", " \n", " if (parents is not None):\n", " for parent in parents:\n", " if (self.id not in parent.children):\n", " parent.children[self] = 1\n", " else:\n", " parent.children[self] += 1\n", " \n", " def all_grads_propagated(self):\n", " for _, grads_count in self.children.items():\n", " if (grads_count != 0): return False\n", " return True\n", " \n", " def __add__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data + other.data, \n", " autograd=True,\n", " parents=[self, other], \n", " creation_op=\"+\")\n", " return Tensor(self.data + other.data)\n", " \n", " def __sub__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data-other.data,\n", " autograd=True,\n", " parents=[self, other],\n", " creation_op=\"-\")\n", " return Tensor(self.data-other.data)\n", " \n", " def __mul__(self, other):\n", " if (self.autograd and other.autograd):\n", " return Tensor(self.data * other.data,\n", " autograd=True,\n", " parents=[self, other],\n", " creation_op=\"*\")\n", " return Tensor(self.data * other.data)\n", " \n", " def sum(self, dim):\n", " if (self.autograd):\n", " return Tensor(self.data.sum(dim),\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"sum_\" + str(dim))\n", " return Tensor(self.data.sum(dim))\n", "\n", " def __neg__(self):\n", " if (self.autograd):\n", " return Tensor(self.data * -1,\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"neg\")\n", " return Tensor(self.data * -1)\n", "\n", " def __repr__(self):\n", " return str('Tensor(' + self.id.__repr__() + ')')\n", "\n", " def __str__(self):\n", " return str(self.data.__str__())\n", " \n", " def expand(self, dim, copies):\n", " trans_cmd = list(range(0, len(self.data.shape)))\n", " trans_cmd.insert(dim, len(self.data.shape))\n", " new_shape = list(self.data.shape) + [copies]\n", " new_data = self.data.repeat(copies).reshape(new_shape)\n", " new_data = new_data.transpose(trans_cmd)\n", " \n", " if (self.autograd):\n", " return Tensor(new_data, \n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"expand_\"+str(dim))\n", " return Tensor(new_data)\n", " \n", " def transpose(self):\n", " if (self.autograd):\n", " return Tensor(self.data.transpose(),\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"T\")\n", " return Tensor(self.data.transpose())\n", " \n", " def mm(self, x):\n", " if (self.autograd):\n", " return Tensor(self.data.dot(x.data),\n", " autograd=True,\n", " parents=[self, x],\n", " creation_op=\"mm\")\n", " return Tensor(self.data.dot(x.data))\n", " \n", " def sigmoid(self):\n", " if (self.autograd):\n", " return Tensor(1/(1+np.exp(-self.data)),\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"sigmoid\")\n", " return Tensor(1/(1+np.exp(-self.data)))\n", " \n", " def tanh(self):\n", " if (self.autograd):\n", " return Tensor(np.tanh(self.data),\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"tanh\")\n", " return Tensor(np.tanh(self.data))\n", " \n", " def index_select(self, indices):\n", " if (self.autograd):\n", " new = Tensor(self.data[indices.data],\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"index_select\")\n", " new.index_select_indices = indices\n", " return new\n", " return Tensor(self.data[indices.data])\n", " \n", " def cross_entropy(self, target_indices):\n", " temp = np.exp(self.data)\n", " softmax_output = temp / np.sum(temp, axis=len(self.data.shape)-1, keepdims=True)\n", " t = target_indices.data.flatten()\n", " p = softmax_output.reshape(len(t), -1)\n", " target_dist = np.eye(p.shape[1])[t]\n", " loss = - (np.log(p) * (target_dist)).sum(1).mean()\n", "\n", " if (self.autograd):\n", " out = Tensor(loss,\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"cross_entropy\")\n", " out.softmax_output = softmax_output\n", " out.target_dist = target_dist\n", " return out\n", " return Tensor(loss)\n", " \n", " def backward(self, grad=None, grad_origin=None):\n", " if (self.autograd):\n", " if (grad == None):\n", " grad = Tensor(np.ones_like(self.data))\n", " if (grad_origin is not None):\n", " if (self.children[grad_origin] == 0):\n", " raise Exception(\"cannot backprop more than once\")\n", " else:\n", " self.children[grad_origin] -= 1\n", " if (self.grad is None):\n", " self.grad = grad\n", " else:\n", " self.grad += grad\n", " if ((self.parents is not None) and (self.all_grads_propagated() or grad_origin is None)):\n", " if (self.creation_op == \"+\"):\n", " self.parents[0].backward(self.grad, grad_origin=self)\n", " self.parents[1].backward(self.grad, grad_origin=self)\n", " if (self.creation_op == \"neg\"):\n", " self.parents[0].backward(self.grad.__neg__())\n", " if (self.creation_op == '-'):\n", " self.parents[0].backward(self.grad, grad_origin=self)\n", " self.parents[1].backward(self.grad.__neg__(), grad_origin=self)\n", " if (self.creation_op == '*'):\n", " self.parents[0].backward(self.grad*self.parents[1], grad_origin=self)\n", " self.parents[1].backward(self.grad*self.parents[0], grad_origin=self)\n", " if (self.creation_op == 'mm'):\n", " activation = self.parents[0] # usually an activation function\n", " weights = self.parents[1] # usually a weights matrix\n", " activation.backward(self.grad.mm(weights.transpose()))\n", " weights.backward(self.grad.transpose().mm(activation).transpose())\n", " if (self.creation_op == 'T'):\n", " self.parents[0].backward(self.grad.transpose())\n", " if (\"sum\" in self.creation_op):\n", " dim = int(self.creation_op.split(\"_\")[1])\n", " ds = self.parents[0].data.shape[dim]\n", " self.parents[0].backward(self.grad.expand(dim, ds))\n", " if (\"expand\" in self.creation_op):\n", " dim = int(self.creation_op.split(\"_\")[1])\n", " self.parents[0].backward(self.grad.sum(dim))\n", " if (self.creation_op == 'sigmoid'):\n", " ones = Tensor(np.ones_like(self.grad.data))\n", " self.parents[0].backward(self.grad * (self * (ones - self)))\n", " if (self.creation_op == 'tanh'):\n", " ones = Tensor(np.ones_like(self.grad.data))\n", " self.parents[0].backward(self.grad * (ones - (self * self)))\n", " if (self.creation_op == 'index_select'):\n", " new_grad = np.zeros_like(self.parents[0].data)\n", " indices_ = self.index_select_indices.data.flatten()\n", " grad_ = grad.data.reshape(len(indices_), -1)\n", " for i in range(len(indices_)):\n", " new_grad[indices_[i]] += grad_[i]\n", " self.parents[0].backward(Tensor(new_grad))\n", " if (self.creation_op == 'cross_entropy'):\n", " dx = self.softmax_output - self.target_dist\n", " self.parents[0].backward(Tensor(dx))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hopefully, this feels fairly routine:" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "class Tanh(Layer):\n", " def __init__(self):\n", " super().__init__()\n", " \n", " def forward(self, input):\n", " return input.tanh()" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "class Sigmoid(Layer):\n", " def __init__(self):\n", " super().__init__()\n", " \n", " def forward(self, input):\n", " return input.sigmoid()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try out the new nonlinearities:" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [], "source": [ "x = Tensor(np.array([[0, 0], [0,1], [1,0], [1,1]]), autograd=True)\n", "y = Tensor(np.array([[1], [0], [1], [0]]), autograd=True)" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [], "source": [ "model = Sequential([Linear(2,3), Tanh(), Linear(3,1), Sigmoid()])\n", "loss = MSELoss()\n", "optimizer = SGD(parameters=model.get_parameters(), alpha=1)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.10815212]\n", "[0.54905107]\n", "[0.31290284]\n", "[0.18050833]\n", "[0.11220714]\n", "[0.07893695]\n", "[0.06391673]\n", "[0.05358047]\n", "[0.04600529]\n", "[0.04022943]\n" ] } ], "source": [ "for i in range(10):\n", " y_hat = model.forward(x)\n", " l = loss.forward(y_hat, y)\n", " l.backward(Tensor(np.ones_like(l.data)))\n", " optimizer.step()\n", " print(l)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, we can drop the new `Tanh()` and `Sigmoid()` Nonlinearities in `Sequential` and the network knows exactly what to do with them. Next, we'll abstract out and implement RNN layers in our framework, to do that, we need 3 new layer types:\n", "- An embedding layer that learns word embeddings.\n", "- an RNN layer that can learn to model sequences of input.\n", "- a Softmax layer that can predict a probability distribution over labels." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Embedding Layer\n", "### An Embedding Layer translates indices into activations\n", "\n", "Word embeddings are vectors mapped to words that we can forward propagate into a neural network. If we have a vocabulary of 200 words, we'll have `200` embeddings.\n", "\n", "First, let's initialize a list of the right length for word embeddings:" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "class Embedding(Layer):\n", " def __init__(self, vocab_size, dim):\n", " super().__init__()\n", " self.vocab_size = vocab_size\n", " self.dim = dim\n", " # this initialization style is a convention from word2vec\n", " weight = (np.random.rand(vocab_size, dim) - 0.5) / dim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The weight matrix has a row (vector) for each unique word in the vocabulary. Forward propagation always starts with the question \"How will the input the inputs be encoded?\", but we forward propagate word indices, not words:" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 0., 0., 0., 0.],\n", " [0., 1., 0., 0., 0.],\n", " [0., 0., 1., 0., 0.],\n", " [0., 0., 0., 1., 0.],\n", " [0., 0., 0., 0., 1.]])" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "identity = np.eye(5)\n", "identity" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[[0., 1., 0., 0., 0.],\n", " [0., 0., 1., 0., 0.],\n", " [0., 0., 0., 1., 0.],\n", " [0., 0., 0., 0., 1.]],\n", "\n", " [[0., 0., 1., 0., 0.],\n", " [0., 0., 0., 1., 0.],\n", " [0., 0., 0., 0., 1.],\n", " [1., 0., 0., 0., 0.]]])" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "identity[np.array([[1,2,3,4], [2,3,4,0]])]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adding Indexing to Autograd\n", "### Before you can build the embedding layer, Autograd needs to support indexing\n", "\n", "Before doing anything with the embedding layer, autograd must support indexing. We need to make sure that during backpropagation, the gradients are placed in the same rows as were indexed into for forward propagation. This requires that we keep around whatever indices you passed in.\n", "\n", "So we can place each gradient in the appropriate location during back propagation with a simple `for` loop:" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "# Added To class\n", "def index_select(self, indices):\n", " if (self.autograd):\n", " new = Tensor(self.data[indices.data],\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"index_select\")\n", " new.index_select_indices = indices\n", " return new\n", " return Tensor(self.data[indices.data])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, during `.backprop()`, initialize a new gradient of the correct size.\n", "2. Flatten the indices so we can itereate through them.\n", "3. Collapse `grad_` to a simple list of rows.\n", "4. Interate through each index, add it into the correct row of the new gradient you're creating, and backpropagate it into `parents[0]`.\n", "\n", "Example:" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 0., 0., 0., 0.],\n", " [0., 1., 0., 0., 0.],\n", " [0., 0., 1., 0., 0.],\n", " [0., 0., 0., 1., 0.],\n", " [0., 0., 0., 0., 1.]])" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = Tensor(np.eye(5), autograd=True)\n", "x.data" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [], "source": [ "x.index_select(Tensor([[1,2,3], [2,3,4]])).backward()" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 0., 0., 0.],\n", " [1., 1., 1., 1., 1.],\n", " [2., 2., 2., 2., 2.],\n", " [2., 2., 2., 2., 2.],\n", " [1., 1., 1., 1., 1.]])" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Embedding Layer (Revisited)\n", "### Now we can finish forward propagation using the `.index_select()` method" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "class Embedding(Layer):\n", " def __init__(self, vocab_size, dim):\n", " super().__init__()\n", " self.vocab_size = vocab_size\n", " self.dim = dim\n", " \n", " # this initialization style is a convention from word2vec\n", " weight = (np.random.rand(vocab_size, dim) - 0.5) / dim\n", " self.weight = Tensor(weight, autograd=True)\n", " self.parameters.append(self.weight)\n", " \n", " def forward(self, input):\n", " # input is word indices\n", " return self.weight.index_select(input)" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [], "source": [ "x = Tensor(np.array([1,2,1,2]), autograd=True)\n", "y = Tensor(np.array([[0], [1], [0], [1]]), autograd=True)" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [], "source": [ "embed = Embedding(vocab_size=5, dim=3)" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [], "source": [ "model = Sequential([embed, Tanh(), Linear(3,1), Sigmoid()])" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "loss = MSELoss()" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "optimizer = SGD(parameters=model.get_parameters(), alpha=0.5)" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.15164466]\n", "[0.40309683]\n", "[0.20199962]\n", "[0.12654533]\n", "[0.08969294]\n", "[0.06850048]\n", "[0.0549488]\n", "[0.04562442]\n", "[0.03885794]\n", "[0.03374584]\n" ] } ], "source": [ "for i in range(10): # epochs\n", " y_hat = model.forward(x)\n", " l = loss.forward(y_hat, y)\n", " l.backward(Tensor(np.ones_like(l.data)))\n", " optimizer.step()\n", " print(l)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this neural network, we learn to correlate inputs `1` and `2` with predictions `0` and `1`. In theory, indices `1` & `2` could correspond to token indices (like words or characters or objects)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The cross-entropy layer\n", "### Let's add cross entropy to the autograd and create its correponding layer" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "# added to `Tensor`\n", "def cross_entropy(self, target_indices):\n", " temp = np.exp(self.data)\n", " softmax_output = temp / np.sum(temp, axis=len(self.data.shape)-1, keepdims=True)\n", " t = target_indices.data.flatten()\n", " p = softmax_output.reshape(len(t), -1)\n", " target_dist = np.eye(p.shape[1])[t]\n", " loss = - (np.log(p) * (target_dist)).sum(1).mean()\n", " \n", " if (self.autograd):\n", " out = Tensor(loss,\n", " autograd=True,\n", " parents=[self],\n", " creation_op=\"cross_entropy\")\n", " out.softmax_output = softmax_output\n", " out.target_dist = target_dist\n", " return out\n", " return Tensor(loss)" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "# Cross Entropy Layer\n", "class CrossEntropyLoss(object):\n", " def __init__(self):\n", " super().__init__()\n", " \n", " def forward(self, input, target):\n", " return input.cross_entropy(target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An example:" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(0)" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [], "source": [ "x = Tensor(np.array([1,2,1,2]), autograd=True)\n", "y = Tensor(np.array([0,1,0,1]), autograd=True)" ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [], "source": [ "model = Sequential([Embedding(3,3), Tanh(), Linear(3,4)])" ] }, { "cell_type": "code", "execution_count": 98, "metadata": {}, "outputs": [], "source": [ "loss = CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 99, "metadata": {}, "outputs": [], "source": [ "optimizer = SGD(parameters=model.get_parameters(), alpha=0.1)" ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.1885620377278218\n", "0.16792356657516766\n", "0.15086902981053774\n", "0.13661027751993166\n", "0.1245597425659242\n", "0.11427427141364929\n", "0.10541566875575809\n", "0.09772282456212722\n", "0.0909918671530601\n", "0.08506189900243766\n" ] } ], "source": [ "for i in range(10):\n", " y_hat = model.forward(x)\n", " l = loss.forward(y_hat, y)\n", " l.backward(Tensor(np.ones_like(l.data)))\n", " optimizer.step()\n", " print(l)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One noticable thing about this loss that's different from the others is that both the final softmax & the computation of the loss are within the low class.\n", "\n", "When we design a network to be trained using cross entropy, we can leave off the softmax from the forward propagation step and call a cross entropy class that will automatically perform the softmax as part of the loss function. It's much faster to calculate the gradient of softmax and negative-log likelihood together in a cross-entropy function than to forward propagate and backpropagate them separately in two different modules." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Recurrent Neural Network Layer\n", "### By combining several layers, we can learn over time series\n", "\n", "Let's create one more layer that's the composition of multiple smaller layer types. This layer is the **Recurrent Layer**. We'll construct it using 3 Linear layers. \n", "\n", "The `.forward()` method will take both the output from the previous hidden state and the input from the current training data." ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [], "source": [ "class RNNCell(Layer):\n", " def __init__(self, n_inputs, n_hidden, n_output, activation='sigmoid'):\n", " super().__init__()\n", " self.n_inputs = n_inputs\n", " self.n_hidden = n_hidden\n", " self.n_output = n_output\n", " \n", " if (activation == 'sigmoid'):\n", " self.activation = Sigmoid()\n", " elif (activation == 'tanh'):\n", " self.activation = Tanh()\n", " else:\n", " raise Exception(\"Non-Linearity not found\")\n", " \n", " self.w_ih = Linear(n_inputs, n_hidden)\n", " self.w_hh = Linear(n_hidden, n_hidden)\n", " self.w_ho = Linear(n_hidden, n_output)\n", " \n", " self.parameters += self.w_ih.get_parameters()\n", " self.parameters += self.w_hh.get_parameters()\n", " self.parameters += self.w_ho.get_parameters()\n", " \n", " def forward(self, input, hidden):\n", " from_prev_hidden = self.w_hh.forward(hidden)\n", " combined = self.w_ih.forward(input) + from_prev_hidden\n", " new_hidden = self.activation.forward(combined)\n", " output = self.w_ho.forward(new_hidden)\n", " return output, new_hidden\n", "\n", " def init_hidden(self, batch_size=1):\n", " return Tensor(np.zeros((batch_size,self.n_hidden)),autograd=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "RNNs have a state vector that passes from timestep to timestep. In this case, it's the variable `hidden`, which is both an input parameter and an output variable to the forward function.\n", "\n", "RNNs also have several weight matrices:\n", "- One that maps input vectors to output vectors -> processing input data.\n", "- One that maps from hidden to hidden -> updates each hidden vector using previous one.\n", "- Optionally a Hidden to Output layer that learns to make predictions based on the hidden vector.\n", "\n", "An `activation` input parameter defines which nonlinearity is applied to hidden vectors at each timestep.\n", "\n", "Let's train the network:" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [], "source": [ "import sys, random, math\n", "from collections import Counter\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [], "source": [ "f = open('static/data/tasksv11/en/qa1_single-supporting-fact_train.txt', 'r')\n", "raw = f.readlines()\n", "f.close()" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [], "source": [ "tokens = list()\n", "for line in raw[0:1000]:\n", " tokens.append(line.lower().replace(\"\\n\", \"\").split(\" \")[1:])" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [], "source": [ "new_tokens = list()\n", "for line in tokens:\n", " new_tokens.append(['-'] * (6 - len(line)) + line)\n", "tokens = new_tokens" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [], "source": [ "vocab = set()\n", "for sent in tokens:\n", " for word in sent:\n", " vocab.add(word)" ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "vocab = list(vocab)" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [], "source": [ "word2index = {}\n", "for i, word in enumerate(vocab):\n", " word2index[word] = i" ] }, { "cell_type": "code", "execution_count": 110, "metadata": {}, "outputs": [], "source": [ "def words2indices(sentence):\n", " idx = list()\n", " for word in sentence:\n", " idx.append(word2index[word])\n", " return idx" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [], "source": [ "indices = list()\n", "for line in tokens:\n", " idx = list()\n", " for w in line:\n", " idx.append(word2index[w])\n", " indices.append(idx)" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [], "source": [ "data = np.array(indices)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### We can learn to fit the task we previously accomplished in the Preceding Chapter\n", "\n", "Now we can initialize the recurrent layer with an embedding input and train a network to solve the same task as in the previous chapter. We should note that this network is slightly more complex:\n", "\n", "- It has one extra layer\n", "- Despite the code being much simpler, thanks to the little framework." ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [], "source": [ "embed = Embedding(vocab_size=len(vocab), dim=16)\n", "model = RNNCell(n_inputs=16, n_hidden=16, n_output=len(vocab))" ] }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [], "source": [ "loss = CrossEntropyLoss()\n", "params = model.get_parameters() + embed.get_parameters()\n", "optimizer = SGD(parameters=params, alpha=0.05)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Cell** is a conventional name given to RNNs when they're implementing a single recurrence. If we created another layer that provided the ability to configure arbitrary numbers of cells together, It would be called an RNN, and `n_layers` would be an input parameter." ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.4210781038536967 % Correct: 0.0\n", "Loss: 0.17030485013740324 % Correct: 0.27\n", "Loss: 0.14961780443925604 % Correct: 0.36\n", "Loss: 0.1390113005612828 % Correct: 0.36\n", "Loss: 0.13628065998222028 % Correct: 0.35\n" ] } ], "source": [ "for iter in range(1000):\n", " batch_size = 100\n", " total_loss = 0\n", " \n", " hidden = model.init_hidden(batch_size=batch_size)\n", " \n", " for t in range(5):\n", " input = Tensor(data[0:batch_size, t], autograd=True)\n", " rnn_input = embed.forward(input=input)\n", " output, hidden = model.forward(input=rnn_input, hidden=hidden)\n", " \n", " target = Tensor(data[0:batch_size, t+1], autograd=True)\n", " l = loss.forward(output, target)\n", " l.backward()\n", " optimizer.step()\n", " total_loss += l.data\n", " if (iter % 200 == 0):\n", " p_correct = (target.data == np.argmax(output.data, axis=1)).mean()\n", " print_loss = total_loss / (len(data) / batch_size)\n", " print(\"Loss:\", print_loss, \"% Correct: \", p_correct)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try to predict using the trained model:" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [], "source": [ "batch_size = 1\n", "hidden = model.init_hidden(batch_size=batch_size)\n", "for t in range(5):\n", " input = Tensor(data[0:batch_size, t], autograd=True)\n", " rnn_input = embed.forward(input=input)\n", " output, hidden = model.forward(input=rnn_input, hidden=hidden)" ] }, { "cell_type": "code", "execution_count": 125, "metadata": {}, "outputs": [], "source": [ "target = Tensor(data[0:batch_size, t+1], autograd=True)\n", "l = loss.forward(output, target)" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [], "source": [ "ctx = \"\"\n", "for idx in data[0:batch_size][0][0:-1]:\n", " ctx += vocab[idx] + \" \"" ] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Context: - mary moved to the \n", "Pred: office.\n" ] } ], "source": [ "print(\"Context: \", ctx)\n", "print(\"Pred: \", vocab[output.data.argmax()])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Neural Network learns to predict the first `100` examples of the training dataset with an accuracy of over `37%`. It predicts a plausible location for Mary to be moving toward." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "### Frameworks are efficient & convenient abstractions of backward and forward logic\n", "\n", "Frameworks can make our code more readable, faster to write, and faster to execute (through built-in optimizations). \n", "\n", "This chapter will prepare us to use and extend industry standard frameworks like PyTorch or TensorFlow. The skills we've learned in this chapter will be the most valuable ones from this book. We highly recommend diving in `PyTorch` after finishing this book.\n", "\n", "---" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 4 }