{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch Tutorial\n", "\n", "> In this post, We will cover the basic tutorial while we use PyTorch. This is the summary of lecture CS285 \"Deep Reinforcement Learning\" from Berkeley.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, PyTorch, Berkeley]\n", "- image: " ] }, { "cell_type": "markdown", "metadata": { "id": "edLrpksDJfDg" }, "source": [ "## Intro\n", "This is a PyTorch Tutorial for UC Berkeley's CS285.\n", "There's already a bunch of great tutorials that you might want to check out, and in particular [this tutorial](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).\n", "This tutorial covers a lot of the same material. If you're familiar with PyTorch basics, you might want to skip ahead to the PyTorch Advanced section.\n", "\n", "First, let's import some things and define a useful plotting function" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "id": "dq6HedFjckGr" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "import numpy as np\n", "\n", "def plot(xs, ys, xlim=(-3, 3), ylim=(-3, 3)):\n", " fig, ax = plt.subplots()\n", " ax.plot(xs, ys, linewidth=5)\n", " # ax.set_aspect('equal')\n", " ax.grid(True, which='both')\n", "\n", " ax.axhline(y=0, color='k')\n", " ax.axvline(x=0, color='k')\n", " ax.set_xlim(*xlim)\n", " ax.set_ylim(*ylim)" ] }, { "cell_type": "markdown", "metadata": { "id": "ynLs1RCduz_O" }, "source": [ "## PyTorch Basic" ] }, { "cell_type": "markdown", "metadata": { "id": "U5rl_7Kx5vk8" }, "source": [ "### Tensors" ] }, { "cell_type": "markdown", "metadata": { "id": "rjXm2HUv_rBm" }, "source": [ "Numpy arrays are objects that allow you to store and manipulation matrices." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 166 }, "id": "x55DbowZBcDY", "outputId": "a37168ca-d4b1-4ef9-cb8c-c29c2ec569f1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0.]\n", " [0. 0. 0.]]\n", "+\n", "[[1. 1. 1.]\n", " [1. 1. 1.]]\n", "=\n", "[[1. 1. 1.]\n", " [1. 1. 1.]]\n" ] } ], "source": [ "shape = (2, 3)\n", "x = np.zeros(shape)\n", "y = np.ones(shape)\n", "z = x + y\n", "print(x)\n", "print(\"+\")\n", "print(y)\n", "print(\"=\")\n", "print(z)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.0\n" ] } ], "source": [ "print(z.sum())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "g3Acexuuz0RS", "outputId": "3941e8b2-f5b5-465e-a910-dbbc83bf12f8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 1.]\n" ] } ], "source": [ "print(z[0, 1:])" ] }, { "cell_type": "markdown", "metadata": { "id": "kPGVgAOZ_mIq" }, "source": [ "PyTorch is build around _tensors_, which play a similar role as numpy arrays. You can do many of the same operations in PyTorch:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 166 }, "id": "ujftyWqRCGtK", "outputId": "471f29d4-5da3-4620-c6e4-3f6ad3f9d872" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0., 0., 0.],\n", " [0., 0., 0.]])\n", "+\n", "tensor([[1., 1., 1.],\n", " [1., 1., 1.]])\n", "=\n", "tensor([[1., 1., 1.],\n", " [1., 1., 1.]])\n" ] } ], "source": [ "x = torch.zeros(shape)\n", "y = torch.ones(shape)\n", "z = x + y\n", "\n", "print(x)\n", "print(\"+\")\n", "print(y)\n", "print(\"=\")\n", "print(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "Nb6RWm3iAQjd" }, "source": [ "Many functions have alternate syntax that accomplish the same thing" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "ozmW7nxfKvue" }, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 1., 1.],\n", " [1., 1., 1.]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.add(x, y)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "6l0MwUdvDNsU" }, "outputs": [ { "data": { "text/plain": [ "tensor(1.)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z.min()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "Kv65CHa_D-WK" }, "outputs": [ { "data": { "text/plain": [ "tensor([1.])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z[1:, 0]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "T6T5UEfJ0eKa" }, "outputs": [ { "data": { "text/plain": [ "tensor(6.)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.sum(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "NKs4pG0oAhro" }, "source": [ "Function that reduce dimenions will by default reduce all dimensions unless a dimension is specified" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "rcLKnEh80gxY" }, "outputs": [ { "data": { "text/plain": [ "tensor([3., 3.])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.sum(z, dim=1)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "2hi_pF7a0i5Q" }, "outputs": [ { "data": { "text/plain": [ "tensor([2., 2., 2.])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.sum(z, dim=0)" ] }, { "cell_type": "markdown", "metadata": { "id": "VUjHBWOHAaXZ" }, "source": [ "Like numpy, pytorch will try to broadcast operations" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "ffy_uyGF0qZV" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[1.],\n", " [1.],\n", " [1.]])\n", "+\n", "tensor([[1., 1., 1.]])\n", "=\n", "tensor([[2., 2., 2.],\n", " [2., 2., 2.],\n", " [2., 2., 2.]])\n" ] } ], "source": [ "x = torch.ones((3, 1))\n", "y = torch.ones((1, 3))\n", "z = x + y\n", "\n", "print(x)\n", "print(\"+\")\n", "print(y)\n", "print(\"=\")\n", "print(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "U8AQHu-K_6Z9" }, "source": [ "Operations that end with an underscore denote in-place functions. Use these sparingly as they easily lead to bugs." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "ewIwYOhnnCC5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[2., 2., 2.],\n", " [2., 2., 2.],\n", " [2., 2., 2.]])\n", "tensor([[0., 0., 0.],\n", " [0., 0., 0.],\n", " [0., 0., 0.]])\n" ] } ], "source": [ "print(z)\n", "z.zero_()\n", "print(z)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "BsC0k_cgnFW-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[5., 5., 5.],\n", " [5., 5., 5.],\n", " [5., 5., 5.]])\n" ] } ], "source": [ "z.add_(5)\n", "print(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "DwA9Pmj-K4Pr" }, "source": [ "#### Moving between numpy and PyTorch" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "v3J1zw1WyHlG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.13298262 0.72862306 0.95932965]\n", " [ 2.24109395 0.53208287 -0.42554932]]\n" ] } ], "source": [ "x_np = np.random.randn(*shape)\n", "print(x_np)" ] }, { "cell_type": "markdown", "metadata": { "id": "YkIYPgPzA2-C" }, "source": [ "numpy -> pytorch is easy" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "EMPx2iKzybuk" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.1330, 0.7286, 0.9593],\n", " [ 2.2411, 0.5321, -0.4255]], dtype=torch.float64)\n" ] } ], "source": [ "x = torch.from_numpy(x_np)\n", "print(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "vInPpem8ArmF" }, "source": [ "By default, numpy arrays are float64. You'll probably want to convert arrays to float32, as most tensors in pytorch are float32." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "xZSegeSByh9D" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.1330, 0.7286, 0.9593],\n", " [ 2.2411, 0.5321, -0.4255]])\n" ] } ], "source": [ "x = torch.from_numpy(x_np).to(torch.float32)\n", "print(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "lUd3zm2yAzuI" }, "source": [ "pytorch -> numpy is also easy" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "mjxo4Lwryu-6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.13298263 0.7286231 0.95932966]\n", " [ 2.2410939 0.53208286 -0.42554933]]\n" ] } ], "source": [ "print(x.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "4NDB2ypyy9eE" }, "source": [ "#### GPU support" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "Kh-UHYiVzSKc" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "markdown", "metadata": { "id": "gvdb3bSIA7i-" }, "source": [ "The code below errors out because both tensors need to be on the same device." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "C2DzoBcizK7j" }, "outputs": [ { "ename": "RuntimeError", "evalue": "expected device cpu but got device cuda:0", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mones\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mz\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m: expected device cpu but got device cuda:0" ] } ], "source": [ "device = torch.device(\"cuda\")\n", "x = torch.zeros(shape)\n", "y = torch.ones(shape, device=device)\n", "z = x + y" ] }, { "cell_type": "markdown", "metadata": { "id": "inYRaSWbBI5y" }, "source": [ "You can move a tensor to the GPU by using the `to` function." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8jbmhtOY5Oj4" }, "outputs": [], "source": [ "x = x.to(device)\n", "z = x + y\n", "print(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "aXasNfBJBKFl" }, "source": [ "This code also errors out, because you can't convert tensors on a GPU into numpy arrays directly." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "ITCz002W5Vu0" }, "outputs": [ { "data": { "text/plain": [ "array([[5., 5., 5.],\n", " [5., 5., 5.],\n", " [5., 5., 5.]], dtype=float32)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "OEW3bs_ABaM6" }, "source": [ "First you need to move them to the CPU." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "z00AQCfG5Xaw" }, "outputs": [ { "data": { "text/plain": [ "array([[5., 5., 5.],\n", " [5., 5., 5.],\n", " [5., 5., 5.]], dtype=float32)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z_cpu = z.to('cpu')\n", "z_cpu.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "PP-wQ8sDzIbn" }, "source": [ "### (homework aside)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "F5W8aKyiyPtc" }, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'cs285'", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mcs285\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minfrastructure\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mpytorch_util\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mptu\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mptu\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx_np\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mptu\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cs285'" ] } ], "source": [ "from cs285.infrastructure import pytorch_util as ptu\n", "ptu.from_numpy(x_np)\n", "ptu.to_numpy(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "lkH9ZA-s49Mg" }, "source": [ "## Neural-Network specific functions\n", "PyTorch has a bunch of built-in funcitons.\n", "See [the docs](https://pytorch.org/docs/stable/torch.html) for a full list." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "gpEeDQWcC2AT" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAZp0lEQVR4nO3deXSV9b3v8fc3E2GOQiDMs2FKWit1tkbFipaKgO3t3HO9Lceea6tdtoDiLCpgr73e1q6Wc7SnnuPx9pRBqKA4bhAVRTyQhCGIyBAEGTcQICTZ+d0/COfiKU8S2E/28MvntVbWYiff/Xu+P3f88OTJs7+Ycw4REfFHRrIbEBGRcCnYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8E3ewm1mumb1vZmvMbK2ZPRhGYyIicnYs3vvYzcyA9s65KjPLBpYDtzvnVoTRoIiInJmseBdwJ/5mqGp4mN3woXc9iYgkSdzBDmBmmcAqYDDwlHPuvdPUTAImAeTm5l7Qt2/fMA6dkurr68nI8PfXF77ub/v27Tjn0Pdm+vJlf8djsPdYPbX1n/98za5Ne51z+U09P+5LMZ9bzCwPmA/81DlXHlRXWFjoKioqQjtuqolEIpSUlCS7jRbj6/5KSkqIRqOsXr062a20GF9fu5PSfX9Hjtfxq1cq+Od3tnC6aN46c+wq59yoptYJ5Yz9JOdc1MwiwBggMNhFROTzln+0l6nzSqk8cCzutcK4Kya/4UwdM2sLjAY2xLuuiEhrcPBoLZPnrOF7T78XGOpZGcbPrhnS7DXDOGPvAfyp4Tp7BvDvzrkXQ1hXRMRrL5fv4t4F5ew5fDyw5gu9OzPz5mKGFnTizmauG8ZdMaXA+fGuIyLSWuw+XM0DC9eyuGxXYE1udgZ3XlvILZcPIDPDzmj9UK+xi4hIMOcc8z7cwUMvruPgsdrAuosHnsuMCcX079r+rI6jYBcRSYDKA0e5e345yzbuCazp2CaLu24Yxre+3IeMMzxLP5WCXUSkBdXXO/5lxVZmvryBozWxwLprhnZj+viR9OjcNu5jKthFRFrIx3uqmDq3lJVbDgTWnNs+h/u/Ppwbv9CTExNa4qdgFxEJWW2sntnLNvPk6x9RU1cfWDfuiz25b+xwunRoE+rxFewiIiEq33GQyXNKWbfzUGBNQadcpt80ktHDu7dIDwp2EZEQVNfG+D+vf8Qflm0mVh88quU7F/Vl6vVD6ZSb3WK9KNhFROK0cst+pswpZfPeI4E1/bq0Y8aEYi4Z1KXF+1Gwi4icparjdcx6eQPPvrs1sCbD4EdXDOTno8+jbU5mQvpSsIuInIWlG/dw97wydkSDh3YNLejIzInFfKFPXgI7U7CLiJyRA0dqeHjROuZ9uCOwJjvTuO2qIfykZBA5WYmfD69gFxFpBuccL5Xv4r4F5eytqgms+2KfPGbdXMx53TsmsLvPU7CLiDRh96Fq7l1QzpK1nwXW5GZn8IuvFvLfLzvzoV1hU7CLiARwzvGXVZVMf3Edh6rrAusuHdSFGROK6dulXQK7C6ZgFxE5je37j3LXvDKWb9obWNMxN4tpNwzjv325T2jjAMKgYBcROUWs3vHsu1uY9XIFx2qDh3ZdO7w7028aSfdOuYlrrpkU7CIiDTbtPszkOaV8uC0aWNO1Qw4P3jiSG4oKUuos/VQKdhFp9Wpj9fw+8jG/eWMTNbHgoV3jz+/FfWOHc077nAR2d+YU7CLSqpVVHuSXc9awYdfhwJoenXN5dHwRVw3tlsDOzp6CXURaperaGL9+bSP/uGwzjczs4vsX92PymEI6tuDQrrAp2EWk1Xlv8z6mzivjk0aGdg3o2p4ZE4q4aGDLD+0Km4JdRFqNw9W1zHx5A/+6YltgTWaG8eMrBnLH6CHkZidmaFfYFOwi0iq8uWE3d88vY+fB6sCaYT06MWtiMUW9Oyews/Ap2EXEa/uP1PDwi+uY/x/BQ7tyMjP46dWDubVkENmZiR/aFTYFu4h4yTnHi6U7eWDhWvYdCR7a9aW+J4Z2De6WvKFdYVOwi4h3dh2s5p4XynltffDQrnY5mUy+rpDvX9I/6UO7wqZgFxFvOOd4/v1tPLpoPYePBw/tumJIVx4dX0Sfc1NjaFfY4g52M+sDPAsUAPXAbOfck/GuKyJyJrbuO8KsldWs318WWNMpN4t7xg7nGxf0TtlxAGEI44y9DrjTOfehmXUEVpnZq865dSGsLSLSqFi9449vf8KvXqmgujZ4HMCYEQU8NG4E3VJwaFfY4g5259xOYGfDnw+b2XqgF6BgF5EWVbHrMJPnlrJme2NDu9rw0LgR3FDUI4GdJVeo19jNrD9wPvBemOuKiJyqpq6e30U28dSbm6iNBc8DmPil3tw7dhh57VJ7aFfYQgt2M+sAzAXucM4dOs3XJwGTAPLz84lEImEdOuVUVVVpf2koGo0Si8W83NtJPrx2mw/GeKbsOJVVwYHeJdf4uxE5FOUfYPX77ySwu9RgzjUy/aa5i5hlAy8CS5xzTzRVX1hY6CoqKuI+bqqKRCKUlJQku40W4+v+SkpKiEajrF69OtmttJh0fu2O1cR44tUKnl7+SeDQLgN+eGl/fnFdIR3a+HfTn5mtcs6NaqoujLtiDHgaWN+cUBcROVPvfryPqfNK2brvaGDNwPz2fHtgjB/fOCKBnaWmMP5Kuwz4PlBmZidPde52zi0OYW0RacUOVdfy2OINPP9+40O7br1yID+9eggr3n4rgd2lrjDuilnOiZ+ARERC8/r6z5g2v5xdh4KHdo3o2YlZNxczomd6D+0Km38XoUQkre2rOs6Df13HwjWfBtbkZGVwx+gh/PiKgV4M7Qqbgl1EUoJzjoVrPuWBhWs5cLQ2sG5Uv3OYMbGYwd06JLC79KJgF5Gk23nwGPfML+f1DbsDa9rnZDLl+qF876J+ZHg2tCtsCnYRSZr6esfzK7fx2OINVDUytOvK8/J5ZPxIep/j59CusCnYRSQptuw9wtR5pazYvD+wpnPbbO4bO5wJX+rl9dCusCnYRSSh6mL1PPP2J/yvVzZyvC54aNfXinrwwI0jyO/YJoHd+UHBLiIJs2HXIabMKWVN5cHAmvyObXh43EjGjCxIYGd+UbCLSIs7XhfjqTc/5ndvbqIuaB4A8M1RvZl2w3A6t8tOYHf+UbCLSIv6cNsBpswp5aPdVYE1vc9py4wJxVw+pGsCO/OXgl1EWsTRmjp+tWQjf3znE4JmDZrB313an19eV0i7HMVRWPRfUkRC9/amvUydV8r2/ccCawZ368DMicVc0O+cBHbWOijYRSQ0B4/V8uii9fz5g+2BNVkZxk9KBnHb1YNpk5WZwO5aDwW7iIRiydpd3PtCObsPHw+sKerVmZkTixnes1MCO2t9FOwiEpc9h4/zwMK1LCrbGVjTJiuDn197Hj+6fABZGtrV4hTsInJWnHO8sHoHD/51HdFGhnZdOOBcZkwoYmC+hnYlioJdRM7Yjugxps0vI1KxJ7CmQ5ssplw/lO9e2FdDuxJMwS4izVZf73ju/W3MWLyeIzWxwLqSwnweHV9Ez7y2CexOTlKwi0izbN5TxdS5Zby/JXho1zntsrn/6yMY98WeGtqVRAp2EWlUXayef1r+Cb9+tfGhXWOLTwzt6tpBQ7uSTcEuIoHWfXqIyXPXUL7jUGBN905tmH5TEdcO757AzqQxCnYR+RvVtTF++8Ymfr/040aHdn37wj7cdcMwOuVqaFcqUbCLyOes2rqfyXNK+XjPkcCavue2Y8aEIi4drKFdqUjBLiIAHDlex+NLKvjTu1sCh3ZlGNxy2QDu/GohbXM0DiBVKdhFhLc+2sNd88qoPBA8tOu87ieGdp3fV0O7Up2CXaQVO3i0lumL1vGXVZWBNdmZxj+UDOYfrhqkoV1pQsEu0kq9XL6TexesZU8jQ7u+0LszM28uZmiBhnalEwW7SCuz+3A19y9Yy0vluwJrcrMzuPPaQm65fACZGgeQdhTsIq2Ec47lO2q5fekyDh4LHtp1ycAuzJhYRL8u7RPYnYQplGA3s2eAscBu59zIMNYUkfBUHjjKXfPKeOujmsCajm2yuPtrw/jWl/toHECaC+uM/Z+B3wLPhrSeiISgvt7x7LtbmLWkgqONDO0aPawb028qoqBzbuKakxYTSrA755aZWf8w1hKRcGzaXcXUuaV8sPVAYM257XN44MYRfL24h87SPZKwa+xmNgmYBJCfn08kEknUoROuqqpK+0tD0WiUWCyW9nurq3e89EktCzbVUhc8DYCLe2Ty3WFZdDywkaVLNyauwRbk6/fmmUpYsDvnZgOzAQoLC11JSUmiDp1wkUgE7S/95OXlEY1G03pv5TsOMnlOKet2Hg2sKeiUyyPjR3LNMP+Gdvn6vXmmdFeMiAeqa2M8+fpHzF62mVgjQ7uu6pPFk//jKxra5TkFu0iaW7llP1PmlLJ5b/DQrv5d2jFjYjHV28oU6q1AWLc7Pg+UAF3NrBK43zn3dBhri8jpVR2vY9bLG3j23a2BNRkGP75iIHeMPo+2OZlEtiWwQUmasO6K+XYY64hI8yzduIe755WxIxo8tGtoQUdm3VxMce+8BHYmqUCXYkTSyIEjNTy8aB3zPtwRWJOTmcFtVw/m1isHkZOVkcDuJFUo2EXSgHOOl8p3cd+CcvZWBb979Py+ecyaWMyQ7h0T2J2kGgW7SIrbfaiaexeUs2TtZ4E1bbMz+eV1hfzw0v4a2iUKdpFU5ZzjLx9UMn3ROg5V1wXWXTa4C4+NL6Zvl3YJ7E5SmYJdJAVt339iaNfyTXsDazrmZnHP14bxzVEa2iWfp2AXSSGxesef3tnC40sqOFYbPLTr2uHdmX7TSLp30tAu+VsKdpEU8dFnh5kyt5QPt0UDa7p2yOHBG0dyQ1GBztIlkIJdJMlq6ur5w9KP+c0bm6iJ1QfWTTi/F/eOHc457XMS2J2kIwW7SBKVVkaZPKeUDbsOB9b07JzLIxOKuKqwWwI7k3SmYBdJguraGL9+dSP/+NZmGpnZxfcv7seU64fSoY3+V5Xm03eLSIKt2LyPqXNL2bIveLTugK7tmTmxmAsHnJvAzsQXCnaRBDlcXcuMlzbw3HvBk7gyM4wfXTGAn48+j9zszAR2Jz5RsIskwBsbPmPa/HJ2HqwOrBnWoxOzJhZT1LtzAjsTHynYRVrQ/iM1PPTXtbyw+tPAmpzMDH52zWD+/spBZGdqaJfET8Eu0gKcc/y1dCcPLFzL/iPBQ7su6HcOMycWM7hbhwR2J75TsIuEbNfBau55oZzX1gcP7WqXk8nk6wr5wSX9ydDQLgmZgl0kJM45/rxyO48sWs/h48FDu64Y0pVHxxfR51wN7ZKWoWAXCcHWfUeYOreMdzfvC6zplJvFvWOHc/MFvTUOQFqUgl0kDrF6xx/f/oRfvVJBdW3wOIDrRxbw4LgRdOuooV3S8hTsImepYtdhJs8tZc32xoZ2teHhcSO4vqhHAjuT1k7BLnKGaurq+V1kE0+9uYnaWPA8gJsv6M09XxtGXjsN7ZLEUrCLnIHV26NMmVNKxWfBQ7t65bXlsQlFfOW8/AR2JvL/KdhFmuFYTYwnXq3g6eWfBA7tMoMfXNyPyWOG0l5DuySJ9N0n0oR3Pt7L1LllbNsfPLRrUP6JoV2j+mtolySfgl0kwKHqWh5bvIHn3298aNdPrhzEbVcP1tAuSRkKdpHTeG3dZ0x7oYzPDh0PrBnRsxOzbi5mRE8N7ZLUomAXOUVNDH72/H+wcE0jQ7uyMrj9miH8/VcGkqWhXZKCQgl2MxsDPAlkAv/knJvRWP2hmhP/EruvPtpay1btL63srTpOaWWUozWxRkP9y/3PYcbEYgbla2iXpK64g93MMoGngGuBSmClmS10zq0Les7+asf9C9fGe+jUtl77SzdHa2KBX2ufk8mU64fyvYv6aWiXpDxzrpF/cLE5C5hdAjzgnLuu4fFdAM65x4Kek5HT1uUUDI7ruCJhq9m9GYCcbgM/9/m8djkM6NqeNlnpf9klGo2Sl5eX7DZajO/7W7p06Srn3Kim6sK4FNML2H7K40rgov9aZGaTgEkAlq15GZL6Mg26tcugc5sYx6oOcSzZDYUgFosRjQaPQEh3vu+vucII9tP9XPo3PwY452YDswHa9BjiCr7T6GV4kYTb9W9TAbhuyu+5fmQB3xzVh3Pb+zUOIBKJUFJSkuw2Wozv+2vuVNAwgr0S6HPK495A8G+fgI45xg8u6RfCoVPTjh076NWrV7LbaDE+7i/DjD+/0p6MumoW/M/Lkt2OSFzCCPaVwBAzGwDsAL4FfKexJ3TJNR4aNzKEQ6emSGQvJSXaX7qJPJFLNBr8j02LpIu4g905V2dmtwFLOHG74zPOOf9umRARSROh3MfunFsMLA5jLRERiU/6378lIiKfo2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDNxBbuZfcPM1ppZvZmNCqspERE5e/GesZcDE4BlIfQiIiIhyIrnyc659QBmFk43IiISt7iC/UyY2SRgEkB+fj6RSCRRh064qqoq7S8NRaNRYrGYl3s7ydfX7iTf99dcTQa7mb0GFJzmS9OccwuaeyDn3GxgNkBhYaErKSlp7lPTTiQSQftLP3l5eUSjUS/3dpKvr91Jvu+vuZoMdufc6EQ0IiIi4dDtjiIinon3dsfxZlYJXAIsMrMl4bQlIiJnK967YuYD80PqRUREQqBLMSIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuKZuILdzB43sw1mVmpm880sL6zGRETk7MR7xv4qMNI5VwxsBO6KvyUREYlHXMHunHvFOVfX8HAF0Dv+lkREJB5hXmO/BXgpxPVEROQsZDVVYGavAQWn+dI059yChpppQB3wXCPrTAImAeTn5xOJRM6m37RQVVWl/aWhaDRKLBbzcm8n+franeT7/prLnHPxLWD2Q+BW4Brn3NHmPKewsNBVVFTEddxUFolEKCkpSXYbLcbX/ZWUlBCNRlm9enWyW2kxvr52J/m+PzNb5Zwb1VRdk2fsTRxkDDAFuLK5oS4iIi0r3mvsvwU6Aq+a2Woz+30IPYmISBziOmN3zg0OqxEREQmH3nkqIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4Jq5gN7OHzazUzFab2Stm1jOsxkRE5OzEe8b+uHOu2Dn3ReBF4L4QehIRkTjEFezOuUOnPGwPuPjaERGReGXFu4CZPQL8ADgIXNVI3SRgUsPD42ZWHu+xU1hXYG+ym2hBPu+vq5n5ujfw+7UD//dX2Jwic67xk2wzew0oOM2XpjnnFpxSdxeQ65y7v8mDmn3gnBvVnAbTkfaXvnzeG2h/6a65+2vyjN05N7qZx/w3YBHQZLCLiEjLifeumCGnPLwR2BBfOyIiEq94r7HPMLNCoB7YCtzazOfNjvO4qU77S18+7w20v3TXrP01eY1dRETSi955KiLiGQW7iIhnkhbsPo8jMLPHzWxDw/7mm1lesnsKk5l9w8zWmlm9mXlza5mZjTGzCjPbZGZTk91PmMzsGTPb7ev7R8ysj5m9aWbrG743b092T2Exs1wze9/M1jTs7cEmn5Osa+xm1unkO1fN7GfAcOdcc3/5mtLM7KvAG865OjObCeCcm5LktkJjZsM48QvzPwC/cM59kOSW4mZmmcBG4FqgElgJfNs5ty6pjYXEzL4CVAHPOudGJrufsJlZD6CHc+5DM+sIrAJu8uH1MzMD2jvnqswsG1gO3O6cWxH0nKSdsfs8jsA594pzrq7h4QqgdzL7CZtzbr1zriLZfYTsQmCTc26zc64G+L/AuCT3FBrn3DJgf7L7aCnOuZ3OuQ8b/nwYWA/0Sm5X4XAnVDU8zG74aDQvk3qN3cweMbPtwHfxd4DYLcBLyW5CmtQL2H7K40o8CYbWxsz6A+cD7yW3k/CYWaaZrQZ2A6865xrdW4sGu5m9Zmblp/kYB+Ccm+ac6wM8B9zWkr2Eram9NdRMA+o4sb+00pz9ecZO8zlvfopsLcysAzAXuOO/XBVIa865WMMU3d7AhWbW6OW0uIeANdGMt+MImtqbmf0QGAtc49LwzQJn8Nr5ohLoc8rj3sCnSepFzkLD9ee5wHPOuXnJ7qclOOeiZhYBxgCBvwhP5l0x3o4jMLMxwBTgRufc0WT3I82yEhhiZgPMLAf4FrAwyT1JMzX8gvFpYL1z7olk9xMmM8s/eWedmbUFRtNEXibzrpi5nBhB+Z/jCJxzO5LSTMjMbBPQBtjX8KkVvtzxA2Bm44HfAPlAFFjtnLsuuV3Fz8xuAP43kAk845x7JMkthcbMngdKODHW9jPgfufc00ltKkRmdjnwFlDGiUwBuNs5tzh5XYXDzIqBP3Hi+zID+Hfn3EONPicNrxKIiEgj9M5TERHPKNhFRDyjYBcR8YyCXUTEMwp2ERHPKNhFRDyjYBcR8cz/AyCVB2MwYsZLAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xs = torch.linspace(-3, 3, 100)\n", "ys = torch.relu(xs)\n", "plot(xs.numpy(), ys.numpy())" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "of1U9ahX46eL" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAZ90lEQVR4nO3deZScdZ3v8fe3lt43knSSzr4Qml1ACAIqLYIwwnHBmTuI3vHK3Jvx3qPjHK8OKjrOjJcrHM91nHN1VGZg1Bl1jg4DjrgBV1r2LdBAyEZCts7eSSrd1Ut1V9X3/tHVEJz0kq6nq7p/+bzOeU49Vf2r3/P9pTqfPPnVs5i7IyIi4YiVuwAREYmWgl1EJDAKdhGRwCjYRUQCo2AXEQmMgl1EJDBFB7uZVZnZ02b2gpm9bGZ/FUVhIiIyOVbscexmZkCtu6fNLAk8CnzS3Z+MokARETkxiWI78OF/GdKFp8nCorOeRETKpOhgBzCzOLAWOBX4prs/dZw2a4A1AFVVVW9esmRJFJuelvL5PLFYuF9fhDq+Xbt24e7od3PmCn18mzdv7nL35vHaFT0V84bOzJqAe4BPuPu60dq1trb6pk2bItvudNPe3k5bW1u5y5gyoY6vra2NVCpFR0dHuUuZMqF+diNCH5+ZrXX3C8drF+k/be6eAtqBa6LsV0REJi6Ko2KaC3vqmFk1cCWwsdh+RURkcqKYY28BvleYZ48BP3b3+yLoV0REJiGKo2JeBM6PoBYREYlAuF8fi4icpBTsIiKBUbCLiARGwS4iEhgFu4hIYBTsIiKBUbCLiARGwS4iEhgFu4hIYBTsIiKBUbCLiARGwS4iEhgFu4hIYBTsIiKBUbCLiARGwS4iEhgFu4hIYBTsIiKBUbCLiARGwS4iEhgFu4hIYBTsIiKBUbCLiARGwS4iEhgFu4hIYBTsIiKBUbCLiASm6GA3s8Vm9pCZbTCzl83sk1EUJiIik5OIoI8s8D/d/TkzqwfWmtkD7r4+gr5FROQEFb3H7u573f25wnoPsAFYWGy/IiIyOZHOsZvZMuB84Kko+xURkYmLYioGADOrA+4G/szdu4/z8zXAGoDm5mba29uj2vS0k06nNb4ZKJVKkcvlghzbiFA/uxGhj2+izN2L78QsCdwH/NrdvzZe+9bWVt+0aVPR252u2tvbaWtrK3cZUybU8bW1tZFKpejo6Ch3KVMm1M9uROjjM7O17n7heO2iOCrGgDuBDRMJdRERmVpRzLFfBvxn4Aoz6ygs746gXxERmYSi59jd/VHAIqhFREQioDNPRUQCo2AXEQmMgl1EJDAKdhGRwCjYRUQCo2AXEQmMgl1EJDAKdhGRwCjYRUQCo2AXEQmMgl1EJDAKdhGRwCjYRUQCo2AXEQmMgl1EJDAKdhGRwCjYRUQCo2AXEQmMgl1EJDAKdhGRwCjYRUQCo2AXEQmMgl1EJDAKdhGRwCjYRUQCo2AXEQmMgl1EJDAKdhGRwCSi6MTM7gKuAw64+9lR9CkiUk65vDOUyxcWJ5vLM5QffszmnWzOyebz5PJONu/Dj7nhx5w7uXyeXJ7XH93JF9rlfWRh+DE/vD7yM+ANP3eHfN4nXHskwQ58F/gG8P2I+hMROa6BoRzpTJb0QJZ0JktvJkvvYJbeTI7nO4fY9tg2+gZzDAzl6B/M0T80vGSG8gwM5chkX3/MZHMMZvPDSy5PJjsc5IPZPCeQo9NOJMHu7g+b2bIo+hKRk0M2l+dw7yAH0xkO9w5yuHeQQ+lBUn2DpPqHONI3xNH+4aW7sPQMZBnM5cfueN360gxgGotqj31cZrYGWAPQ3NxMe3t7qTZdcul0WuObgVKpFLlcLsixjSjVZ9efdbr6ncMDeQ73O4cHnFTGOZJxUgN5jmac9BDM4J3iaa1kwe7udwB3ALS2tnpbW1upNl1y7e3taHwzT1NTE6lUKsixjYjqs3N3DvZkeLWrl+1dvWw/1MeOQ73sOtLHrsP9HO0fKr5YmbSSBbuIzEwHegbYtK+HTft62Livhy0H0mw9mKZnIFvu0qZcRSJGMmYkEzESsRjJuJGIG8lYjHjMSMRjJGI2vB4zYoXH+Mhir6/HjnluxhvWY4X1mBlAYR3Mhl8bXoc/v31idSvYReQ1+44O0LHrCC/tPsrLe7pZt7ubrnSm3GW9QTxm1FclqKt8famtTFBbGaf7cBcrliykuiJOdTJOTUWcquTwUp0cWY9RmYhTmYhRmYxREY9RkTjmMREjWQhsKwTtdPHnE2wX1eGOPwLagDlm1gl8yd3vjKJvEZka2Vye9Xu7eXrbYdbuOMLzO1Ps6x4oaQ1NNUnm1FUyu7aCWccsTTUVNFUnaapJ0lg9vDRUJ2moSlKVjI0auMNTTTriOqqjYj4YRT8iMnVyeWf70Rzf/u1WHt96iGe3H6ZvMDcl20rGjZbGahY0VbGgsZqWpirmN1Yzv6GKeQ2VNNdXMru2koqEzpGcCpqKEQnYwZ4MD28+SPvmgzzyykFSfUPAxkj6rkrGWD6njhVzalk2p4als2tZOquGxbNqmNdQRTw2vaYxTiYKdpHAbDmQ5v71+7j/5f107EoV3V8ybpw6t57T59fTOr+e1nn1nDq3joVN1cQU3tOSgl0kAK/s7+FnL+7l5y/uYevB3kn3k4gZp7fU86ZFTZy7qJGzFjRy2rx6TZnMMAp2kRlq39EB7nl+N/c+v5tN+3sm1UdDVYILl83iomWzuHDZKZyzsJGqZDziSqXUFOwiM0gmm+PXL+/nJ8/u4rEtXSd8PZPKOFxyajOXrZzDJStnc2ZLg6ZTAqRgF5kBth5M86OndnL3c50c6TuxszpXNtfyjta5vOP0ufTtfImrrlg9RVXKdKFgF5mmcnnnoY0H+O7j23l0S9cJvfeCJU1cfdZ8rjpzHiua6157vb1Te+cnAwW7yDTTm8ny42d38Y+PbWfn4b4Jv++8xU1cd24L157bQktj9RRWKNOdgl1kmuhKZ/je49v5/hM7JnwRrcWzqrn+/EVcf8FCls6uneIKZaZQsIuU2d6j/Xznt6/yo6d3ksmOc61xhi9Mde05LfzhRYtZvWyWvvyU/0DBLlImu1P9/N1DW/jJs53j3zwCWD6nlg+/ZSkfuGAhTTUVJahQZioFu0iJHegZ4O8e2soPn9o5oUC//LRm/stly7h8VbP2zmVCFOwiJXK0f4hvtW/lu49vY2Bo7EBPxo33n7+Q//a2FayaV1+iCiUUCnaRKZbJ5vinJ3bwjYe2FC7CNbraijgffstSbnrrcuY1VJWoQgmNgl1kirg79724l9t/tZHOI/1jtq2vSvDRS5fx0cuWc0qt5s+lOAp2kSnwwq4UX75vPc/uODJmu9qKOH/81uX88dtW0FidLFF1EjoFu0iEDvZkuO2XG7n7uc4x21UmYnzk0mX8ydtXMLuuskTVyclCwS4SgaFcnu8/sYOvP7CZnszoN3k2g9+/YBGfetdpOjtUpoyCXaRIT287zBfvXTfupXMvP62Zz737dE6f31CiyuRkpWAXmaRD6Qz/+xfjT7ucOreOL1x7Bm2tc0tUmZzsFOwiJyifd3787C6+8suNY17TpbE6yaeuOo0bL15CMq47EEnpKNhFTsCWAz18/t/W8fT2w6O2MYMbLlrCZ65uZZYOXZQyULCLTEAmm+ObD23lW+1bGMqNftuiNy1q5MvvO5tzFzWVsDqRN1Kwi4zjme2H+ezdL455k+iGqgQ3/97p3HDREuK6nouUmYJdZBQ9A0Pc/quN/POTO8dsd/35C/n8tWcwR8ejyzShYBc5jgfX7+cL965jX/fAqG2Wza7h1vefw2WnzilhZSLjU7CLHCOXh0/86Hl+9sKeUdskYsafXL6CT1yxiqpkvITViUxMJMFuZtcAfwvEgX9w99ui6FekVNydrnSGbUdz9IwR6uctbuK2D5yjk4xkWis62M0sDnwTuAroBJ4xs3939/XF9i1SCntS/Xzh3nVsOZAetU1NRZzPXN3KH12yTF+OyrRn7qMfujWhDswuAf7S3a8uPP8cgLt/ZbT31NTU+OrVq4va7nSWSqVoagr3cLeQxre/e4Cdh/vI5Z3BA68CUDF3xRvaNNVUsHxOLZWJmX+SUUif3fGEPr7f/va3a939wvHaRTEVsxDYdczzTuDi321kZmuANQDJZJJUKhXBpqenXC6n8U1zgznY15unLzv6jk3cYG5NjMbKHP3pbsa+ovrMEMJnN5bQxzdRUQT78f5f+h/+trj7HcAdAK2trd7R0RHBpqen9vZ22trayl3GlJnJ4xvK5fn7R17l6w++QkM2z7Ez5ft++FkA5t94G9ee08KX3nMmc+vDuovRTP7sJiL08ZlNbBowimDvBBYf83wRMPq3TyJl8mJnipvvfokNe7tHbZOIGd/+8Ju55uz5JaxMJFpRBPszwCozWw7sBm4AboygX5FI9Gay/M0Dm7nrsW3kx/hKaW59JY2JrEJdZryivw1y9yzwceDXwAbgx+7+crH9ikThNxv3866/eZh/eHT0UF8yq4Yf/NeLWdFcR1wHvEgAIjmO3d1/Afwiir5EorC/e4C//tl6fv7S3lHbxAxuumw5n3rXadRU6Fw9CYd+myUo2cIt6r72wGbSY9yi7oyWBm7/wDm6CqMEScEuwVi74whfvHcd68f4crQyEeNP37mKNW9foZtfSLAU7DLjHezJcPuvNvKva8e+Rd3bVs3hf73vbJbOri1RZSLloWCXGWswm+f7T2znbx98hZ4xpl1m11bwxevO5L3nLZjwccAiM5mCXWYcd+c3Gw9w68838GrX6De/MIMPrl7CzVefTmNNsoQVipSXgl1mlPV7uvnKLzfwyCtdY7Y7s6WBW99/NucvOaVElYlMHwp2mRF2p/r5P/dv4p7ndzPWdesaq5N8+upWblytW9TJyUvBLtPa4d5BvtW+he89sYPBbH7UdmZww0WL+czVpzOrtqKEFYpMPwp2mZaO9g9x5yOvcuej2+gdzI3ZdvXyWfzFdWdy9sLGElUnMr0p2GVaSfUNctdj2/nuY9voHhj9SBeAxbOqueXdZ3D1WfN1tIvIMRTsMi0c7Mlw12Pb+Kcndox5xijAKTVJPnHFKj70liVUJnTPUZHfpWCXsnr1YJq/f2Qbdz/XOeYcOkBVMsZHL1vOf29bSUOVDl8UGY2CXUrO3Xli6yH+8fHtPLhh/5hHuQBUxGPcePES/sc7VgZ34wuRqaBgl5JJZ7L8tGM333t8O5v3j37j6BHJuPH7b17Ex69YxcKm6hJUKBIGBbtMKXfnxc6j/MszO/lpxx76xjnCBaAiEeOGixbzsctXskCBLnLCFOwyJfZ3D3Dv87u5+7nOCe2dA9RXJfjQxUu56bJlzG3QlIvIZCnYJTKpvkF+tW4f9724l8e3do15G7pjtTRWcdNly7lh9WLq9aWoSNEU7FKUAz0D/L8NB/j1y/t49JUushNNc4ZPLPropcu46sx5JHRtdJHIKNjlhOTzzvajOb7xm1f4zcYDPL8rNe5RLceqq0zwvvMXcOPqpZy5oGHqChU5iSnYZVx7Uv08vvUQj2/p4uFXuuhKZ4DNJ9THm5eewh9euJjr3tSi+4uKTDH9DZM3cHe2H+rjmW2HeWb78LL9UN+k+lp0SjXXX7CI689fyLI5umuRSKko2E9yB3oGeHl3Ny90pujYleKFXSmO9A1Nur/5DVVce24L153bwnmLm3QNF5EyULCfJAaGcmzr6mXTvh427uth075uXt7TzYGeTNF9r2yu5aoz5/Ous+Zx3qImYroOukhZKdgDks3l2Xt0gB2H+th+qJftXb1s6+ply8E0uw73Tfjww/EkYnDJyjm0tc6lrbWZlc110XQsIpFQsM8Q7s6RviH2dw+wr3uAfUcH2JvqZ3dqgD2pfjpTfexNDZzQ4YYTZQZnL2jk0pWzufTUOQzsXMfVV14c+XZEJBoK9jLJZHMc7R+iuz9Lqm+QI31DHOkb5EjvIId7BzlUeDzYk6ErneFQepDB3NhXP4xKZSLGuYsauWjZLC5aNosLlp5CY/XrJw6179FUi8h0pmAfg7uTyeYZzOXJDOUZGMqRyeYYKKz3D+XoH3z9sW8wR99glg2vDPLQ0XWkMzl6M1nSmSw9mSw9A0P0DAw/DgyVJqTHk4wbq+bWc9aCBt60uInzFjfROr+epE4YEpmxigp2M/sD4C+BM4DV7v7sRN6XHnL+dW0neXfcnbyDO+Tdh5f88GuvPXfI5Ydfz/nrj7k85PJ5cvnhttnCei6fJ5t3cnknm3eyuTzZnDP0hvXCY244uIdyeQazeYZyzmB2eL2oPeRtOyb/3ikyv6GKVfPqaJ1XT+v8es5oaWDVvDrdrEIkMMXusa8Drge+cyJv6up3Pv2TF4rctBzPrNoKFp9SzbI5tSydXcuy2TWsbK5j5dw66ir1HzSRk0FRf9PdfQOgY5VLpLYizrzGKubVVzGvoZKWpmoWNFbR0ljNolnVLDqlRuEtIqWbYzezNcAagIr5p5Zqs9NWzKAmAbVJozZp1CWN2iTUVRj1FUbDMY+NlcOPVYmRf0AHCstRyAAHYN8B2Fei2tPpNO3t7SXaWumkUilyuVyQYxsR6mc3IvTxTdS4wW5mDwLzj/OjW9z9pxPdkLvfAdwBUNmyKvpj8qZIMm5UJeJUJmNUJuJUJmJUJeNUJYcfayriVCXjVBfWayoT7N+9k7NaV1FfmaC2MkFtZZz6qgT1VUnqKhPUVyWoq0zM2P/ptLe309bWVu4yItfU1EQqlQpybCNC/exGhD6+iRo32N39yqg3Wps0rr9gITEzDIjHDDPDDOJmxGx4eidmRjwGsVhhvfCzWKywHjMSMSN+zBIzIxk34rHYaz879nkyHiMRH34tGY8VluH1ikTstdcqEzEq4rFJnUXZ3r6Ptrcuj/qPTURkQsoyIdtcbXztP51Xjk2LiASvqIOVzez9ZtYJXAL83Mx+HU1ZIiIyWcUeFXMPcE9EtYiISAR0eqGISGAU7CIigVGwi4gERsEuIhIYBbuISGAU7CIigVGwi4gERsEuIhIYBbuISGAU7CIigVGwi4gERsEuIhIYBbuISGAU7CIigVGwi4gERsEuIhIYBbuISGAU7CIigVGwi4gERsEuIhIYBbuISGAU7CIigVGwi4gERsEuIhIYBbuISGAU7CIigVGwi4gEpqhgN7OvmtlGM3vRzO4xs6aoChMRkckpdo/9AeBsdz8X2Ax8rviSRESkGEUFu7vf7+7ZwtMngUXFlyQiIsWIco79JuCXEfYnIiKTkBivgZk9CMw/zo9ucfefFtrcAmSBH4zRzxpgDUBzczPt7e2TqXdGSKfTGt8MlEqlyOVyQY5tRKif3YjQxzdR5u7FdWD2EeBjwDvdvW8i72ltbfVNmzYVtd3prL29nba2tnKXMWVCHV9bWxupVIqOjo5ylzJlQv3sRoQ+PjNb6+4Xjtdu3D32cTZyDXAzcPlEQ11ERKZWsXPs3wDqgQfMrMPMvh1BTSIiUoSi9tjd/dSoChERkWjozFMRkcAo2EVEAqNgFxEJjIJdRCQwCnYRkcAo2EVEAqNgFxEJjIJdRCQwCnYRkcAo2EVEAqNgFxEJjIJdRCQwCnYRkcAo2EVEAqNgFxEJjIJdRCQwCnYRkcAo2EVEAqNgFxEJjIJdRCQwCnYRkcAo2EVEAqNgFxEJjIJdRCQwCnYRkcAo2EVEAqNgFxEJjIJdRCQwRQW7mX3ZzF40sw4zu9/MFkRVmIiITE6xe+xfdfdz3f084D7gLyKoSUREilBUsLt79zFPawEvrhwRESlWotgOzOxW4I+Ao8A7xmi3BlhTeJoxs3XFbnsamwN0lbuIKRTy+OaYWahjg7A/Owh/fK0TaWTuY+9km9mDwPzj/OgWd//pMe0+B1S5+5fG3ajZs+5+4UQKnIk0vpkr5LGBxjfTTXR84+6xu/uVE9zmD4GfA+MGu4iITJ1ij4pZdczT9wAbiytHRESKVewc+21m1grkgR3Axyb4vjuK3O50p/HNXCGPDTS+mW5C4xt3jl1ERGYWnXkqIhIYBbuISGDKFuwhX47AzL5qZhsL47vHzJrKXVOUzOwPzOxlM8ubWTCHlpnZNWa2ycy2mNlny11PlMzsLjM7EOr5I2a22MweMrMNhd/NT5a7pqiYWZWZPW1mLxTG9lfjvqdcc+xm1jBy5qqZ/SlwprtP9MvXac3M3gX8xt2zZnY7gLvfXOayImNmZzD8hfl3gE+7+7NlLqloZhYHNgNXAZ3AM8AH3X19WQuLiJm9HUgD33f3s8tdT9TMrAVocffnzKweWAu8L4TPz8wMqHX3tJklgUeBT7r7k6O9p2x77CFfjsDd73f3bOHpk8CictYTNXff4O6byl1HxFYDW9z9VXcfBP4FeG+Za4qMuz8MHC53HVPF3fe6+3OF9R5gA7CwvFVFw4elC0+ThWXMvCzrHLuZ3Wpmu4APEe4FxG4CflnuImRcC4FdxzzvJJBgONmY2TLgfOCp8lYSHTOLm1kHcAB4wN3HHNuUBruZPWhm646zvBfA3W9x98XAD4CPT2UtURtvbIU2twBZhsc3o0xkfIGx47wWzP8iTxZmVgfcDfzZ78wKzGjunitcRXcRsNrMxpxOK/oiYOMUE+zlCMYbm5l9BLgOeKfPwJMFTuCzC0UnsPiY54uAPWWqRSahMP98N/ADd/+3ctczFdw9ZWbtwDXAqF+El/OomGAvR2Bm1wA3A+9x975y1yMT8gywysyWm1kFcAPw72WuSSao8AXjncAGd/9aueuJkpk1jxxZZ2bVwJWMk5flPCrmboYvQfna5QjcfXdZiomYmW0BKoFDhZeeDOWIHwAzez/wf4FmIAV0uPvV5a2qeGb2buDrQBy4y91vLXNJkTGzHwFtDF/Wdj/wJXe/s6xFRcjM3go8ArzEcKYAfN7df1G+qqJhZucC32P49zIG/Njd/3rM98zAWQIRERmDzjwVEQmMgl1EJDAKdhGRwCjYRUQCo2AXEQmMgl1EJDAKdhGRwPx/CGwV1xt9T54AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xs = torch.linspace(-3, 3, 100)\n", "ys = torch.tanh(xs)\n", "plot(xs.numpy(), ys.numpy())" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "smjZcbzo4Ers" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxU9dn38c8v+54QCATCvoU9uBStS407rfsCT61dnnq3tHf34taWqrVWKy50u9v7qVVbe9tqQYtS3LWmLhUUNIGwhB0S9gCTfZ35PX8k9KY6JwnMme3k+3698kqGuXLm+jHD1+OZc64x1lpERMQ7EqLdgIiIuEvBLiLiMQp2ERGPUbCLiHiMgl1ExGMU7CIiHhNysBtj0owx7xpjKowx64wxd7rRmIiInBgT6nnsxhgDZFprG40xycBbwLettSvcaFBERI5PUqgbsF3/ZWjsvpnc/aWrnkREoiTkYAcwxiQCq4HxwK+ttSuD1MwD5gGkpaWdMnLkSDceOiYFAgESErz79oVX11ddXY21Fr0241e8rK8jAAebA7QHnGtSE6EgPYGkY5azadOmWmttQW/bD/lQzL9tzJg8YCnwTWttpVNdcXGxraqqcu1xY01ZWRmlpaXRbiNsvLq+0tJSfD4f5eXl0W4lbLz63B0V6+vzByy/e3Mbi17eRLs/eKqnJCZw40UT+dLZY0lMMP92nzFmtbX21N4ex5U99qOstT5jTBkwG3AMdhGR/mZHbRM3Lqlg9c4jjjXTinJYNHcmE4dkh/RYIQe7MaYA6OgO9XTgAmBhqNsVEfECay2Pr9jJPc9vpKXDH7QmMcHw9XPH883zxpOcGPqhJDf22IcCj3UfZ08AFltrl7uwXRGRuLa3roVbnlrDm5trHWvGFWSyaO5MSkbkufa4bpwVswY4yYVeREQ8wVrL0g92c8eydTS0dgatMQb+48wx3HRxMWnJia4+vqvH2EVE+rvaxjYWLF3LS+v2O9YMH5DOA3NKOH3swLD0oGAXEXHJi5X7WLB0LYea2h1rrps1kgWXTCYrNXzxq2AXEQlRXUsHP1q2jqUf7HasGZydysJrZ3Bu8eCw96NgFxEJwRubDnLLU2vYV9/qWHPlzGH86PKp5GWkRKQnBbuIyAloauvkpy9s4PEVuxxrBmQkc/dV0/nU9KER7EzBLiJy3N7bcZibllSw81CzY80Fk4fw06unU5CdGsHOuijYRUT6qLXDz89e2cRDb27DaRpLdmoSd1w+lWtOLqJr+G3kKdhFRPqgcncd8xeXs2l/o2PNGeMGcv+cEory0iPY2Ucp2EVEetDhD/Cb17fyq79vpjMQfDc9LTmB739yMp87fRQJCdHZSz+Wgl1ExMGWAw3MX1zBmpo6x5qTRubx4JwSxhZkRbCzninYRUQ+JBCwPPr2du57qYr2zuDjdZMTDd+9cCJf+cS4j4zXjTYFu4jIMXYdauampyp4d/thx5rJQ3NYNLeEyUNzIthZ3ynYRUToGtz1xLvV3P3ceprag4/XTTDwtdLxfOv8CaQkxe4nNSnYRaTf21fXyq1Pr+Efmw461owtyOTBOSWcNHJABDs7MQp2Eem3rLUsq9jDbc9UUu8wXhfghjPHcMts98frhouCXUT6pUONbdz2bCXPr93nWFOU1zVe9+PjwjNeN1wU7CLS77y8bh8/WLqW2kbn8bpzTx3ObZdOITstOYKduUPBLiL9Rn1rB3cuW8/T79c41hRkp3Lv1dM5f/KQCHbmLgW7iPQLb2+p5eYlFeypcx6ve+mModx1xTQGZEZmvG64KNhFxNOa2ztZ+MJGHntnp2NNXkYyd10xjctKhkWws/BRsIuIZ63eeYQbF5ezo4fxuudNGsy9V09ncE5aBDsLLwW7iHhOW6efn7+6md/+YysOc7vISk3i9kunMOfU4VEbrxsuCnYR8ZR1e+q4cXEFG/c1ONacPjaf+68tYUR+RgQ7ixwFu4h4Qqc/wLKt7fztlbfp8AffTU9NSuCW2ZP44hmjY2K8brgo2EUk7m092Mj8xRVUVHc41pSM6BqvO35w7IzXDRcFu4jErUDA8od/7mDhixtp62G87rfPn8BXzxlHUmLsDu5yk4JdROJSzZFmbl6yhne2HXKsmVSYzYNzS5g6LDeCnUVfyMFujBkB/BEoBALAQ9baX4S6XRGRYKy1LF5VzV3LN9DYFnxwV4KBr5wzju9cMIHUpPgY3OUmN/bYO4EbrbXvG2OygdXGmFestetd2LaIyL8cqG/le39dy983HnCsGT0wgwfnlnDKqPwIdhZbQg52a+1eYG/3zw3GmA1AEaBgFxHX/K1iD7c9W4mv2fkN0vNHJvGrL51NRkr/Psrs6uqNMaOBk4CVbm5XRPqvI03t3PZsJcvX7HWsGZabxn3XltC5u7Lfhzq4GOzGmCzgaeA71tr6IPfPA+YBFBQUUFZW5tZDx5zGxkatLw75fD78fr8n13ZUvD135Qc6+f26duraHC4fBc4qSuIzkxLo3F0Zd+sLF2Ot819YnzdiTDKwHHjJWruot/ri4mJbVVUV8uPGqrKyMkpLS6PdRth4dX2lpaX4fD7Ky8uj3UrYxMtz19DawU+Wb+Avq6odawZlpXDPVdO5aGrhv/4sXtZ3oowxq621p/ZW58ZZMQZ4BNjQl1AXEenJP7fWcvOSNez2tTjWfGp6IT+5cjr5cT5eN1zcOBRzJvA5YK0x5uiuzg+stc+7sG0R6SdaO/wsfHEjv397h2NNTloSd105jctLhnlucJeb3Dgr5i1Af8MicsLKq33MX1zOtoNNjjXnTCxg4TUzKMz1znjdcNHbxyISNe2dAX752mZ+U7bFcbxuRkoiP7xkCtfNGqG99D5SsItIVGzYW8/8xRVs2PuRk+j+ZdbofB6YU8LIgd4crxsuCnYRiahOf4CH3tzGz17Z5DheNyUpgVsuLuaGM8d4erxuuCjYRSRittc2MX9xOR/s8jnWTC/KZdHcEiYMyY5gZ96iYBeRsAsELP+zYic/fWEDrR3Bx+smJRi+ed4EvnbuOJL7yXjdcFGwi0hY7fa1cMtTFby9xXm87oTBWSyaO5Ppw/vXeN1wUbCLSFhYa3lqdQ0//tt6GhzG6xoD884ey3cvnEhacv8brxsuCnYRcd3Bhja+/9e1vLphv2PNyPwMHphTwqwx/Xe8brgo2EXEVS+s3cuCZyo53NTuWHP9aSP5wacmk5mqCAoH/a2KiCvqmju4fVklz5bvcawpzElj4bUzOGdiQQQ7638U7CISsrKqA9z69Br217c51lx1UhE/umwquRnJEeysf1Kwi8gJa2zr5J7nN/DnlbscawZmpnD3VdOYPW1oBDvr3xTsInJCVm47xE1PVVB92Hm87kVThnDP1dMZlJUawc5EwS4ix6W1w88DL1XxyNvbcfqcnuy0JO68fCpXnVSkwV1RoGAXkT5bU+Nj/uIKthxodKw5e8IgFl4zg2F56RHsTI6lYBeRXnX4A/zq71v49etb8DvM101PTuQHl0zms6eN1F56lCnYRaRHVfsamL+4nHV7nMfrnjJqAA/OKWH0oMwIdiZOFOwiEpQ/YHnkrW088NIm2v3BB3elJCbw3QsnMu8TY0nUeN2YoWAXkY/YeaiJm5ZU8N6OI441U4flsGjuTIoLNV431ijYReRfrLU8vnIX9zy3gZYOf9CaxATD10vH8Y3zJpCSpPG6sUjBLiIA7K1r4Zan1vDm5lrHmnEFmSyaO5OSEXkR7EyOl4JdpJ+z1rL0g93csWwdDa3O43VvOHMMN19crPG6cUDBLtKP1Ta2sWDpWl5a5zxed/iAdB6YU8LpYwdGsDMJhYJdpJ96sXIfC5au5VAP43WvmzWSBZdMJkvjdeOKni2RfqaupYM7l63jrx/sdqwZnJ3KwmtncG7x4Ah2Jm5RsIv0I5W1nXzvZ2+wr77VsebykmH8+Iqp5GWkRLAzcZOCXaQfaG7vGq/7+CrneekDMpL5yZXTuWSGxuvGOwW7iMet2nGYG5dUsPNQs2PNBZMHc8/V0xmcnRbBziRcXAl2Y8yjwKXAAWvtNDe2KSKhae3w87NXN/HQG9scx+tmpSZx+2VTmHPKcA3u8hC39tj/APwX8EeXticiIajcXceNiyuo2t/gWHPGuIHcd+0Mhg/IiGBnEgmuBLu19g1jzGg3tiUiJ67DH+C/y7byy9c20+kwXjctOYHvf3Iynzt9FAka3OVJETvGboyZB8wDKCgooKysLFIPHXGNjY1aXxzy+Xz4/f64XduexgC/W9PG9vrgkxgBRmdbvjozlcL2Hbzxxo7INRchXn1tHq+IBbu19iHgIYDi4mJbWloaqYeOuLKyMrS++JOXl4fP54u7tQUClkff3s59K6po7wwe6smJhu9eOJFJtprzzj03wh1Gjldfm8dLZ8WIxLHqw83cuKSCd7cfdqyZPDSHRXNLmDw0h7Kymgh2J9GiYBeJQ9Zannyvmp8sX09Te/DxugkG/rN0HN8+f6LG6/Yzbp3u+ARQCgwyxtQAd1hrH3Fj2yLy7/bXt3Lr02soqzroWDNmUCYPzi3h5JEDItiZxAq3zoq5zo3tiEjPni3fze3PrqOupcOx5v+eMZpbZ08iPUXjdfsrHYoRiQOHm9q57ZlKnlu717GmKC+d+6+dwRnjB0WwM4lFCnaRGPfq+v18769rqW10nvMy99Th/PDSKeSkJUewM4lVCnaRGFXf2sGP/7aep1Y7n8lSkJ3KvVdP5/zJQyLYmcQ6BbtIDPrnllpufmoNu30tjjWXzBjKT66YxoBMjdeVf6dgF4khLe1+Fr64kT/8c4djTV5GMnddMY3LSoZFrjGJKwp2kRixeucRblpSwfbaJseac4sLWHjNDAbnaLyuOFOwi0RZW6efn7+6md/+YysOc7vITEnktkun8H8+NkLjdaVXCnaRKFq/p575i8vZuM95vO5pY/J5YE4JI/I1Xlf6RsEuEgWd/gC/fWMbP391Ex3+4LvpqUkJ3DJ7El88Y7TG68pxUbCLRNjWg43cuLiC8mqfY03J8FwenDuT8YOzItiZeIWCXSRCAgHLY+/sYOGLG2ntCD5eNynB8K3zJ/C10nEkJWpwl5wYBbtIBNQcaebmJWt4Z9shx5riIdk8OLeEaUW5EexMvEjBLhJG1loWr6rmruUbaGzrDFqTYODLnxjL/AsnkpqkwV0SOgW7SJgcqG/l+39dy2sbDzjWjB6YwYNzSzhlVH4EOxOvU7CLhMHyNXv44TOV+Jqdx+t+/uOj+N4nJ5GRon+G4i69okRcdKSpnduXreNvFXsca4blpnHftSWcNUHjdSU8FOwiLnl94wFueXoNBxucx+tec/Jw7rhc43UlvBTsIiFqaO3g7uc28OR71Y41g7JSuPuq6Vw8tTCCnUl/pWAXCcE7Ww9x05KKHsfrzp5ayN1XTWNgVmoEO5P+TMEucgJaO/zc92IVj7693bEmJy2JH18xjStmDtPgLokoBbvIcSqv9jF/cTnbDjqP1/3ExALuu2YGhbkaryuRp2AX6aP2zgC/+vtmflO2Fb/DfN2MlEQWXDKZz8waqb10iRoFu0gfbNxXz/y/VLB+b71jzazRXeN1Rw7UeF2JLgW7SA/8ActDb2zjZ69sot0ffHBXSlICt1xczBfPHEOixutKDFCwizjYXtvETUsqWL3ziGPN9KJcFs0tYcKQ7Ah2JtIzBbvIhwQClsdX7uSnz2+kpcMftCYpwfCN88bz9XPHk6zxuhJjFOwix+gIwOcffZe3ttQ61kwYnMWiuTOZPlzjdSU2uRLsxpjZwC+AROBha+29bmxXJJKa2jrZUe+nySHUjYEvn901XjctWeN1JXaFHOzGmETg18CFQA3wnjFmmbV2fajbFomUNzYdZP3eehzeH2VEfjoPzpnJrDEaryuxz4099lnAFmvtNgBjzJPAFYBjsFdXV1NaWurCQ8cmn89HXl5etNsIG6+tr7axja0Hm2jbvxWAfX/+3r/dPyQnDfIzuOX5+D/jxWvP3Yd5fX195UawFwHHTj+qAU77cJExZh4wDyA5ORmfz/mDfOOd3+/X+uJEXZtlb5Pz548OzTRkJnXQUF8X4c7Cw0vPXTBeX19fuRHswXZjPnJZnrX2IeAhgOLiYlteXu7CQ8emsrIyT/8fiVfW98LavXz9z+9T2P1qPbqnXviZe5lUmM1jN8zq2lv3EK88d068vr6+Xs3sRrDXACOOuT0ccP6UAZEY8Obmg3z7yXKCTQaYNSaf333+VHLTNTNd4pMbJ+C+B0wwxowxxqQAnwaWubBdkbB4f9cRvvI/q4NeSZqdYvjjDbMU6hLXQt5jt9Z2GmO+AbxE1+mOj1pr14XcmUgYrKnx8YVH36W5/aMXHuVlpDAkza9TGSXuuXIeu7X2eeB5N7YlEi6Vu+v47MMraWjt/Mh9Hxs9gP1Dsqiv88abpNK/6Vpo6RfW7anj+odXUh8k1KcMzeHhL3yMBI3ZFY9QsIvnHd1Tr2vp+Mh94woyeUzH1MVjNCtGPG3VjsN88ffv0dD20T31sYMyeeLLp1OQrc8iFW9RsItnvbW5li//cVXQCY1jBmXyxLzTGeyx89RFQMEuHvXSun1884kPaO/86CmNowZm8MSXT/fcxUciRynYxXMeX7GT25+tDHrx0YTBWTz+pdMU6uJpCnbxDGstP3tlE7/8+5ag908dlsMfb5jFwCwdUxdvU7CLJ7R3BvjhM2tZvKom6P0nj8zj91/U2S/SPyjYJe75mtv56uOrWbHtcND7z5lYwG+uP5nMVL3cpX/QK13i2raDjfzHY6vYXtsU9P5rTh7OvddM1+eSSr+iYJe49Y9NB/nWEx8EvfAI4OvnjuOmi4r7POpUxCsU7BJ3rLX8pmwrD7xchQ1y5ktiguHOy6fy2dNHRb45kRigYJe40tjWyU2LK3hx3b6g92enJfHf15/CWRMGRbgzkdihYJe4Ubm7jm/8+X12HGoOev+ogRk88oWPMX5wVoQ7E4ktCnaJedZaHl+5i7uWrw96JSnA2RMG8ctPn8SAzJQIdycSexTsEtN8ze0sWFrJc2v3Otb8Z2nXm6SJCXqTVAQU7BLD3tpcy01LKthX3xr0/oyURB6YU8Knpg+NcGcisU3BLjGntcPP/S9V8chb2x1rJg/N4defOYmxBTqeLvJhCnaJKat3HubmJWvY5nDBEcD1p43ktkun6LNJRRwo2CUmtLT7eeDlKh59e3vQc9MBctKSuOfq6Vw6Y1hkmxOJMwp2ibrXqw5w2zOV1Bxpcaz5+NiBPDi3hGF56RHsTCQ+Kdglag7Ut3Ln8vU8t8b5jJeUxARuvGgiXz57LAk660WkTxTsEnHtnQH+8M/t/PK1LTQG+SzSo0pG5HH/tTOYOCQ7gt2JxD8Fu0RUWdUBfrx8PdsOOr85mpKUwPwLJ/Kls8aQpKmMIsdNwS4RsWFvPT99YSNvbDrYY90Z4wZy15XTGKfTGEVOmIJdwmpvXQuLXt7EU+/XOJ7tApCfmcIPL5nMVScVacyuSIgU7BIWtY1t/Ob1rTy+cqfjfBeABAPXnzaK+RdO1JwXEZeEFOzGmDnAj4DJwCxr7So3mpL4dbipnYff3MYf/rmD5nZ/j7WzxuTzo8umMmVYToS6E+kfQt1jrwSuBn7rQi8Sxw40tPLwm9t5fMXOXgN9RH46t86exCXTh+qwi0gYhBTs1toNgP5x9mM7apt4+K1tLFlVQ1sPh1wActOT+eZ54/ncx0eRmqRxACLhErFj7MaYecA8gIKCAsrKyiL10BHX2Njo+fU9/MxrvLi9g9X7/fTwnigAqYlw4ahkPjkmmUz/Lt55a1dE+jxePp8Pv9/v+edO6/O+XoPdGPMqUBjkrgXW2mf7+kDW2oeAhwCKi4ttaWlpX3817pSVleHF9bV3Bnihci+/eGcN2+qCj9I9VkpSAp89bRRfO3ccg7JSI9BhaPLy8vD5fJ587o7y6mvzKK+vr696DXZr7QWRaERi125fC0++u4sn36vmYENbr/VpyQl8ZtYovnLOWIbkpEWgQxE5lk53lKA6/AFe33iAv7xXzetVBwj0drwFyEpN4vrTRvKls8dSkB37e+giXhXq6Y5XAb8CCoDnjDHl1tqLXelMomLz/gaeWl3D0+/vprax971zgCE5qdxw5hiuO20kOWnJYe5QRHoT6lkxS4GlLvUiUXKgvpVlFXtY+sFu1u2p7/PvTSvK4YtnjOGykmGkJGmmi0is0KGYfupQYxsvVO5j+Zo9rNx+uMfL/Y+VlGCYWZDArVfN4tRRA3Sqq0gMUrD3I3vrWnh53X5eWrePldsP4+/LgfNuRXnpXDdrBHNPHcH691fwsdH5YexUREKhYPcway3r99bz2oYDvLZhPxU1dcf1+8mJhoumFDLn1OGcPaGAxO4PulgfjmZFxDUKdo9paO3g7S21/GPTQf5RdZA9fTjf/MNKhudyxcwirjypiHwN5hKJOwr2ONfhD1BR7eOtLbW8tbmW8mofncdxiOWoUQMzuLxkGFfMLGL8YM1CF4lnCvY4094ZoHJPHSu3HeadbYdYteNwr0O3nAwfkM4lM4Zy2YxhTB2WozdCRTxCwR7j6po7eL/6CB/sPMK7Ow5TXu2jtaPnYVs9mTA4i9nTCrl4aqHCXMSjFOwxpL0zwMZ99VTU1FFR7eODXUfY2sNng/ZFUoLhtLH5nDdpCBdMHsyogZkudSsisUrBHiUt7X6q9jewbk8dlbvrqdxdR9W+Btr9J743flRRXjrnFBdwzsQCzhw/iKxUPc0i/Yn+xYdZIGCpOdLCxn31bNrfwMZ9DWzYW8/22qY+zV/pi9z0ZE4fm89ZEwo4a/wgRg/M0CEWkX5Mwe6S1g4/Ow81s+1gI69sbeeZfR+w+UAjWw82hnRMPJj8zBROGTWA08cO5ONjBzKpMJuEBAW5iHRRsB+HlnY/uw43s/NQU/f3ZrbXNrG9tok9dS0fuix/j2uPO7Ygk5NHDuCUUQP42Oh8xhVkao9cRBwp2I9R39rBXl8re3wt1Pha2H2khd2+FqoPN1NzpJnaxvaw9zAoK4WS4XnMGJ7HjBG5nDQij7wMXSQkIn3XL4K9vTPAoaY2DtS3sb++lf0NbRyob2VfXSv7ur/vrWulsa0zon0V5aUzeWgOU4blML0ol+lFuQzJSdXeuIiEJC6D3VpLfWsnR5raOdTUzuGmdg43tVHb2M6hxnZqG9uO+eq6P5qy05KYMDiL4sIcJhVmM3FINlOG5pCbodnlIuK+qAW7tZa2zgD1rR3Ut3R2f++gruV/v9e1dOBr7sDX0oGvuZ0jzV3ffc0dJ3TZfLgV5qQxbnAmqe11nF0ykQmDs5kwJIvB2doLF5HIiUqw72oIMPGHL9Dhj71w7kmCgaG56YzMz2DUwAxGDcxk1MAMxgzq+p6R0vXXWVZWRumZY6LcrYj0V1EJ9oAlJkM9JTGBwtw0huamUTQgneF56RQNSKcoL4MR+ekMzU3XJwWJSMyLy2PsJ2JgZgoF2akMzkljSHYqQ3LSGJLT9X1objpDclMZlJmq88FFJO7FbbCnJyeSn5lCfmYKAzJTGNT9c35WCoOyUinISmVQViqDsrtuJydqT1tE+oeoBntyoiE7LZnc9GRy0pL+9+f0JHLSu34ekJFCXvfPeRkpDMjs+rO05MRoti4iErOiEuwjshOovGs2qUkJOltERMRlUQn2RIP2uEVEwkQHnkVEPEbBLiLiMQp2ERGPUbCLiHhMSMFujLnfGLPRGLPGGLPUGJPnVmMiInJiQt1jfwWYZq2dAWwCvh96SyIiEoqQgt1a+7K19ugQ8xXA8NBbEhGRULh5jP0G4AUXtyciIieg1wuUjDGvAoVB7lpgrX22u2YB0An8qYftzAPmARQUFFBWVnYi/caFxsZGrS8O+Xw+/H6/J9d2lFefu6O8vr6+MtaGNj7XGPMF4KvA+dba5r78TnFxsa2qqgrpcWNZWVkZpaWl0W4jbLy6vtLSUnw+H+Xl5dFuJWy8+twd5fX1GWNWW2tP7a0upJECxpjZwK3AOX0NdRERCa9Qj7H/F5ANvGKMKTfG/D8XehIRkRCEtMdurR3vViMiIuIOXXkqIuIxCnYREY9RsIuIeIyCXUTEYxTsIiIeo2AXEfEYBbuIiMco2EVEPEbBLiLiMQp2ERGPUbCLiHiMgl1ExGMU7CIiHqNgFxHxGAW7iIjHKNhFRDxGwS4i4jEKdhERj1Gwi4h4jIJdRMRjFOwiIh6jYBcR8RgFu4iIxyjYRUQ8RsEuIuIxCnYREY9RsIuIeIyCXUTEY0IKdmPMXcaYNcaYcmPMy8aYYW41JiIiJybUPfb7rbUzrLUzgeXA7S70JCIiIQgp2K219cfczARsaO2IiEiokkLdgDHmbuDzQB1wbg9184B53TfbjDGVoT52DBsE1Ea7iTDy8voGGWO8ujbw9nMH3l9fcV+KjLU972QbY14FCoPctcBa++wxdd8H0qy1d/T6oMasstae2pcG45HWF7+8vDbQ+uJdX9fX6x67tfaCPj7mn4HngF6DXUREwifUs2ImHHPzcmBjaO2IiEioQj3Gfq8xphgIADuBr/bx9x4K8XFjndYXv7y8NtD64l2f1tfrMXYREYkvuvJURMRjFOwiIh4TtWD38jgCY8z9xpiN3etbaozJi3ZPbjLGzDHGrDPGBIwxnjm1zBgz2xhTZYzZYoz5XrT7cZMx5lFjzAGvXj9ijBlhjHndGLOh+7X57Wj35BZjTJox5l1jTEX32u7s9XeidYzdGJNz9MpVY8y3gCnW2r6++RrTjDEXAX+31nYaYxYCWGtvjXJbrjHGTKbrDfPfAjdZa1dFuaWQGWMSgU3AhUAN8B5wnbV2fVQbc4kx5hNAI/BHa+20aPfjNmPMUGCotfZ9Y0w2sBq40gvPnzHGAJnW2kZjTDLwFvBta+0Kp9+J2h67l8cRWGtfttZ2dt9cAQyPZj9us9ZusNZWRbsPl80Ctlhrt1lr24EngSui3JNrrLVvAIej3Ue4WGv3Wmvf7/65AdgAFEW3K3fYLo3dN5O7v3rMy6geYzfG3G2MqQaux7sDxG4AXoh2E9KrIqD6mNs1eHnZTv4AAAGfSURBVCQY+htjzGjgJGBldDtxjzEm0RhTDhwAXrHW9ri2sAa7MeZVY0xlkK8rAKy1C6y1I4A/Ad8IZy9u621t3TULgE661hdX+rI+jzFB/swz/xfZXxhjsoCnge986KhAXLPW+run6A4HZhljejycFvIQsF6a8ew4gt7WZoz5AnApcL6Nw4sFjuO584oaYMQxt4cDe6LUi5yA7uPPTwN/stb+Ndr9hIO11meMKQNmA45vhEfzrBjPjiMwxswGbgUut9Y2R7sf6ZP3gAnGmDHGmBTg08CyKPckfdT9BuMjwAZr7aJo9+MmY0zB0TPrjDHpwAX0kpfRPCvmabpGUP5rHIG1dndUmnGZMWYLkAoc6v6jFV454wfAGHMV8CugAPAB5dbai6PbVeiMMZ8Cfg4kAo9aa++OckuuMcY8AZTSNdZ2P3CHtfaRqDblImPMWcCbwFq6MgXgB9ba56PXlTuMMTOAx+h6XSYAi621P+7xd+LwKIGIiPRAV56KiHiMgl1ExGMU7CIiHqNgFxHxGAW7iIjHKNhFRDxGwS4i4jH/HxzAJhId/onEAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xs = torch.linspace(-3, 3, 100)\n", "ys = torch.selu(xs)\n", "plot(xs.numpy(), ys.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "X97sJV2bEqT2" }, "source": [ "## Automatic differentiation" ] }, { "cell_type": "markdown", "metadata": { "id": "-1wKYOcNc_I-" }, "source": [ "Given some loss function\n", "$$L(\\vec x, \\vec y) = ||2 \\vec x + \\vec y||_2^2$$\n", "we want to evaluate\n", "$$\\frac{\\partial L}{\\partial \\vec x}$$\n", "and\n", "$$\\frac{\\partial L}{\\partial \\vec y}$$" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "ceo_JWyUD0Py" }, "outputs": [], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "5f0cycqiBt0h" }, "source": [ "PyTorch makes this easy by having tensors keep track of their data..." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "GroZ0prJgAZ0" }, "outputs": [ { "data": { "text/plain": [ "tensor([1., 2., 3.])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.data" ] }, { "cell_type": "markdown", "metadata": { "id": "5Ozf4qtnByLZ" }, "source": [ "...and their gradient:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "_9-FwL4wgCDG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "print(x.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "sdnlnBFjB0KT" }, "source": [ "However, right now `x` has no gradient because it does not know what loss it must be differentiated with respect to.\n", "Below, we define the loss." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "GN2CeNSOCZw9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(83., grad_fn=)\n" ] } ], "source": [ "loss = ((2 * x + y)**2).sum()\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": { "id": "OFfIwK0jB8T3" }, "source": [ "And we perform back-propagation by calling `backward` on it." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "UIr4j_hrgN7n" }, "outputs": [], "source": [ "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "sfEKcV4YB-s9" }, "source": [ "Now we see that the gradients are populated!" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "yUSm8_r5gOhi" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([12., 20., 28.])\n", "tensor([ 6., 10., 14.])\n" ] } ], "source": [ "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "rUnvZN9Jg0mg" }, "source": [ "### gradients accumulate\n", "Gradients accumulate, os if you call backwards twice..." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "qwAZ8fyMgQ1z" }, "outputs": [], "source": [ "loss = ((2 * x + y)**2).sum()\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ndyknuW_CEGb" }, "source": [ "...you'll get twice the gradient." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "5bT1AWRngYVJ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([24., 40., 56.])\n", "tensor([12., 20., 28.])\n" ] } ], "source": [ "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "WJ3rei78gyxd" }, "source": [ "### multiple losses" ] }, { "cell_type": "markdown", "metadata": { "id": "atVB4ihsCH7s" }, "source": [ "This accumulation makes it easy to add gradients from different losses, which might not even use the same parameters. For example, this loss is only a function of `x`...." ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "n1Lf5NOcgfN3" }, "outputs": [], "source": [ "other_loss = (x**2).sum()\n", "other_loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "j_TI2Db4CQ5N" }, "source": [ "...and so only `x.grad` changes." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "4NX0Oj1Vgjia" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([26., 44., 62.])\n", "tensor([12., 20., 28.])\n" ] } ], "source": [ "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "_tVQv9nCgpxk" }, "source": [ "### stopping and starting gradients" ] }, { "cell_type": "markdown", "metadata": { "id": "cHO0aQioCXrC" }, "source": [ "If you don't specify `required_grad=True`, the gradient will always be `None`." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 54 }, "id": "ueSASTvNg6mP", "outputId": "3829650d-f4fc-4173-f97a-e27de7f4e262" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([12., 20., 28.])\n", "None\n" ] } ], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape)\n", "loss = ((2 * x + y)**2).sum()\n", "loss.backward()\n", "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "_o-HON3lCbjS" }, "source": [ "You can turn `required_grad` back on after initializing a tensor." ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "tOUN_k0hhFU5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([12., 20., 28.])\n", "tensor([ 6., 10., 14.])\n" ] } ], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape)\n", "y.requires_grad = True\n", "loss = ((2 * x + y)**2).sum()\n", "loss.backward()\n", "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "yHfgIecCCeyq" }, "source": [ "You can cut a gradient by calling `y.detach()`, which will return a new tensor with `required_grad=False`. Note that `detach` is not an in-place operation!" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "id": "_bQbj4u8hL3-", "outputId": "eaa75c2f-8258-4620-e09c-95658546531a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([12., 20., 28.])\n", "tensor([ 6., 10., 14.])\n" ] } ], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape, requires_grad=True)\n", "y_detached = y.detach()\n", "loss = ((2 * x + y)**2).sum()\n", "loss.backward()\n", "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([1., 1., 1.])\n" ] } ], "source": [ "print(y_detached)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "id": "RD-zF6Xna2GJ", "outputId": "60591077-350b-4112-bbde-3cf902c1251d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 40., 72., 104.])\n", "tensor([10., 18., 26.])\n" ] } ], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape, requires_grad=True)\n", "z = 2 * x\n", "z.required_grad = True\n", "loss = ((2 * z + y)**2).sum()\n", "loss.backward()\n", "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 71 }, "id": "YsxDGYysafqT", "outputId": "8d5f0ede-82ac-445e-d007-0ec1854fa5d8" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\kcsgo\\anaconda3\\lib\\site-packages\\torch\\tensor.py:746: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n", " warnings.warn(\"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad \"\n" ] } ], "source": [ "z.grad" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "id": "AkXdmy9Fah_1", "outputId": "aae3245b-3d60-4eab-f8b6-af7767ad459b" }, "outputs": [ { "data": { "text/plain": [ "tensor([1, 2], dtype=torch.int32)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.from_numpy(np.array([1,2]))" ] }, { "cell_type": "markdown", "metadata": { "id": "g0nGMph3amPa" }, "source": [ "Any difference between p.data.add_(-0.001 + p.grad) and p.data+= -0.001 + p.grad?" ] }, { "cell_type": "markdown", "metadata": { "id": "q8B8phm0hdrJ" }, "source": [ "## Modules\n", "![tmp.png]()" ] }, { "cell_type": "markdown", "metadata": { "id": "lLYr9XTskvut" }, "source": [ "`nn.Modules` represent the building blocks of a computation graph.\n", "For example, in typical pytorch code, each convolution block above is its own module, each fully connected block is a module, and the whole network itself is also a module.\n", "Modules can contain modules within them.\n", "All the classes inside of `torch.nn` are instances `nn.Modules`.\n", "Below is an example definition of a module:" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "id": "vwmkBLL0dQiw" }, "outputs": [], "source": [ "import pdb" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "id": "LhkD9c-Uheqd" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class Net(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", " super(Net, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 32)\n", " self.fc2 = nn.Linear(32, 32)\n", " self.fc3 = nn.Linear(32, output_size)\n", "\n", " def forward(self, x):\n", " pdb.set_trace()\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x\n", " \n", " def hook(self, gradient):\n", " return 2*gradient" ] }, { "cell_type": "markdown", "metadata": { "id": "ulXK0yY-DCwt" }, "source": [ "The main function that you need to implement is the `forward` function.\n", "Otherwise, it's a normal Python object:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 104 }, "id": "CvcCU1s1itC3", "outputId": "06896f68-b7da-4ebe-cf8d-04650f2d6c7d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Net(\n", " (fc1): Linear(in_features=1, out_features=32, bias=True)\n", " (fc2): Linear(in_features=32, out_features=32, bias=True)\n", " (fc3): Linear(in_features=32, out_features=1, bias=True)\n", ")\n" ] } ], "source": [ "net = Net(input_size=1, output_size=1)\n", "print(net)" ] }, { "cell_type": "markdown", "metadata": { "id": "RaYLDhamEMKw" }, "source": [ "Here we create some dummy input. The first dimension will be the batch dimension." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "id": "H2AGmUDrhy7i", "outputId": "da201e8d-b17c-4a57-f03a-0e7b75887fed" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 1])\n" ] } ], "source": [ "x = torch.linspace(-5, 5, 100).view(100, 1)\n", "print(x.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "vXIho7F8D9_x" }, "source": [ "To evaluate a neural network on some input, you pass an input through a module by calling it directly. In particular, don't call `net.forward(x)`." ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "JN4mWZlkiEmZ", "outputId": "6706faf8-dfd7-4a8e-8507-ed47d36a742e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "> (16)forward()\n", "-> x = F.relu(self.fc1(x))\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) print(x)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-5.0000],\n", " [-4.8990],\n", " [-4.7980],\n", " [-4.6970],\n", " [-4.5960],\n", " [-4.4949],\n", " [-4.3939],\n", " [-4.2929],\n", " [-4.1919],\n", " [-4.0909],\n", " [-3.9899],\n", " [-3.8889],\n", " [-3.7879],\n", " [-3.6869],\n", " [-3.5859],\n", " [-3.4848],\n", " [-3.3838],\n", " [-3.2828],\n", " [-3.1818],\n", " [-3.0808],\n", " [-2.9798],\n", " [-2.8788],\n", " [-2.7778],\n", " [-2.6768],\n", " [-2.5758],\n", " [-2.4747],\n", " [-2.3737],\n", " [-2.2727],\n", " [-2.1717],\n", " [-2.0707],\n", " [-1.9697],\n", " [-1.8687],\n", " [-1.7677],\n", " [-1.6667],\n", " [-1.5657],\n", " [-1.4646],\n", " [-1.3636],\n", " [-1.2626],\n", " [-1.1616],\n", " [-1.0606],\n", " [-0.9596],\n", " [-0.8586],\n", " [-0.7576],\n", " [-0.6566],\n", " [-0.5556],\n", " [-0.4545],\n", " [-0.3535],\n", " [-0.2525],\n", " [-0.1515],\n", " [-0.0505],\n", " [ 0.0505],\n", " [ 0.1515],\n", " [ 0.2525],\n", " [ 0.3535],\n", " [ 0.4545],\n", " [ 0.5556],\n", " [ 0.6566],\n", " [ 0.7576],\n", " [ 0.8586],\n", " [ 0.9596],\n", " [ 1.0606],\n", " [ 1.1616],\n", " [ 1.2626],\n", " [ 1.3636],\n", " [ 1.4646],\n", " [ 1.5657],\n", " [ 1.6667],\n", " [ 1.7677],\n", " [ 1.8687],\n", " [ 1.9697],\n", " [ 2.0707],\n", " [ 2.1717],\n", " [ 2.2727],\n", " [ 2.3737],\n", " [ 2.4747],\n", " [ 2.5758],\n", " [ 2.6768],\n", " [ 2.7778],\n", " [ 2.8788],\n", " [ 2.9798],\n", " [ 3.0808],\n", " [ 3.1818],\n", " [ 3.2828],\n", " [ 3.3838],\n", " [ 3.4848],\n", " [ 3.5859],\n", " [ 3.6869],\n", " [ 3.7879],\n", " [ 3.8889],\n", " [ 3.9899],\n", " [ 4.0909],\n", " [ 4.1919],\n", " [ 4.2929],\n", " [ 4.3939],\n", " [ 4.4949],\n", " [ 4.5960],\n", " [ 4.6970],\n", " [ 4.7980],\n", " [ 4.8990],\n", " [ 5.0000]])\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) l\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 11 \t self.fc2 = nn.Linear(32, 32)\n", " 12 \t self.fc3 = nn.Linear(32, output_size)\n", " 13 \t\n", " 14 \t def forward(self, x):\n", " 15 \t pdb.set_trace()\n", " 16 ->\t x = F.relu(self.fc1(x))\n", " 17 \t x = F.relu(self.fc2(x))\n", " 18 \t x = self.fc3(x)\n", " 19 \t return x\n", " 20 \t\n", " 21 \t def hook(self, gradient):\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) x.shape\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 1])\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) n\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "> (17)forward()\n", "-> x = F.relu(self.fc2(x))\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) n\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "> (18)forward()\n", "-> x = self.fc3(x)\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) p x\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[1.8007, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.2302],\n", " [1.7679, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.2086],\n", " [1.7351, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.1871],\n", " ...,\n", " [0.0000, 0.0000, 1.0427, ..., 0.0000, 0.0000, 1.0551],\n", " [0.0000, 0.0000, 1.0706, ..., 0.0000, 0.0000, 1.0787],\n", " [0.0000, 0.0000, 1.0985, ..., 0.0000, 0.0000, 1.1023]],\n", " grad_fn=)\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) print(x.shape)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 32])\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) c\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 1])\n" ] } ], "source": [ "y = net(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "IHO5oRs1EQxz" }, "source": [ "Let's visualize what the networks looks like." ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 269 }, "id": "vJu-Tt2jiF2y", "outputId": "0c88791c-975d-4f30-c7e9-dfa2ebafee05" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAf/UlEQVR4nO3dfXRc9Z3f8fdXz7ZlWZIfZGPLD2ATwGaDQZgNtIl4Ntkcm7QkhZ7dddrkuNmGbLJp0pClJ7slSw/p9pSeZtNNnIQN2aSBJNsk3q6NQwCFNASwAAUkg23hRyHLsi2NbNl6HH37x72yBnkkjTwjzdj38zpHR3N/93dnvv6d8Xx0H353zN0REZHoyst2ASIikl0KAhGRiFMQiIhEnIJARCTiFAQiIhGnIBARibiMBIGZPWZm7WbWOMZ6M7P/aWbNZva6mV2bsG6Tme0NfzZloh4REUldpvYIvgusH2f9XcCq8Gcz8LcAZlYJ/AVwA7AO+Aszq8hQTSIikoKMBIG7Pw90jNNlI/A9D7wIlJvZIuBO4Gl373D3TuBpxg8UERHJsIJpep3FwOGE5Zawbaz2c5jZZoK9CWbMmHFddXX11FSaoqGhIfLydIoFNBbDDh8+jLuzdOnSbJeSE/S+GJErY7Fnz57j7j5/dPt0BYElafNx2s9tdN8CbAGoqanx+vr6zFV3Hurq6qitrc1qDblCYxGora0lFovR0NCQ7VJygt4XI3JlLMzsYLL26YqoFiDxT/glQOs47SIiMk2mKwi2An8cXj30+0CXux8BdgB3mFlFeJL4jrBNRESmSUYODZnZD4FaYJ6ZtRBcCVQI4O7fALYBHwSagTPAvwnXdZjZV4Cd4VM95O7jnXQWEZEMy0gQuPt9E6x34FNjrHsMeCwTdYiIyORl/zS2iIhklYJARCTiFAQiIhGnIBARiTgFgYhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiEScgkBEJOIyEgRmtt7MdptZs5k9kGT9o2bWEP7sMbNYwrp4wrqtmahHRERSl/Z3FptZPvB14HagBdhpZlvdfddwH3f/s4T+nwbWJjxFj7tfk24dIiJyfjKxR7AOaHb3fe7eDzwBbByn/33ADzPwuiIikgGZCILFwOGE5Zaw7RxmtgxYATyb0FxiZvVm9qKZ3Z2BekREZBLSPjQEWJI2H6PvvcBP3D2e0LbU3VvN7FLgWTN7w93fPudFzDYDmwGqqqqoq6tLs+z0dHd3Z72GXKGxCMRiMeLxuMYipPfFiFwfi0wEQQtQnbC8BGgdo++9wKcSG9y9Nfy9z8zqCM4fnBME7r4F2AJQU1PjtbW16dadlrq6OrJdQ67QWATKy8uJxWIai5DeFyNyfSwycWhoJ7DKzFaYWRHBh/05V/+Y2XuACuC3CW0VZlYcPp4H3ATsGr2tiIhMnbT3CNx90MzuB3YA+cBj7t5kZg8B9e4+HAr3AU+4e+JhoyuBb5rZEEEoPZJ4tZGIiEy9TBwawt23AdtGtX151PJfJtnuBeDqTNQgIiLnRzOLRUQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiEScgkBEJOIUBCIiEacgEBGJOAWBiEjEKQhERCJOQSAiEnEKAhGRiMtIEJjZejPbbWbNZvZAkvUfM7NjZtYQ/nwiYd0mM9sb/mzKRD0iIpK6tL+83szyga8DtwMtwE4z2+ruu0Z1fdLd7x+1bSXwF0AN4MAr4bad6dYlIiKpycQewTqg2d33uXs/8ASwMcVt7wSedveO8MP/aWB9BmoSEZEUpb1HACwGDicstwA3JOn3L83s/cAe4M/c/fAY2y5O9iJmthnYDFBVVUVdXV36laehu7s76zXkCo1FIBaLEY/HNRYhvS9G5PpYZCIILEmbj1r+R+CH7t5nZp8EHgduSXHboNF9C7AFoKamxmtra8+74Eyoq6sj2zXkCo1FoLy8nFgsprEI6X0xItfHIhOHhlqA6oTlJUBrYgd3P+HufeHit4DrUt1WRESmViaCYCewysxWmFkRcC+wNbGDmS1KWNwAvBk+3gHcYWYVZlYB3BG2iYjINEn70JC7D5rZ/QQf4PnAY+7eZGYPAfXuvhX4UzPbAAwCHcDHwm07zOwrBGEC8JC7d6Rbk4iIpC4T5whw923AtlFtX054/CXgS2Ns+xjwWCbqEBGRydPMYhGRiFMQiIhEnIJARCTiFAQiIhGnIBARiTgFgYhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiERcRoLAzNab2W4zazazB5Ks/5yZ7TKz183sGTNblrAubmYN4c/W0duKiMjUSvs7i80sH/g6cDvQAuw0s63uviuh22tAjbufMbM/Af4r8K/CdT3ufk26dYiIyPnJxB7BOqDZ3fe5ez/wBLAxsYO7P+fuZ8LFF4ElGXhdERHJgLT3CIDFwOGE5RbghnH6fxzYnrBcYmb1wCDwiLv/LNlGZrYZ2AxQVVVFXV1dOjWnrbu7O+s15AqNRSAWixGPxzUWIb0vRuT6WGQiCCxJmyftaPaHQA3wgYTmpe7eamaXAs+a2Rvu/vY5T+i+BdgCUFNT47W1tWkXno66ujqyXUOu0FgEysvLicViGouQ3hcjcn0sMnFoqAWoTlheArSO7mRmtwEPAhvcvW+43d1bw9/7gDpgbQZqEhGRFGUiCHYCq8xshZkVAfcC77r6x8zWAt8kCIH2hPYKMysOH88DbgISTzKLiMgUS/vQkLsPmtn9wA4gH3jM3ZvM7CGg3t23An8NlAI/NjOAQ+6+AbgS+KaZDRGE0iOjrjYSEZEplolzBLj7NmDbqLYvJzy+bYztXgCuzkQNIiJyfjSzWEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiEScgkBEJOIUBCIiEacgEBGJOAWBiEjEKQhERCJOQSAiEnEKAhGRiFMQiIhEnIJARCTiFAQiIhGXkSAws/VmttvMms3sgSTri83syXD9S2a2PGHdl8L23WZ2ZybqERGR1KUdBGaWD3wduAu4CrjPzK4a1e3jQKe7rwQeBb4abnsVcC+wGlgP/K/w+UREZJpk4svr1wHN7r4PwMyeADYCuxL6bAT+Mnz8E+BvzMzC9ifcvQ/Yb2bN4fP9drwX3L17N7W1tRko/fzFYjHKy8uzWkOu0FgEGhoaGBwczPp7M1fofTEi18ciE0GwGDicsNwC3DBWH3cfNLMuYG7Y/uKobRcnexEz2wxsBigsLCQWi2Wg9PMXj8ezXkOu0FgEBgcHcXeNRUjvixG5PhaZCAJL0uYp9kll26DRfQuwBaCmpsbr6+snU2PG1dXV6S+/kMYiUFtbSywWo6GhIdul5AS9L0bkylgEB2LOlYmTxS1AdcLyEqB1rD5mVgDMATpS3FZERKZQJoJgJ7DKzFaYWRHByd+to/psBTaFj+8BnnV3D9vvDa8qWgGsAl7OQE0iIpKitA8Nhcf87wd2APnAY+7eZGYPAfXuvhX4DvD34cngDoKwIOz3I4ITy4PAp9w9nm5NIiKSukycI8DdtwHbRrV9OeFxL/CRMbZ9GHg4E3WIiMjkaWaxiEjEKQhERCJOQSAiEnEKAhGRiFMQiIhEnIJARCTiMnL56HRrjfXwX7a9SVF+HkUF4U/C4+LwJ2jPp7gw75y+Z9cX5FFckE9+XvKp1yLD+geHONbdR1tXL0dP9rL/+Gn2Hj3F3vZu3on10HygE/chrvvK08wqLqCqrJgFZSUsLCthaeVMls2dybK5s1g0p4SSQt1kV3LHBRkEJ073s+X5fRl9zjzjbCgkC4vRQdLV0cs/tv/ubPAk6zNWMI3uU1yQd87zFORrZ22qxYecE919HOnqpe1kL21dvRzp6uVU7wC9A0P0Dsbp7h2k/VQfx071cry7f9znGxwaAoL354nT/RzqODNm3zkzCpk/u5iFZSVUhyGxtHImZSWFzCzOZ2ZRPqXFBcyZUUhpccGY94gRyYQLMgimwpAT/OcfGEp5m5faWqasnuFgCsIi/10hkWwPJzF0RgIm/90BkyR8hoPp3XtIwbbFhXmUFORTmG8X7AfRqd4B9h07zf7jpzl44gyHOs5wuOMM78R6OHqyl8GhpPc4nHJdPQN09QzQ3N49Yd/8PKNiZiGLy2ewpGImSypnMHdWEWUlhZTNKKSspJDZJQXMLilgVnHwX9odhnzk32YGp/sGaT/ZR/upPrr7Blk2dybXLq04u41El94BOerdwTSY1VrMoCQhGEoKg6AY/n3mVC/fP1hPSWEeJYVBe2L/4eAqLsynKD+PwoI8ivKNwvy8hJ+R5aICoyAvj4KwLT/PyE8IoiF34kPO4JAzGHdO9g5wMvxgbensYd/xbvYdO82+46c5dqoviyOXGfEh53h3P8e7+/ldS1dGnzs/z1h9SRnXL69k3YpKrl9eSeWsooy+huQ+BYFMyB16BuL0DMSBgaR9Gk8cnd6iJCPiQ87rLV283tLFd/7ffgBWLijl+uUV1CwLwmFJxYwLdo9QUnNBBsGiOSV8cf0V9A8O0R+PB78Hh+iPD9E3/DhcHn7cN6otWI6f3cazc4RALjDzSouoCk8AL5xTwsoFpVxeNZvL5pdyzwuVdHV18eyDt9HV009bVx9tJ3tpjfVw4MTIoakT3X1k6YhUSprbu2lu7+aHLwffNzV/djHXLi3n2qUVXLusgqsXz9HJ7ovMBRkE80qL+ZPayzL6nIPxMEgGRofFuwNlOHhee72Jyy5/T9LQGdk+fvY5xgqo0cHUF7YpmKZHWUkBC+eUsGjODBbNKaGqrIS5pUWUFAbnZWYU5jN/djFVZSXMKy2mqGDsk/gFeUa+BR+c82cXs3LB7KT94kPOidN9tJ/s451YDwfDkGjr6uV0/yA9/XFO98c51Rsc7prMeaupcOxUHzuajrKjKdjrK8rP4/eWzKFmeSXXL6/gumUVlM/U4aQL2QUZBFOhILxSJ9X3c8nx3dTWVE/c8Ty4B8e/k4dFEDDJAmXsvaH4hH1Gh2DvQBBKvQPxrJ1QzYT8PKO6YgaXzi9l+dxZwdU5c2dSXTGTRXNKsnKiND/PWDC7hAWzS1izeM6E/XsH4hw71UdLZw+HO8/QGuvhZM/g2XMjp3qDx6d6BznTH8csuNjAMMw4+0dFYYExv7SYBbNLKMg3XjsU451Yz6Tr748PUX+wk/qDnXzjV0Hb5VWlZ4Ph+uWVLKmYOennlexREOQgMzt78nRWcbarCfaW+kaFQ+9AEDC9A0PsfPU1Lr9y9Tnr+sJLMIcDpi9sHxhyBgaHGIgPnQ28gXCPbDDu9MeD5Xg8PCE85O+6AgagIC84wZyfZ5QWF5y9gmburCJWzJ/FpfNmcen8WSytnDXuX/EXgpLCfKorZ1JdOZP3MTejz90a62HngQ5e3h/87E3hKqZk9hztZs/Rbv73S4cAuGROCUtnDnC45CDrlleyakEpeZqrk7MUBDKh4b2lsf567jtcQO2aRdNclWTCJeUz2HjNYjZesxiAztP9vHKwk50HO6g/0MnrLTEG4pPfI2zt6qW1C178WSMQzJuoWVbB9SuCvYarF5df8AF9MVEQiMhZFbOKuO2qKm67qgoIDks1tZ7ktUOdvHqok50HOs/rktyungGeeaudZ95qB6C4II9rqsvPXrJ67bIKSjWfIWs08iIyppLCfK5bFpwQhuD81eGOHl4+0EH9gQ52Hujg7WOnJ/28fYNDvLS/g5f2dwDBeZOrFg3PZ6igZnkl80pz4LhoRKQVBGZWCTwJLAcOAB91985Rfa4B/hYoA+LAw+7+ZLjuu8AHgOFZMh9z94Z0ahKRqWNmLA1PuN9z3RIATnT3BSePD3Tw8oFOmt7pmvQFBvEh5413unjjnS4e+00wn+HS+bN4/6r5fOA983nfpXN1yeoUSneP4AHgGXd/xMweCJe/OKrPGeCP3X2vmV0CvGJmO9w9Fq7/grv/JM06RCRL5pYWc+fqhdy5eiEAZ/oHaTgU48d1r3KMObx6qJMz/fFJP+++Y6fZd+w0333hAMUFedy0ch53XFXFrVdWMX+29hYyKd0g2AjUho8fB+oYFQTuvifhcauZtQPzgRgictGZWVTAjSvn0d9SRG3tDQzEh9jVevLs1Un1BzvpOD3+DfxG6xsc4tm32nn2rXbM3uC6pRWsXxOET3WlLlVNl3kaM5fMLObu5QnLne5eMU7/dQSBsdrdh8JDQ+8D+oBngAfcPemZKDPbDGwGqKqquu6JJ54477ozobu7m9LS0qzWkCs0FoHPfvazxONxvva1r2W7lJww1vvC3Tly2tnTGWdv5xB7OuMc6zn/z6FlZXlcV5VPTVUBl5Tm5pVIufJ/5Oabb37F3WtGt08YBGb2S2BhklUPAo+nGgRmtohgj2GTu7+Y0NYGFAFbgLfd/aGJ/jE1NTVeX18/UbcpVVdXR21tbVZryBUai0BtbS2xWIyGBp3mgsm9L4509bDzQCc79wcnoHcfPXVes+tXLijlrnBPYfUlZTlzj6Rc+T9iZkmDYMJDQ+5+2zhPetTMFrn7kfBDvX2MfmXAPwH/aTgEwuc+Ej7sM7O/Az4/UT0icvFZNGcGG947gw3vvQQI5jP8uvk4dbvb+dXuY5xI8VBSc3s3X3u2ma8920x15QzWr17I+jWLWFtdrglt40j3HMFWYBPwSPj756M7mFkR8FPge+7+41HrhkPEgLuBxjTrEZGLQMWsIja89xI2vPcShoac1w7H+MWuNn7RdJT9x1O7XPVwRw/f+vV+vvXr/VSVBSe0169ZyLrllfrip1HSDYJHgB+Z2ceBQ8BHAMysBviku38C+CjwfmCumX0s3G74MtEfmNl8wIAG4JNp1iMiF5m8PDs7l+GB9VfQ3N7N9sY2nmpsY9eRkyk9x9GTfXzvtwf53m8PUjmriNuvrGL91Qu58bK5FBfostS0gsDdTwC3JmmvBz4RPv4+8P0xtr8lndcXkWgxM1ZVzWZV1Wz+9NZVHDpxhqeajvBUYxuvHkrtQsSO0/08WX+YJ+sPM7u4gFuvXMD6NQv5wOULmFEUzVDQzGIRuWAtnTuTze+/jM3vv4y2rl52NLWxvfEIL+/vSOk7H071DfKzhlZ+1tBKSWEetZcv4K6rF3LzFQsoKymc+n9AjlAQiMhFYeGcEjbduJxNNy7nRHcfT+86yvbGNl54+3hKN87rHRjiqaY2nmpqoyg/j5tWzuWuNYu47aqqi/7rOxUEInLRmVtazL3rlnLvuqV09Qzw7FtHeaqxjV/tOZbSF/30x4d4bvcxntt9jPyfGjesqOSuNQu5Y/VCqspKpuFfML0UBCJyUZszo5APr13Ch9cu4Uz/IL/afYynmtp45s12uvsGJ9w+PuS88PYJXnj7BF/e2sS1SyvCy1IvnlnNCgIRiYyZRQXcdfUi7rp6EX2DcV5oPsH2xiM8vesonWcGJtzeHV452MkrBzt5eNubrL6kjLvWBHMVVi7I/szh86UgEJFIKi7I5+YrFnDzFQsYjA/x8v4Otje2saOpjfYUv3OhqfUkTa0n+W+/2HN2VvP6NQu5alHuzGpOhYJARCKvID+PG1fO48aV8/jPG1bz2uFOtr/RxvbGtpS/1zlxVvPSyplnb4q3trp84o2zTEEgIpIgmMBWyXXLKnnwD66kqfUkTzUGl6Wm+iU8hzrOsOX5fWx5fh9VZcWsKY9TVH08Z2c1KwhERMZgZqxZPIc1i+fw+TvfQ3P7qbN7CpOZ1Xz0JDzzrZdGZjWvWciNK3NnVrOCQEQkRSsXzObTt87m0wmzmrc3tvHaBT6rWUEgInIeMjmreUZhPrXvmc/6NQu55YoFzJ7mWc0KAhGRNCWb1fxUUxu/aU5tVnPPQJztjcEhp2zMalYQiIhk0OhZzc+91c7f171BU4ef96zm4SuQpmpWs4JARGSKzJlRyN1rF1PetZd1N/6z9GY1/7yJa5eWc9eaRRmf1awgEBGZBqNnNf+m+ThPNbalPKsZ4NVDMV49FOPhbW+yZnHZ2W9gS3dWs4JARGSaFRfkc8sVVdxyRdV5z2pufOckje+8e1bz+X5Xs4JARCSLxprV/FRTGy2dk5/VfD7f1awgEBHJEclmNW9vDOYq7EtxVnPS72pevZB1KyrH3CatIDCzSuBJYDlwAPiou3cm6RcH3ggXD7n7hrB9BfAEUAm8CvyRu/enU5OIyMUgcVbzF+68gr1HT4W3ujj/72oeS7o3vXgAeMbdVwHPhMvJ9Lj7NeHPhoT2rwKPhtt3Ah9Psx4RkYvSqqpgRvO2z/xznv/Czfz5B69g7dLUb2jXcXrsv7HTDYKNwOPh48eBu1Pd0IKzGbcAPzmf7UVEomp4VvNP//1NvPilW3lo42red+lcUjgdkFS65wiq3P0IgLsfMbMFY/QrMbN6YBB4xN1/BswFYu4+fDFtC7B4rBcys83AZoCqqirq6urSLD093d3dWa8hV2gsArFYjHg8rrEI6X0xYqrHYinw7y6H+5bP5LX2QV5pi9N0Ik4Kk5qBFILAzH4JLEyy6sHJ1OnurWZ2KfCsmb0BJDvINWbZ7r4F2AJQU1PjtbW1k3j5zKurqyPbNeQKjUWgvLycWCymsQjpfTFiOsdi+Nj78Kzm7Y1HJvyu5gmDwN1vG2udmR01s0Xh3sAioH2M52gNf+8zszpgLfAPQLmZFYR7BUuA1onqERGRiQ3Par577eKz39X8wa8m75vuOYKtwKbw8Sbg56M7mFmFmRWHj+cBNwG73N2B54B7xtteRETSMzyreSzpBsEjwO1mthe4PVzGzGrM7NthnyuBejP7HcEH/yPuvitc90Xgc2bWTHDO4Dtp1iMiIpOU1slidz8B3JqkvR74RPj4BeDqMbbfB6xLpwYREUlP7n15poiITCsFgYhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiEScgkBEJOIUBCIiEacgEBGJOAWBiEjEpRUEZlZpZk+b2d7wd0WSPjebWUPCT6+Z3R2u+66Z7U9Yd0069YiIyOSlu0fwAPCMu68CngmX38Xdn3P3a9z9GuAW4Azwi4QuXxhe7+4NadYjIiKTlG4QbAQeDx8/Dtw9Qf97gO3ufibN1xURkQxJNwiq3P0IQPh7wQT97wV+OKrtYTN73cweNbPiNOsREZFJKpiog5n9EliYZNWDk3khM1sEXA3sSGj+EtAGFAFbgC8CD42x/WZgM0BVVRV1dXWTefmM6+7uznoNuUJjEYjFYsTjcY1FSO+LEbk+FhMGgbvfNtY6MztqZovc/Uj4Qd8+zlN9FPipuw8kPPeR8GGfmf0d8Plx6thCEBbU1NR4bW3tRKVPqbq6OrJdQ67QWATKy8uJxWIai5DeFyNyfSzSPTS0FdgUPt4E/Hycvvcx6rBQGB6YmRGcX2hMsx4REZmkdIPgEeB2M9sL3B4uY2Y1Zvbt4U5mthyoBn41avsfmNkbwBvAPOCv0qxHREQmacJDQ+Nx9xPArUna64FPJCwfABYn6XdLOq8vIiLp08xiEZGIUxCIiEScgkBEJOIUBCIiEacgEBGJOAWBiEjEKQhERCJOQSAiEnEKAhGRiFMQiIhEnIJARCTiFAQiIhGnIBARiTgFgYhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4tIKAjP7iJk1mdmQmdWM02+9me02s2YzeyChfYWZvWRme83sSTMrSqceERGZvHT3CBqBfwE8P1YHM8sHvg7cBVwF3GdmV4Wrvwo86u6rgE7g42nWIyIik5RWELj7m+6+e4Ju64Bmd9/n7v3AE8BGMzPgFuAnYb/HgbvTqUdERCavYBpeYzFwOGG5BbgBmAvE3H0woX3xWE9iZpuBzeFit5lNFEBTbR5wPMs15AqNxYh5ZqaxCOh9MSJXxmJZssYJg8DMfgksTLLqQXf/eQovbEnafJz2pNx9C7AlhdebFmZW7+5jnheJEo3FCI3FCI3FiFwfiwmDwN1vS/M1WoDqhOUlQCtBOpabWUG4VzDcLiIi02g6Lh/dCawKrxAqAu4Ftrq7A88B94T9NgGp7GGIiEgGpXv56IfNrAV4H/BPZrYjbL/EzLYBhH/t3w/sAN4EfuTuTeFTfBH4nJk1E5wz+E469UyznDlMlQM0FiM0FiM0FiNyeiws+MNcRESiSjOLRUQiTkEgIhJxCoIMMLPPm5mb2bxs15ItZvbXZvaWmb1uZj81s/Js1zTdxrqVStSYWbWZPWdmb4a3oPlMtmvKJjPLN7PXzOz/ZruWsSgI0mRm1cDtwKFs15JlTwNr3P33gD3Al7Jcz7Sa4FYqUTMI/Ad3vxL4feBTER4LgM8QXCiTsxQE6XsU+I+MMxkuCtz9FwmzxF8kmBcSJUlvpZLlmrLC3Y+4+6vh41MEH4Jj3jXgYmZmS4A/AL6d7VrGoyBIg5ltAN5x999lu5Yc82+B7dkuYpolu5VKJD/8EpnZcmAt8FJ2K8ma/0Hwh+JQtgsZz3Tca+iCNt4tNoA/B+6Y3oqyJ5XbjZjZgwSHBn4wnbXlgEndMiUKzKwU+Afgs+5+Mtv1TDcz+xDQ7u6vmFlttusZj4JgAmPdYsPMrgZWAL8LbqTKEuBVM1vn7m3TWOK0meh2I2a2CfgQcKtHb4LKWLdSiSQzKyQIgR+4+//Jdj1ZchOwwcw+CJQAZWb2fXf/wyzXdQ5NKMsQMzsA1Lh7LtxhcNqZ2XrgvwMfcPdj2a5nuplZAcFJ8luBdwhurfKvE2bRR0Z4i/nHgQ53/2y268kF4R7B5939Q9muJRmdI5BM+RtgNvC0mTWY2TeyXdB0muBWKlFzE/BHwC3he6Eh/KtYcpT2CEREIk57BCIiEacgEBGJOAWBiEjEKQhERCJOQSAiEnEKAhGRiFMQiIhE3P8H6hvm0FfV+u8AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "ccm0jfOcETl0" }, "source": [ "The network keeps track of all the parameters and gradients!" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "id": "_nQx9_nDmKXs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "print(net.fc1.bias.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "oirQcFSmDVW5" }, "source": [ "In the `__init__` function, any variable that you assign to `self` that is also a module will be automatically added as a sub-module. The parameters of a module (and all sub-modules) can be accessed through the `parameters()` function:" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 121 }, "id": "Vwrjulr8mlwR", "outputId": "6cc62d8f-491f-482e-ba48-e02f0745cd6a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 1])\n", "torch.Size([32])\n", "torch.Size([32, 32])\n", "torch.Size([32])\n", "torch.Size([1, 32])\n", "torch.Size([1])\n" ] } ], "source": [ "for p in net.parameters():\n", " print(p.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "RAkasCBWDd2D" }, "source": [ "WARNING: if you want to have a list of modules use\n", "```\n", "def __init__(self, network1, network2):\n", " self.list = nn.ModuleList([network1, network 2])\n", "```\n", "and **not** \n", "```\n", "def __init__(self, network1, network2):\n", " self.list = [network1, network 2]\n", "```\n", "In the later case, `network1` and `network2` won't be added as sub-modules." ] }, { "cell_type": "markdown", "metadata": { "id": "PcP0KwA4EYH-" }, "source": [ "The output of the module is just a tensor. We can perform operations on the tensor like before to automatically compute derivatives.\n", "For example, below, we minimize the sum-of-squares of the output." ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "id": "tqDLwkAXmSV3" }, "outputs": [], "source": [ "loss = (y**2).sum()\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "mzAJE131EkW5" }, "source": [ "We can manually update the parameters by adding the gradient (times a negative learning rate) and zero'ing out the gradients to prevent gradient accumulation." ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "id": "1tT_gl-QmfMx" }, "outputs": [], "source": [ "for p in net.parameters():\n", " p.data.add_(-0.001 * p.grad)\n", " p.grad.data.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "cBSnExmmEsy5" }, "source": [ "And we can do this in a loop to train our network!" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "class Net2(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", " super(Net2, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 32)\n", " self.fc2 = nn.Linear(32, 32)\n", " self.fc3 = nn.Linear(32, output_size)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x\n", " \n", " def hook(self, gradient):\n", " return 2*gradient\n", " \n", "net = Net2(input_size=1, output_size=1)" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "id": "mrhmNphPnWEQ" }, "outputs": [], "source": [ "for _ in range(100):\n", " y = net(x)\n", " loss = (y**2).sum()\n", " loss.backward()\n", " for p in net.parameters():\n", " p.data.add_(- 0.001 * p.grad)\n", " p.grad.data.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "oGe1_gpCE1hs" }, "source": [ "Sure enough, our network learns to set everything to zero." ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "id": "YnZ9swUJnp1d" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAaA0lEQVR4nO3df5Dcd33f8edr935I8q+TkH2WJRlEKzJAyNhwY+iQIVdjOU7KWG5LgmHSiNaMph3c/KCk2HUHMk6YMc0MJJkwDSo4mEIwCSlBbcwY/2CbacEgEYR/1pYsG3SWbNmWFvl00t3t7rt/fL+n+97pfuj0Xd0ufF6PmZ39fj/fz3e/7/vc3r72+2P3FBGYmVm6Kp0uwMzMOstBYGaWOAeBmVniHARmZolzEJiZJc5BYGaWuLYEgaQ7JR2W9Og8yyXpTyXtk/SwpDcXlm2TtDe/bWtHPWZmdubatUfweeC6BZb/CrA5v20H/iuApDXAx4C3AlcBH5O0uk01mZnZGWhLEETE3wNHFuiyFfhCZB4CBiStA34ZuC8ijkTEUeA+Fg4UMzNrs55l2s564EBhfiRvm6/9NJK2k+1NsHLlyrds3Ljx3FR6hlqtFpWKT7GAx2LKgQMHiAguv/zyTpfSFfy8mNYtY/HUU0+9FBEXz25friDQHG2xQPvpjRE7gB0AQ0NDsXv37vZVdxZqtRrDw8MdraFbeCwyw8PD1Ot19uzZ0+lSuoKfF9O6ZSwk/Wiu9uWKqBGg+BZ+A3BwgXYzM1smyxUEO4HfzK8eehvwk4g4BNwLXCtpdX6S+Nq8zczMlklbDg1J+jIwDKyVNEJ2JVAvQET8OXAP8KvAPmAM+Nf5siOS/gDYlT/U7RGx0ElnMzNrs7YEQUS8d5HlAXxwnmV3Ane2ow4zM1u6zp/GNjOzjnIQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklri1BIOk6SU9K2ifpljmWf0rSnvz2lKR6YVmzsGxnO+oxM7MzV/p/FkuqAp8GtgAjwC5JOyPi8ak+EfG7hf7/Hriy8BAnIuKKsnWYmdnZaccewVXAvojYHxETwN3A1gX6vxf4chu2a2ZmbdCOIFgPHCjMj+Rtp5H0amAT8GCheYWk3ZIeknRDG+oxM7MlKH1oCNAcbTFP3xuBr0ZEs9B2eUQclPRa4EFJj0TE06dtRNoObAcYHBykVquVLLuc0dHRjtfQLTwWmXq9TrPZ9Fjk/LyY1u1j0Y4gGAE2FuY3AAfn6Xsj8MFiQ0QczO/3S6qRnT84LQgiYgewA2BoaCiGh4fL1l1KrVaj0zV0C49FZmBggHq97rHI+XkxrdvHoh2HhnYBmyVtktRH9mJ/2tU/kn4OWA18p9C2WlJ/Pr0WeDvw+Ox1zczs3Cm9RxARDUk3A/cCVeDOiHhM0u3A7oiYCoX3AndHRPGw0euBz0hqkYXSHcWrjczM7Nxrx6EhIuIe4J5ZbR+dNf/7c6z3beBN7ajBzMzOjj9ZbGaWOAeBmVniHARmZolzEJiZJc5BYGaWOAeBmVniHARmZolzEJiZJc5BYGaWOAeBmVniHARmZolzEJiZJc5BYGaWOAeBmVniHARmZolzEJiZJc5BYGaWOAeBmVniHARmZolrSxBIuk7Sk5L2SbpljuXvl/SipD357QOFZdsk7c1v29pRj5mZnbnS/7xeUhX4NLAFGAF2SdoZEY/P6vqViLh51rprgI8BQ0AA38/XPVq2LjMzOzPt2CO4CtgXEfsjYgK4G9h6huv+MnBfRBzJX/zvA65rQ01mZnaGSu8RAOuBA4X5EeCtc/T7l5LeATwF/G5EHJhn3fVzbUTSdmA7wODgILVarXzlJYyOjna8hm7hscjU63WazabHIufnxbRuH4t2BIHmaItZ8/8T+HJEjEv6t8BdwNVnuG7WGLED2AEwNDQUw8PDZ11wO9RqNTpdQ7fwWGQGBgao1+sei5yfF9O6fSzacWhoBNhYmN8AHCx2iIiXI2I8n/1vwFvOdF0zMzu32hEEu4DNkjZJ6gNuBHYWO0haV5i9Hngin74XuFbSakmrgWvzNjMzWyalDw1FREPSzWQv4FXgzoh4TNLtwO6I2An8lqTrgQZwBHh/vu4RSX9AFiYAt0fEkbI1mZnZmWvHOQIi4h7gnlltHy1M3wrcOs+6dwJ3tqMOMzNbOn+y2MwscQ4CM7PEOQjMzBLnIDAzS5yDwMwscQ4CM7PEOQjMzBLnIDAzS5yDwMwscQ4CM7PEOQjMzBLnIDAzS5yDwMwscQ4CM7PEOQjMzBLnIDAzS5yDwMwscQ4CM7PEtSUIJF0n6UlJ+yTdMsfyD0l6XNLDkh6Q9OrCsqakPflt5+x1zczs3Cr9P4slVYFPA1uAEWCXpJ0R8Xih2w+AoYgYk/TvgP8CvCdfdiIirihbh5mZnZ127BFcBeyLiP0RMQHcDWwtdoiIb0XEWD77ELChDds1M7M2KL1HAKwHDhTmR4C3LtD/JuAbhfkVknYDDeCOiPjbuVaStB3YDjA4OEitVitTc2mjo6Mdr6FbeCwy9XqdZrPpscj5eTGt28eiHUGgOdpizo7SbwBDwC8Vmi+PiIOSXgs8KOmRiHj6tAeM2AHsABgaGorh4eHShZdRq9XodA3dwmORGRgYoF6veyxyfl5M6/axaMehoRFgY2F+A3BwdidJ1wC3AddHxPhUe0QczO/3AzXgyjbUZGZmZ6gdQbAL2Cxpk6Q+4EZgxtU/kq4EPkMWAocL7asl9efTa4G3A8WTzGZmdo6VPjQUEQ1JNwP3AlXgzoh4TNLtwO6I2An8EXA+8NeSAH4cEdcDrwc+I6lFFkp3zLrayMzMzrF2nCMgIu4B7pnV9tHC9DXzrPdt4E3tqMHMzM6OP1lsZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklri3/mMbsZ0lEcGKyyeh4gxMTTSabLSabQaMZnGw0OTHR5MRkk2MnJnlpdIKXR8epn5jkmZeOMznR4pPffJLLBlbyjtddzGUDKzv945gtykFgSXnl5CQ/PjLG8z85ycujE7x8fIKXRsd57ugJnquf4GD9BEfHJmjF0h/7hWMnAfjTB/edavv59Rdy7RsuZfjnLuaNl11EtaJ2/ShmbdOWIJB0HfAnZP+z+LMRcces5f3AF4C3AC8D74mIZ/NltwI3AU3gtyLi3sW298ShY1z18fvpqYhKRVQroqr8viIq+XSlIqriVHtxWU9hesZNoqcqeiqVU/2q1by9IqqVCj1V8aNnJ9hX3Z8vr+TLVLifuX7PHHVWKyICguxdaPbik91HTLcFWaesXzbfypdPzUcw3Zb3A2Zsv6dSoa9H9FYrp2591Qq9hba+aiX/+UX+/6W7ytS79ePjTcYmGhwfz96dn5zM3qmPTTYZG29wfKLJKycnOfzKOIePneSFY+OMHB3j6Njkstb76HPHePS5Y3zyvqe4cEUPb3vtq7jkwn5OTrY4Odk8dX9issnYRJMTE1ntY+MNWgEVQUWiv7fKprWr+EcXn5/dLjmP1649nw2rV9JT7fwR3mYrGJvI9qCOT2S/m71Hm+ipF7Ofaer3NdGkWUhZ5T9fVTP/XiuVYpuoVJj+ey20z+zLjLbpx+XUehJIQmR/L81W0GwFrcjumxG0WtPTM5czo28rsr3EU+vMWBdaraCR99v39AQ/mHzqtO00WtPrNlvMeJxGPh2FOmY+dotWi1N1tmbVW3zsVgtaMf+7m9JBIKkKfBrYAowAuyTtnPVP6G8CjkbEP5Z0I/AJ4D2S3gDcCLwRuAy4X9LrIqK50DYbreDwK+NlSy/vySc6XcE51VvNAqJamb7vqUwH5VS4nRg7wcBj/5feqRAsLD81XVVh/el1sz/JTKPVyl/YGzNe6McKL45jk00WeD53tWMnG3zz8RfOat1Xxhu8NDrOrmePzmjvrYp1F63kghU9nN/fwwUrerhgRS/n9Vc5v7/3VPt5/T30VsVEo8V4ozXrvslEo8VEszXjfuqQWCum33iMN1pZ4OYv+mN5gE00WnMX/t3vndXP+zNp395OVzCvduwRXAXsi4j9AJLuBrYCxSDYCvx+Pv1V4M+Uvd3cCtwdEePAM5L25Y/3nYU2OPnyCM//5S1tKN2sfSYO7wdY9ufmgWXdmv0sakcQrGfmc3EEeOt8fSKiIeknwKvy9odmrbt+ro1I2g5sB1DvijaUbTa/7JAFVMgPJ+RtImufuu85ddgBnq+IIFizosLxyWC8+VO662LJaUcQzHUgefZfwHx9zmTdrDFiB7ADoH/d5rj0fXfM1c1sQT0VsWH1SjasXsXFF/Sz5rw+1pzXx7qLVrB+YCXrV6/kkgtW0Nez9OPuw8PD1Ot19uzZA8CBI2N88/EX+D97X+R7zxzh+MSCRzzNzrkffeJdc7a3IwhGgI2F+Q3AwXn6jEjqAS4Cjpzhuqd5/aUX8o1br85OihROlix44qdw8mf6xE+LZis7Nj3V1mi2TvWbbBZO2rSmj5c2WsEzz/6Iy9ZvzNpbxX7BZLM1fSKp0N5otWad9Ak04x2nsneb+ckszZ6e6pufYJu5bHr9Yr5O1TtV/2R+/LcxNT013yosy+vrVn09Fc7v72FVX5Xz+npY2VdlZW+VFb0VVuZt5+XL157fzyUX9jN44QouvXAF6y5asWwnVzeuWcVNv7iJm35xE5PNFg+P1Hns4DEiYEVvhRW9Vfp7qqzqq7Kyr8qKnirn9VdZlf9MPRWdOj5/5PgETx8e5ekXs9v+F4+z/6XjHDk+sSw/y2IkWNVbZWVfNu4re6s0Th5n3cVr8t9J9dTvpLdaOfVur3XqbzSfPvU3OvPE6eknRJlx0vb0v31O/e0XHze/HoNWxMwTznNcYHLaxST5SeqpC02K09XKzHUqmr5ApFIRIz/+MZs2vSbfHrPWnVnDqZPhFahWKtPrzDo5XrxYZua6nHZBzNT9+k/M/ftrRxDsAjZL2gQ8R3by932z+uwEtpEd+3838GBEhKSdwF9K+iTZyeLNwKJnl3ryE2SdVKs9z/DwGzpaw7nUagWTeXg0m1mINaZCrZkta7aysPverl38wpVvPhWCxQBsNFuFIJw5P9mceYKxWhHn9fWwqr966kV+6sVj1VR7b7UrrpJZqt5qhbe8eg1vefWas1r/opW9bFp7HtcwOKO9PjbBkeMTjI43GD3Z4NjJBsfHG4yON3jl5CSj401GxycZPdlgshX091To78muDuvvrRams/u+nip9PRV6q8qvIKvkV5BlbzR6q8p+F/mL/dTvZkVv5bSrzGq1GsPDs48Spyl7vXhdp8uYV+kgyI/53wzcS3b56J0R8Zik24HdEbET+Bzw3/OTwUfIwoK831+RnVhuAB9c7IohWx6ViuivVOk/g2fI4YuqvPny1ee+KDvNwKo+Blb1dboM+ynXls8RRMQ9wD2z2j5amD4J/No8634c+Hg76jAzs6X76dvHNjOztnIQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklrlQQSFoj6T5Je/P70/5xraQrJH1H0mOSHpb0nsKyz0t6RtKe/HZFmXrMzGzpyu4R3AI8EBGbgQfy+dnGgN+MiDcC1wF/LGmgsPz3IuKK/LanZD1mZrZEZYNgK3BXPn0XcMPsDhHxVETszacPAoeBi0tu18zM2qSn5PqDEXEIICIOSbpkoc6SrgL6gKcLzR+X9FHyPYqIGJ9n3e3AdoDBwUFqtVrJ0ssZHR3teA3dwmORqdfrNJtNj0XOz4tp3T4WioiFO0j3A5fOseg24K6IGCj0PRoRp50nyJetA2rAtoh4qND2PFk47ACejojbFyt6aGgodu/evVi3c6pWqzE8PNzRGrqFxyIzPDxMvV5nzx4f4QQ/L4q6ZSwkfT8ihma3L7pHEBHXLPCgL0hal+8NrCM77DNXvwuBvwP+81QI5I99KJ8cl/QXwIcXq8fMzNqr7DmCncC2fHob8PXZHST1AV8DvhARfz1r2br8XmTnFx4tWY+ZmS1R2SC4A9giaS+wJZ9H0pCkz+Z9fh14B/D+OS4T/ZKkR4BHgLXAH5asx8zMlqjUyeKIeBl45xztu4EP5NNfBL44z/pXl9m+mZmV508Wm5klzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpa4UkEgaY2k+yTtze9Xz9OvWfh/xTsL7ZskfTdf/yv5P7o3M7NlVHaP4BbggYjYDDyQz8/lRERckd+uL7R/AvhUvv5R4KaS9ZiZ2RKVDYKtwF359F3ADWe6oiQBVwNfPZv1zcysPXpKrj8YEYcAIuKQpEvm6bdC0m6gAdwREX8LvAqoR0Qj7zMCrJ9vQ5K2A9sBBgcHqdVqJUsvZ3R0tOM1dAuPRaZer9NsNj0WOT8vpnX7WCwaBJLuBy6dY9FtS9jO5RFxUNJrgQclPQIcm6NfzPcAEbED2AEwNDQUw8PDS9h8+9VqNTpdQ7fwWGQGBgao1+sei5yfF9O6fSwWDYKIuGa+ZZJekLQu3xtYBxye5zEO5vf7JdWAK4G/AQYk9eR7BRuAg2fxM5iZWQllzxHsBLbl09uAr8/uIGm1pP58ei3wduDxiAjgW8C7F1rfzMzOrbJBcAewRdJeYEs+j6QhSZ/N+7we2C3ph2Qv/HdExOP5so8AH5K0j+ycwedK1mNmZktU6mRxRLwMvHOO9t3AB/LpbwNvmmf9/cBVZWowM7Ny/MliM7PEOQjMzBLnIDAzS5yDwMwscQ4CM7PEOQjMzBLnIDAzS5yDwMwscQ4CM7PEOQjMzBLnIDAzS5yDwMwscQ4CM7PEOQjMzBLnIDAzS5yDwMwscQ4CM7PEOQjMzBJXKggkrZF0n6S9+f3qOfr8U0l7CreTkm7Il31e0jOFZVeUqcfMzJau7B7BLcADEbEZeCCfnyEivhURV0TEFcDVwBjwzUKX35taHhF7StZjZmZLVDYItgJ35dN3ATcs0v/dwDciYqzkds3MrE3KBsFgRBwCyO8vWaT/jcCXZ7V9XNLDkj4lqb9kPWZmtkQ9i3WQdD9w6RyLblvKhiStA94E3FtovhV4HugDdgAfAW6fZ/3twHaAwcFBarXaUjbfdqOjox2voVt4LDL1ep1ms+mxyPl5Ma3bx2LRIIiIa+ZbJukFSesi4lD+Qn94gYf6deBrETFZeOxD+eS4pL8APrxAHTvIwoKhoaEYHh5erPRzqlar0ekauoXHIjMwMEC9XvdY5Py8mNbtY1H20NBOYFs+vQ34+gJ938usw0J5eCBJZOcXHi1Zj5mZLVHZILgD2CJpL7Aln0fSkKTPTnWS9BpgI/C/Z63/JUmPAI8Aa4E/LFmPmZkt0aKHhhYSES8D75yjfTfwgcL8s8D6OfpdXWb7ZmZWnj9ZbGaWOAeBmVniHARmZolzEJiZJc5BYGaWOAeBmVniHARmZolzEJiZJc5BYGaWOAeBmVniHARmZolzEJiZJc5BYGaWOAeBmVniHARmZolzEJiZJc5BYGaWOAeBmVniHARmZokrFQSSfk3SY5JakoYW6HedpCcl7ZN0S6F9k6TvStor6SuS+srUY2ZmS1d2j+BR4F8Afz9fB0lV4NPArwBvAN4r6Q354k8An4qIzcBR4KaS9ZiZ2RKVCoKIeCIinlyk21XAvojYHxETwN3AVkkCrga+mve7C7ihTD1mZrZ0PcuwjfXAgcL8CPBW4FVAPSIahfb18z2IpO3A9nx2VNJiAXSurQVe6nAN3cJjMW2tJI9Fxs+Lad0yFq+eq3HRIJB0P3DpHItui4ivn8GGNUdbLNA+p4jYAew4g+0tC0m7I2Le8yIp8VhM81hM81hM6/axWDQIIuKaktsYATYW5jcAB8nScUBST75XMNVuZmbLaDkuH90FbM6vEOoDbgR2RkQA3wLenffbBpzJHoaZmbVR2ctH/7mkEeCfAH8n6d68/TJJ9wDk7/ZvBu4FngD+KiIeyx/iI8CHJO0jO2fwuTL1LLOuOUzVBTwW0zwW0zwW07p6LJS9MTczs1T5k8VmZolzEJiZJc5B0AaSPiwpJK3tdC2dIumPJP0/SQ9L+pqkgU7XtNzm+yqV1EjaKOlbkp7Iv4LmtztdUydJqkr6gaT/1ela5uMgKEnSRmAL8ONO19Jh9wE/HxG/ADwF3NrhepbVIl+lkpoG8B8i4vXA24APJjwWAL9NdqFM13IQlPcp4D+ywIfhUhAR3yx8Svwhss+FpGTOr1LpcE0dERGHIuIf8ulXyF4E5/3WgJ9lkjYA/wz4bKdrWYiDoARJ1wPPRcQPO11Ll/k3wDc6XcQym+urVJJ88SuS9BrgSuC7na2kY/6Y7I1iq9OFLGQ5vmvop9pCX7EB/Cfg2uWtqHPO5OtGJN1GdmjgS8tZWxdY0lempEDS+cDfAL8TEcc6Xc9yk/Qu4HBEfF/ScKfrWYiDYBHzfcWGpDcBm4AfZl+kygbgHyRdFRHPL2OJy2axrxuRtA14F/DOSO8DKvN9lUqSJPWShcCXIuJ/dLqeDnk7cL2kXwVWABdK+mJE/EaH6zqNP1DWJpKeBYYiohu+YXDZSboO+CTwSxHxYqfrWW6SeshOkr8TeI7sq1XeV/gUfTLyr5i/CzgSEb/T6Xq6Qb5H8OGIeFena5mLzxFYu/wZcAFwn6Q9kv680wUtp0W+SiU1bwf+FXB1/lzYk78rti7lPQIzs8R5j8DMLHEOAjOzxDkIzMwS5yAwM0ucg8DMLHEOAjOzxDkIzMwS9/8Be66HKqAY94EAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "5onPqgL_o9N9" }, "source": [ "## Loss Functions\n", "PyTorch has a bunch of built in loss functions, which are just other modules that you can pass your data through." ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "id": "KutB4zljpCJz" }, "outputs": [], "source": [ "y_target = torch.sin(x)\n", "loss_fn = nn.SmoothL1Loss()" ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "id": "LBtKp7iro-VR" }, "outputs": [], "source": [ "for _ in range(1000):\n", " y = net(x)\n", " loss = loss_fn(y, y_target)\n", " loss.backward()\n", " for p in net.parameters():\n", " p.data.add_(- 0.001 * p.grad)\n", " p.grad.data.zero_()" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "id": "8DHmqNfxpPoQ" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de3RU533u8e9Pd4HQXYj7HWwMdrAtg2/HyHfS02PSnjRxmqZO6yyaNjlt0pM2TnOadrknaznNWs3pxastJ3HjnKSx0zSJaYpDbMfCSRwDsgPmZi4WAgkh0G0Qukuj3/ljNpIsSyAxg2ZgP5+1Zmlm73fvefUyzKO99/u+29wdEREJr7RkV0BERJJLQSAiEnIKAhGRkFMQiIiEnIJARCTkFAQiIiGXkCAws6fM7IyZ7RtnvZnZ35nZUTN708xuGrHuETM7EjweSUR9RERk4hJ1RPB1YMMF1r8XWB48NgH/CGBmxcBfAOuAtcBfmFlRguokIiITkJAgcPdXgNYLFNkIfMNjXgMKzWw28CDwgru3unsb8AIXDhQREUmwjCl6n7lA3YjX9cGy8Za/i5ltInY0QW5u7s3z58+/PDWdoMHBQdLSdIkF1Bbn1dXV4e4sWLAg2VVJCfpcDEuVtjh8+HCzu5eNXj5VQWBjLPMLLH/3QvfNwGaAiooKr66uTlztLkFVVRWVlZVJrUOqUFvEVFZWEolE2L17d7KrkhL0uRiWKm1hZsfHWj5VEVUPjPwTfh7QcIHlIiIyRaYqCLYAvx30HroVOOvup4BtwANmVhRcJH4gWCYiIlMkIaeGzOzbQCVQamb1xHoCZQK4+z8BW4FfAY4CXcDvBOtazeyvgF3Brh539wtddBYRkQRLSBC4+4cust6BT4yz7ingqUTUQ0REJi/5l7FFRCSpFAQiIiGnIBARCTkFgYhIyCkIRERCTkEgIhJyCgIRkZBTEIiIhJyCQEQk5BQEIiIhpyAQEQk5BYGISMgpCEREQk5BICIScgoCEZGQUxCIiIScgkBEJOQUBCIiIZeQIDCzDWZ2yMyOmtljY6z/ipntDh6HzSwyYl10xLotiaiPiIhMXNz3LDazdOBJ4H6gHthlZlvc/cD5Mu7+6RHl/wdw44hddLv7mnjrISIilyYRRwRrgaPuXuPufcAzwMYLlP8Q8O0EvK+IiCRAIoJgLlA34nV9sOxdzGwhsBj4yYjFOWZWbWavmdn7ElAfERGZhLhPDQE2xjIfp+zDwHfdPTpi2QJ3bzCzJcBPzGyvu7/9rjcx2wRsAigvL6eqqirOaseno6Mj6XVIFWqLmEgkQjQaVVsE9LkYluptkYggqAfmj3g9D2gYp+zDwCdGLnD3huBnjZlVEbt+8K4gcPfNwGaAiooKr6ysjLfecamqqiLZdUgVaouYwsJCIpGI2iKgz8WwVG+LRJwa2gUsN7PFZpZF7Mv+Xb1/zOwaoAj4xYhlRWaWHTwvBe4ADozeVkRELp+4jwjcfcDMPglsA9KBp9x9v5k9DlS7+/lQ+BDwjLuPPG20EvhnMxskFkpPjOxtJCIil18iTg3h7luBraOWfWHU678cY7tXgesTUQcREbk0GlksIhJyCgIRkZBTEIiIhJyCQEQk5BQEIiIhpyAQEQk5BYGISMgpCEREQk5BICIScgoCEZGQUxCIiIScgkBEJOQUBCIiIacgEBEJOQWBiEjIKQhEREJOQSAiEnIKAhGRkFMQiIiEXEKCwMw2mNkhMztqZo+Nsf6jZtZkZruDx8dGrHvEzI4Ej0cSUR8REZm4uG9eb2bpwJPA/UA9sMvMtrj7gVFFn3X3T47athj4C6ACcOD1YNu2eOslIiITk4gjgrXAUXevcfc+4Blg4wS3fRB4wd1bgy//F4ANCaiTiIhMUNxHBMBcoG7E63pg3Rjl/ruZ3QUcBj7t7nXjbDt3rDcxs03AJoDy8nKqqqrir3kcOjo6kl6HVKG2iIlEIkSjUbVFQJ+LYaneFokIAhtjmY96/R/At92918w+DjwN3DPBbWML3TcDmwEqKiq8srLykiucCFVVVSS7DqlCbRFTWFhIJBJRWwT0uRiW6m2RiFND9cD8Ea/nAQ0jC7h7i7v3Bi//L3DzRLcVEZHLKxFBsAtYbmaLzSwLeBjYMrKAmc0e8fIh4GDwfBvwgJkVmVkR8ECwTEREpkjcp4bcfcDMPknsCzwdeMrd95vZ40C1u28B/tDMHgIGgFbgo8G2rWb2V8TCBOBxd2+Nt04iIjJxibhGgLtvBbaOWvaFEc8/B3xunG2fAp5KRD1ERGTyNLJYRCTkFAQiIiGnIBARCTkFgYhIyCkIRERCTkEgIhJyCgIRkZBTEIiIhJyCQEQk5BQEIiIhpyAQEQm5KzII6tu62bKngbbOvmRXRUTkipeQSeemWltXH3/47V9iBjfMK2T9ijLWryjlPfMKyUi/IrNNRCRprsggOM8d9tRF2FMX4e9eOkJ+TgZ3Li9l/Yoy7lpRxuyC3GRXUUQk5V3RQTBae88AW/c2snVvIwAryvO4a3kZ668p45ZFxeRkpie5hiIiqeeqCoLRDp/u4PDpDr76s2PkZKaxbnFJ7DTSNWUsKZ2O2Vi3TBYRCZcrMghKpmexpHQ6Nc2dE96mp3+Q7Yeb2H64CX4IcwtzWX9NGXctL+OOZSXMyMm8jDUWEUldV2QQzCnM5SefqeRESxfbjzTxyuEmXj3aTGdfdML7OBnp5l93nOBfd5wgI824aUHRUDCsmpNPWpqOFkQkHBISBGa2AfhbYvcs/qq7PzFq/R8DHyN2z+Im4Hfd/XiwLgrsDYqecPeHJvq+C0qm8ZGShXzk1oX0DQzyxok2Xgn+6t/f0D7h+g8MOjtrW9lZ28qXtx2iZHoW/2V5KeuvKeO/LC+jNC97wvsSEbnSxB0EZpYOPAncD9QDu8xsi7sfGFHsl0CFu3eZ2e8Dfw18MFjX7e5r4q1HVkYaty4p4dYlJfzphmtpOtfLT4OjhVeONNM6iTEHLZ19/GB3Az/Y3QDA6rn5sZ5Iy8u4aWERmeqiKiJXkUQcEawFjrp7DYCZPQNsBIaCwN1fHlH+NeC3EvC+F1Q2I5tfv2kev37TPAYHnf0N7Ww/fIZXDjfz+ok2ooM+4X3tO9nOvpPtPPny2+RlZ3D70hJm0c/S1i7mF0+7jL+FiMjll4ggmAvUjXhdD6y7QPlHgedHvM4xs2pip42ecPcfjLWRmW0CNgGUl5dTVVU16YquToPV10LX0lwOtETZ1xxlb3OUlp6Jh0JH7wA/PnAagG8ceJlZ04zVpelcX5bOtcXpZKeH79pCR0fHJf17XG0ikQjRaFRtEdDnYliqt0UigmCsb74xv1nN7LeACmD9iMUL3L3BzJYAPzGzve7+9rt26L4Z2AxQUVHhlZWVcVX6V4b3y9tNHWw/3Mz2w03sqGmhd2Bwwvtp7HIaTwzw4okBsjLSWLuoeKiL6vKZeaHoolpVVUW8/x5Xg8LCQiKRiNoioM/FsFRvi0QEQT0wf8TreUDD6EJmdh/weWC9u/eeX+7uDcHPGjOrAm4E3hUEl4uZsWzmDJbNnMGjdy6mpz/KjmOtQxedj57pmPC++gYG+dnRZn52tJkvbj3IrPycoVHOdy4rpWCauqiKSOpJRBDsApab2WLgJPAw8JsjC5jZjcA/Axvc/cyI5UVAl7v3mlkpcAexC8lJk5OZHsxdVMafE+tm+tMgFH52tJlzPQMT3ldjew/PVtfxbHUdaQZr5heyfsVM7lpRyg3zCklXF1URSQFxB4G7D5jZJ4FtxLqPPuXu+83scaDa3bcAXwbygH8LTpWc7ya6EvhnMxskNhPqE6N6GyXd3MJcHl67gIfXLmAgOsjuugjbDzfxw9drqG0fxCd4eWHQ4Y0TEd44EeErLx6mcFomdy4rHQqdmfk5l/cXEREZR0LGEbj7VmDrqGVfGPH8vnG2exW4PhF1mAoZ6WlULCqmYlExN2ed4oZbbuenR2JHC68cbqa5o/fiOwlEuvr54Zun+OGbpwBYOTvoorqilIqFxWRlqIuqiEyNK3Jkcaoonp7FxjVz2bhmLoODzsHG9iAUmnj9eBv90Yn3Rjp4qp2Dp9r5p+1vMz0rnduWxga0Va4oUxdVEbmsFAQJkpZmrJpTwKo5BfxB5TI6egf4xdstbD98hu2Hm6hr7Z7wvjr7orx48DQvHox1U11SOp27glNIty4pITdLs6iKSOIoCC6TvOwM7r+unPuvK8fdqW3pio1yPtzEq2+30N0/8XmRapo7qWnu5Ouv1oa2i6qIXD4KgilgZiwunc7i0uk8cvsiegeivF7bxvYjTWw/1MRbjecmvK/RXVRnF+QMXXC+fVkpBbnqoioik6MgSILsjHRuX1bK7ctK+dx7V9J4tic2buFIEz870szZ7v4J7+vU2R6e2VXHM7vqSE8zblpQyL0ry7lv5UyWluloQUQuTkGQAmYV5PCBW+bzgVvmMxAdZE/92aEBbXvqIxPuohoddHbVtrGrto0nnn+LRSXTeGDVLB5cNYsb5xdqam0RGZOCIMVkpKdx88Iibl5YxKfvX0FbZx+vHIl1T91+uGlSXVRrW7rY/EoNm1+pYeaMbB5YVc6GVbNZt6RYM6iKyBAFQYorGqOLaiwUzlBd28bABGdRPXOul2++doJvvnaCgtxM7l05kw2rZnHXijLdy1kk5BQEV5CRXVR/v3Ip53r6gy6qsdNI9W0T66J6truf771xku+9cZLczHTuvraMB1fN4u5rZ5KvW3aKhI6C4Ao2IyeTB1bN4oFVs4ZmUX3x4BlePHCaN060MZGDhe7+KFv3NrJ1byOZ6cYdy0rZsGoW911XrjuziYSEguAqMXIW1Y+vX0pLRy8vHjzNj/Y18vOjLfRFLz61dn/UqTrURNWhJtK+v5eKRcVsWDWLB1fPYm5h7hT8FiKSDAqCq1RJXjYfvGUBH7xlAed6+nn5UBPb9jXy8qEzdPVdfDDboMPOY63sPNbK4z88wPVzC9iwOtYDadnMvCn4DURkqigIQmBGTiYPvWcOD71nDj39UX56pJlt+xt58eBpIl0TG7Ow9+RZ9p48y5e3HWJp2XQ2rJ7FhlWzWT03/zLXXkQuNwVByORkpg9NfdEfHWTnsVa27W9k2/5GTrdPrGvq202dPPny2zz58tvMLczluoJ+che0ULGoWPdYELkCKQhCLDM9jTuWlXLHslL+8r+tYnd9hG37G/nRvkaOt3RNaB8nI92cjMALm1+jZHoWD6wq58FVs7h9aamm0ha5QigIBIh1Tb1pQRE3LSjisQ3Xcuj0OZ7fGztSmOhcSC2dfXx7Zx3f3lnHjJwM7rk2NlZh/TVlTMvSR00kVel/p7yLmXHtrHyunZXPp+9fwfGWTn60LxYKb5yITGgf53oGeG53A8/tbiA7I427VpTFuqWuLNe9m0VSjIJALmphyXR+b/1Sfm/9Uk639/DjA6fZtq+RX9S0EJ3AYIXegUFeOHCaFw6cJiPNuG1pCQ+umsUD15XrFp0iKSAhQWBmG4C/JXbP4q+6+xOj1mcD3wBuBlqAD7p7bbDuc8CjQBT4Q3fflog6yeVRnp/DR25dyEduXUhbZx8vvXWGb1bt40Cb0zdw8bEKA4POT48089Mjzfz5c/u4aUFRbKzCqlksKNGd2ESSIe4gMLN04EngfqAe2GVmW0bdhP5RoM3dl5nZw8CXgA+a2XXAw8AqYA7wopmtcPeJ37VFkqZoehbvv3kepeeOcsttd7L9cBPP72vk5bfO0NE7cNHt3eH14228fryNL249yMrZ+dx77UzWX1PGjfMLyQjhxHjuTqSrn9qWTk60dlHX2sWJ1q6hqckNwyx2TSfdjDSD3Kx08nMzKczNomhaJrMLc5kbPHQ3O5mIRBwRrAWOunsNgJk9A2wERgbBRuAvg+ffBf7BYhPlbwSecfde4JiZHQ3294sLveGhQ4eorKxMQNUvXSQSobCwMKl1SBWj26LYnYzuAVo7+2jr6qN/AqOaARqBl4Pn6WlGQW4m+TmZ5OdmMu0K+ELbvXs3AwMDF/1sOrEbDPX2R+kZ8bOnP0pPf3RCp9smKjM9jezMNHIy0snOTCM3M52czHRyM9Mve1df/R8ZluptkYggmAvUjXhdD6wbr4y7D5jZWaAkWP7aqG3njvUmZrYJ2ASQmZlJJDKxi5aXSzQaTXodUsV4bVGcCUUFRvdAGuf6nI4+6J/gl1x00Gnt7KO1sw+AdIPcDIs9MiE3PfaXcSoZGBiI/UUftEXUoWfA6Y3Gfu++KPQPxqbySNxX/YX1Rwfpjw7SwbuP0DLSjKx0yEqD7HQjJyP2M1H5oP8jw1K9LRIRBGN9bEZ/zscrM5FtYwvdNwObASoqKry6unoydUy4qqqqpB+VpIqJtoW7s/fkWX60LzZWoaa585Lerw8gPY33zC9g3eIS1i0ppmJhcdJOg7g7zR19PHDfPbSc7eCeP/s6u+si1DTFfr80IDt4pLJeoM9gaVkeq+fks3puAavnFrBqTj4zLmFWWv0fGZYqbTHeHQsTEQT1wPwRr+cBDeOUqTezDKAAaJ3gtnKVMDNumFfIDfMK+ZMHr+HomY7YALb9jew72T6pffVFB4fuxvYPL0NmunHj/CJuXVrCLYuKWDO/8JK+vKKDTlffAN39Ubr7onT3R+kbGIydyhkYJNLVT0tnL80dfdS3dXGsuZOapk7OdvfTeCr2O3zvjZOTft9U4Q5Hz3Rw9EwHP9g9/F9xcel0Vs3J5/ogHFbPKVA34KtIIoJgF7DczBYDJ4ld/P3NUWW2AI8QO/f/fuAn7u5mtgX4VzP7G2IXi5cDOxNQJ0lxZsby8hksL5/BJ+9ZzslIN1WHzrD9UBM/P9pM5wQmxhupP+rsrG1lZ21rsH+4pnwGS2fmMa8wl7lFuWSlp9HRO0B7zwDt3f00d/QOnX5q7+6nvWdgQhe5L7esjDQWFE97x6M8P4c0ix0uu0PUncFBHwqus939RLpiv9PJSDcn27ppbO+Z0FTkE3GsuZNjzZ388M1TQ8vmF+eyek7B0JHD9XMLKJ6elZg3lCkVdxAE5/w/CWwj1n30KXffb2aPA9XuvgX4GvD/govBrcTCgqDcd4hdWB4APqEeQ+E0tzCXD69byIfXLaRvYJA36yPsONbKazUtVNe20d0/uY+FO7zVeG7Co6KnWkFuJgtLpjG/eBoLz3/hl0xjUcl0ZuXnJOT+0v3RQRoi3dS1dnO8tZPjLV3UNHVQ09TJ8dauuC9K17XG9v38vsahZfOKcrlhXgHvmVeIt0a5pXeA6dkarpTqEvIv5O5bga2jln1hxPMe4DfG2faLwBcTUQ+5OmRlpFGxqJiKRcV84u5l9EcHeevUOaqPt1J9vI2dx1ppOjfxezcn09Ky6ayZX8SymXksKJ4W+/IvmjYlp1Uy09NYWDKdhSXTuZPSd6zrGxjkRGtXcBroHPsb2tnXcJa61ond5W489W3d1Ld1s3VvLBz+etc2VpTP4OaFRdy+tJTblpboqCEFKaol5WWmp3H9vAKun1fA79yxGHentqWLHTUtvFbTwi9qWiY8c+rlkpWexrSsDDIsymPvvZaVs/NZM68wZc+jZ2WksWxmXnBviVlDyyNdfew72c7+hrPsa2hn38mzHLvEi/oQu6/F+SOzb+04AcDK2fncvrSE25aUsHZJsW6PmgIUBHLFMTMWl05ncel0Hl67AHfnWHMnv6hp4fXjbbxxvI3aCc6eOpZpWelMy0onNyt9qP99VnoaWRlp5OdkUpKXRcn0bMpmZLOodDpLSqczpzCXe3/+JSKRCB9fvzSBv+3UKpyWxZ3LS7lz+fARxLmefg40tLP35Fn2Bz/fburAL/HM0sFT7Rw81c7XfnaMNINVcwpYu7iYtYuLuWVRsY4YkkBBIFc8M2NJWR5LyvL48LqFADR39HKo8Rwn27qpj3RzKhI75ZGXk8GM7AzycjIomZ499KVeOC02eC0vJ0P3VBhlRk4m65aUsG5JydCyzt4BDp6KHTHsPRn7eeTMuUlfnB704Zsefe1nxwBYUZ7HLYtiwbBucQmzCjQf1eWmIJCrUmleNqXLUr3n/pVrenbG0HWc8zp7B9jf0M6b9RF210XYcaSRpu7JHzYcPt3B4dMdQ6eSFhRPGzpiWLe4mAXF08btDy+XRkEgIgkxPTtj6AsbYoOoVlfcRnVtK6++3cLPjzbzdtPkrzecCOZb+u7r9QCU52ezdnFJ7L0WFbN8Zl5CelmFmYJARC6b0rxsNqyezYbVswE43d7DazUtvHo0dpH/ROvkr+Wcbu/lP/Y08B97YgPeCqdlcsui2NHCLYuKWTUnP5QTFsZDQSAiU6Y8P4eNa+aycU1sSrFTZ7vZeayVncda2XGslaNnOia9z0hX/9D9LgCmZ6Vz08Ii1i0uZu3iEm6YV0BOZupPWphMCgIRSZrZBbnvCIaWjl521bax41gLu2pbOdDQPukL0J190aF7XkCsq+yaeYVDp61uWlhEnga5vYNaQ0RSRkleNhtWz2LD6tjYhnM9/VQfb2NXcNSwpz5Cf3RyydA3MDg8/cjLsSnOV8/Jj51OWhKbm6pwWri7rCoIRCRlzcjJ5O5rZnL3NTMB6OmP8ssTEXbVxoLh9eOTn34kOujsqT/LnvqzfDXosnrtrBlD4xjWLS4O3S1UFQQicsXIyUzntqUl3LY0NqahPzrIvpNnh64z7Kptpb1n8hMHnh/9/I1fHAdgUcn5LqslrF1UzPzi3Ku6y6qCQESuWJnpady4oIgbFxTxe+uXMjjoHDp97h0XoJs7Jj/9SG1LF7UtXXynOtZldVZ+zjvGMiybmXdVBYOCQESuGmlpxsrZ+aycnc8jty96x7xUO4PTSfVtk59Yr7G9hy17GtgSdFktCrqsnh/9vHL2jCu6y6qCQESuWqPnpQI4Gelm57EWdh5rY+exlksa5NbW1c+PD5zmx0GX1bzsjBFdVou5YV4B2RlXTpdVBYGIhMrcwlx+7cZ5/NqN84DYvFTVtbHTSDtqWjnY2D7pCfU6egd45XATrxxuAmJdVm+cP9xltWdgqu5SfWkUBCISaqNHP5/t7ueN423srG1lR00Le0+evaQuqzuCaxQAaQbXH/p57IhhUax3UipNUa4gEBEZoSA3k7uvncnd18a6rHb3RfllXRs7amK9kt440UZP/+Ck9jnosKcuwp66CJtfqRm6ler50c+3LC5i5ozkdVlVEIiIXEBuVjq3Ly3l9qWxezT0DQyyr+HsUDDsqm3l3CS7rI68lerTQZfVxaXTWRtcgF67uJh5RVPXZTWuIDCzYuBZYBFQC3zA3dtGlVkD/COQD0SBL7r7s8G6rwPrgbNB8Y+6++546iQicjllZaRx04IiblpQxO+zlOig81Zje2z0c9Azqbmjb9L7PdbcybHmTp6trgNgTkEOt4zosrq07PJ1WY33iOAx4CV3f8LMHgtef3ZUmS7gt939iJnNAV43s23uHgnW/4m7fzfOeoiIJEV6mrFqTgGr5hTw0eBWqjXNncNjGWpaaDjbM+n9Npzt4bndDTy3O9ZltXh6FrcsKmLt4hLWLS5m5ez8hN1EKd4g2AhUBs+fBqoYFQTufnjE8wYzOwOUARFERK4yZsbSsjyWluXxoaDL6nef/wlp5Sti4VDbSs0ldFlt7exj2/7TbNsf67I6IzuDmxcVDU2LccO8QrIyLm0sg/ml3ngUMLOIuxeOeN3m7kUXKL+WWGCscvfB4NTQbUAv8BLwmLuPOQzQzDYBmwDKy8tvfuaZZy653onQ0dFBXl5eUuuQKtQWMZ/61KeIRqP8/d//fbKrkhL0uRg2ui3O9jqH26IcbotyqHWQunODxNvBNDMNlhWmsaIonWuK01lakEZ2xjuPGO6+++7X3b1i9LYXDQIzexGYNcaqzwNPTzQIzGw2sSOGR9z9tRHLGoEsYDPwtrs/fsEKARUVFV5dXX2xYpdVVVUVlZWVSa1DqlBbxFRWVhKJRNi9W5e5QJ+LkS7WFme7+6muHb7GsLf+LAOTnX97lIw0Y/XcAm5aUMSaBYXcOL+QBSXTxwyCi54acvf7xltnZqfNbLa7nwq+1M+MUy4f+E/gf50PgWDfp4KnvWb2L8BnLlYfEZGrTUFuJveuLOfeleUAdPUNsPtEhB3BdYY3TrTROzC5LqsDg87uutj9o/n5hcvGe41gC/AI8ETw87nRBcwsC/g+8A13/7dR686HiAHvA/bFWR8RkSvetKwMbl9Wyu3Lhrus7j0ZGRr9/PrxNjp6Jz/L6njiDYIngO+Y2aPACeA3AMysAvi4u38M+ABwF1BiZh8NtjvfTfRbZlYGGLAb+Hic9RERuepkZaRx88Jibl5YzB9Uxu6pcPBUOzuOtQ51W23tnHyX1fPiCgJ3bwHuHWN5NfCx4Pk3gW+Os/098by/iEgYpQfn/1fPLeDRO2NdVt9u6hgKhh3HWjk1iS6rGlksInKFMzOWzZzBspkz+PC6hbg79W3dvHGibeg6wf6G9nG3VxCIiFxlzIz5xdOYXzyNjWvmArHrDNlfHLv8lXsnBRERmbALDTZTEIiIhJyCQEQk5BQEIiIhpyAQEQk5BYGISMgpCEREQk5BICIScgoCEZGQUxCIiIScgkBEJOQUBCIiIacgEBEJOQWBiEjIKQhEREJOQSAiEnJxBYGZFZvZC2Z2JPhZNE65qJntDh5bRixfbGY7gu2fDW50LyIiUyjeI4LHgJfcfTnwUvB6LN3uviZ4PDRi+ZeArwTbtwGPxlkfERGZpHiDYCPwdPD8aeB9E93QzAy4B/jupWwvIiKJEe89i8vd/RSAu58ys5njlMsxs2pgAHjC3X8AlAARdx8IytQDc8d7IzPbBGwCKC8vp6qqKs6qx6ejoyPpdUgVaouYSCRCNBpVWwT0uRiW6m1x0SAwsxeBWWOs+vwk3meBuzeY2RLgJ2a2F2gfo5yPtwN33wxsBqioqPDKyspJvH3iVVVVkew6pAq1RUxhYSGRSERtEdDnYliqt8VFg8Dd7xtvnZmdNrPZwdHAbODMOPtoCH7WmFkVcCPw70ChmWUERwXzgIZL+B1ERCQO8V4j2AI8Ejx/BEDEj2UAAAaDSURBVHhudAEzKzKz7OB5KXAHcMDdHXgZeP+FthcRkcsr3iB4ArjfzI4A9wevMbMKM/tqUGYlUG1me4h98T/h7geCdZ8F/tjMjhK7ZvC1OOsjIiKTFNfFYndvAe4dY3k18LHg+avA9eNsXwOsjacOIiISH40sFhEJOQWBiEjIKQhEREJOQSAiEnIKAhGRkFMQiIiEnIJARCTkFAQiIiGnIBARCTkFgYhIyCkIRERCTkEgIhJyCgIRkZBTEIiIhJyCQEQk5BQEIiIhpyAQEQk5BYGISMjFFQRmVmxmL5jZkeBn0Rhl7jaz3SMePWb2vmDd183s2Ih1a+Kpj4iITF68RwSPAS+5+3LgpeD1O7j7y+6+xt3XAPcAXcCPRxT5k/Pr3X13nPUREZFJijcINgJPB8+fBt53kfLvB553964431dERBIk3iAod/dTAMHPmRcp/zDw7VHLvmhmb5rZV8wsO876iIjIJGVcrICZvQjMGmPV5yfzRmY2G7ge2DZi8eeARiAL2Ax8Fnh8nO03AZsAysvLqaqqmszbJ1xHR0fS65Aq1BYxkUiEaDSqtgjoczEs1dviokHg7veNt87MTpvZbHc/FXzRn7nArj4AfN/d+0fs+1TwtNfM/gX4zAXqsZlYWFBRUeGVlZUXq/plVVVVRbLrkCrUFjGFhYVEIhG1RUCfi2Gp3hbxnhraAjwSPH8EeO4CZT/EqNNCQXhgZkbs+sK+OOsjIiKTFG8QPAHcb2ZHgPuD15hZhZl99XwhM1sEzAe2j9r+W2a2F9gLlAL/O876iIjIJF301NCFuHsLcO8Yy6uBj414XQvMHaPcPfG8v4iIxE8ji0VEQk5BICIScgoCEZGQUxCIiIScgkBEJOQUBCIiIacgEBEJOQWBiEjIKQhEREJOQSAiEnIKAhGRkFMQiIiEnIJARCTkFAQiIiGnIBARCTkFgYhIyCkIRERCTkEgIhJyCgIRkZCLKwjM7DfMbL+ZDZpZxQXKbTCzQ2Z21MweG7F8sZntMLMjZvasmWXFUx8REZm8eI8I9gG/DrwyXgEzSweeBN4LXAd8yMyuC1Z/CfiKuy8H2oBH46yPiIhMUlxB4O4H3f3QRYqtBY66e4279wHPABvNzIB7gO8G5Z4G3hdPfUREZPIypuA95gJ1I17XA+uAEiDi7gMjls8dbydmtgnYFLzsMLOLBdDlVgo0J7kOqUJtMazUzNQWMfpcDEuVtlg41sKLBoGZvQjMGmPV5939uQm8sY2xzC+wfEzuvhnYPIH3mxJmVu3u414XCRO1xTC1xTC1xbBUb4uLBoG73xfne9QD80e8ngc0EEvHQjPLCI4Kzi8XEZEpNBXdR3cBy4MeQlnAw8AWd3fgZeD9QblHgIkcYYiISALF233018ysHrgN+E8z2xYsn2NmWwGCv/Y/CWwDDgLfcff9wS4+C/yxmR0lds3ga/HUZ4qlzGmqFKC2GKa2GKa2GJbSbWGxP8xFRCSsNLJYRCTkFAQiIiGnIEgAM/uMmbmZlSa7LsliZl82s7fM7E0z+76ZFSa7TlNtvKlUwsbM5pvZy2Z2MJiC5o+SXadkMrN0M/ulmf0w2XUZj4IgTmY2H7gfOJHsuiTZC8Bqd78BOAx8Lsn1mVIXmUolbAaA/+nuK4FbgU+EuC0A/ohYR5mUpSCI31eAP+UCg+HCwN1/PGKU+GvExoWEyZhTqSS5Tknh7qfc/Y3g+TliX4LjzhpwNTOzecB/Bb6a7LpciIIgDmb2EHDS3fckuy4p5neB55NdiSk21lQqofzyG8nMFgE3AjuSW5Ok+T/E/lAcTHZFLmQq5hq6ol1oig3gz4AHprZGyTOR6UbM7PPETg18ayrrlgImNWVKGJhZHvDvwKfcvT3Z9ZlqZvarwBl3f93MKpNdnwtREFzEeFNsmNn1wGJgT2wiVeYBb5jZWndvnMIqTpmLTTdiZo8Avwrc6+EboDLeVCqhZGaZxELgW+7+vWTXJ0nuAB4ys18BcoB8M/umu/9Wkuv1LhpQliBmVgtUuHsqzDA45cxsA/A3wHp3b0p2faaamWUQu0h+L3CS2NQqvzliFH1oBFPMPw20uvunkl2fVBAcEXzG3X812XUZi64RSKL8AzADeMHMdpvZPyW7QlPpIlOphM0dwEeAe4LPwu7gr2JJUToiEBEJOR0RiIiEnIJARCTkFAQiIiGnIBARCTkFgYhIyCkIRERCTkEgIhJy/x88oJxHEOp6FgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "WX-qNKKPn2Ft" }, "source": [ "## Optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "pPhBZESWE-9M" }, "source": [ "We can use more fancy optimizers with the `optim` package." ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "id": "pZ0ptKURn5TE" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "net = Net2(input_size=1, output_size=1)\n", "\n", "optimizer = optim.Adam(net.parameters(), lr=1e-3)\n", "\n", "x = torch.linspace(-5, 5, 100).view(-1, 1)\n", "y = net(x)\n", "y_target = torch.sin(x)\n", "loss_fn = nn.MSELoss()" ] }, { "cell_type": "markdown", "metadata": { "id": "TTmTVBaDFDP-" }, "source": [ "Here's the network before training" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "id": "vR278Ii7ooXa" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxV9Z3/8dcnCSHIlrAFJGGToOCGElHcCJtia5W21qVjpa0OtaNWXKa105l2xrYzdkarU+u00paq1Z+4tZVal7oQcGELGEVBtrCFNZLcQFgScvP5/XGvkmJuEriX3BvO+/l45JF7zvme3I9fD+d971m+x9wdEREJrrRkFyAiIsmlIBARCTgFgYhIwCkIREQCTkEgIhJwCgIRkYBLSBCY2Uwz22FmH8RYbmb2CzNbY2bvm9mZjZZNNbPV0Z+piahHRERaL1HfCB4BJjez/BKgIPozDfgVgJn1AH4EnA2MBn5kZjkJqklERFohIUHg7vOAymaaXA485hELgGwz6wdcDLzq7pXuXgW8SvOBIiIiCZbRRu/TH9jUaLo8Oi/W/M8ws2lEvk3QqVOnUfn5+Uen0lZqaGggLU2nWEB98YlNmzbh7gwYMCDZpaQEbRcHpUpfrFq16mN3733o/LYKAmtinjcz/7Mz3WcAMwAKCwu9pKQkcdUdgeLiYoqKipJaQ6pQX0QUFRURCoUoLS1NdikpQdvFQanSF2a2oan5bRVR5UDjj/B5wJZm5ouISBtpqyCYDVwXvXroHKDa3bcCrwAXmVlO9CTxRdF5IiLSRhJyaMjMngSKgF5mVk7kSqAOAO7+a+BF4HPAGmAv8I3oskoz+zGwOPqn7nb35k46i4hIgiUkCNz9mhaWO3BTjGUzgZmJqENERA5f8k9ji4hIUikIREQCTkEgIhJwCgIRkYBTEIiIBJyCQEQk4BQEIiIBpyAQEQk4BYGISMApCEREAk5BICIScAoCEZGAUxCIiAScgkBEJOAUBCIiAacgEBEJOAWBiEjAKQhERAIuIUFgZpPNbKWZrTGzu5pYfr+ZlUZ/VplZqNGycKNlsxNRj4iItF7czyw2s3TgIWASUA4sNrPZ7r78kzbufluj9rcAZzT6E/vcfWS8dYiIyJFJxDeC0cAady9z9zpgFnB5M+2vAZ5MwPuKiEgCJCII+gObGk2XR+d9hpkNBAYDbzSanWVmJWa2wMymJKAeERE5DHEfGgKsiXkeo+3VwLPuHm40b4C7bzGzIcAbZrbM3dd+5k3MpgHTAHJzcykuLo6z7PjU1NQkvYZUob6ICIVChMNh9UWUtouDUr0vEhEE5UB+o+k8YEuMtlcDNzWe4e5bor/LzKyYyPmDzwSBu88AZgAUFhZ6UVFRvHXHpbi4mGTXkCrUFxHZ2dmEQiH1RZS2i4NSvS8ScWhoMVBgZoPNLJPIzv4zV/+Y2YlADjC/0bwcM+sYfd0LOA9Yfui6IiJy9MT9jcDd683sZuAVIB2Y6e4fmtndQIm7fxIK1wCz3L3xYaPhwMNm1kAklO5pfLWRiIgcfYk4NIS7vwi8eMi8Hx4y/e9NrPcOcGoiahARkSOjO4tFRAJOQSAiEnAKAhGRgFMQiIgEnIJARCTgFAQiIgGnIBARCTgFgYhIwCkIREQCTkEgIhJwCgIRkYBTEIiIBJyCQEQk4BQEIiIBpyAQEQk4BYGISMApCEREAk5BICIScAoCEZGAS0gQmNlkM1tpZmvM7K4mln/dzCrMrDT6c0OjZVPNbHX0Z2oi6hERkdaL++H1ZpYOPARMAsqBxWY2292XH9L0KXe/+ZB1ewA/AgoBB5ZE162Kty4REWmdRHwjGA2scfcyd68DZgGXt3Ldi4FX3b0yuvN/FZicgJpERKSV4v5GAPQHNjWaLgfObqLdl83sQmAVcJu7b4qxbv+m3sTMpgHTAHJzcykuLo6/8jjU1NQkvYZUob6ICIVChMNh9UWUtouDUr0vEhEE1sQ8P2T6L8CT7l5rZjcCjwLjW7luZKb7DGAGQGFhoRcVFR1xwYlQXFxMsmtIFeqLiOzsbEKhkPoiStvFQaneF4k4NFQO5DeazgO2NG7g7jvdvTY6+RtgVGvXFRGRoysRQbAYKDCzwWaWCVwNzG7cwMz6NZq8DFgRff0KcJGZ5ZhZDnBRdJ6IiLSRuA8NuXu9md1MZAeeDsx09w/N7G6gxN1nA98xs8uAeqAS+Hp03Uoz+zGRMAG4290r461JRERaLxHnCHD3F4EXD5n3w0avvw98P8a6M4GZiahDREQOn+4sFhEJOAWBiEjAKQhERAJOQSAiEnAKAhGRgFMQiIgEnIJARCTgFAQiIgGnIBARCTgFgYhIwCkIREQCTkEgIhJwCgIRkYBTEIiIBJyCQEQk4BQEIiIBpyAQEQk4BYGISMAlJAjMbLKZrTSzNWZ2VxPLbzez5Wb2vpm9bmYDGy0Lm1lp9Gf2oeuKiMjRFfczi80sHXgImASUA4vNbLa7L2/U7F2g0N33mtm3gf8Groou2+fuI+OtQ0REjkwivhGMBta4e5m71wGzgMsbN3D3Oe6+Nzq5AMhLwPuKiEgCxP2NAOgPbGo0XQ6c3Uz764GXGk1nmVkJUA/c4+5/bmolM5sGTAPIzc2luLg4nprjVlNTk/QaUoX6IiIUChEOh9UXUdouDkr1vkhEEFgT87zJhmbXAoXA2EazB7j7FjMbArxhZsvcfe1n/qD7DGAGQGFhoRcVFcVdeDyKi4tJdg2pQn0RkZ2dTSgUUl9Eabs4KNX7IhGHhsqB/EbTecCWQxuZ2UTgB8Bl7l77yXx33xL9XQYUA2ckoCYREWmlRATBYqDAzAabWSZwNfB3V/+Y2RnAw0RCYEej+Tlm1jH6uhdwHtD4JLOIiBxlcR8acvd6M7sZeAVIB2a6+4dmdjdQ4u6zgf8BugDPmBnARne/DBgOPGxmDURC6Z5DrjYSEZGjLBHnCHD3F4EXD5n3w0avJ8ZY7x3g1ETUICIiR0Z3FouIBJyCQEQk4BQEIiIBpyAQEQk4BYGISMApCEREAk5BICIScAoCEZGAUxCIiAScgkBEJOAUBCIiAacgEBEJOAWBiEjAKQhERAJOQSAiEnAKAhGRgFMQiIgEnIJARCTgEhIEZjbZzFaa2Rozu6uJ5R3N7Kno8oVmNqjRsu9H5680s4sTUY+IiLRe3EFgZunAQ8AlwAjgGjMbcUiz64Eqdx8K3A/8LLruCOBq4GRgMvB/0b8nIiJtJBEPrx8NrHH3MgAzmwVcDixv1OZy4N+jr58FfmlmFp0/y91rgXVmtib69+Y394YrV66kqKgoAaUfuVAoRHZ2dlJrSBXqi4jS0lLq6+uTvm2mCm0XB6V6XyQiCPoDmxpNlwNnx2rj7vVmVg30jM5fcMi6/Zt6EzObBkwD6NChA6FQKAGlH7lwOJz0GlKF+iKivr4ed1dfRGm7OCjV+yIRQWBNzPNWtmnNupGZ7jOAGQCFhYVeUlJyODUmXHFxsT75RakvIoqKigiFQpSWlia7lJSg7eKgVOmLyIGYz0rEyeJyIL/RdB6wJVYbM8sAugOVrVxXRESOokQEwWKgwMwGm1kmkZO/sw9pMxuYGn19BfCGu3t0/tXRq4oGAwXAogTUJCIirRT3oaHoMf+bgVeAdGCmu39oZncDJe4+G/gd8IfoyeBKImFBtN3TRE4s1wM3uXs43ppERKT1EnGOAHd/EXjxkHk/bPR6P/CVGOv+FPhpIuoQEZHDpzuLRUQCTkEgIhJwCgIRkYBTEIiIBJyCQEQk4BQEIiIBpyAQEQk4BYGISMApCEREAi4hdxaLpIq6+gYqamrZsWs/O2vq2FNXz966MHvrwuw/EKauvoG6cAPhhsggt0ZkRMasDml0zszguI7ppJtRF2442Dbs1Dc44QYn7JHf9WEnMyONft2zOD67E8dnZ9HgTQ6cK5LyFATS7tTU1rNy224+2raL1dtr2Bzax5bQPrZW76dyT13S6tq+vorMNPjus+9xen42owf1YGifLjGH/hVJFQoCSUnhBqdqbx2Ve+qo2F3Liq27eL+8mvfLQ6zfuTfZ5TXJ3akNw9Ml5TxdUg5Az86ZnDWoB4WDcjhjQA6n9O9Gxww9jVVSi4JA2py7U7mnjk1V+9hUuZdNVXspr9rH1tA+duyuZcfuWnbW1NJwDBxp2bmnjpc/3MbLH24DIDM9jVEDc7jqrHwmn9KXrA4KBUk+BYEkXEOD83FNLVur97O1ej+bQ5EdfnnVXjZV7mNT1V721gVztPG6cAPzy3Yyv2wnOX/pwBWj8rjqrAEM7dMl2aVJgCkI5LDt2L2fVdtq2Fq9j4qaWt5dUcvTm5ewrXo/23fVsn3XfuqT9HHeLHI4pnfXLHp37UjXrAw6Z6ZzXGYGHTuk0TE9jY4d0klPixy3d4cGd/YfCLOnNsye2nqcyIngDulpkd9paaSlGelmZKQb6dHXNbX1bK3ex5bQfsoqath2mLVW7T3Ab95cx2/eXBf5llCYz+dP60fnjvpnKW1LW5w0q3rfAZZurGLJ+ireKw+xYutuPq6pbaLl4e4G4zOw53Gc1Lcrw/t1Y3CvzvTP7kS/7E706dqRDunJuSr6vBdz2F5ZzT9NKGDxukqWbqyitr6hVesu2VDFkg1V/Gj2h0wakcuUM47ngoLeSftvkWBREMin3J1Nlfso2VDJ4vVVLNlQyeodNSTrqshuWRn07NKRHp0z6dc9i1P7d+e0vGxO6d+NrlkdklNUMzqkp9Glg3H7pGEA1NaHWVZezZINVZEw3RCKEaIH7TsQZvZ7W5j93hZ6dM7k8pHHc8WoPE4+vntb/CdIQCkIAizc4KzYuovF6yspWV/F4vWV7Njd/I4qUbI6pJGfcxx5OZ3I73Hcp69zu2fRp2tHenft2O6vrumYkU7hoB4UDuoBRIJ2yYYqnli4kb8u20pdC98WKvfU8fu31/P7t9czvF83rj4rny+d2T8lQ1Dat7iCwMx6AE8Bg4D1wJXuXnVIm5HAr4BuQBj4qbs/FV32CDAWqI42/7q7l8ZTk8R2INzA++XVLFpXyaJ1OylZX8Xu2vqj8l7dsjLo170Tfbtn0a97Fvk9/n6n36tLZuCurzezT4Phh5eO4I/vbubpxZtYuX13i+uu2LqLH83+kP9++SO+PCqP68YMZGifrm1QtQRBvN8I7gJed/d7zOyu6PT3DmmzF7jO3Veb2fHAEjN7xd1D0eX/7O7PxlmHNKG2Psx7m6pZULaThet2snRDiH0H4r9ap0O6MbRPV07o3Zncblns3lHO2aePoF/3LPpGf47L1JfN5uR0zuT68wfzzfMG8X55NU+XbOKF97dSve9As+vtqQvz2PwNPDZ/A+cM6cG15wzkohF9yczQuQQ5cvH+a70cKIq+fhQo5pAgcPdVjV5vMbMdQG8ghCTU/gNh3t0YYuG6nSwsO7yTlc0Z0rszowbkUDgoh1P7ZzO0T5e/2/EUF++gaFRe3O8TRGbG6fnZnJ6fzY++cDJzV1Xw59LNvLp8e4uHjhaUVbKgrJJeXTpy1Vl5XDN6AHk5x7VR5XIsMY/jTKCZhdw9u9F0lbvnNNN+NJHAONndG6KHhsYAtcDrwF3u3uRBajObBkwDyM3NHTVr1qwjrjsRampq6NIludd+14adtaEGVlaG+agyzNrqBuLd72cYDOqeRkFOOgXZaQzNSadbZvOHcFKhL1LB9OnTCYfDPPjgg3H/rT0HnEVb63lrcz1rq1v3P9WA03qnM35ABqf2SictyYfetF0clCp9MW7cuCXuXnjo/BaDwMxeA/o2segHwKOtDQIz60fkG8NUd1/QaN42IBOYAax197tb+o8pLCz0kpKSlpodVcXFxRQVFbXpe+6tq2fphtCnh3pKN4U4EI7vkp6uHTMYNSiH0YN7cNagHpzav/th3+2ajL5IRUVFRYRCIUpLE3uaa9X23Tw2fz1/XLq51TfiDenVmannDuLLo/LokqT7ErRdHJQqfWFmTQZBi1uIu09s5o9uN7N+7r41ulPfEaNdN+CvwL9+EgLRv701+rLWzH4P3NlSPUGyp7aekg1VLCzbyYKynbxfXh33jVrZx3Vg9KAejB7cg3OG9GR4v26f3lwlqWlYbld+MuVUvjv5JJ5bUs7jCzawtmJPs+uUfbyHH83+kHv/tpKpYwZx/fmDyemc2UYVS3sT70eF2cBU4J7o7+cPbWBmmcCfgMfc/ZlDln0SIgZMAT6Is552bff+A9EdfyULynaybHP1p8MlH6menTM5e0hkp3/24J4U9OlCmnb87VK3rA5847zBfP3cQSwoq+TxhRt45YNtzX442L2/nl/OWcPv317HtWMG8o8XDKFXl45tWLW0B/EGwT3A02Z2PbAR+AqAmRUCN7r7DcCVwIVATzP7enS9Ty4TfcLMehM5vFkK3BhnPe1KXX0DSzdW8ebqCt5a/THLNlfHPdBary4dP93xnzNYwyAfi8yMMSf0ZMwJPdmxez/PlJTz/xZuZHNoX8x19tSFeXhuGY++s56vjh7It8YOIbdbVhtWLaksriBw953AhCbmlwA3RF8/DjweY/3x8bx/e7Spci/FqyqYu7KC+Ws/Zk+cg6/lduvI2YN7frrzH9Krs3b8AdKnaxY3jRvKjWNPYM5HO3jknfW8tebjmO33H2hg5tvreHzhBq4qzOfGohPon92pDSuWVKSLvY+y/QfCLCjbydxVFcxdVUFZC8d2W9K3WxbnfHKoZ0hPBvU8Tjt+IT3NmDgil4kjclm1fTe/f3sdzy3ZTF246SuO6uob+MOCDcxavJErRuXx7bFDGdBTl54GlYIgwdydso/3ULwysuNfWLYzrmv5+2d3inzaj37qH9BDO35p3rDcrvzXl07jOxMKeHhuGU8u2hhzGzwQdp5ctImnS8qZMrI/N407gSG9k3+Zo7QtBUEC1NTW886ajz/91F9eFftYbUvye3SKHOqJXtWT30Of0uTI9OveiX+/7GT+adwJ/GZeGY8v2BjzzvJwg/Pc0nL+9G45l552PDePH8qwXA1hERQKgiPg7mza3cCv565l7soKSjZUHvH1/N2yMji/oBcXFPTm/KG9tOOXhOvTNYsffH4EN449gd+9tY7H5m+gJsYYUw3Op6OfXnJKX24eP1QjnwaAgqCVqvce4K01HzN31Q7mrqpg+65a4KPD/jtmcFr/7ow9sQ9jh/Xm9LzuZGjMeWkDPbt05LuTT2LahUOio5quY9f+2IMOvvTBNl76YBsTh/fhlvEFnJ6fHbOttG8KghgaGpwPtlQzd2UFxasqeHdj1RFf2tmzcyZjh/Vm7ImRT/09dR23JFH2cZncNmkY118wmD/M38Bv3yyjam/swe5eW7GD11bs4MJhvfnO+KGfDqstxw4FQSMf19Ty5urIpZ3zVn9M5Z66I/o76WnGGfnZFJ3Ym6IT+zCiXzfdxCUpp1tWB24aN5SvnzuIJxZuYMa8dc0+OGfeqgrmrapgzJCe3DJhKGOG9NSFC8eIQAdBfbiB0k0h5q6qoHhlBcs2V7e8Ugx9u2VRdGJvxg7rzblDe9G9kx4eIu1D544ZTLvwBK4bM4gnF23k13PXRg99Nm1+2U7ml+2kcGAOt0wo4MKCXgqEdi5wQbCtej/zolf3vLm6otljpM1JNzh7SM/ozr8Pw3J1B6+0b1kd0vnGeYP56tkDeKaknF8Vr232buWSDVVMnbmI0/O6c8v4AiYM76N/A+3UMR8EdfUNlGyojFzaubKCj7a1/DSoWPJ7dKJoWOQkb/2W5UyeeE4CKxVJDR0z0rn2nIFcWZjPn94t56E5a9lYuTdm+/fKq7nhsRJG9OvGrRMLuGhErgKhnTkmgyBRwzh0zEjjnOin/qIT+/zdXbzFO1YksmSRlJOZkcZVZw3gy2fmMfu9Lfxyzppm74xfvnUX3/rDEkb068b0iQV0iONZJ9K2jokgSOQwDkP7dOHCgt4Undib0YN7HPbY/CLHmoz0NL50Zh6Xj+zPi8u28ss31jT7nOXlW3cx7Q9LGNgtjfrc7UzUIaOU126DoKyi5tNhHBbEMYxD58x0zh3a69MTvXrUn0jT0tOML5x+PJ8/tR9/W76dB99YzYdbdsVsv2FXA//4WAmn9u/O9IkFjD9JgZCq2mUQrNy2m/H3zT3i9Yf368bYYZFP/WcOyNGDv0UOQ1qaMfmUvlx8ci5zVu7gF6+voXRT7EeQL9tczfWPlnBaXiQQxp2oQEg17TIIYo2oGEu3rAwuGBb5xD92WG+Nwy6SAGbG+JNyGXdiH+auquCB11Y3Gwjvl1fzzUdKOD0/m+kTCyga1luBkCLaZRC0xAxOy8tmbEEvxp7YR8M4iBxFZkZRdMiUuasquP/VVbxXHvuenPc2hfjG7xczMj+b2yYN030IKeCYCYJeXTK5sCAyjMMFBb3poeezirSpxoEwZ+UOfvzHpazbFfvbe+mmEFNnLuLMAdlMnziMCxQISdNugyA9zThzQHb0WL+GcRBJFZ8cMrIxWTT0HcH9r63ig82xTyov3RjiupmLGDUwh9smDuO8oRq6oq3FFQRm1gN4ChgErAeudPeqJtqFgWXRyY3ufll0/mBgFtADWAp8zd1bHOBnQI/jKPm3SRrGQSSFmRkThucy/qQ+vLZiBw+8tqrZq4yWbKji2t8tZMyQnnzvkpMYqdFO20y8B87vAl539wLg9eh0U/a5+8joz2WN5v8MuD+6fhVwfWvetHunDgoBkXbCzJg0IpcXbjmfGV8bxYh+3ZptP79sJ1MeeptvP76EsoqaNqoy2OINgsuBR6OvHwWmtHZFi3z3Gw88eyTri0j7YmZcdHJfXrjlfH597ShO6tv8E9Be+mAbF90/jx+/sJzqZobJlviZx3EbuJmF3D270XSVu+c00a4eKAXqgXvc/c9m1gtY4O5Do23ygZfc/ZQY7zUNmAaQm5s7atasWUdcdyLU1NTQpYue7Qrqi09Mnz6dcDjMgw8+mOxSUkJL20WDO0u3h/nzmjrKa5rfD3XpAFOGZjIuP4P0dnguMFX+jYwbN26JuxceOr/FcwRm9hrQt4lFPziM9x/g7lvMbAjwhpktA5o6WBhza3D3GcAMgMLCQi8qKjqMt0+84uJikl1DqlBfRGRnZxMKhdQXUa3ZLsYDtzc4L32wjfv+tpKyj5seHqbmADy+oo6FOzP5t0tHcOGw3okv+ChK9X8jLQaBu0+MtczMtptZP3ffamb9gB0x/saW6O8yMysGzgCeA7LNLMPd64E8YMsR/DeISDuWlmZ8/rR+XHxyLs8sKeeB11bFfB7C6h01XDdzERNO6sO/XjqCwb06t3G1x6Z4zxHMBqZGX08Fnj+0gZnlmFnH6OtewHnAco8ck5oDXNHc+iISDBnpaVwzegDFd47j9knD6NTMgI+vf7SDi++fx89e/og9tUf2TBE5KN4guAeYZGargUnRacys0Mx+G20zHCgxs/eI7Pjvcffl0WXfA243szVAT+B3cdYjIu1cp8x0vjOhgDl3FvGlM/vHbFcXbuBXxWuZcN9cni/dTDznO4MurvsI3H0nMKGJ+SXADdHX7wCnxli/DBgdTw0icmzq2z2Ln185kqljBnH3C8tZsuEztygBsG3Xfm6dVcoTCzfyH5edzPAWLk+Vz9IAPCKS0k7Pz+bZG8fwv1eP5PjusQeMXLSuks//4k1++PwHhPa2eF+qNKIgEJGUZ2ZcPrI/r99RxHfGD405dHyDw2PzNzDu3mKeXLSRcIMOF7WGgkBE2o1OmencftGJvHbbWCYOz43ZrmrvAb7/x2VMeehtlm5s+pCSHKQgEJF2Z0DP4/jt1EIe+cZZDGnmEtJlm6v50v+9wx1Pv8eO3fvbsML2RUEgIu1W0Yl9eHn6hdx1yUl0zox9uelzS8sZf+9cfjOvjAOH+WCrIFAQiEi7lpmRxo1jT+CNO4uYMvL4mO1qauv56YsrmPzAPN5cXdGGFaY+BYGIHBNyu2XxwNVn8MyNY5od4XRtxR6+9rtFfOsPJWyq3NuGFaYuBYGIHFPOGtSDv9xyPj+eckqzw9W/8uF2Jv58Lve/uor9B8JtWGHqURCIyDEnPc342jkDKb6ziH84ewCxHnhWW9/A/76+mgn3zeXlD7YF9u5kBYGIHLNyOmfy0y+eyl9uPp9RAz8zQv6nNof2cePjS7hu5iLW7Ajew3AUBCJyzDulf3eevXEMP7/ydHp37Riz3ZurP2byA/P46V+Xs3t/cB6GoyAQkUAwM750Zh5v3DGWaRcOISPGA27qG5zfvLmO8ffN5bkl5TQE4O5kBYGIBErXrA78y+eG8/L0C7mgoFfMdhW7a7njmfe44tfv8MHm6jassO0pCEQkkIb26cJj3xzNw18bRV5Op5jtlm4M8YVfvsX3/7iMyj3H5mB2CgIRCSwz4+KT+/La7WO5fdIwsjo0vUt0hycXbWTcvcU8Nn899cfY3ckKAhEJvKwOkYfhvHb7WC45palHtEdU7zvAD5//kC/88m0WratswwqPLgWBiEhUXs5x/OraUTxxw9kU9OkSs92Krbu48uH53DrrXbZVt//B7BQEIiKHOG9oL1689QL+9fPD6dox9oMcny/dwvj7ivlV8Vpq69vv3clxBYGZ9TCzV81sdfT3Z+7YMLNxZlba6Ge/mU2JLnvEzNY1WjYynnpERBKlQ3oaN1wwhNfvHMsVo/JitttbF+ZnL3/E5AfeZM7KHW1YYeLE+43gLuB1dy8AXo9O/x13n+PuI919JDAe2Av8rVGTf/5kubuXxlmPiEhC9emaxb1fOZ3nvn0up/bvHrPduo/38I3fL+b6RxazYeeeNqwwfvEGweXAo9HXjwJTWmh/BfCSu2vIPxFpV0YNzOHPN53Hf33pVHp0zozZ7vWPdjDp5/O495WV7K2rb8MKj1y8QZDr7lsBor/7tND+auDJQ+b91MzeN7P7zSz2vd8iIkmWnmZcM3oAc+4oYuqYgcS4OZm6cAO/nLOGiffN5YX3t6T8YHbWUoFm9hrQ1PVUPwAedffsRm2r3L3JkZ3MrB/wPnC8ux9oNG8bkAnMANa6+90x1p8GTAPIzc0dNWvWrBb+046umpoaunSJfVVBkKgvIjVMI38AAAgQSURBVKZPn044HObBBx9MdikpIQjbxabdDTy+vJaVVc3fVzCsu3PdKceR1zW51+eMGzduibsXHjq/xSBojpmtBIrcfWt0p17s7ifGaHsrcLK7T4uxvAi4090vbel9CwsLvaSk5IjrToTi4mKKioqSWkOqUF9EFBUVEQqFKC3VqS4Iznbh7vzl/a38519XsG1X7EtJPxka+7ZJw5p9TsLRZGZNBkG88TQbmBp9PRV4vpm213DIYaFoeGBmRuT8wgdx1iMi0qbMjMtOP57X7xjLPxWdQGZ607vVcIPzyDvrGX9vMU8t3phSg9nFGwT3AJPMbDUwKTqNmRWa2W8/aWRmg4B8YO4h6z9hZsuAZUAv4Cdx1iMikhSdO2bw3ckn8cptFzL+pNinS3fuqeN7zy3ji//3NqWbQm1YYWyx75RoBXffCUxoYn4JcEOj6fVA/ybajY/n/UVEUs3gXp2Z+fWzeH3Fdu5+YTkbdjZ9keR75dVMeehtrizM47uTT6JXl+RdK6M7i0VEjoIJw3N5ZfqF/PPFJ5KZHrvd0yXljLu3mJlvrUvaYHYKAhGRoySrQzo3jRvKPRd04gunHx+z3e799dz9wnI+94s3eWftx21YYYSCQETkKOuRlcaD15zBrGnncFLfrjHbrdpew1d/s5CbnljK5tC+NqtPQSAi0kbOGdKTF245n/+47GS6ZcU+RfvXZVuZcF8xD76+mv0Hjv5gdgoCEZE2lJGextRzBzHnziKuPisfi3F38v4DDdz36iouun8ery7fflTvTlYQiIgkQc8uHbnny6fx/E3nccaA7JjtNlbu5R8fK+EbjyymrKLmqNSiIBARSaLT8rJ57sZz+Z8rTqNXl9iD2RWvrODiB+bxXy+toKY2sYPZKQhERJIsLc34SmE+b9xZxPXnDyY9xmh2B8LOw3PLmHBfMX9+d3PCDhcpCEREUkS3rA7826UjePnWCzhvaM+Y7bbvqmX6U6Vc+fB8PtxSHff7KghERFJMQW5XHr/+bH71D2fSP7tTzHaL11fxhQff4l//vIzQ3rojfj8FgYhICjIzLjm1H6/dPpbvjB9KZkbTu+sGh8cXbGTcvcU8sXAD4SMYzE5BICKSwjplpnP7RSfy2m1jmTQiN2a7qr0H+MGfPuDyh95iyYbKw3oPBYGISDswoOdx/Oa6Qh795miG9Oocs90Hm3fx5V/N5/anStnRzPMRGlMQiIi0I2OH9ebl6Rfy/UtOonMzo9n98d3NjL9vLjPmraWuvvnB7BQEIiLtTGZGGt8aewJv3FnEF8/4zAj/n6qprec/X/yIyf87j3mrKmK2UxCIiLRTud2yuP+qkTx74xhG9OsWs11ZxR6um7ko5nIFgYhIO1c4qAd/ueV8fjLlFLKPO/znISsIRESOAelpxrXnDGTOHUVce84AYtyc3CQFgYjIMSSncyY/mXIqf7nlfAoH5rRqnbiCwMy+YmYfmlmDmRU2026yma00szVmdlej+YPNbKGZrTazp8ws9ohLIiLSaicf351nbhzDA1eNpE/X5p+HHO83gg+ALwHzYjUws3TgIeASYARwjZmNiC7+GXC/uxcAVcD1cdYjIiJRZsaUM/rzxp1FfGvskJjt4goCd1/h7itbaDYaWOPuZe5eB8wCLjczA8YDz0bbPQpMiaceERH5rC4dM/j+JcNjLo/9rLTE6Q9sajRdDpwN9ARC7l7faH7MC2LNbBowLTpZY2YtBdDR1gto+6dMpyb1xUG9zEx9EaHt4qBU6YuBTc1sMQjM7DWgbxOLfuDuz7fijZs6d+3NzG+Su88AZrTi/dqEmZW4e8zzIkGivjhIfXGQ+uKgVO+LFoPA3SfG+R7lQH6j6TxgC5F0zDazjOi3gk/mi4hIG2qLy0cXAwXRK4QygauB2R55tM4c4Ipou6lAa75hiIhIAsV7+egXzawcGAP81cxeic4/3sxeBIh+2r8ZeAVYATzt7h9G/8T3gNvNbA2Rcwa/i6eeNpYyh6lSgPriIPXFQeqLg1K6LyxRz7wUEZH2SXcWi4gEnIJARCTgFAQJYGZ3mpmbWa9k15IsZvY/ZvaRmb1vZn8ys+xk19TWYg2lEjRmlm9mc8xsRXQImluTXVMymVm6mb1rZi8ku5ZYFARxMrN8YBKwMdm1JNmrwCnufhqwCvh+kutpUy0MpRI09cAd7j4cOAe4KcB9AXArkQtlUpaCIH73A9+lmZvhgsDd/9boLvEFRO4LCZImh1JJck1J4e5b3X1p9PVuIjvB2I/ROoaZWR7weeC3ya6lOQqCOJjZZcBmd38v2bWkmG8CLyW7iDbW1FAqgdz5NWZmg4AzgIXJrSRpHiDyQbH5hwYnWVuMNdSuNTfEBvAvwEVtW1HytGa4ETP7AZFDA0+0ZW0p4LCGTAkCM+sCPAdMd/ddya6nrZnZpcAOd19iZkXJrqc5CoIWxBpiw8xOBQYD70UGUiUPWGpmo919WxuW2GZaGm7EzKYClwITPHg3qMQaSiWQzKwDkRB4wt3/mOx6kuQ84DIz+xyQBXQzs8fd/dok1/UZuqEsQcxsPVDo7qkwwmCbM7PJwM+Bse5ekex62pqZZRA5ST4B2ExkaJWvNrqLPjCiQ8w/ClS6+/Rk15MKot8I7nT3S5NdS1N0jkAS5ZdAV+BVMys1s18nu6C21MJQKkFzHvA1YHx0WyiNfiqWFKVvBCIiAadvBCIiAacgEBEJOAWBiEjAKQhERAJOQSAiEnAKAhGRgFMQiIgE3P8HVvw9BpUUhAQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "CCCE8LCoFEiB" }, "source": [ "and here's how you can use the optimize to train the network.\n", "Note that we call `zero_grad` _before_ calling `loss.backward()`, and then we just call `optimizer.step()`. This `step` function will take care of updating all the parameters that were passed to that optimizer's constructor." ] }, { "cell_type": "code", "execution_count": 78, "metadata": { "id": "sJfVNGhPod5i" }, "outputs": [], "source": [ "for _ in range(100):\n", " y = net(x)\n", " loss = loss_fn(y, y_target)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": { "id": "lOM7_Xr6FUaV" }, "source": [ "And we see that this trained a network quite well" ] }, { "cell_type": "code", "execution_count": 79, "metadata": { "id": "219Pojw0os9X" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dd3xUVdrA8d+TTg8QCCWhoxQpgYRqiYKKFVdQAYEgIPa1rGvbXfXV9V3LWlbXFgHpApZVVrFgiY0aIPTeQ2+BhPTJef/I8BJyb0LCTKZknu/nkw8z59ybebxO5pl77znPEWMMSimlAleQtwNQSinlXZoIlFIqwGkiUEqpAKeJQCmlApwmAqWUCnCaCJRSKsC5JRGIyGQROSQia8voFxF5U0S2ishqEelRoi9JRLY4f5LcEY9SSqmKc9cZwRRgUDn91wDtnT8TgHcBRKQB8AzQG+gFPCMi9d0Uk1JKqQpwSyIwxvwCHCtnk8HANFNsMRApIk2Bq4EFxphjxpjjwALKTyhKKaXcLMRDr9Mc2FPiebqzrax2CxGZQPHZBDVq1OgZGxtbNZFWUFFREUFBeosF9FictmfPHowxtGjRwtuh+AR9X5zhK8di8+bNR4wxjUq3eyoRiE2bKafd2mhMMpAMEB8fb1JTU90X3XlISUkhMTHRqzH4Cj0WxRITE8nIyCAtLc3bofgEfV+c4SvHQkR22bV7KkWlAyW/wscA+8ppV0op5SGeSgTzgNHO0UN9gBPGmP3At8BVIlLfeZP4KmebUkopD3HLpSER+QhIBKJEJJ3ikUChAMaY94D5wLXAViAbuMPZd0xEngeWOX/Vc8aY8m46K6WUcjO3JAJjzPBz9BvgvjL6JgOT3RGHUkqpyvP+bWyllFJepYlAKaUCnCYCpZQKcJoIlFIqwGkiUEqpAKeJQCmlApwmAqWUCnCaCJRSKsBpIlBKqQCniUAppQKcJgKllApwmgiUUirAaSJQSqkAp4lAKaUCnCYCpZQKcJoIlFIqwGkiUEqpAKeJQCmlApxbEoGIDBKRTSKyVUSesOl/XUTSnD+bRSSjRJ+jRN88d8SjlFKq4lxes1hEgoG3gSuBdGCZiMwzxqw/vY0x5uES2z8AxJX4FTnGmO6uxqGUUur8uOOMoBew1Riz3RiTD8wGBpez/XDgIze8rlJKKTdwRyJoDuwp8Tzd2WYhIi2B1sCPJZojRCRVRBaLyE1uiEcppVQluHxpCBCbNlPGtsOAT4wxjhJtLYwx+0SkDfCjiKwxxmyzvIjIBGACQHR0NCkpKS6G7ZqsrCyvx+Ar9FgUy8jIwOFw6LFw0vfFGb5+LNyRCNKB2BLPY4B9ZWw7DLivZIMxZp/z3+0ikkLx/QNLIjDGJAPJAPHx8SYxMdHVuF2SkpKCt2PwFXosikVGRpKRkaHHwknfF2f4+rFwx6WhZUB7EWktImEUf9hbRv+IyIVAfWBRibb6IhLufBwF9AfWl95XKaVU1XH5jMAYUygi9wPfAsHAZGPMOhF5Dkg1xpxOCsOB2caYkpeNOgLvi0gRxUnpxZKjjZRSSlU9d1wawhgzH5hfqu3pUs+ftdlvIdClsq+3+1g2eYUOwkOCK7urUkqpUvxyZvGJnALGT00lO7/Q26EopZTf88tEAPDrliOMmrSUE9kF3g5FKaX8mt8mAoDlu45zW/IiDmfmeTsUpZTyW36dCAA2HsjklvcWsudYtrdDUUopv+T3iQBg59FsbnlvEVsPZXo7FKWU8jt+mQiCxDqZ+cDJXG55bxGr0zNs9lBKKVUWv0wEraNqUa9GqKX9eHYBIz5YwuLtR70QlVJK+Se/TAQ1w4KZe1dfGtcJt/Rl5RWSNHkpP2w46IXIlFLK//hlIgC4sEkdPrm7H7ENalj68gqLmDB9OZ+v3OuFyJRSyr/4bSIAaNGwJp/c3Y8Lo+tY+hxFhofnpjFt0U6Px6WUtxljOJKVx6GTuZxd1UUpK7eUmPCm6LoRzLmrD2M+XEbanrNvFBsDT3+xjpM5Bdx3eTvE5iazUtVJRnY+c1P3MH3xLvYcywGgRmgwLRvWpEOTOtyaEEu/tlFejlL5Gr9PBACRNcOYOb43E6an8vtW643if363mRM5BTx1bUdNBqpaOpFTwCvfbuST5enkFhSd1ZdT4GDjgUw2Hsjk87R9DO0Zw9+u72Q74EIFJr++NFRSrfAQJo9J4OrO0bb9H/y6g8c/XY2jSE+TVfVyJCuPIe8uZMbi3ZYkYOeT5elc9frP/LhRB1SoYtUmEQCEhwTz9ogeDOkRY9s/NzWd+2etIK/QYduvlL85fiqfkROXsPVQVqX2O3gyj7FTUnn1u016D0FVr0QAEBIcxCtDu3JH/1a2/V+vPaCVS1W1cCKngNGTl7LxgP2MepHiodbleevHrfzti7V6phzgqsU9gtKCgoSnr+9EZI0wXv9+s6X/1y1HGDlxCR+O6UW9mnqdVPmf3AIHd3y4lDV7T1j6aoYFM7xXC0b1aUnLhjU5nJXHgvUH+cf8jWTlWb8AzVi8m4zsAl67tTthIdXuu6GqgGr7f11EeHBge569oZNt/4rdGdyWvIhDmbkejkwp173w1QZW7LaWU6kVFsz0cb352/WdaBVVCxGhcZ0Ibu/dkm8fvpRLL2hk+/u+XL2f8dP0TDlQVdtEcNqY/q159ZZuBAdZRwttPJDJre8t0sqlyq8sWH+Q6Yt3WdojQoOYPCaBni3r2+7XPLIGU+9I4J7Etrb9v2w+zMiJS8jIzndrvMr3VftEADCkZwzv3N6DsGDrf+7pyqVbDmrlUuX7DpzI5bFPVlnaw4KDmDg6gd5tGpa7v4jw+KAOPHVtB9v+FbszuO39xRw8qWfKgcQtiUBEBonIJhHZKiJP2PSPEZHDIpLm/Blfoi9JRLY4f5LcEY+dqzs3YcodCbY3zw6czOXW97VyqfJtjiLDI3PTOG6zKt9fruvIxe0rPlFswqVteXloV2xOlNl0MJOh7y1k55FTroSr/IjLiUBEgoG3gWuATsBwEbG7MD/HGNPd+TPRuW8D4BmgN9ALeEZE7M9r3aBfuyhm3dmHSJsbxKcrly7appVLlW96/5dtLLR5fw7o0JjRfVtW+vfdGh/LuyN72p4p7zmWw9D3FrF+38nzilX5F3ecEfQCthpjthtj8oHZwOAK7ns1sMAYc8wYcxxYAAxyQ0xl6h4byZwJ5VQu/XAp36/XiTbKt6TtyeC176wj4BrXCefloV3Pe8b81Z2bMGVsArVszpSPZOVxW/IiUnceO6/frfyHO4aPNgf2lHieTvE3/NKGiMilwGbgYWPMnjL2bW73IiIyAZgAEB0dTUpKiktBPxoXxCvLhMM5Z4+fzi8sYsL0VMZ3Cadfs7IPT1ZWlssxVBd6LIplZGTgcDjcfixyCg1P/55DYamx/gIkXSisSV3k8ms82jOU11IdZJa66pSZW8iI5EXcFxdOt0aV+7jQ98UZvn4s3JEI7L6KlJ6d8l/gI2NMnojcDUwFrqjgvsWNxiQDyQDx8fEmMTHxvAM+LfHiXEZNWsqmUjeKiwwkr84jpnU7RvdtZbtvSkoK7oihOtBjUSwyMpKMjAy3H4tH5qRxOMc6sm3CZW2475qObnmNROCSvlmMnrSEfSfOvlGcXwRvrczn1Vs7Mri77fc0W/q+OMPXj4U7Lg2lA7ElnscA+0puYIw5aozJcz79AOhZ0X2rUmNn5dLusZG2/U9/sY63ftiiU/CV13y+ci+f2ayr0TWmHn+68kK3vla7xrX5+J5+tGlUy9JXWGR4aE4aUxfudOtrKt/gjkSwDGgvIq1FJAwYBswruYGINC3x9EZgg/Pxt8BVIlLfeZP4Kmebx5yuXNq/nf2wu1cXbOaFrzZoMlAel348m79+vtbSXjMsmH8Ni6uSWcDNI2vw8V196dK8nqXPGHhm3jr+9b1+OapuXH4nGWMKgfsp/gDfAMw1xqwTkedE5EbnZn8UkXUisgr4IzDGue8x4HmKk8ky4Dlnm0edq3LpxN+0cqnyrKIiw+OfrrYtCfHc4ItoHWX91u4uDWuHM+vO3vQtY07C699v5n/+u54i/XuoNtzylcIYM98Yc4Expq0x5gVn29PGmHnOx08aYzobY7oZYy43xmwsse9kY0w758+H7ojnfJyuXDq0p1YuVd43c8ku27U1bujWjCE9Kn6d/nzViQjlwzsSuKqT/ZejKQt38sjcNAoc5y57rXxfQMwsrqiQ4CBeHtKVsf1b2/Zr5VLlCbuOnuJ/52+0tDepG8Hfb7rIY4srRYQG887tPbiljC9Hn6ft467py8nJ1y9H/k4TQSlBQcLfru/IwwMvsO0/Xbn0VIGeFiv3Kyoy/Pnj1eQUWD9cXxra1eOrioUEB/Hy0K5MuLSNbf+PGw8xYXoquTbxKv+hicDG6cqlz5RTufTFpblauVS53fTFu1hqM4FreK8WXFZG5dCqJiI8dW1HHh9kX5/o1y1HuGv6cr1s6sc0EZTjjnIql+7JLNLKpcqtDmfm8c/vNlnaY+rX4C/XuWe+gCvuSWzLP27uYluf6OfNh7lv5gryC/WegT/SRHAOWrlUecqLX28kM9d6/+nloV2pHe4ba0gN79WCt4b3sP1y9P2GQzw8J01HE/khTQQVoJVLVVVbtvMYn65It7QPS4ilX9uKVxX1hOu6NuWN27rbnhl8tWY/76Rs9XxQyiWaCCroXJVLhycv1sql6rwUOor4m83EsXo1QnmsjOvy3nZDt2a8ems37AYwvbpgMz9tOuT5oNR500RQCd1jI5l7l33l0lP5Dq1cqs7LtEW7bBegf2zQhTSoFeaFiCrmD3ExvDykq6XdGHjwo5Ucytb7Bf5CE0ElXRBdh0/v6UejGtavQvmFRdw1Yzn/WWk9xVfKzpGsPF5fYC0v3TWmHsMSWnghosq5JT6We22WvjyZW8ibK3J1zo2f0ERwHmIb1OQvvSO4MLqOpc9RZHh4ziqmLdrp8biU/3n1u01kliojIQLPD77I9oasL/rTVRdyic3qaOlZhic+XaN1ifyAJoLzFBkRxJy7+hDXQiuXqvOzdu8JZi/bY2kflhBLtzIq4vqi4CDhzWFxxNSvYembt2ofk3/f6fmgVKVoInBBZM0wZozrzcXt7Ed1aOVSVRZjDM/OW0fpt0adiBAevcq95aU9oX6tMN4f1ZOIUOtHyv/O36ADKXycJgIX1QoPYdKYeAZ1bmLbP/G3HTz2yWoKtTiXKuG/q/eTuuu4pf3hgRfQsLZ1MII/6NysHi/ebL157Cgy3D9rBfsycrwQlaoITQRuEB4SzL9HxJVZnOvj5ek88NFKnXWpAMjJd/CP+Rss7W0b1WLUeSxC70tuimvOHf1bWdqPnsrnnhnLtSaRj9JE4CYhwUG8NKQr4y4uu3LpvTO1jLUqLjG9/4S1TtXTN3Qm1GYGu7956tqO9GrdwNK+Kv0Ez85b54WI1Ln4/7vOhwQFCX+9riOPXGlfufT7DQe5e7p+KwpkuQUOkn/Zbmkf2LGx14rKuVtocBBvj+hB/XDrqKfZy/Ywa8luL0SlyqOJwM1EhD8OaM+zZVQu/WnTYe6cpmV7A9Uny9M5lJl3VpsIPOGmReh9RaM64dwfF25bo+uZeWtZsdt6f0R5jyaCKjKmf2teHtLVdgq+lu0NTAWOIt77eZul/dqLmtKucW0vRFS12kYG8z+DO1vaCxyGe2Ys1zLuPsQtiUBEBonIJhHZKiJP2PQ/IiLrRWS1iPwgIi1L9DlEJM35M6/0vv7s1oRYXru1W5llex+YtVKX+gsgX6TtI/24deTMfZe380I0njG8VwuGJcRa2g+ezOP+mfr+9xUuJwIRCQbeBq4BOgHDRaT0dZGVQLwxpivwCfByib4cY0x358+NVDN/iIvhjWFxtrNEv1t/kEfmrsKhZXurPUeRsa3KOaBDYzo1q+uFiDznfwZ3tp0gt3TnMV74yjp6SnmeO84IegFbjTHbjTH5wGxgcMkNjDE/GWNOr+CyGLAfZ1lN3ditGW8Oi7M9M/jvqn089slqreFezX2z9gDbD5+ytN93RfU9GzgtPCSY90b2IKq2tYDelIU7tTaXD3DHahfNgZLz5NOB3uVsPw74usTzCBFJBQqBF40xn9vtJCITgAkA0dHRpKSkuBKzy7KysioVQy1g3EVhTFyTT+mP/E9XpHPs8AFGdwrz2MLk7lTZY1FdZWRk4HA4LMfCGMPLi6zXwzs2COLk9lWkWAcRVQul3xfjOwXx8jIo/Z3nsY9XcXLPJlrWta73UV34+t+IOxKB3SeX7ddbERkJxAOXlWhuYYzZJyJtgB9FZI0xxnJHzRiTDCQDxMfHm8TERJcDd0VKSgqVjSERaLN0N09+tsbS99OeQlq3iOVv13f0u2RwPseiOoqMjCQjI8NyLH7feoRd3y6xbP+3mxPoV0Z5kuqg9PsiEQhtvIPnv1x/1nYFRfD+Ovji/r40sinxXh34+t+IOy4NpQMl7wbFAPtKbyQiA4G/ADcaY/5//JwxZp/z3+1AChDnhph81vBeLcocWjr59x288u0mrU1UzdiNFOoWG0nftg29EI13je3fisHdm1na953I5Z4Zy3X2vZe4IxEsA9qLSGsRCQOGAWeN/hGROOB9ipPAoRLt9UUk3Pk4CugPnP11oRoa0781T1xjv/LUOynbeOtHXeqvuli37wS/bjliab/nsjZ+d+bnDiLCizd3pUMTawn31F3HefqLtfpFyAtcTgTGmELgfuBbYAMw1xizTkSeE5HTo4BeAWoDH5caJtoRSBWRVcBPFN8jqPaJAODuy9ry4ID2tn2vLdhM8i/Wb5HK/9jNIm7VsCZXdrIvUhgIaoQF88HoeNvV12Yv28PUhTs9H1SAc8c9Aowx84H5pdqeLvF4YBn7LQS6uCMGf/TQwPbkFjp4/2frh8X/zt9IeEgwSf1aeT4w5Rbpx7P5cvV+S/udl7bxm0Vnqkpsg5q8c3sPRk5cQmGpu8fPf7WB1o1qV5uSG/5AZxZ7kYjwxKAOjCnjw/6ZeeuYvVTrsvirSb/tsMwRiaodxpAeATV6ukx92jS0nXnsKDLcP3MFWw5a13FWVUMTgZeJCM/c0InhvayzLwGe/M8aPl+518NRKVedyC5gjs3qY2P6tSIitPoOk6ys23u3ZGQf69rMmXmFjJ26jKNZeTZ7KXfTROADRIQXburCzXHNLX3GwJ8+XsXXa6yXGJTvmrV0N9n5Z9eSqhkWzMg+/r3eQFV45obO9G9nHUG151gOd8/QmlyeoInARwQFCS8P7cp1XZpa+hxFhgc+WsmPGw96ITJVWQWOItsbnrfGxxJZ03qDNNCFBgfxzoietImqZelbtvM4T32mI4mqmiYCHxISHMQbw7ozsGO0pa+wyHDPjBWk7jzmhchUZcxfs58DJ8+eSSwCY/vbL1qkoF7NUCaNSaBejVBL36cr0nnXZi6Gch9NBD4mNDiIt2+P45L21hmneYVFjJ2yjM16E82nTfpth6Xt6k5NaNGwphei8R+to2rx3siehNiMqHr5m018s1Yvj1YVTQQ+KDwkmORR8fRpY13u72RuIaMnLWWvLgTuk7ILDavTT1jax12iZwMV0bdtQ/5+00W2fQ/NSWPtXuuxVa7TROCjaoQFMykpgW4x9Sx9B07mMnrSEo6fyvdCZKo8x3Ot17K7xdQjvmV9L0Tjn4b1asGdNokzt6CIcVOXcfCkLmjjbpoIfFit8BAmj0mgtc1NtG2HT3HHlGVk5xd6ITJlJ7fAQWa+NRGMvbh1QJaTcMUT13RkQIfGlvaDJ/MYPzWVnHwdSeROmgh8XMPa4Uwb24vGNlUZ0/ZkcO/MFbrKk4/Yf8L6TbVpvQiutRkJpsoXHCT8a3icbU2iNXtP8MjcNF3Dw400EfiB2AY1mTq2F3XCrRVBUjYd5nFd2MbrDpzI5XCmdfJTUr9WhNos4K7OrXZ4CJPGJBBV2/ol6Ou1B3h1wSYvRFU96TvUT3RsWpcPkuIJC7H+L/ts5V5e/GajF6JSpyX/sp2iUmPd64SHMLyXddasqrjmkTX4YHRP2/f92z9t47MVurqZO2gi8CN92jTkzWHdbZe8TP5lu1Ys9ZIjWXnMWrrL0p7Ur5XtuHhVOXEt6vPPW7rZ9j3x6RqdW+MGmgj8zKCLmvL3m+wLtv7v/I36DckLPvh1O7kFZ9+nqRkWzNiLdciou9zYrRkPDbSWbc93FHHX9OXsOZZts5eqKE0EfmhE7xY8cuUFtn2PfbKanzYdsu1T7nf8VD4zFlnPBkb2aWlbb1+dvwcHtOeGbtbVzY6eymfc1GVk5hZ4IarqQROBn3rginaMsilgVlhkuHfGClbuPu6FqALPh7/v4FSpoYzhIUGM1wlkbicivDK0K91jIy19mw9m8cBHKynUEXTnRROBnxIRnr2xM9d2sa50lVPgYOyUZWw9lOWFyAJHZm4BH9oUlxveqwWN60R4PqAAEBEaTPLonjSrZz2+KZsO88L8DV6Iyv9pIvBjwUHC67d1p28bawnf49kFJE1eygGbse3KPaYv3kVm7tkT+gSYcGkb7wQUIBrXiWDSmARqhlnXdfjw953MWGy9VKfK55ZEICKDRGSTiGwVkSds+sNFZI6zf4mItCrR96SzfZOIXO2OeAJJeEjxN6ROTeta+vZm5JA0eSknsvXaqbvl5DuY9Ku1uFy9cKFZZA0vRBRYOjaty5vD4rCbsP3MvHX8tuWI54PyYy4nAhEJBt4GrgE6AcNFpFOpzcYBx40x7YDXgZec+3YChgGdgUHAO87fpyqhTkQoU8Ym0KKBtbrlpoOZjJu6jNwCnZLvTnOW7eZoqVpPAjSsoSfZnjKwUzRPXdPR0u4oMtwzc7leGq0Edyxe3wvYaozZDiAis4HBwPoS2wwGnnU+/gT4txQXXxkMzDbG5AE7RGSr8/ctKu8FN23aRGJiohtCP38ZGRlERlpvWnlTfoGDo/tOWkpOfAm0/lcYF0TXsf0G5SpfPBZVyRhI23OcvMKzj7M5upNdx4zX35u+wlPvi6LDpziUefYl0ANA/PRgLmpWj5Bg79d58vW/EXckguZAycVZ04HeZW1jjCkUkRNAQ2f74lL7WtdrBERkAjABIDQ0lIyMDDeEfv4cDofXY7DTvBbszoTSFSeOZ+ezad8xmtRy/zdWXz0WVeVEnrEkAYBgDMaYgDoW5fHU+6J+KGSFCtkFZ7/pcwscrN97nNi6QXg7Ffj634g7EoHdMS5d+KasbSqyb3GjMclAMkB8fLxJTU2tTIxul5KS4rPf/H7feoQxHy6lwGE9lLdf3pY/X93Bra/ny8fC3RxFhoGv/Yw5cuqs9qs6RbN58qNkZGSQlpbmpeh8iyffFxnZ+dz09u/sPGqdWHZ1fAwvDenq1QqwvvI3UtYxcMfXw3QgtsTzGGBfWduISAhQDzhWwX1VJfVvF8Xrt3W3vQz09k/bmPK79Sanqphv1h5gR6kkAHDf5e28EI06LbJmGJPGJFA3wvrddm5qOh/8ut0LUfkPdySCZUB7EWktImEU3/ydV2qbeUCS8/FQ4EdTvBr1PGCYc1RRa6A9sNQNMQW867s245nrS9+zL/Y/X67nv6s031aWMYb3bNbOvaR9FN1sJjkpz2rbqDbvjuxJsE0xrn98vZEF6w96ISr/4HIiMMYUAvcD3wIbgLnGmHUi8pyI3OjcbBLQ0Hkz+BHgCee+64C5FN9Y/ga4zxijw1vcZEz/1txv803VGHhkbpoOsaukRduPssZmqcR7LmvrhWiUnf7tonhucGdLuzHw4OyVrNunS13accudQ2PMfGPMBcaYtsaYF5xtTxtj5jkf5xpjbjHGtDPG9Do9wsjZ94JzvwuNMV+7Ix51xp+uuoBhCbGW9gKH4a7pqayxWV9X2Xv/Z+vlha4x9ejb1jqhT3nP7b1bckf/Vpb27HwH46emckiXurTQQc/VnIjw95su4spO0Za+U/kOxny41Paatzrb+n0n+XnzYUv7XZe21WUofdBfr+vE5Rc2srTvP5HLndOX67yaUjQRBICQ4CDeGh5HQivrAupHT+UzevISyzhsdTa7tR5aNqzJoIustZ6U9wUHCW8Oj+PCaOtSl6v2ZPDox6swRlf1O00TQYCICA1m4ugE2z+MPcdySJq8jJNaxtdW+vFs/rt6v6X9zkva2N6YVL6hTkQoE5PiaWhTDvzL1ft54/stXojKN2kiCCD1aoYydWwvmtvUwtmw/yQTpqXqKbONSb/twFFqhl7DWmEM7RnjpYhURcU2qEny6HjCbNaN/tcPW/giba8XovI9mggCTJN6EUwd24v6Na1LKC7efoxH5qZZPvQC2fFT+cxeusfSPqZfKyJCtSyWP+jZsj4vD+1q2/fnT1azQtfu0EQQiNo1rs3kMQnUsPkgm7/mAM/OW6fXT52mL95FTqmzpJphwYzqa10USPmum+Ka88AV1qHU+YVFTJiWSvrxwF7qUhNBgIprUZ93R/YgxOYa9/TFu3jzh61eiMq35BY4mGKz8MywhBZE1tRlKP3NwwMv4LouTS3tR7LyGT81lay8Qpu9AoMmggCWeGFjXrnF/pT59e83M2vJbg9H5Fs+Xp7OsVKlpoODhHG6DKVfCgoS/nlLN7rF1LP0bTyQyYMfrQzYy6KaCALcH+Ji+Ot11pruAH/9fA3frD3g4Yh8g6PI8MEv1glkN3ZrZnuzXfmHGmHBfDA6nqY2S13+sPEQ/wjQpS41ESjGX9KGu2yWVywy8MfZK1my/agXovKur9fuZ/cx63VjXYbS/zWuG8EHo+Nt75FN/G0Hs5cG3pmwJgIFwOODOnBzD+tSEPmFRYyflspamxo71ZUxxracxGUXNKKjzZKgyv9c1Lwebwyzr9D718/XsnBbYNXh0kSggOLrpy8N6UqizbT8zNxCRk1awsYDJ70Qmect3GZfXO5uLS5XrVzduQmPD7KuzVFYZLhnxgq2Hw6cpS41Eaj/FxocxDu396C7TUnl49kFjJy4JCDWgX03xVpOomtMPfq0aeCFaFRVuuvSNtxiMzHwRE4B46emkpGdb7NX9aOJQJ2lZlgIH45JoG2jWpa+I1n5jPhgMTurcZG61ZLGvQYAABVpSURBVOkZ/LbVelngnsu0uFx1JCK88Icu9GplTfLbj5zi3pkrLGuAV0eaCJRF/VphzBzfh5YNa1r6DmXmMeKDxeyxuZFaHdgtPNMmqhZXddbictVVWEgQ743qSYsG1vf7wm1HefqL6j/BUhOBstWkXgSz7uxjO1Ry34lchiUvrnazMbcfzuJrm+Gyd12mxeWquwa1wpg8Jp46NktdfrR0N5N+q97Lu2oiUGVqHlmDWXf2pkld65jrvRk5DP9gMfsycrwQWdVI/mU7pb/4RdcN56Y462gqVf20a1yHt0f0sE36L8zfwA8bqu9Sl5oIVLlaNqzFzDt706hOuKVvz7HiZHA81/+voR48mctnK6yVKMdf3IbwEC0uFyguvaARz95gXevbGPjjRyvZsL96jpxzKRGISAMRWSAiW5z/WlY+EZHuIrJIRNaJyGoRua1E3xQR2SEiac6f7q7Eo6pG20a1mTW+N1G1rfV1dh3N5qWluX6//N+k33aQX+qmYN2IEIb3buGliJS3jOrbiiSbooKnnEtdHs7M80JUVcvVM4IngB+MMe2BH5zPS8sGRhtjOgODgDdEpOT4xD8bY7o7f9JcjEdVkfbRdZg5vg8NbBb5OJBtGDFxid/+gZzILmDm4l2W9qR+ragdbr1mrKq/v13fiUsvsM6p2ZuRw4Tp1W/dDlcTwWBgqvPxVOCm0hsYYzYbY7Y4H+8DDgHWI6x83oVN6jBzfG8ibdYy2Hooi9snLuZolv8lg+mLd3Iq/+w/7IjQIJL6tfJOQMrrQoKD+PeIONo3rm3pW7k7g8c+WV2tRhKJK/8xIpJhjIks8fy4Mca6MO6Z/l4UJ4zOxpgiEZkC9AXycJ5RGGNsP0lEZAIwASA6Orrn7Nmzzztud8jKyqJ2beubJBDsOungpaW5ZNtU7Y2tE8TjCRHUDvOPUTZ5DsOjP2eTWWre0IAWIYzqZL0vUp6HHnoIh8PBW2+95cYI/Vd1+Bs5lF3E84tyyLRZxfUP7UIZ3K5i5ch95Vhcfvnly40x8aXbz5kIROR7wG4Q9V+AqRVNBCLSFEgBkowxi0u0HQDCgGRgmzHmuXP9x8THx5vU1NRzbValUlJSSExM9GoM3rQ6PYPbJy4hM9eaDTo3q8us8X2oZ3Pm4GumLdrJ01+sO6stOEhIeTSRWJtx5eVJTEwkIyODtDS9wgnV529k2c5j3P7BEss9JIB/j4jj+q7Nzvk7fOVYiIhtIjjnpSFjzEBjzEU2P18AB50f5qc/1A+V8eJ1ga+Av55OAs7fvd8UywM+BHqd33+e8rSuMZFMG9vL9hr6un0nGTV5CSdybL5G+ZACR5FtcbkbuzWrdBJQ1VdCqwb84+Yutn1/mruKtD0ZHo7I/Vy9RzAPSHI+TgK+KL2BiIQB/wGmGWM+LtV3OokIxfcX1roYj/KguBb1mTo2gQib0ZWr008wevJSMnN9Nxl8tXo/e23mQdx1mZaaVmcb0jOGexOtRQfzCosYPzXV7+fTuJoIXgSuFJEtwJXO54hIvIhMdG5zK3ApMMZmmOhMEVkDrAGigL+7GI/ysJ4tG/Bwzwjb2u6r9mQw5sNlPrkEoDHGtpzEgA6N6dBES00rq0evupBBNqVGjmTlMW5qKqd88H1eUS4lAmPMUWPMAGNMe+e/x5ztqcaY8c7HM4wxoSWGiP7/MFFjzBXGmC7OS00jjTHVv7RlNXRhg2Amj0kgItT6dlq+6zhjpywjO9+3/kgWbjvKxgOZlvZ7bL71KQXFpdpfu60bFzW3flHYsP8kD85O89ulLnVmsXKLvm0bMikpgfAQ61tq6Y5jjJ/qW2OvP/jVem+gR4tI4m2qUCp1Ws2wECaOTiC6rnVE2fcbDvLyNxu9EJXrNBEot+nfLork0fGEBVvfVgu3HWXC9OU+kQy2HMwkZdNhS/udl+i9AXVuTepFMHG0/Rnw+79sZ+6yPV6IyjWaCJRbXXZBI94f1ZPQYOs8gl82H+b+WSvIL/RubaLJv1srScY2qKGlplWFdYmpxxu32VfEeeo/a1jsZ+t8ayJQbnd5h8a8c3tPQmyqOH6/4RAPzl5JoZcW+ziSlcenNsXlxvZvraWmVaUMuqgpf776Qkt7YZHh7hnL/WoBJ00Eqkpc2Smat4bH2X64fr32AI996p0p+jMW77KckdSJCOHW+FiPx6L8372JbbnZpkx5RnYBY6cu40S27w6fLkkTgaoy13Rpymu3dsNuhcfPVuzllW83eTSenHwH0xdZi8uN6N2CWlpcTp0HEeEfQ7oQ39JaUGH74VPcN8s/lrrURKCq1ODuzXl5SFfbvndStjF14U6PxTJt0U6Onjq7qFBIkDBGi8spF4SHBPP+qJ7ENrCu5vfb1iM8O8/3l7rURKCq3C3xsTx/00W2fc/+dx3z1+yv8hhO5hbwrs0Esuu7NqVpPesfsFKV0bB2OJOSEqhjc2Y5c8luvt/lW/NoStNEoDxiVJ+W/HFAe0u7MfDQnDSW7jhWpa8/8dcdZJS6XhskcP8V1piUOh8XRNfhrRFx2I05mLUxn5822ZZi8wmaCJTHPDywPcMSrDdl8wuLuHNaKlsPWWf6usPRrDwm2UwgG9IjhnY29eaVOl+JFzbm6ettlroEHpi1ks0Hq+Y97ipNBMpjRIS/33QRAzo0tvSdyCkgafKyKlny8p2UbZaFZ8KCg3hwoJ4NKPdL6teKkX2sS5xm5RUydsoyn1y8SROB8qiQ4CDeGhFHt9hIS9/ejBzumOLeIXe7jp5ius0ylCN6tyCmvpaaVu4nIjxzQ2cuaR9l6Us/nsNd05eTV+j9GfYlaSJQHlczLIRJSfG0bGj9IF637yS3vr+Ig244M9h19BTDkxdb5g3UDAvmvsvbufz7lSpLaHAQ/x7Rg7aNaln6Uncd58lP1/jUSCJNBMoromqHM/WOXjSoZV3qb9PBTG5+ZyHbDp9/Mdpth7O49f1F7DthTShj+7emUZ3KLUOpVGXVqxHK5DEJ1LdZqe+zlXt5J8U6is1bNBEor2kVVYtJSfG2xbv2ZuQw9N2F/PvHLZVa9CO3wMHc1D3c9v5iDp60XottUjeCOy/V4nLKM1o2rMV7I3tiU3qLV77dxNceGDpdETqdUnlVXIv6TEpKYMK0VMsN3ePZBfzzu828umAzcbGRRNYMIyI0iIiQYMJDg4sfhwYjFI/KyMl38OXqfRzJyrd9raja4Uwb14t6NXx/LWVVffRu05AxncOYtNb6vnx4bhrN69ega4z1npknaSJQXte/XRSzJ/RlzIdLLTN/oXiuwYrdrq0LG103nFl39qFtIx0uqjzvkphQQhrGWNbIzi0oHjr9xX0X06RehJei00tDykd0ianHJ/f0s52m76pm9SKYM6GvJgHlVY9f3YErO0Vb2g+ezGPcVO+u4udSIhCRBiKyQES2OP+1Vl4q3s5RYr3ieSXaW4vIEuf+c5wL3asA1TqqFp/e049BnZvYFqo7H/3aNmTu3X1pFWUdvaGUJwUFCW/c1p2OTa1LXa7bd5JH5qyiyEtLXbp6RvAE8IMxpj3wg/O5nZwS6xXfWKL9JeB15/7HgXEuxqP8XOM6Ebw3qie/PX4Fj1x5AS0aVH6sf0iQcFP3ZnxxX39m3dlH5wson1ErvHjotN2otW/WHeCf33m2Iu9prt4jGAwkOh9PBVKAxyuyo4gIcAUwosT+zwLvuhiTqgaaR9bgjwPac//l7dh6OItDJ/PILXCQW+ggJ99BbmEReQUOcgscGFP8bQsgqnYYl13Q2KvXW5UqT7PIGnwwOp7b3l9EXqk5Lu+kbKNto9oM6Rnj0ZjElUkNIpJhjIks8fy4McZyeUhECoE0oBB40RjzuYhEAYuNMe2c28QCXxtjbMtUisgEYAJAdHR0z9mzZ5933O6QlZVF7dp6zRn0WJz20EMP4XA4eOutt7wdik/Q98UZdsdi6f5C3lllHeIcIvBYrwguqB/s9jguv/zy5caYeMtrnmtHEfkesFvM9S+VeP0Wxph9ItIG+FFE1gAnbbYrMysZY5KBZID4+HiTmJhYiZd3v5SUFLwdg6/QY1EsMjKSjIwMPRZO+r44w+5YJAIRjbfw2oLNZ7UXGnhvbRGf39uXFjaz76vCOe8RGGMGGmMusvn5AjgoIk0BnP/a1lk1xuxz/rud4stHccARIFJETiejGGCfy/9FSinlJx64oh2DuzeztB87lc+4qcs4meuZpS5dvVk8D0hyPk4Cvii9gYjUF5Fw5+MooD+w3hRfk/oJGFre/kopVV2JCC8N6UpcC+uEsi2Hsrh/1koKPbDUpauJ4EXgShHZAlzpfI6IxIvIROc2HYFUEVlF8Qf/i8aY9c6+x4FHRGQr0BCY5GI8SinlVyJCg0keFU/zSOscml82H+b5L9fb7OVeLo0aMsYcBQbYtKcC452PFwJdyth/O9DLlRiUUsrfNaoTzqQx8Qx5Z6Gl1MrURbto17g2o/q2qrLX15nFSinlAzo0qVvmUpfP/nc9v245XGWvrYlAKaV8xBUdonnq2o6WdkeR4d6ZK6psOVdNBEop5UPGXdya4b2sS11m5hYydkoqx2wKM7pKE4FSSvkQEeG5wZ3p17ahpW/3sWzuroKlLjURKKWUjwkNDuLd23vSxqZY4tKdx3jqs7VuXepSE4FSSvmgejVDmTQmwXYhpU9XpPNeqbUNXKGJQCmlfFTrqFq8O7IHITZDiV76ZiPfrD3gltfRRKCUUj6sX9sonr/JthYnD89JY+3eEy6/hiYCpZTyccN7tWD8xa0t7TkFDsZPTeXgyVyXfr8mAqWU8gNPXtuRAR0aW9oPnMzlzmmp5OSf/0giTQRKKeUHgoOEfw2Po0OTOpa+1ekn+NPHaee91KUmAqWU8hO1w0OYmBRPVG3r8u7z1xzg9e832+x1bpoIlFLKj8TUr0ny6HjCQqwf32/9uJXPV+6t9O/URKCUUn6mR4v6/POWbrZ9j32ymuW7jlXq92kiUEopP3Rjt2Y8OKC9pT3fUcSEacvZcyy7wr9LE4FSSvmphwa254Zu1qUuj57KZ/zUVDIruNSlJgKllPJTIsIrQ7vSPda61OWmg5n88aOVOCowkkgTgVJK+bGI0GCSR/ekWb0IS99Pmw7zwlcbzvk7XEoEItJARBaIyBbnv/VttrlcRNJK/OSKyE3OvikisqNEX3dX4lFKqUDUuE4EE5MSqBkWbOmb/PsOZizeVe7+rp4RPAH8YIxpD/zgfH4WY8xPxpjuxpjuwBVANvBdiU3+fLrfGJPmYjxKKRWQOjWry5vD4hCbpS6fmbeOnzeXvdSlq4lgMDDV+XgqcNM5th8KfG2MqfjtbKWUUhUysFM0T17TwdLuKDLcN3NFmfu5mgiijTH7AZz/WgthnG0Y8FGpthdEZLWIvC4i4S7Go5RSAe3OS9pwW3yspT0rr7DMfULO9UtF5HugiU3XXyoTnIg0BboA35ZofhI4AIQBycDjwHNl7D8BmAAQHR1NSkpKZV7e7bKysrweg6/QY1EsIyMDh8Ohx8JJ3xdnePpYXNnAsKZhEOuPFlVo+3MmAmPMwLL6ROSgiDQ1xux3ftAfKudX3Qr8xxjz/wNbT59NAHki8iHwaDlxJFOcLIiPjzeJiYnnCr1KpaSk4O0YfIUei2KRkZFkZGTosXDS98UZ3jgWCX0LGPLuQrYeyjrntq5eGpoHJDkfJwFflLPtcEpdFnImD0REKL6/sNbFeJRSSgH1aoTy4ZgEGtayFqgrzdVE8CJwpYhsAa50PkdE4kVk4umNRKQVEAv8XGr/mSKyBlgDRAF/dzEepZRSTrENyi5QV9I5Lw2VxxhzFBhg054KjC/xfCfQ3Ga7K1x5faWUUuXr2bI+r97Sjcc+WV3mNjqzWCmlqrkbujXjl8cuL7NfE4FSSgWARnXKHp2viUAppQKcJgKllApwmgiUUirAaSJQSqkAp4lAKaUCnCYCpZQKcJoIlFIqwGkiUEqpAKeJQCmlApwmAqWUCnCaCJRSKsBpIlBKqQCniUAppQKcJgKllApwmgiUUirAaSJQSqkAp4lAKaUCnEuJQERuEZF1IlIkIvHlbDdIRDaJyFYReaJEe2sRWSIiW0RkjoiEuRKPUkqpynP1jGAtcDPwS1kbiEgw8DZwDdAJGC4inZzdLwGvG2PaA8eBcS7Go5RSqpJcSgTGmA3GmE3n2KwXsNUYs90Ykw/MBgaLiABXAJ84t5sK3ORKPEoppSovxAOv0RzYU+J5OtAbaAhkGGMKS7Q3L+uXiMgEYILzaZaInCsBVbUo4IiXY/AVeizOiBIRPRbF9H1xhq8ci5Z2jedMBCLyPdDEpusvxpgvKvDCYtNmymm3ZYxJBpIr8HoeISKpxpgy74sEEj0WZ+ixOEOPxRm+fizOmQiMMQNdfI10ILbE8xhgH8XZMVJEQpxnBafblVJKeZAnho8uA9o7RwiFAcOAecYYA/wEDHVulwRU5AxDKaWUG7k6fPQPIpIO9AW+EpFvne3NRGQ+gPPb/v3At8AGYK4xZp3zVzwOPCIiWym+ZzDJlXg8zGcuU/kAPRZn6LE4Q4/FGT59LKT4i7lSSqlApTOLlVIqwGkiUEqpAKeJwA1E5FERMSIS5e1YvEVEXhGRjSKyWkT+IyKR3o7J08oqpRJoRCRWRH4SkQ3OEjQPejsmbxKRYBFZKSJfejuWsmgicJGIxAJXAru9HYuXLQAuMsZ0BTYDT3o5Ho86RymVQFMI/MkY0xHoA9wXwMcC4EGKB8r4LE0ErnsdeIxyJsMFAmPMdyVmiS+meF5IILEtpeLlmLzCGLPfGLPC+TiT4g/BMqsGVGciEgNcB0z0dizl0UTgAhG5EdhrjFnl7Vh8zFjga28H4WF2pVQC8sOvJBFpBcQBS7wbide8QfEXxSJvB1IeT9Qa8mvlldgAngKu8mxE3lORciMi8heKLw3M9GRsPqBSJVMCgYjUBj4FHjLGnPR2PJ4mItcDh4wxy0Uk0dvxlEcTwTmUVWJDRLoArYFVxYVUiQFWiEgvY8wBD4boMecqNyIiScD1wAATeBNUyiqlEpBEJJTiJDDTGPOZt+Pxkv7AjSJyLRAB1BWRGcaYkV6Oy0InlLmJiOwE4o0xvlBh0ONEZBDwGnCZMeawt+PxNBEJofgm+QBgL8WlVUaUmEUfMJwl5qcCx4wxD3k7Hl/gPCN41BhzvbdjsaP3CJS7/BuoAywQkTQRec/bAXnSOUqpBJr+wCjgCud7Ic35rVj5KD0jUEqpAKdnBEopFeA0ESilVIDTRKCUUgFOE4FSSgU4TQRKKRXgNBEopVSA00SglFIB7v8AJCMeM14LiI0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "5lS-wv_ABTky" }, "source": [ "## Regression example" ] }, { "cell_type": "markdown", "metadata": { "id": "j1KeajKoec60" }, "source": [ "We're going to train two separate neural networks to solve a prediction task:\n", "$$f_\\theta(x) \\approx y$$\n", "First we generate the data" ] }, { "cell_type": "code", "execution_count": 81, "metadata": { "id": "0_TUikVIuBCA" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAYoElEQVR4nO3dfYxcV3nH8d/j9YaskxBHylYom7iOWuQ0SiBuRojWUtu8FAeSEjcpIkggWir5n9ICoga7kUpbVcWSK14kUNEKUFURARVJXaRAnSAHoUYkYp01JMExSkkTvAliUbMB4kVe20//2J0wnp07c2fuufeec+f7kSJlZ2fuPbOQZ84853nOMXcXACBdG+oeAACgGAI5ACSOQA4AiSOQA0DiCOQAkLiNddz00ksv9a1bt9ZxawBI1pEjR37q7tPdj9cSyLdu3aq5ubk6bg0AyTKzZ3s9TmoFABJHIAeAxBHIASBxQQK5mX3AzJ40syfM7Itmdn6I6wIABiscyM1sRtJfSWq5+zWSJiTdVfS6AIB8QlWtbJQ0ZWYrkjZJej7QdQEgmIPzCzpw6LieX1rWZZuntGfnNu3aPlP3sAorPCN39wVJ/yzpOUkvSHrJ3R/ofp6Z7TazOTObW1xcLHpbABjKwfkF7bvvcS0sLcslLSwta999j+vg/ELdQyssRGrlEkm3S7pS0mWSLjCzd3Y/z91n3b3l7q3p6XX17ABQqgOHjmt55cw5jy2vnNGBQ8drGlE4IRY7b5b0jLsvuvuKpPsk/W6A6wJAMM8vLQ/1eEpCBPLnJL3RzDaZmUm6SdKxANcFgGAu2zw11OMpCZEjf1TSVyQ9JunxtWvOFr0uAIS0Z+c2TU1OnPPY1OSE9uzcVtOIwglSteLuH5H0kRDXAoAytKtTmli1UsumWQBQh13bZxoRuLvRog8AiWNGDiAKTW3WqQKBHEDt2s067TrvdrOOJIJ5DqRWANSuyc06VSCQA6hdk5t1qkAgB1C7JjfrVIFADqB2qTXrHJxf0I79h3Xl3vu1Y//h2jfeYrETQO1Gbdapo9IlxoVZAjmAKAzbrFNXQO23MFtXICe1AiBJdVW6xLgwSyAHkKS6AmqMC7MEcgBJqiugxrgwSyAHkKS6Auqu7TP66B3XambzlEzSzOYpffSOa2vtQGWxE0CS6tyWNrZdFAnkAJIVW0CtC4EcgCR2H0wZgRxAlE0uyI/FTgDsPpg4ZuQAomxyKUsTU0jMyAFE2eRShnYKaWFpWa5fpZDq3vSqKAI5gCibXMrQ1BQSqRUAtdZkV6mpKaQggdzMNkv6rKRrJLmk97j7t0NcG0A1xqEm+7LNU1roEbRTTyGFSq18UtJ/uftVkl4v6Vig6wJAME1NIRWekZvZqyX9nqQ/lSR3PyXpVNHrAohXqpUfTU0hmbsXu4DZdZJmJX1fq7PxI5Le5+4vdz1vt6TdkrRly5brn3322UL3BVCP7uYhaXVWW3TjqFQ/HKpkZkfcvdX9eIjUykZJvy3pX9x9u6SXJe3tfpK7z7p7y91b09PTAW4LoA5lVH40tSywKiEC+QlJJ9z90bWfv6LVwA6ggcqo/EitLLBxhy+7+4/N7Edmts3dj0u6SatpFgCJyZPe2LxpUi+eXFn32iKVHymVBca4L02oqpW/lHSPmX1P0nWS/inQdQFUJE964+D8gn7xy9PrXjs5YYUqP1LqLI3x20OQQO7uR9fy369z913u/mKI6wKoTp4AdeDQca2cXV8gccF5GwvNRlMqC4zx2wMt+gAk5QtQWc95aXl9qmUYMR6fliXGbw+06AMNNkxJX56uxzI7I1PpLN2zc1vP8ksOXwYQ1MH5BV339w/o/V8+mrukL096o9dzJjeYTp46HU0FR9li/PZQuCFoFK1Wy+fm5iq/LzAOejXsdJrZPKWH996Y+dpBM/jO51w8NamXT53WyplfxZFhm4PGoREo1HvMaggikAMNs2P/4Z7pjzaT9Mz+W0u9V78Pi05ldYnGJOR7zArk5MiBhhlUPRFyUW6YCo5es9J+lTJNCeRVvEcCORCxUb6SZy1ISuEW5drjyvo+3/lhcXB+QX/31Se11FHZ0s7XZ6V/hi3lizk9U0W5IoudQKRG3X+k14KkJF2yaTJIyqJzXL10fli0n7vUozxxeeWMNljvewzzrSH2fVqqKFckkAORGrWDsFdVxSfefp3m//ZNQWapvcbVtnnq3A+Lfs+VpLO+2hXaadhvDTF2WnaqotmJ1AoQqSJfyUPWZHenLfotpF7wqnM7PPs995XXnLdRF7xq48hpkRg7LTtVsQc6gRyI1DDNN2XliHttEGVSZm68M3genF/o+9y2l5ZXdPQjbxp5jCkc31Z2sxOpFSBSeb+Sj5ojzrMVa6+0Rb/A3Bk8+y2GZr1mFCnt01IWAjkQqbwdhKPkiPMG/37pie51yu7gmSe1ESLgxthpWTVSK0DE8nwlHyVHnLe2OSttMdNRB56Vzsl6rXX8PlQKKJV9WspCIAcSN0qOOG/w77dBVFbwbOfrsxY6N06YDvzJ60tbjI2phrwqBHIgcaPsxpc3+A9bcTFonxdJWjnjI3c1dgftG66a1r1HFqI6racOBHIgcaOUtw0T/IdJWwyqG28bpTSwVwXNPY88t25BtWkt/nkQyIGahUgNDJsjLqO2+eD8Qq66cWm0SpVhKmhiqSGvCoEcqNEwB/mGzgV3B/92OWLn9aV8wb79PvIYtVJlmOAcUw15FQjkQI3yVo8UObk97x7j3dff85XvSq5Xzujsd89+KZXJDaYLz9+opZMrhT6A+lXBdM7Mx62GXCKQA0MLOTPOWz0y6laoeT8Ael2/87CIQffsW+r4tjAVKll5/Tuvn9FDTy1StQIgnyIz416yZpmbN02e8/Oo+4nk/QAYJm3R67n96s1DBdUq9ixJVbBAbmYTkuYkLbj7baGuC8QkxCEB3UelTWwwnTl77uz3F788rYPzC69cc9T9RPJ+AAzaDGvQPas6kHjcG3+yhGzRf5+kYwGvB0Sn6E573a3xS8sr64K4tJqX7myxH3U/kbx7YWftYd6tfc/ufVokjX2bfJ2CBHIzu1zSrZI+G+J6QKyKHhKQt85aWk3btPc+GXU/kbwfAJ3X7+ejd1wrST33aZGkh/feqGf236qH995IEK9QqNTKJyR9SNJFWU8ws92SdkvSli1bAt0WKUuxtbpoCmHY+ubO/PsoaYVh8srt62//hwf04sn1J/pcsmlSu7bPaMf+w40/ZzM1hQO5md0m6SfufsTM/iDree4+K2lWklqtVp7dLdFgoRcNq1J0wS1vCV1biAA57AeAZ/zX2X489oMcsqQ4ccgrxIx8h6S3mtlbJJ0v6dVm9gV3f2eAa6OhUj49vciCW78Sui888lzP11QdIF/qcb5m5+MpHOTQLdWJQ16Fc+Tuvs/dL3f3rZLuknSYII5BUp3VFZWV6/7HXdn56aoD5KB1gEF59zwHVoyiyHVjP9ezKOrIUYsUZ3WhZM3oqyrhG2TQOPqll8qa+Ra9btMnDkEDubt/U9I3Q14TzRRL0IpJyIaXfvngQbniPOPI+jAqK2VW9LpNnzgwI0ctxqVLb9gFthANL/1mr5JyzWxHHUdZM9+i1236xIFAjto0vUtvUDqgrCqKQfngMheZy5r5Fr1u0ycOBHKgJIMCallVFKPMXkPlisua+Ya4bpMnDiFb9AF06BdQy6yi6Fd1UrQzdZCyTrQv67pNwYwcKEm/dECZVRSDZq9l54rLmvk2eUZdFDNyoCRZ9dY3XDWtDWY9XxNiZtxv9srMtpmYkQMl6bXA1j71/UyPPvhQM+M85YUE7mYhkAMl6g6avTackqQJsyAz46a3oqM3AjnGWt7zLMs+2u3s2gy9+/DjYe+T8h42GB2BHFGqYqe6PLPXqo52u3hqstB92n+vrFN+mtKKjt5Y7ERtsjZB6j5Fpx3UQm2+1JanBDB0mWDWAqhZdqPOIJ1/ryxNaUVHbwRy1KJfsK5qp7o8JYChZ7hZVSNLPQ5yyHufQacONakVHb2RWkEt+gXrqnaqG9T2fXB+IfPAhyIz3F5VI1lpkTz36fd3mWlYKzp6Y0aOWvQL1mV3H7YN2lf7wKHjPYO4rb22yrH0k/V3mdk8xdmZY4JAjlr0C9a9gpppNc3xG/u+pq2BDizolea48/oZHTh0XFfuvT8zreIKX8pXpFGnyIcAmoHUCmrRr428s5FmYWn5nPRGu5EmVH10Z5qju0Ily6CT5kOMZdjXSc3d2Q+DmWedtFqiVqvlc3Nzld8XcclTYrhj/+G+1Rjt9EEIg+4lrX7Y9Jopd7+XG66a1kNPLRJYEZSZHXH3VvfjzMhRmzwz0EELnCEXQAdda8LsnOqZfrXmnQcp012JspEjR9QGLXCGXADNutYlmyY1NTmxLq3TztEPKv+TmnXQL+JDIEdUupuEbrhqet1CXtuwC3qDTmHPWjR079+sk/dbAd2VKAuBHNHo1SR075EF3Xn9zCsLjBNr278Ou/1qnm7RrMqRl5b7N+vk/VZAdyXKQo4c0chqEnroqcXCC5p5N5MapVmnVwVOt9DlgP0WiqvYpwZxIZAjGmV2dBa59qATd7L2HS+raqXfRl5SeWeBIl6FA7mZXSHp3yS9RtJZSbPu/smi18X4KesEdml1d8GlHimSi6cmB742T512lYc1DNqLhm1sx0+IGflpSR9098fM7CJJR8zsQXf/foBrY4yUcQJ7O83QK4hLUsaJa+vEdKrOKN8uWGhttsKB3N1fkPTC2r//3MyOSZqRRCDHUEJ3KObp1MzadbDXtWLJOw/65lLWtxrEK2iO3My2Stou6dEev9stabckbdmyJeRt0SAhZ7556rvzBLjYjk8b9M0l9LcaxC9YIDezCyXdK+n97v6z7t+7+6ykWWm1RT/UfYEsg9IJeQNcbMen5fnmEsu3B1QjSCA3s0mtBvF73P2+ENdEvWJKJYwqKwUhDbdPd1X7ow+j3zeXmPL5qEbhhiAzM0mfk3TM3T9WfEioW1VHrZUtq1PzE2+/bqh9uqvaHx0YVYjOzh2S3iXpRjM7uvbPWwJcFzWp6qi1shXZ47sT+30jdiGqVv5bq/v+oyFiTCWMqkiaoTO9dPHUpM6f3KClkyvJpprQXHR2Yp0yG3NS0V2psrS8oqnJCX387dcRwBEdNs3COqQSmpNewnhgRj7GsipTODqsWeklNB+BfEwNanIZ9xK2zZsm9WKPrs9xSi8hHaRWxhSpg2wH5xf0i1+eXvf45ISNVXoJ6WBGPqbKTh2M2lAUQyPSgUPHtXJ2ffPxBedtHOtvKYgXgXxMlVmZMureJLHsaZL1YZZ1UhBQN1IrY6rMypRR0zaxpHvo5ERqCORjKlTXYy+jpm1iqRSh/BKpIbUyxsqqTBk1bRNLIxLll0gNgRzBjXrSTxknBI1q3MsvkRYCOYIbdUbLTBgYjblXf8ZDq9Xyubm5yu+LdMVQlgjUzcyOuHur+3Fm5IheLGWJQKwI5IjeoKPWmK1j3BHIEb1+ZYnM1gHqyJGAfg06sTQRAXUikCN6/Rp0smbrC0vLunLv/dqx/3ByZ40CwyKQI1oH5xe0Y/9hfeDLR/WqjRt0yabJdV2o/ZqFUj44GhgGOXKsE8PiYd6j1no1EXXrXBgFmohAnrAyAm4si4eDKlXaupuIsroiONkHTUZqJVHtgLuwFrxCpRBiWTwcZgOtXdtn9PDeG/XM/ls1w86FGENBArmZ3WJmx83saTPbG+Ka6K+sgBvLDoRZgffiqcm+r2PnQoyjwoHczCYkfVrSmyVdLekdZnZ10euiv7ICblYA3WBW6YLhnp3bNLnB1j3+8qnTfcdR5va8QKxCzMjfIOlpd/+hu5+S9CVJtwe4Lvoo6/CDXjNaSTrjXmn1x67tM7rw/PVLOCtnfOC3js5Uy8N7bySIo/FCBPIZST/q+PnE2mMoUVkphPaMdsLWz4arzpUv9TjFXmLhEugWIpCv/y9e64sHzGy3mc2Z2dzi4mKA21anXc8cU4NJmSmEXdtndDZjV8wqgyhHrgH5hCg/PCHpio6fL5f0fPeT3H1W0qy0uo1tgPtWIpZyvF6KHH4wqHQxhtN6YjpoAohZiBn5dyS91syuNLPzJN0l6asBrhuFWMrxQspTunjDVdPrvmpVHURZuATyKTwjd/fTZvZeSYckTUj6vLs/WXhkkYilHC+kPNvC3ntk4Zz8mEm68/rsbwBldYNy5BowWJDOTnf/mqSvhbhWbGJIMYQ26MOpV6B3SQ891XttI+b0EzAO6OwcoIkNJoMWEYf9FtLE9BOQEgL5AE3M0w76cBq2WqSJ6ScgJWyalUPT8rSDTqsftlqkjPRTDDswAqkgkDdYv2DY68Op8/kXT03q/MkNWjq5MjCQhi4TJOcODIdA3lDDBsO8+3/3MmiGP6y8W9gCWEUgb6hhg2HR4Bky/UTOHRgOi50NNWwwjCl40poPDIdA3lDDBsOYgmcTSz6BMhHIG2rYYBhT8GxiySdQJnLkDTXsAmToBcuimlbyCZTJPGO70jK1Wi2fm5ur/L4AkDIzO+Lure7HSa0AQOII5ACQuKRy5LRtj46/HdBcyQRy2rZHx98OaLZkUitslTo6/nZAsyUzI4+p81BKK1UR298OQFjJzMhj6jzMc+ZlTGL62wEIL5lAHlPnYWqpipj+dgDCSya1ElPnYWqpipj+dgDCSyaQS/G0bad4IHMsfzsA4SWTWokJqQoAMUlqRl6n7iqVO6+f0UNPLZKqAFC7QoHczA5I+iNJpyT9j6Q/c/elEAOLSa+GmnuPLLC1KoAoFE2tPCjpGnd/naQfSNpXfEjxSa1KBcB4KRTI3f0Bdz+99uMjki4vPqT4pFalAmC8hFzsfI+kr2f90sx2m9mcmc0tLi4GvG35aKgBELOBgdzMvmFmT/T45/aO59wt6bSke7Ku4+6z7t5y99b09HSY0VeEKhUAMRu42OnuN/f7vZm9W9Jtkm7yOo4bqgANNQBiVrRq5RZJH5b0++5+MsyQ4kRDDYBYFc2Rf0rSRZIeNLOjZvaZAGMCAAyh0Izc3X8z1EAAAKOhRR8AEkcgB4DEEcgBIHEEcgBIHIEcABJHIAeAxBHIASBxBHIASByBHAASRyAHgMQRyAEgcQRyAEgcgRwAEkcgB4DEEcgBIHEEcgBIHIEcABJHIAeAxBHIASBxBHIASByBHAASRyAHgMQRyAEgcUECuZn9tZm5mV0a4noAgPwKB3Izu0LSH0p6rvhwAADDCjEj/7ikD0nyANcCAAypUCA3s7dKWnD37+Z47m4zmzOzucXFxSK3BQB02DjoCWb2DUmv6fGruyX9jaQ35bmRu89KmpWkVqvF7B0AAhkYyN395l6Pm9m1kq6U9F0zk6TLJT1mZm9w9x8HHSUAINPAQJ7F3R+X9Gvtn83sfyW13P2nAcYFAMiJOnIASNzIM/Ju7r411LUAAPkxIweAxAWbkaO3g/MLOnDouJ5fWtZlm6e0Z+c27do+U/ewADQIgbxEB+cXtO++x7W8ckaStLC0rH33PS5JBHMAwZBaKdGBQ8dfCeJtyytndODQ8ZpGBKCJCOQlen5peajHAWAUBPISXbZ5aqjHAWAUBPIS7dm5TVOTE+c8NjU5oT07t9U0IgBNxGJnidoLmlStACgTgbxku7bPELgBlIrUCgAkjkAOAIkjkANA4gjkAJA4AjkAJM7cqz91zcwWJT1b+Y3rd6mkcTx4Y1zftzS+731c37dU7nv/dXef7n6wlkA+rsxszt1bdY+jauP6vqXxfe/j+r6let47qRUASByBHAASRyCv1mzdA6jJuL5vaXzf+7i+b6mG906OHAASx4wcABJHIAeAxBHIK2RmB8zsKTP7npn9h5ltrntMVTGzt5nZk2Z21swaX5ZmZreY2XEze9rM9tY9nqqY2efN7Cdm9kTdY6mSmV1hZg+Z2bG1/5+/r8r7E8ir9aCka9z9dZJ+IGlfzeOp0hOS7pD0rboHUjYzm5D0aUlvlnS1pHeY2dX1jqoy/yrplroHUYPTkj7o7r8l6Y2S/qLK/80J5BVy9wfc/fTaj49IurzO8VTJ3Y+5+7icOv0GSU+7+w/d/ZSkL0m6veYxVcLdvyXp/+oeR9Xc/QV3f2zt338u6Zikyg4iIJDX5z2Svl73IFCKGUk/6vj5hCr8jxr1MrOtkrZLerSqe3JCUGBm9g1Jr+nxq7vd/T/XnnO3Vr+K3VPl2MqW572PCevxGHW+Y8DMLpR0r6T3u/vPqrovgTwwd7+53+/N7N2SbpN0kzesiH/Qex8jJyRd0fHz5ZKer2ksqIiZTWo1iN/j7vdVeW9SKxUys1skfVjSW939ZN3jQWm+I+m1ZnalmZ0n6S5JX615TCiRmZmkz0k65u4fq/r+BPJqfUrSRZIeNLOjZvaZugdUFTP7YzM7Iel3JN1vZofqHlNZ1ha03yvpkFYXvf7d3Z+sd1TVMLMvSvq2pG1mdsLM/rzuMVVkh6R3Sbpx7b/to2b2lqpuTos+ACSOGTkAJI5ADgCJI5ADQOII5ACQOAI5ACSOQA4AiSOQA0Di/h/0dCUAhQ5piAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "N = 100\n", "d = 1\n", "X = np.random.randn(N, 1)\n", "Y = X * 2 + 3 + np.random.randn(N, 1)\n", "plt.scatter(X, Y)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 82, "metadata": { "id": "EJ3h8uW1WWYg" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(100, 1) (100, 1)\n" ] } ], "source": [ "print(X.shape, Y.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "c-2yrp4uFa43" }, "source": [ "Next we convert that data to pytorch" ] }, { "cell_type": "code", "execution_count": 83, "metadata": { "id": "5Q6Mil6uUoJs" }, "outputs": [], "source": [ "X_pt = torch.from_numpy(X).to(torch.float32)\n", "Y_pt = torch.from_numpy(Y).to(torch.float32)\n", "\n", "loss_fn = nn.MSELoss()" ] }, { "cell_type": "markdown", "metadata": { "id": "d5hhaIlsFd_4" }, "source": [ "Define the training loop" ] }, { "cell_type": "code", "execution_count": 84, "metadata": { "id": "JbQaSvbruUjo" }, "outputs": [], "source": [ "def train(net: nn.Module):\n", " optimizer = optim.SGD(net.parameters(), lr=1e-2)\n", " losses = []\n", " for _ in range(100):\n", " Y_hat_pt = net(X_pt)\n", " loss = loss_fn(Y_hat_pt, Y_pt) \n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " losses.append(loss.detach().numpy())\n", " return np.array(losses)" ] }, { "cell_type": "markdown", "metadata": { "id": "J9UGOhquFfz-" }, "source": [ "Let's test two different networks" ] }, { "cell_type": "code", "execution_count": 86, "metadata": { "id": "VH4ARnKUqVt0" }, "outputs": [], "source": [ "linear_network = nn.Linear(1, 1)\n", "linear_losses = train(linear_network)\n", "\n", "non_linear_network = Net2(1, 1)\n", "non_linear_losses = train(non_linear_network)" ] }, { "cell_type": "markdown", "metadata": { "id": "2kHLnn2TFiXO" }, "source": [ "and plot the losses and predictions." ] }, { "cell_type": "code", "execution_count": 88, "metadata": { "id": "N3tTaRBkv573" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxU1f3/8deZmUz2fQGykYQdwh4g7ChaUVHRStW614pLa21tv1bb38O2fttvtS5Vq3WpIlQtdcOq1AUREGU17DsJkJAEyAYJ2dfz++NOYggJ2WZyM5PP8/G4jztz5869n8vom8u5956jtNYIIYRwPxazCxBCCNE1EuBCCOGmJMCFEMJNSYALIYSbkgAXQgg3ZevJnUVEROiEhISe3KUQQri9rVu3FmqtI1su79EAT0hIIC0trSd3KYQQbk8pldXacmlCEUIINyUBLoQQbkoCXAgh3FSPtoELIdxfbW0tOTk5VFVVmV2Kx/Hx8SE2NhYvL68OrS8BLoTolJycHAIDA0lISEApZXY5HkNrTVFRETk5OSQmJnboO9KEIoTolKqqKsLDwyW8nUwpRXh4eKf+ZSMBLoToNAlv1+jsn6tbBPhXhwr4+9oMs8sQQohexS0CfENGIU+vPERxRY3ZpQgheoGAgAAAjh8/zrXXXmtyNeZpN8CVUouVUvlKqT0tlt+nlDqolNqrlPqL60qEK8ZGU9eg+XTPSVfuRgjhZqKjo3nvvfdcuo+6ujqXbr87OnIGvgSY13yBUuoC4CpgjNZ6FPCk80v7zqjoIJIi/Pl453FX7kYI4WYyMzNJTk4GYMmSJVxzzTXMmzePIUOG8OCDDzatt3LlSqZOncqECRNYuHAhZWVlADz66KNMmjSJ5ORkFi1aROMIZXPmzOE3v/kNs2fP5tlnn+35A+ugdm8j1FqvU0oltFh8D/CY1rrasU6+80v7jlKK+WOj+dvqdPLPVBEV5OPK3QkhOugPH+9l3/EzTt3myOggfnfFqC59d8eOHWzfvh1vb2+GDRvGfffdh6+vL3/84x9ZtWoV/v7+PP744zz99NM88sgj/PSnP+WRRx4B4Oabb2bFihVcccUVABQXF/PVV1857bhcoatt4EOBmUqpzUqpr5RSk9paUSm1SCmVppRKKygo6OLu4IoxA9AaPtl9osvbEEJ4trlz5xIcHIyPjw8jR44kKyuLTZs2sW/fPqZPn864ceNYunQpWVlG31Br1qxhypQpjB49mtWrV7N3796mbV133XVmHUaHdfVBHhsQCqQCk4B3lFJJupURkrXWrwCvAKSkpHR5BOUh/QIZ3j+Qj3ed4LbpHbvJXQjhWl09U3YVb2/vptdWq5W6ujq01lx88cUsW7bsrHWrqqq49957SUtLIy4ujt///vdn3YPt7+/fY3V3VVfPwHOA5dqwBWgAIpxXVuuuGBvN1qzT5JyucPWuhBAeIjU1lfXr15ORYdyKXFFRwaFDh5rCOiIigrKyMpdfDHWFrgb4f4ALAZRSQwE7UOisotpy5dhoAFbskmYUIUTHREZGsmTJEm644QbGjBlDamoqBw4cICQkhDvvvJPRo0ezYMECJk1qsyW411KttHqcvYJSy4A5GGfYecDvgDeAxcA4oAb4ldZ6dXs7S0lJ0d0d0GHBC+upqWvgk/tndms7Qoiu2b9/PyNGjDC7DI/V2p+vUmqr1jql5boduQvlhjY+uqlr5XXPlWOjeXTFPjLySxkcFWhGCUII0Su4xZOYzV0xNhqrRfHB9lyzSxFCCFO5XYBHBnozY3AE/9l+nIaGLt/UIoQQbs/tAhzg6vEx5BZX8m3mKbNLEUII07hHgGd8Ceufa3r7vVH98LNb+c8OaUYRQvRd7hHgh1fD6j9CndEboZ/dxrxR/Vmx6wRVtfUmFyeEEOZwjwCPmQj11ZD3XYeIV0+IobSqjjUHXNoNixBC9FruEeCxjtsfc7c2LZo2KIKoQG+Wy90oQggnWbt2LfPnzwfgo48+4rHHHjO5ovNzjwAPjgP/yLMC3GpRXDUumrUH8zldLgM9CCGc68orr+Shhx5y6T7q67vXBOweo9IrBTEpkHP2U5zXTIjlH18f5cMdudLBlRBm+PQhOLnbudvsPxoubfvMNzMzk0svvZQZM2awYcMGYmJi+PDDDzl48CB33303FRUVDBo0iMWLFxMaGsqcOXOYMmUKa9asobi4mNdee42ZM9t/knvJkiWkpaXx/PPPc9tttxEUFERaWhonT57kL3/5S9NIQE888QTvvPMO1dXVXH311fzhD38AYMGCBWRnZ1NVVcX999/PokWLAGM0oQceeIDPP/+cp556ihkzZnT5j8o9zsABYidCUTpUnm5aNGJAEMkxQby7NcfEwoQQPS09PZ2f/OQn7N27l5CQEN5//31uueUWHn/8cXbt2sXo0aObghSMUXW2bNnCM888c9byzjhx4gTffPMNK1asaDozX7lyJenp6WzZsoUdO3awdetW1q1bB8DixYvZunUraWlpPPfccxQVFQFQXl5OcnIymzdv7lZ4g7ucgYNxBg6Quw0Gz21avHBiHL/7yOhUfmR0kEnFCdFHnedM2ZUSExMZN24cABMnTuTw4cMUFxcze/ZsAG699VYWLlzYtP4111zTtG5mZmaX9rlgwQIsFgsjR44kLy8PMAJ85cqVjB8/HoCysjLS09OZNWsWzz33HB988AEA2dnZpKenEx4ejtVq5fvf/36XamjJfc7AYyYA6qx2cDD6RrFbLby7NducuoQQPa5lv9/FxcUdWr+xj/Du7rOxE0CtNQ8//DA7duxgx44dZGRkcMcdd7B27VpWrVrFxo0b2blzJ+PHj2/qvtbHxwer1dqlGlpynwD3CYaIoee0g4f627loZBQf7jhOTV2DScUJIcwUHBxMaGgoX3/9NQBvvPFG09m4K11yySUsXry4aYzN3Nxc8vPzKSkpITQ0FD8/Pw4cOMCmTZtcsn/3aUIB43bCQ5+B1saFTYeFE+P4ZPdJVh/IZ15yfxMLFEKYZenSpU0XMZOSknj99dddvs/vfe977N+/n6lTpwLGBco333yTefPm8dJLLzFmzBiGDRtGamqqS/bfbn/gztTt/sC/fQ3++wD8bAeEfXfXSV19A9MeW82Y2GBevdX9OmUXwp1If+Cu1Zn+wN2nCQVafaAHwGa1cM2EWNYcLCC/tKqVLwohhOdxrwCPGgU233MCHOAHKbHUN2jek1sKhRDt+Pzzzxk3btxZ09VXX212WZ3Wbhu4UmoxMB/I11ont/jsV8ATQKTW2uVjYmK1QfS4cy5kAiRFBjAlMYy3v83m7lmDsFhUKxsQQjiD1hql3Pf/sUsuuYRLLrnE7DLO0dkm7Y6cgS8B5rVcqJSKAy4GjnVqj90VMxFO7IS66nM+umFyPFlFFWw8UtSjJQnRl/j4+FBUVNTpsBHnp7WmqKgIHx+fDn+nI2NirlNKJbTy0V+BB4EPO7w3Zxg4DTY+bzSjDJx21kfzkvsT/JEXy7YcY/rgiB4tS4i+IjY2lpycHAoKCswuxeP4+PgQGxvb4fW7dBuhUupKIFdrvbO9f0YppRYBiwDi4+O7sruzxU8FFGSuPyfAfbysXDMhhjc3ZVFUVk14gHfr2xBCdJmXlxeJidL3UG/Q6YuYSik/4LfAIx1ZX2v9itY6RWudEhkZ2dndncsvDPolQ+bXrX58w+R4aus1y7dJN7NCCM/WlbtQBgGJwE6lVCYQC2xTSvXcEzQJ0yF7S9MIPc0N7RfIxIGhLPv2mLTRCSE8WqcDXGu9W2sdpbVO0FonADnABK31SadX15aEGVBXCce3tfrx9ZPiOFJQzqYjMuixEMJztRvgSqllwEZgmFIqRyl1h+vLake8o+0785tWP54/JppgXy/e3JzVg0UJIUTPajfAtdY3aK0HaK29tNaxWuvXWnye0CP3gDfnH2481NNGgPvarSycGMvne06Sf0aezBRCeCb3ehKzuYTpkL0Z6mtb/fjG1IHUNWiWbZFuZoUQnsl9A3zgdKitgOPbW/04McKfWUMj+deWLGrrpZtZIYTnce8AhzabUQBuTh1I3plqvtyf10NFCSFEz3HfAA+IhMjhkLW+zVUuHB5FTIgvb2ySi5lCCM/jvgEOxu2Exza12Q5utSh+OCWe9RlFZOSX9XBxQgjhWm4e4DOhpqzV7mUbXTcpDrvNwtINmT1XlxBC9AD3DvDEWaAscHh1m6tEBHhz5dho3tuaQ0lF62fqQgjhjtw7wP3CIHoCHF5z3tVun55AZW09b6f1bM+3QgjhSu4d4ACDLoTcNKgsbnOVUdHBTEkMY+mGLOrklkIhhIfwjADXDXB03XlXu316IrnFlaySWwqFEB7C/QM8NgXsgedtBwe4eGQ/YkN9WfxNZs/UJYQQLub+AW71Mi5mHv4SztN9rNWiuG1aAlsyT7Ent6QHCxRCCNdw/wAHGHQBFB+DU0fOu9rClDj87VZe/fr86wkhhDvwkAC/0Ji304wS7OvF9ZPj+XjXCY4XV/ZAYUII4TqeEeBhSRAysN3bCQF+NMMYy2/xN0ddXZUQQriUZwS4UsZZ+NF1bT5W3ygmxJf5YwawbMsxSirlwR4hhPvyjAAHI8BrSo2xMttx58wkymvqWbZFHuwRQrivjgyptlgpla+U2tNs2RNKqQNKqV1KqQ+UUiGuLbMDkmaDshp3o7QjOSaY6YPDeX39UWrq5MEeIYR76sgZ+BJgXotlXwDJWusxwCHgYSfX1Xk+wRA3GTLaD3AwzsLzzlTz4Y5cFxcmhBCu0ZExMdcBp1osW6m1rnO83QTEuqC2zhs8F07sgLKCdledPTSSEQOCePGrw9Q3tH3/uBBC9FbOaAP/EfBpWx8qpRYppdKUUmkFBe0Ha7cMvsiYt3M7oaMu7p0ziCMF5azce9K1dQkhhAt0K8CVUr8F6oC32lpHa/2K1jpFa50SGRnZnd21r/9Y8IvoUDs4wGWjB5AY4c8LazPQ53mKUwgheqMuB7hS6lZgPnCj7i3pZ7EYd6NkfAkN7V+ctFoU98wexJ7cM6xLL+yBAoUQwnm6FOBKqXnAr4ErtdYVzi2pmwZfBBWFcHJnh1ZfMD6GAcE+vLAmw8WFCSGEc3XkNsJlwEZgmFIqRyl1B/A8EAh8oZTaoZR6ycV1dlzjY/UZqzq0ut1mYdGsJLYcPcW3mafa/4IQQvQSHbkL5Qat9QCttZfWOlZr/ZrWerDWOk5rPc4x3d0TxXZIQCQMGAsZ7V/IbHT9pHjC/e0892W6CwsTQgjn8pwnMZsbfBFkb4aqjnUb62u3ctfsJL5OL2RrlpyFCyHcg2cG+KC5oOvbHaWnuZtSBxLub+eZVXIWLoRwD54Z4LGTwMsfjqzt8Ff87DYWzWo8Cz/tutqEEMJJPDPAbXZImN6pAAe4eepAwvztPCtt4UIIN+CZAQ6QdAEUZUBxdoe/4me3cdesJNYdKpCzcCFEr+fBAT7HmHfhLDzc387TXxx0dkVCCOFUnhvgUSPAP6rTAe5nt3HPnEGszyhiQ4Y8nSmE6L08N8CVMs7Cj6zt0GP1zd2UOpABwT48sfKg9JEihOi1PDfAwRitvqIQ8vd26ms+XlZ+NncI248V8+X+fBcVJ4QQ3ePZAZ4425h3shkF4NqJsSSE+/HkyoM0SH/hQoheyLMDPDgGIoZ2KcC9rBZ+cfFQDpws5eNdx51fmxBCdJNnBzgYtxNmroe66k5/9Yox0QzvH8hTKw9RXVfvguKEEKLr+kCAz4G6yg6NVt+SxaJ46NLhHDtVwVubZAR7IUTv4vkBnjAdlAUyv+7S12cPjWTG4Aj+tjqdkspaJxcnhBBd5/kB7hNsdC+b+U2Xvq6UcRZeXFnLi2sPO7k4IYToOs8PcICEGZDzLdRWdunryTHBXD0uhsXrj5Jb3LVtCCGEs/WRAJ8F9TVdagdv9MD3hgLw1OfyiL0QonfoyJBqi5VS+UqpPc2WhSmlvlBKpTvmoa4ts5viU0FZu9wODhAb6sePpieyfHsuO7OLnVicEEJ0TUfOwJcA81osewj4Ums9BPjS8b738gmC6HFdbgdv9NMLBxMR4M2jK/bJI/ZCCNN1ZEzMdUDLccauApY6Xi8FFji5LudLmAE5aVBT0eVNBHjbePCSYWzNOs1HO+XhHiGEubraBt5Pa30CwDGPamtFpdQipVSaUiqtoKCgi7tzgoRZ0FBrjJXZDddOjCU5JojHPj1AZY083COEMI/LL2JqrV/RWqdorVMiIyNdvbu2xU/pdjs4GA/3PDJ/FCdKqnh5ndxWKIQwT1cDPE8pNQDAMe/9XfZ5B0LMhG63gwNMTgzj8jEDeHHtYbJPdb1JRgghuqOrAf4RcKvj9a3Ah84px8USZkDuVqgu6/am/t/lI7BaFI+u2OeEwoQQovM6chvhMmAjMEwplaOUugN4DLhYKZUOXOx43/slzISGOsje1O1NDQj25Wdzh/DFvjxWH8hzQnFCCNE5tvZW0Frf0MZHc51ci+vFOdrBszbC4Iu6vbkfTU/k3bRsfv/RPqYNisDHy+qEIoUQomP6xpOYjbwDjH5RsjY4ZXN2m4VHr0rm2KkKXvpKLmgKIXpW3wpwgIHTjHbwLvQP3prpgyOYP2YAf197mKOF5U7ZphBCdETfC/D4qVBfDbnbnLbJR+aPxNtm4bcf7JYnNIUQPaZvBjjAMec0owBEBfnw63nD2XC4iPe35Tptu0IIcT59L8D9wyFyuHEh04l+ODmeiQND+dN/93GqvMap2xZCiNb0vQAH4yw8ezM0OO9ReItF8edrRlNWXccf5d5wIUQP6JsBPnAaVJ+BvD3tr9sJQ/sFcvfsQSzfnsuag73/4VQhhHvrmwHe2A7u5GYUMLqcHRIVwG+W7+ZMlYyhKYRwnb4Z4CFxEBzv1AuZjbxtVp5YOJa8M1X8+ZP9Tt++EEI06psBDjBwqvFAjwtu+xsXF8KdM5NYtiWbb9ILnb59IYSAvhzg8VOhvACKXPME5S8uHkpShD+/fn8XpdKUIoRwgb4b4AOnGXMXNKMA+HgZTSknSir5X7krRQjhAn03wCOGgm8YHOt+z4RtmTgwlLtnD+KdtBxW7j3psv0IIfqmvhvgShmj1bswwAF+ftFQRg4I4uHluyksc07/K0IIAX05wMHoXvbUYShz3ViddpuFv143jtKqOh5eLn2lCCGcp28HeOP94E4Y4OF8hvUP5H8uGcYX+/L497fZLt2XEKLv6NsBHj0OrN4ub0YBuGNGIjOHRPCHj/eSnlfq8v0JITxftwJcKfULpdRepdQepdQypZSPswrrETZvY6Dj7M0u35XFonhq4Vj87DbuW7adqlrn9cMihOibuhzgSqkY4GdAitY6GbAC1zursB4TNwWO74DaSpfvKirIhycXjuHAyVIe+/SAy/cnhPBs3W1CsQG+Sikb4Acc735JPSw+FRpqnTrAw/lcOLwft01LYMmGTLm1UAjRLV0OcK11LvAkcAw4AZRorVe2XE8ptUgplaaUSisocN3dHl0WN8WYH3N+x1Ztefiy4YyOCeaX7+7kWFFFj+1XCOFZutOEEgpcBSQC0YC/UuqmlutprV/RWqdorVMiIyO7Xqmr+IVBxLAeaQdv5G2z8sIPJwDwk39to7pO2sOFEJ3XnSaUi4CjWusCrXUtsByY5pyyelh8qmOAh4ae22W4H08uHMvu3BL+9F/ptVAI0XndCfBjQKpSyk8ppYC5gHsmUXwqVJVAQc9eWLxkVH/unJnIPzdm8eEOGUtTCNE53WkD3wy8B2wDdju29YqT6upZ8anG3MUP9LTmwXnDmZwYxq/f38Xe4yU9vn8hhPvq1l0oWuvfaa2Ha62TtdY3a63ds7OP0ETwj4JjPdcO3sjLauGFH04gxNfO3W9upbhCBkQWQnRM334Ss5FSED+lR+9EaS4y0JsXb5pAXkk19y3bTn2D9JcihGifBHij+KlQnAVnTpiy+/HxoTx61Si+Ti/k8c/kIR8hRPskwBvFmdcO3uj6yfHcMnUgr6w7wrtp0umVEOL8JMAbDRgDNl9T2sGbe2T+SGYMjuA3H+zm28xTptYihOjdJMAbWb0gNsXUM3AAm+OiZlyoH3e9sZXsU/KkphCidRLgzcWnwoldUF1mahnBfl68emsK9Q2a25d8S0mFDIoshDiXBHhzcamg6yF3q9mVkBQZwMs3T+RYUQWL3kiTx+2FEOeQAG8ubhKgemSAh45ITQrniYVj2Hz0FL9+b5cMxyaEOIvN7AJ6FZ9g6DfK9Hbw5q4aF0PO6Uqe+PwgA0J8+fW84WaXJIToJSTAW4qbArvegYZ6sFjNrgaAe+cM4nhxJS+uPUy4v50fz0wyuyQhRC8gTSgtxU+FmlLI22t2JU2UUjx6VTKXje7PH/+7nw+255hdkhCiF5AAbyneMcBDD/YP3hFWi+Kv141j2qBw/ufdXaw+kGd2SUIIk0mAtxQcB0ExkLXB7ErO4W2z8vLNExkxIIh73tzGhsOFZpckhDCRBHhLSsHAaUbHVr3wro9AHy+W/mgyA8P9+PHSNLZmnTa7JCGESSTAWxM/FUpPwOmjZlfSqjB/O2/+eApRgd7c9voW9uRKP+JC9EUS4K0ZON2Y98JmlEZRgT68dWcqQT5e3PjqZglxIfogCfDWRA4Dv/BeHeAAMSG+/HtRKgHeNglxIfqgbgW4UipEKfWeUuqAUmq/UmqqswozlVJGM0rWerMraVdcmJ+EuBB9VHfPwJ8FPtNaDwfG4q6DGrdm4DQ4nQlnjptdSbuah/gP/7GJbcfkwqYQfUGXA1wpFQTMAl4D0FrXaK2LnVWY6QZOM+a9vBmlUVyYH2/flUqYv52bXt3MxsNFZpckhHCx7pyBJwEFwOtKqe1KqVeVUv4tV1JKLVJKpSml0goKCrqxux7WbzTYA90mwAFiQ/14566pxIT4ctvrW1hzIN/skoQQLtSdALcBE4AXtdbjgXLgoZYraa1f0VqnaK1TIiMju7G7Hma1GU9lulGAA0QF+fD2XVMZ0i+AO/+Zxn+255pdkhDCRboT4DlAjta68Znz9zAC3XPET4WC/VDuXs0RYf52/nVnKikJofz87R28+vURs0sSQrhAlwNca30SyFZKDXMsmgvsc0pVvUXj/eDHNppbRxcE+Xix5PbJXJpsdID150/309DQ+54sFUJ0XXfvQrkPeEsptQsYB/xf90vqRWImgNXb7ZpRGvl4WXn+hxO4KTWel786wv1v76CqVkb2EcJTdKs/cK31DiDFSbX0PjZviJsMmevMrqTLrBbF/16VTGyoH499eoCTJZW8cnMKof52s0sTQnSTPInZnsTZcHK327WDN6eU4u7Zg/jbDePZmV3CNS9u4EiBuQM3CyG6TwK8PUlzjPnRr8yswimuGBvNW3dOoaSylgUvrOebdOmOVgh3JgHenujx4B3kEQEOMCkhjA9/Mp0Bwb7c+voWlm7IlMGShXBTEuDtsdqMu1GOeEaAg/HU5vv3TuOCYVH87qO9/M97u+TiphBuSAK8I5LmGH2Dn84yuxKnCfC28crNE7l/7hDe25rDwpc2kltcaXZZQohOkADviKTZxtxDmlEaWSyKX1w8lFdvSSGzsJz5z33NV4fcqLsDIfo4CfCOiBwOAf08qhmluYtG9uPDn06nX5APt72+hadWHqReHvoRoteTAO8IpYzbCY9+1SvHyXSGpMgAPrh3Oj+YGMffVmdw46ubyDtTZXZZQojzkADvqKTZUF4A+Z7VW0BzvnYrj187hicXjmVndgnznlnHqn15ZpclhGiDBHhHJTrawT20GaW5ayfGsuJnM4gO8eXH/0zjkQ/3yF0qQvRCEuAdFRIHYYPgyBqzK+kRgyIDWH7vNH48I5F/bszi8ue+ZleO54zXIYQnkADvjCEXw9F1UFNhdiU9wttm5f/NH8mbd0yhoqaeq/++gWdWHaK2vsHs0oQQSIB3ztBLoK7K424nbM+MIRF89vNZXDk2mmdWpbPghfXsPS6DJwthNgnwzhg4wxhm7eCnZlfS44J9vfjrdeN46aaJ5J2p5qrn1/P0yoNU10nbuBBmkQDvDJsdBl8Ihz6Hhr7ZjDAvuT+rHjDOxp9bncHlz33DlqOnzC5LiD5JAryzhl4KZSfhxA6zKzFNiJ+dp68bx+u3T6Kqtp4fvLyRh97fRXFFjdmlCdGnSIB31pDvgbLAoc/MrsR0FwyLYuUvZrFoVhLvbs3hwqe+4u1vj8nQbUL0kG4HuFLKqpTarpRa4YyCej3/cIid3CfbwVvjZ7fxm8tGsOK+GQyK9OfX7+/mmhc3yC2HQvQAZ5yB3w/sd8J23MeweXByF5Tkml1JrzFiQBDv3DWVp38wlpzTlVz5/Hp++c5OeRxfCBfqVoArpWKBy4FXnVOOmxg6z5hLM8pZlFJcMyGWNb+azV2zk/h453EueHItf/synYqaOrPLE8LjdPcM/BngQaDNWzKUUouUUmlKqbSCAg/pqjRyOIQMlGaUNgT6ePHwpSP44oFZzBwSwVNfHGLOE2v595Zj1MlDQEI4TZcDXCk1H8jXWm8933pa61e01ila65TIyMiu7q53UQpGXGE8Vu/Ggx272sBwf16+OYV3755KTKgvDy3fzbxnv+bT3SdkGDchnKA7Z+DTgSuVUpnAv4ELlVJvOqUqdzDmOmiog30fmF1JrzcpIYzl90zjxRsnoLXmnre2ceXz61lzMF+CXIhuUM74H0gpNQf4ldZ6/vnWS0lJ0Wlpad3eX6+gNfx9KvgEwR0rza7GbdQ3aD7Ynsszqw6Rc7qSsXEh/PyiIcwZGolSyuzyhOiVlFJbtdYpLZfLfeBdpRSMvQ6yN8OpI2ZX4zasFsW1E2NZ/cs5/Pma0RSWVnP769+y4IX1fL73pNxDLkQnOCXAtdZr2zv79kijFxrzXe+aW4cbstss3DA5njW/msNj14zmdEUtd72xlUueWcfybTnS46EQHSBn4N0RHAsJM2HX2x471Jqr2W0Wrp8cz+pfzubZ68dhUYoH3tnJrL+s4R/rjlBaVWt2iUL0WhLg3TXmB3DqMORuM7sSt2azWrhqXAyf3j+T12+bREK4P3/6ZD9T/7yaRyOoqskAAA60SURBVD/ex7GivtEHuxCd4ZSLmB3lURcxG1UWw5NDYeJtcNlfzK7Go+zKKWbxN0dZsesE9Vozd3g/bpk6kBmDI7BY5IKn6DvauogpAe4M79xqDPLwi31g9zO7Go+Td6aKNzdl8a/NxygqryExwp8bp8Tz/QmxhPrbzS5PCJeTAHelzPWw5DK4/CmY9GOzq/FY1XX1fLbnJG9szCIt6zR2q4V5yf25fnIcqYnhclYuPJYEuCtpDf+4EKpK4KdpYJFLC6528GQpy7YcY/m2HM5U1REX5su1E+L4/sQYYkPlX0HCs0iAu9qe9+G9H8H1/4Lhl5tdTZ9RVWuclb+7NZv1GUa3BqlJYVw9PoZ5yQMI9vUyuUIhuk8C3NXq6+C58cathT+STq7MkHO6guXbcvnP9lyOFJZjt1m4cFgUV4yNZu6IKHy8rGaXKESXSID3hI1/h88fhh+vhtiJZlfTZ2mt2ZVTwgfbc/nv7hMUlFbjb7cyd0Q/Lhvdn9lDo/C1S5gL9yEB3hOqS+HpUcbAxwuXmF2NwOh7ZfORIj7edZzP9+ZxqrwGXy8rc4ZFcsmo/lwwPEqaWUSvJwHeU1b9Hr55Bu5aBwPGmF2NaKauvoHNR0/xye4TfLEvj/zSamwWxZSkMOYO78dFI/oRHy4XQEXvIwHeUyqL4W8TIGIY3P6J0emV6HUaGjQ7cor5fO9JvtyfT0Z+GQCDowKYMzSSOcOimJQYirdNmlqE+STAe1La67Di53DtYkj+vtnViA7IKipn1f581h7MZ/ORU9TUN+DrZSU1KYyZQyKZOSSCwVEB0uWtMIUEeE9qqIdXZkPFafjpt/J0ppupqKljQ0YRX6cX8HV6IUcKywGICvRm2qBwpg2KYOqgcGJDfSXQRY+QAO9pWRvg9Uth9kNwwcNmVyO6IftUBeszCll/uIiNhwspLKsBIDrYhylJ4UxKCGNyYiiDIuUMXbiGBLgZ3r0dDn4Cd66GfqPMrkY4gdaajPwyNh0pYtORU2w+WtQU6GH+dibEhzJxoDGNjgmW2xWFU0iAm6E0D16eCfYAWLTWGH5NeBStNUcLy/k28xTfZp5mW9bppiYXq0UxYkAg4+NCGRMbzNi4EAZFBmCVPltEJzk9wJVSccA/gf5AA/CK1vrZ832nzwU4GB1dLb0CRsyHhUvlrpQ+oKismm3HitmRfZrtx4rZmV1MeU09AH52K6Oig0iOCWZ0TDCjooMZFOmPzSr954i2uSLABwADtNbblFKBwFZggdZ6X1vf6ZMBDrD+WfjiEbjkzzD1XrOrET2soUFzpLCMHdkl7MopZk9uCftOnKGq1hg2zm6zMLx/ICP6BzFiQCDDBwQxvH8gIX7SVa4wuLwJRSn1IfC81vqLttbpswGuNbx9Exz6DL7/GoxaYHZFwmR19Q0cKSxn3/Ez7D1uBPr+E6WcKq9pWicq0Jth/QMZEhXI0H4BDOkXwODIQIL95MnRvsalAa6USgDWAcla6zMtPlsELAKIj4+fmJWV1e39uaWqM/DWQsjZAle/bAzFJkQzWmsKSqvZd+IM6XllHDhZyqG8UjLyy6isrW9aLyLATlJkAIMiAxgU6U9ihDHFhfnhJU0xHsllAa6UCgC+Av6ktV5+vnX77Bl4o+oyWHY9ZH4DVz4HE24xuyLhBhoaNLnFlRzKK+VwQRkZ+cZ0pLCc4orvBn22WhSxob4MDPdnYJgfA8P9iA/zIz7cj7hQP/y9bSYehegOlwS4UsoLWAF8rrV+ur31+3yAA9RWGs0pGatgyj1w0e/By8fsqoSbOl1ew5HCMo4WVpBZWM7RonKOFVWQWVROaVXdWeuG+9uJDfUlNtSPmFBfYkKMKdoxD/K1yX3svZQrLmIqYClwSmv98458RwLcoa7auKi5+SXol2y0i0cNN7sq4UG01hRX1HLsVAXZpyuM+alKck5XkHu6kpziSmrqGs76jp/dyoBgH6JDfOkf5EP/YB/6BfnQP8iY9wvyJjzAW26DNIErAnwG8DWwG+M2QoDfaK0/aes7EuAtHFoJ/7kHaspgyt0w/X7wCzO7KtEHNDRoispryC2u5HjTVMWJkkpOnqniZEkVeWeqaGgRDxYFEQHeRAV5ExngTVSgDxGBdiICvJtNdsIDvAnx9ZJxSp1EHuTprUrzYOVvYfd7xgM/U38Ck+6AgCizKxN9XF19A0XlNZwsqeLkmSryz1SRX1pN/plq8kurKCgzXheV11DfMukx2uRD/bwI87efPfnZCfGzE+rvZcz97IT4ehHqZyfQxyah3woJ8N4ubx+s/T/Y/zFYbDDkezDuRhg8F7x8za5OiDY1NGiKK2spLKt2TDUUllZzqryGovIaTpVXc7q8lqJyY1lJZe05Z/aNlIIgHy+Cfc+egnxtBPl4EeTrRZCPjUAfLwLPmtsI9PbC39vqkQ9FSYC7i/wDsOMt2PU2lOWBzQcGToNBcyF+KvRPBpu32VUK0WUNDZozVbWcKq/hdEUtJZU1nC6v5XRFDWcqaymprKXYMW+cSqvqKKmsPafdvjW+Xlb8vW0EeFsJ8LHhb7cR4G3Dv3GyW/HztuFntxqv7cZr33NeW/H1Ml7brRZTL/BKgLub+jo4uhbSV8HhL6HwkLHc4mV0jNUvGSKGQOQwCE2EkDiw+5tashCuVlVbT2lVHaVVtZypqqPM8bq0qo6yamMqraqlrLreeF9VS3lNPeXVdcZUU0+FY94ZFkVTmPt4NU4WfGzfvfb2sjreW/B2zH28rHjbLHjbLFw0sh+xoV3rWrqtAJcbQ3srqw0GX2RMACU5kJMGx7fD8W2Q8QXsePPs7/iGQVC00X7uHwX+EeAb6phCwDsYvAPBOwC8/ByTrzFZ5ek+0fs1hmdkYPf+FdrQoKmsrafCEe4VNfVU1jrmNfVNn1XU1FNVa0yNrytrG5c1UFVbT3lNHUXlDVTX1VNda8wbP6tr1laUFBnQ5QBviwS4uwiONabmj+FXFkNhOhRnQfExKMk2LoqW5UFhBlQUQm1Fx7avrEZzjc0OVrtxpm+1OeZeRrv8WZO1lfdWYztNcxtYLC2WWUFZvps3X6as362vLGdPFqvRQNq0rOV7C6DOXaYsZy+j8TXNXrf8XJ39uq3vnPW5Ov9nLZfBebZxvjlnf7fpNS3W6eA+29qeh98PbrGopiaV7v5lcD71DZqaOiPUXdG1sAS4O/MNgbhJxtSW2iqoKobK08aToNVnoLrUeKCottyY11UZ69VVQX2NMdXVQEMdNNRCfa0xylDj+4Z6Y6qrara8znit68+en7Osodl7x2vdfrumMFMbfwl0+XXjNpttv+ll8784WvsLqZXXbX33fHW0+GqLN23U3IH9tVhuReEL+AJc8YxxPcuJJMA9nZcPePWHwP5mV3J+ZwV7vdEBWFPAO1431AP67HUa34Nj/YbWv6sbjHWb3ju+2/x14/Yb10GD5tz1znrdyrzxL6RzPmv8Hm1/95x5y+1w/tet7re1fTb/w29v2x3cd6uvabu+sz5r9nmr63T2u411tHaMba3f2rptrdfZ5Ri3CTuZBLjoHSwWwCJt8UJ0gufdMCmEEH2EBLgQQrgpCXAhhHBTEuBCCOGmJMCFEMJNSYALIYSbkgAXQgg3JQEuhBBuqkd7I1RKFQBdHZY+Aih0Yjnuoi8ed188Zuibx90Xjxk6f9wDtdaRLRf2aIB3h1IqrbXuFD1dXzzuvnjM0DePuy8eMzjvuKUJRQgh3JQEuBBCuCl3CvBXzC7AJH3xuPviMUPfPO6+eMzgpON2mzZwIYQQZ3OnM3AhhBDNSIALIYSbcosAV0rNU0odVEplKKUeMrseV1BKxSml1iil9iul9iql7ncsD1NKfaGUSnfMQ82u1dmUUlal1Hal1ArH+0Sl1GbHMb+tlLKbXaOzKaVClFLvKaUOOH7zqZ7+WyulfuH4b3uPUmqZUsrHE39rpdRipVS+UmpPs2Wt/rbK8Jwj23YppSZ0Zl+9PsCVUlbgBeBSYCRwg1JqpLlVuUQd8Eut9QggFfiJ4zgfAr7UWg8BvnS89zT3A/ubvX8c+KvjmE8Dd5hSlWs9C3ymtR4OjMU4fo/9rZVSMcDPgBStdTJgBa7HM3/rJcC8Fsva+m0vBYY4pkXAi53ZUa8PcGAykKG1PqK1rgH+DVxlck1Op7U+obXe5nhdivE/dAzGsS51rLYUWND6FtyTUioWuBx41fFeARcC7zlW8cRjDgJmAa8BaK1rtNbFePhvjTGEo69Sygb4ASfwwN9aa70OONVicVu/7VXAP7VhExCilBrQ0X25Q4DHANnN3uc4lnkspVQCMB7YDPTTWp8AI+SBKPMqc4lngAeBxqHpw4FirXWd470n/t5JQAHwuqPp6FWllD8e/FtrrXOBJ4FjGMFdAmzF83/rRm39tt3KN3cIcNXKMo+991EpFQC8D/xca33G7HpcSSk1H8jXWm9tvriVVT3t97YBE4AXtdbjgXI8qLmkNY4236uARCAa8MdoPmjJ037r9nTrv3d3CPAcIK7Z+1jguEm1uJRSygsjvN/SWi93LM5r/CeVY55vVn0uMB24UimVidE0diHGGXmI45/Z4Jm/dw6Qo7Xe7Hj/Hkage/JvfRFwVGtdoLWuBZYD0/D837pRW79tt/LNHQL8W2CI42q1HePCx0cm1+R0jrbf14D9Wuunm330EXCr4/WtwIc9XZuraK0f1lrHaq0TMH7X1VrrG4E1wLWO1TzqmAG01ieBbKXUMMeiucA+PPi3xmg6SVVK+Tn+W288Zo/+rZtp67f9CLjFcTdKKlDS2NTSIVrrXj8BlwGHgMPAb82ux0XHOAPjn067gB2O6TKMNuEvgXTHPMzsWl10/HOAFY7XScAWIAN4F/A2uz4XHO84IM3xe/8HCPX03xr4A3AA2AO8AXh74m8NLMNo56/FOMO+o63fFqMJ5QVHtu3GuEunw/uSR+mFEMJNuUMTihBCiFZIgAshhJuSABdCCDclAS6EEG5KAlwIIdyUBLgQQrgpCXAhhHBT/x9m/dbOyyLfZQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "t = np.arange(len(linear_losses))\n", "plt.plot(t, linear_losses, t, non_linear_losses)\n", "plt.legend(['linear', 'non_linear'])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 90, "metadata": { "id": "Wu_cKZ5PUzT0" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dfXxU5Z338c9vMiEjARMqcRNFBHxoVYw8BEFsfaxBG0Rx7a21u5Xt3tDeW2va1S3gY6youLbdht69X32xqwWtKHvjs6k3qIjaVhGCqFBoRSoVDRKpCY8hmcx1/zGZMEnOmYfMOXPmZH7v12tfymRy5hrsfuea3/ld1yXGGJRSSvlXwOsBKKWUyowGuVJK+ZwGuVJK+ZwGuVJK+ZwGuVJK+VzQixcdPny4GTVqlBcvrZRSvtXY2PiZMaas9+OeBPmoUaNYv369Fy+tlFK+JSI7rB7X0opSSvmcBrlSSvmcBrlSSvmcJzVyKx0dHezcuZO2tjavhzLghEIhRowYQWFhoddDUUq5IGeCfOfOnQwdOpRRo0YhIl4PZ8AwxrBnzx527tzJ6NGjvR6OUsoFjpRWROSHIrJZRDaJyGMiEkr3Gm1tbRxzzDEa4g4TEY455hj9pqPUAJZxkIvI8cCNQJUxZixQAFzbz2tlOhxlQf9elRrYnLrZGQSOEpEgMBj4xKHrKqVUxhq2N1C9oprKpZVUr6imYXuD10NyVMZBboz5GPgJ8FegCWg1xqzq/TwRmSMi60VkfXNzc6Yv64ohQ4YA8Mknn3D11Vd7PBqllBMatjdQ94c6mg40YTA0HWii7g91AyrMnSitDAOuAEYDxwHFIvIPvZ9njFlsjKkyxlSVlfVZYZpTjjvuOFasWOHqa4TDYVevr5SKqt9QT1tnz3tEbZ1t1G+o92hEznOitPJV4C/GmGZjTAfwJDDVget65sMPP2Ts2LEALFmyhKuuuopLL72UU045hR/96Efdz1u1ahXnnHMOEyZM4Otf/zr79+8H4Mc//jGTJk1i7NixzJkzh9gpTBdccAG33HIL559/PvX1A+d/RErlsl0HdqX1uB85EeR/BaaIyGCJ3lW7GNjiwHVzxsaNG1m+fDnvvfcey5cv56OPPuKzzz5jwYIFvPTSS2zYsIGqqip+9rOfAXDDDTewbt06Nm3axKFDh3j++ee7r9XS0sKrr77KTTfd5NXbUSqvlBeXp/W4HzlRI18LrAA2AO91XXNxptdNxbK1O5hy38ssW2u5j4xjLr74YkpKSgiFQpx++uns2LGDN998kz/+8Y+ce+65jBs3jqVLl7JjR3Qcr7zyCpMnT+bMM89k9erVbN68ufta11xzjatjVUr1VDuhllBBz47oUEGI2gm1Ho3IeY4sCDLG3Anc6cS10rFo9TZ2tbbxi9XbuG7yia69TlFRUfe/FxQUEA6HMcZwySWX8Nhjj/V4bltbG//yL//C+vXrOeGEE6irq+vRw11cXOzaOJVSfdWMqQGitfJdB3ZRXlxO7YTa7scHAl/vtXLjRSdTURLi+xednPXXnjJlCr///e/Ztm0bAAcPHuTPf/5zd2gPHz6c/fv3u37TVCmVXM2YGlZdvYp3r3+XVVevGlAhDjm0RL8/rpt8oqsz8UTKyspYsmQJ3/jGNzh8+DAACxYs4NRTT2X27NmceeaZjBo1ikmTJnkyPqVU/pBYR0U2VVVVmd4HS2zZsoXTTjst62PJF/r3q3JZw/aGAV36cIqINBpjqno/7usZuVLK/2ILdmK93rEFO4CGeYp8XSNXSvlfPizYcZsGuVLKU/mwYMdtGuRKKU/5acFOrm6+pUGulPKUXxbs5PLmWxrkSilP1YypoW5qHRXFFQhCRXEFdVPrEt7o9GJmnMu1fO1aUUp5rmZMTcodKl51ueRyLV9n5FmyZs0apk+fDsCzzz7LwoULPR6RUv7k1cw4l2v5GuQemDFjBvPmzXP1NTo7O129vlJe8WpmnMu1fA3yOB9++CGnnXYas2fP5owzzqC6uppDhw6xceNGpkyZQmVlJTNnzuTzzz8HovuLz507l7PPPptTTz2V119/PaXXWbJkCTfccAMAs2bN4sYbb2Tq1KmMGTOmx94sDzzwAJMmTaKyspI77zyyJ9mVV17JxIkTOeOMM1i8+MhGk0OGDOGOO+5g8uTJvPHGG078lSiVc7yaGfenlp8t/g7yxiXws9Oi/3TI+++/z/e+9z02b95MaWkpTzzxBN/61re4//77effddznzzDO56667up8fDod56623+PnPf97j8XQ0NTXxu9/9jueff757pr5q1Sref/993nrrLTZu3EhjYyOvvfYaAA899BCNjY2sX7+eRYsWsWfPHgAOHDjA2LFjWbt2LV/+8pcz/JtQKjd5OTPO1c23/B3kr94Pez+J/tMho0ePZty4cQBMnDiRDz74gJaWFs4//3wArr/++u5ABbjqqqu6n/vhhx/26zWvvPJKAoEAp59+Op9++ikQDfJVq1Yxfvx4JkyYwNatW3n//fcBWLRoEWeddRZTpkzho48+6n68oKCAv//7v+/XGJTyi1yeGXvF310r58+Nhvj5cx27ZO+9x1taWlJ6fmyf8kxfM7aJmTGG+fPn853vfKfHc9esWcNLL73EG2+8weDBg7ngggu6t84NhUIUFBT0awwqv/lt06p0ulzygb9n5BNnwb9uif7TJSUlJQwbNqy7/v3II490z87dNG3aNB566KHuc0A//vhjdu/eTWtrK8OGDWPw4MFs3bqVN9980/WxqIEtlxe6qNT4e0aeJUuXLuW73/0uBw8eZMyYMfz61792/TWrq6vZsmUL55xzDhC9kfmb3/yGSy+9lF/96ldUVlbyxS9+kSlTprg+FjWwJWrn01lvBhqXwIt3QLgdgkVwyV2uTTp1P/I8oX+/yk7l0koMfXNAEN69/l0PRuSOrJWPYgHethfi/16PPi5aQciA7keulLJUXlxO04Emy8cHCtdXg8bPvsNt9Ajw4FHRGbmD9/J683eNPAetXLmScePG9fi/mTNnej0spWzl8kIXp7i2GrRxCSwcCc/9ANpaIXyI7hAPlbL2jDuoDC+lsm0xy8IXZvZaCeiM3GHTpk1j2rRpXg9DqZTlwynzjq4GTWH2vfakG5m9eSz7GsMYot1sv1i9zbUzhh0JchEpBf4LGEv0XX3bGKNLC5XyiYHezpdx+ShReAOESuGSu1gWvpCFL2ztEeAAJaEg37/o5P6/gSScmpHXA//PGHO1iAwCBjt0XaWUyljthNoeNXJIsXxkd+MS+sy+25/q5HB4U49nlYSCzL3sS67NxLuHkukFRORo4DxgFoAxph1oz/S6Sqnc5LfFQ5Bm+SiVG5cJZt+QvQDvHpYD1xgDNAO/FpGzgEag1hhzIP5JIjIHmAMwcuRIB17WXXV1dQwZMoSbb77Z8udPP/00p556KqeffnqWR6aUd9zq/sjGh0PS8lGi2XdX6YSJs1i2dgcLn9vKvraes+9QUCgKFmQ1wGOcCPIgMAH4vjFmrYjUA/OA2+OfZIxZDCyGaB+5A6/rqaeffprp06drkKu84sbiIa8OigBSnn13B3jdSva1hT0pnyTiRJDvBHYaY9Z2/XkF0SB3lRuf4Pfccw8PP/wwJ5xwAmVlZUycOJH//M//ZPHixbS3t3PyySfzyCOPsHHjRp599lleffVVFixYwBNPPMHq1av7PG/wYL1VoAYWN/YCz/rK0hRvXMbPvtuf+i2Hw4aCo99m8PErkcIWJFzKjJH/zD3Vs5wfY5oy7iM3xuwCPhKRL3Y9dDHwx0yvm4gbe0M0Njby+OOP8/bbb/Pkk0+ybt06ILq74bp163jnnXc47bTTePDBB5k6dSozZszggQceYOPGjZx00kmWz1PKT1I5B9ONvcCzdlBEop7v4FHRAL+8HubtYFn4QirrVnLrU5vY2xamrSvEQxVPEhjUgghQ2MKq3b/MiT1pnOpa+T7waFfHynbgnxy6riU3PsFff/11Zs6c2T2LnjFjBgCbNm3itttuo6Wlhf3799v2iKf6PKVyUarljfNGnMfyPy3v8buZLh5ydWVpGqUTwLZ8EgoKhceuhEBHj8vnyp40jgS5MWYj0Gf9v1vc+gQXkT6PzZo1i6effpqzzjqLJUuWsGbNGsvfTfV5SuWiVCZHDdsbeGbbM31+94qTr8goyPrdGmgnjdJJzLK1O6IdKAnq35VL51nsSJMbhy/7cmWnG5/g5513HrNmzWLevHmEw2Gee+45vvOd77Bv3z4qKiro6Ojg0Ucf5fjjjwdg6NCh7Nu3r/v37Z6nlB+kMjmyCnuA13a+1uexdDiysjRZeNvNvl/YSnu4k8Nhk/QGZi7vSePLIHf8ExyYMGEC11xzDePGjePEE0/kK1/5CgB33303kydP5sQTT+TMM8/sDu9rr72W2bNns2jRIlasWGH7PKW8kk5DQCoh5WYtu18rS/sR3mA/+4bEHShu5I5TfLuNrR8XJXhJt7HNHw3bG7hv7X20trf2eDxUELI9Eq13jdzq+dUrqi3DPiABjDHZ+f/DfoY32Ad4Ov3fXueO3Ta2vg1ylR79+80PVoEcr6K4glVXr7L93UQhlezakPjDoj+v2S2FpfKJwjvV8onTnA5+3Y9cqTxgV8eOSVQGSVbe6F3LFhEiJtLjOel0cSTtlEmz4yRef8snTsrmQqecCnJjjGXniMqMF9+6lDeS1aszvTEXH/aVSytTGoPdrNS2U+aVH1Gz7J9T7jiJ50T5xCnZXOiUM0EeCoXYs2cPxxxzjIa5g4wx7Nmzh1AolPzJKqf052u53U1LcO7GXGxcVsfDxcYQe17vWn38rNT25mmBdC3W6ZJk9t2wvYF73/gpezuaiXSUcnjQNEzbeMDb5fNZW+hEDgX5iBEj2LlzJ83NzV4PZcAJhUKMGDHC62GoNPT3a7lVZwVAaVEp886el/FMsGF7A7f//nY6Ih2WP499WCSqp7d1tnHL6/Mp6eykpaDv4vLycGfS8Ibo7Pu+1x7FDF+BBDpAIDCohVDFk0hhgPnnfdPT/U+y2a6YM0FeWFjI6NGjvR6GUjmhv1/L3T7tZ+FbC21DvKK4ovu1qldUJ6zVRzDsFyg0ho64b+ChiKF29Az4n/fb/m58+WTwSS8Q6LXaUgIdVIxew3WTb0nz3Tkrm+2KORPkSqkjMvla7uRpP73LOy2HW2yfG98Nk8o4w4EAJZ0RBhvDrmCA8rgPAitW9W8ptB5PLqy2zOYRehrkSuWgdL+Wu9HfbFXeSdXRgUG0Rg4nfd7eggJ+d/27CZ+TaPn8kMIy9ob7lmNzYbUlZO8IPQ1ypXJQOl/L+1NPTyX4k7UyxisZVNKjXVCO+wIUFCT9PbvATbX/u2F7OGdXW2aTBrlSOSidr+Xp1tNTDf5UyxPBiGF+00fwpx8QaxdsDSTfIdsqcNPt/85m+SKXaZArlaNS/Vqebj091eC3K++UDCphsImwq30v5eFOaj9voebAwSNPCB5FeaehKWjdRixIn8DNpP87W+WLXKZBrpTPpVtPTzX4Lcs7XbPvmr2t9F5p2TCkmPpj/45dHfspKSpBDrf26TUvDBRy97l3UzOmhmVrd1D58MqUls9HD72Y3WPWDToTj9EgV8rn0m1zSzX4u8sWb95jP/vuWmnZMKwsOoaO6K6fdt0tHZEO7n3jp8x/OJhy+cSqFHTb725DRLpbIbN6zmcO0iBXyufSrROnFPxdNy5rwu3UpLDPSX2SvvF4rR3N7G8LH3ntJOUTq1JQ2IT7rN7PldN6vKBBrpSHnGobTKdOnDD4E+0ymGCfk3RaE01HKZD68vl0esJzoX/cCxrkSnkk3bZBJ3vFewR/4xIanr2R6qH/xq6CAOXDi6n9PFqyqP/CsOhjg46mdsqttuOyYwzEb51kIoUEWi7j3pljU14+n2j/GKvn5qPkPUJKKVck6h7pLRb6TQeaMJju0E92gnv0JmE1lUsrqV5R3fP5XafKN6y+hbqSEE3BAowITYVBbi87htuOLTvyWMc+29ezGi9EQ7zj8ylE2ksxBugo5Yrjb+S9f7slrT1QaifUEirouelbUIIUBgp7PJaP/eMxOiNXKk1OzYzTaRvsz94rtjP+v7xGzbrHussn9SOOo61X33d0/5OepRW710tUzpA9Mwm2ZrZ9rF0pyOqxfKyPgwa5Umlx8rAA2z7topI+j/Vn7xXb8P/Ls9S0Hdladlcw+QpMu9dbtnYHpqMELPY8KSks4/cLvpbytROxuweQr8Hdm2OlFREpEJG3ReR5p66pVK5JpxxiJ1busKv77m/f36eEYVf7TVQTtg3/WHCHSuHyesqHHJfCqI+83rK1O6isW8mXbvsttz61iUO7p2Eifcsct5xzU8rXVZlxskZeC2xx8HpK5ZxMDwuIr3XbCZsw816f16OmbVUntq0Jd9W+y8Phvj8DyjsNXF4P83bAxFmW17YSlCKa/nIBd778CJ3HLyB48lwGn7QQAPnsao4OliEIFcUVaZ3bqTLnSGlFREYANcA9wL86cU018Hl9Inl/ZHpYQDobUVmVbRL+ffVqHaz9Wwd1w7/Qo/4dKghR+5U6iPu9+GvbfsAYYf/HV2KIRA9u6NoDXAa1MHTE0yz48l3UjPF2/+98Jk6c5ygiK4D7gKHAzcaY6RbPmQPMARg5cuTEHTt2ZPy6yr+sTpBJ9wR2L2Q67sqllbZHpNmpKLY/+T7ZAcXxy+ZT+bBc8OYClv9peY/HYt0nhz+9kuKTFhIY1LcennCMyjEi0miMqer9eMYzchGZDuw2xjSKyAV2zzPGLAYWA1RVVelpwHkumwfTOinT3fbS6YmOsSzbpLhwp4boV+VUvfDB6j6PiUBwyFZCrUEig1otfiu3F+L48ZtfupworZwLzBCRrwEh4GgR+Y0x5h8cuLYaoLJ5MK3TMtltz2p5fFCCDBk0xHZ/ku6yTZLZd7IzLhOJ7T4YGdWM1dnngUGtvFM3jeoVP7UtLeViYDrZZZTLMg5yY8x8YD5A14z8Zg1xlUw2D6bNJYlm9JZlGwlSu/MDWFDeN7wh4bL5VPTePra4oxSxLJ1E/7vY7dNy3ojzXAvMTD4g/PrNL13aR648kc2DaXNNsp7oaGg1UR6OUPu3T6k5cKDnE5PMvpMFX6LTdwItlxH8uycJmyPHtMX/d7H7IHIrMDOdUfv5m186HA1yY8waYI2T11QDk57sYq3m82Zqtm3tW/tOsXSSKPham8dy32uPEil9ARnVQkFHKQXN0wjvHR+3gVUNDdvPSvjfxeqDaP7r8y3Hk2lgZvoBkS/f/HRGrjyTDye7pFQWSFT7TrN0Yhd881YvpG33NIoqniQQ1zp4VMWTzKgawT3VR67fn/8ubgVmpjPqfPnmp5tmKeWSZBtdNbw8l+oHz6DyvZ9QPbyYhqK4/U26Vl3GFu2kyi7gTLCFQWUru/u/uwU6WLd3WT/eXU9pLVhKQ39WtMarGVND3dQ6KoorBvRiJZ2RK+US27LAm/fAszdSVxKiLRidSzUVBqkbfgwUDqbmK7f3++bl0OBw9oab+zxuOkoJWOyHAs7Ui90qlTkxo86Hb34a5Eq5xLYs0L6X+iGD+uw42BYQ6kecRE0G7YOHBl1EUdzKSwAihVx54j+zbu8yV+vFbgSm3ktJjQa5Ui6x3d0wEqHJZsfBdGfHfU6fbxuPAULHrkSCLRxdWMYt59zU1d5Y5st6cT7MqDOlQa6US2qHjadu3ye0BY6ssCk0hv0SwHLVDanNjhO1Dw495h2OPm41+8KtlBdX9Ji96ux24NIgV8ppsYOL2/ZC8VHUDytlV7CA8k7DwUFH0Ro5bPlryWbHfWbfcUpCQaZPbWLV7qfYG7bvudbZ7cCkQa7ymuPLyhuXwPM/BBMBoObAQWo6B3W3EFYurbT9VbtuCrsA7336fLXFSfYDcRWj6kuDXOUst/fuSHXVYFq94PELeSx6wG3r5oNKqN9Qz/zX53e/RmvzWMsAtzt9Pl9WMaq+NMiVp+xCMhubHaWyajClcfSahSMBmP4fli2EdptmHQwfpLW9tfs15r56O4ebrqKjbXz38+wCPDZOEcFqW+qBtopR9aVBrjyTKCSzsdlRKjPY+9beZz+Oz5tTmoXHs7rheCh8qM/OhxLoYFDZSjp6LJ+3Prw49vcYiX2QxPFDV4rKnAa58kyisM5GmSDZsvKG7Q3ds+S+42hKeRbeW/wNx2Vrd3DflulgtXVsYQv3zhyb9PR5u1OHAhIYkKsYVV+6RF95JlFYZ7o0OxXJlpUnOlC5vCN8JMRDpSmHeEzsAONbn9pEpKPU8jkVQyqShjgkWJZvjIZ4ntAZufJMohmxVS0ZouWXsx4+i4iJUNGrTzpdyfqqbWf/xlD7eUtas3Cw7/8+3DyNoyqehLjVmOmURPJlhz9lT4NceSbRPhqJDgSO1YKduAHau6+6YXsD1Suq2XVgl+3Nw9JIJNpSOP2+lEI8Wf/33Iv/kZKy8f3u0MmXHf6UPUcOX05XVVWVWb9+fdZfV+WeVFr7qldUJzzn0qmDf61O6OktFIlQN3I6NRffn/A9ANz7xk/Z29FMpKOUw137fkPf/m8n5OIxa8p5docva5CrnJfs5HlBePf6dzN+HbsPjIAxRIjeUIqI9CjpWIa/KcAYgwSOdJGYSCHy2dXMP++bjoW3yj92Qa6lFZXzkp0871Qt2K4mHgFCgULaTBhI3iaJdPbZSkUCHVSMXsN1k29xZKxKxdOuFZVTYjXqyqWVVK+opmF7g2V3SYyTtWC7D4SABLpDPCbWJtmURjukrrBUbtEgVznD7kQdoPuUF4gGK5D2aS9WHxLdGpdQu/MDQpGei2pCBSEiNmWdpv1NRNpLUn5/2kWi3KKlFZUzEi0QWnX1KvdOY/+8GZ7/ITUmAh2Dqf9CKbuCwe5tYK06Z4Dum5ihXgc5FAYKMcYQjpvFO/nNIdGNTb3pmZ80yFXOcHM1p92HxH2/v5Oa7R8c2a2wcxA14+b2aCv8w7Zmnt23qEeft4kUcrh5GsXtk5h+/AjW7V3Wp2vFjUBN9IEEuL4/jcpNGQe5iJwAPAyUE70vtNgYY78kTikbbi5ssfswaO1so2FwiJqDbX0W9xzp/y6j4OirKCpbiRS2QLgU+fwyfnxxfAfKrD7XdiM8E31rif271c80yAc2J2bkYeAmY8wGERkKNIrIi8aYPzpwbZVH3FjYEis12LYvilD/hVJqLjoyC7dawBPeO57i9kmO9n73R3++tehN1oEv4yA3xjQBTV3/vk9EtgDHAxrkKi1OH0WWygIfgF3BIMvCF7KwbmWC49NeYl/4M5Z8VE5JmXd152TfWnSpfn5ydEGQiIwCXgPGGmP29vrZHGAOwMiRIyfu2LHDsddVykqyFaHdOkrZv21eguPTftnnW4JXuwpafTjFxgPY/kxLKwOD6wuCRGQI8ATwg94hDmCMWQwshujKTqdeVyk7qZQUTKSQtt3TukM8149PS+Vbi3at5B9HglxEComG+KPGmCeduKbynt9b2ezKEGKid+VN3B4ofjo+LdEBynq4cn5yomtFgAeBLcaYn2U+JJULsnHUmqu6FvjUlYRoCxxZ92YihRxquqp7A6uSUJC5M+1vYOoWscoPnFjZeS7wj8BFIrKx6/++5sB1lYeStbnltK4zNGtaPqPus79R0REGA5H2Utq6QrwkFOTemWN5p25awi6UZIdPKJULnOha+R2WB1UpP8vFkkJKGpcQee6HBIgQMXDufni95TIe77yYUFAoDhYknIHHi5WW2jrbCEjAkcMslHKDruxUlvxWUlj7f3/Klzb/hKHmEAExdBrh1vC3ebzz4ujse0Z6/d+9S0sRE+lz6IVSuUI3zVKW/FJSWLZ2B3fV/RtVm+6mhIM9QvyFwmkplU+s+Lq0pPKOzsjznF1nitOLc5wWW31Z07GSBcGHKBBDxMA+ivm5fJPKGTewMIMVmL4tLam8pEGex5J1puRiK1t8gL8afIyS4EECAp0I68fezuSv38SdDrxOSVEJLYdb+jyeq6Ulld80yPNYovJBLgV479Pnryl4uXsWDhAhQMHl/8HkFE+zT6ZhewP72/f3ebwwUJhzpSWlQIM8r7ldPujPgqL43xkaHM7eT77K/j1nYYBrC15mbtFjXbXw6KKeQKiUwCV3pXSafarqN9T32Es8ZnBwcE59wCkVo0Gex9zsTOnPgqLev7M33IwZvoIvd77HokNrKCHakQKABAj02nbWKXYfZHvb++w8oVRO0K6VPOZmZ0p/uj7ufeOnfX5HAh18Onwjw+TgkRAPlfbZO9xJdh9kWh9XuUpn5HnMzc6UVMs28fXv4MnNfU6fB9gVLIj+S6gUHC6jWHFjX3Sl3KRBnufc6kxJVraxOryhuKMUGWTRKdJp4PJ61wM8JtdbL5XqTYNcucJuVjvp6OuorFvZI8AB/rFwNZUtf+Xfhw/psclVSILUXrgAshyiudh6qZQdDXLlivhZbdOBXdBRwv6maSzbVIbhSEfIP4XWMC/4GEXh/XDAMJj2PqfYa6AqlZgGuXJNa/NYdm26if29Zt/QK8DDR35qdYo9+H9vdKXcpEGuHGdV/4Zo+eTmgmUMlgiF5nCPAE90I9P3e6Mr5TINcuUYuwDvMfvG0OOHKXSiJFuBqrN1le80yFVGei+ftw3w+Nl38CgIFqXcSpiolVFn60qBGJP9c5CrqqrM+vXrs/66yjl2s2+wmIHH9LMPvHpFtWUrY0VxBYDlzwISwBijM3Q1oIhIozGmqvfjOiNXabEL8FBQ+EbwFesZeD8DPFYysQrq2AKd+a/Pt/zdiIkAOkNX+UGDXFnqXXeedPR1PPeHij4BXhIKsviMTUz+oB7a9joS4LHX792HHhN/3Jpd0MfLxR0dlXKSBvkA4PTNPqu68zP7FtE26CpM25HT57sDfPNenCihxLO6wQnREF919aruP1stPLKiB0KogUyD3OfcuNlnt3lVUdlKitsnuRrgManu1dJ7Ob2IdJdV4umGV2og0yD3OScPh4jVvyOjrDevCgxq4Z3QbFcDPCadLXbjl9NblWR0wys10DkS5CJyKVAPFAD/ZYxZ6MR1VZO+9xMAAAwbSURBVHKZHg5h1T5ot3lVRUcY2lqPPODiboR2JZPzRpyX8Pd0wyuVjzIOchEpAH4JXALsBNaJyLPGmD9mem2VXH8Ph0jUPhhouYxBx/437XQeedAYDorQUDyYms5Brm8nWzOmhrd3v83yPy3v8fgz255h/LHjEwazbnil8o0TB0ucDWwzxmw3xrQDjwNXOHBdlYJ0D4dYtnYHlXUrufWpTeyNC/FQUCgJBVk+cSvvdf6KH+/eTWlnJ8TWGYjQGiyg7u/Kafgf/ycrW8q+tvO1Po8lO5xCqXzkRGnleOCjuD/vBCb3fpKIzAHmAIwcOdKBl1WQeinBbgZeEgoy97IvcV3wFXjxju76dw1QP6yEloKexfI2E85aK5/bZ4oqNVA4EeQWt8X6fFvHGLMYWAzRlZ0OvG7W5eqeHnalhETL50tCQaZPbWLdZ4tZuGUvS8Kd1BZ0UMOR49R2FVr/zyNbQermmaJKDSROlFZ2AifE/XkE8IkD180psW6IpgNNGEx3m1/D9gavh9ZH7/JJW1yIl4SC3DtzLPd+ZTWrPv4JTR37MCI0FQapG/4FGkqHR0/jmbeD8q4l8L1lK0jdPFNUqYHEiRn5OuAUERkNfAxcC1znwHVzipNtfm5JtHy+KFgQV0KZTfXwYtp6zbjbAgHqR5xETVf9+7wR5/W52ZjNINUOFKVSk3GQG2PCInIDsJJo++FDxpjNGY8sx+RyvTZp/XvyidC4BF6cHV1Gj2FX8GjLa8XeT8P2Bp7Z9kyfn19x8hW2QepG6Uk7UJRKzpE+cmPMb4HfOnGtXJVr9dpk9e/42TcvtkO4jfhbF+WdhqZg39sbsfdjt0TeqpME9PAHpbzkRI08L+RKvTZZ/Xv5xK28E5rNdS9Ohud+EF3AEz4EcTcxubye2gv/PeH7SfcbSKLSk1LKXbpEP0Ve12tT2j6WMGzuOfMGLA9yiI3a7v2k+w3EjdJTrnYJKZVrNMjTkO16bUqn7xCOlk3CycO7N6v3k8oe4FacLj1pqUap1GmQ56CUT99JMbxTndmmuge4Fau9UTIpPfmhS0ipXKFBnkP6dfp8kpl3OjPbVPcAt+J06SmXu4SUyjUa5DnAKsCvLXiZ+YWPHQnvfpw+D+nNbDMNTydLT7nWJaRULtMg94hV/fvagpeZG3yMIglzFB1I7/BO8/R5SC+ccyk8nS7VKDWQaZBnWe/Z97UFLzO36DGK6CAkYQIpdJykI51wzqXw9LpLSCk/0SDPEssADz5GCYcIiLPhHS+dcM618NRVnUqlRoPcZfEBfk2i2beD4R0v3XDW8FTKf8SY7O8oW1VVZdavX5/1182WWHhf0bmKfw0so4gOAEJ0EOi9Kt7F49KUUgOLiDQaY6p6P64zcqc0LuHwC7fRGT7MlcZwJRAKWAQ3uDb7VkrlJ18GuedLtxuXRE/TCbcD0NFpKDCHKYqVSuLCOwJ0SojCAtHwVkq5wndBnvWl271CG+izk2Bh1z8jBtooBBEKJUBhUYjAJXcRyJHg9vwDUCnlCt8FuetLt3sHd9hiEyriQrtLO4XUyz9wymU3RPf/zjG6d4lSA5fvgtzRpdspzLZjOiREh4mAMbRTyMLwN3i882KOKn2b4PCVUNhKRfE7VJVtAnIvyHXvEqUGLt8Feb9XH6YR2kD0hiRwmEIWhq9lyaEL+hze8M3zm1i1+xlfzHJ17xKlBi7fBXnKC1xSLJEA3aEd/fci1p50I7M3j018+s7kE6leUe2bWW4uLb9XSjnLd0FuucBl2Hhq/vt/pTbb7hXa8V0k3Yt3GsMYwj1+rcf5l138NMvNpeX3Siln+S7IoWv14efNXTPu3RB+i6Sz7QStfymdPm9xA9NPs9xcW36vlHKOL4McgFfvj55HGS/BbNtKSqfPJ+C3Wa4uv1dqYPJvkJ8/90gNPI2FNimdPm8T4FZ92HVT63SWq5TyVEZ7rYjIA8DlQDvwAfBPxpiWZL/nxV4riY5PS2UGbnUMWqggRN3UOg1upVRWuLXXyovAfGNMWETuB+YDczO8pqP6W//uTfuwlVK5KqMgN8bEH+T4JnB1ZsNxRiblEzt+6lBRSuUXJ2vk3waW2/1QROYAcwBGjhzp4MsekWn5JBE/dagopfJL0iAXkZcAq7S61RjzTNdzbgXCwKN21zHGLAYWQ7RG3q/R2nCqfJKI3zpUlFL5I2mQG2O+mujnInI9MB242GT5lIpM2wfToX3YSqlclVFpRUQuJXpz83xjzEFnhpSaZWt3cNvTm4jEJbgbAR5P+7CVUrko0xr5/waKgBdFBOBNY8x3Mx5VChat3tYd4m4HuFJK5bJMu1ZOdmog6brxopP5xeptfP+ikzXAlVJ5zbcrO6+bfKIGuFJKAQGvB6CUUiozGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzGuRKKeVzjgS5iNwsIkZEhjtxPaWUUqnLOMhF5ATgEuCvmQ9HKaVUupyYkf8H8CPAOHAtpZRSacooyEVkBvCxMeadFJ47R0TWi8j65ubmTF5WKaVUnGCyJ4jIS0C5xY9uBW4BqlN5IWPMYmAxQFVVlc7elVLKIUmD3BjzVavHReRMYDTwjogAjAA2iMjZxphdjo5SKaWUraRBbscY8x5wbOzPIvIhUGWM+cyBcSmllEqR9pErpZTP9XtG3psxZpRT11JKKZU6nZErpZTPaZBnScP2BqpXVFO5tJLqFdU0bG/wekhKqQHCsdKKstewvYG6P9TR1tkGQNOBJur+UAdAzZgaD0emlBoIdEaeBfUb6rtDPKats436DfUejUgpNZBokGfBrgPWbfV2jyulVDo0yLOgvNhqYaz940oplQ4N8iyonVBLqCDU47FQQYjaCbUejUgpNZDozc4siN3QrN9Qz64DuygvLqd2Qq3e6FRKOUKDPEtqxtRocCulXKGlFaWU8jkNcqWU8jkNcqWU8jkNcqWU8jkNcqWU8jkNcqWU8jkNcqWU8jkxJvvnIItIM7ADGA7k29Fw+fieIT/fdz6+Z8jP952t93yiMaas94OeBHn3i4usN8ZUeTYAD+Tje4b8fN/5+J4hP9+31+9ZSytKKeVzGuRKKeVzXgf5Yo9f3wv5+J4hP993Pr5nyM/37el79rRGrpRSKnNez8iVUkplSINcKaV8zvMgF5G7ReRdEdkoIqtE5Divx+Q2EXlARLZ2ve+nRKTU6zFlg4h8XUQ2i0hERAZ0e5qIXCoifxKRbSIyz+vxZIOIPCQiu0Vkk9djyRYROUFEXhGRLV3/2/bk2C/Pgxx4wBhTaYwZBzwP3OH1gLLgRWCsMaYS+DMw3+PxZMsm4CrgNa8H4iYRKQB+CVwGnA58Q0RO93ZUWbEEuNTrQWRZGLjJGHMaMAX4nhf/rT0PcmPM3rg/FgMD/u6rMWaVMSbc9cc3gRFejidbjDFbjDF/8nocWXA2sM0Ys90Y0w48Dlzh8ZhcZ4x5Dfib1+PIJmNMkzFmQ9e/7wO2AMdnexw5cdSbiNwDfAtoBS70eDjZ9m1gudeDUI46Hvgo7s87gckejUVliYiMAsYDa7P92lkJchF5CSi3+NGtxphnjDG3AreKyHzgBuDObIzLTcnec9dzbiX61ezRbI7NTam87zwgFo8N+G+a+UxEhgBPAD/oVWXIiqwEuTHmqyk+dRnQwAAI8mTvWUSuB6YDF5sB1Myfxn/rgWwncELcn0cAn3g0FuUyESkkGuKPGmOe9GIMntfIReSUuD/OALZ6NZZsEZFLgbnADGPMQa/Hoxy3DjhFREaLyCDgWuBZj8ekXCAiAjwIbDHG/MyzcXg9GRSRJ4AvAhGiW9t+1xjzsaeDcpmIbAOKgD1dD71pjPmuh0PKChGZCfwCKANagI3GmGnejsodIvI14OdAAfCQMeYej4fkOhF5DLiA6JaunwJ3GmMe9HRQLhORLwOvA+8RzTCAW4wxv83qOLwOcqWUUpnxvLSilFIqMxrkSinlcxrkSinlcxrkSinlcxrkSinlcxrkSinlcxrkSinlc/8fSIMOS0HnGBkAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x_probe = torch.linspace(X.min(), X.max(), 300).reshape(-1, 1)\n", "y_lin = linear_network(x_probe).detach().numpy()\n", "y_non_lin = non_linear_network(x_probe).detach().numpy()\n", "\n", "plt.scatter(x_probe.numpy(), y_lin, s=3)\n", "plt.scatter(x_probe.numpy(), y_non_lin, s=3)\n", "plt.scatter(X, Y)\n", "plt.legend(['linear', 'non_linear', 'data'])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "N7Faplaeu1Ur" }, "source": [ "# PyTorch - Advanced" ] }, { "cell_type": "markdown", "metadata": { "id": "lSt5Pr5gu5qX" }, "source": [ "## Distributions\n", "PyTorch has a very convenient [distributions](https://pytorch.org/docs/stable/distributions.html) package." ] }, { "cell_type": "code", "execution_count": 91, "metadata": { "id": "qaCv2BGmbXWF" }, "outputs": [], "source": [ "from torch import distributions" ] }, { "cell_type": "markdown", "metadata": { "id": "C4p2WPuCGxww" }, "source": [ "You create distributions by passing the parameters of the distribution." ] }, { "cell_type": "code", "execution_count": 92, "metadata": { "id": "790r1b6tduyY" }, "outputs": [], "source": [ "mean = torch.zeros(1, requires_grad=True)\n", "std = torch.ones(1, requires_grad=True)\n", "gaussian = distributions.Normal(mean, std)" ] }, { "cell_type": "markdown", "metadata": { "id": "6Hoepp6iG0tl" }, "source": [ "These distributions are instances of the more general `Distribution` class, which you can read more about [here](https://pytorch.org/docs/stable/distributions.html#distribution)." ] }, { "cell_type": "code", "execution_count": 93, "metadata": { "id": "1RO2ebhxgYUx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Normal(loc: tensor([0.], requires_grad=True), scale: tensor([1.], requires_grad=True))\n", "True\n" ] } ], "source": [ "print(gaussian)\n", "print(isinstance(gaussian, distributions.Distribution))" ] }, { "cell_type": "code", "execution_count": 94, "metadata": { "id": "4X2m8tlmd0Rb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.9333]])\n" ] } ], "source": [ "sample = gaussian.sample((1,))\n", "print(sample)" ] }, { "cell_type": "code", "execution_count": 95, "metadata": { "id": "0UbTN8I5Fmxs" }, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.3545]], grad_fn=)" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian.log_prob(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "z3BYXDwOG8Z-" }, "source": [ "The log probability depends on the the parameters of the distribution. So, calling `backward` on a loss that depends on `log_prob` will back-propagate gradients into the parmaeters of the distribution.\n", "\n", "NOTE: this won't back-propagate through the samples (the \"reparatermization trick''), unless you use `rsample`, which is only implemented for some distributions." ] }, { "cell_type": "markdown", "metadata": { "id": "gLAaXWLHeLco" }, "source": [ "at 5:18, we can set loss to -log_prob to maximize the probability of an event. That makes sense to me because the higher prob is, the smaller loss would be. Now if we want to incorporate reward into this loss function, i think usually people do loss=-log_prob*reward? but that means the higher the reward is, the higher the loss is. Is this because we want to push the prob of the event to be higher with a higher loss when we have a higher reward? or should we have lower loss with higher reward (loss=log_prob*reward)?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "azs4-dQ3fWoP" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "IZ7F-74Be6Rg" }, "source": [ "$$s_1, a_1, s_2, a_2, ...$$" ] }, { "cell_type": "code", "execution_count": 96, "metadata": { "id": "p1Hm8zTjd9pR" }, "outputs": [], "source": [ "loss = - gaussian.log_prob(sample).sum()" ] }, { "cell_type": "code", "execution_count": 97, "metadata": { "id": "oxkylpG9hGZU" }, "outputs": [], "source": [ "loss.backward()" ] }, { "cell_type": "code", "execution_count": 98, "metadata": { "id": "3Ad0bfIChIgS" }, "outputs": [ { "data": { "text/plain": [ "tensor([-0.9333])" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "_KmS36DoeAgn" }, "source": [ "## Batch-Wise distribution" ] }, { "cell_type": "markdown", "metadata": { "id": "HB1GRLL3GZ-d" }, "source": [ "The distributions also support batch-operations. In this case, all the operations (`sample`, `log_prob`, etc.) are batch-wise." ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "id": "DfPMQRfWeDRs" }, "outputs": [], "source": [ "mean = torch.zeros(10)\n", "std = torch.ones(10)\n", "gaussian = distributions.Normal(mean, std)" ] }, { "cell_type": "code", "execution_count": 100, "metadata": { "id": "cFPZf9mBhXMa" }, "outputs": [ { "data": { "text/plain": [ "Normal(loc: torch.Size([10]), scale: torch.Size([10]))" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian" ] }, { "cell_type": "code", "execution_count": 101, "metadata": { "id": "EhwiqGtseGD-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-1.4607, -0.5094, -1.2975, -1.6653, -0.1941, 1.3166, 0.5912, -0.6193,\n", " -0.2724, -0.4517]])\n" ] } ], "source": [ "sample = gaussian.sample((1,))\n", "print(sample)" ] }, { "cell_type": "code", "execution_count": 102, "metadata": { "id": "fGVHHImreH5n" }, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.9857, -1.0487, -1.7606, -2.3055, -0.9378, -1.7857, -1.0937, -1.1107,\n", " -0.9560, -1.0210]])" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian.log_prob(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "wUn4p3aReLeR" }, "source": [ "## Multivariate Normal" ] }, { "cell_type": "markdown", "metadata": { "id": "OUchz-ayHlVz" }, "source": [ "There are other distributions" ] }, { "cell_type": "code", "execution_count": 103, "metadata": { "id": "Dv74f_XZbTv_" }, "outputs": [], "source": [ "mean = torch.zeros(2)\n", "covariance = torch.tensor(\n", " [[1, 0.8],\n", " [0.8, 1]]\n", ")\n", "gaussian = distributions.MultivariateNormal(mean, covariance)" ] }, { "cell_type": "code", "execution_count": 104, "metadata": { "id": "g1vfCg1McpOM" }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.3727, -0.0504]])" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian.sample((1,))" ] }, { "cell_type": "code", "execution_count": 105, "metadata": { "id": "XMG929mzeqJz" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 105, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO2df4wc533en+/uDak9quGSMAFba/4QXIOMKVrH8mASZdGWqiuqpiWfJcuMILdpU0AIkKAhKxChKtYiE7k64OBQQBqgFWCjLSQrlE35TJkOKBtUkUYtZR9zR9GsxMCKJEoroWYsLRPxVuLe3ds/9uY4O/u+M+/82tm5ez6AYN7e7sw7s77nfed5vz9EKQVCCCHFpZT3AAghhCSDQk4IIQWHQk4IIQWHQk4IIQWHQk4IIQVnII+TfuxjH1MbNmzI49SEEFJYzp49+zdKqTX+13MR8g0bNmBiYiKPUxNCSGERkTd1r9NaIYSQgkMhJ4SQgkMhJ4SQgkMhJ4SQgkMhJ4SQgkMhJ4SQgpNL+CEhhBSd8ck6xk5dxDuNJm6qVnBg90aMbK3lMhYKOSGERGR8so6Hnj2PZmsWAFBvNPHQs+cBIBcxp7VCCCERGTt1cUHEXZqtWYydupjLeCjkhBASkXcazUivZw2FnBBCInJTtRLp9axJLOQicoOI/FREzonIBRE5ksbACCGkXzmweyMqTrnjtYpTxoHdG3MZTxqbnR8BuE0p9YGIOAD+QkT+TCl1JoVjE0JI3+FuaC6aqBXV7t78wfyPzvx/7OhMCFnUjGyt5SbcflLxyEWkLCJTAH4J4MdKqZc073lARCZEZOLy5ctpnJYQQghSEnKl1KxSagjAJwF8TkRu0bznCaXUsFJqeM2arrrohBBCYpJq1IpSqgHgfwK4I83jEkIIMZNG1MoaEanO/7sC4PMAXk16XEIIIXakEbXyCQD/XUTKaE8MzyilfpjCcQkhS5h+qmXS76QRtfIygK0pjIUQQgD0Xy2TfodFswghPSdstR1Uy4RC3g2FnBDSU2xW2/1Wy6TfYa0VQkhPsakc2G+1TPodCjkhpKfYrLb7rZZJv0MhJ4T0FJvV9sjWGh67ewtq1QoEQK1awWN3b6E/boAeOSGkpxzYvbHDIwf0q+1+qmXS71DICSE9pd8qBy4GKOSEkJ7D1Xa6UMgJISQi/ZZ1SiEnhPSMfhPAOPRj1imFnBDSE/ISwLQnj37MOmX4ISGkJ9gkAqWNO3nUG00oXJ88xifrsY9ZN8TBm17vBRRyQkhPyCPtPovJoyxi/N3O0dOJJom4UMgJIT0hj7T7LCaPWWVuSZzGij8OFHJCSE/II+0+i8mjFvJZ3Yp/fLKOnaOncfPBk5ms2inkhJCekEfafRaTh+6YfuqN5oJYZ+HT+2HUCiGkZ/Q6ESiLLFLvMYM2ON2InF5EuYgK8HuyYnh4WE1MTPT8vIQQkib+kEo/tWoF78yvxP0IgNdH90Q6n4icVUoN+1+ntUIIITFx7SIT7lOAjjQ3eSnkhBCSgJGtNeMGqGvlZL3JSyEnhJCEBIl1LzZ5udlJCOk5i6HmipewTdWsN3kp5ISQ1AkS6n4pOpX2ZJJnaV4KOSHESByxCxPqfig61S+TSVok9shFZK2IvCAir4jIBRH5vTQGRgjJl7iJLGH1TfKoueInjwJeWZLGZucMgAeVUr8OYAeA3xGRz6RwXEJIjsQVO5Mg1xtN7Bw9rY2pBrKtueKnHyaTNEks5Eqpd5VSfzn/778D8AqA4j2bEEI6iCt2QYJsyoTMuuaKnzwKeGVJquGHIrIBwFYAL2l+94CITIjIxOXLl9M8LSEkA+KKnU0tEi+9qLniJ48CXlmS2maniNwI4DiAfUqpv/X/Xin1BIAngHaKflrnJYRkw4HdG7vSz23Ezh+KVxIxln4VAC8evC21MYfh3bytDjpYPlDClWar8CGQqQi5iDhoi/hTSqln0zgmISScLOOxkxSc8obi3XzwpPF9aVgZtvfAH6ny/nQLFaeMo3uHCivgLomFXEQEwLcAvKKU+qPkQyKE2NCLELo0YqNvqla03rgAia2MKPegH8IesyINj3wngH8J4DYRmZr/7wspHJcQEkA/hdAFNU7Q+dEC4P4d6xILaJR7EBZNk1XTh16QeEWulPoLtL8XQkgPyTOEzu81f/DhDFpzbR/cvyrOoia4S5R7EPRk4L5e1MQgZnYSUlBMwpS176zzmv34LYus0tej3APd5q0AXXHtRbRbWP2QkIKSVQhdWEanzs7Q0Ysng12b1mjtgHevNLHBZ5XoqhCawufijD3rvpxBcEVOSEHJyrII2xS0Fbm0k2v8Twm7Nq3B8bN1rRjPuzxam8d7f3aOnk7lqUa36brv2BQOn7iAw3dtznx1TyEnpMBkYVmE+c4mO8NL2sk1OqF88swlq88GWSU6uwUApq/NYHyybn1vTU8pjWarJ547rRVCSAdhGZ02mZv3bEt3grG1c0yYJifXbqlWnI7X359uRep0H/SU0otIIgo5IaSDMO/d6zWbOPnyu6mOKanf7p+cvH722KmLEI3RHkWAw6yYrPcLaK0QQjqw8d7df+87NqU9hi6SJQk2do4J7yQ0PlnHkecudIwv6Li2AmyyaFyyLsZFISeEdGHjvcexC+KWFDiweyP2H5syRpmYqHnO4ffZbbAVYPca/JME0JtiXBRyQkgsglarfs8ZSFZSYGRrDRNvvoenzlyyFvNatdJRkCuqzx5VgN3JL49+pBRyQsgCUUQoyO5oNFvYOXq6YzX84DPnuqogRkm+eXRkC4bXr8bhExfQaAZbNzoRDrNJqhUHK5YPJBbgPHp3ijKUl8yS4eFhNTEx0fPzEkLM6KwHtybKoyNbrN7vp+KUcc+2Go6frRvfJwBeH90Teay6iQEAyiL45ldv7RJTU8y4O85e10SPg4icVUoN+19n1AohBIDeelAAnjpzSRuG58+ULGtCP5qtWTz90luBYh9nI3Bkaw33bV+r/d1929caY8Z1YZOC9MMlew2tFUKWKH4bxbRaVYDR/rCpO25qKuGya9OaWGMu6WIGAbzwqr4DmWlDUgE4fraO4fWrCyvmFHJClhCuENYbzY6CUf6f/bj+sj98r1pxFlLQTZNBOaBDEGAWXj+Hxs93bHaajhnkhY9srWHs1MWuyJIiFsryQmuFkCWCtxgW0C3aQevmm6oVHBo/j33HpjpEsNFs4cB3z2F8sm5MJLpv+9rATFCbWO3xybp1xErc5JxeFPnKCq7ICcmAPELQwoib5l5xyti1aQ2eMtQ2ac0pjJ26uBDqp7vu4fWrjZuTN1Urofdr7NRFKxG3CRmsDjrahKWsk3ayhEJOSMr0ogVbHOKsOF3rJExI3WObQu/c13TNnHdtWmO8XwAWrCATZRHMKWU1YY5P1vHBhzNdrztlyTxpJ0so5ISkTC96Q8ZZ8cdJc1+xfAAjW2vYb0jF9x47DFPqv+l+HXnuAj5szQU+RQigDTU0MXbq4kInIy8rlg0YGzbb3uc8n8Io5ISkTNYebNQVv2mDE9B3yPHi9rMMeo9TMq9mdeLmzbYEYJwkwuq1xOn7afoOrmgSjKLc57yfwrjZSUjKhJWBtSGo20yUhsO6DU43aK9WreDo3iG8MbrHWMnQ289Sx6BTwti9+hVxWKchlzjetDt2XaJSEFG+myj3Oe9G2BRyQlImaQs2nQDuPzaFQ+PtFZ7Nit+dCPYdm9Im+bh1SFwBNnW6D1qJrxp08J/u/izGTl3smnDczEuduB0+caHjNdP90tVrgWbsUYjy3UR5sso7EobWCiEpk7QFmynD8skzl3Dy5XeN4uquKm1S571x4e44V1Yc3OCU0JhuWfnpbvMFv50w8eZ7OH62bozzbjRbHd13TPcL0G+OJtmUjPLdRGnsnGUjbBso5IRkQJLCSUGrOJNv7BU4mzBDN+TPK5SNZgsVp4yje4cwsrUWWJsEaEeL6FbcNi3Yjjx3QVsp0D23l7Q3EG2/G12NcdNEEuW9WUAhJ6TPiBpdsmrQwSN3Xm/wG/Y47wpMmK979aPuMD3vMZK0Xnt/umW1QZhGJcG40SRRVu9ZNcK2JZXqhyLybQBfBPBLpdQtYe9n9UOymAlKY7f9fJQmCk5ZMPaV6xuOQSvp2nz3+RdevRw4WQT54260SNgxwqgFTFi1lIRQZzMVpdKhDlP1w7SE/B8D+ADA/6CQk6XM+GQdB753Dq3Zzr8rpyTG6A4d/roiYXibKASJF9DtO8fBFdq4x6pWHKua4ibB9a+y3cnJvxo2TWplEdy3fa32M/1MpkI+f4INAH5IISdLCb+gTF+bMfrY/o41Nse2aaIAdNf09o6rOuhAqXasdCmkgJUt7vkOjZ83euKmYllOSbD3c2utJirdPbOtg/7Y3VsiPdkUYaVuEvKeeeQi8gCABwBg3bp1vTotIZmh83iDsC0O5Y0iuXrN7FN78UdHeNuOeceYhoi75xufrOP42e465UCnKOo8atvaKbp7arOZ63r9UfYbilwBsWdCrpR6AsATQHtF3qvzEpIVUYtQhYWi6aJIbAiqExK3UFYQQZulQHsl7l3Z6jYsw1L+XQToCFUE7GOz32k0cXTvUCT7p6gVEJkQREhMovzRl4DQULQ4ojvolHDj8gHsPzbVlQEadYy2uCJtOvacUlZ1X2xwm1rE+exN1cpCFyNd96Ik4+o3KOSExMT0Rz/odP9ZzQHYZxBbF1vRrVUreGN0Dx7fOwQFwfvTLW0K/Phk3dhFJy5lkQWRXmnIvBxcVjaWF3A5sHsjnLLd2Pz3xdSyzYs3hntkaw3f/OqtkT5TNFIRchF5GsD/AbBRRN4WkX+bxnEJ6WdM6d7LBsyCYao3AtitBsMSf5qtWTz4zDkcGj+Ph549H9kTL5fEmBoPtPthjk/WsfUPnjdaP1evzYbWV5l4872uyB4TJZGOScHfK7RWreBrO9Z1/OzftHQ/Y7q2VYNO3290BpFa1EoUGLVCFgu6jbx9Fv6vbTSGUxLceMPAQtq8N0Tu5oMnA2O9o/5l+2O3D42fx9MvvYVZpRbC9YbXr44VcugPj4wSTeIlaWSJKRwxakRRXuQetULIYsSf0ecvCGVCJyZRswPDGiZHxS9kj45s6aouuHP0dKzNU689YhuxoiNpZEnexa2ygkJOSALiRproojGAaCnpSRJy4hJX8LwFvZJkgyYZgzuOPItbZQWFnJAAbHpJxhFSNxojStq+aRymXphRCPLFvcTpMgS0Jx130rNBAAyUgNacfgxxybu4VVZQyAkxYCrqNPHmewup3UnkU1dK1lvC1e3qUxLA251MV1zKxpc34ZQEh+/abByLd7Kx3QPwUq04C+nyQZPeimVlTF+bXUi5P/azt+A3iYK6EdmQd3GrrOBmJyksWfdING2MxdlI1OFuIB4/W+/c4CwLoKDtLemlWnGwYvmA1YRiavHmL6Llf59/czFsozLo80Gbs4/7ytea7v2qQQeTX7895GoXL9zsJIVHl77uhrBl0SPR5MXaiLh/Fa1jViltnRLbsLxGs2XtyXuP6I1O8T91+M/s3Vx036sbnbcol2lyNdkytfnEHS+me98I6eO5VKGQk0Jgs6mYdq2MuH4w0BbxtFbuabJq0MGLB2/DofHz1t66K6pRUvJ1RPGnF+umZFYws5MUAttNxTTDyEx9LHWUNL/oNxEH2g0d3IqFthukIu2YddOkZpOSD1xPyglK3HFJ2vd0qcEVOcmNKB63rUCXRLRhfXHQbYzt2rRG62nb2iFpUK04EDG3fQvj6ZfeivT+MIvItEo2fb9pd+chFHKSEzZtvrzY2hyzSqXqleuEZ3j96g6BufrRjLVXnQS/tx0nUsWmoUMUTKvkKN9vkOBTuO2gtUJyIaxfpB/do7ZTEq2lEXScNBjZWsOLB2/D66N78OLB23ClByIOtMXw8IkLsZ843DBD20qAYZRFcM82vdjafr+u4IfVZiHBcEVOcsG0uja9bnrUNtW17kXKtbuS7KUX3mi2sO/YVOTVeFmut5qbePM9q073Ycwqhafmj+N/SjF9j/7vJUjwuRq3h0JOcsHUBgxoxxDr/FDdo7abNOMnSnRDnHh0m3ZjYVScMu7ZVsN3XroU6kMnxbsh+ejIFhz76SVt1mRUFIAnz1zqmBh08egu/u9lsdY+6TW0VkguBEVMRHm8ThrdEPfRPk5q/qpBpytiY3j96tSsjiC8Ajo+WU9FxINQ0Ef4TF+b6bi3pgmXYYbR4Iqc5EItZPPS9vE6aXRD3Ef7qCtGpyR45M7NXcfcOXo6NIMzKYL2BOU+6Rx5zq5CY1IUujdX359udWx6LtbaJ72GKfokF2ysCX9n+CwIq+ltmhhMKeQmTKnlGw6ejDBaPVESj5ySxJ44on62Nr+qNqXaDy5rlxeoDjpQCrjS7K65TjoxpejTWiG54E0OMdGLx+vqoLnqX5DVEnXFqEstTysyI4osxxXxWrWCsXtvtX6/oH2PTE8u70+3Fuys96db+GhmDkf3DuHFg7dRxGNAISe54YbxPb53KLcsPpsHUl3Y3MjWGlYFTAJ+/JPS+GQdDz5zzvrzeSLAgsAGTbxe/uGnVmNka816Ms46ZHSxQyEnuRMldTsK45P10CbAtjHgupXlI3duDm3oC7QzP72T0qHx89h/bCpxDfE0qVYco0h7xdim8TEAvPGrZqT3A8H7Djbf5VKGm52kL0g7iy8osxC4vjlaCgiD9KJbWfo3WldWHLRm53D1Wqfv35pVC7HfFaeEZtYhIwGUAPjPXnHKOHzXZky8+R6eOnOpqwytdxLyX7PpzrmirNuMNmXCBqX6R8kCXopQyMmiY3yyjn//zFRXbHazNYt9x6Y6NgdtRDzI5vFPQDtHT+PqtaBonPRF3Hazszpf+nfOUxdGANyzrT3+42frXTXLdZmb3ms2bfp6Rdl/j3Qb3UH3mElD4dBaIYXA9tF6fLKOA987F5hgE8XQqFYc/IN1K/HgM+ew4eBJfOqhH+HQuLldWdJ+lACwfCDan6X3epyStBtTeKg4ZTy+dwgrlg90FfdSAF549bJWLN3fBd17U+mE6Wszxu8qqpXGpKFwuCInfU+UR+uxUxdTrUTYmp3Di6+9t/CztxmEv8N8Gr5txSlDEiT9t+ZUR+cgbzhfnHIG7r023XudvXT12sxCZUbTdxXFSmNt8nC4Iid9T5QCW2mv0vx+t4uuFGzSqAt3ZTqd0H5x/Wd/OJ9J+KqD7bK4OsoioffeW0RMt+pPGpHC2uThpCLkInKHiFwUkV+IyME0jkmIS5RH66C48DTReetxJ5GKU8Ib85UU00IX/661QcqCK82W1opyyuaN4Kh2R5IJNquopsVEYmtFRMoA/gTAPwfwNoCficgJpdT/TXpsQoBoj9ZpRvQ5ZcHMrNIaHbr6KHFbw3k3QNOMpfZvCEaJIAGAFcsGsGL5QCRbY6Wh3rnu/VGKlbE2eTBprMg/B+AXSqm/VkpdA/CnAL6UwnEJAWD/aD0+WU+3wYNqJ7bouG/7Wu04/RuNALBM85qJoJXroFMytpqzPV6UWupXmq1Itsb4ZB1Xr810ve6UrsfRuxunGw6exP5jU6xDnhJpbHbWAHgNw7cBbPe/SUQeAPAAAKxbty6F05Klgk1hLHdDNE1acwpv/KqJr+1Yh6dfeguzSqEsgvu2r+3a6FzAt3x3SoJrIZuvywdKVrXN43jnuoxS732sDjrGlnE3ebrb26ycTRvNN94wsNDVyLtx6n8nQwrjk4aQ6xYJXd+mUuoJAE8A7aJZKZyXLCFMj9auMAVZGrokGFveaTTx6MiWjsYJbkieNmLGZza35lRg7XUA+GhmLlbbNj+63qFXP5pZGKsu+scpCcolwaxv3CVcryfjF3PX/vFfv+lpwq0zY1P6lyGF8UhDyN8G4H3O/CSAd1I4LiGB2DZ3WBmw6gxDARg68jyuXptZEEl/SF3YZDKrFCpOOVETijDcfp5HnrvQca2N5vWysTohbc2pdrapT8jLHjvINvwzbC/DRqQZUhiPNDzynwH4tIjcLCLLAPwGgBMpHJeQQGxWeLVqRVt5MAqNZssYUudtTBE0hsfu3pJZAwm30uDI1hoGl3WvzZqtWTz4zDnjGHXZpq1ZtbDytg3/DPPTw0SaIYXxSSzkSqkZAL8L4BSAVwA8o5TqTeV6sqQJW+G5GYZhPl5cfXVthqDJxBWnka017QZpGihgYVIx3ZNZpWJvlNqGFIaFCeqE3h0TQwqTkUpmp1LqRwB+lMaxyNIjKAwt6HdB4X4ibdvAxlKJG7JYEgldiXvH+8Krl+OdyALX7jCF/wHX26/5i2Ld4JS098ldQUcJ/wwKE0zazYmYYYo+yZUg/9Vfjc/vzerahDklAQSJ0vRti1C5q1zde2vVSleCT9Ybec3WLG5wSoF+vJofm1dIAQQWsUqzHRvjwbOBQk5yxeS/HnnuAhrTrcAQNd0Kb9pT5yMu1UEHH7bmrDYnTatcncjFTRiKQmO6haN7h/DgM+e0kTJlkcDwQd1KWXefd21ag7FTF7H/2BRX1n0Ae3aSxETJ0PO/N46wBfXyDOrB6f180HsE7TolYWGNXtxVrr//5K5Na/DCq5cXfvfBhzOZNlt2e2HWG03jdVacciI/2lSGlh539rBnJ8kEb9RGWIaeW2LW+944BEU/2ISvufZCUEccNwNy0An/E3FtlKN7h/Bhaw6NZmvh+p48c6mjNyWkXRpXAJQyCGK5Mt8LEzBPVkmLWEUpYkZ6A60VkogoRf+PPHfByruuOGUsHygZN+0O7N7YsbJfWWlX72tMt7Cy4mgTY/y802ji/h3rFkrSetm1ac31a5kJTiXy2ig24ZCtWYUVywcw9cjt2HDwZOB742Cb+JTEr2d98P6DQk5iEZYEo/ujDvKudRtwB753rkuQS2hvgh4/W18QTa/gN5otOCXBqkEHjekWRKBvMiHAyZff1Y7FG10S5DyWpHMlaitk/SB4SRJvWB+8/6CQLwGieNg2nwW6oxz8RP2j1pVwPXziQteqfA5YqHtiojWnMLhsAJNfvx1DR57XruyVMk8s9UYTNx88iZuqFYiYxdydIFw7KahuiRf33lQDQgXDsI2s0ZE08SbNKBaSDvTIFzlRPGzbzx4+ccEqCcZPtaKvFW563VSZz6bPprvqDaruF4R7vbY2drM1iw8+DD+X994cvmtz6B9gWQRvjO7B43uHOhJt7t+xzro7vZc0Em9YH7z/4Ip8kZOkca3ps0Ei7k+C8XL4rs048N1zHVEbTklw+K7N2mOZHuFLJrvEg9tgImnI35xql4/9aEYtVD80TSRhxQnLIh2C5/6v7snDxT2XLv56eP3qwM/68ce2J3lSYzx4f8EV+SInycZUVC/XFYqgzL6xe29diBYpi6A1pxbSy/2YanfYNCd2tda7camjWnEWVpYmmq05vPbYF/DG6B689tgXjNEuQVScMr751Vu1HemnHrndeMygc41srWHFcru1mP8pKcmTGuk/KOSLHJNXbeNhm96zatCxbvTg777uZmRWnPLCatMkIqZH+A8t6nJfabYwPlnH8bPBwiTSnjBeH90TGI7oJaoX7F+J6zDVIak3mgv3Tnc/gyZbN8xRZ30whHBxwYSgRU6S5I2gzwJ2jR50nzVFu+jS2nXsHD0dape4omxjq3ivSTfme7bVFpJ63Gv1l4sN4vG9Q1Y2hDcSyL+ZqSs9UHHKKIm+QfSqQQeTX7/deC5T4lRQspVpvKyb0jtMCUH0yBc5SQoVhX026BhBK76kcci7Nq3Rxn9HfY93XIdPXMCK5QNotmYXfPDafGamN9TRfXq4Z1ut43UT1YoT2Xfe+gfPd00SumzQZmvWaAmFrc+ShhDa1ignvYFCvgRIsjEV97NBYp1URGyqCP7wnD5G3ESj2VrYNHSLYbnp9boJ6emX3sJ929cGhkJWnDK+eOsnsHP0tPUkOj5Zj1QrxqTXYdE6uhBC95ptSLKJTtKHQk4yIUisbeOQTdmbNmZg0ibMCuiovOhnVikcP1vHfdvXalfmqwYd7PnsJ7SreaCzs5BX5KN61KYomrBJcWRrrau6pAJw/Gwdw+tXh4oxszv7C252kkwI6hZjE4fsj6poNFt431LE00IBgV19mq1ZPHnmEm5wSh0bi4/vHcLk1283rub9nYW8USNBnr5T7hxLxSnjvu1rrbvc+3nh1cvG6pJhJNlEJ+lDISeZECbWblGqo3uHAAD7j00tRGIAdnVLkuAKdK1awapBfUIScL3fZhDvT7fw0cwcju4d6gi/DFq1mqwJ08RRrTgY+8qtXffz0ZEtsZNzkqyqw9q6kd5Ca4UEEicy4dD4+QXvuCyC+3esw6MjW7TH9iYI1RtNHPjuOQDZP6LPKrXQCi6sBsyB3RuN9b1ddP5wkL0U1pLNX9/88F2bjfsVcfcxkuxVsNtPf8HwQ2JEF0Loiowpg/PQ+HlttMjXNGL+mf/4Z5jWxIRXKw5WLB/IvAlDGE5JMHbvrQt+9r5jU4Hvd0P3gsIIw0IwvccKus9pwLrixYP1yElkdI///rZr/iSep196S3ss/+vjk3WtiANtP1z36N5rbrxhoMMKCrJggPZK9tD4eew/NtVRE9zbYPiebbUOkTfhirg/U1aXFBQX1kxZPNBaWQRklZgRZm/47YTxybrRfvC/Hrah5n90X1lxcOXDVuxGyXFo+CyXR+7cbKz6WHHK2LVpjTbSxbuy9n5e1ybOi//+ZxG7zZopiwMKecHJMjHDpuCUKzbuOEz4N/GCJgl35esXmZszaMQQhN8r9k4u9UazI3HIDR0MEuWgJxyb8zN2m5igkBecLP+4dfHeflyxCYsyuW/72q7PmSaJR+6MVg2xLII5pRZ6ZNpkdNrU89YlxwStYPcHeOhBG5w6nLJ0RYDEiTJhGv3SgB55wUkzMcPvvwLAY3dvMYbECa4XkAo6n26j01Qk6ms71hmFxhTy9s2v3orXR/fgxYO3WWV97vzU6oUyt0HYHMuLKdrDvU+RYqw1s0zU2G1WOFw6JBJyEblXRC6IyJyIdO2kkuxJKzHD9EcPAHMGY1rhut1gOl+tWtGGHuo22o7uHTKGKXzylAoAAA8MSURBVO4cPY39x6awfKCEVYPmqn42E9j/fu09qzR407FMG44Hdm9sF7fycf/85BRlA9ct7+slauw2KxwuHZJaKz8HcDeA/5rCWEgM0mq7FfRHb7I0vGVfTZZGUO0Ok03htQMGl5U7qvs1mi0IYIxNt/H1bfdLdZPTofHzHRua3gkPAPyhKE5ZMLx+NYBujz0M/0QSNXabafRLh0RCrpR6BQAkII2ZZEtaiRlBf/RH9w6FThYmGyKqPeFPEtKVaFUAnjxzCSdffheN6VbHNdv4+jaYar/oolKarVk8+Mw5/FploKtZdGtWdexXuJOXLobbj24iiRJlwibJS4eebXaKyAMAHgCAdevW9eq0S4I0QsiC/uhtJgvTROA2RtDVK9cd7/CJC9qSrTpce0QXqWO76nXxto8T6bQgvMcMKqJlsmt098Y9pilj1Lv/EBc2SV46hAq5iPwEwMc1v3pYKfUD2xMppZ4A8ATQzuy0HuESIe/oAt0fvZvC7naUDxpTkKWhq/pnCpmMW7XQG6njjjHKytw7d7i66h93XEvCtAIe2VozRrp49x/iwjT6pUOokCulPt+LgSxl+qFIvy4B56qnDknYmMIsDa/Qmvz4B585l+gavEJrW3QrqJmyOy533HEaOYetgG32H0zYTP5M+FkaMPywD+iX6AK3IuHro3uwYnm33xs0Jm8UiglXaIMKRiXBu/K1WT3XqhVjRI4X91g2USfeZs42Ke9xqwgytJB4SeSRi8iXAfwxgDUATorIlFJqdyojW0L0Y3RBnDEFtSoDrgttnJVtGH7xCzuH+34bL90dt99/D6pSaEtc+4NZnsRL0qiV7wP4fkpjWbL0Y3RB3DGNT9bxwYczXa97MxXTiixxWbGsjG98uXPla2plpjDfXV7amZgrKw6csnQ9fbj4JwivVZHWvkYc+6MfJ3+SH0zR7wP6Mbog7pjGTl3URp3MeIRS12YsCdXBZVpv2B2PV2iBzk3QRrMFpyRYNeigMd3qaCkXJs42cfBZbTD24+RP8oNC3gekHV2QhpDEHZNpRaiAjs1SXZuxuJjOqRPanaOnu54EWnMKg8sGMPn12xOPpVcb1/04+ZP8oJD3CWlFF6QpJHHGFORNN1uzOHziQuQYb5tz2pK1JdEr75qhhcQLhXyRYSMkSVbs/s/u2rQGL7x6ueNnXVd5l0azFRgrXhbB8gExNp3wIzAnHenI2pLopXfN0ELiwvDDRUaYkCQJW9N99skzlzp+PvbTt6CpG2WFW8nw7m2fNPy+/X9XtxqjN2rE9jqybhrM7vIkDyjki4wwIUkSs26TZNOaU9r6KGGURXDPttqCf65j9YrleGN0D1577AuoVSvamic2nYeybG8WNFGk2aaNEC+0VgqEjSUStgkWVhMl6NhZhrbNKoXjZ+sYXr/ayp5IYmH4/WV/TZUkBEXLeIuB1RtNHPjuudTOS5Y2FPKCYLuJGbYJFlYTJejYaSTyBHXmabZmse/YlDFt3vu0kcTrzjqyROddDx15vissszWncPjEBQo5SQytlYIQxRLxptr7u7DbNjfQHTtpZ/uKU8b9O9aF1hHRibjfxzaNZfraTKhlkUdJBNMGb9wiYYR44Yq8IKQVDRGlzKtNYwNv1IpbaEuXJVnzPRnsHD0den5/L84jz13AvvlqgdWKg3u21fDDc+92iOH70y3sPzaFiTff0zae0F1X2OuE9DsU8oKQRtic32O3Oafps0f3DmktAX8HHeD6ajoshd7PnFJ4fXRPu9nE9851TBCNZgvfOXOpqyMP0LZunjpzCcPrV2vHmEdW5KpBR1t/ZpVF71BCwqC1UhCShs3pQgeDogS9x44SsqjL2PTaFrr+mya8kTa6Vf4cOuuIe1Hzn9ORdQiijkfu3Ayn3HnHnbLgkTs3Z3ZOsnSgkBeEpGFzOl9YQbugRbXidBw7iqccZFv4J4RGs4UPW3P42o51gcIa1/IISt3PMgTRdM6xr9zacc6xr9zKjU6SCrRWCkSSTL6gGii1aiVW2KHu9ZUVR7uBd1O1YpwQXnj1Mh67e0usSJsggqySPLIimYlJsoJCvkQI6kTz4sHbYn3WL5Tjk3VcvaYpYVtql7A1tTV7p9EMFLkDuzd2eeRhsIAUWUrQWlkiJPGFd21aY/W6ycu+8YaBhVZpOsI2GV1bwuunO5r/57o2US+sEkL6Ca7Ic6SXDZeTVMszpcz7XzdZMI35aI0kpVf9DR0eevY8WnOdTSPu37HOGHJIyGKGQp4TeTRcjuvR2nrkYRZMWqVXTRu3J19+t6sSo/dnlnklixUKeU7Y1q3u5ardhK1HbrPiTmPDzzSxvD/dWojVdiszuvRioiQkLyjkOWGzyjWt2ifefK+nK01bS8SU+Tl26iL2H5sKHavtpBU3ioXNiclihUKeEzarXNOq3Zs5mXSlaSOeUSwRfwkA27FGsZqSNG9mGj5ZjFDIc8JmlRsU++3FdqWp6+7j7eYTJJ62lohfkG3HGqVFmm5iufrRjFUBKjZ4IIsRCnkPCFr1Bq1yo1gIYe/TrXh1XeyT2g82zSd0E1TUQlb+icV/fToYW04WKxTyjAmzDIIEM4qF4LY/M2GK9NDhptPH2WS1sS50q2LTpLWy4oQ2vADCKzMyaoUsZhIJuYiMAbgTwDUArwH4N0qpRhoDWywk6aquEyfTyltXw9tLFG94ZcWJHRoZ9hRhWhXrJi2nJLh67bplEjYOpsCTpUrSzM4fA7hFKfVZAH8F4KHkQ1pcJK197W8SYWrKENasweQN+9fxFacMEcRuvKDLILXJuNQVsrrxhoGuTNGsG0AQUkQSCblS6nmllFtc4wwAffvzJUzaXdXjptqbPud27PFWAXQzMf3Y9sP0C/LRvUN4Q9OtSPdZ76SVZByELCXS9Mh/C8Ax0y9F5AEADwDAunXrUjxtf5MkLV1H3OzIKJ8zdQ+ynXzSsjhMNk2VzRgI6UBUiLcqIj8B8HHNrx5WSv1g/j0PAxgGcLcKOyCA4eFhNTExEWO4xaQfsjOjoIsAqTjlnhei0nUGAtre+di9rOVNlh4iclYpNex/PdRaUUp9Xil1i+Y/V8R/E8AXAdxvI+Kk/8mj8YJpHCuWdT80tuYUfXJCPCSNWrkDwO8D+CdKqel0hrS4yKM4lmkcUZ4K+iUC5IohyYc+OSHXSRq18p8B/D0APxaRKRH5LymMaVERpU1aVkTpudlvpL1ZTMhiJGnUyt9XSq1VSg3N//fbaQ1ssZA0/DAN+mEyiUsejZIJKRrsEJQxppVjSaRnK+J+mEzi0i9+PSH9DFP0M8aUZj+rVM+8ctt64v1Kv/j1hPQrXJFnjLui1NVC6ZW90a/2xPhkHTtHT+Pmgyexc/R0ITx7QvoRCnkPGNlaw5whMrMX9obOnrhnWw1jpy7mJqJF3oAlpN+gtdIj8rY3dM2L8wyJTFJMjBDSCVfkPaKf7I1+iGIp8gYsIf0GhbxH9Ev0xfhk3VhmtpciyvhwQtKD1koAaddIyTv6wrVUTPRSRNMuJkbIUoZCbiBLHzmvIlpBbdh6LaJxqzgSQrqhkBvIajMuz43GIOskr6JYFG5CkkOP3EBWm3F5bjSarJNatUJBJaTAUMgNZLUZl2e0Rj9FzhBC0oNCbiAr0cszWqNfImcIIelCj9xAVptxeUdr0JcmZPFBIQ8gC9GznSCK1h6OEJIfFPIcCJsg+iGFnhBSHOiR9yH9kEJPCCkOhVqRF8luSDJW1iEhhEShMEJeJLsh6VjzrpRICCkWhbFWimQ3JB0r470JIVEozIq8SHZD0rGyDgkhJAqFEfIi2Q1pjJXx3oQQWwpjrRTJbijSWAkhxSfRilxE/hDAlwDMAfglgH+tlHonjYH5KZLdUKSxEkKKjyhDU2CrD4v8mlLqb+f//e8AfEYp9dthnxseHlYTExOxz0sIIUsRETmrlBr2v57IWnFFfJ4VAOLPCoQQQmKReLNTRL4B4F8BuAJgV+IREUIIiUToilxEfiIiP9f89yUAUEo9rJRaC+ApAL8bcJwHRGRCRCYuX76c3hUQQsgSJ5FH3nEgkfUATiqlbgl7Lz1yQgiJTiYeuYh82vPjXQBeTXI8Qggh0UnqkY+KyEa0ww/fBBAasUIIISRdUrNWIp1U5DLawt9PfAzA3+Q9iJTgtfQnvJb+pEjXsl4ptcb/Yi5C3o+IyITOeyoivJb+hNfSnyyGaylMij4hhBA9FHJCCCk4FPLrPJH3AFKE19Kf8Fr6k8JfCz1yQggpOFyRE0JIwaGQE0JIwaGQexCRPxSRl0VkSkSeF5Gb8h5TXERkTERenb+e74tINe8xxUVE7hWRCyIyJyKFCxMTkTtE5KKI/EJEDuY9niSIyLdF5Jci8vO8x5IEEVkrIi+IyCvz/9/6vbzHlAQKeSdjSqnPKqWGAPwQwNfzHlACfgzgFqXUZwH8FYCHch5PEn4O4G4Af573QKIiImUAfwLgXwD4DID7ROQz+Y4qEf8NwB15DyIFZgA8qJT6dQA7APxOkb8XCrmHxVRfXSn1vFJqZv7HMwA+med4kqCUekUpdTHvccTkcwB+oZT6a6XUNQB/inZXrUKilPpzAO/lPY6kKKXeVUr95fy//w7AKwAK28KrMM2Xe8Uira/+WwCO5T2IJUoNwFuen98GsD2nsRANIrIBwFYAL+U7kvgsOSEXkZ8A+LjmVw8rpX6glHoYwMMi8hDa9dUf6ekAIxB2LfPveRjtx8inejm2qNhcS0ERzWuFfdJbbIjIjQCOA9jneyIvFEtOyJVSn7d863cAnEQfC3nYtYjIbwL4IoB/pvo8YSDC91I03gaw1vPzJwFk0qCcRENEHLRF/Cml1LN5jycJ9Mg9LKb66iJyB4DfB3CXUmo67/EsYX4G4NMicrOILAPwGwBO5DymJY+ICIBvAXhFKfVHeY8nKczs9CAixwF01FdXStXzHVU8ROQXAJYD+NX8S2eUUoWsFy8iXwbwxwDWAGgAmFJK7c53VPaIyBcAPA6gDODbSqlv5Dyk2IjI0wD+KdqlX/8fgEeUUt/KdVAxEJF/BOB/ATiP9t87APwHpdSP8htVfCjkhBBScGitEEJIwaGQE0JIwaGQE0JIwaGQE0JIwaGQE0JIwaGQE0JIwaGQE0JIwfn/y1oYnpTCWjgAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "samples = gaussian.sample((500,))\n", "plt.scatter(samples[:, 0].numpy(), samples[:, 1].numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "vXCJs0dCHn54" }, "source": [ "NOTE: if you want to use a batch of `MultivariateNormal` distributions, you'll need to construct a batch of covariance matrices (i.e. shape `[BATCH_SIZE, DIM, DIM]`)." ] }, { "cell_type": "markdown", "metadata": { "id": "KsZZu33nfDB5" }, "source": [ "## Categorical Distribution" ] }, { "cell_type": "code", "execution_count": 106, "metadata": { "id": "IfoFevpBdy9X" }, "outputs": [], "source": [ "from torch import distributions" ] }, { "cell_type": "markdown", "metadata": { "id": "ePKpc3akHxit" }, "source": [ "Another useful distribution is the categorical distribution." ] }, { "cell_type": "code", "execution_count": 107, "metadata": { "id": "QgRw31OMfEcV" }, "outputs": [], "source": [ "probs = torch.tensor([0.1, 0.2, 0.7])\n", "dist = distributions.Categorical(probs=probs)" ] }, { "cell_type": "code", "execution_count": 108, "metadata": { "id": "K5ucMwVKfNv9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 1, 2, 2, 2])\n" ] } ], "source": [ "sample = dist.sample([20])\n", "print(sample)" ] }, { "cell_type": "code", "execution_count": 109, "metadata": { "id": "ye8mqQa4iDa5" }, "outputs": [ { "data": { "text/plain": [ "tensor([-0.3567, -0.3567, -0.3567, -1.6094, -1.6094, -0.3567, -0.3567, -0.3567,\n", " -0.3567, -0.3567, -2.3026, -0.3567, -0.3567, -0.3567, -0.3567, -0.3567,\n", " -1.6094, -0.3567, -0.3567, -0.3567])" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dist.log_prob(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "lNEq80qTfZ6s" }, "source": [ "## Distributions and Modules" ] }, { "cell_type": "markdown", "metadata": { "id": "OJm0r8WFIwqf" }, "source": [ "Typically, your network will output parameters of a distribution" ] }, { "cell_type": "code", "execution_count": 110, "metadata": { "id": "TkKsDyCmIzvF" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class Net(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", " super(Net, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 32)\n", " self.fc2 = nn.Linear(32, 32)\n", " self.fc3 = nn.Linear(32, output_size)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 111, "metadata": { "id": "3nCEfh-II2zV" }, "outputs": [], "source": [ "mean_network = Net(1, 1)\n", "x = torch.randn(100, 1)\n", "mean = mean_network(x)\n", "distribution = distributions.Normal(x, scale=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "G9FDsX3DIpFn" }, "source": [ "If you want, your nn.Module can return a distribution in the `forward` function!" ] }, { "cell_type": "code", "execution_count": 112, "metadata": { "id": "Gz2eMYZydMyC" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", " super(Net, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 32)\n", " self.fc2 = nn.Linear(32, 32)\n", " self.fc3 = nn.Linear(32, output_size)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return distributions.Normal(x, scale=1)" ] }, { "cell_type": "code", "execution_count": 113, "metadata": { "id": "L5KpD2reccWJ" }, "outputs": [], "source": [ "distribution_network = Net(1, 1)\n", "x = torch.randn(100, 1)\n", "distribution = distribution_network(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "4Wh09beAu6d5" }, "source": [ "## `gather`\n", "This function will be useful for the DQN assignment.\n", "It allows you to index into arrays in a batch." ] }, { "cell_type": "code", "execution_count": 114, "metadata": { "id": "l0A4vz6iJKab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0, 1, 2],\n", " [3, 4, 5]])\n", "tensor([[0],\n", " [1]])\n" ] }, { "ename": "RuntimeError", "evalue": "Size does not match at dimension 1 get 3 vs 1", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m: Size does not match at dimension 1 get 3 vs 1" ] } ], "source": [ "x = torch.arange(6).reshape(2, 3)\n", "y = torch.tensor([0, 1]).reshape(2, 1)\n", "print(x)\n", "print(y)\n", "print(torch.gather(x, 0, y))" ] }, { "cell_type": "code", "execution_count": 116, "metadata": { "id": "NfMgAVAYu7n_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0, 1, 2],\n", " [3, 4, 5]])\n", "tensor([[0, 1, 0]])\n", "tensor([[0, 4, 2]])\n" ] } ], "source": [ "x = torch.arange(6).reshape(2, 3)\n", "y = torch.tensor([0, 1, 0]).reshape(1, 3)\n", "print(x)\n", "print(y)\n", "print(torch.gather(x, 0, y))" ] }, { "cell_type": "markdown", "metadata": { "id": "hUmPV_O_cFz4" }, "source": [ "For a 3-D tensor the output is specified by::\n", "\n", " out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0\n", " out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1\n", " out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "PyTorch_Tutorial.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }