{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Chapter 4 - Under the Hood - Training a Digit Classifier\n", "> Deep Learning For Coders with fastai & Pytorch - Under the Hood - Training a Digit Classifier. In this notebook. I add some cells for utility fuctions. `path`, `ls`, `untar`, `!`, `tree` usage, as usual I followed both Jeremy Howard's Lesson and Weights and Biases reading group videos. Click `open in colab` button at the right side to view as notebook.\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- categories: [fastbook]\n", "- image: images/magpie.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](images/chapter-04/magpie.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> I found this little one in front of my window. Suffering from foot deformity and can't fly. Now fully recovered and back his/her family." ] }, { "cell_type": "code", "execution_count": 259, "metadata": {}, "outputs": [], "source": [ "#!pip install -Uqq fastbook\n", "import fastbook\n", "fastbook.setup_book()\n", "# below is for disabling Jedi autocomplete that doesn't work well.\n", "%config Completer.use_jedi = False\n" ] }, { "cell_type": "code", "execution_count": 260, "metadata": {}, "outputs": [], "source": [ "from fastai.vision.all import *\n", "from fastbook import *\n", "\n", "matplotlib.rc('image', cmap='Greys')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## EXPLORING THE DATASET" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What untar does?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: 'untar_data' come from fastai library, it downloads the data and untar it if it didn't already and returns the destination folder." ] }, { "cell_type": "code", "execution_count": 261, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)" ] }, { "cell_type": "code", "execution_count": 262, "metadata": {}, "outputs": [], "source": [ "??untar_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Tip: Check it with '??'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What is path ?" ] }, { "cell_type": "code", "execution_count": 263, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('.')" ] }, "execution_count": 263, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Tip: what is inside the current folder? this where the jupyter notebook works. '!' at the beginning means the command works on the terminal." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What is !ls ?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: __ls works on the terminal. (-d for only listing directories)__" ] }, { "cell_type": "code", "execution_count": 264, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2020-02-20-test.ipynb\t ghtop_images my_icons\r\n", "2021-07-16-chapter-4.ipynb images\t README.md\r\n" ] } ], "source": [ "!ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__can be used like this too.__" ] }, { "cell_type": "code", "execution_count": 265, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/niyazi/.fastai/data/mnist_sample/train\r\n" ] } ], "source": [ "!ls /home/niyazi/.fastai/data/mnist_sample/train -d" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__also like this:__" ] }, { "cell_type": "code", "execution_count": 266, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/niyazi/.fastai/data/mnist_sample/train/3\r\n" ] } ], "source": [ "!ls /home/niyazi/.fastai/data/mnist_sample/train/3 -d" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What is tree ?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: __for seeing tree sturucture of the files and folders (-d argument for directories)__" ] }, { "cell_type": "code", "execution_count": 267, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[01;34m/home/niyazi/.fastai/data/mnist_sample/\u001b[00m\r\n", "├── \u001b[01;34mtrain\u001b[00m\r\n", "│   ├── \u001b[01;34m3\u001b[00m\r\n", "│   └── \u001b[01;34m7\u001b[00m\r\n", "└── \u001b[01;34mvalid\u001b[00m\r\n", " ├── \u001b[01;34m3\u001b[00m\r\n", " └── \u001b[01;34m7\u001b[00m\r\n", "\r\n", "6 directories\r\n" ] } ], "source": [ "!tree /home/niyazi/.fastai/data/mnist_sample/ -d" ] }, { "cell_type": "code", "execution_count": 268, "metadata": {}, "outputs": [], "source": [ "Path.BASE_PATH = path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What is ls() ?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: 'ls' is method by fastai similiar the Python's list fuction but more powerful." ] }, { "cell_type": "code", "execution_count": 269, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#3) [Path('labels.csv'),Path('train'),Path('valid')]" ] }, "execution_count": 269, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path.ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: Check this usage:" ] }, { "cell_type": "code", "execution_count": 270, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('train')" ] }, "execution_count": 270, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(path/'train')" ] }, { "cell_type": "code", "execution_count": 271, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#2) [Path('train/7'),Path('train/3')]" ] }, "execution_count": 271, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(path/'train').ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: there are two folders under training folder" ] }, { "cell_type": "code", "execution_count": 272, "metadata": {}, "outputs": [], "source": [ "threes = (path/'train'/'3').ls().sorted()\n", "sevens = (path/'train'/'7').ls().sorted()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: this code returns and ordered list of paths" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What is PIL ? (Python Image Library)" ] }, { "cell_type": "code", "execution_count": 273, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PIL.PngImagePlugin.PngImageFile" ] }, "execution_count": 273, "metadata": {}, "output_type": "execute_result" } ], "source": [ "im3_path = threes[1]\n", "im3 = Image.open(im3_path)\n", "type(im3)\n", "#im3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### NumPy array" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The 4:10 indicates we requested the rows from index 4 (included) to 10 (not included) and the same for the columns. NumPy indexes from top to bottom and left to right, so this section is located in the top-left corner of the image. " ] }, { "cell_type": "code", "execution_count": 274, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 29],\n", " [ 0, 0, 0, 48, 166, 224],\n", " [ 0, 93, 244, 249, 253, 187],\n", " [ 0, 107, 253, 253, 230, 48],\n", " [ 0, 3, 20, 20, 15, 0]], dtype=uint8)" ] }, "execution_count": 274, "metadata": {}, "output_type": "execute_result" } ], "source": [ "array(im3)[4:10,4:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: this is how it looks some part of the image in the NumPy array" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pytorch tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's the same thing as a PyTorch tensor:\n" ] }, { "cell_type": "code", "execution_count": 275, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 29],\n", " [ 0, 0, 0, 48, 166, 224],\n", " [ 0, 93, 244, 249, 253, 187],\n", " [ 0, 107, 253, 253, 230, 48],\n", " [ 0, 3, 20, 20, 15, 0]], dtype=torch.uint8)" ] }, "execution_count": 275, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tensor(im3)[4:10,4:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: It is possible to convert it to a tansor as well." ] }, { "cell_type": "code", "execution_count": 276, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
0000000000000000000
1000002915019525425525417619315096000
200048166224253253234196253253253253233000
309324424925318746108410194253253233000
401072532532304800000192253253156000
503202015000004322425324574000
600000000002492532451260000
700000001410122325324812400000
800000111662392532532531873000000
90000016248250253253253253232213111200
100000000439898208253253253253187220
" ], "text/plain": [ "" ] }, "execution_count": 276, "metadata": {}, "output_type": "execute_result" } ], "source": [ "im3_t = tensor(im3)\n", "df = pd.DataFrame(im3_t[4:15,4:22])\n", "df.style.set_properties(**{'font-size':'6pt'}).background_gradient('OrRd')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## BASELINE: Pixel similarity" ] }, { "cell_type": "code", "execution_count": 277, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(6131, 6265)" ] }, "execution_count": 277, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seven_tensors = [tensor(Image.open(o)) for o in sevens]\n", "three_tensors = [tensor(Image.open(o)) for o in threes]\n", "len(three_tensors),len(seven_tensors)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: 'sevens' are still list of paths. 'o' is a path in the list, then with the list comprehension we use the path to read the image, then cast the image into tensor.(Same for threes). 'seven_tensor' is a list of tensors" ] }, { "cell_type": "code", "execution_count": 278, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIEElEQVR4nO2bS08TbRuArx6ntJQONBQoYIFgEPCQYBQ0Rl2YuDFGoy504Q/xJ7j1B7hwR+LGjboQSZRo8IQBLIdWC5aDI1Baep525luYzgulKF++TjHv1yvpZmae9u41z9z3c2gNqqpS5R+MBx3A30ZVSBFVIUVUhRRRFVKE+Q/n/80lyFDqYLWHFFEVUkRVSBFVIUVUhRRRFVLEn8pu2VEUhXw+Tz6fR5ZlMpkMqVQKs9mMxWLBbDZjMpm06wVBwGKxYDCUrJJlp+JCstksW1tbrK6uMjs7y/v373nz5g0+nw+v16u9Cpw5c4bm5maMRmNFpFRMiCzLpNNp1tfX+fbtG9++fWN+fp5AIEAoFCKbzZJMJolEIqysrABgNBrp7u6mvr4eq9WK2ax/uBUTEo1GmZiY4MWLFzx69IhUKkUymSSXy5HP51laWuL9+/cYjUaMxl+pzWAwIIoiDoeDlpYWamtrdY+zYkIsFgv19fXYbDaSySTpdJp0Oq2dz+fzJdvNzc3x7t07hoaGMJlMWp7RDVVVf/cqG7Isq4lEQh0eHlZbWlpUp9Op8muu9NuXy+VS29vb1QcPHqh+v1+Nx+PlCqnkd65Y2TUajZjNZlpaWjh9+jRtbW2YTCYtUZrNZux2+667n0ql2Nzc5MePH0iShCzL+sap67tv/yCjEavVSldXF3fv3uXixYsIgqAJqK2tpampCYfDsaNdNpslHo8TDAYZHx8nHo/rG6eu716C2tpaenp6aG9vx+VyIQgCACaTCUEQdoxBtuPxeDhy5Ag1NTW6xldxIU6nk76+PoaGhvD5fLhcLgCsVit1dXVYLJZdbQwGA8ePH+fs2bO6V5qKD8wKOcPtdjM0NIQoilq1kSRpR+U5CCoupIDX6+XWrVuMjIwQi8UIhUKEQqE9ry8M+fXmwCZ3NTU1tLe3c/78eW7cuMG5c+dwu90lc4SqqgQCAfx+P6lUSte4DqyHOBwOHA4HTU1NDAwM4Ha7CYVCLC4u7vrSqqoyPj7O1tYWXq8XURR1i+vAhKTTaeLxOIlEgmg0SiAQYHV1lUQisetag8GA2+2mpaUFq9Wqa1wHJmRrawu/38/y8jLhcJhPnz7x/ft31D32mj0eDz09PbqX3YoJyeVyyLLM2toac3NzzM/PMzMzQywWIxqNMjc3t6cMgKamJrq7u7HZbLrGWVEh8XicsbEx7t+/jyRJ/PjxA0VRUBTlt20NBgM+n4/Ozk5tIKcXFasy+XyeWCyGJEmsra2RSCRQFOW3vaKAqqrMz8/z5cuXf0+VyeVybGxsIEkSP3/+RJblP/aM7czNzWGz2Whra9NGt3pQMSGCINDR0cGFCxcIh8P4/X4+fvy4r0cGIJlMEovFdB+cVUyIzWbDZrPR39/PlStXEASBycnJffUUVVVJJBLE43FyuZyucVa87IqiyODgIKIo0tDQQDabJZvNaudVVUVRFMbHx5menkaW5YoM2QtUXIjdbsdut+N0OmlsbCSTyZDJZLTzBSHpdJpgMKhtWVSKAxuY1dTU4PP5Sk7aFEXh4sWLKIrC69evCQaDrK6u4nA42NjYIJ1OY7FY9lw7+V84MCGCIOw5plBVlRMnThCLxVhYWCAYDLK5ucni4iKxWIxMJoPJZNJFyF+3lVl4ZKampnj8+DEzMzPAP2uyem9YHVgP2YuCkEAgwOjoqHa8sAXxfydkdXWVYDDI/Py8dsxgMDAwMMCxY8doa2vDZrPp8riADkIK+xvb7+J/c0clSWJsbIzv37/vOO71eunv70cUxZLrruWibEIURdGG569evaKhoYGOjg5EUcTtdv+xfS6XI5fLMT09zdOnT3f1kK6uLk6ePIndbi9XyCUpW1JVVRVZlvn58yfPnj1jdHSUUChEJBLZcxK3fccsl8uRyWRYXFzk8+fPrK+vA79kmEwmmpqaaGxs1LV3QBl7yNbWFs+fP2dycpKXL1/idrtZWlri1KlTXL9+HavVitVqxWKxIAgC6XSaVCqlbXr7/X4+fPjAq1evtJkwwODgIL29vfT19emaOwqUTUgqlWJiYoJAIEA4HCYSiZBOp3E6nUiSpK2hOhwOzGYz2WyWSCRCNBolEonw9u1bRkZGCIVC2nzFaDTS2dnJsWPHaGho0MqunpRNiNPp5OrVq3z69InFxUXW1tZYWFjgyZMnzM7OarmksDYaDof5+vUr8XicaDTK8vIy6+vr2r5MQ0MDoihy6dIlLl++TH19/Y69YL0omxCz2UxrayvxeJy2tjby+TwrKyuEw2FCoRC1tbV4PB6am5s5dOgQX79+xe/3k8lkdkzu4FfecLlc+Hw+Dh8+jMfjqYgMAMMfVqz2/dNuRVHIZrMsLS3x8OFDVFXFarXi9/sZHh7WNrutVqv2G5FkMrkr4dpsNgRB4N69e9y8eROPx4PdbsdgMJRbSMk3K1sPMRqN2Gw26urqaG1tRRAEvF4vsixjt9u1JJnJZDQRpWaxdrsdURQ5ceIEXV1d5Qpv35R9YCaKIrdv3wZ+df3CD+YKQjY3N5EkCb/fz9TUlNauUH3u3LnDtWvXOHr0aLlD2xdlF2KxWBBFUZuTNDc3MzAwoPUGSZJwuVwoioIkSVq7worakSNH6O3txel0lju0fVG2HFKy8bYBV+FzCo9KJpPZsdNfyBGiKFJTU1OJElsyh+gq5C+n+n+Z/VAVUkRVSBFVIUVUhRRRFVLEnwZmlfmTyl9EtYcUURVSRFVIEVUhRVSFFFEVUsR/AP0FXN1zCRLUAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_image(three_tensors[0]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Tip: Show image shows the first tensor as image" ] }, { "cell_type": "code", "execution_count": 279, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 279, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAJHElEQVR4nO2bXXMSZxuAL1jYXYQsiCYxIWIIjImJ0TaVTu2H41FnnGmPPOtMf0NP+i/6H9oDx+OOOtMjW6cdGz/Sk1ZqxkRDhCTQEAjfsLDsvgeWbbMmxgrEzDtcR5ln+bi5eJ5n7/t+iM0wDPr8g/1tB3DY6Aux0BdioS/EQl+IBcc+1/+fb0G23Qb7M8RCX4iFvhALfSEW+kIs9IVY6Aux0BdiYb/ErGMMw6DVatFsNqlUKmiaRrPZpNls0mg0Xnq8LMtIkoSu6+i6jtPpxOFw4HQ6EQQBWZZxOHoXds+FaJpGvV4nHo/zww8/kMlkSKVSxONxnjx5wr/7MTabjffee4+pqSmq1SqqqjIyMsKxY8eIRCIMDw8zOzuLz+frWbxdF6LrujkDisUilUqFbDbLn3/+yeLiIoVCgfX1dTKZDJVKBVEUEUWRer1OrVZjZWUFh8NhCqlUKmxublIulzl+/DgTExM9FWLbp2P2n2sZVVXZ3t5mdXWV77//nlQqxaNHj8jlcqTTaQzDwDAMZFnG7XYzODhIIBDgyZMnPH/+HLvdjt1uN2eOzWbDZrPh8/nwer1cv36daDT6hh93B7vWMh3PEE3TqNVq1Go1UqkUlUqFVCrFysoKT58+pVQqIQgCExMTRKNRnE4nkiQhiiIulwtFUfB6vczOzpLJZFheXubZs2eUy2VqtZr5Po1GA1VV0XW905BfScdCVFUlFovx22+/8c0331CpVGg2m7RaLRqNBoFAgGg0ysWLF7l69Soejwe3220+3263Y7PZ0HUdwzC4ceMG165dIxaLkUwmOw3vP9OxEMMwaDQaVKtVSqUS1WoVXdex2+2IokggEGBubo5z587h8/lwOp04nU7z+e0l8e+lJAgCNtvOGe33+wkGg8iy3GnIr6QrQqrVKrVaDVVV0TQNAFEUOXr0KHNzc3zxxRf4fD4URXnpg7ZpjzudTux2+0vXpqenmZmZQVGUTkN+JR0LcTqdhEIhRFEkn8/TbDaBF0Lcbjdzc3N4vV5EUdxTBrzYizRNI5/Pk8/nzRzFZrMhCALDw8OEw2FcLlenIb+SjoVIksTp06cJh8N88MEH5nh7KQiCgCiK+75Oo9GgVCqxtrZGMpmkWq0CmM+PRCJEo9Ed+08v6FhI+1sXBGHH3tC+Zp3+e7G+vs7PP//M77//TqlUQtM0BEEgHA4zPj7OzMwMJ06ceOk9uk1XErP2bHidmbAXd+7c4auvvkLTNHRdx+FwIIoiFy9eJBqN8u6773LixIluhPtKep66WzEMA13XqdfrlMtlM2FbWFgwN2S73U4kEiESifDhhx8SjUZ7vpm2OXAhuq6jaRrZbJbHjx/z448/cvPmTbLZrHm7djgcXLhwgY8//phPP/2UkydPHlh8B1LtGoZBsVhkdXXVrGWSySSrq6ssLS2Ry+Wo1+vAi3zD7/dz+vRpzp8/z8DAQK9D3MGBlf/JZJLvvvuO5eVlHjx4gKqqO1LzNkNDQ0xPT3PhwgUmJyd7fpu1ciBLRtd1CoUCjx49Yn19HVVVzXzFSiaTIRaLcevWLeLxOMeOHcPj8TA2NobX62VwcLCnkg5syWxtbfHgwQM0TTM31t3IZDJkMhnW19dxu90oioLH4+HKlSucP3+eS5cuIcvyK5O8Tuh6+f/SC/y9ZDY2Nrh9+zblcplCoUCxWCSXy5mPi8fjLC0tmT0UURRxOp1mv2RqaorR0VE++ugjzpw5w9mzZ/F6vQiC8Nq5joVdjfZciJVarWbKWFtbM8d/+uknfvnlF1ZXV0mn07s+t91Ri0QifP3110xOTiJJEoIgvEkovemH/FecTieKoiDL8o7O19DQEJcuXWJtbY10Ok0mkyGXy3H//n3i8TjwYrYlEgmq1SrPnz9ncHCQ48ePv6mQXTlwIQ6HA4fDgcvlwuv1muPDw8NMT09TrVbNW/TKygrZbNYUArC5uUkulyMejxMKhTh69Gh34+vqq3VAuxBsd9VlWSYYDBKPx/nrr79IJBJsb28DL2bKxsYG8XicU6dOdTWOQ3Mu0y4EJUkye63BYJDZ2VmmpqZ2pO6GYZgzR1XVrsZxaITsRbtwtI55vd6eVL+HXgjAbndCt9uN3+/v6oYKh2gPsVIsFtne3mZhYYGHDx/uyFlsNhunTp0iEol01HLYjUMppF0MJpNJEokEiURiR2YrCAJDQ0P4/f6uH2seOiHVapVKpcK9e/e4c+cOf/zxh3lEATA5Ocn4+DiBQABZlt80S92TQyOk/YHr9TrZbJZYLMb8/DypVMq8ZrfbCQaDRCIRvF4vDoej6zXNoRFSKBRIp9Pcvn2bu3fv8vjxY5LJpNknGRgYwO12c/XqVS5fvszo6Oiu5zed8laFtCthwzDI5/M8ffqUhw8fcuPGDbO3Ci820SNHjjA4OMi5c+cIhUI9kQFvUYiqqqiqysbGBktLS9y9e5f5+XkSiQTNZtNcJi6XC1mW+fLLL/nkk08Ih8M9kwFdFvLvb9yaULXH2383Gg0KhQLxeJz5+XkWFha4d++e+fj2me/AwACKonD27FneeecdPB5Pz2RAF4Vomka1WqVcLrO2toaiKIyMjJiH3pVKha2tLUqlEpubmywvL7O4uGjWKfl8HvgnM41EIoTDYa5cuUI0GiUUCqEoSk9/PQRdFNJqtSiVSmxtbRGLxRgdHUWSJPMXRLlcjuXlZba2tsxl0u6tqqq645RPFEXGx8cJh8NEo1FmZmbMhlGv6ZqQfD7Pt99+SzKZ5P79++Zhd6vVotVqmecwqqqaM6ZSqZgb5/DwMGNjY7z//vtmk/nkyZMoioIkSV3PN/aia0IajQaJRIJnz56xuLj4Wj9ssdlsSJKEJEmMjY0RiUQ4c+YM0WiUiYkJ/H5/t8J7bbomxOPxcPnyZTweD7/++uu+QmRZ5siRI3z++ed89tlnhEIhRkZGcLlcB7Y8dqNrQhwOB8FgkHQ6TSAQMI8l98LlcuHz+ZicnGR2dpahoaEdHbS3RdeazK1WyzxvKRaL+7/x3w0ht9ttdsm6XcrvF8KugwfddT9E9P+j6nXoC7HQF2KhL8TCfrfd3lVRh5T+DLHQF2KhL8RCX4iFvhALfSEW/gcMlBno19ugeQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_image(tensor(im3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: check this in more straight way (im3>tensor>image) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training Set: Stacking Tensors" ] }, { "cell_type": "code", "execution_count": 280, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([6131, 28, 28])" ] }, "execution_count": 280, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stacked_sevens = torch.stack(seven_tensors).float()/255\n", "stacked_threes = torch.stack(three_tensors).float()/255\n", "stacked_threes.shape" ] }, { "cell_type": "code", "execution_count": 281, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Tensor" ] }, "execution_count": 281, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(stacked_sevens)" ] }, { "cell_type": "code", "execution_count": 282, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Tensor" ] }, "execution_count": 282, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(stacked_sevens[0])" ] }, { "cell_type": "code", "execution_count": 283, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "list" ] }, "execution_count": 283, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(seven_tensors)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: now we turn our list into a tensor size of ([6131, 28, 28])" ] }, { "cell_type": "code", "execution_count": 284, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 284, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(stacked_threes.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: This is rank (lenght of the shape)" ] }, { "cell_type": "code", "execution_count": 285, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 285, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stacked_threes.ndim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: This is more direct way to get it. (ndim)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Mean of threes and sevens our ideal 3 and 7." ] }, { "cell_type": "code", "execution_count": 286, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAJtUlEQVR4nO1b2XLiWhJM7QsChDG22x3h//+qfnKzWVhoX5HmoaNqDufK9jRge2aCiiCEAS0nVUtWlqz0fY+r/dvU776A/za7AiLZFRDJroBIdgVEMv2D7/+fS5Ay9OHVQyS7AiLZFRDJroBI9lFS/RQ7tV1QlME8eFH7dEDkxdPf4udDAMmLVxQFfd8Pfn5JuyggQ4vs+/7o1XUdfy6+l01RFCiKAlX9E9WqqvJn9JJ/fwm7CCDy4ruu423XdTgcDjgcDqiqCm3boq5rtG2LpmlwOBzQti3vr6oqVFWFaZrQNA2WZUHXdViWBU3ToOs6NE3j3xFQZOcCcxYgQ15AIHRdh7Zt0bYtqqpCXdcoigJlWaIoCtR1jTzPUdc1yrLk4+i6DlVVMRqNYFnW0dY0Tdi2DcMwYBgGNE1jEGRgTrWTARHBkD2haRpUVYU8z1EUBcIwRJIk2Gw2CMMQu90OaZoiiiKUZYk8z3E4HNB1HUzThGEY8DwPo9EIi8UCs9kMj4+PmM1mmM/ncF0X4/EYlmUdeY4YYqeCc7aHDAFSliWqqkKSJMjzHLvdDmEYYrPZII5jBEGALMsQRRGyLEOWZRw6dOdnsxk8z0PXdSiKAoqioKoq6LqOw+EAXdfR9z17CYXPUOL9dEDkXCHmiKqqEMcx0jTFdrtFGIZ4fn7Gfr/HZrNBmqYIggBJkiCKIlRVhbIs0bYtDocDn8NxHFiWhYeHB9ze3iIMQ9zc3CDLMtze3qJtW0wmEyiKAsdxjpLvOXZ2UhXzBy1K3LZty+GlaRpM08RoNIKiKNA0jQFpmgZt27Knkaf0fY+6rlHXNYdhlmX8N53DNE2+ji/3EAJiKHfQhdZ1zVVEVVVYlsVxb9s270P7U/WhvNM0DZqmga7rnJgp76iqivl8Dk3TMB6Poes6uq7jkKEbcAowJ4eMaHRiimNd19kTyHNs24Zt27xQApRAITAJkKIokOc5NE07SpbyfqKHXsIuRszoonVdh+M4nOw8z4Pv++wB8j60QAJgv98jjmOuTMRZdP3PpYqe+R4Q31Jl6MQiGADQdR2Tp6ZpOESImYpGn2dZBl3XUVUViqJgPkILo5xD5zEM44ikib87x04CRD45ubVpmnyRXdfBcZyjaiTfTSJvTdPAsiyYpvmPUAFwlBNM02SCRpxlCLxT7SwPES9AVVVehBgKsnuLpZq25BVhGGK/32O/3yNNU+R5zhVLVVUYhnHEXonWEykb6nG+HBDxLhIQcqKTgaBS3Pc9V4+Xlxes12usViu8vLwgiiLs93tesKZpsG0bk8kEvu9jNBrBdV0GhRL6uXYyIDIQiqKg67pBUESeIvKJJEmYwa7Xa2y3WwRBgN1uhyzLkOc5JpMJl2rXdeF5HsbjMRzHOQpRsRv+FkBEUAgEIlJvtfvU4NHdXy6XWK1W7BUESBiGHIau60LTNDiOg8lkwpTecRw4jnPkHd+WQwgA8b28JRAOhwN3tHEcI4oiDo3lcsmhst1usdlsUBQFqqri5EkVi7jNUHW5VIU5GZD/BBSRxRZFgSRJsN1usVqt8OvXL6zXazw/P+P3799YLpeI4xhJkvAxPc+D53kA/lQxSqhi6y+LRnQt59jFRGbxQihcyDuKouCmbrPZYLPZIAgCLJdLBEGANE1R1zULRCLPAHBUiaiBpLZAZqvnMtazc8iQZirS67qukWUZ4jjGer1mQJbLJZbLJZIkQZIkXIYVRTnyAjpWVVUsFbiue9QMUh/zrSHznhFIIg+h5EpyoO/7WCwWGI1GmEwmDKSmaZxEKUwo7KIoQhAE3NRR9yxrr8A3UnfRZJFZJmZEvy3Lgud5uLu7Y3WNTNRLAcAwDHRdhzzPoaoqXl9foWkabm5uoOs6bNvm4367h7w3SlAUhXOB67po2xY/fvxghlkUBbIs49AiI/BkQbrve65UqqoiiiJomgbXdTnviJ5C1/C3dramKr8XL0am2/P5HI7jwHVdFn1kSk+9DVH3OI4ZOGK1qqpiv9/DsixUVcVeJHriqXaWHjK0FUsxjRNc14VhGDBNE3VdYz6fs2fIgJCC9vr6ijiOmXiRStY0DVet0WiEsixhmib3O8SWvyyHvFVVxO+o+pD7ip0pdbhDs5yu61igtm2bQ4tAAoC6rqFpGqv1ojInMmU69t8Cc5aHyG39EDAU3zRzeav5I05BJZd6HgLGMAxOvqJsQFLkECf58hwytDBxS6CQejaUa2QPIbNtm0uvyErl38tgnGt/BYjsGVQd3tM236PW4gLp+LRYyhUkWFOYDc13L8FQyU7OISIAlBxliVAeWL8Finx8keWSQCQn7KF9ZftS1V2Me3FoTQuStVZRURMBot9TcozjmGn+arXCbrfjkSclTmK7NBCn7ve9pwM+DRBZ6xABaZrmKBcQ42zbli9cFKMVRWFQReEoTVNW37MsQ1EURyFD9F7WU7+NqYqA0BCpbVu+8Kqq+Ht5VkPAkFG1oMZts9mwRvL6+ookSVAUBatjJBT5vo/pdMojTzr2uU3eSUlVBoW8g2a0RVFwDiAKL48OyGhARSradrvFbrdDEAQcKsRGiehRBSJuQ99dwlP+ChBZFBJPTqSqLEsEQcC0m7xIfIaDRo9Ex8uyRJqmSNOU5YA8z1GWJfMQz/MwnU5xf3+Pu7s73N3dYTKZwPM82LY9mEc+HRARGPmkYnIkgTiKIvYcAk0UpEV5kYbYSZLw4xF93zMpcxwHnudhMplgMpnAcRxmwEO66qn214CIYFBiI+2TdFAAR3khDEPUdY04jo9yDok8IuOkhOn7PsbjMR4fH+H7Pp6ennB/f4+npydMp1PMZjMGhfb58pAZAoUSpvwiLxCbsd1ux5WE8g6Vb3J3Yqeu68L3ffi+j/l8jsVigdvb26MwuVQiPRkQOinxCHERNL60bRsAMJ1Ooes69vs9DMNAkiQwDIPlRNI66O7SvGU+n2M8HuP+/h6z2Qw/f/7EYrHAzc0NC88iGEPc5ssAkcERgSHdAwCr5VmWQVGOH4Ui8IjIkUfNZjOMRiPMZjP4vs9PDj08PGA8HnPeoFmMqKxdcgyhfNADDH451NOQuEP6Z9M0XCnSNOVHrUg9pypDi6PRpLilZ0rkofZQiT0BjMEdzgJkiLVSmaXRgchPaEu5gzQTEotN02SSRS+x2x16NvUMr7gcIPzlAFED/tn9fjQ7kZO03JO81aOcGSKDO19ktiv/LbblQ9uPjvfW9q3zXtLO8pCP7BIaxScu/vIe8uEZP/FOfpZd/4FIsisgkn0UMv97Pn+mXT1Esisgkl0BkewKiGRXQCS7AiLZvwBtCZqwAvXF1QAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mean3 = stacked_threes.mean(0)\n", "show_image(mean3);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: This is the mean of the all tensors through first axis. 'Ideal Three'" ] }, { "cell_type": "code", "execution_count": 287, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAI6klEQVR4nO1baVPiTBc9ZF/IhoOWOjUf5v//KgelNEZICCErvB+eutemjaNC8N04VanG7H1y99uOdrsdzniF8u9+gf80nAmRcCZEwpkQCWdCJGgfHP9fdkGjvp1nCZFwJkTCmRAJZ0IknAmRcCZEwpkQCWdCJJwJkfBRpHoQPlNjkc8ZjXoDxzf47HmHYhBCxMnR74/Gz0IkYDQa8d/y2Hf+ITiKEHmS2+2Wx91uxxv9TcfFY+L1BHGy9FtRFIxGo96RjsvXH4KDCBEnIk6aJt51HbbbLbquQ9d1aNuWR9pP58n3kUET1zQNqqrCMAyoqgrTNKGqKjRNg6IoexvhEGK+TIj44kQCEdA0DZqmQVmWaJoGm80GdV1jvV6jrmvkeY6qqrDZbNA0Deq6Rtu2TFSfJKmqCkVR4LouTNPExcUFXNdFFEVwHAe+78M0TViWxQSNRiOoqordbvdlUr5EiCwZ9KVJAoqiQNM0WK/XKMsSWZahKAokSYLNZoMsy1CWJRPVNA1fS6SKKgaACRmPxzBNE9PpFJ7n4devX/B9H5qmYbfbMRHb7RaKohxExpcI6VMPmgxNcLVaoSgKxHGMNE0xn8+RZRniOEae51gsFlitVlgulywpoiqJ6kQbffUgCOB5Hn7//o0wDJGmKS4vLwEAQRBA0/6ZCqnYyQnpI0ckhlSFJCHLMqRpymOe50iSBOv1GqvVCm3boq7rvfuJIHLatkVVVRiNRmiaBlEUQVEU5HkO13VZ0kjCjsWXVUYkous6NE2DqqpQliXyPEeapojjGMvlEnEcI8sy3N/fY7VaIUkS1HWNsiyZAEVR2DCS1xCfQXamrmvoug7TNFHXNS4vL2EYBvI8h23bLK3vGeeTEPI3kIskEdd1HbquwzAM+L4PVVUBgL+6fA55ETKyy+US6/UaWZZhvV7zMwDskSkSSb9Fd/1VfIqQjxgXyaAJmqYJ27aZhPF4jDAMoaoqu03HcfhcXdehqip7qiRJkGUZ7u7u8PT0xAabDCdBdrnHkPFpQshIyQQoisISsd1uYZomuq5DEARQFAV1XcO2bei6ziqg6zosy4Jt23BdF5ZlwbIslpDNZoOyLHl/nucoioLVgQik0bIsjk2+jRCRiD5CDMMAALiuC1VV0XUdTNOEoiioqgphGLKtoNjB8zy4rssTo3sSIZ7nYTabIc9zrNdrVFWFtm1hWRYcx4Ft27Btm4kjCRNV5tu8jKizAKDrOn89ALBtm41j13WoqgqapsE0Tbiuy5LhOA50XedYYrfb7UWmwD/qRt5IVVU4jgPXdTEejxEEAUuIaJiPwacJER+kKAoHQPTypNuqqrL6ULS42+1YVSzLYsmwLIuJJZUiIk3TZEKqquKYxHEcOI4Dz/M4WqUoVbQjJydEJIaCHjHxAv6RFABMhghR98muEJF0Pbnbuq6RZRkHcuRlDMPAeDyG7/sIw5CjV13X3+Qxh+JglQFeJYV0V9d1Jqxt2z2dFr0P6TtNgK6h2KYsS6RpisVigcViwUEYqVwQBIii6A0hx7rcLxMiehvRsJLEkOEkkogQ2i8TQRCDvDzP8fz8jMfHRzw/PyNNU9R1zaG77/sIggCu68K27T3SAey938mTO3oQPVj0OuRxSBrETJU2MVUX70PGl/Kh+XyOp6cnzGYzpGmKpmmgaRo8z4PneQjDkCWGpGMoHBypytICvNoSsh80igUdoL+EUBQF0jTF/f097u7uEMcx4jhG0zRQVRVhGGIymfBIrvZYFZFxVOjeV84Tv5ZMmLyfJKNtW2w2G86QHx4eMJ/PkaYpezFys7TJrnYoUg42qrItod8A9ryGDLmMQMlekiT48+cPZrMZ5vM5FosFmqZhb3J1dYXLy0tMp1OMx2OOP0TJG4KUoyVEtCVkYGkTiZOLS2RI67pm6Xh4eEAcx3h4eECe5+i6DoZhIIoihGGIi4sLNqhiZErvMgSOtiF9Ve/tdvsm/wH2SaGSY1EUeHl5wWw2w2w2w8vLC6uK53m4vb3Fz58/cX19jZubGwRBwEmhHIx9u9v9G+RIVm5NyG6RCktUR0mSBEmSYLFYoCxLqKoK27bx48cPXFxcYDKZYDKZwHEcmKb5xn70kfBtucxnH0xSIhdtKHCjuuv9/T3HHGVZQtM0RFGEIAhwfX2N29tb3NzcIIoi2LbNkXBfyn+sCg1iQ+S/33sZUWWoClYUBZbLJfI8R5Zl7GZ938d0OkUURZhMJvB9H7Zt73mX98g4BkdLSB8phPdUheqvaZri6ekJSZJwi4LynNvbW0ynU5YQMqaidPTZjW/Ldv8GedLv7SPVIemgEiGRQV5FzGajKILneRyIDW1EZQza7O7zLMCrVxGN6HK5xOPjI6vLdruF4zgIggC2bePq6gpXV1fchxErY3LkKz7/WJx8OURfy4J6Mnmec08HANdIKGchcvq8CtAfFB6LQSVEVg+5b0M9mcVigefnZywWC2w2G4xGI1aJyWSCKIrY1ZLdoJrr0KG6jJOtD/lbD2ez2aAoCpRlia7rOF/RdZ2Lz5TeExGniEr7cJL1IWJoTipSFAVWqxWyLMPLywvyPOe2AnkWUUJEdZELQKfE4CrTJx3Ua6mqipO5uq65qKxpGtsPalGIlfTvIgMYgJC+tSLUZyXpoJ7ver1mQ0pVNbHOats2wjCE7/uwLOvDmOMUGNyG9NkOMqo00nIHsTBMTScav8uIyhhsSdV7RpQWxtBGrQbTNN8siKEikOhqRYP6X6EyIvoW1IgrgyiUp8YUqYSqqiwdhmFwi2KIPstXMWhy99451AR3XZcnKa8CIBsitiffa1mI49A4SRwC7Lc7adLU6gReWw+apnGb0zAM3khy/lbzOAUpow++8KdWnvTZEHHJFdkSMqxt2+6pEEkRkUNumEj5qDJ2IDG9Fw1eMaNqmdinESdMC+xEQoDXxXWiIZXvcYp0/808hpAQPrknYhX391XPel+qJ4EbqiImPqZ355CE8EXvlADeO953ft/EB5aKgwj5v8P530MknAmRcCZEwpkQCWdCJJwJkfAv6ObhbeIGuNEAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mean7 = stacked_sevens.mean(0)\n", "show_image(mean7);" ] }, { "cell_type": "code", "execution_count": 288, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAJHElEQVR4nO2bXXMSZxuAL1jYXYQsiCYxIWIIjImJ0TaVTu2H41FnnGmPPOtMf0NP+i/6H9oDx+OOOtMjW6cdGz/Sk1ZqxkRDhCTQEAjfsLDsvgeWbbMmxgrEzDtcR5ln+bi5eJ5n7/t+iM0wDPr8g/1tB3DY6Aux0BdioS/EQl+IBcc+1/+fb0G23Qb7M8RCX4iFvhALfSEW+kIs9IVY6Aux0BdiYb/ErGMMw6DVatFsNqlUKmiaRrPZpNls0mg0Xnq8LMtIkoSu6+i6jtPpxOFw4HQ6EQQBWZZxOHoXds+FaJpGvV4nHo/zww8/kMlkSKVSxONxnjx5wr/7MTabjffee4+pqSmq1SqqqjIyMsKxY8eIRCIMDw8zOzuLz+frWbxdF6LrujkDisUilUqFbDbLn3/+yeLiIoVCgfX1dTKZDJVKBVEUEUWRer1OrVZjZWUFh8NhCqlUKmxublIulzl+/DgTExM9FWLbp2P2n2sZVVXZ3t5mdXWV77//nlQqxaNHj8jlcqTTaQzDwDAMZFnG7XYzODhIIBDgyZMnPH/+HLvdjt1uN2eOzWbDZrPh8/nwer1cv36daDT6hh93B7vWMh3PEE3TqNVq1Go1UqkUlUqFVCrFysoKT58+pVQqIQgCExMTRKNRnE4nkiQhiiIulwtFUfB6vczOzpLJZFheXubZs2eUy2VqtZr5Po1GA1VV0XW905BfScdCVFUlFovx22+/8c0331CpVGg2m7RaLRqNBoFAgGg0ysWLF7l69Soejwe3220+3263Y7PZ0HUdwzC4ceMG165dIxaLkUwmOw3vP9OxEMMwaDQaVKtVSqUS1WoVXdex2+2IokggEGBubo5z587h8/lwOp04nU7z+e0l8e+lJAgCNtvOGe33+wkGg8iy3GnIr6QrQqrVKrVaDVVV0TQNAFEUOXr0KHNzc3zxxRf4fD4URXnpg7ZpjzudTux2+0vXpqenmZmZQVGUTkN+JR0LcTqdhEIhRFEkn8/TbDaBF0Lcbjdzc3N4vV5EUdxTBrzYizRNI5/Pk8/nzRzFZrMhCALDw8OEw2FcLlenIb+SjoVIksTp06cJh8N88MEH5nh7KQiCgCiK+75Oo9GgVCqxtrZGMpmkWq0CmM+PRCJEo9Ed+08v6FhI+1sXBGHH3tC+Zp3+e7G+vs7PP//M77//TqlUQtM0BEEgHA4zPj7OzMwMJ06ceOk9uk1XErP2bHidmbAXd+7c4auvvkLTNHRdx+FwIIoiFy9eJBqN8u6773LixIluhPtKep66WzEMA13XqdfrlMtlM2FbWFgwN2S73U4kEiESifDhhx8SjUZ7vpm2OXAhuq6jaRrZbJbHjx/z448/cvPmTbLZrHm7djgcXLhwgY8//phPP/2UkydPHlh8B1LtGoZBsVhkdXXVrGWSySSrq6ssLS2Ry+Wo1+vAi3zD7/dz+vRpzp8/z8DAQK9D3MGBlf/JZJLvvvuO5eVlHjx4gKqqO1LzNkNDQ0xPT3PhwgUmJyd7fpu1ciBLRtd1CoUCjx49Yn19HVVVzXzFSiaTIRaLcevWLeLxOMeOHcPj8TA2NobX62VwcLCnkg5syWxtbfHgwQM0TTM31t3IZDJkMhnW19dxu90oioLH4+HKlSucP3+eS5cuIcvyK5O8Tuh6+f/SC/y9ZDY2Nrh9+zblcplCoUCxWCSXy5mPi8fjLC0tmT0UURRxOp1mv2RqaorR0VE++ugjzpw5w9mzZ/F6vQiC8Nq5joVdjfZciJVarWbKWFtbM8d/+uknfvnlF1ZXV0mn07s+t91Ri0QifP3110xOTiJJEoIgvEkovemH/FecTieKoiDL8o7O19DQEJcuXWJtbY10Ok0mkyGXy3H//n3i8TjwYrYlEgmq1SrPnz9ncHCQ48ePv6mQXTlwIQ6HA4fDgcvlwuv1muPDw8NMT09TrVbNW/TKygrZbNYUArC5uUkulyMejxMKhTh69Gh34+vqq3VAuxBsd9VlWSYYDBKPx/nrr79IJBJsb28DL2bKxsYG8XicU6dOdTWOQ3Mu0y4EJUkye63BYJDZ2VmmpqZ2pO6GYZgzR1XVrsZxaITsRbtwtI55vd6eVL+HXgjAbndCt9uN3+/v6oYKh2gPsVIsFtne3mZhYYGHDx/uyFlsNhunTp0iEol01HLYjUMppF0MJpNJEokEiURiR2YrCAJDQ0P4/f6uH2seOiHVapVKpcK9e/e4c+cOf/zxh3lEATA5Ocn4+DiBQABZlt80S92TQyOk/YHr9TrZbJZYLMb8/DypVMq8ZrfbCQaDRCIRvF4vDoej6zXNoRFSKBRIp9Pcvn2bu3fv8vjxY5LJpNknGRgYwO12c/XqVS5fvszo6Oiu5zed8laFtCthwzDI5/M8ffqUhw8fcuPGDbO3Ci820SNHjjA4OMi5c+cIhUI9kQFvUYiqqqiqysbGBktLS9y9e5f5+XkSiQTNZtNcJi6XC1mW+fLLL/nkk08Ih8M9kwFdFvLvb9yaULXH2383Gg0KhQLxeJz5+XkWFha4d++e+fj2me/AwACKonD27FneeecdPB5Pz2RAF4Vomka1WqVcLrO2toaiKIyMjJiH3pVKha2tLUqlEpubmywvL7O4uGjWKfl8HvgnM41EIoTDYa5cuUI0GiUUCqEoSk9/PQRdFNJqtSiVSmxtbRGLxRgdHUWSJPMXRLlcjuXlZba2tsxl0u6tqqq645RPFEXGx8cJh8NEo1FmZmbMhlGv6ZqQfD7Pt99+SzKZ5P79++Zhd6vVotVqmecwqqqaM6ZSqZgb5/DwMGNjY7z//vtmk/nkyZMoioIkSV3PN/aia0IajQaJRIJnz56xuLj4Wj9ssdlsSJKEJEmMjY0RiUQ4c+YM0WiUiYkJ/H5/t8J7bbomxOPxcPnyZTweD7/++uu+QmRZ5siRI3z++ed89tlnhEIhRkZGcLlcB7Y8dqNrQhwOB8FgkHQ6TSAQMI8l98LlcuHz+ZicnGR2dpahoaEdHbS3RdeazK1WyzxvKRaL+7/x3w0ht9ttdsm6XcrvF8KugwfddT9E9P+j6nXoC7HQF2KhL8TCfrfd3lVRh5T+DLHQF2KhL8RCX4iFvhALfSEW/gcMlBno19ugeQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "a_3 = stacked_threes[1]\n", "show_image(a_3);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Distance between the ideal three and other threes" ] }, { "cell_type": "code", "execution_count": 289, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1114), tensor(0.2021))" ] }, "execution_count": 289, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dist_3_abs = (a_3 - mean3).abs().mean()\n", "dist_3_sqr = ((a_3 - mean3)**2).mean().sqrt()\n", "dist_3_abs,dist_3_sqr" ] }, { "cell_type": "code", "execution_count": 290, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1586), tensor(0.3021))" ] }, "execution_count": 290, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dist_7_abs = (a_3 - mean7).abs().mean()\n", "dist_7_sqr = ((a_3 - mean7)**2).mean().sqrt()\n", "dist_7_abs,dist_7_sqr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: Then we need to calculate the distance between the 'ideal' and ordinary three.Two methods for getting the distance __L1 Norm__ and __MSE__ second one is panelize bigger mistake more havil, L1 is uniform.\n", "\n", "It is obvious that a_3 is closer to the perfect 3 so our approach worked at this time. (Both in L1 and MSE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pytorch L1 and MSE fuctions" ] }, { "cell_type": "code", "execution_count": 291, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.1586), tensor(0.3021))" ] }, "execution_count": 291, "metadata": {}, "output_type": "execute_result" } ], "source": [ "F.l1_loss(a_3.float(),mean7), F.mse_loss(a_3,mean7).sqrt()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: torch.nn.functional as F (for mse, manually take the sqrt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Important: (from notebook) If you don't know what C is, don't worry as you won't need it at all. In a nutshell, it's a low-level (low-level means more similar to the language that computers use internally) language that is very fast compared to Python. To take advantage of its speed while programming in Python, try to avoid as much as possible writing loops, and replace them by commands that work directly on arrays or tensors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Array and Tensor Examples" ] }, { "cell_type": "code", "execution_count": 292, "metadata": {}, "outputs": [], "source": [ "data = [[1,2,3],[4,5,6]]\n", "arr = array (data)\n", "tns = tensor(data)" ] }, { "cell_type": "code", "execution_count": 293, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1, 2, 3],\n", " [4, 5, 6]])" ] }, "execution_count": 293, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arr # numpy" ] }, { "cell_type": "code", "execution_count": 294, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 2, 3],\n", " [4, 5, 6]])" ] }, "execution_count": 294, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns # pytorch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Splitting, adding, multiplying tensors" ] }, { "cell_type": "code", "execution_count": 295, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([2, 5])" ] }, "execution_count": 295, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns[:,1]" ] }, { "cell_type": "code", "execution_count": 296, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([5, 6])" ] }, "execution_count": 296, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns[1,1:3]" ] }, { "cell_type": "code", "execution_count": 297, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[2, 3, 4],\n", " [5, 6, 7]])" ] }, "execution_count": 297, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns+1" ] }, { "cell_type": "code", "execution_count": 298, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'torch.LongTensor'" ] }, "execution_count": 298, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns.type()" ] }, { "cell_type": "code", "execution_count": 299, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1.5000, 3.0000, 4.5000],\n", " [6.0000, 7.5000, 9.0000]])" ] }, "execution_count": 299, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tns*1.5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Validation set :Stacking Tensors" ] }, { "cell_type": "code", "execution_count": 300, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))" ] }, "execution_count": 300, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_3_tens = torch.stack([tensor(Image.open(o)) \n", " for o in (path/'valid'/'3').ls()])\n", "valid_3_tens = valid_3_tens.float()/255\n", "valid_7_tens = torch.stack([tensor(Image.open(o)) \n", " for o in (path/'valid'/'7').ls()])\n", "valid_7_tens = valid_7_tens.float()/255\n", "valid_3_tens.shape,valid_7_tens.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Manual L1 distance function" ] }, { "cell_type": "code", "execution_count": 301, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.1114)" ] }, "execution_count": 301, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def mnist_distance(a,b): return (a-b).abs().mean((-1,-2))\n", "mnist_distance(a_3, mean3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### This is broadcasting:" ] }, { "cell_type": "code", "execution_count": 302, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([0.1270, 0.1254, 0.1114, ..., 0.1494, 0.1097, 0.1365]),\n", " torch.Size([1010]))" ] }, "execution_count": 302, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_3_dist = mnist_distance(valid_3_tens, mean3)\n", "valid_3_dist, valid_3_dist.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note:I think this an example of not using loops which slows down the process (check above important tag). Although shapes of the tensors don't match, out function still works. Pytorch fills the gaps." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__here is another example. Shapes don't match.__" ] }, { "cell_type": "code", "execution_count": 303, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([2, 3, 4])" ] }, "execution_count": 303, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tensor([1,2,3]) + tensor(1)" ] }, { "cell_type": "code", "execution_count": 304, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1010, 28, 28])" ] }, "execution_count": 304, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(valid_3_tens-mean3).shape" ] }, { "cell_type": "code", "execution_count": 305, "metadata": {}, "outputs": [], "source": [ "def is_3(x): return mnist_distance(x,mean3) < mnist_distance(x,mean7)" ] }, { "cell_type": "code", "execution_count": 306, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(True), tensor(1.))" ] }, "execution_count": 306, "metadata": {}, "output_type": "execute_result" } ], "source": [ "is_3(a_3), is_3(a_3).float()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### here is an another broadcasting for all validation set:" ] }, { "cell_type": "code", "execution_count": 307, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([True, True, True, ..., True, True, True])" ] }, "execution_count": 307, "metadata": {}, "output_type": "execute_result" } ], "source": [ "is_3(valid_3_tens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Accuracy of our 'ideal' 3 and 7" ] }, { "cell_type": "code", "execution_count": 308, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.9168), tensor(0.9854), tensor(0.9511))" ] }, "execution_count": 308, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_3s = is_3(valid_3_tens).float() .mean()\n", "accuracy_7s = (1 - is_3(valid_7_tens).float()).mean()\n", "\n", "accuracy_3s,accuracy_7s,(accuracy_3s+accuracy_7s)/2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STOCHASTIC GRADIENT DECENT (SGD)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Arthur Samues Machine Learning process:__" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Initialize the weights.\n", "- For each image, use these weights to predict whether it appears to be a 3 or a 7.\n", "- Based on these predictions, calculate how good the model is (its loss).\n", "- Calculate the gradient, which measures for each weight, how changing that weight would change the loss (SGD)\n", "- Step (that is, change) all the weights based on that calculation.\n", "- Go back to the step 2, and repeat the process.\n", "- Iterate until you decide to stop the training process (for instance, because the model is good enough or you don't want to wait any longer).\n" ] }, { "cell_type": "code", "execution_count": 309, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "\n", "init\n", "\n", "init\n", "\n", "\n", "predict\n", "\n", "predict\n", "\n", "\n", "init->predict\n", "\n", "\n", "\n", "\n", "loss\n", "\n", "loss\n", "\n", "\n", "predict->loss\n", "\n", "\n", "\n", "\n", "gradient\n", "\n", "gradient\n", "\n", "\n", "loss->gradient\n", "\n", "\n", "\n", "\n", "step\n", "\n", "step\n", "\n", "\n", "gradient->step\n", "\n", "\n", "\n", "\n", "step->predict\n", "\n", "\n", "repeat\n", "\n", "\n", "stop\n", "\n", "stop\n", "\n", "\n", "step->stop\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 309, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#id gradient_descent\n", "#caption The gradient descent process\n", "#alt Graph showing the steps for Gradient Descent\n", "gv('''\n", "init->predict->loss->gradient->step->stop\n", "step->predict[label=repeat]\n", "''')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GD example" ] }, { "cell_type": "code", "execution_count": 310, "metadata": {}, "outputs": [], "source": [ "def f(x): return x**2" ] }, { "cell_type": "code", "execution_count": 311, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEMCAYAAADeYiHoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyKklEQVR4nO3dd3yV9fn/8deVTSYrCTMJYQ9ZBmS5SquiVWqhCuIsiqNaW6t+bdVqXb9qd6uoKIrgqKMobq1VVBCRgKwgG5IwEhJG9s71++Mc2hhPIAk5932SXM/H43547nM+Ofeb25Nz5XOPz0dUFWOMMaa+ILcDGGOMCUxWIIwxxvhkBcIYY4xPViCMMcb4ZAXCGGOMTyFuB2gpXbt21ZSUFLdjGGNMq7J69ep8VY339VqbKRApKSmkp6e7HcMYY1oVEcls6DU7xGSMMcYnKxDGGGN8sgJhjDHGJysQxhhjfLICYYwxxifHC4SI9BeRchF5voHXRUQeFpGD3uVhERGncxpjTHvnxmWujwGrjvH6HOBHwAhAgX8Du4An/J7MGGPMfznagxCRGcAR4D/HaHYF8CdV3aOqe4E/AVf6K9Pa7CM8/P5mf729Mcb4jary4DubyNhX4Jf3d6xAiEgscB9wy3GaDgXW1Vlf533O13vOEZF0EUnPy8trVq4Ne47w+NIdbNzrnx1sjDH+8uXOQzz1+S625BT55f2d7EHcD8xX1T3HaRcN1P22LgCifZ2HUNV5qpqmqmnx8T7vFD+uC0b2JDwkiH+uymrWzxtjjFteXpVFTEQIU4Z198v7O1IgRGQk8H3gL41oXgzE1lmPBYrVT1PfxXUI5dyTurNk7T7KKmv8sQljjGlxBaVVvLcxh6kje9AhLNgv23CqB3EGkAJkiUgOcCswTUTW+GibgecE9VEjvM/5zUVpvSkqr+a9jfv9uRljjGkxS9btpaK6lhljkvy2DacKxDygLzDSuzwBvAOc7aPtQuAWEekpIj2AXwEL/BluXGpnUrpE8vKqbH9uxhhjWszLq7IZ0j2WYT3j/LYNRwqEqpaqas7RBc9hpHJVzRORU0WkuE7zJ4G3gA3ARjyF5El/5hMRfpLWm5W7DrEzr/j4P2CMMS7auLeAjH2FzBjb26/bceVOalW9V1Uv9T7+XFWj67ymqnq7qnb2Lrf76/xDXdNP7kVwkPBK+vHOoRtjjLv+uSqL8JAgpo7o6dft2FAbXomxEZw5MJ7XVu+hqqbW7TjGGONTWWUNS9buY8qwbsRFhvp1W1Yg6rh4TBL5xRV8vPmA21GMMcandzfsp6i8movG+PfwEliB+JYzB8aTGBvOS1/ZPRHGmMD00ldZ9OkaxfjULn7flhWIOkKCg7gorTefbs1j75Eyt+MYY8y3bMstIj3zMDPG9MaJMUytQNRzUZqn22aXvBpjAs1LX2UTGixMO7mXI9uzAlFP786RnNY/nlfTs6m2k9XGmABRXlXD4q/3cNbQbnSNDndkm1YgfJg5tjf7C8r5dGvzBgA0xpiW9kFGDkdKq5jpxzun67MC4cPkwYl0jQ7npa/sMJMxJjC89FUWSZ0jmdDX/yenj7IC4UNocBAXpfXi48255BSUux3HGNPO7cwr5sudh5gxtjdBQc5NsGkFogEzxiRRq3ay2hjjvpe+yiIkSJju0Mnpo6xANCCpSySnDYjnn6uy7GS1McY15VU1vLp6D2cP7UZCTISj27YCcQyzTklif0E5n2yxk9XGGHe8t3E/R0qruOQU505OH2UF4hgmD0ogMTacF1Zmuh3FGNNOvfClc3dO12cF4hhCgoO4eEwSn27NI/tQqdtxjDHtzOacQtIzD3PJ2CRHT04fZQXiOGaM6Y2Ajc9kjHHciyuzCAsOcuzO6focKxAi8ryI7BeRQhHZKiJXN9DuShGpEZHiOssZTuWsr0fHDnxvUCKvpO+hstpOVhtjnFFaWc3ra/Zy7knd6BwV5koGJ3sQ/w9IUdVY4ALgARE5uYG2K1Q1us6y1LGUPswa5xkG/IOMHDdjGGPakTfX7qOooppLTkl2LYNjBUJVM1S14uiqd+nr1PZPxGn94+nduQPPf2knq40x/qeqLFyRycDEGMakdHIth6PnIERkroiUApuB/cC7DTQdJSL53kNRd4tISAPvN0dE0kUkPS/Pf5eiBgcJs05JZuWuQ2zNLfLbdowxBuDr7CNs2l/IpeOTHRnWuyGOFghVvQGIAU4FFgMVPpp9BgwDEoBpwEzgtgbeb56qpqlqWnx8vH9Ce12U1puwkCDrRRhj/O75FZlEh4dw4Sj/zjl9PI5fxaSqNaq6DOgFXO/j9Z2quktVa1V1A3AfMN3pnPV1jgrjhyd1Z/GavRRXVLsdxxjTRh0qqeTt9fv58eieRIf7PHjiGDcvcw2hcecgFHCvj1XHpeOTKa6o5o2v97odxRjTRr2Snk1lTS2XjnPv5PRRjhQIEUkQkRkiEi0iwSJyNp5DR//x0XaKiCR6Hw8C7gaWOJHzeEb17sjQHrE8/2Umqup2HGNMG1NTq7ywMpNT+nRmQGKM23Ec60EonsNJe4DDwB+BX6jqmyKS5L3X4ehAI5OB9SJSguck9mLgIYdyHpOIcNm4ZDbneOaFNcaYlvTZ1jyyD5Vx2Xj3ew/gOczjd6qaB5zewGtZQHSd9VuBW53I1RwXjOzBQ+9+w3Nf7GZMSme34xhj2pDnVuwmPiacs4Z0czsKYENtNFlkWAgXpfXm/Y055BbaZELGmJaxO7+EpVvymHVKEmEhgfHVHBgpWpnLxidTo8oLK218JmNMy1i4IpPQYHFlWO+GWIFohuQuUZw5MIEXV2bZ+EzGmBNWUlHNq6uzmTKsu+OTAh2LFYhmunx8MvnFFby3cb/bUYwxrdzrX++lqLyaKyYExsnpo6xANNNp/ePp0zWK577Y7XYUY0wr5hl3aTfDesYyOsm9cZd8sQLRTEFBnkte12QdYcOeArfjGGNaqRU7D7I1t5jLx6e4Ou6SL1YgTsD0tF5EhgXz7Be73I5ijGmlFizfTafIUC4Y0cPtKN9hBeIExEaEMv3kXry9bj95Rb7GHTTGmIZlHyrlo29ymTk2iYjQYLfjfIcViBN0xYQUKmtqedEueTXGNNHCFbs9IzQEyJ3T9VmBOEF946M5Y2A8z6/MtEtejTGNVlJRzT9XZTNlWDe6x3VwO45PViBawJUTUsgrquCdDfvcjmKMaSUWr9lDUXk1V01McTtKg6xAtIDT+seTGh/Fs8t32yivxpjjqq1Vnv1iN8N7xQXcpa11WYFoAUFBwlUTUli/p4A1WTbKqzHm2D7fns/OvBKumhh4l7bWZQWihfx4dC9iI0J4Zvlut6MYYwLcM8t2ER8TznknBd6lrXVZgWghUeEhzBybxHsb9pN9qNTtOMaYALUtt4hPt+Zx+bjkgBm1tSGOpROR50Vkv4gUishWEbn6GG1/KSI53rbPiEi4UzlPxBUTPN1FG37DGNOQZ5bvIjwkiFkBMKXo8ThZvv4fkKKqscAFwAMicnL9Rt7pSO/AM7NcMpAK/M7BnM3Wo2MHzj2pOy+vyqaovMrtOMaYAHOwuIJ/rdnLj0f3onNUmNtxjsuxAqGqGap69HZj9S59fTS9ApjvbX8YuB+40pmUJ272pD4UVVTzSvoet6MYYwLMC94pAmZPSnE7SqM4egBMROaKSCmwGdiPZ87p+oYC6+qsrwMSRaSLj/ebIyLpIpKel5fnl8xNNbJ3R9KSO/Hs8l3U1Nolr8YYj4rqGhauyOSMgfH0S4hxO06jOFogVPUGIAY4FVgM+BrAKBqoOzzq0cff2aOqOk9V01Q1LT4+vqXjNtvsSX3Yc7iMDzNy3I5ijAkQb67dR35xBbMn9XE7SqM5fgpdVWtUdRnQC7jeR5NiILbO+tHHRf7O1lLOGtqN3p078PQyG+XVGOOZ8+Hpz3cxMDGGSf26uh2n0dy8xioE3+cgMoARddZHALmqetCRVC0gOEiYPbEPqzMPszrzkNtxjDEu+2xbPltyi7jmtNSAvjGuPkcKhIgkiMgMEYkWkWDvlUozgf/4aL4QmC0iQ0SkI3AXsMCJnC3pJ2m9iesQyrzPdrodxRjjsqc+20libHhAzvlwLE71IBTP4aQ9wGHgj8AvVPVNEUkSkWIRSQJQ1feBR4BPgCwgE7jHoZwtJio8hMvGJfPhplx25Ze4HccY45KNewtYtj2fqyb2Cfgb4+pzJK2q5qnq6araUVVjVfUkVX3K+1qWqkaralad9n9W1URv26vqXB7bqlw+IZnQoCDmL7NehDHt1dOf7yQqLJiZY5PcjtJkrauctTIJMRFcOKonr6bv4WBxq6xxxpgTsO9IGW+t38+MsUnEdQh1O06TWYHws6tP7UNFdS2Lvsx0O4oxxmHPLvdcyRjIcz4cixUIP+ufGMPkQQksXJFJWWWN23GMMQ4pKK3ixZVZ/HB4d3p1inQ7TrNYgXDAdWf05VBJJa+kZ7sdxRjjkOdXZlJSWcO1p/m6mr91sALhgDEpnTk5uRNPfb6T6hqbt9qYtq68qoZnl+/itAHxDOkRe/wfCFBWIBxy3el92XO4jHc27Hc7ijHGz15bvYf84kquOz3V7SgnxAqEQyYPSqB/QjRPfLrT5q02pg2rqVWe+nwnI3rFMT71O2OMtipWIBwSFCTMOS2Vb/YX8unWwBh51hjT8t7buJ/Mg6Vcd3rfVjWshi9WIBw0dWRPusdF8PjSHW5HMcb4garyxKc76NM1irOGdnM7zgmzAuGgsJAgrj41lZW7DrE687DbcYwxLeyzbfls3FvIdaenEhzUunsPYAXCcTPH9qZTZCiPL93udhRjTAub+8l2usdFcOGoXm5HaRFWIBwWGRbCTyf24aNvDvDN/kK34xhjWkj67kOs3HWIa05NbXWD8jWkbfwrWpnLx6cQHR5i5yKMaUPmLt1B56gwZozt7XaUFmMFwgVxkaHMGpfE2+v3sduGAjem1cvYV8DHmw9w1YQUIsNC3I7TYpyaMChcROaLSKaIFInIWhGZ0kDbK0WkxjtHxNHlDCdyOmn2pD6EBAfxxKfWizCmtXt86Q6iw0O4fHyK21FalFM9iBAgGzgdiMMzS9wrIpLSQPsV3jkiji5LnYnpnISYCGaM6c2/1uxh75Eyt+MYY5pp+4Fi3tmwn0vHJRMX2fqG9D4WpyYMKlHVe1V1t6rWqurbwC7gZCe2H6iuPd0ziNeT1oswptWa+8l2wkOCuPrUPm5HaXGunIMQkURgAJDRQJNRIpIvIltF5G4R8XlQT0TmiEi6iKTn5bW+u5N7duzAtNG9+OeqbA4UlrsdxxjTRJkHS1iybh+XnpJM1+hwt+O0OMcLhIiEAi8Az6nqZh9NPgOGAQnANGAmcJuv91LVeaqapqpp8fHx/orsVzec0Y+aWuXJz2xaUmNam7mf7CDYO4xOW+RogRCRIGARUAnc6KuNqu5U1V3eQ1EbgPuA6Q7GdFRSl0imjuzBCyszybdpSY1pNfYcLuVfa/Ywc0xvEmIj3I7jF44VCPGMWjUfSASmqWpVI39UgdZ/z/ox/OzMflRU1/LU59aLMKa1eOLTHYj871xiW+RkD+JxYDBwvqo2eNmOiEzxnqNARAYBdwNLnInojr7x0fxweA8WrcjkUEml23GMMcexv6CMV1btYfrJvenRsYPbcfzGqfsgkoFrgZFATp37G2aJSJL3cZK3+WRgvYiUAO8Ci4GHnMjppp9/rx9lVTXMs3MRxgS8uZ/soFaVn53ZdnsP4Lk/we9UNZNjHyaKrtP2VuBWv4cKMP0TY/jh8B4sXLGba07tQ5c2eEWEMW3BviNlvLwqm5+k9aZXp0i34/iVDbURQG6e7OlFPPX5LrejGGMaMHfpdpS233sAKxABpV9CDOd7exEH7YomYwJOe+o9gBWIgPNzby9inl3RZEzAeewTzzwuPzuzn8tJnGEFIsD0S4jhgthKFv7nG/KiO0NKCrzwgtuxjGn39hwu5ZV0T++hZxu+cqkuKxCB5oUXuPnR26gICuGJU6ZBZibMmWNFwhiX/eM/2xERbvpe++g9gBWIwHPnnaTu28GPMz5m0ahzyYnuAqWlcOedbiczpt3alV/Ca2v2cMnYJLrHtY/eA1iBCDxZWQDcvPyf1EoQj46/6FvPG2Oc97ePthIaLNzQDq5cqssKRKBJ8twv2Lsgl4vXf8jLI84iOzbhv88bY5y1NbeIJev2ccWEFBJi2uaYSw1pVIEQkUgRGSUiMT5em9jysdqxBx+ESM/lczeueBlR5R+nXep53hjjuL9+tJWosBCuO6199R6gEQVCRMYCmcBSIFdEbq/X5D0/5Gq/Zs2CefMgOZnuxYe4dMcyXht6JjvOmup2MmPanY17C3h3Qw4/nZhCp6gwt+M4rjE9iD8Bv1HVOGACcKmIPFHn9TY90qorZs2C3buhtpYbFj1ERFgIf/5wq9upjGl3HvlgCx0jQ7m6jc73cDyNKRDDgKcBVHUtMAkYJCILvfM7GD/qGh3O1aem8s6G/WzYU+B2HGPajRU7DvLZ1jx+dkY/YiPa1lzTjdWYL/hS4L/TtalqIXAOnhnfXsN6EH53zal96BQZyiMf+JqAzxjT0lSVRz7YTLfYCC4bn+x2HNc0pkB8ClxS9wlVLQcuAEKB9nNRsEtiIkK54Yx+fL4tny925Lsdx5g279+bcvk66wg3f78/EaHBbsdxTWMKxM34mLBHVSuBC4EzWzqU+a7LxifTPS6CR97fgqq6HceYNqumVvnjh1tI7RrFT07u5XYcVx23QKhqHjD86LqITK3zWrWqfna89xCRcBGZLyKZIlIkImtFZMox2v9SRHJEpFBEnhGRdj85QkRoML/4fn/WZh/hg4wct+MY02YtXrOHrbnF3HLWAEKC2/dp1sb+6xNF5CoRuQLPnNJNFQJkA6cDccBdwCsiklK/oYicDdyBZ2a5ZCAV+F0zttnmTBvdi34J0Tzy/haqamrdjmNMm1NeVcOf/72VEb3iOO+k7m7HcV1j7oM4DdgKXONdtnifazRVLVHVe1V1t6rWqurbwC7gZB/NrwDmq2qGqh4G7geubMr22qqQ4CDuOGcQO/NL+OeqbLfjGNPmPLt8N/sLyrljymBE7PqbxvQg+uD5Sz4CiPQ+7nMiGxWRRGAAkOHj5aHAujrr6/D0YLr4eJ85IpIuIul5eXknEqnVmDw4gbEpnfnbR9soqah2O44xbcbhkkrmLt3O9wYlML7vd75u2qXGnIN4DigAXgCeBwq9zzWLiIR63+s5VfV13Wa0d3tHHX38nWE+VHWeqqapalp8fHz9l9skEeHX5w4iv7iCp2xSIWNazGOfbKekopr/O2eQ21ECRmPPQcQDfwX+Tp17IprKe2PdIqASuLGBZsVAbJ31o4+LmrvdtmZUUifOPakb8z7byYGicrfjGNPqZR8qZeGKTKaN7sXAbt/5W7TdamyBEP53Q1yzDsyJ54DefDwnuaepalUDTTOAEXXWRwC5qnqwOdttq24/exBVNbX85d82BIcxJ+rh9zcTFAS3nDXA7SgBpbEFIg/P/RA3AQeaua3HgcHA+apadox2C4HZIjJERDriueJpQTO32WaldI3isnEpvLwqm805hW7HMabVWp15mLfX72fOqantajKgxmjMVUyX4znMcylwGRDrfa7RRCQZuBYYCeSISLF3mSUiSd7HSQCq+j7wCPAJkIVnJNl7mrK99uLnk/sRExHKg+98YzfPGdMMqsoD72wiPiaca09vf8N5H09II9pkev9bCmid9UZT1UyOfWgqul77PwN/bup22puOkWH8fHJ/7n97E0u35nHmwAS3IxnTqryzYT9fZx3hkWnDiQpvzNdh+9KYq5g+xXNJ6tPeZYD3ORMALhuXTEqXSB565xuq7eY5YxqtvKqGh9/fzKBuMUxr50NqNKSx5yAOADtVdQGe8xH/JSIzWzqUabywkCDumDKYbQeKeekrm7famMZ6dvlusg+Vcdd5QwgOspvifGlUgVDVJcCrIvIw8A6AiHQUkZexYTBcd/bQRManduHP/97KkdJKt+MYE/AOFJXz6Mfb+MGQRCb17+p2nIDVlJGoRnqXVSIyG9gAHAFGtXgq0yQiwm/PH0JBWRV//Wib23GMCXh/eH8LlTW13HnuYLejBLRGFwhV3Qf8yPsz84D3VPVaVS3xUzbTBIO7xzJzbBKLvsxkW67dU2hMQ9ZlH+HV1Xv46aQ+pHSNcjtOQGt0gRCRkcAqYCcwFfieiLzovVfBBIBbfjCAqLBg7nt7k132aowPqsrv3sqga3Q4N57Zz+04Aa8ph5j+A/xFVX/kHY11BFCG51CTCQBdosP5xfcH8Pm2fD76prn3MxrTdi1Zu481WUe4/ZyBxLTTeaaboikFYoyqzj+64h3Cezbws5aPZZrrsvHJ9E+I5r63MyivqnE7jjEBo6i8igff/YYRveKYPtoua22MppyD8Dl0qKq+2XJxzIkKDQ7idxcMJftQGU9+aqO9GnPU3/+zjfziCu6bOowgu6y1Udr3fHpt1IR+XTlveHfmLt1O9qFSt+MY47ptuUU8u3w3F6f1ZkTvjm7HaTWsQLRRd503mCARHnhnk9tRjHGVqnLvWxlEhgVz29kD3Y7TqliBaKO6x3Xgpsn9+CAjl6Vb7IS1ab/e3ZDD8u0Hue3sgXSJDnc7TqtiBaINu3pSKqnxUdzzpp2wNu1TUXkV972dwdAesVxySrLbcVodKxBtWFhIEA9MHUbmwVLmfrLd7TjGOO7P/97KgaIKHrzwJBtvqRmsQLRxE/p15Ucje/DEpzvZkVfsdhxjHLNxbwHPfbGbWackMdJOTDeLYwVCRG4UkXQRqRCRBcdod6WI1NSZVKhYRM5wKmdbdOd5QwgPDeLuNzbaHdamXaitVe56YyOdo8K47exBbsdptZzsQewDHgCeaUTbFaoaXWdZ6t9obVt8TDi3nzOIL3YcZMnafW7HMcbvXvwqi7XZR7jrvCHEdbA7ppvLsQKhqotV9Q3goFPbNP9zyVhPN/v+tzdxuMSGBDdtV25hOQ+/t5mJ/bowdWQPt+O0aoF6DmKUiOSLyFYRuVtEfM4FKCJzvIet0vPy8nw1MV7BQcLvp51EQZlnuAFj2qp7lmRQWVPLgz86CRE7MX0iArFAfAYMAxKAacBM4DZfDVV1nqqmqWpafHy8gxFbp0HdYrn29FReW72H5dvz3Y5jTIv7ICOH9zNyuPn7/W0o7xYQcAVCVXeq6i5VrVXVDcB9wHS3c7UVN32vPyldIvnN6xvs3gjTphSVV3HPkgwGdYvhmlNT3Y7TJgRcgfBBAesntpCI0GAe+vFJZB4s5S8fbXU7jjEt5uH3N5NbVM7vpw0nNLg1fLUFPicvcw0RkQggGAgWkQhf5xZEZIqIJHofDwLuBpY4lbM9mNC3KzPG9Oapz3ayLvuI23GMOWErdhzk+S+zuGpCH7vnoQU5WWbvwjPB0B3Apd7Hd4lIkvdehyRvu8nAehEpAd4FFgMPOZizXfjNeYNJiIng9tfWU1ld63YcY5qttLKa//vXepK7RNpgfC3Myctc71VVqbfcq6pZ3nsdsrztblXVRFWNUtVUVf2tqlY5lbO9iI0I5cELh7Elt4hHbRgO04r96cOtZB0q5fc/Hk6HsGC347QpdqCuHZs8OJELR/Vk7ifb2bSv0O04xjTZ6szDPLN8F5eOS2J83y5ux2lzrEC0c7/94RA6RoZy66vr7FCTaVXKKmu47dV19IjrwB1TBrsdp02yAtHOdYoK48ELT2LT/kL+8fE2t+MY02iPfLCZnfklPDJ9ONHhPu+lNSfICoTh7KHd+PHonsxduoO1dlWTaQW+2JHPs8t3c8X4ZCb26+p2nDbLCoQB4J7zh5IQE86vXllrN9CZgFZUXsVtr64npUsk/zfFRmr1JysQBoC4DqE8PG04O/JK+MMHW9yOY0yDHnj7G/YXlPGni0YQGWaHlvzJCoT5r9MGxHPZuGTmL9vFsm02VpMJPO9vzOHl9GyuPb0vJyd3djtOm2cFwnzLb84dTN/4KH716lqOlNqw4CZwHCgs59eL1zOsZyy//P4At+O0C1YgzLd0CAvmbzNGcaikkt+8vsFmoDMBobZWufW19ZRV1fDXi0cRFmJfXU6wvWy+Y1jPOG75wUDe3ZDDa6v3uB3HGJ5bsZvPtuZx53lD6JcQ7XacdsMKhPFpzmmpnNKnM/e8mcHOvGK345h2bNO+Qv7fe5v53qAELj0l6fg/YFqMFQjjU3CQ8NcZIwkLCeKml76motoufTXOK62s5saX1tCxQyh/mD7cZohzmBUI06DucR34w/QRZOwr5PfvbXY7jmmH7lmSwa78Ev46YyRdosPdjtPuWIEwx/SDIYlcOSGFZ5fv5qNNuW7HMe3IkrV7eXX1Hm48sx8T+trd0m5wcsKgG0UkXUQqRGTBcdr+UkRyRKRQRJ4REfvTwUW/PncQQ7rHcutr69hzuNTtOKYd2JFXzG8WbyAtuRM3T+7vdpx2y8kexD7gAeCZYzUSkbPxTCo0GUgGUoHf+T2daVB4SDCPzRpNdY1y44tf26ivxq/KKmu44fk1hIUE8feZowix6UNd4+SEQYtV9Q3g4HGaXgHMV9UMVT0M3A9c6ed45jj6dI3iD9OHszb7CA+9+43bcUwbdveSjWw9UMRfZ4yiR8cObsdp1wKxNA8F1tVZXwckiojNBuKyKSd156qJKSz4YjfvrN/vdhzTBr2Sns1rq/dw05n9OH1AvNtx2r1ALBDRQEGd9aOPY+o3FJE53vMa6Xl5eY6Ea+9+PWUwo5I6cvtr69h+oMjtOKYN2bi3gLvf2MiEvl242YbSCAiBWCCKgdg660cff+fbSFXnqWqaqqbFx9tfG04ICwli7qzRdAgLZs6i1RSV23Th5sQdKqnk2kWr6RwVxt9njiI4yO53CASBWCAygBF11kcAuap6vHMXxiHd4zrw6CWjyTxYyi2vrKO21sZrMs1XXVPLTS+tIa+4gicuPZmudr9DwHDyMtcQEYkAgoFgEYkQEV+DuS8EZovIEBHpCNwFLHAqp2mccalduPPcwfx7Uy6PfrLd7TimFfvDB1tYvv0gD/xoGCN6d3Q7jqnDyR7EXUAZnktYL/U+vktEkkSkWESSAFT1feAR4BMgC8gE7nEwp2mkqyamcOGonvzlo618mJHjdhzTCr3x9V6e/Gwnl45L4qK03m7HMfVIWxnOOS0tTdPT092O0e6UV9Vw8ZMr2HagmH9dP4HB3WOP/0PGAF9nHebieV8yOqkji2afQqjd7+AKEVmtqmm+XrP/I+aERIQGM+/yNGIiQrj6uXTyiyvcjmRagf0FZcxZtJpusRE8PutkKw4Byv6vmBOWGBvBU5encbCkgusWrbaRX80xlVZWc83CdMoqa3j6ijQ6RYW5Hck0wAqEaRHDe3Xkjz8ZQXrmYW5/bb3NRGd8qqlVfv7SWjbtK+TvM0cyIPE7tzeZAOLrKiJjmuWHw3uQdaiUR97fQlLnSH511kC3I5kAc//bm/jom1zumzqU7w1KdDuOOQ4rEKZFXX96X7IOlvKPj7fTu1MkF42xK1OMxzPLdrHgi93MntSHy8enuB3HNIIVCNOiRIT7fzSMvUfK+M3rG0iIDeeMgQluxzIue2/Dfu5/ZxNnD03kN+cOdjuOaSQ7B2FaXGiwZziOgd1iuP75NXydddjtSMZFX+zI5+Z/rmV0Uif+erENo9GaWIEwfhETEcqCq8aSEBvOTxesYvuBYrcjGRds3FvAnIWrSe4Syfwr0ugQFux2JNMEViCM38THhLPwp2MJDhKueOYr9h0pczuScVDmwRKufHYVsREhLJw9lo6Rdjlra2MFwvhVcpcoFlw1lsKyKi59eiV5RXYjXXuw70gZlzy1kpraWp776Vi6x9nEP62RFQjjd8N6xvHsVWPYX1DOZfNXcqS00u1Ixo8OFJUz6+mVFJZVsWj2KfS3ex1aLSsQxhFpKZ156vI0duaVcMUzX9k8Em3U4ZJKLp//FTkF5Sz46RiG9YxzO5I5AVYgjGMm9e/K3FmjydhXyOVWJNqcwyWVzHp6JTvzS3j6ijROTu7sdiRzgqxAGEd9f0gij14ymg17Crj8ma8otCLRJhwqqeSSp1eyPa+Ypy9PY2K/rm5HMi3ACoRx3DnDuvHYLG+RmG9ForU7dLTn4C0Opw2w6X/bCidnlOssIq+LSImIZIrIJQ20u1dEqryTCB1dUp3KaZxx9tBu3sNNBVzy1JcctGHCW6XcwnIufnIFO/OKecqKQ5vjZA/iMaASSARmAY+LyNAG2r6sqtF1lp2OpTSOOWtoN+Zdnsa23GIunvcluYXlbkcyTZB9qJSfPLGCfUfKWHDVWCsObZAjBUJEooBpwN2qWqyqy4A3gcuc2L4JXGcOTOC5n44lp6Cc6U98QdbBUrcjmUbYfqCInzyxgoKyKl64Zhzj+3ZxO5LxA6d6EAOAalXdWue5dUBDPYjzReSQiGSIyPUNvamIzBGRdBFJz8vLa8m8xkHjUrvwwtWnUFRezY8fX86GPQVuRzLHsGr3IaY9voLqWuXla8cxsndHtyMZP3GqQEQDhfWeKwB83UHzCjAYiAeuAX4rIjN9vamqzlPVNFVNi4+37m1rNqJ3R167bgLhIcFcPG8Fn261gh+I3t+Yw6VPr6RLVBiv3zCBQd1sDvK2zKkCUQzU/yTFAkX1G6rqJlXdp6o1qvoF8DdgugMZjcv6JUSz+IYJJHeJYvaCVbyyKtvtSMZLVVmwfBfXv7CaIT1iee36CfTuHOl2LONnThWIrUCIiPSv89wIIKMRP6uAjQ/cTiTGRvDKtZ5j2rf/az0PvfsNNbU2fambqmpquXvJRu59axOTByXy4tXj6GzzSLcLjhQIVS0BFgP3iUiUiEwEpgKL6rcVkaki0kk8xgI/B5Y4kdMEhpiIUJ65cgyXjUtm3mc7uXbRaoorqt2O1S4VlFXx0wWreP7LLK49LZUnLzvZhuxuR5y8zPUGoANwAHgJuF5VM0TkVBGpO1nADGA7nsNPC4GHVfU5B3OaABAaHMT9PxrG7y4Yysebc/nx3OXszLM5JZy0JaeIqY8u48udB3lk+nB+fe5gm+ynnRHVttF9T0tL0/T0dLdjGD9Yti2fm15aQ3WN8ueLR/KDITbZvb+9tW4ft7+2nuiIEObOGs2YFBtXqa0SkdWqmubrNRtqwwS8Sf278tZNk0juGsk1C9P5wwebqa6pdTtWm1RRXcPv3srgppe+ZmiPWN65aZIVh3bMCoRpFXp1iuS16yZwcVpvHvtkBzPmfclem6GuRe3OL2H64yt4dvlurpyQwovXjCMhNsLtWMZFViBMqxERGszD04fztxkj2ZxTxLl/+5z3N+53O1arp6q88fVefviPZWQdKuXJy07m3guGEhZiXw/tnX0CTKszdWRP3vn5JJK7RHLd82v45ctrKSizEWGb42BxBTe8sIZfvLyWQd1iePfmUzl7aDe3Y5kAEeJ2AGOaI7lLFP+6fgKPfrydRz/ZzoodB3l4+nBOtwHjGu2DjBzufH0DhWXV/N85g5hzWqpdpWS+xXoQptUKDQ7ilz8YwOs3TCA6IoQrnvmKm//5Nfk2dPgx5RSUc92i1Vy7aDXxMRG8edNErj+jrxUH8x12matpE8qrapi7dAePL91OZFgIv54yiIvSehNkX3r/VV1Tywsrs/jDB1uoqqnl5u/355pTUwkNtr8T27NjXeZqBcK0Kdtyi/jN6xtYtfswJ/WM457zh5Bml2myfHs+9721iS25RUzq15UHLxxGcpcot2OZAGAFwrQrqsqStfv4/XubySks54fDu3PrWQNJ6dr+vhC35Rbxhw+28OGmXHp16sCd5w7mnGHdELGelfGwAmHapdLKap5YuoOnPt9FZU0tF6X15ubJ/ekW1/av7c8+VMpfP9rG61/vITIshOtOT+XqU1OJCLVxlMy3WYEw7dqBonIe+3g7L36VhYgw/eReXHtaaps8xLL9QDFPfLqDN77eS1CQcMX4ZK4/o5+NvmoaZAXCGDx/Vc9duoN/rd5DdW0t557Unasm9mF0UsdWfchFVVm56xALlu/mg005hIcEMWNMEteenkr3uA5uxzMBzgqEMXUcKCxn/rJdvLgyi6KKaob2iOXy8cmcN7wH0eGt59agwvIq3ly7j0UrMtmSW0Rch1AuG5fMVRNT6BId7nY800pYgTDGh5KKal7/ei8LV+xma24xHUKDOWdYNy4c1ZPxfbsE5OWfldW1LNuex+I1e/lwUy6V1bUM6R7LlRNSOH9ED5urwTSZFQhjjkFVWZ15mMVf7+XtdfsoLK8mrkMokwcncNaQbkzs14WYiFDX8hWUVrFsez4fZOTwyeYDFFVU0ykylAtG9ODC0b0Y0SuuVR8iM+4KiAIhIp2B+cBZQD7wa1V90Uc7AX4PXO196mngDj1OUCsQpiWUV9WwdEseH27K4T/fHKCgrIrgIGFErzgm9uvK6OROjOjV0a8nffOLK1iXfYTVmYdZvj2fDXsLqFXoHBXG971F67QB8TaYnmkRxyoQTh5wfQyoBBKBkcA7IrJOVevPSz0H+BGeOasV+DewC3jCsaSm3YrwHmY6Z1g3qmpqSd/t+ZJeviOfuUt3/Hd+7N6dOzAwMZZ+CdH0jY+iV6dIusVF0C02olGHeUoqqskpLCe3oJw9h8vYkVfMjrxivtlf9N9hzEOChFFJHbnpe/2Z1L8ro3p3JCQAD3uZtsuRHoSIRAGHgWGqutX73CJgr6reUa/tF8ACVZ3nXZ8NXKOq4461DetBGH8rrqhm494C1mUfYd2eI2zLLWb3wRKqar79OxQeEkRMRAhR4SGEeb/QFc/5g5KKaooqqqms/vaER2EhQaR2jaJfQjQjenVkZFJHhvaIJTKs9Zw0N61TIPQgBgDVR4uD1zrgdB9th3pfq9tuqK83FZE5eHocJCUltUxSYxoQHR7CuNQujEvt8t/nqmtqyT5cxr4jZeQUlJNTWE5hWRVFFdUUl1dTXfu/QhAaHER0eAjRESF07BBGt7hwEmMj6NmxA706RdpgeSbgOFUgooHCes8VADENtC2o1y5aRKT+eQhvL2MeeHoQLRfXmMYJCQ6iT9co+rTDYTxM2+fUAc1iILbec7FAUSPaxgLFxztJbYwxpmU5VSC2AiEi0r/OcyOA+ieo8T43ohHtjDHG+JEjBUJVS4DFwH0iEiUiE4GpwCIfzRcCt4hITxHpAfwKWOBETmOMMf/j5DVzNwAdgAPAS8D1qpohIqeKSHGddk8CbwEbgI3AO97njDHGOMixa+hU9RCe+xvqP/85nhPTR9cVuN27GGOMcYnddWOMMcYnKxDGGGN8sgJhjDHGpzYzmquI5AGZzfzxrngGEAw0lqtpLFfTBWo2y9U0J5IrWVXjfb3QZgrEiRCR9IbGInGT5Woay9V0gZrNcjWNv3LZISZjjDE+WYEwxhjjkxUIj3luB2iA5Woay9V0gZrNcjWNX3LZOQhjjDE+WQ/CGGOMT1YgjDHG+GQFwhhjjE/trkCISLiIzBeRTBEpEpG1IjLlOD/zSxHJEZFCEXlGRML9lO1GEUkXkQoRWXCctleKSI2IFNdZznA7l7e9U/urs4i8LiIl3v+flxyj7b0iUlVvf6U6nUU8HhaRg97lYRHx21yjTcjl1/1Tb1tN+Zw78llqajaHf/+a9J3Vkvus3RUIPCPYZuOZDzsOuAt4RURSfDUWkbOBO4DJQDKQCvzOT9n2AQ8AzzSy/QpVja6zLHU7l8P76zGgEkgEZgGPi4jP+cu9Xq63v3a6kGUOnlGNRwDDgfOBa1swR3NzgX/3T12N+jw5/FlqUjYvp37/Gv2d1eL7TFXb/QKsB6Y18NqLwEN11icDOX7O8wCw4DhtrgSWObyfGpPLkf0FROH54htQ57lFwO8baH8v8Lyf9kujswBfAHPqrM8GvgyAXH7bP839PLnxu9eEbI7//tXbvs/vrJbeZ+2xB/EtIpIIDKDhaU2HAuvqrK8DEkWki7+zNcIoEckXka0icreIODa/xzE4tb8GANWqurXeto7VgzhfRA6JSIaIXO9SFl/751iZncoF/ts/zRXIv3vg0u/fcb6zWnSftesCISKhwAvAc6q6uYFm0UBBnfWjj2P8ma0RPgOGAQnANGAmcJuriTyc2l/RQGG95wqOsZ1XgMFAPHAN8FsRmelCFl/7J9pP5yGaksuf+6e5AvV3D1z6/WvEd1aL7rM2VyBEZKmIaAPLsjrtgvB0tyuBG4/xlsVAbJ31o4+L/JGrsVR1p6ruUtVaVd0A3AdMb+r7tHQunNtf9bdzdFs+t6Oqm1R1n6rWqOoXwN9oxv5qQFOy+No/xeo9HtDCGp3Lz/unuVrks+QPLfX71xSN/M5q0X3W5gqEqp6hqtLAMgk8V5IA8/GcuJumqlXHeMsMPCcUjxoB5KrqwZbOdYIUaPJfoX7I5dT+2gqEiEj/ettq6FDhdzZBM/ZXA5qSxdf+aWxmf+aqryX3T3O1yGfJIX7dX034zmrRfdbmCkQjPY6nO32+qpYdp+1CYLaIDBGRjniuIFjgj1AiEiIiEUAwECwiEQ0d1xSRKd5jkYjIIOBuYInbuXBof6lqCbAYuE9EokRkIjAVz19Yvv4NU0Wkk3iMBX5OC+2vJmZZCNwiIj1FpAfwK/z0eWpKLn/uHx/bauznybHfvaZmc/L3z6ux31ktu8/cOgvv1oLn0i8FyvF0x44us7yvJ3nXk+r8zC1ALp7juc8C4X7Kdq83W93lXl+5gD96M5UAO/F0cUPdzuXw/uoMvOHdB1nAJXVeOxXPoZuj6y8BB71ZNwM/dyKLjxwCPAIc8i6P4B0Tzcl95PT+acznyc3PUlOzOfz71+B3lr/3mQ3WZ4wxxqf2eojJGGPMcViBMMYY45MVCGOMMT5ZgTDGGOOTFQhjjDE+WYEwxhjjkxUIY4wxPlmBMMYY45MVCGOMMT5ZgTDGD0Skr3duhdHe9R4ikid+mpbSGH+woTaM8RMRuQb4JZAGvA5sUNVb3U1lTONZgTDGj0TkTaAPnsHWxqhqhcuRjGk0O8RkjH89hWfmsX9YcTCtjfUgjPETEYnGMyfwJ8AU4CRVPeRuKmMazwqEMX4iIvOBaFW9WETmAR1V9SK3cxnTWHaIyRg/EJGpwDnA9d6nbgFGi8gs91IZ0zTWgzDGGOOT9SCMMcb4ZAXCGGOMT1YgjDHG+GQFwhhjjE9WIIwxxvhkBcIYY4xPViCMMcb4ZAXCGGOMT/8fNzhXulh8USQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_function(f, 'x', 'x**2')\n", "plt.scatter(-1.5, f(-1.5), color='red');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to decrease the loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"A" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How to calculate gradient:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Now our tensor xt is under investigation. Pytorch will keeps its eye on it.__" ] }, { "cell_type": "code", "execution_count": 312, "metadata": {}, "outputs": [], "source": [ "xt = tensor(3.).requires_grad_()" ] }, { "cell_type": "code", "execution_count": 313, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(9., grad_fn=)" ] }, "execution_count": 313, "metadata": {}, "output_type": "execute_result" } ], "source": [ "yt = f(xt)\n", "yt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Result is 9 but there is a grad function in the result.__\n", "***" ] }, { "cell_type": "code", "execution_count": 314, "metadata": {}, "outputs": [], "source": [ "yt.backward()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__backward calculates the derivative.__" ] }, { "cell_type": "code", "execution_count": 315, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(6.)" ] }, "execution_count": 315, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xt.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__result is 6.__\n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__now with a bigger tensor__" ] }, { "cell_type": "code", "execution_count": 316, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 3., 4., 10.], requires_grad=True)" ] }, "execution_count": 316, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xt = tensor([3.,4.,10.]).requires_grad_()\n", "xt" ] }, { "cell_type": "code", "execution_count": 317, "metadata": {}, "outputs": [], "source": [ "def f(x): return (x**2).sum()" ] }, { "cell_type": "code", "execution_count": 318, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(125., grad_fn=)" ] }, "execution_count": 318, "metadata": {}, "output_type": "execute_result" } ], "source": [ "yt = f(xt)\n", "yt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__again we expect 2*xt:__" ] }, { "cell_type": "code", "execution_count": 319, "metadata": {}, "outputs": [], "source": [ "yt.backward()\n" ] }, { "cell_type": "code", "execution_count": 320, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 6., 8., 20.])" ] }, "execution_count": 320, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xt.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### End to end SGD example" ] }, { "cell_type": "code", "execution_count": 321, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.])" ] }, "execution_count": 321, "metadata": {}, "output_type": "execute_result" } ], "source": [ "time = torch.arange(0,20).float()\n", "time" ] }, { "cell_type": "code", "execution_count": 322, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 322, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAAD7CAYAAACYLnSTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXl0lEQVR4nO3df4zcdZ3H8efLtpEN7bpCV7xuQisorVew9lyjkYDmwGs0MfZak0M4RHOk/ggX71ejvbNQQVO0epec+KsJKireIbrsiZw2xwGneEpcrKXZ2PauYJWtP7bqlm67hdJ73x/zHZiOszPf6Xy/M7PfeT2SCZ3P9zMzbz47+9rvfL6f+X4VEZiZ2dz2nE4XYGZmrXOYm5kVgMPczKwAHOZmZgXgMDczK4D5nXjRxYsXx7Jlyzrx0mZmc9bDDz98KCIGa23rSJgvW7aMsbGxTry0mdmcJenAbNs8zWJmVgAOczOzAnCYm5kVgMPczKwAGoa5pOmq20lJn6jYfpmkPZKOSbpf0tJ8SzYzs2oNV7NExMLyvyUtBH4J3JncXwyMANcCdwM3AXcAr86j2NGdE2zbsZeDUzMsGehj45rlrF09lMdLmZnNKc0uTVwP/Br4bnJ/HTAeEeVw3wIckrQiIvZkViWlIN80spuZEycBmJiaYdPIbgAHupn1vGbnzK8BvhjPnjd3JbCrvDEijgL7k/ZTSNogaUzS2OTkZNOFbtux95kgL5s5cZJtO/Y2/VxmZkWTOsyTufDXArdVNC8EDld1PQwsqn58RGyPiOGIGB4crPkFproOTs001W5m1kua2TO/GngwIh6raJsG+qv69QNHWi2s2pKBvqbazcx6STNh/jZO3SsHGAdWle9IOhM4P2nP1MY1y+lbMO+Utr4F89i4ZnnWL2VmNuekOgAq6TXAEMkqlgp3AdskrQfuAa4HHsn64Cc8e5DTq1nMbC7KezVe2tUs1wAjEXHK9ElETCZBfgvwZeAh4IrMqquydvWQw9vM5px2rMZLFeYR8c462+4FVmRSjZlZAdVbjZdVmPvr/GZmOWvHajyHuZlZztqxGs9hbmaWs3asxuvIlYbMzHpJO1bjOczNzNog79V4nmYxMysAh7mZWQE4zM3MCsBhbmZWAA5zM7MCcJibmRWAw9zMrAAc5mZmBeAwNzMrAIe5mVkBOMzNzArAYW5mVgCpw1zSFZJ+IumopP2SLknaL5O0R9IxSfdLWppfuWZmVkuqMJf0euAjwDuARcClwKOSFgMjwGbgLGAMuCOfUs3MbDZpT4H7QeDGiPhBcn8CQNIGYDwi7kzubwEOSVoREXuyLtbMzGpruGcuaR4wDAxK+l9Jj0u6RVIfsBLYVe4bEUeB/Ul79fNskDQmaWxycjK7/wMzM0s1zXIOsAB4C3AJ8HJgNfABYCFwuKr/YUpTMaeIiO0RMRwRw4ODg63UbGZmVdKEefny0Z+IiF9ExCHgH4E3AtNAf1X/fuBIdiWamVkjDcM8In4HPA5EZXPy33FgVblR0pnA+Um7mZm1SdqliZ8H/lLSCyQ9H/hr4JvAXcCFktZLOgO4HnjEBz/NrGhGd05w8c338aL338PFN9/H6M6JTpd0irSrWW4CFgP7gOPAV4EPR8RxSeuBW4AvAw8BV+RRqJlZp4zunGDTyG5mTpwEYGJqhk0juwFyvUhzM1KFeUScAN6T3Kq33QusyLguM7OusW3H3meCvGzmxEm27djbNWHur/ObmTVwcGqmqfZOSDvNUgijOyfYtmMvB6dmWDLQx8Y1y7vmr6qZda8lA31M1AjuJQN9Haimtp7ZMy/PeU1MzRA8O+fVbQcxzKz7bFyznL4F805p61swj41rlneoot/XM2Feb87LzKyetauH2LruIoYG+hAwNNDH1nUXddUn+56ZZpkLc15m1r3Wrh7qqvCu1jN75rPNbXXTnJeZ2enqmTCfC3NeZmanq2emWcofj7yaxcyKqGfCHLp/zsvM7HT1zDSLmVmROczNzArAYW5mVgAOczOzAnCYm5kVgMPczKwAHOZmZgXgMDczK4BUYS7pAUnHJU0nt70V266UdEDSUUmjks7Kr1wzM6ulmT3z6yJiYXJbDiBpJfBZ4GrgHOAY8KnsyzQzs3pa/Tr/VcDdEfEdAEmbgZ9IWhQRR1quzszMUmlmz3yrpEOSvifpdUnbSmBXuUNE7AeeAi6ofrCkDZLGJI1NTk62ULKZmVVLG+bvA84DhoDtwN2SzgcWAoer+h4GFlU/QURsj4jhiBgeHBxsoWQzM6uWKswj4qGIOBIRT0bEbcD3gDcC00B/Vfd+wFMsZmZtdLpLEwMQMA6sKjdKOg94LrCv9dLMzCythgdAJQ0ArwL+C3ga+DPgUuC9wALg+5IuAX4E3AiM+OCnmVl7pVnNsgD4ELACOAnsAdZGxD4ASe8CbgfOBu4F3pFPqWZmNpuGYR4Rk8Ar62z/CvCVLIsyM7Pm+Ov8ZmYF4DA3MysAh7mZWQE4zM3MCsBhbmZWAA5zM7MCcJibmRWAw9zMrAAc5mZmBeAwNzMrAIe5mVkBOMzNzAqg1WuAmpnNCaM7J9i2Yy8Hp2ZYMtDHxjXLWbt6qNNlZcZh3oSivxnMimp05wSbRnYzc+IkABNTM2wa2Q1QmN9hT7OkVH4zTEzNEDz7ZhjdOdHp0sysgW079j4T5GUzJ06ybcfeDlWUPYd5Sr3wZjArqoNTM021z0UO85R64c1gVlRLBvqaap+LmgpzSS+RdFzSlyvarpR0QNJRSaOSzsq+zM7rhTeDWVFtXLOcvgXzTmnrWzCPjWuWd6ii7DW7Z/5J4IflO5JWAp8FrgbOAY4Bn8qsui7SC28Gs6Jau3qIresuYmigDwFDA31sXXdRYQ5+QhOrWSRdAUwB/w28OGm+Crg7Ir6T9NkM/ETSoog4knGtHVX+oXs1i9nctHb1UKF/X1OFuaR+4Ebgj4FrKzatpBTuAETEfklPARcAD2dYZ1co+pvBzOautNMsNwG3RsTjVe0LgcNVbYeBRdVPIGmDpDFJY5OTk81XamZms2oY5pJeDlwO/FONzdNAf1VbP/B7UywRsT0ihiNieHBw8DRKNTOz2aSZZnkdsAz4mSQo7Y3Pk/SHwLeBVeWOks4Dngvsy7pQMzObXZow3w78a8X9v6MU7u8GXgB8X9IlwI8ozauPFO3gp5lZt2sY5hFxjNKSQwAkTQPHI2ISmJT0LuB24GzgXuAdOdVqZmazaPpEWxGxper+V4CvZFWQmZk1z1/nNzMrAIe5mVkBOMzNzArAYW5mVgAOczOzAnCYm5kVgMPczKwAHOZmZgXgMDczKwCHuZlZATjMzcwKwGFuZlYATZ9oy07f6M4JX0PUzHLhMG+T0Z0TbBrZzcyJkwBMTM2waWQ3gAPdzFrmaZY22bZj7zNBXjZz4iTbduztUEVmViQO8zY5ODXTVLuZWTM8zdImSwb6mKgR3EsG+jpQjdnc42NO9XnPvE02rllO34J5p7T1LZjHxjXLO1SR2dxRPuY0MTVD8Owxp9GdE50urWukCnNJX5b0C0lPSNon6dqKbZdJ2iPpmKT7JS3Nr9y5a+3qIbauu4ihgT4EDA30sXXdRd6zMEvBx5waSzvNshX4i4h4UtIK4AFJO4EDwAhwLXA3cBNwB/DqPIqd69auHnJ4m50GH3NqLNWeeUSMR8ST5bvJ7XxgHTAeEXdGxHFgC7AqCXwzs0zMdmzJx5yelXrOXNKnJB0D9gC/AP4dWAnsKveJiKPA/qS9+vEbJI1JGpucnGy5cDPrHT7m1FjqMI+I9wCLgEsoTa08CSwEDld1PZz0q3789ogYjojhwcHB06/YzHqOjzk11tTSxIg4CTwo6c+BdwPTQH9Vt37gSDblmZmV+JhTfae7NHE+pTnzcWBVuVHSmRXtZmbWJg3DXNILJF0haaGkeZLWAG8F/hO4C7hQ0npJZwDXA49ExJ58yzYzs0pp9syD0pTK48DvgI8BfxUR34iISWA98OFk26uAK3Kq1czMZtFwzjwJ7NfW2X4v4KWIZmYd5HOzmFlb+Nwq+XKYm1nufD7//PlEW2aWO59bJX8OczPLnc+tkj+HuZnlzudWyZ/D3Mxy53Or5M8HQM0sd+WDnF7Nkh+HuZm1hc+tki9Ps5iZFYDD3MysABzmZmYF4DA3MysAh7mZWQE4zM3MCsBhbmZWAA5zM7MCcJibmRVAmmuAPlfSrZIOSDoi6ceS3lCx/TJJeyQdk3S/pKX5lmxmZtXS7JnPB35O6dJxzwM+AHxV0jJJi4ERYDNwFjAG3JFTrWZmNos01wA9CmypaPqmpMeAVwBnA+MRcSeApC3AIUkrImJP9uWamVktTZ9oS9I5wAXAOPBuYFd5W0QclbQfWAnsqXrcBmADwLnnnttCyb3L11A0s9k0dQBU0gLgduC2ZM97IXC4qtthYFH1YyNie0QMR8Tw4ODg6dbbs8rXUJyYmiF49hqKozsnOl2amXWB1Hvmkp4DfAl4CrguaZ4G+qu69gNHMqnOnlHvGoreO7d28CfD7pZqz1ySgFuBc4D1EXEi2TQOrKrodyZwftJuGfI1FK2T/Mmw+6WdZvk08FLgTRFRmR53ARdKWi/pDOB64BEf/Myer6FonVTvk6F1hzTrzJcC7wReDvxS0nRyuyoiJoH1wIeB3wGvAq7Isd6e5WsoWif5k2H3S7M08QCgOtvvBVZkWZT9Pl9D0TppyUAfEzWC258Mu4evATqH+BqK1ikb1yxn08juU6Za/MmwuzjMzawhfzLsfg5zM0vFnwy7m8+aaGZWAA5zM7MCcJibmRWAw9zMrAAc5mZmBeAwNzMrAIe5mVkBOMzNzArAYW5mVgAOczOzAnCYm5kVgM/N0kN82a/e5p9/sTnMe0T5sl/lU5iWL/sF+Be6B/jnX3yeZukRvuxXb/PPv/jSXtD5Okljkp6U9IWqbZdJ2iPpmKT7k8vMWZfxZb96m3/+xZd2z/wg8CHgc5WNkhYDI8Bm4CxgDLgjywItG1lcEHp05wQX33wfL3r/PVx8832+Mvsc4guCF1+qMI+IkYgYBX5TtWkdMB4Rd0bEcWALsEqSrwnaZVq9IHR5znViaobg2TlXB3p6nfxj6AuCF1+rc+YrgV3lOxFxFNiftFsXWbt6iK3rLmJooA8BQwN9bF13UeqDX55zbU2n/xi2+vO37tfqapaFwGRV22FgUXVHSRuADQDnnntuiy9rp6OVy355zrU19f4YtitQfdm3Ymt1z3wa6K9q6weOVHeMiO0RMRwRw4ODgy2+rLWb51xb4z+GlrdWw3wcWFW+I+lM4Pyk3QrEc66t8R9Dy1vapYnzJZ0BzAPmSTpD0nzgLuBCSeuT7dcDj0TEnvxKtk7wnGtr/MfQ8qaIaNxJ2gLcUNX8wYjYIuly4BZgKfAQ8PaI+Gm95xseHo6xsbHTKthsrmr16/T+Or5JejgihmtuSxPmWXOY2+no5TCr/jo+lPbs/emot9QLc3+d3+aETi/t6zQvDbVGHOY2J/R6mHk1jDXiMLc5odfDzKthrBGHuc0JvR5mXg1jjTjMbU7o9TDz0lBrxBensDmhHFq9upoF/HV8q89hbm3T6tJCh5nZ7Bzm1ha+bJlZvjxnbm3R60sLzfLmMLe26PWlhWZ5c5hbW/T60kKzvDnMrS16fWmhWd58ANTawksLzfLlMLe28dJCs/x4msXMrAAc5mZmBeAwNzMrAIe5mVkBZBLmks6SdJeko5IOSLoyi+c1y9Lozgkuvvk+XvT+e7j45vt65ipF1huyWs3ySeAp4Bzg5cA9knZFxHhGz2/WEp8bxoqu5T1zSWcC64HNETEdEQ8C3wCubvW5zbLic8NY0WUxzXIB8HRE7Kto2wWsrOwkaYOkMUljk5OTGbysWXo+N4wVXRZhvhB4oqrtMLCosiEitkfEcEQMDw4OZvCyZun53DBWdFmE+TTQX9XWDxzJ4LnNMuFzw1jRZXEAdB8wX9JLIuJ/krZVgA9+WtfwuWGs6FoO84g4KmkEuFHStZRWs7wZeE2rz22WJZ8bxoosq6WJ7wE+B/wa+A3wbi9LtKJp9RqmZnnKJMwj4rfA2iyey6wbeZ26dTt/nd8sBa9Tt27nMDdLwevUrds5zM1S8Dp163YOc7MUvE7dup0vG2eWgtepW7dzmJul5HXq1s08zWJmVgAOczOzAnCYm5kVgMPczKwAHOZmZgWgiGj/i0qTwIEWnmIxcCijcvLg+lrj+lrj+lrTzfUtjYiaV/fpSJi3StJYRAx3uo7ZuL7WuL7WuL7WdHt9s/E0i5lZATjMzcwKYK6G+fZOF9CA62uN62uN62tNt9dX05ycMzczs1PN1T1zMzOr4DA3MysAh7mZWQF0ZZhLOkvSXZKOSjog6cpZ+knSRyT9Jrl9RJJyru25km5N6joi6ceS3jBL37dLOilpuuL2ujzrS173AUnHK16z5oUqOzR+01W3k5I+MUvftoyfpOskjUl6UtIXqrZdJmmPpGOS7pe0tM7zLEv6HEsec3me9Ul6taT/kPRbSZOS7pT0B3WeJ9X7IsP6lkmKqp/f5jrP0+7xu6qqtmNJva+Y5XlyGb+sdGWYA58EngLOAa4CPi1pZY1+G4C1wCrgZcCbgHfmXNt84OfAa4HnAR8Avipp2Sz9vx8RCytuD+RcX9l1Fa852+Vw2j5+lWMBvBCYAe6s85B2jN9B4EPA5yobJS0GRoDNwFnAGHBHnef5F2AncDbwD8DXJNX8tl4W9QHPp7TyYhmwFDgCfL7Bc6V5X2RVX9lAxWveVOd52jp+EXF71fvxPcCjwI/qPFce45eJrgtzSWcC64HNETEdEQ8C3wCurtH9GuDjEfF4REwAHwfenmd9EXE0IrZExE8j4v8i4pvAY0DNv+Zdru3jV2U98Gvgu218zd8TESMRMQr8pmrTOmA8Iu6MiOPAFmCVpBXVzyHpAuCPgBsiYiYivg7spvT/mEt9EfGtpLYnIuIYcAtwcauvl1V9zejE+NVwDfDFmKNL/LouzIELgKcjYl9F2y6g1p75ymRbo365kXQOpZrHZ+myWtIhSfskbZbUrqs7bU1e93t1piY6PX5pfnk6NX5QNT4RcRTYz+zvxUcj4khFW7vH81Jmfx+WpXlfZO2ApMclfT75tFNLR8cvmT67FPhig66dGL9UujHMFwJPVLUdBhbN0vdwVb+Fec/7lklaANwO3BYRe2p0+Q5wIfACSnsYbwU2tqG09wHnAUOUPobfLen8Gv06Nn7JL89rgdvqdOvU+JVVjw+kfy/W65s5SS8Drqf++KR9X2TlEPBKSlNAr6A0FrfP0rej4we8DfhuRDxWp0+7x68p3Rjm00B/VVs/pfnARn37gel2fEyS9BzgS5Tm9q+r1SciHo2Ix5LpmN3AjcBb8q4tIh6KiCMR8WRE3AZ8D3hjja4dGz9K02YP1vvl6dT4VWjlvVivb6YkvRj4FvDeiJh1yqqJ90UmkmnSsYh4OiJ+Ren35E8k1Qrojo1f4m3U37Fo+/g1qxvDfB8wX9JLKtpWUfvj43iyrVG/TCV7rrdSOkC7PiJOpHxoAG351JDydTsyfomGvzw1tHv8Thmf5HjO+cz+XjyvKqhyH8/kE869wE0R8aUmH97u8SzvJNTKnY6MH4Cki4ElwNeafGinfp9r6rowT+YlR4AbJZ2ZDPSbKe0FV/si8DeShiQtAf4W+EIbyvw08FLgTRExM1snSW9I5tRJDpptBv4tz8IkDUhaI+kMSfMlXUVpLvDbNbp3ZPwkvYbSR9V6q1jaNn7JOJ0BzAPmlccOuAu4UNL6ZPv1wCO1ptSSYzw/Bm5IHv+nlFYIfT2v+iQNAfcBt0TEZxo8RzPvi6zqe5Wk5ZKeI+ls4J+BByKiejqlI+NX0eUa4OtV8/XVz5Hb+GUmIrruRmkZ2ChwFPgZcGXSfgmlaYByPwEfBX6b3D5Kcr6ZHGtbSukv8nFKHw3Lt6uAc5N/n5v0/Rjwq+T/41FK0wQLcq5vEPghpY+nU8APgNd3y/glr/tZ4Es12jsyfpRWqUTVbUuy7XJgD6UllA8Ayyoe9xngMxX3lyV9ZoC9wOV51gfckPy78n1Y+fP9e+Bbjd4XOdb3VkorvY4Cv6C08/DCbhm/ZNsZyXhcVuNxbRm/rG4+0ZaZWQF03TSLmZk1z2FuZlYADnMzswJwmJuZFYDD3MysABzmZmYF4DA3MysAh7mZWQH8P+vctCOu+6ZUAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "speed = torch.randn(20)*3 + 0.75*(time-9.5)**2 + 1\n", "plt.scatter(time,speed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Now we are trying to come up with some parameters for our quadratic fuction that predicts speed any given time. Our choice is quadratic but that could be something else too. with a quadratic function our problem would be much easier.__ " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__here is the function gets time and parameter as inputs and predicts a result:__" ] }, { "cell_type": "code", "execution_count": 323, "metadata": {}, "outputs": [], "source": [ "def f(t, params):\n", " a,b,c = params\n", " return a*(t**2) + (b*t) + c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__this our loss function that calculate distance between prediction and target( actual mesurements)__" ] }, { "cell_type": "code", "execution_count": 324, "metadata": {}, "outputs": [], "source": [ "def mse(preds, targets): return ((preds-targets)**2).mean().sqrt()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### __Step 1: here are initial random parameters:__" ] }, { "cell_type": "code", "execution_count": 325, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 0.9569, 0.0048, -0.1506], requires_grad=True)" ] }, "execution_count": 325, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = torch.randn(3).requires_grad_()\n", "params" ] }, { "cell_type": "code", "execution_count": 326, "metadata": {}, "outputs": [], "source": [ "orig_params = params.clone()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 2: calculate predictions:" ] }, { "cell_type": "code", "execution_count": 327, "metadata": {}, "outputs": [], "source": [ "preds = f(time,params)" ] }, { "cell_type": "code", "execution_count": 328, "metadata": {}, "outputs": [], "source": [ "def show_preds(preds, ax=None):\n", " if ax is None: ax=plt.subplots()[1]\n", " ax.scatter(time, speed)\n", " ax.scatter(time, to_np(preds), color='red')\n", " ax.set_ylim(-300,100)" ] }, { "cell_type": "code", "execution_count": 329, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAEACAYAAACznAEdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbXElEQVR4nO3df5Ac5Z3f8fdHEpaEpD1JoMMgl6RAQPIJn0xYiouxARufsUmRIyh/ALINuRyyz6WrVJGCgyAZlYECW3f5wyH+IQpZiMiOzVkosc+WY8cIn0jiyhKQfWsEVTIstvjhFQhJKwnx65s/+hkzDDO7M+qenpndz6tqamf6ebrnu72z8+3u50crIjAzM5vU6QDMzKw7OCGYmRnghGBmZokTgpmZAU4IZmaWOCGYmRnghGBmZkmhCUHSKkkDko5K2lhTdpGkXZIOS3pQ0sKqsqmSNkg6IOl5SdcVGZeZmY2t6DOEZ4HbgA3VCyWdCGwB1gBzgQHg21VV1gKnAwuBDwM3SPp4wbGZmdko1I6RypJuA94TEdek1yuBayLiA+n1DGAvcFZE7JL0bCr/H6n8VuD0iLii8ODMzKyuKSW9z1JgZ+VFRByStBtYKukF4OTq8vT8snobSsllJcCMGTPOXrJkSbtiNiveSy/Bnj3w6qvwrnfB/Pkwd26no7IJ5pFHHtkbEfNql5eVEGYCwzXL9gOzUlnldW3ZO0TEemA9QH9/fwwMDBQbqVm7bN4MK1dmyQCyny+8ALfeCitWdDY2m1AkDdVbXlYvoxGgr2ZZH3AwlVFTXikzGz9uvhkOH377ssOHs+VmXaCshDAILKu8SG0IpwGDEbEPeK66PD0fLCk2s3I880xry81KVnS30ymSpgGTgcmSpkmaAjwAnClpeSr/PPCLiNiVVt0ErJY0R9IS4FpgY5GxmXXcggWtLTcrWdFnCKuBI8CNwCfT89URMQwsB24H9gHnAtU9iG4BdgNDwEPAuojYVnBsZp11++1w/PFvX3b88dlysy7Qlm6nZXGjsvWczZuzNoNnnsnODG6/3Q3KVjpJj0REf+3ysnoZmRlkX/5OANalPJeRmZkBTghmZpY4IZi1YvNmWLQIJk3Kfm7e3OmIzArjNgSzZlVGGlcGlw0NZa/B7QI2LvgMwaxZHmls45wTglmzPNLYxrkJd8lo66N7WPejJ3j25SOcMns611+8mMvOmt/psKwXLFiQXSaqt9xsHJhQZwhbH93DTVt+yZ6XjxDAnpePcNOWX7L10T2dDs16gUcaW4dtfXQP5935U/7JjX/PeXf+tPDvrgmVENb96AmOvPbG25Ydee0N1v3oiQ5FZD1lxQpYvx4WLgQp+7l+vRuUrRRlHNBOqITw7MtHWlpu9g4rVsDTT8Obb2Y/nQysJGUc0E6ohHDK7OktLbdxyOMIrEeVcUA7oRLC9RcvZvpxk9+2bPpxk7n+4sUdishKVRlHMDQEEW+NI3BSsB5QxgHthEoIl501nzsufx/zZ09HwPzZ07nj8ve5l9FE4XEE1mF5GoXLOKD19NctcrfVHjZpUnZmUEvK2gTM2qjSKFzdDjD9uMktHZQW9f3j6a8LUPsHrbTyA04KvcDjCKyDRmsUbvb747Kz5rf1u2ZCXTLKy91We5zHEVgH9UIvx1ITgqTtkl6RNJIeT1SVXSVpSNIhSVslzS0ztmb0wh/URuFxBNZBvdDLsRNnCKsiYmZ6LAaQtBT4OvAp4CTgMPCVDsQ2ql74g9oYPI7AOqQXejl2yyWjFcD3IuJnETECrAEulzSrw3G9TS/8Qcc9jyOwHtULvRw70ah8h6Q7gSeAmyNiO7AU+F+VChGxW9KrwBnAIx2Isa7KH869jDrE9yOwHtfuRuG8Su12Kulc4FfAq8AVwF3A+4H1wP0R8bWqunuAFSlhVG9jJbASYMGCBWcP1es1YuPTokX1ewktXJhd/jGzpnRFt9OI+HnVy3slXQlcAowAfTXV+4CDdbaxniyB0N/f33ODKDyOIQffj8A6bLz//3Z6HEIAAgaBZZWFkk4FpgJPdiiutvA4hpw8jsA6aCL8/5bWqCxptqSLJU2TNEXSCuB8YBuwGbhU0ockzQC+AGyJiHecIfQyj2PIyeMIrIMmwv9vmb2MjgNuA4aBvcBfAZdFxJMRMQh8liwx/A6YBXyuxNhK4XEMOXkcgXXQRPj/LS0hRMRwRJwTEbMiYnZE/ElE/Liq/JsRsSAiZkTEn0XES2XFVhaPYyB/t1GPI7AOmQj/v90yDmFCKGIcQ7tvoddWnn7aethEGIfk2U5LlqeXQhGzJXaUu41ah+XtJTReehk16nbqhNBDzrvzp+ypc71y/uzpPHzjRzoQUYs8/bR1UM8fUBWoUULwJaMe0vONWo26h7rbqDUpzyXTidBLKC8nhB7S841a7jZqOVSO8Pe8fITgrXEAzSaFnj+gKoETQg/pikatPL2E3G3Ucsh7hN/zB1QlcELoIR2fLbGIXkLuNmrHKO8RflccUHW5Tk9dYS3KO1tirl4So92k3l/s1manzJ5et1NFs0f4nq14bE4IE0juuVg8uZzllOeA5PqLF9ftJdTKEX63Tz/daU4IE8i6Hz3Bnz72E2742SZOObCXZ/tO5Evnf5p1M97V3D+JJ5ezHPIekPgIv/2cECaQ/od/wB3b7uL4148C8J4Dw9y57S5uAqCJcQy33/72G9SAewlZ00ZrFG72S91H+O3lRuUJ5KYd9/0+GVQc//pRbtpxX3MbcC+hCS/POAB3++x+PkOYQE7aP9zS8nq2/tGFrPvshrdO2f9oMZe1EMN4Gfo/EeW95JO3Udjaz2cIvSbHOAA1uNbfaHmtvAOD8q5vnZ3cMO84AHf77H5OCL0k7ziAnCOF834hdMPUAb08W2wRCbWTl3w6Po7GxuRLRr0k7ziASp2bb866ii5YkCWDJtsA8n4hFHENucjZYnvtFoh5G2W74ZKPG4W7m88QypZn6ocixgHkGCmcd+h/3vXzHiF3wxlKHnkTqi/52FicEFqV5ws97yWfDs8WmvcLIe/6eb/Qer2XS96E6ks+NpauSQiS5kp6QNIhSUOSrmrLG3XyC320Sz7N6PBsoXm/EPKun/cLrYjJzTrZBpE3oRbx+1921nwevvEjPHXnv+DhGz/iZDDOdM0NciR9iyxB/Vvg/cDfAx+IiMFG67R8g5zKF3rtwKpm+9LnveNXETeI2bz5mNsAel3eGwTlvUFKETdY6eQdu3yDGKvo6jumSZoB7APOjIgn07L7gD0RcWOj9VpOCJ3+QvctJHPp9BdypxNSETwOxKBxQuiWXkZnAK9XkkGyE7igtqKklcBKgAWtXjvP2yibdy4fT/2QSxFz2eTp5dLORt2yvpTdy8dG0y0JYSZwoGbZfmBWbcWIWA+sh+wMoaV36fQXes5un9bZL7S83S57vVHbxr9uaVQeAfpqlvUBBwt9l7yNskXM5eMbxPSsbmjUNWunbkkITwJTJJ1etWwZ0LBB+Zj4C91yyNtLyv34rdt1RaMygKT/CgTwF2S9jH5A0b2MzDrMjbrWDbq9URngc8AG4HfAi8BfjpYMzHqRG3Wtm3VNQoiIl6ClmZTNzKxA3dKGYGZmHeaEYGZmgBOCmZklTghmZgY4IZiZWeKEYGZmgBOCmZklTghmZgY4IZiZWeKEYGZmgBOCmZklTghmZgY4IZiZWeKEYGZmgBOCmZklTghmZgY4IZiZWVJKQpC0XdIrkkbS44ma8qskDUk6JGmrpLllxGVmZm8p8wxhVUTMTI/FlYWSlgJfBz4FnAQcBr5SYlxmZkZ33FN5BfC9iPgZgKQ1wOOSZkXEwc6GZmY2cZR5hnCHpL2SHpZ0YdXypcDOyouI2A28CpxRbyOSVkoakDQwPDzcznjNzCaUshLCXwOnAvOB9cD3JJ2WymYC+2vq7wdm1dtQRKyPiP6I6J83b1674jUzm3ByJ4TUYBwNHjsAIuLnEXEwIo5GxL3Aw8AlaRMjQF/NZvsAXy4yMytR7jaEiLjwWFYDlJ4PAssqBZJOBaYCT+aNzczMmtf2S0aSZku6WNI0SVMkrQDOB7alKpuBSyV9SNIM4AvAFjcom5mVq4xeRscBtwFLgDeAXcBlEfEkQEQMSvosWWI4AfgJ8G9KiMvMzKq0PSFExDBwzhh1vgl8s92xmJlZY566wszMACcEMzNLnBDMzAxwQjAzs8QJwczMACcEMzNLnBDMzAxwQjAzs8QJwczMACcEMzNLnBDMzAxwQjAzs8QJwczMACcEMzNLnBDMzAxwQjAzs8QJwczMgIISgqRVkgYkHZW0sU75RZJ2STos6UFJC6vKpkraIOmApOclXVdETGZm1pqizhCeJbtv8obaAkknAluANcBcYAD4dlWVtcDpwELgw8ANkj5eUFxmZtakQhJCRGyJiK3Ai3WKLwcGI+L+iHiFLAEsk7QklV8N3BoR+yLiceBu4Joi4jIzs+aV0YawFNhZeRERh4DdwFJJc4CTq8vT86WNNiZpZbo8NTA8PNymkM3MJp4yEsJMYH/Nsv3ArFRGTXmlrK6IWB8R/RHRP2/evEIDNTObyMZMCJK2S4oGjx1NvMcI0FezrA84mMqoKa+UmZlZicZMCBFxYUSoweODTbzHILCs8kLSDOA0snaFfcBz1eXp+WBrv4aZmeVVVLfTKZKmAZOByZKmSZqSih8AzpS0PNX5PPCLiNiVyjcBqyXNSQ3N1wIbi4jLzMyaV1QbwmrgCHAj8Mn0fDVARAwDy4HbgX3AucAVVeveQtbIPAQ8BKyLiG0FxWVmZk1SRHQ6hmPW398fAwMDnQ7DzKynSHokIvprl3vqCjMzA5wQzMwscUIwMzPACcHMzBInBDMzA5wQzMwscUIwMzPACcHMzBInBDMzA5wQzMwscUIwMzPACcHMzBInBDMzA5wQzMwscUIwMzPACcHMzBInBDMzA4q7p/IqSQOSjkraWFO2SFJIGql6rKkqnyppg6QDkp6XdF0RMZmZWWumFLSdZ4HbgIuB6Q3qzI6I1+ssXwucDiwE3g08KOlXvq+ymVm5CjlDiIgtEbEVePEYVr8auDUi9kXE48DdwDVFxGVmZs0rsw1hSNJvJX1D0okAkuYAJwM7q+rtBJY22oikleny1MDw8HB7IzYzm0DKSAh7gXPILgmdDcwCNqeymenn/qr6+1OduiJifUT0R0T/vHnz2hCumdnENGZCkLQ9NQrXe+wYa/2IGImIgYh4PSJeAFYBH5M0CxhJ1fqqVukDDh7LL2NmZsduzEbliLiw4PeM9HNSROyT9BywDPhxWr4MGCz4Pc3MbAxFdTudImkaMBmYLGmapCmp7FxJiyVNknQC8GVge0RULhNtAlZLmiNpCXAtsLGIuMzMrHlFtSGsBo4ANwKfTM9Xp7JTgW1kl4H+ETgKXFm17i3AbmAIeAhY5y6nZmblU0SMXatL9ff3x8DAQKfDMDPrKZIeiYj+2uWeusLMzAAnBDMzS5wQzMwMcEIwM7PECcHMzAAnBDMzS5wQzMwMcEIwM7PECcHMzAAnBDMzS5wQzMwMcEIwM7PECcHMzAAnBDMzS5wQzMwMcEIwM7PECcHMzIACEoKkqZLukTQk6aCkxyR9oqbORZJ2STos6UFJC2vW3yDpgKTnJV2XNyYzM2tdEWcIU4DfABcAf0B2L+XvSFoEIOlEYAuwBpgLDADfrlp/LXA6sBD4MHCDpI8XEJeZmbUgd0KIiEMRsTYino6INyPi+8BTwNmpyuXAYETcHxGvkCWAZZKWpPKrgVsjYl9EPA7cDVyTNy4zM2tN4W0Ikk4CzgAG06KlwM5KeUQcAnYDSyXNAU6uLk/Pl46y/ZWSBiQNDA8PFx2+mdmEVWhCkHQcsBm4NyJ2pcUzgf01VfcDs1IZNeWVsroiYn1E9EdE/7x584oJ3MzMxk4IkrZLigaPHVX1JgH3Aa8Cq6o2MQL01Wy2DziYyqgpr5SZmVmJxkwIEXFhRKjB44MAkgTcA5wELI+I16o2MQgsq7yQNAM4jaxdYR/wXHV5ej6ImZmVqqhLRl8F3gtcGhFHasoeAM6UtFzSNODzwC+qLiltAlZLmpMamq8FNhYUl5mZNamIcQgLgc8A7weelzSSHisAImIYWA7cDuwDzgWuqNrELWSNzEPAQ8C6iNiWNy4zM2vNlLwbiIghQGPU+QmwpEHZUeDP08PMzDrEU1eYmRnghGBmZokTgpmZAU4IZmaWOCGYmRnghGBmZokTgpmZAU4IZmaWOCGYmRnghGBmZokTgpmZAU4IZmaWOCGYmRnghGBmZokTgpmZAU4IZmaWOCGYmRlQzC00p0q6R9KQpIOSHpP0iaryRZKi6taaI5LW1Ky/QdIBSc9Lui5vTGZm1rrct9BM2/gNcAHwDHAJ8B1J74uIp6vqzY6I1+usvxY4HVgIvBt4UNKvfF9lM7Ny5T5DiIhDEbE2Ip6OiDcj4vvAU8DZTW7iauDWiNgXEY8DdwPX5I3LzMxaU3gbgqSTgDOAwZqiIUm/lfQNSSemunOAk4GdVfV2AkuLjsvMzEZXaEKQdBywGbg3InalxXuBc8guCZ0NzEp1AGamn/urNrM/1Wn0HislDUgaGB4eLjJ8M7MJbcyEIGl7ahSu99hRVW8ScB/wKrCqsjwiRiJiICJej4gXUtnHJM0CRlK1vqq37AMONoonItZHRH9E9M+bN6+lX9bMzBobs1E5Ii4cq44kAfcAJwGXRMRro20y/ZwUEfskPQcsA36cli/jnZebzMyszYq6ZPRV4L3ApRFxpLpA0rmSFkuaJOkE4MvA9oioXCbaBKyWNEfSEuBaYGNBcZmZWZOKGIewEPgM8H7g+aqxBitSlVOBbWSXgf4ROApcWbWJW4DdwBDwELDOXU7NzMqXexxCRAwBGqX8W8C3Rik/Cvx5epiZWYd46gozMwOcEMzMLHFCMDMzwAnBzMwSJwQzMwOcEMzMLHFCMDMzwAnBzMwSJwQzMwOcEMzMLHFCMDMzwAnBzMwSJwQzMwOcEMzMLHFCMDMzwAnBzMwSJwQzMwOcEMzMLCkkIUj6L5Kek3RA0pOS/qKm/CJJuyQdlvRgug9zpWyqpA1p3eclXVdETGZm1pqizhDuABZFRB/wL4HbJJ0NIOlEYAuwBpgLDADfrlp3LXA6sBD4MHCDpI8XFJeZmTWpkIQQEYMRcbTyMj1OS68vBwYj4v6IeIUsASyTtCSVXw3cGhH7IuJx4G7gmiLiMjOz5k0pakOSvkL2RT4deBT4QSpaCuys1IuIQ5J2A0slvQCcXF2enl82yvusBFamlyOSnjjGkE8E9h7jumVwfPk4vnwcXz7dHt/CegsLSwgR8TlJfwX8c+BCoHLGMBMYrqm+H5iVyiqva8savc96YH3eeCUNRER/3u20i+PLx/Hl4/jy6fb4GhnzkpGk7ZKiwWNHdd2IeCMidgDvAf4yLR4B+mo22wccTGXUlFfKzMysRGMmhIi4MCLU4PHBBqtN4a02hEFgWaVA0oxUNhgR+4DnqsvT88Fj+WXMzOzY5W5UlvSHkq6QNFPSZEkXA1cC/zNVeQA4U9JySdOAzwO/iIhdqXwTsFrSnNTQfC2wMW9cTch92anNHF8+ji8fx5dPt8dXlyIi3wakecDfkR3ZTwKGgC9HxN1VdT4K3EXWkPFz4JqIeDqVTQW+Cvxr4AjwxYj4j7mCMjOzluVOCGZmNj546gozMwOcEMzMLBm3CUHSXEkPSDokaUjSVQ3qSdIXJb2YHl+UpBLimyrpnhTbQUmPSfpEg7rXSHpD0kjV48ISYtwu6ZWq96w7CLAT+7BmX4yk/fOfGtRt+/6TtErSgKSjkjbWlDWcy6vOdhalOofTOh9tZ3yS/kTSjyW9JGlY0v2STh5lO019JgqMb1Hq4l79t1szynbK3n8ramI7nOI9u8F22rL/ijJuEwLwn4FXgZOAFcBXJS2tU28l2cjoZcAfA5cCnykhvinAb4ALgD8AVgPfkbSoQf3/HREzqx7bS4gRYFXVey5uUKf0fVi9L4B3k3VIuH+UVdq9/54FbgM2VC/U2HN51foW2Uj/E4Cbgb9LHTfaEh8wh6xHzCKyTh8HgW+Msa1mPhNFxVcxu+o9bx1lO6Xuv4jYXPNZ/Bzwa+D/jbKtduy/QozLhKBsrMNyYE1EjKTBcv8d+FSd6lcDfxsRv42IPcDfUsJcShFxKCLWRsTTEfFmRHwfeAqoe2TR5TqyD6ssB34H/EOJ7/k2EbElIrYCL9YUjTWX1+9JOgP4Z8AtEXEkIr4L/JLs92tLfBHxwxTbgYg4TNYb8Ly871dUfK3oxP6r42pgU/Rob51xmRCAM4DXI+LJqmU7yeZVqvW2uZZGqddWkk4ii7vRoLyzJO1VNr34GkmFTTsyhjvS+z48ymWWTu/DZv4JO7X/3jGXF7Cbxp/FX0dE9Uj9svfl+Yw9MLSZz0TRhiT9VtI30llXPR3df+lS4PlkY6tG04n915TxmhBmAgdqljWaI2km75xLaWa7r4FXk3QcsBm4t2rAXrWfAWcCf0h2tHMlcH0Jof01cCown+yywvcknVanXsf2YfonvAC4d5Rqndp/8M59A81/FkerWzhJf0w2cHS0fdPsZ6Ioe4FzyC5nnU22LzY3qNvR/Qd8GviHiHhqlDpl77+WjNeEMNr8SWPV7QNGyjrlkzQJuI+svWNVvToR8euIeCpdWvol8AWygXxtFRE/j4iDEXE0Iu4FHgYuqVO1k/vwU8CO0f4JO7X/kjyfxdHqFkrSPwV+CPy7iGh46a2Fz0Qh0iXfgYh4PSJeIPsf+Zikel/yHdt/yacZ/cCk9P3XqvGaEJ4Epkg6vWpZozmS3jbX0ij1CpeOoO8ha/heHhGvNblqAKWdwTTxvh3bhzTxT1hHmfuv4VxeDeqeWvNl1/Z9mc6yfkJ2X5L7Wly97M9i5SCj3ndXR/YfgKTzgFPIZm1oRaf+l+salwkhXafdAnxB0oz0x/ozsiPxWpuA6yTNl3QK8O8pZy4lyKbseC9waUQcaVRJ0idSGwOpMXIN8N/aGZik2ZIuljRN0hRJK8iuj26rU70j+1DSB8hOvUfrXVTK/kv7aBowGZhc2W+MPZfX76U2r8eAW9L6/4qs19Z32xWfpPnAT4G7IuJrY2yjlc9EUfGdK2mxpEmSTgC+DGyPiNpLQx3Zf1VVrga+W9N+UbuNtu2/wkTEuHyQdfHbChwCngGuSss/RHY5o1JPwJeAl9LjS6QpPdoc30Kyo4NXyE51K48VwIL0fEGq+zfAC+l3+TXZJY/j2hzfPOD/kp1uvwz8H+BPu2wffh24r87y0vcfWe+hqHmsTWUfBXaRdY3dTna72cp6XwO+VvV6UapzBHgC+Gg74wNuSc+rP4PVf9v/APxwrM9EG+O7kqz33SGymZE3Ae/ulv2Xyqal/XFRnfVK2X9FPTyXkZmZAeP0kpGZmbXOCcHMzAAnBDMzS5wQzMwMcEIwM7PECcHMzAAnBDMzS5wQzMwMgP8PeOZyChMQEdEAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_preds(preds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 3: Calculate the loss" ] }, { "cell_type": "code", "execution_count": 330, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(139.3082, grad_fn=)" ] }, "execution_count": 330, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = mse(preds,speed)\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***\n", "__The Question is how to improve these results:__" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 4: first we calculate the gradient:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Pytorch makes it easier we just call the backward() on the loss but it calculates gradient for the params 'a' 'b' and 'c'.___" ] }, { "cell_type": "code", "execution_count": 331, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([165.0324, 10.5991, 0.6615])" ] }, "execution_count": 331, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss.backward()\n", "params.grad # this is the derivative of the initial values in other word our slope." ] }, { "cell_type": "code", "execution_count": 332, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1.6503e-03, 1.0599e-04, 6.6150e-06])" ] }, "execution_count": 332, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params.grad * 1e-5 # scaler at the end is learning rate." ] }, { "cell_type": "code", "execution_count": 333, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 0.9569, 0.0048, -0.1506], requires_grad=True)" ] }, "execution_count": 333, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params # they are still same." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***\n", "#### Step 5: Step the weight." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "we picked the learning rate 1e-5 very small step to avoid missing the lowest possible loss." ] }, { "cell_type": "code", "execution_count": 334, "metadata": {}, "outputs": [], "source": [ "lr = 1e-5\n", "params.data -= lr * params.grad.data\n", "params.grad = None" ] }, { "cell_type": "code", "execution_count": 335, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(139.0348, grad_fn=)" ] }, "execution_count": 335, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = f(time,params)\n", "mse(preds, speed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "lets create a function for all these steps" ] }, { "cell_type": "code", "execution_count": 336, "metadata": {}, "outputs": [], "source": [ "def apply_step(params, prn=True):\n", " preds = f(time, params)\n", " loss = mse(preds, speed)\n", " loss.backward()\n", " params.data -= lr * params.grad.data\n", " params.grad = None\n", " if prn: print(loss.item())\n", " return preds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Step 6: repeat the step:" ] }, { "cell_type": "code", "execution_count": 337, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "139.03475952148438\n", "138.76133728027344\n", "138.4879150390625\n", "138.2145538330078\n", "137.94122314453125\n", "137.6679229736328\n", "137.39466857910156\n", "137.12144470214844\n", "136.84825134277344\n", "136.5751190185547\n" ] } ], "source": [ "for i in range(10): apply_step(params)" ] }, { "cell_type": "code", "execution_count": 338, "metadata": {}, "outputs": [], "source": [ "params = orig_params.detach().requires_grad_()" ] }, { "cell_type": "code", "execution_count": 339, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1QAAADMCAYAAAB0vOLuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbkUlEQVR4nO3db6xc9Z2Y8edrnAZjcAFBaSDydYRY6JqVi3Il1IUWErSFqtouwm9onTaoWlztikrbqEFGGIISrKB6X1TVbiM5CqJQb0VTGWvVrKDaQlKB+uYiFiIrJlW0mK0B1WzMH2NDQvj1xZzhzh3PnDNzzu/OnJl5PpLFnTlz7z3MzHNnvjPnnImUEpIkSZKk8W2Y9gpIkiRJ0qxyoJIkSZKkmhyoJEmSJKkmBypJkiRJqsmBSpIkSZJqcqCSJEmSpJocqCRJkiSppqwDVUTcExErEfFRRDzWt+yWiDgaEacj4rmIWOpZ9tmIeDQi3ouItyLiaznXS5pF9iTlZVNSXjYldeR+h+oN4GHg0d4zI+IS4BDwAHAxsAI82XORh4CrgCXgS8C9EXFb5nWTZo09SXnZlJSXTUlApJTy/9CIh4HPp5TuKk7vBu5KKf1mcXoz8DZwXUrpaES8USz/H8XybwFXpZTuzL5y0oyxJykvm5LysiktukntQ7UdeLl7IqX0AfAzYHtEXAR8rnd58fX2Ca2bNGvsScrLpqS8bEoLZeOEfs/5wIm+894FLiiWdU/3LztL8arHboDNmzd/8Zprrsm7povm5z+HY8fgk09Wz9uwAZaW4OKLp7deM+DFF198O6V06RR+dbaewKays6na5qEpe8rMnhqxKZ3Fpmor62lSA9UpYEvfeVuA94tl3dMf9i07S0rpAHAAYHl5Oa2srGRf2YWybdvaqKBz+pNPwOu2VEQcm9KvztYT2FR2NlXbPDRlT5nZUyM2pbPYVG1lPU1qk78jwI7uiWJb2iuBIymlk8CbvcuLr49MaN0W2+uvj3e+2sCe2symZpFNtZU9zSqbaiubWhe5D5u+MSLOBc4BzomIcyNiI/AUcG1E7CyWPwi8klI6Wnzr48DeiLgoIq4B7gYey7luGmLr1vHO18TY04yyqdayqRlkT61mUzPIptZF7neo9gJngD3AV4qv96aUTgA7gX3ASeB6oPdILt+gs7PiMeBHwP6U0tOZ102D7NsH55239rzzzuucr2mzp1lkU21mU7PGntrOpmaNTa2LdTls+qS4LW0mBw/C/fd33u7durUT1a5d016r1ouIF1NKy9Nej5xsKhObqmXemrKnTOypNpvSQDZVS1lPk9qHStN08GBnJ8QNGzr/PXhw7fJdu+C11zo7JL72mlFJVWxKyqusKXuSxuNj1MRN6ih/mpaDB2H3bjh9unP62LHOaTAgqQ6bkvKyKSkfe5oK36Gad/ffvxpV1+nTnfMljc+mpLxsSsrHnqZi7t6hOvzScfY/8ypvvHOGyy/cxNdvvZrbr7ti2qs1PR4eUw3ZVB+bUkM21cem1IA99bGnqZirgerwS8e579CPOfPLXwFw/J0z3HfoxwCLG9fWrZ23ewedr6H8A91hUwPY1NjsaZVNDWBTY7OpDnsawJ5qadrUXG3yt/+ZVz+NquvML3/F/mdendIatYCHxxxb9w/08XfOkFj9A334pePTXrWJs6kBbGos9rSWTQ1gU2OxqVX2NIA9jS1HU3M1UL3xzpmxzl8Iu3bBgQOwtAQRnf8eOOCOiSX8A73KpgawqbHY01o2NYBNjcWmVtnTAPY0thxNzdVAdfmFm8Y6f254eMys/AO9yqZsqil7WsumbKopm1q1sD2BHzWQUY6m5mqg+vqtV7PpM+esOW/TZ87h67dePaU1moDu4TGPHYOUVg+P2f9gpTUOv3ScGx55li/s+QE3PPLsmrd1F/oPdB+bsqlRDWvKntayKZsalU1VW8iewKZqWO/nfXM1UN1+3RV8+47f4IoLNxHAFRdu4tt3/MaancrKrtCZ5OExx1a1rezC/oEewKYKNlWqrCl7WquqqbnrCWyqBpsazUI+RoFNjWkSz/sipZRznSdqeXk5raysjHz5/qPBQOcK649vpmzY0Hl1ol9E561eneWGR57l+IC3ca+4cBMv7PkyMNrRXiLixZTS8kRWekJsCpuqoaqpUY+eNG9N2VPBpsZmU4PZVMGmxjKJ531zddj0KmU7nc1sWB4ec2yjbCt7+3VXzO59YoJsSlDdlD2NZi57ApuqwabysCnBZJ73zdUmf1XmckdOD485Nrc/z8emBDaVy1z2BDZVg03lYVOCyfS0UAPVTP+BGnY0Fw+POTa3P8/HpgQ2lctc9gQ2VYNN5WFTgsn0tFAD1cz+gao6mouHxxzLKDuxajQ2JbCpXOa2J7CpMdlUHjYlmExPC3VQCijf6WzUnTwnbtu2wdvKLi11ItJZJnFbztvOvmBTNjWcTY0vd0+jLJ8Ke6rFpsZnUzZVZr1vy7KeFm6gGqbVR4LxaC5jmdRtOW8PVGBTNjWYTdWTsydocVP2NDabqsembGqYSdyWZT0t1CZ/ZcqOBDN1w47a4tFcBmr1bblAWn072NRYWn1bLpDW3g72NLbW3pYLprW3g02Nbdq3pQNVYepHginb+dCjuYxl6relgBbcDjaVzdRvSwEtuB2GNWVPY5v6bSlgyreDj1FZTbuphfocqjKXX7hp4Id+9R4JZt22zezufNj91OvuzofQ2cmwu6Ph/ffD6693XqHYt2/hd0AcdnuMcltq/dnUbCm7LWyqHVrdFNhTH5tqv6rbwceodmlzU75DVag6Ekx328zj75whAcffOcN9h37M4ZeON//l99+/GlXX6dOd87s8mssaZbfHzB7VZ87Y1Oyoui1sqh1a3ZQ9rWFTs6HsdvAxql3a3pQDVaHqkIrrum3m66+Pd/6COPzScW545Fm+sOcH3PDIs2v+iFV9+rmHm50+m2qfYU1V3RY21Q421T42NdvKbgd7mry6z/tg+k25yV+P26+7YugV33jbzIMHh791u3Xr4MNjLvDOh/1Ha+m+EgGd26nq9ii7LTU5NtUeZU2NclvYVDvYVHvY1HwYdjtk2SdnWFP2dJamz/u6l5tWU75DNaJGn7Zd9QFt7nx4lqpXImb6088F2NSklTVlT/PBpibLpuZb49uwrCl7OsusP+9zoBrRKNtmDn2rcpRtzw8c6HxgW0TnvwcOzP32smVv7Va9EjHtbWXVnE3lV7cpe5oPo+xjNez+YVOD2dTiavQYBeVNLWhPMPw6m/XnfW7yN6LebdQHHV2k9K3KUbaV7T2qywKoemu36mgtVbeH2s+m8mrSlD3Nh7Lbser+MdI+HTZlUwuk0WPUKE0tWE9Qfp3N+vM+B6oxlG2buf+ZV/mtv/hz7v1fj3P5e2/zxpZL+Hf/4F+wf/Pf4Ha3lT1L1UElvn7r1QM/8br3lQi3P599NpVP06bsaT4Mux1Le7ruCvfpGMCmVPsxyqYGKmtq1p/3OVBlsvzCn/Htp/+I8z7+CIDPv3eCR57+I+4D2LePj3/3bjZ+uDp5f3zuJjYuwLaywz4zYJSDSkB7X4nQ+rOps5V9BodNqUxpT3zZpmxKY7Kpweo2Nes9OVCNo+QISPc9/8SnUXWd9/FH3Pf8Exz+5r/h+dvu4Q+efezTVzH+/Zfv4sZfv5nbi8uu24fHNdRkvZq8tQvtfiVCmdjUyOvVdDPZ7uXacB1o8sp6gj/k8K/fbFM2pTE0baqtPcH0mprlnhyoRlXxqdaXvXti4Ldd9u4J9j/zKsevvon/dvVNa5b972KzgcrtcCusV5SjrFfZ72761q7mnE2N1VSOzWS1AIa8SFHWE2BTBZvSGiUv+jVpCmjUE0yvqbrP++a9KY/y1+vgQdi2DTZs6Py3e7hYqDwCUgzZJja2bq3cbGCUD48bdlSU9fwk76r1qvrdVW/t+qGGC8Cm1mjS1CibH9nUAihrquQwzWU9QfXmbaPcd4cd7cym1GrDmqr4KIEmTTV5jOoum0ZTTZ73wXw31Zp3qCLiYuB7wD8E3gbuSyn9ycRWoOLV8sqjtezbt/b74dPPFLj8r8rf4qy6A5a9WlD1akATTR5g5/2t3VlgU2eb5abc/Gj6Wt9U2YsUJT1B9WbYZffdqle0bUqDTL0nKG+q6rDnDZpq8hg1zaaaPu/rrv88NtWmd6j+GPgFcBmwC/hORGwf+6eUvXpXtrzqMziGHZWle37JZwpUHTu/6sPKyu7Ao3xydOnnJJSoWq9Z/8yABWBTfWa5KXtqhfVvqsG7uqUvUlR87k2TpqpebbcpDTHdxygob2qUw57XbKrJYxSM9qJhnZ5612HQ+T7vG64VA1VEbAZ2Ag+klE6llJ4H/hT452P9oKpPei9bPsqr5VWfar1rF7z2GnzySee/RVRVb3FW3QHL7sBVUY7ytvCw8Jo+aZ3nt3bbzqbmryl7mq6JNFXVW1VTo7xIMaAnaNZU1ZOspk2VPTm0qdnUiscoKG+qqieo3VSTxygov183eYyCZoPgIjcVKaVprwMRcR3wQkrpvJ7z/i1wU0rpt4d93/LyclpZWVk9Y9u2wcf8X1rq3NHLlkP590LpzolNle3kd8Mjzw58C/WK4nKDdvDr3oHLvveFPV8+623l/u8vW6+q7513EfFiSml52usxiE3Z1Cyat6bO6gmaNVPVY//mS9B5kaLnVfMmht13mzbRpMey9eous6n2NdWKx6iq5cM26VvnnqD8Maqqqf3PvNqox7J1s6fhPbVloPr7wPdTSn+757y7gV0ppZv7Lrsb2A2wdevWLx7rDWHDhs4rEGf/gs6rB2XLn3hiXcNposkTtC/s+QGDbuEA/vKRf1wZ7Sjr1tbDfq63tj5QgU1Vsal2moemSnuC8magvLdRBqZ1fJFimCZP0KC8qWH7ZIzaU9XvnndtbaoVj1GjNDWFnqBZUz5GrZ+yntpyUIpTwJa+87YA7/dfMKV0ADgAnVcq1iys+lTqsuXdQKYQTpWqDzsr28GvyY7Go67booQ0Y2yqhE2phpGaKu0JqpsqWzZKU7t2TbyxUT6Qs25TTXuq+t2amuk/RkF1U1PoCZo15WPUdLRloPopsDEirkop/Z/ivB3AkbF+SsURVyqXTymcUdS9A1cd83+UI7JoJtlUBZvSmCbTVNkyaG1TTZ5klTU1bPMle5p57XiMgrlryseo6WjFQSlSSh8Ah4BvRsTmiLgB+B3gibF+UMURVyqXz6GmO+9rNtnU+rGpxTSRphawJ2i2875mk49R68fHqOloxT5U8OnnETwK/Bbw18Ceqs8jGLjDr8a2yNvDNtHWbdO7bGp6bKqeeWvKnvKwp/ra3JSPUdNjU/XMwj5UpJR+Dtw+7fVYRG4PO59sanpsaj7Z1HTY03yyp+mxqfxascmfJEmSJM0iBypJkiRJqsmBSpIkSZJqcqCSJEmSpJocqCRJkiSpJgcqSZIkSarJgUqSJEmSanKgkiRJkqSaHKgkSZIkqSYHKkmSJEmqyYFKkiRJkmpyoJIkSZKkmhyoJEmSJKkmBypJkiRJqsmBSpIkSZJqcqCSJEmSpJocqCRJkiSpJgcqSZIkSarJgUqSJEmSanKgkiRJkqSaHKgkSZIkqSYHKkmSJEmqyYFKkiRJkmpyoJIkSZKkmhyoJEmSJKkmBypJkiRJqsmBSpIkSZJqcqCSJEmSpJocqCRJkiSpJgcqSZIkSarJgUqSJEmSanKgkiRJkqSasgxUEXFPRKxExEcR8diA5bdExNGIOB0Rz0XEUs+yz0bEoxHxXkS8FRFfy7FO0iyzKSkvm5LysilpVa53qN4AHgYe7V8QEZcAh4AHgIuBFeDJnos8BFwFLAFfAu6NiNsyrZc0q2xKysumpLxsSipkGahSSodSSoeBvx6w+A7gSErp+ymlD+lEtCMirimWfxX4VkrpZErpJ8B3gbtyrJc0q2xKysumpLxsSlo1iX2otgMvd0+klD4AfgZsj4iLgM/1Li++3j7sh0XE7uIt5pUTJ06s0ypLrWZTUl7ZmrInCbApLZhJDFTnA+/2nfcucEGxjL7l3WUDpZQOpJSWU0rLl156adYVlWaETUl5ZWvKniTAprRgKgeqiPhhRKQh/54f4XecArb0nbcFeL9YRt/y7jJpLtmUlJdNSXnZlDSeyoEqpXRzSimG/LtxhN9xBNjRPRERm4Er6WxbexJ4s3d58fWR8f43pNlhU1JeNiXlZVPSeHIdNn1jRJwLnAOcExHnRsTGYvFTwLURsbO4zIPAKymlo8Xyx4G9EXFRsbPi3cBjOdZLmlU2JeVlU1JeNiWtyrUP1V7gDLAH+Erx9V6AlNIJYCewDzgJXA/c2fO936Czo+Ix4EfA/pTS05nWS5pVNiXlZVNSXjYlFSKlNO11qG15eTmtrKxMezW0oCLixZTS8rTXIyeb0jTNW1P2pGmzKSmfsp4mcZQ/SZIkSZpLDlSSJEmSVJMDlSRJkiTV5EAlSZIkSTU5UEmSJElSTQ5UkiRJklSTA5UkSZIk1eRAJUmSJEk1OVBJkiRJUk0OVJIkSZJUkwOVJEmSJNXkQCVJkiRJNTlQSZIkSVJNDlSSJEmSVJMDlSRJkiTV5EAlSZIkSTU5UEmSJElSTQ5UkiRJklSTA5UkSZIk1eRAJUmSJEk1OVBJkiRJUk0OVJIkSZJUkwOVJEmSJNXkQCVJkiRJNTlQSZIkSVJNDlSSJEmSVJMDlSRJkiTV5EAlSZIkSTU5UEmSJElSTQ5UkiRJklSTA5UkSZIk1dR4oIqIz0bE9yLiWES8HxF/ERH/qO8yt0TE0Yg4HRHPRcRS3/c/GhHvRcRbEfG1puskzTKbkvKyKSkvm5LWyvEO1Ubgr4CbgL8J7AX+a0RsA4iIS4BDwAPAxcAK8GTP9z8EXAUsAV8C7o2I2zKslzSrbErKy6akvGxK6tF4oEopfZBSeiil9FpK6ZOU0n8H/hL4YnGRO4AjKaXvp5Q+pBPRjoi4plj+VeBbKaWTKaWfAN8F7mq6XtKssikpL5uS8rIpaa3s+1BFxGXArwFHirO2Ay93l6eUPgB+BmyPiIuAz/UuL77ennu9pFllU1JeNiXlZVNadFkHqoj4DHAQ+E8ppaPF2ecD7/Zd9F3ggmIZfcu7y4b9jt0RsRIRKydOnMiz4lJL2ZSU13o3ZU9aNDYljTBQRcQPIyIN+fd8z+U2AE8AvwDu6fkRp4AtfT92C/B+sYy+5d1lA6WUDqSUllNKy5deemnV6kutY1NSXm1qyp40D2xKGk/lQJVSujmlFEP+3QgQEQF8D7gM2JlS+mXPjzgC7OieiIjNwJV0tq09CbzZu7z4+gjSnLIpKS+bkvKyKWk8uTb5+w7wd4DfTimd6Vv2FHBtROyMiHOBB4FXet4WfhzYGxEXFTsr3g08lmm9pFllU1JeNiXlZVNSIcfnUC0B/wr4u8BbEXGq+LcLIKV0AtgJ7ANOAtcDd/b8iG/Q2VHxGPAjYH9K6emm6yXNKpuS8rIpKS+bktba2PQHpJSOAVFxmT8Hrhmy7CPgXxb/pIVnU1JeNiXlZVPSWtkPmy5JkiRJi8KBSpIkSZJqcqCSJEmSpJocqCRJkiSpJgcqSZIkSarJgUqSJEmSanKgkiRJkqSaHKgkSZIkqSYHKkmSJEmqyYFKkiRJkmpyoJIkSZKkmhyoJEmSJKkmBypJkiRJqsmBSpIkSZJqcqCSJEmSpJocqCRJkiSpJgcqSZIkSarJgUqSJEmSanKgkiRJkqSaHKgkSZIkqSYHKkmSJEmqyYFKkiRJkmpyoJIkSZKkmhyoJEmSJKkmBypJkiRJqsmBSpIkSZJqcqCSJEmSpJocqCRJkiSpJgcqSZIkSarJgUqSJEmSanKgkiRJkqSasgxUEfGfI+LNiHgvIn4aEb/bt/yWiDgaEacj4rmIWOpZ9tmIeLT43rci4ms51kmaZTYl5WVTUl42Ja3K9Q7Vt4FtKaUtwD8BHo6ILwJExCXAIeAB4GJgBXiy53sfAq4CloAvAfdGxG2Z1kuaVTYl5WVTUl42JRWyDFQppSMppY+6J4t/Vxan7wCOpJS+n1L6kE5EOyLimmL5V4FvpZROppR+AnwXuCvHekmzyqakvGxKysumpFXZ9qGKiP8YEaeBo8CbwJ8Vi7YDL3cvl1L6APgZsD0iLgI+17u8+Hp7rvWSZpVNSXnZlJSXTUkdG3P9oJTS70fEvwb+HnAz0H3V4nzgRN/F3wUuKJZ1T/cvGygidgO7i5OnIuLVIRe9BHh71PUX4HU2rqXqi9RnU3PB62w8M9/UGD2B949xeX2Nb5Ga8v4xPq+z8QztqXKgiogfAjcNWfxCSunG7omU0q+A5yPiK8DvAf8BOAVs6fu+LcD7xbLu6Q/7lg2UUjoAHBhhvVdSSstVl9Mqr7PJsKnF4XU2GW1qatSeivX2/jEGr6/JmcWmvH+Mz+ssn8pN/lJKN6eUYsi/G4d820ZWt6M9AuzoLoiIzcWyIymlk3TeIt7R8707iu+R5pJNSXnZlJSXTUnjabwPVUT8rYi4MyLOj4hzIuJW4J8C/7O4yFPAtRGxMyLOBR4EXkkpHS2WPw7sjYiLip0V7wYea7pe0qyyKSkvm5LysilprRwHpUh03uL9v8BJ4A+BP0gp/SlASukEsBPYVyy/Hriz5/u/QWdHxWPAj4D9KaWnM6zXSJtcaA2vs3awqfnhddYONjUfvL7ao41Nef8Yn9dZJpFSmvY6SJIkSdJMynbYdEmSJElaNA5UkiRJklTT3A1UEXFxRDwVER9ExLGI+GfTXqe2iYh7ImIlIj6KiMf6lt0SEUcj4nREPBcR6/oZFmo/mypnTxqXTZWzKY3DnqrZ1Pqbu4EK+GPgF8BlwC7gOxHhp2+v9QbwMPBo75kRcQlwCHgAuBhYAZ6c+NqpbWyqnD1pXDZVzqY0DnuqZlPrbK4OSlF8zsFJ4NqU0k+L854AjqeU9kx15VooIh4GPp9Suqs4vRu4K6X0m8XpzXQ+Qfu6nkOdaoHY1OjsSaOwqdHZlKrY03hsav3M2ztUvwZ83I2q8DLgKxWj2U7n+gIgpfQBncOaev0tLpuqz540iE3VZ1PqZ0/N2FQm8zZQnQ+813feu8AFU1iXWXQ+neurl9ffYrOp+uxJg9hUfTalfvbUjE1lMm8D1SlgS995W4D3p7Aus8jrT/28T9TndadBvF/U53Wnft4nmvH6y2TeBqqfAhsj4qqe83YAR6a0PrPmCJ3rC/h0W9or8fpbZDZVnz1pEJuqz6bUz56asalM5mqgKrb9PAR8MyI2R8QNwO8AT0x3zdolIjZGxLnAOcA5EXFuRGwEngKujYidxfIHgVfcMXFx2VQ1e9I4bKqaTWlU9jQam1p/czVQFX4f2AT8P+C/AL+XUnLSXmsvcAbYA3yl+HpvSukEsBPYR+eoOdcDd05rJdUaNlXOnjQumypnUxqHPVWzqXU2V4dNlyRJkqRJmsd3qCRJkiRpIhyoJEmSJKkmBypJkiRJqsmBSpIkSZJqcqCSJEmSpJocqCRJkiSpJgcqSZIkSarJgUqSJEmSanKgkiRJkqSa/j89K9hCGd/VTwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_,axs = plt.subplots(1,4,figsize=(12,3))\n", "for ax in axs: show_preds(apply_step(params, False), ax)\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***\n", "## MNIST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loss Function our 3 and 7 recognizer. Currently we use metric not loss" ] }, { "cell_type": "code", "execution_count": 340, "metadata": {}, "outputs": [], "source": [ "train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)" ] }, { "cell_type": "code", "execution_count": 341, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([12396, 784])" ] }, "execution_count": 341, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_x.size()" ] }, { "cell_type": "code", "execution_count": 342, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([12396, 784]), torch.Size([12396, 1]))" ] }, "execution_count": 342, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)\n", "train_x.shape,train_y.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How tensor manipulated " ] }, { "cell_type": "code", "execution_count": 343, "metadata": {}, "outputs": [], "source": [ "temp_tensor = tensor (1)" ] }, { "cell_type": "code", "execution_count": 344, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1)" ] }, "execution_count": 344, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp_tensor" ] }, { "cell_type": "code", "execution_count": 345, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Tensor" ] }, "execution_count": 345, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(temp_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__is above tensor is wrong what's the difference?__\n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__we have a tensor__" ] }, { "cell_type": "code", "execution_count": 346, "metadata": {}, "outputs": [], "source": [ "temp_tensor = tensor([1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__then we multiuplied the inside of__" ] }, { "cell_type": "code", "execution_count": 347, "metadata": {}, "outputs": [], "source": [ "temp_tensor =tensor([1]*4)" ] }, { "cell_type": "code", "execution_count": 348, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1, 1, 1, 1])" ] }, "execution_count": 348, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp_tensor" ] }, { "cell_type": "code", "execution_count": 349, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4])" ] }, "execution_count": 349, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp_tensor.shape" ] }, { "cell_type": "code", "execution_count": 350, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 350, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp_tensor.ndim" ] }, { "cell_type": "code", "execution_count": 351, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4])" ] }, "execution_count": 351, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp_tensor.size()" ] }, { "cell_type": "code", "execution_count": 352, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1],\n", " [1],\n", " [1],\n", " [1]])" ] }, "execution_count": 352, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(temp_tensor).unsqueeze(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: __looked changed but why size is still unchanged why not [4,1]__" ] }, { "cell_type": "code", "execution_count": 353, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4])" ] }, "execution_count": 353, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp_tensor.shape" ] }, { "cell_type": "code", "execution_count": 354, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4])" ] }, "execution_count": 354, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp_tensor.size()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How unsqueeze works?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: Whaaaaaaaaaaaaat?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(temp_tensor).unsqueeze(1) doesn't work but (temp_tensor*1).unsqueeze(1) you need to unsqueeze it when creating otherwise it doesnt work. I do not believe it." ] }, { "cell_type": "code", "execution_count": 355, "metadata": {}, "outputs": [], "source": [ "temp_tensor = tensor([1]).unsqueeze(1)" ] }, { "cell_type": "code", "execution_count": 356, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 1])" ] }, "execution_count": 356, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp_tensor.shape" ] }, { "cell_type": "code", "execution_count": 357, "metadata": {}, "outputs": [], "source": [ "temp_tensor =tensor([1]*1).unsqueeze(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset" ] }, { "cell_type": "code", "execution_count": 358, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([784]), 1, tensor([1]))" ] }, "execution_count": 358, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dset = list(zip(train_x,train_y))\n", "x,y = dset[0]\n", "x.shape,x.ndim,y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__we create list of tuples, each tuple contains a image and a target__" ] }, { "cell_type": "code", "execution_count": 359, "metadata": {}, "outputs": [], "source": [ "valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)\n", "valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)\n", "valid_dset = list(zip(valid_x,valid_y))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__same for validation__\n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Weights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__this is not clear on the videos but consider a layer NN of 728 inputs and 1 output.__" ] }, { "cell_type": "code", "execution_count": 360, "metadata": {}, "outputs": [], "source": [ "def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()" ] }, { "cell_type": "code", "execution_count": 361, "metadata": {}, "outputs": [], "source": [ "weights = init_params((28*28,1))" ] }, { "cell_type": "code", "execution_count": 362, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([784, 1])" ] }, "execution_count": 362, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weights.shape" ] }, { "cell_type": "code", "execution_count": 363, "metadata": {}, "outputs": [], "source": [ "bias = init_params(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: The function `weights*pixels` won't be flexible enough—it is always equal to 0 when the pixels are equal to 0 (i.e., its *intercept* is 0). You might remember from high school math that the formula for a line is `y=w*x+b`; we still need the `b`. We'll initialize it to a random number too:" ] }, { "cell_type": "code", "execution_count": 364, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.0959], requires_grad=True)" ] }, "execution_count": 364, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bias" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Again transposing the weight matrix is not clear but Tariq Rashed's book would be very beneficial at this point__" ] }, { "cell_type": "code", "execution_count": 365, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-5.6867], grad_fn=)" ] }, "execution_count": 365, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(train_x[0]*weights.T).sum() + bias" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__for all dataset put this multiplication in a function__" ] }, { "cell_type": "code", "execution_count": 366, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ -5.6867],\n", " [ -6.5451],\n", " [ -2.0241],\n", " ...,\n", " [-14.3286],\n", " [ 4.3505],\n", " [-12.6773]], grad_fn=)" ] }, "execution_count": 366, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def linear1(xb): return xb@weights + bias\n", "preds = linear1(train_x)\n", "preds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Create a tensor with results based on their value (above 0.5 is 7 and below it is 3)__" ] }, { "cell_type": "code", "execution_count": 367, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[False],\n", " [False],\n", " [False],\n", " ...,\n", " [ True],\n", " [False],\n", " [ True]])" ] }, "execution_count": 367, "metadata": {}, "output_type": "execute_result" } ], "source": [ "corrects = (preds>0.5).float() == train_y\n", "corrects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***\n", "__check it__" ] }, { "cell_type": "code", "execution_count": 368, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4636172950267792" ] }, "execution_count": 368, "metadata": {}, "output_type": "execute_result" } ], "source": [ "corrects.float().mean().item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__almost half of them is 3 and the other half is 7 (since weighs are totally random)__ \n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Why we need a loss Function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Basically we need to have gradients for correcting our weighs, we need to know which direction we need to go__ \n", "\n", "\n", "If you dont understand all of these, ckeck khan academy for gradient." ] }, { "cell_type": "code", "execution_count": 369, "metadata": {}, "outputs": [], "source": [ "trgts = tensor([1,0,1])\n", "prds = tensor([0.9, 0.4, 0.2])" ] }, { "cell_type": "code", "execution_count": 370, "metadata": {}, "outputs": [], "source": [ "def mnist_loss(predictions, targets):\n", " return torch.where(targets==1, 1-predictions, predictions).mean()" ] }, { "cell_type": "code", "execution_count": 371, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.1000, 0.4000, 0.8000])" ] }, "execution_count": 371, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.where(trgts==1, 1-prds, prds)" ] }, { "cell_type": "code", "execution_count": 372, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.4333)" ] }, "execution_count": 372, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mnist_loss(prds,trgts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sigmoid" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__We need this for squishing predictions between 0-1__" ] }, { "cell_type": "code", "execution_count": 373, "metadata": {}, "outputs": [], "source": [ "def sigmoid(x): return 1/(1+torch.exp(-x))" ] }, { "cell_type": "code", "execution_count": 374, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEMCAYAAAA/Jfb8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlXklEQVR4nO3deXxU9b3G8c8XCCQkJGwh7DvIpoBEEBStVety61ZstSrutaLWrfXW6tVatbW112urtS63KIriWnGjaqtWxaXKGiDs+04Cgex7vvePCb0xJmaAJGdm8rxfr3npnPnN8BhnHk5+58zvmLsjIiKxpVXQAUREpPGp3EVEYpDKXUQkBqncRURikMpdRCQGqdxFRGKQyl1ijpndZWZrg86xn5nNMLP3GhhzqZlVNFcmiX0qd4kqZpZgZveY2RozKzazHDObZ2bX1xj238DRQWWsww3A94MOIS1Lm6ADiBygR4ETCBVmBpAMjAX67h/g7gVAQSDp6uDuuUFnkJZHe+4Sbc4Gfu/ur7n7BnfPcPcZ7n73/gF1TcuY2Y1mttXMiszsXTObamZuZr2rH7/UzCrM7AQzW1r9W8GHZtbTzI4zs0VmVmhm75lZr1qvfYmZLTezsuo/414za1Pj8a9My5hZq+rfPrLMrMDMXgQ6NdHPS1oolbtEmx3AqWbWOdwnmNn3CE3V/B4YDTwP/K6Ooa2AXwJXAscAvYAXgbuBadXbegP/U+O1/wN4EpgJjAJ+Clxb/Tr1+QlwM3ALcCSwoIHxIgfO3XXTLWpuhAp2E1AJLAGeILQ3bzXG3AWsrXH/U2Bmrdf5LeBA7+r7l1bfH1NjzC3V28bV2HYTsLvG/bnAS7Ve+wagGGhbfX8G8F6Nx7cCv671nFeAiqB/vrrFzk177hJV3P1TYBAwGXgaSCNUjG+YmdXztBHAv2pt+7yulweW1ri/s/qfS2pt62JmravvjwQ+rvU6HwHx1Tm/wsySCf1G8Fmthz6pJ7vIQVG5S9Rx9wp3/8zdH3D3swjtdX8XOO6bnhbGS1e5e2Xt57h7eR2vU99fJCIRQeUusWBF9T+71fP4cmBirW2NdapkJl//S+V4QtMy62oPdvc8YBswqdZDxzRSHhFAp0JKlDGzjwgdEJ0PZAODgd8A+4B/1vO0B4AXzexL4G1CxXpx9WOHekGD+4A3zexW4FVgDKE5/wfcvewb8txjZisJTRedCZx0iDlEvkJ77hJt3gYuBP4GrAKeAtYAx7j77rqe4O6vAv8J3EpoTv1C4FfVD5ccShh3/xtwOXAJsAx4EPhzjdevyx+Bh6rHLib0W8Xd3zBe5ICZu67EJC2Pmd0JXO/uXYPOItIUNC0jMc/M4gidf/43oJDQN1xvAR4JMpdIU9Keu8S86m+LvgWMAzoAG4BnCH3TVYt1SUxSuYuIxCAdUBURiUERMefetWtX79+/f9AxRESiyoIFC3a7e2pdj0VEuffv35/58+cHHUNEJKqY2ab6HtO0jIhIDAqr3M3sOjObb2alZjajgbE3mdlOM8szsyfNrF2jJBURkbCFu+e+HbiX0LrV9TKzUwh9C/BEoB8wkG/+pp6IiDSBsMrd3V9199eAPQ0MvQSY7u6Z7r4XuIfQin0iItKMGnvOfSSh61rulwGkmVmXRv5zRETkGzR2uScBNS8GvP/fO9QeaGZXVc/jz8/Ozm7kGCIiLVtjl3sBoavR77f/3/NrD3T3J9w93d3TU1PrPE1TREQOUmOf555J6ALEL1XfHw3scveG5upFRGKau5NTWMbOvBKy8krJyi9hV14pY/t2ZPKQxt/BDavcqxdeagO0BlqbWTyhi/nWXnTpGWCGmT1H6Ayb/yJ0cWARkZhWVlHFtn3FbN1bxNa9xWzbW8z2fcVs21fMjtwSduaVUFZR9bXnTfvWoODKnVBJ/7LG/YuAX5nZk4QuYTbC3Te7+ztmdj+hK+IkAH+t9TwRkahVXlnF5pwi1mcXsmF3ARt2F7FxdyGbc4rYkVtMVY11GFu3Mronx9OzYzxj+nSkR8d4uieHbt2S4+nWoR2pHdoRH9e6/j/wEETEqpDp6emu5QdEJFJUVjkbdhewcmc+q3cVsGZXPmuyCti0p5Dyyv/vzE7t4+jfNZF+ndvTt0sifTu3p0+nBHp3bk9ah3a0ad20iwCY2QJ3T6/rsYhYW0ZEJCgl5ZWs2pnP0m25ZG7PJXN7Hqt25lNaPYXSyqBfl0QGd0vi5BFpDE5NYmBqIgO7JpHSPi7g9PVTuYtIi+HubM4pYsGmvSzavI+MrftYsSPv33vjKQlxjOyZzNSj+zG8RzLDenRgUGpSk02dNCWVu4jErKoqZ8XOPL5Yn8OXG3KYv2kvuwtKAUhs25ojenfkyskDOaJXCqN6pdC7UwJmFnDqxqFyF5GYsnF3IZ+s3c2na3fz2bo95BaXA9C7UwKTh3RlXL9OpPfvxJBuHWjdKjaKvC4qdxGJaiXllXy+bg8frsriw9XZbNpTBEDPlHi+MyKNiYO6MGFgF3p1TAg4afNSuYtI1NlXVMY/lu/ivRW7+Hj1borLK4mPa8WkQV254tgBTB6SSv8u7WNmiuVgqNxFJCrsLSzjncyd/G3pDj5ft4eKKqdHSjznjuvNicO7cfTALlF54LOpqNxFJGIVl1Xy9+U7eWPxdj5anU1FldOvS3t+dNxAThvVncN7pbTovfNvonIXkYji7izcvJdXFmzlrYwd5JdW0D05nsuPHcCZo3sysmeyCj0MKncRiQi5ReX8deFWZn25mbVZBSTEteb0w3swZVwvjh7QhVYxfGZLU1C5i0iglm/PY8ZnG3h98XZKK6oY3acj9085gtOP6EFSO1XUwdJPTkSaXVWV896KXUz/ZANfbMghPq4V3zuyNxcd3ZeRPVOCjhcTVO4i0mxKKyp5bdE2Hv94PeuzC+nVMYHbTh/Geel9I3qdlmikcheRJldSXskLX27msY/WszOvhJE9k3noh2M5fVT3Jl85saVSuYtIkykpr+S5Lzbz2EfryM4vZfyAzvz++0dw7OCuOuOliancRaTRVVRW8cqCrfzx/TXsyC1h0qAuPPzDsRw9sEvQ0VoMlbuINBp3570VWdz39grWZxcypk9HHvj+aCYN7hp0tBZH5S4ijWLZtlzueWs5X2zIYWBqIk9MHcfJI9I0/RIQlbuIHJKcwjJ+/+4qXpi3mU7t23LPWSM5f3xf4nSgNFAqdxE5KFVVzqwvN/P7d1dRUFrBZZMGcOPJQ0iO1ymNkUDlLiIHbOXOPH7x6lIWbd7HxIFd+NVZIxma1iHoWFKDyl1EwlZSXslD76/hiY/Xk5wQx4PnjebsMb00rx6BVO4iEpZFm/dyyytLWJtVwLnjenP76cPplNg26FhSD5W7iHyj0opKHvzHGp74eB1pyfE8ffl4jh+aGnQsaYDKXUTqtXpXPje8sJgVO/I4L70Pt393uA6YRgmVu4h8jbvz9Gcbue/tlSS1a8NfLk7npBFpQceSA6ByF5Gv2FdUxs9eXsJ7K3ZxwmGp3H/uaFI7tAs6lhwglbuI/NuCTXu5/vlFZOWXcMd3R3D5Mf11JkyUUrmLCO7OU59u5Dd/W0GPjvG8cvUkRvfpGHQsOQQqd5EWrqisgl+8upTXF2/npOFpPPCD0aQk6KBptFO5i7Rgm/cUcdXM+azalc8tpxzGtOMH6ULUMSKslX3MrLOZzTazQjPbZGYX1DOunZk9Zma7zCzHzN40s16NG1lEGsPn6/Zw1iOfsCO3hBmXjefaEwar2GNIuMu2PQKUAWnAhcCjZjayjnE3ABOBI4CewF7g4UbIKSKNaNYXm5k6/Qu6JLXj9WuP0ZeSYlCD5W5micAU4A53L3D3T4A3gKl1DB8AvOvuu9y9BHgRqOsvAREJQGWVc89by7lt9lKOHdKVV6+ZRP+uiUHHkiYQzpz7UKDC3VfX2JYBHF/H2OnAH82sJ7CP0F7+24caUkQOXXFZJTe+uIh3M3dx6aT+3PHdEbTWNEzMCqfck4C8WttygbrW91wDbAG2AZXAUuC6ul7UzK4CrgLo27dvmHFF5GDsKSjliqfnk7F1H3d+dwSXHzsg6EjSxMKZcy8AkmttSwby6xj7CNAO6AIkAq9Sz567uz/h7ununp6aqvk+kaayJaeIcx/7nJU783jsonEq9hYinHJfDbQxsyE1to0GMusYOwaY4e457l5K6GDqeDPT1XFFArBiRx5THv2MnMIynrtyAqeM7B50JGkmDZa7uxcS2gO/28wSzewY4CxgZh3D5wEXm1mKmcUB1wDb3X13Y4YWkYbN25jDDx7/nFZmvHz1RMb16xx0JGlG4Z4KeQ2QAGQBzwPT3D3TzCabWUGNcT8DSgjNvWcDpwPnNGJeEQnD3DXZTJ3+BalJ7Xhl2kRdAq8FCusbqu6eA5xdx/a5hA647r+/h9AZMiISkL9n7uS6WYsYmJrIzCsmaEXHFkrLD4jEkDcztnPji4sZ1SuFpy87io7tdRm8lkrlLhIjXl+8jZteXEx6v85MvzSdDrpiUoumcheJAa8t2sbNLy3mqP6defLSo0hsp492S6d3gEiU21/s4weEir19W32sReUuEtXmLNnBzS8tZsKALjx56VEktG0ddCSJEOGeCikiEebvmTu54YVFHNm3E9MvTVexy1eo3EWi0Eers7lu1iJG9krhqcs0FSNfp3IXiTLzN+bw45nzGdQtiWcuG6+zYqROKneRKLJ8ex6XzZhHz5QEZl4xnpT2Knapm8pdJEps2F3IxU9+SVK7Nsy8cgJdk/TNU6mfyl0kCmTllTB1+hdUuTPzign06pgQdCSJcCp3kQiXV1LOJU/NI6ewjBmXHcXgbkkNP0laPJW7SAQrrajk6pkLWLMrn8cuGscRvTsGHUmihM6fEolQVVXOz15ewmfr9vDgeaM5bqiuWCbh0567SIT63bsreTNjO7eeNoxzxvYOOo5EGZW7SAR69l+bePyj9Vx0dF9+fNzAoONIFFK5i0SYD1bu4s7Xl/HtYd2464yRmFnQkSQKqdxFIsjy7XlcN2sRI3om8/APx9KmtT6icnD0zhGJEFl5JVz59DxSEuKYfonWZJdDo3ePSAQoLqvkR8/MZ19xOS9fPZG05PigI0mUU7mLBCx0ymMGS7bl8vhF4xjZMyXoSBIDNC0jErCHPljDnKU7uPXUYXxnZPeg40iMULmLBOjtpTv4w3trmHJkb67SKY/SiFTuIgHJ3J7LzS9lMLZvR359ziid8iiNSuUuEoDdBaVc9cwCOraP4/Gp44iP0yXypHHpgKpIMyuvrOLa5xayu6CUV66eRLcOOjNGGp/KXaSZ/XrOCr7YkMOD543m8N46M0aahqZlRJrRy/O3MOOzjVxx7AAtBiZNSuUu0kyWbN3H7a8tY9KgLvzitGFBx5EYp3IXaQZ7Ckq5euYCUpPa8acLjtSaMdLkNOcu0sQqKqu4/oVF7C4s469XT6JzYtugI0kLENbug5l1NrPZZlZoZpvM7IJvGHukmX1sZgVmtsvMbmi8uCLR57//vppP1+7h3rNH6QCqNJtw99wfAcqANGAMMMfMMtw9s+YgM+sKvAPcBLwCtAV01EharHeW7eSxj9ZxwYS+/CC9T9BxpAVpcM/dzBKBKcAd7l7g7p8AbwBT6xh+M/Cuuz/n7qXunu/uKxo3skh0WJ9dwM9ezmB0n4788owRQceRFiacaZmhQIW7r66xLQMYWcfYo4EcM/vMzLLM7E0z69sYQUWiSVFZBdOeXUhca+PPFx5Juzb6Bqo0r3DKPQnIq7UtF+hQx9jewCXADUBfYAPwfF0vamZXmdl8M5ufnZ0dfmKRCOfu3D57Gauz8vnD+WPp1TEh6EjSAoVT7gVAcq1tyUB+HWOLgdnuPs/dS4BfAZPM7GtHkdz9CXdPd/f01NTUA80tErGe+2Izsxdt48YTh3L8UL23JRjhlPtqoI2ZDamxbTSQWcfYJYDXuO91jBGJWUu35nL3m8s5bmgqP/n24KDjSAvWYLm7eyHwKnC3mSWa2THAWcDMOoY/BZxjZmPMLA64A/jE3XMbM7RIJMotKueaWQvoktSWP5w3hlattISvBCfcr8ldAyQAWYTm0Ke5e6aZTTazgv2D3P0D4DZgTvXYwUC958SLxAp352evZLBjXwl/uuBIfVFJAhfWee7ungOcXcf2uYQOuNbc9ijwaGOEE4kWf5m7gX8s38Ud3x3BuH6dgo4jorVlRA7Vgk17+d07Kzl1ZHcuP6Z/0HFEAJW7yCHZW1jGT2YtpEfHeH537hG6VJ5EDC0cJnKQqqqcn76cwe6CMv46bRIpCXFBRxL5N+25ixyk/527ng9WZnH7fwzXgmAScVTuIgdhwaYc7n93Facf3p2LJ/YLOo7I16jcRQ5QaJ59Eb06JvDbKZpnl8ikOXeRA+Du/KzGPHtyvObZJTJpz13kAPxl7gbeX5nFbacP0zy7RDSVu0iYFm0Onc9+ysg0LpnUP+g4It9I5S4Shtyicq6btYjuKfHcf+5ozbNLxNOcu0gD3J2f/3UJu/JKePnqiTqfXaKC9txFGvDM55t4J3MnPz91GGP7at0YiQ4qd5FvsGxbLr+es4JvD+vGlZMHBB1HJGwqd5F65JeUc92shXRJassD39c8u0QXzbmL1MHduW32MrbsLeaFq46mk9ZnlyijPXeROrw4bwtvZmzn5pOHclT/zkHHETlgKneRWlbtzOeXb2Ry7OCuTDt+UNBxRA6Kyl2khqKyCq6dtZAO8XE8qOugShTTnLtIDXe+nsm67AKevWICqR3aBR1H5KBpz12k2l8XbOWVBVv5ybeHcMzgrkHHETkkKncRYG1WPv/12jLGD+jMDScOCTqOyCFTuUuLV1xWybXPLSKhbWseOn8srTXPLjFAc+7S4t31RiarduXz9OXj6Z4SH3QckUahPXdp0V5btI0X52/hmm8N4vihqUHHEWk0KndpsdZmFXDb7KUc1b8TN588NOg4Io1K5S4tUmiefSHxca156IdjadNaHwWJLZpzlxZp/zz7jMuOokdKQtBxRBqddlekxXl14VZenL+Fa08YxLcO6xZ0HJEmoXKXFmXNrnxunx06n/2mkzTPLrFL5S4tRmFpBdOeW0hiu9b8SfPsEuM05y4tgrtz++ylrK9eN6Zbss5nl9gW1q6LmXU2s9lmVmhmm8zsggbGtzWzFWa2tXFiihyaWV9u5rXF27nppKFM0rox0gKEu+f+CFAGpAFjgDlmluHumfWMvwXIBjocckKRQ7Rk6z5+9cZyjh+ayrUnDA46jkizaHDP3cwSgSnAHe5e4O6fAG8AU+sZPwC4CLivMYOKHIx9RWVMe3YhqR3aaX12aVHCmZYZClS4++oa2zKAkfWMfxi4DSg+xGwih6SqyrnxxcVk55fy5wuPpLOugyotSDjlngTk1dqWSx1TLmZ2DtDa3Wc39KJmdpWZzTez+dnZ2WGFFTkQD3+wlg9XZXPnGSMY3adj0HFEmlU45V4AJNfalgzk19xQPX1zP3B9OH+wuz/h7ununp6aqgWbpHF9uCqLP7y/mnPG9uLCCX2DjiPS7MI5oLoaaGNmQ9x9TfW20UDtg6lDgP7AXDMDaAukmNlO4Gh339goiUUasHlPETe8sJjD0jrwm3MOp/r9KNKiNFju7l5oZq8Cd5vZlYTOljkLmFRr6DKgT437k4A/AUcSOnNGpMkVl1Vy9bMLcHcenzqOhLatg44kEohwv6J3DZAAZAHPA9PcPdPMJptZAYC7V7j7zv03IAeoqr5f2STpRWpwd25/bSkrdubxx/PH0q9LYtCRRAIT1nnu7p4DnF3H9rmEDrjW9ZwPgd6HkE3kgDzz+SZeXbiNG08awgnDtCCYtGxaXENiwufr9nD3W8s5aXga139bF7gWUblL1Nu2r5hrZy2kf5f2PHjeaH1RSQSVu0S5kvJKfjxzPuUVVTxxcTod4uOCjiQSEbQqpEQtd+eWV5aQuT2Pv1yczqDUOg//iLRI2nOXqPXnD9fxZsZ2bjnlME4cnhZ0HJGIonKXqPT3zJ38/t1VnDWmJ9OOHxR0HJGIo3KXqLNyZx43vbiY0b1T+N2UI/QNVJE6qNwlqmTnl3LFjPkkxbfh8anpxMfpG6giddEBVYkaJeWVXDVzPjmFZbx89US6p+hSeSL1UblLVNh/Zsyizft47KJxjOqVEnQkkYimaRmJCg/+YzVvZmzn56cO49RR3YOOIxLxVO4S8V6at4WHPljLeel9uPr4gUHHEYkKKneJaHPXZHPb7KUcNzSVe88ZpTNjRMKkcpeItWJHHtOeXcjgbkk8csFY4lrr7SoSLn1aJCJt3VvEpU99SVK7Njx12VFaM0bkAKncJeLsLSzj4ie/pLiskmeuGE+PlISgI4lEHZ0KKRGluKySy5+ex9a9xTx7xQSGpnUIOpJIVNKeu0SMsooqrnluARlb9vHQ+WMZP6Bz0JFEopb23CUiVFY5P305g3+uyua+7x2uc9lFDpH23CVw7s6dry/jzYzt3HraMH44vm/QkUSinspdAuXu3P/uKp77YjNXHz+Iq7V8r0ijULlLoB56fy2PfriOCyb05eenHhZ0HJGYoXKXwDz+0ToefG81547rzb1n6dunIo1J5S6BeOrTDdz39krOGN2T3005glatVOwijUlny0ize/KTDdz91nJOHdmd//nBaFqr2EUancpdmtVf5q7n3jkrOHVkdx7WejEiTUafLGk2+4v9tFEqdpGmpj13aXLuzsMfrOV//rGa/zi8B384f4yKXaSJqdylSbk7v31nJY9/tJ4pR/bmd1MOp42KXaTJqdylyVRWOb98YxnP/mszU4/ux6/OHKmzYkSaicpdmkRpRSU3v5jBnKU7+PHxA7n11GE6j12kGYX1+7GZdTaz2WZWaGabzOyCesbdYmbLzCzfzDaY2S2NG1eiQUFpBZfPmMecpTu4/fTh/OK04Sp2kWYW7p77I0AZkAaMAeaYWYa7Z9YaZ8DFwBJgEPB3M9vi7i80Ul6JcFl5JVz+9DxW7Mjnge+PZsq43kFHEmmRGtxzN7NEYApwh7sXuPsnwBvA1Npj3f1+d1/o7hXuvgp4HTimsUNLZFq9K59z/vwZ67ML+cvF6Sp2kQCFMy0zFKhw99U1tmUAI7/pSRb6PXwyUHvvXmLQp2t3M+XPn1FWWcVLP57ICcO6BR1JpEULp9yTgLxa23KBhq5/dlf16z9V14NmdpWZzTez+dnZ2WHEkEj13BebuOTJL+nRMZ7Xrj2GUb1Sgo4k0uKFM+deACTX2pYM5Nf3BDO7jtDc+2R3L61rjLs/ATwBkJ6e7mGllYhSUVnFPW8t5+nPN3H80FQevmAsyfFxQccSEcIr99VAGzMb4u5rqreNpp7pFjO7HLgVOM7dtzZOTIk0OYVlXP/8Ij5Zu5sfTR7AracN1wJgIhGkwXJ390IzexW428yuJHS2zFnApNpjzexC4DfACe6+vpGzSoRYujWXq59dQHZBKfefewQ/SO8TdCQRqSXc74FfAyQAWcDzwDR3zzSzyWZWUGPcvUAXYJ6ZFVTfHmvcyBKkl+ZvYcpjnwHwytUTVewiESqs89zdPQc4u47tcwkdcN1/f0CjJZOIUlRWwZ2vZ/LKgq0cO7grD/1wLJ0T2wYdS0TqoeUHpEGrduZz7ayFrMsu4PoTh3DDiUM0vy4S4VTuUi9359kvNvPrOctJahfHs1dM4JjBXYOOJSJhULlLnbLzS/n5X5fwwcosjhuayn9//wi6dYgPOpaIhEnlLl/zzrKd3D57KfmlFdx1xggunthfS/WKRBmVu/xbTmEZv3wjkzcztjOyZzLPnzeGoWkNfRFZRCKRyl1wd+Ys3cFdb2SSW1zOzScPZdq3BulSeCJRTOXewm3JKeLO15fxz1XZHN4rhZlXTGB4j9qrTYhItFG5t1ClFZVM/2QDD7+/FjO447sjuGRiP13fVCRGqNxboA9XZfGrN5ezYXchJ49I464zR9KrY0LQsUSkEancW5A1u/K57+2VfLAyi4FdE3n68vEcPzQ16Fgi0gRU7i1Adn4pf3hvNS/M20L7tq35xWnDuOyYAbRtoykYkVilco9huUXlPDF3HU9+spHyyiqmHt2P608cojVhRFoAlXsMyisp5+lPN/K/c9eTV1LBmaN7ctPJQxnQNTHoaCLSTFTuMWRfURlPfbqRJz/dQH5JBScN78bNJx/GiJ46tVGkpVG5x4Bt+4qZPncDL8zbTFFZJaeMTOMn3x6ia5mKtGAq9yi2eMs+nvp0A28t2YEBZ4zuyVXHDdSXkERE5R5tSsoreWfZTmZ8tpHFW/aR1K4Nl0zszxWTB+hcdRH5N5V7lFifXcDzX27mlQVb2VtUzoCuidx1xgjOTe9DUjv9bxSRr1IrRLC8knLmLNnBKwu2smDTXtq0Mk4ekcaFE/oxaVAXLcMrIvVSuUeYkvJKPlyVzRsZ23h/RRalFVUM7pbEracN43tje9EtWRfMEJGGqdwjQEl5JR+vzubtZTt5b8Uu8ksq6JrUlvOP6sPZY3sxpk9HzLSXLiLhU7kHJKewjH+uzOK9Fbv4eHU2hWWVpCTEccrI7pw5uieTBnXRCo0ictBU7s2ksspZti2XD1dl89HqLBZv2UeVQ1pyO84c04vTRnVn4qAuukCGiDQKlXsTcXfWZRfyr/V7+HTtbj5bt4fc4nLM4IheKVz37SGcNLwbo3qm6MCoiDQ6lXsjKa+sYsWOPOZv3Mv8TTl8uSGH3QVlAPRMiec7I9I4dkhXjh3clS5J7QJOKyKxTuV+ENydrXuLWbotl8Vb9rF4yz6Wbs2luLwSCJX55CGpTBjQmQkDu9C/S3sdEBWRZqVyb0BZRRXrsgtYuTOPFTvyWb49j2Xbc9lXVA5A29atGNEzmfOO6kN6/04c2bcTPfVNUREJmMq9Wkl5JRt2F7Iuu4C1WQWsySpg9c58NuwupKLKAWjbphVD05I4bVR3RvVKYVTPFIb3SNZFL0Qk4rSocs8tLmfr3iK25BSxaU8Rm3OK2LinkI27i9ieW4yHOhwz6NOpPUPTkjhpRBrDundgeI9kBnZN1OmJIhIVYqbcC0sryMovZUduMbvyStiZW8r2fcXsyC1m274Stu4tIr+k4ivP6dg+jv5dEhk/oDP9uyQyMDWRwd2SGNA1kfi41gH9l4iIHLqoLvd/rszi7reWk5VXQmFZ5dceT0mIo2fHBHqmxDO+fyd6d2pPr04J9O3cnj6d25OSEBdAahGRphdWuZtZZ2A68B1gN/ALd59VxzgDfgtcWb3pL8Ct7vsnPBpXx/ZxjOiRzLcOS6Vbh3i6dWhHj5R4ulff2reN6r+7REQOWrjt9whQBqQBY4A5Zpbh7pm1xl0FnA2MBhz4B7ABeKwxwtY2tm8nHrmwU1O8tIhIVGvw6KCZJQJTgDvcvcDdPwHeAKbWMfwS4AF33+ru24AHgEsbMa+IiIQhnFM/hgIV7r66xrYMYGQdY0dWP9bQOBERaULhlHsSkFdrWy7QoZ6xubXGJVkdX880s6vMbL6Zzc/Ozg43r4iIhCGcci8Aal9xORnID2NsMlBQ1wFVd3/C3dPdPT01NTXcvCIiEoZwyn010MbMhtTYNhqofTCV6m2jwxgnIiJNqMFyd/dC4FXgbjNLNLNjgLOAmXUMfwa42cx6mVlP4KfAjEbMKyIiYQj3u/TXAAlAFvA8MM3dM81sspkV1Bj3OPAmsBRYBsyp3iYiIs0orPPc3T2H0PnrtbfPJXQQdf99B/6z+iYiIgGxJvry6IGFMMsGNh3k07sS+tZspInUXBC52ZTrwCjXgYnFXP3cvc4zUiKi3A+Fmc139/Sgc9QWqbkgcrMp14FRrgPT0nJp/VoRkRikchcRiUGxUO5PBB2gHpGaCyI3m3IdGOU6MC0qV9TPuYuIyNfFwp67iIjUonIXEYlBKncRkRgUc+VuZkPMrMTMng06C4CZPWtmO8wsz8xWm9mVDT+ryTO1M7PpZrbJzPLNbLGZnRZ0LgAzu656KehSM5sRcJbOZjbbzAqrf1YXBJmnOlPE/HxqivD3VMR9Bmtqqs6KuXIndEnAeUGHqOE+oL+7JwNnAvea2biAM7UBtgDHAynAfwEvmVn/IENV2w7cCzwZdBC+ennJC4FHzSzoi89E0s+npkh+T0XiZ7CmJumsmCp3Mzsf2Ae8H3CUf3P3THcv3X+3+jYowEi4e6G73+XuG929yt3fInSt28Df8O7+qru/BuwJMscBXl6y2UTKz6e2CH9PRdxncL+m7KyYKXczSwbuBm4OOkttZvZnMysCVgI7gL8FHOkrzCyN0OUUtfb+/zuQy0tKLZH2norEz2BTd1bMlDtwDzDd3bcGHaQ2d7+G0GUJJxNaG7/0m5/RfMwsDngOeNrdVwadJ4IcyOUlpYZIfE9F6GewSTsrKsrdzD40M6/n9omZjQFOAh6MpFw1x7p7ZfWv9r2BaZGQy8xaEbroShlwXVNmOpBcEeJALi8p1Zr7PXUgmvMz2JDm6Kyw1nMPmrt/65seN7Mbgf7A5uprcScBrc1shLsfGVSuerShief7wslVfdHy6YQOFp7u7uVNmSncXBHk35eXdPc11dt02chvEMR76iA1+WcwDN+iiTsrKvbcw/AEof9ZY6pvjxG6CtQpwUUCM+tmZuebWZKZtTazU4AfEhkHfB8FhgNnuHtx0GH2M7M2ZhYPtCb0Zo83s2bfCTnAy0s2m0j5+dQj4t5TEfwZbPrOcveYuwF3Ac9GQI5U4CNCR8PzCF1+8EcRkKsfoTMGSghNP+y/XRgB2e7i/89o2H+7K6AsnYHXgEJgM3CBfj7R9Z6K1M9gPf9fG7WztHCYiEgMipVpGRERqUHlLiISg1TuIiIxSOUuIhKDVO4iIjFI5S4iEoNU7iIiMUjlLiISg/4PDYVAdiC75S0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_function(torch.sigmoid, title='Sigmoid', min=-4, max=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__update the fuction with the sigmoid thats all.__" ] }, { "cell_type": "code", "execution_count": 375, "metadata": {}, "outputs": [], "source": [ "def mnist_loss(predictions, targets):\n", " predictions = predictions.sigmoid()\n", " return torch.where(targets==1, 1-predictions, predictions).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What are SGD and Mini-Batches" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__This explains most of it.__" ] }, { "cell_type": "code", "execution_count": 376, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([ 0, 2, 10, 13, 8]),\n", " tensor([11, 12, 4, 1, 5]),\n", " tensor([ 3, 14, 6, 9, 7])]" ] }, "execution_count": 376, "metadata": {}, "output_type": "execute_result" } ], "source": [ "coll = range(15)\n", "dl = DataLoader(coll, batch_size=5, shuffle=True)\n", "list(dl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__but this is only a list however we neeed a tuple consist of independent and dependent variable.__" ] }, { "cell_type": "code", "execution_count": 377, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#26) [(0, 'a'),(1, 'b'),(2, 'c'),(3, 'd'),(4, 'e'),(5, 'f'),(6, 'g'),(7, 'h'),(8, 'i'),(9, 'j')...]" ] }, "execution_count": 377, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds = L(enumerate(string.ascii_lowercase))\n", "ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### DataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__then put it into a Dataloader.__" ] }, { "cell_type": "code", "execution_count": 378, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(tensor([ 1, 23, 9, 8, 24, 2]), ('b', 'x', 'j', 'i', 'y', 'c')),\n", " (tensor([14, 25, 13, 11, 19, 5]), ('o', 'z', 'n', 'l', 't', 'f')),\n", " (tensor([ 0, 10, 4, 7, 18, 12]), ('a', 'k', 'e', 'h', 's', 'm')),\n", " (tensor([ 6, 21, 15, 16, 22, 3]), ('g', 'v', 'p', 'q', 'w', 'd')),\n", " (tensor([20, 17]), ('u', 'r'))]" ] }, "execution_count": 378, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dl = DataLoader(ds, batch_size=6, shuffle=True)\n", "list(dl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__now we have batches and tuples__\n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__all together__" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's time to implement the process we saw in <>. In code, our process will be implemented something like this for each epoch:\n", "\n", "```python\n", "for x,y in dl:\n", " pred = model(x)\n", " loss = loss_func(pred, y)\n", " loss.backward()\n", " parameters -= parameters.grad * lr\n", "```" ] }, { "cell_type": "code", "execution_count": 379, "metadata": {}, "outputs": [], "source": [ "weights = init_params((28*28,1))\n", "bias = init_params(1)" ] }, { "cell_type": "code", "execution_count": 380, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([256, 784]), torch.Size([256, 1]))" ] }, "execution_count": 380, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dl = DataLoader(dset, batch_size=256)\n", "xb,yb = first(dl)\n", "xb.shape,yb.shape" ] }, { "cell_type": "code", "execution_count": 381, "metadata": {}, "outputs": [], "source": [ "valid_dl = DataLoader(valid_dset, batch_size=256)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__a small test__" ] }, { "cell_type": "code", "execution_count": 382, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 784])" ] }, "execution_count": 382, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = train_x[:4]\n", "batch.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***\n", "__predictions__" ] }, { "cell_type": "code", "execution_count": 383, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 8.0575],\n", " [14.3841],\n", " [-3.8017],\n", " [ 5.1179]], grad_fn=)" ] }, "execution_count": 383, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = linear1(batch)\n", "preds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__loss__" ] }, { "cell_type": "code", "execution_count": 384, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.2461, grad_fn=)" ] }, "execution_count": 384, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = mnist_loss(preds, train_y[:4])\n", "loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__gradients__" ] }, { "cell_type": "code", "execution_count": 385, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([784, 1]), tensor(-0.0010), tensor([-0.0069]))" ] }, "execution_count": 385, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss.backward()\n", "weights.grad.shape,weights.grad.mean(),bias.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__for the step we need a optimizer__\n", "***" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__put all into a function except the optimizer.__" ] }, { "cell_type": "code", "execution_count": 386, "metadata": {}, "outputs": [], "source": [ "def calc_grad(xb, yb, model):\n", " preds = model(xb)\n", " loss = mnist_loss(preds, yb)\n", " loss.backward()" ] }, { "cell_type": "code", "execution_count": 387, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(-0.0021), tensor([-0.0138]))" ] }, "execution_count": 387, "metadata": {}, "output_type": "execute_result" } ], "source": [ "calc_grad(batch, train_y[:4], linear1)\n", "weights.grad.mean(),bias.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: if you do it twice results are change." ] }, { "cell_type": "code", "execution_count": 388, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(-0.0031), tensor([-0.0207]))" ] }, "execution_count": 388, "metadata": {}, "output_type": "execute_result" } ], "source": [ "calc_grad(batch, train_y[:4], linear1)\n", "weights.grad.mean(),bias.grad" ] }, { "cell_type": "code", "execution_count": 389, "metadata": {}, "outputs": [], "source": [ "weights.grad.zero_()\n", "bias.grad.zero_();" ] }, { "cell_type": "code", "execution_count": 390, "metadata": {}, "outputs": [], "source": [ "def train_epoch(model, lr, params):\n", " for xb,yb in dl:\n", " calc_grad(xb, yb, model)\n", " for p in params:\n", " p.data -= p.grad*lr\n", " p.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__little conversion to our results, it's important because we need to understand that what our model says about the numbers(three or not three)__" ] }, { "cell_type": "code", "execution_count": 391, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ True],\n", " [ True],\n", " [False],\n", " [ True]])" ] }, "execution_count": 391, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(preds>0.0).float() == train_y[:4]" ] }, { "cell_type": "code", "execution_count": 392, "metadata": {}, "outputs": [], "source": [ "def batch_accuracy(xb, yb):\n", " preds = xb.sigmoid()\n", " correct = (preds>0.5) == yb\n", " return correct.float().mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "***\n", "__this is training accuracy__" ] }, { "cell_type": "code", "execution_count": 393, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.7500)" ] }, "execution_count": 393, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_accuracy(linear1(batch), train_y[:4])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__this is for validation for all set__" ] }, { "cell_type": "code", "execution_count": 394, "metadata": {}, "outputs": [], "source": [ "def validate_epoch(model):\n", " accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]\n", " return round(torch.stack(accs).mean().item(), 4)" ] }, { "cell_type": "code", "execution_count": 395, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5136" ] }, "execution_count": 395, "metadata": {}, "output_type": "execute_result" } ], "source": [ "validate_epoch(linear1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__one epochs of training__" ] }, { "cell_type": "code", "execution_count": 396, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7121" ] }, "execution_count": 396, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr = 1.\n", "params = weights,bias\n", "train_epoch(linear1, lr, params)\n", "validate_epoch(linear1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__then more__" ] }, { "cell_type": "code", "execution_count": 397, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8656 0.9203 0.9457 0.9549 0.9593 0.9623 0.9652 0.9666 0.9681 0.9705 0.9706 0.9711 0.972 0.973 0.9735 0.9735 0.974 0.9745 0.9755 0.9755 " ] } ], "source": [ "for i in range(20):\n", " train_epoch(linear1, lr, params)\n", " print(validate_epoch(linear1), end=' ')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optimizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Let's start creating our model with Pytorch instead of our \"linear1\" function. Pytorch also creates parameters like our init_params function.__\n" ] }, { "cell_type": "code", "execution_count": 398, "metadata": {}, "outputs": [], "source": [ "linear_model = nn.Linear(28*28,1)" ] }, { "cell_type": "code", "execution_count": 399, "metadata": {}, "outputs": [], "source": [ "w,b = linear_model.parameters()" ] }, { "cell_type": "code", "execution_count": 400, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 784]), torch.Size([1]))" ] }, "execution_count": 400, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w.shape, b.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Custom optimizer__" ] }, { "cell_type": "code", "execution_count": 401, "metadata": {}, "outputs": [], "source": [ "class BasicOptim:\n", " def __init__(self,params,lr): self.params,self.lr = list(params),lr\n", "\n", " def step(self, *args, **kwargs):\n", " for p in self.params: p.data -= p.grad.data * self.lr\n", "\n", " def zero_grad(self, *args, **kwargs):\n", " for p in self.params: p.grad = None" ] }, { "cell_type": "code", "execution_count": 402, "metadata": {}, "outputs": [], "source": [ "opt = BasicOptim(linear_model.parameters(), lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__new training fuction will be__" ] }, { "cell_type": "code", "execution_count": 403, "metadata": {}, "outputs": [], "source": [ "def train_epoch(model):\n", " for xb,yb in dl:\n", " calc_grad(xb, yb, model)\n", " opt.step()\n", " opt.zero_grad()" ] }, { "cell_type": "code", "execution_count": 404, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4078" ] }, "execution_count": 404, "metadata": {}, "output_type": "execute_result" } ], "source": [ "validate_epoch(linear_model)" ] }, { "cell_type": "code", "execution_count": 405, "metadata": {}, "outputs": [], "source": [ "def train_model(model, epochs):\n", " for i in range(epochs):\n", " train_epoch(model)\n", " print(validate_epoch(model), end=' ')" ] }, { "cell_type": "code", "execution_count": 406, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4932 0.8193 0.8418 0.9136 0.9331 0.9477 0.9555 0.9629 0.9658 0.9673 0.9697 0.9717 0.9736 0.9751 0.9761 0.9761 0.9775 0.9775 0.9785 0.9785 " ] } ], "source": [ "train_model(linear_model, 20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fastai's SDG class" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__instead of using \"BasicOptim\" class we can use fastai's SGD class__" ] }, { "cell_type": "code", "execution_count": 407, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4932 0.7808 0.8623 0.9185 0.9365 0.9521 0.9575 0.9638 0.9658 0.9678 0.9707 0.9726 0.9741 0.9751 0.9761 0.9765 0.9775 0.978 0.9785 0.9785 " ] } ], "source": [ "linear_model = nn.Linear(28*28,1)\n", "opt = SGD(linear_model.parameters(), lr)\n", "train_model(linear_model, 20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Just remove the \"train_model\" at this time and use fastai's \"Learner.fit\" Before using Learner first we need to pass our trainig and validation data into \"Dataloaders\" not \"dataloader\"__" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Fastai's Dataloaders" ] }, { "cell_type": "code", "execution_count": 408, "metadata": {}, "outputs": [], "source": [ "dls = DataLoaders(dl, valid_dl)" ] }, { "cell_type": "code", "execution_count": 409, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,\n", " loss_func=mnist_loss, metrics=batch_accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### FastAi's Fit" ] }, { "cell_type": "code", "execution_count": 410, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossbatch_accuracytime
00.6371660.5035750.49558400:00
10.5622320.1397270.90039300:00
20.2045520.2079350.80618300:00
30.0889040.1147670.90480900:00
40.0463270.0816020.93032400:00
50.0297540.0645300.94455300:00
60.0229630.0541350.95485800:00
70.0199660.0472930.96123600:00
80.0184640.0425150.96516200:00
90.0175730.0390110.96663400:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(10, lr=lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Adding a Nonlinearity" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__The basic idea is that by using more linear layers, we can have our model do more computation, and therefore model more complex functions. But there's no point just putting one linear layer directly after another one, because when we multiply things together and then add them up multiple times, that could be replaced by multiplying different things together and adding them up just once! That is to say, a series of any number of linear layers in a row can be replaced with a single linear layer with a different set of parameters.__ (From Fastbook)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Amazingly enough, it can be mathematically proven that this little function can solve any computable problem to an arbitrarily high level of accuracy, if you can find the right parameters for w1 and w2 and if you make these matrices big enough. For any arbitrarily wiggly function, we can approximate it as a bunch of lines joined together; to make it closer to the wiggly function, we just have to use shorter lines. This is known as the __universal approximation theorem.___ The three lines of code that we have here are known as layers. The first and third are known as linear layers, and the second line of code is known variously as a nonlinearity, or activation function.(From Fastbook)" ] }, { "cell_type": "code", "execution_count": 411, "metadata": {}, "outputs": [], "source": [ "simple_net = nn.Sequential(\n", " nn.Linear(28*28,30),\n", " nn.ReLU(),\n", " nn.Linear(30,1)\n", ")" ] }, { "cell_type": "code", "execution_count": 412, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, simple_net, opt_func=SGD,\n", " loss_func=mnist_loss, metrics=batch_accuracy)" ] }, { "cell_type": "code", "execution_count": 413, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossbatch_accuracytime
00.3032840.3983780.51177600:00
10.1423840.2215170.81795900:00
20.0797020.1126100.91707600:00
30.0528550.0764740.94210000:00
40.0403010.0597910.95878300:00
50.0338240.0503890.96418100:00
60.0300750.0444830.96614300:00
70.0276290.0404650.96663400:00
80.0258650.0375530.96957800:00
90.0244990.0353360.97154100:00
100.0233910.0335790.97252200:00
110.0224670.0321420.97399400:00
120.0216790.0309360.97399400:00
130.0209970.0299010.97497500:00
140.0203980.0289980.97497500:00
150.0198690.0281980.97595700:00
160.0193950.0274840.97644800:00
170.0189660.0268410.97693800:00
180.0185770.0262590.97742900:00
190.0182200.0257300.97742900:00
200.0178920.0252440.97841000:00
210.0175880.0247990.97988200:00
220.0173060.0243880.97988200:00
230.0170420.0240080.98037300:00
240.0167940.0236560.98086400:00
250.0165610.0233280.98086400:00
260.0163410.0230220.98086400:00
270.0161330.0227370.98184500:00
280.0159350.0224700.98184500:00
290.0157460.0222210.98184500:00
300.0155660.0219880.98233600:00
310.0153950.0217690.98233600:00
320.0152310.0215650.98282600:00
330.0150760.0213710.98282600:00
340.0149250.0211900.98282600:00
350.0147820.0210180.98282600:00
360.0146430.0208560.98282600:00
370.0145100.0207030.98282600:00
380.0143820.0205580.98282600:00
390.0142580.0204200.98282600:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(40, 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### recorder is a fast ai method" ] }, { "cell_type": "code", "execution_count": 414, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD+CAYAAADBCEVaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZx0lEQVR4nO3de3Bc53nf8e8DLC7EjRQIiJQiU9SNtkTXVCvIcc3I8bRuPPIkkWw1qaOL5bg2E6kae5K0E7VjtqoyU+c24+l0VHmUkSXLF9mySiVyFMvtJHIjRrUjqi6lAS1TlmzKCikCvAi7C2LvT/84B+BiuQAOwSXP7jm/z8wOsWdfLB6+BH98+Z4X72vujoiIJEtX3AWIiEjrKdxFRBJI4S4ikkAKdxGRBFK4i4gkkMJdRCSBFO4iIgkUKdzN7C4z22NmRTN7eIW2v2Nmb5pZ1sy+aGZ9LalUREQisyg/xGRmHwFqwAeBNe7+8SXafRB4BPhnwEHgCeB77n73cu8/NjbmmzdvPq3CRUTS7oUXXjji7uPNXstEeQN33wVgZhPARcs0vR140N0nw/Z/AHwVWDbcN2/ezJ49e6KUIiIiITM7sNRrrZ5z3wrsrXu+F9hgZutb/HVERGQZrQ73IWCm7vn8x8ONDc1sRziPv2d6errFZYiIpFurwz0PjNQ9n/8419jQ3R9w9wl3nxgfbzplJCIiq9TqcJ8EttU93wYcdvejLf46IiKyjKhLITNm1g90A91m1m9mzW7GPgL8azO7yszWAZ8FHm5VsSIiEk3UkftngTmCVS+3hh9/1sw2mVnezDYBuPvTwB8DzwCvAweA/9TyqkVEZFmR1rmfbRMTE66lkCIip8fMXnD3iWavRVrnLiLJVanWmC1WKVarq38Th3LNKVVqlCo1ytUaxfDjUrVGuVJb9tNr7pSqJz+3VAk/P/y4Vot/EHq2TGwe5X1bWr+oROEucgbcnXyxwvHZMsdOlMgVysu2r9Sccl1oleo/rr82H3J1IVmu1jiT/2hXa0Gts6UK+UKFfLFKvlimUF4+eNuFWdwVnB2//YuXKdxFluPulKtLjADnR5DV5UeBDsyVquSLlSAIixVyheDXfLFCrljh+GyJY7Mljp8ocXy2TKna2nDMdBm9mS56urvozXTR291FX/j8TAKuy4yh/gznD/dz6ViGwb4Mw/0ZBnszDPZ109fTzZnkZ0+3hfV2h/UHzxdqX+bdzVj4vfaEv9Z/bndXQpP9LFK4yxlzdwrl2qJArCwXoO7MlarkinWhWRegs8UqxUp1IZyD/9Y7xYbQbjbqPVsGe7sZ7Msw1J9hdKCXt40OsO2idZw32MvoYA/nDfSyfqiXob4elsshM6MvczK0e8OPe7oVZNJaCveUcHeKldqiMD1+IhyBhiPRY+FI9OhskXyxsuz7lSrBPG2uUGa2VKXagjnRTFcwshzszSwE3cnwM9b29oSBaIuCsbe7m56M0bfo2qkjwPkQ7V5h+Lumt5uhMMgH+4J6FLjSaRTuHaxYqTKVLfJmtsCbMwUOh7++mS0wlS2SLZQXjabL1aUD2AzOG+jlvIEeRgd7OX+4f9n/ome6jaG+nuC/9X3dDPX1MNTXvRDOPd3Lr7JdCNAwRIf6gkC3pE6sipxjCvc25+4cminw6nSeH0/l636dZTpXPKV9f08XG0f6OX+kn02jA4tGoEN1j8G+DKODvQuPtWt6NDoVSRCF+zlSrtaYyhUXjbCPnygtutlXv1KiVK0xnSvy6nSeE6WTS9SG+zNcfv4Q798yzttGB9g40s+Gtf1sHAkeI2syGv2KiMK91dyd147M8szLU3z/J8c4NDPHmzNFjs4WT1nG1lW3QqA3003v/GqDcG74vIFefn3ibVx2/hCXjw9x2fmDjA/1KbxFZEUK9xYolKt877WjPPPyFM/8aJrXj50A4JKxQS5eP8A7L1zLhpF+NoYj7A0j/Vywtp91Az0KahE5KxTuq1SrOY+/8AZPT77Jc68eoVCu0d/TxfbLxvjU+y5dmDYREYmDwn0VipUqv/fYXv7yxUNsGh3go9du4v1vH+c9l66nv6c77vJERBTupytbKPNbj7zA/3ntKP/++new432XampFRNqOwv00TGUL3P7Q87xyOMfn/9U2PvyPlzsrXEQkPgr3iF6dzvOxB/+e4ydKfPHj156VjX5ERFpF4R7BD14/zicefp4uM76+4z2866J1cZckIrIshfsK/vqHh/k3X/u/bBjp50u/+W42jw3GXZKIyIoU7sv45p6fcfeul7jqghEe+s1rGRvqi7skEZFIFO5L+P5rR/n9//Ei2y8f4wu3XsNgn7pKRDpH1AOyU+Vovsinv/4DNq8f5H4Fu4h0IKVWg1rN+b1v7uX4iTJf/Pi1DCnYRaQDaeTe4M+efY3v/mianb98FVsvXBt3OSIiq6Jwr/PCgeP8yXd+xIf+0UZu/flNcZcjIrJqCvfQWydKfPrRH3DBun7+8KZ3aUsBEelomlAm2IP93z3+IlO5Ao//9nsZ6e+JuyQRkTOikTvw8HM/5X/tO8zd11/Jtreti7scEZEzlvpwf/GNt/gvf/VDPnDlBj6xfXPc5YiItESqwz1bKHPX137A+FAff/prmmcXkeRI9Zz7H377Zf7hrTke+633sG6gN+5yRERaJtUj992vHOGDWzdwzcWjcZciItJSqQ33bKHM68dO6AeVRCSRUhvuLx/KAXDVhSMxVyIi0nqRwt3MRs3sCTObNbMDZnbzEu3WmdmXzGwqfNzT0mpbaN/BGQC2XqBwF5HkiXpD9T6gBGwArgaeMrO97j7Z0O7zwACwGTgf+GszO+DuD7Wm3NbZdyjL2FAv48Pao11EkmfFkbuZDQI3ATvdPe/uu4EngduaNP8V4I/d/YS7/xR4EPhEC+ttmX2Hslx5wYiWP4pIIkWZltkCVNx9f921vcDWJdpbw8fvXGVtZ025WmP/m3nNt4tIYkUJ9yEg23BtBhhu0vZp4G4zGzazywlG7QPN3tTMdpjZHjPbMz09fTo1n7FXp/OUqjWu0ny7iCRUlHDPA40pOALkmrT9NDAHvAL8BfAo8EazN3X3B9x9wt0nxsfHo1fcAvsOBv9WbdXIXUQSKkq47wcyZnZF3bVtQOPNVNz9mLvf4u4b3X1r+P5/35pSW2ffwSz9PV1cMjYUdykiImfFiqtl3H3WzHYB95rZJwlWy9wAvLexrZldBrwVPn4J2AH8YuvKbY3Jg1nevnGE7i7dTBWRZIr6Q0x3AmuAKYKpljvcfdLMrjOzfF27a4CXCKZsPgfc0mS5ZKzcnX2HsppvF5FEi7TO3d2PATc2uf4swQ3X+eePAY+1qriz4eBMgZm5slbKiEiipW77gfmbqRq5i0iSpTLczeAdG5ut5BQRSYb0hfuhGS5ZP8hgX6q3sheRhEthuGe5UvPtIpJwqQr3mbkyPzs2p/l2EUm8VIX7y4fCm6kauYtIwqUq3PeF4a493EUk6dIV7ge1h7uIpEO6wl17uItISqQm3EuVGq8c1h7uIpIOqQl37eEuImmSmnDXHu4ikibpCfdD2sNdRNIjPeGuPdxFJEVSEe7aw11E0iYV4a493EUkbVIR7trDXUTSJjXhrj3cRSRN0hHu2sNdRFImJeGuPdxFJF0SH+7aw11E0ijx4a493EUkjRIf7trDXUTSKPnhrj3cRSSFEh/ukwe1h7uIpE+iw71UqfHKVE7z7SKSOokO9x9P5SlXna0Xro27FBGRcyrR4T5/M1XLIEUkbRId7j+eytPTbVwyNhh3KSIi51Siw31mrszaNb3aw11EUifR4Z4tlBnp134yIpI+iQ73XKHC8JqeuMsQETnnIoW7mY2a2RNmNmtmB8zs5iXa9ZnZF8zssJkdM7NvmdnPtbbk6LJzGrmLSDpFHbnfB5SADcAtwP1mtrVJu88A/xR4F3AhcBz4by2oc1VyhTIj/Rq5i0j6rBjuZjYI3ATsdPe8u+8GngRua9L8EuA77n7Y3QvAN4Bm/wicE9lChWGN3EUkhaKM3LcAFXffX3dtL81D+0Fgu5ldaGYDBKP8b595mauTK5QZ0Zy7iKRQlGHtEJBtuDYDNDuz7hXgZ8A/AFXgJeCuZm9qZjuAHQCbNm2KWG50pUqNQrnGsE5fEpEUijJyzwONP+I5AuSatL0P6APWA4PALpYYubv7A+4+4e4T4+Pj0SuOKFcoB4Vq5C4iKRQl3PcDGTO7ou7aNmCySdurgYfd/Zi7Fwlupr7bzMbOuNLTlC1UADTnLiKptGK4u/sswQj8XjMbNLPtwA3Al5s0fx74mJmtNbMe4E7goLsfaWXRUWTnwpG7VsuISApFXQp5J7AGmAIeBe5w90kzu87M8nXt/i1QIJh7nwY+BHy4hfVGlgtH7pqWEZE0ijRn4e7HgBubXH+W4Ibr/POjBCtkYpcN59w1LSMiaZTY7Qd0Q1VE0iyx4Z6d0w1VEUmvxIZ7rlDGDIZ6Fe4ikj6JDfdsocJQX4Yu7eUuIimU4HDXpmEikl7JDfc5bRomIumV2HDXpmEikmaJDfdsoaJpGRFJreSGu05hEpEUS2y4a1pGRNIskeFeqzm5om6oikh6JTLcZ0sV3LUjpIikVyLDXXu5i0jaJTLctWmYiKRdIsNdm4aJSNolMtwXRu6acxeRlEpkuOugDhFJu0SGu47YE5G0S2S4zx+OrZG7iKRVMsO9UKEv00VfpjvuUkREYpHIcNfWAyKSdokMd+3lLiJpl8xw1ylMIpJyCQ13jdxFJN0SGe6acxeRtEtkuGfnKjqoQ0RSLZHhntOcu4ikXOLCvVipUqzUNC0jIqmWuHDPaS93EZHkhfv81gOalhGRNEteuGvkLiKSvHDXKUwiIhHD3cxGzewJM5s1swNmdvMS7b5tZvm6R8nMXmptycvTKUwiIhA1Ae8DSsAG4GrgKTPb6+6T9Y3c/fr652b2XeBvzrzM6HQKk4hIhJG7mQ0CNwE73T3v7ruBJ4HbVvi8zcB1wCMtqDMyncIkIhJtWmYLUHH3/XXX9gJbV/i8jwHPuvtPV1nbquQKFboMBnsV7iKSXlHCfQjINlybAYZX+LyPAQ8v9aKZ7TCzPWa2Z3p6OkIZ0WTnygz399DVZS17TxGRThMl3PPASMO1ESC31CeY2S8AG4HHl2rj7g+4+4S7T4yPj0epNZKcdoQUEYkU7vuBjJldUXdtGzC5RHuA24Fd7p4/k+JWQ3u5i4hECHd3nwV2Afea2aCZbQduAL7crL2ZrQF+nWWmZM4m7eUuIhL9h5juBNYAU8CjwB3uPmlm15lZ4+j8RuAt4JlWFXk6snPay11EJNIQ192PEYR24/VnCW641l97lOAfgFhozl1EJIHbD2jOXUQkYeFeqzn5ok5hEhFJVLjnSxXctWmYiEiiwn1+L3fNuYtI2iUq3OdPYdKcu4ikXaLCfeEUJk3LiEjKJSrcdX6qiEggUeGe1V7uIiJAwsJdI3cRkUCiwv3kahmN3EUk3ZIV7oUy/T1d9GYS9dsSETltiUrBXKGi+XYRERIW7tlCWfPtIiIkLNxzhYrWuIuIkLBwz85pR0gREUhYuGsvdxGRQKLCPVvQKUwiIpC4cNfIXUQEEhTuhXKVUqWmOXcRERIU7ie3+9XIXUQkMeG+sGmY5txFRBIU7jqFSURkQWLCXacwiYiclJhwn5+W0Y6QIiIJCveFkfsaTcuIiCQm3BfOT9XIXUQkOeGeK1To7jIGervjLkVEJHaJCff57X7NLO5SRERil5hw16ZhIiInJSbctd2viMhJiQl3jdxFRE5KTLhnCxq5i4jMixTuZjZqZk+Y2ayZHTCzm5dp+0/M7G/NLG9mh83sM60rd2nZubJ+gElEJBR1HuM+oARsAK4GnjKzve4+Wd/IzMaAp4HfAR4HeoGLWlbtMoLzUzUtIyICEUbuZjYI3ATsdPe8u+8GngRua9L8d4HvuPtX3b3o7jl3/2FrSz5VtebkihVNy4iIhKJMy2wBKu6+v+7aXmBrk7bvAY6Z2XNmNmVm3zKzTa0odDn5YrD1gG6oiogEooT7EJBtuDYDDDdpexFwO/AZYBPwE+DRZm9qZjvMbI+Z7Zmeno5ecRMLWw9oL3cRESBauOeBkYZrI0CuSds54Al3f97dC8B/Bt5rZmsbG7r7A+4+4e4T4+Pjp1v3IjqFSURksSjhvh/ImNkVdde2AZNN2r4IeN1zb9Km5RZOYdKcu4gIECHc3X0W2AXca2aDZrYduAH4cpPmDwEfNrOrzawH2AnsdveZVhbdaH7krqWQIiKBqD/EdCewBpgimEO/w90nzew6M8vPN3L3vwH+A/BU2PZyYMk18a1ycs5d0zIiIhBxnbu7HwNubHL9WYIbrvXX7gfub0VxUeV0CpOIyCKJ2H4gW9BSSBGReokI91yhzJqebnq6E/HbERE5Y4lIw+ycth4QEamXjHDXjpAiIoskIty1l7uIyGKJCPdsoaytB0RE6iQi3IORu8JdRGReIsI9OD9V0zIiIvM6PtzdXSN3EZEGHR/uxUqNUrWmpZAiInU6Ptyz2npAROQUnR/uc9rLXUSkUceH+/ymYVoKKSJyUseHe1anMImInKLzw31OpzCJiDTq+HDXKUwiIqfq+HBfOD9VSyFFRBZ0fLjnCmW6u4w1Pd1xlyIi0jY6PtyzcxVG+jOYWdyliIi0jY4P91yhrPl2EZEGHR/u2YJOYRIRadTx4Z4rlBnu08hdRKRex4e7zk8VETlVx4d7TuenioicouPDPau93EVETtHR4V6tOfmipmVERBp1dLjntfWAiEhTHR3uC1sPaEdIEZFFEhHuGrmLiCzW2eE+fwqT5txFRBbp6HBfOIVJI3cRkUU6OtzXD/Vy/Ts3Mj7cF3cpIiJtJVK4m9momT1hZrNmdsDMbl6i3T1mVjazfN3j0taWfNI1F49y/63XsGGk/2x9CRGRjhR1svo+oARsAK4GnjKzve4+2aTtN9z91hbVJyIiq7DiyN3MBoGbgJ3unnf33cCTwG1nuzgREVmdKNMyW4CKu++vu7YX2LpE+18xs2NmNmlmd5xxhSIictqihPsQkG24NgMMN2n7GHAlMA58CviPZvYbzd7UzHaY2R4z2zM9PX0aJYuIyEqihHseGGm4NgLkGhu6+z53P+juVXd/DvivwL9s9qbu/oC7T7j7xPj4+OnWLSIiy4gS7vuBjJldUXdtG9DsZmojB3S4qYjIObZiuLv7LLALuNfMBs1sO3AD8OXGtmZ2g5mdZ4F3A58G/qLVRYuIyPKi/hDTncAaYAp4FLjD3SfN7Dozy9e1+yjwY4Ipm0eAP3L3L7WyYBERWZm5e9w1YGbTwIFVfvoYcKSF5bSSaluddq4N2rs+1bY6nVrbxe7e9KZlW4T7mTCzPe4+EXcdzai21Wnn2qC961Ntq5PE2jp6bxkREWlO4S4ikkBJCPcH4i5gGaptddq5Nmjv+lTb6iSuto6fcxcRkVMlYeQuIiINFO4iIgnUseEe9QCRuJjZd82sUHdoyY9iquOucIO2opk93PDaPzezl83shJk9Y2YXt0NtZrbZzLzh0Jed57i2PjN7MPzeypnZ/zOz6+tej63vlqutTfruK2Z2yMyyZrbfzD5Z91rc33NNa2uHfqur8YowO75Sd+3m8M971sz+3MxGV3wjd+/IB8FPyn6DYNfKXyDYqXJr3HXV1fdd4JNtUMdHgBuB+4GH666PhX32a0A/8CfA99qkts0E+xJlYuy3QeCesJYu4JcJfvJ6c9x9t0Jt7dB3W4G+8ON3AG8C18TdbyvUFnu/1dX4P4Fnga/U1ZwD3hfm3deAr6/0PlFPYmordQeIvNPd88BuM5s/QOTuWItrM+6+C8DMJoCL6l76CDDp7t8MX78HOGJm73D3l2OuLXYe7Kl0T92lvzSznxAEwXpi7LsVanvhbH/9lfjiE9o8fFxGUF/c33NL1Xb0XHz9lZjZR4G3gOeAy8PLtwDfcve/DdvsBH5oZsPufsruvPM6dVrmdA8QicvnzOyImf2dmb0/7mIabCXoM2AhMF6lvfrwgJm9YWYPmdlYnIWY2QaC77tJ2qzvGmqbF2vfmdl/N7MTwMvAIeCvaJN+W6K2ebH1m5mNAPcCv9vwUmO/vUpw7OmW5d6vU8P9dA4QicvvA5cCP0ewTvVbZnZZvCUtMkTQZ/XapQ+PANcCFxOM9oaBr8ZVjJn1hF//S+EIs236rkltbdF37n5n+LWvI9hVtkib9NsStbVDv/0B8KC7v9FwfVX91qnhHvkAkbi4+/fdPefuRQ92xvw74ENx11WnbfvQg7N697h7xd0PA3cBv2RmcYRnF8H21qWwDmiTvmtWWzv1nQeH9uwmmHK7gzbpt2a1xd1vZnY18AHg801eXlW/deScO3UHiLj7K+G1qAeIxKXdDi6ZBG6ffxLex7iM9uzD+Z+0O6eDETMz4EFgA/Ahdy+HL8Xed8vU1iiWvmuQ4WT/tNv33Hxtjc51v72f4Kbu68EfLUNAt5ldBTxNkG8AmNmlQB9BDi4t7jvDZ3BH+esEK2YGge200WoZYB3wQYIVARmCGyKzwJYYasmEdXyOYJQ3X9N42Gc3hdf+iHO/cmGp2n4eeDvBX6z1BKuinomh774AfA8YarjeDn23VG2x9h1wPsG5DkNAd/j3YBb41bj7bYXa4u63AWBj3eNPgcfDPttKMA19XZh3XyHCaplz9s14FjpjFPjz8A/ndeDmuGuqq20ceJ7gv01vhX8J/0VMtdzDyVUB8497wtc+QHBTaY5g6ebmdqgN+A3gJ+Gf7SGCg182nuPaLg7rKRD8t3j+cUvcfbdcbXH3Xfi9/7/D7/ss8BLwqbrX4+y3JWuLu9+a1HoP4VLI8PnNYc7NEpxuN7rSe2hvGRGRBOrUG6oiIrIMhbuISAIp3EVEEkjhLiKSQAp3EZEEUriLiCSQwl1EJIEU7iIiCaRwFxFJoP8P3mOrNPBRzp4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(L(learn.recorder.values).itemgot(2));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Last value__" ] }, { "cell_type": "code", "execution_count": 415, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.982826292514801" ] }, "execution_count": 415, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.recorder.values[-1][2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GOING DEEPER" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__why deeper if it is two and a nonlinear between them is enough__" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We already know that a single nonlinearity with two linear layers is enough to approximate any function. So why would we use deeper models? The reason is performance. With a deeper model (that is, one with more layers) we do not need to use as many parameters; it turns out that we can use smaller matrices with more layers, and get better results than we would get with larger matrices, and few layers." ] }, { "cell_type": "code", "execution_count": 416, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.0897270.0117550.99705600:13
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dls = ImageDataLoaders.from_folder(path)\n", "learn = cnn_learner(dls, resnet18, pretrained=False,\n", " loss_func=F.cross_entropy, metrics=accuracy)\n", "learn.fit_one_cycle(1, 0.1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }