{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch Tutorial\n", "\n", "> In this post, We will cover the basic tutorial while we use PyTorch. This is the summary of lecture CS285 \"Deep Reinforcement Learning\" from Berkeley.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, PyTorch, Berkeley]\n", "- image: " ] }, { "cell_type": "markdown", "metadata": { "id": "edLrpksDJfDg" }, "source": [ "## Intro\n", "This is a PyTorch Tutorial for UC Berkeley's CS285.\n", "There's already a bunch of great tutorials that you might want to check out, and in particular [this tutorial](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).\n", "This tutorial covers a lot of the same material. If you're familiar with PyTorch basics, you might want to skip ahead to the PyTorch Advanced section.\n", "\n", "First, let's import some things and define a useful plotting function" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "id": "dq6HedFjckGr" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "import numpy as np\n", "\n", "def plot(xs, ys, xlim=(-3, 3), ylim=(-3, 3)):\n", " fig, ax = plt.subplots()\n", " ax.plot(xs, ys, linewidth=5)\n", " # ax.set_aspect('equal')\n", " ax.grid(True, which='both')\n", "\n", " ax.axhline(y=0, color='k')\n", " ax.axvline(x=0, color='k')\n", " ax.set_xlim(*xlim)\n", " ax.set_ylim(*ylim)" ] }, { "cell_type": "markdown", "metadata": { "id": "ynLs1RCduz_O" }, "source": [ "## PyTorch Basic" ] }, { "cell_type": "markdown", "metadata": { "id": "U5rl_7Kx5vk8" }, "source": [ "### Tensors" ] }, { "cell_type": "markdown", "metadata": { "id": "rjXm2HUv_rBm" }, "source": [ "Numpy arrays are objects that allow you to store and manipulation matrices." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 166 }, "id": "x55DbowZBcDY", "outputId": "a37168ca-d4b1-4ef9-cb8c-c29c2ec569f1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0.]\n", " [0. 0. 0.]]\n", "+\n", "[[1. 1. 1.]\n", " [1. 1. 1.]]\n", "=\n", "[[1. 1. 1.]\n", " [1. 1. 1.]]\n" ] } ], "source": [ "shape = (2, 3)\n", "x = np.zeros(shape)\n", "y = np.ones(shape)\n", "z = x + y\n", "print(x)\n", "print(\"+\")\n", "print(y)\n", "print(\"=\")\n", "print(z)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.0\n" ] } ], "source": [ "print(z.sum())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "g3Acexuuz0RS", "outputId": "3941e8b2-f5b5-465e-a910-dbbc83bf12f8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1. 1.]\n" ] } ], "source": [ "print(z[0, 1:])" ] }, { "cell_type": "markdown", "metadata": { "id": "kPGVgAOZ_mIq" }, "source": [ "PyTorch is build around _tensors_, which play a similar role as numpy arrays. You can do many of the same operations in PyTorch:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 166 }, "id": "ujftyWqRCGtK", "outputId": "471f29d4-5da3-4620-c6e4-3f6ad3f9d872" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0., 0., 0.],\n", " [0., 0., 0.]])\n", "+\n", "tensor([[1., 1., 1.],\n", " [1., 1., 1.]])\n", "=\n", "tensor([[1., 1., 1.],\n", " [1., 1., 1.]])\n" ] } ], "source": [ "x = torch.zeros(shape)\n", "y = torch.ones(shape)\n", "z = x + y\n", "\n", "print(x)\n", "print(\"+\")\n", "print(y)\n", "print(\"=\")\n", "print(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "Nb6RWm3iAQjd" }, "source": [ "Many functions have alternate syntax that accomplish the same thing" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "ozmW7nxfKvue" }, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 1., 1.],\n", " [1., 1., 1.]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.add(x, y)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "6l0MwUdvDNsU" }, "outputs": [ { "data": { "text/plain": [ "tensor(1.)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z.min()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "Kv65CHa_D-WK" }, "outputs": [ { "data": { "text/plain": [ "tensor([1.])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z[1:, 0]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "T6T5UEfJ0eKa" }, "outputs": [ { "data": { "text/plain": [ "tensor(6.)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.sum(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "NKs4pG0oAhro" }, "source": [ "Function that reduce dimenions will by default reduce all dimensions unless a dimension is specified" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "rcLKnEh80gxY" }, "outputs": [ { "data": { "text/plain": [ "tensor([3., 3.])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.sum(z, dim=1)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "2hi_pF7a0i5Q" }, "outputs": [ { "data": { "text/plain": [ "tensor([2., 2., 2.])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.sum(z, dim=0)" ] }, { "cell_type": "markdown", "metadata": { "id": "VUjHBWOHAaXZ" }, "source": [ "Like numpy, pytorch will try to broadcast operations" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "ffy_uyGF0qZV" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[1.],\n", " [1.],\n", " [1.]])\n", "+\n", "tensor([[1., 1., 1.]])\n", "=\n", "tensor([[2., 2., 2.],\n", " [2., 2., 2.],\n", " [2., 2., 2.]])\n" ] } ], "source": [ "x = torch.ones((3, 1))\n", "y = torch.ones((1, 3))\n", "z = x + y\n", "\n", "print(x)\n", "print(\"+\")\n", "print(y)\n", "print(\"=\")\n", "print(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "U8AQHu-K_6Z9" }, "source": [ "Operations that end with an underscore denote in-place functions. Use these sparingly as they easily lead to bugs." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "ewIwYOhnnCC5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[2., 2., 2.],\n", " [2., 2., 2.],\n", " [2., 2., 2.]])\n", "tensor([[0., 0., 0.],\n", " [0., 0., 0.],\n", " [0., 0., 0.]])\n" ] } ], "source": [ "print(z)\n", "z.zero_()\n", "print(z)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "BsC0k_cgnFW-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[5., 5., 5.],\n", " [5., 5., 5.],\n", " [5., 5., 5.]])\n" ] } ], "source": [ "z.add_(5)\n", "print(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "DwA9Pmj-K4Pr" }, "source": [ "#### Moving between numpy and PyTorch" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "v3J1zw1WyHlG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.13298262 0.72862306 0.95932965]\n", " [ 2.24109395 0.53208287 -0.42554932]]\n" ] } ], "source": [ "x_np = np.random.randn(*shape)\n", "print(x_np)" ] }, { "cell_type": "markdown", "metadata": { "id": "YkIYPgPzA2-C" }, "source": [ "numpy -> pytorch is easy" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "EMPx2iKzybuk" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.1330, 0.7286, 0.9593],\n", " [ 2.2411, 0.5321, -0.4255]], dtype=torch.float64)\n" ] } ], "source": [ "x = torch.from_numpy(x_np)\n", "print(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "vInPpem8ArmF" }, "source": [ "By default, numpy arrays are float64. You'll probably want to convert arrays to float32, as most tensors in pytorch are float32." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "xZSegeSByh9D" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.1330, 0.7286, 0.9593],\n", " [ 2.2411, 0.5321, -0.4255]])\n" ] } ], "source": [ "x = torch.from_numpy(x_np).to(torch.float32)\n", "print(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "lUd3zm2yAzuI" }, "source": [ "pytorch -> numpy is also easy" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "mjxo4Lwryu-6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.13298263 0.7286231 0.95932966]\n", " [ 2.2410939 0.53208286 -0.42554933]]\n" ] } ], "source": [ "print(x.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "4NDB2ypyy9eE" }, "source": [ "#### GPU support" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "Kh-UHYiVzSKc" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "markdown", "metadata": { "id": "gvdb3bSIA7i-" }, "source": [ "The code below errors out because both tensors need to be on the same device." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "C2DzoBcizK7j" }, "outputs": [ { "ename": "RuntimeError", "evalue": "expected device cpu but got device cuda:0", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mones\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mz\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m: expected device cpu but got device cuda:0" ] } ], "source": [ "device = torch.device(\"cuda\")\n", "x = torch.zeros(shape)\n", "y = torch.ones(shape, device=device)\n", "z = x + y" ] }, { "cell_type": "markdown", "metadata": { "id": "inYRaSWbBI5y" }, "source": [ "You can move a tensor to the GPU by using the `to` function." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8jbmhtOY5Oj4" }, "outputs": [], "source": [ "x = x.to(device)\n", "z = x + y\n", "print(z)" ] }, { "cell_type": "markdown", "metadata": { "id": "aXasNfBJBKFl" }, "source": [ "This code also errors out, because you can't convert tensors on a GPU into numpy arrays directly." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "ITCz002W5Vu0" }, "outputs": [ { "data": { "text/plain": [ "array([[5., 5., 5.],\n", " [5., 5., 5.],\n", " [5., 5., 5.]], dtype=float32)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "OEW3bs_ABaM6" }, "source": [ "First you need to move them to the CPU." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "z00AQCfG5Xaw" }, "outputs": [ { "data": { "text/plain": [ "array([[5., 5., 5.],\n", " [5., 5., 5.],\n", " [5., 5., 5.]], dtype=float32)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z_cpu = z.to('cpu')\n", "z_cpu.numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "PP-wQ8sDzIbn" }, "source": [ "### (homework aside)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "F5W8aKyiyPtc" }, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'cs285'", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mcs285\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minfrastructure\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mpytorch_util\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mptu\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mptu\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx_np\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mptu\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto_numpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cs285'" ] } ], "source": [ "from cs285.infrastructure import pytorch_util as ptu\n", "ptu.from_numpy(x_np)\n", "ptu.to_numpy(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "lkH9ZA-s49Mg" }, "source": [ "## Neural-Network specific functions\n", "PyTorch has a bunch of built-in funcitons.\n", "See [the docs](https://pytorch.org/docs/stable/torch.html) for a full list." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "gpEeDQWcC2AT" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAZp0lEQVR4nO3deXSV9b3v8fc3E2GOQiDMs2FKWit1tkbFipaKgO3t3HO9Lceea6tdtoDiLCpgr73e1q6Wc7SnnuPx9pRBqKA4bhAVRTyQhCGIyBAEGTcQICTZ+d0/COfiKU8S2E/28MvntVbWYiff/Xu+P3f88OTJs7+Ycw4REfFHRrIbEBGRcCnYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8E3ewm1mumb1vZmvMbK2ZPRhGYyIicnYs3vvYzcyA9s65KjPLBpYDtzvnVoTRoIiInJmseBdwJ/5mqGp4mN3woXc9iYgkSdzBDmBmmcAqYDDwlHPuvdPUTAImAeTm5l7Qt2/fMA6dkurr68nI8PfXF77ub/v27Tjn0Pdm+vJlf8djsPdYPbX1n/98za5Ne51z+U09P+5LMZ9bzCwPmA/81DlXHlRXWFjoKioqQjtuqolEIpSUlCS7jRbj6/5KSkqIRqOsXr062a20GF9fu5PSfX9Hjtfxq1cq+Od3tnC6aN46c+wq59yoptYJ5Yz9JOdc1MwiwBggMNhFROTzln+0l6nzSqk8cCzutcK4Kya/4UwdM2sLjAY2xLuuiEhrcPBoLZPnrOF7T78XGOpZGcbPrhnS7DXDOGPvAfyp4Tp7BvDvzrkXQ1hXRMRrL5fv4t4F5ew5fDyw5gu9OzPz5mKGFnTizmauG8ZdMaXA+fGuIyLSWuw+XM0DC9eyuGxXYE1udgZ3XlvILZcPIDPDzmj9UK+xi4hIMOcc8z7cwUMvruPgsdrAuosHnsuMCcX079r+rI6jYBcRSYDKA0e5e345yzbuCazp2CaLu24Yxre+3IeMMzxLP5WCXUSkBdXXO/5lxVZmvryBozWxwLprhnZj+viR9OjcNu5jKthFRFrIx3uqmDq3lJVbDgTWnNs+h/u/Ppwbv9CTExNa4qdgFxEJWW2sntnLNvPk6x9RU1cfWDfuiz25b+xwunRoE+rxFewiIiEq33GQyXNKWbfzUGBNQadcpt80ktHDu7dIDwp2EZEQVNfG+D+vf8Qflm0mVh88quU7F/Vl6vVD6ZSb3WK9KNhFROK0cst+pswpZfPeI4E1/bq0Y8aEYi4Z1KXF+1Gwi4icparjdcx6eQPPvrs1sCbD4EdXDOTno8+jbU5mQvpSsIuInIWlG/dw97wydkSDh3YNLejIzInFfKFPXgI7U7CLiJyRA0dqeHjROuZ9uCOwJjvTuO2qIfykZBA5WYmfD69gFxFpBuccL5Xv4r4F5eytqgms+2KfPGbdXMx53TsmsLvPU7CLiDRh96Fq7l1QzpK1nwXW5GZn8IuvFvLfLzvzoV1hU7CLiARwzvGXVZVMf3Edh6rrAusuHdSFGROK6dulXQK7C6ZgFxE5je37j3LXvDKWb9obWNMxN4tpNwzjv325T2jjAMKgYBcROUWs3vHsu1uY9XIFx2qDh3ZdO7w7028aSfdOuYlrrpkU7CIiDTbtPszkOaV8uC0aWNO1Qw4P3jiSG4oKUuos/VQKdhFp9Wpj9fw+8jG/eWMTNbHgoV3jz+/FfWOHc077nAR2d+YU7CLSqpVVHuSXc9awYdfhwJoenXN5dHwRVw3tlsDOzp6CXURaperaGL9+bSP/uGwzjczs4vsX92PymEI6tuDQrrAp2EWk1Xlv8z6mzivjk0aGdg3o2p4ZE4q4aGDLD+0Km4JdRFqNw9W1zHx5A/+6YltgTWaG8eMrBnLH6CHkZidmaFfYFOwi0iq8uWE3d88vY+fB6sCaYT06MWtiMUW9Oyews/Ap2EXEa/uP1PDwi+uY/x/BQ7tyMjP46dWDubVkENmZiR/aFTYFu4h4yTnHi6U7eWDhWvYdCR7a9aW+J4Z2De6WvKFdYVOwi4h3dh2s5p4XynltffDQrnY5mUy+rpDvX9I/6UO7wqZgFxFvOOd4/v1tPLpoPYePBw/tumJIVx4dX0Sfc1NjaFfY4g52M+sDPAsUAPXAbOfck/GuKyJyJrbuO8KsldWs318WWNMpN4t7xg7nGxf0TtlxAGEI44y9DrjTOfehmXUEVpnZq865dSGsLSLSqFi9449vf8KvXqmgujZ4HMCYEQU8NG4E3VJwaFfY4g5259xOYGfDnw+b2XqgF6BgF5EWVbHrMJPnlrJme2NDu9rw0LgR3FDUI4GdJVeo19jNrD9wPvBemOuKiJyqpq6e30U28dSbm6iNBc8DmPil3tw7dhh57VJ7aFfYQgt2M+sAzAXucM4dOs3XJwGTAPLz84lEImEdOuVUVVVpf2koGo0Si8W83NtJPrx2mw/GeKbsOJVVwYHeJdf4uxE5FOUfYPX77ySwu9RgzjUy/aa5i5hlAy8CS5xzTzRVX1hY6CoqKuI+bqqKRCKUlJQku40W4+v+SkpKiEajrF69OtmttJh0fu2O1cR44tUKnl7+SeDQLgN+eGl/fnFdIR3a+HfTn5mtcs6NaqoujLtiDHgaWN+cUBcROVPvfryPqfNK2brvaGDNwPz2fHtgjB/fOCKBnaWmMP5Kuwz4PlBmZidPde52zi0OYW0RacUOVdfy2OINPP9+40O7br1yID+9eggr3n4rgd2lrjDuilnOiZ+ARERC8/r6z5g2v5xdh4KHdo3o2YlZNxczomd6D+0Km38XoUQkre2rOs6Df13HwjWfBtbkZGVwx+gh/PiKgV4M7Qqbgl1EUoJzjoVrPuWBhWs5cLQ2sG5Uv3OYMbGYwd06JLC79KJgF5Gk23nwGPfML+f1DbsDa9rnZDLl+qF876J+ZHg2tCtsCnYRSZr6esfzK7fx2OINVDUytOvK8/J5ZPxIep/j59CusCnYRSQptuw9wtR5pazYvD+wpnPbbO4bO5wJX+rl9dCusCnYRSSh6mL1PPP2J/yvVzZyvC54aNfXinrwwI0jyO/YJoHd+UHBLiIJs2HXIabMKWVN5cHAmvyObXh43EjGjCxIYGd+UbCLSIs7XhfjqTc/5ndvbqIuaB4A8M1RvZl2w3A6t8tOYHf+UbCLSIv6cNsBpswp5aPdVYE1vc9py4wJxVw+pGsCO/OXgl1EWsTRmjp+tWQjf3znE4JmDZrB313an19eV0i7HMVRWPRfUkRC9/amvUydV8r2/ccCawZ368DMicVc0O+cBHbWOijYRSQ0B4/V8uii9fz5g+2BNVkZxk9KBnHb1YNpk5WZwO5aDwW7iIRiydpd3PtCObsPHw+sKerVmZkTixnes1MCO2t9FOwiEpc9h4/zwMK1LCrbGVjTJiuDn197Hj+6fABZGtrV4hTsInJWnHO8sHoHD/51HdFGhnZdOOBcZkwoYmC+hnYlioJdRM7Yjugxps0vI1KxJ7CmQ5ssplw/lO9e2FdDuxJMwS4izVZf73ju/W3MWLyeIzWxwLqSwnweHV9Ez7y2CexOTlKwi0izbN5TxdS5Zby/JXho1zntsrn/6yMY98WeGtqVRAp2EWlUXayef1r+Cb9+tfGhXWOLTwzt6tpBQ7uSTcEuIoHWfXqIyXPXUL7jUGBN905tmH5TEdcO757AzqQxCnYR+RvVtTF++8Ymfr/040aHdn37wj7cdcMwOuVqaFcqUbCLyOes2rqfyXNK+XjPkcCavue2Y8aEIi4drKFdqUjBLiIAHDlex+NLKvjTu1sCh3ZlGNxy2QDu/GohbXM0DiBVKdhFhLc+2sNd88qoPBA8tOu87ieGdp3fV0O7Up2CXaQVO3i0lumL1vGXVZWBNdmZxj+UDOYfrhqkoV1pQsEu0kq9XL6TexesZU8jQ7u+0LszM28uZmiBhnalEwW7SCuz+3A19y9Yy0vluwJrcrMzuPPaQm65fACZGgeQdhTsIq2Ec47lO2q5fekyDh4LHtp1ycAuzJhYRL8u7RPYnYQplGA3s2eAscBu59zIMNYUkfBUHjjKXfPKeOujmsCajm2yuPtrw/jWl/toHECaC+uM/Z+B3wLPhrSeiISgvt7x7LtbmLWkgqONDO0aPawb028qoqBzbuKakxYTSrA755aZWf8w1hKRcGzaXcXUuaV8sPVAYM257XN44MYRfL24h87SPZKwa+xmNgmYBJCfn08kEknUoROuqqpK+0tD0WiUWCyW9nurq3e89EktCzbVUhc8DYCLe2Ty3WFZdDywkaVLNyauwRbk6/fmmUpYsDvnZgOzAQoLC11JSUmiDp1wkUgE7S/95OXlEY1G03pv5TsOMnlOKet2Hg2sKeiUyyPjR3LNMP+Gdvn6vXmmdFeMiAeqa2M8+fpHzF62mVgjQ7uu6pPFk//jKxra5TkFu0iaW7llP1PmlLJ5b/DQrv5d2jFjYjHV28oU6q1AWLc7Pg+UAF3NrBK43zn3dBhri8jpVR2vY9bLG3j23a2BNRkGP75iIHeMPo+2OZlEtiWwQUmasO6K+XYY64hI8yzduIe755WxIxo8tGtoQUdm3VxMce+8BHYmqUCXYkTSyIEjNTy8aB3zPtwRWJOTmcFtVw/m1isHkZOVkcDuJFUo2EXSgHOOl8p3cd+CcvZWBb979Py+ecyaWMyQ7h0T2J2kGgW7SIrbfaiaexeUs2TtZ4E1bbMz+eV1hfzw0v4a2iUKdpFU5ZzjLx9UMn3ROg5V1wXWXTa4C4+NL6Zvl3YJ7E5SmYJdJAVt339iaNfyTXsDazrmZnHP14bxzVEa2iWfp2AXSSGxesef3tnC40sqOFYbPLTr2uHdmX7TSLp30tAu+VsKdpEU8dFnh5kyt5QPt0UDa7p2yOHBG0dyQ1GBztIlkIJdJMlq6ur5w9KP+c0bm6iJ1QfWTTi/F/eOHc457XMS2J2kIwW7SBKVVkaZPKeUDbsOB9b07JzLIxOKuKqwWwI7k3SmYBdJguraGL9+dSP/+NZmGpnZxfcv7seU64fSoY3+V5Xm03eLSIKt2LyPqXNL2bIveLTugK7tmTmxmAsHnJvAzsQXCnaRBDlcXcuMlzbw3HvBk7gyM4wfXTGAn48+j9zszAR2Jz5RsIskwBsbPmPa/HJ2HqwOrBnWoxOzJhZT1LtzAjsTHynYRVrQ/iM1PPTXtbyw+tPAmpzMDH52zWD+/spBZGdqaJfET8Eu0gKcc/y1dCcPLFzL/iPBQ7su6HcOMycWM7hbhwR2J75TsIuEbNfBau55oZzX1gcP7WqXk8nk6wr5wSX9ydDQLgmZgl0kJM45/rxyO48sWs/h48FDu64Y0pVHxxfR51wN7ZKWoWAXCcHWfUeYOreMdzfvC6zplJvFvWOHc/MFvTUOQFqUgl0kDrF6xx/f/oRfvVJBdW3wOIDrRxbw4LgRdOuooV3S8hTsImepYtdhJs8tZc32xoZ2teHhcSO4vqhHAjuT1k7BLnKGaurq+V1kE0+9uYnaWPA8gJsv6M09XxtGXjsN7ZLEUrCLnIHV26NMmVNKxWfBQ7t65bXlsQlFfOW8/AR2JvL/KdhFmuFYTYwnXq3g6eWfBA7tMoMfXNyPyWOG0l5DuySJ9N0n0oR3Pt7L1LllbNsfPLRrUP6JoV2j+mtolySfgl0kwKHqWh5bvIHn3298aNdPrhzEbVcP1tAuSRkKdpHTeG3dZ0x7oYzPDh0PrBnRsxOzbi5mRE8N7ZLUomAXOUVNDH72/H+wcE0jQ7uyMrj9miH8/VcGkqWhXZKCQgl2MxsDPAlkAv/knJvRWP2hmhP/EruvPtpay1btL63srTpOaWWUozWxRkP9y/3PYcbEYgbla2iXpK64g93MMoGngGuBSmClmS10zq0Les7+asf9C9fGe+jUtl77SzdHa2KBX2ufk8mU64fyvYv6aWiXpDxzrpF/cLE5C5hdAjzgnLuu4fFdAM65x4Kek5HT1uUUDI7ruCJhq9m9GYCcbgM/9/m8djkM6NqeNlnpf9klGo2Sl5eX7DZajO/7W7p06Srn3Kim6sK4FNML2H7K40rgov9aZGaTgEkAlq15GZL6Mg26tcugc5sYx6oOcSzZDYUgFosRjQaPQEh3vu+vucII9tP9XPo3PwY452YDswHa9BjiCr7T6GV4kYTb9W9TAbhuyu+5fmQB3xzVh3Pb+zUOIBKJUFJSkuw2Wozv+2vuVNAwgr0S6HPK495A8G+fgI45xg8u6RfCoVPTjh076NWrV7LbaDE+7i/DjD+/0p6MumoW/M/Lkt2OSFzCCPaVwBAzGwDsAL4FfKexJ3TJNR4aNzKEQ6emSGQvJSXaX7qJPJFLNBr8j02LpIu4g905V2dmtwFLOHG74zPOOf9umRARSROh3MfunFsMLA5jLRERiU/6378lIiKfo2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDMKdhERzyjYRUQ8o2AXEfGMgl1ExDNxBbuZfcPM1ppZvZmNCqspERE5e/GesZcDE4BlIfQiIiIhyIrnyc659QBmFk43IiISt7iC/UyY2SRgEkB+fj6RSCRRh064qqoq7S8NRaNRYrGYl3s7ydfX7iTf99dcTQa7mb0GFJzmS9OccwuaeyDn3GxgNkBhYaErKSlp7lPTTiQSQftLP3l5eUSjUS/3dpKvr91Jvu+vuZoMdufc6EQ0IiIi4dDtjiIinon3dsfxZlYJXAIsMrMl4bQlIiJnK967YuYD80PqRUREQqBLMSIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuKZuILdzB43sw1mVmpm880sL6zGRETk7MR7xv4qMNI5VwxsBO6KvyUREYlHXMHunHvFOVfX8HAF0Dv+lkREJB5hXmO/BXgpxPVEROQsZDVVYGavAQWn+dI059yChpppQB3wXCPrTAImAeTn5xOJRM6m37RQVVWl/aWhaDRKLBbzcm8n+franeT7/prLnHPxLWD2Q+BW4Brn3NHmPKewsNBVVFTEddxUFolEKCkpSXYbLcbX/ZWUlBCNRlm9enWyW2kxvr52J/m+PzNb5Zwb1VRdk2fsTRxkDDAFuLK5oS4iIi0r3mvsvwU6Aq+a2Woz+30IPYmISBziOmN3zg0OqxEREQmH3nkqIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4RsEuIuIZBbuIiGcU7CIinlGwi4h4Jq5gN7OHzazUzFab2Stm1jOsxkRE5OzEe8b+uHOu2Dn3ReBF4L4QehIRkTjEFezOuUOnPGwPuPjaERGReGXFu4CZPQL8ADgIXNVI3SRgUsPD42ZWHu+xU1hXYG+ym2hBPu+vq5n5ujfw+7UD//dX2Jwic67xk2wzew0oOM2XpjnnFpxSdxeQ65y7v8mDmn3gnBvVnAbTkfaXvnzeG2h/6a65+2vyjN05N7qZx/w3YBHQZLCLiEjLifeumCGnPLwR2BBfOyIiEq94r7HPMLNCoB7YCtzazOfNjvO4qU77S18+7w20v3TXrP01eY1dRETSi955KiLiGQW7iIhnkhbsPo8jMLPHzWxDw/7mm1lesnsKk5l9w8zWmlm9mXlza5mZjTGzCjPbZGZTk91PmMzsGTPb7ev7R8ysj5m9aWbrG743b092T2Exs1wze9/M1jTs7cEmn5Osa+xm1unkO1fN7GfAcOdcc3/5mtLM7KvAG865OjObCeCcm5LktkJjZsM48QvzPwC/cM59kOSW4mZmmcBG4FqgElgJfNs5ty6pjYXEzL4CVAHPOudGJrufsJlZD6CHc+5DM+sIrAJu8uH1MzMD2jvnqswsG1gO3O6cWxH0nKSdsfs8jsA594pzrq7h4QqgdzL7CZtzbr1zriLZfYTsQmCTc26zc64G+L/AuCT3FBrn3DJgf7L7aCnOuZ3OuQ8b/nwYWA/0Sm5X4XAnVDU8zG74aDQvk3qN3cweMbPtwHfxd4DYLcBLyW5CmtQL2H7K40o8CYbWxsz6A+cD7yW3k/CYWaaZrQZ2A6865xrdW4sGu5m9Zmblp/kYB+Ccm+ac6wM8B9zWkr2Eram9NdRMA+o4sb+00pz9ecZO8zlvfopsLcysAzAXuOO/XBVIa865WMMU3d7AhWbW6OW0uIeANdGMt+MImtqbmf0QGAtc49LwzQJn8Nr5ohLoc8rj3sCnSepFzkLD9ee5wHPOuXnJ7qclOOeiZhYBxgCBvwhP5l0x3o4jMLMxwBTgRufc0WT3I82yEhhiZgPMLAf4FrAwyT1JMzX8gvFpYL1z7olk9xMmM8s/eWedmbUFRtNEXibzrpi5nBhB+Z/jCJxzO5LSTMjMbBPQBtjX8KkVvtzxA2Bm44HfAPlAFFjtnLsuuV3Fz8xuAP43kAk845x7JMkthcbMngdKODHW9jPgfufc00ltKkRmdjnwFlDGiUwBuNs5tzh5XYXDzIqBP3Hi+zID+Hfn3EONPicNrxKIiEgj9M5TERHPKNhFRDyjYBcR8YyCXUTEMwp2ERHPKNhFRDyjYBcR8cz/AyCVB2MwYsZLAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xs = torch.linspace(-3, 3, 100)\n", "ys = torch.relu(xs)\n", "plot(xs.numpy(), ys.numpy())" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "of1U9ahX46eL" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xs = torch.linspace(-3, 3, 100)\n", "ys = torch.tanh(xs)\n", "plot(xs.numpy(), ys.numpy())" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "smjZcbzo4Ers" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xs = torch.linspace(-3, 3, 100)\n", "ys = torch.selu(xs)\n", "plot(xs.numpy(), ys.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "X97sJV2bEqT2" }, "source": [ "## Automatic differentiation" ] }, { "cell_type": "markdown", "metadata": { "id": "-1wKYOcNc_I-" }, "source": [ "Given some loss function\n", "$$L(\\vec x, \\vec y) = ||2 \\vec x + \\vec y||_2^2$$\n", "we want to evaluate\n", "$$\\frac{\\partial L}{\\partial \\vec x}$$\n", "and\n", "$$\\frac{\\partial L}{\\partial \\vec y}$$" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "ceo_JWyUD0Py" }, "outputs": [], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "5f0cycqiBt0h" }, "source": [ "PyTorch makes this easy by having tensors keep track of their data..." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "GroZ0prJgAZ0" }, "outputs": [ { "data": { "text/plain": [ "tensor([1., 2., 3.])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.data" ] }, { "cell_type": "markdown", "metadata": { "id": "5Ozf4qtnByLZ" }, "source": [ "...and their gradient:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "_9-FwL4wgCDG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "print(x.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "sdnlnBFjB0KT" }, "source": [ "However, right now `x` has no gradient because it does not know what loss it must be differentiated with respect to.\n", "Below, we define the loss." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "GN2CeNSOCZw9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(83., grad_fn=)\n" ] } ], "source": [ "loss = ((2 * x + y)**2).sum()\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": { "id": "OFfIwK0jB8T3" }, "source": [ "And we perform back-propagation by calling `backward` on it." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "UIr4j_hrgN7n" }, "outputs": [], "source": [ "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "sfEKcV4YB-s9" }, "source": [ "Now we see that the gradients are populated!" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "yUSm8_r5gOhi" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([12., 20., 28.])\n", "tensor([ 6., 10., 14.])\n" ] } ], "source": [ "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "rUnvZN9Jg0mg" }, "source": [ "### gradients accumulate\n", "Gradients accumulate, os if you call backwards twice..." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "qwAZ8fyMgQ1z" }, "outputs": [], "source": [ "loss = ((2 * x + y)**2).sum()\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ndyknuW_CEGb" }, "source": [ "...you'll get twice the gradient." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "5bT1AWRngYVJ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([24., 40., 56.])\n", "tensor([12., 20., 28.])\n" ] } ], "source": [ "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "WJ3rei78gyxd" }, "source": [ "### multiple losses" ] }, { "cell_type": "markdown", "metadata": { "id": "atVB4ihsCH7s" }, "source": [ "This accumulation makes it easy to add gradients from different losses, which might not even use the same parameters. For example, this loss is only a function of `x`...." ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "n1Lf5NOcgfN3" }, "outputs": [], "source": [ "other_loss = (x**2).sum()\n", "other_loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "j_TI2Db4CQ5N" }, "source": [ "...and so only `x.grad` changes." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "4NX0Oj1Vgjia" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([26., 44., 62.])\n", "tensor([12., 20., 28.])\n" ] } ], "source": [ "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "_tVQv9nCgpxk" }, "source": [ "### stopping and starting gradients" ] }, { "cell_type": "markdown", "metadata": { "id": "cHO0aQioCXrC" }, "source": [ "If you don't specify `required_grad=True`, the gradient will always be `None`." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 54 }, "id": "ueSASTvNg6mP", "outputId": "3829650d-f4fc-4173-f97a-e27de7f4e262" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([12., 20., 28.])\n", "None\n" ] } ], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape)\n", "loss = ((2 * x + y)**2).sum()\n", "loss.backward()\n", "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "_o-HON3lCbjS" }, "source": [ "You can turn `required_grad` back on after initializing a tensor." ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "tOUN_k0hhFU5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([12., 20., 28.])\n", "tensor([ 6., 10., 14.])\n" ] } ], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape)\n", "y.requires_grad = True\n", "loss = ((2 * x + y)**2).sum()\n", "loss.backward()\n", "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "yHfgIecCCeyq" }, "source": [ "You can cut a gradient by calling `y.detach()`, which will return a new tensor with `required_grad=False`. Note that `detach` is not an in-place operation!" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "id": "_bQbj4u8hL3-", "outputId": "eaa75c2f-8258-4620-e09c-95658546531a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([12., 20., 28.])\n", "tensor([ 6., 10., 14.])\n" ] } ], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape, requires_grad=True)\n", "y_detached = y.detach()\n", "loss = ((2 * x + y)**2).sum()\n", "loss.backward()\n", "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([1., 1., 1.])\n" ] } ], "source": [ "print(y_detached)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "id": "RD-zF6Xna2GJ", "outputId": "60591077-350b-4112-bbde-3cf902c1251d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 40., 72., 104.])\n", "tensor([10., 18., 26.])\n" ] } ], "source": [ "shape = (3, )\n", "x = torch.tensor([1., 2, 3], requires_grad=True)\n", "y = torch.ones(shape, requires_grad=True)\n", "z = 2 * x\n", "z.required_grad = True\n", "loss = ((2 * z + y)**2).sum()\n", "loss.backward()\n", "print(x.grad)\n", "print(y.grad)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 71 }, "id": "YsxDGYysafqT", "outputId": "8d5f0ede-82ac-445e-d007-0ec1854fa5d8" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\kcsgo\\anaconda3\\lib\\site-packages\\torch\\tensor.py:746: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n", " warnings.warn(\"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad \"\n" ] } ], "source": [ "z.grad" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "id": "AkXdmy9Fah_1", "outputId": "aae3245b-3d60-4eab-f8b6-af7767ad459b" }, "outputs": [ { "data": { "text/plain": [ "tensor([1, 2], dtype=torch.int32)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.from_numpy(np.array([1,2]))" ] }, { "cell_type": "markdown", "metadata": { "id": "g0nGMph3amPa" }, "source": [ "Any difference between p.data.add_(-0.001 + p.grad) and p.data+= -0.001 + p.grad?" ] }, { "cell_type": "markdown", "metadata": { "id": "q8B8phm0hdrJ" }, "source": [ "## Modules\n", "![tmp.png]()" ] }, { "cell_type": "markdown", "metadata": { "id": "lLYr9XTskvut" }, "source": [ "`nn.Modules` represent the building blocks of a computation graph.\n", "For example, in typical pytorch code, each convolution block above is its own module, each fully connected block is a module, and the whole network itself is also a module.\n", "Modules can contain modules within them.\n", "All the classes inside of `torch.nn` are instances `nn.Modules`.\n", "Below is an example definition of a module:" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "id": "vwmkBLL0dQiw" }, "outputs": [], "source": [ "import pdb" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "id": "LhkD9c-Uheqd" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class Net(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", " super(Net, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 32)\n", " self.fc2 = nn.Linear(32, 32)\n", " self.fc3 = nn.Linear(32, output_size)\n", "\n", " def forward(self, x):\n", " pdb.set_trace()\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x\n", " \n", " def hook(self, gradient):\n", " return 2*gradient" ] }, { "cell_type": "markdown", "metadata": { "id": "ulXK0yY-DCwt" }, "source": [ "The main function that you need to implement is the `forward` function.\n", "Otherwise, it's a normal Python object:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 104 }, "id": "CvcCU1s1itC3", "outputId": "06896f68-b7da-4ebe-cf8d-04650f2d6c7d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Net(\n", " (fc1): Linear(in_features=1, out_features=32, bias=True)\n", " (fc2): Linear(in_features=32, out_features=32, bias=True)\n", " (fc3): Linear(in_features=32, out_features=1, bias=True)\n", ")\n" ] } ], "source": [ "net = Net(input_size=1, output_size=1)\n", "print(net)" ] }, { "cell_type": "markdown", "metadata": { "id": "RaYLDhamEMKw" }, "source": [ "Here we create some dummy input. The first dimension will be the batch dimension." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "id": "H2AGmUDrhy7i", "outputId": "da201e8d-b17c-4a57-f03a-0e7b75887fed" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 1])\n" ] } ], "source": [ "x = torch.linspace(-5, 5, 100).view(100, 1)\n", "print(x.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "vXIho7F8D9_x" }, "source": [ "To evaluate a neural network on some input, you pass an input through a module by calling it directly. In particular, don't call `net.forward(x)`." ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "JN4mWZlkiEmZ", "outputId": "6706faf8-dfd7-4a8e-8507-ed47d36a742e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "> (16)forward()\n", "-> x = F.relu(self.fc1(x))\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) print(x)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-5.0000],\n", " [-4.8990],\n", " [-4.7980],\n", " [-4.6970],\n", " [-4.5960],\n", " [-4.4949],\n", " [-4.3939],\n", " [-4.2929],\n", " [-4.1919],\n", " [-4.0909],\n", " [-3.9899],\n", " [-3.8889],\n", " [-3.7879],\n", " [-3.6869],\n", " [-3.5859],\n", " [-3.4848],\n", " [-3.3838],\n", " [-3.2828],\n", " [-3.1818],\n", " [-3.0808],\n", " [-2.9798],\n", " [-2.8788],\n", " [-2.7778],\n", " [-2.6768],\n", " [-2.5758],\n", " [-2.4747],\n", " [-2.3737],\n", " [-2.2727],\n", " [-2.1717],\n", " [-2.0707],\n", " [-1.9697],\n", " [-1.8687],\n", " [-1.7677],\n", " [-1.6667],\n", " [-1.5657],\n", " [-1.4646],\n", " [-1.3636],\n", " [-1.2626],\n", " [-1.1616],\n", " [-1.0606],\n", " [-0.9596],\n", " [-0.8586],\n", " [-0.7576],\n", " [-0.6566],\n", " [-0.5556],\n", " [-0.4545],\n", " [-0.3535],\n", " [-0.2525],\n", " [-0.1515],\n", " [-0.0505],\n", " [ 0.0505],\n", " [ 0.1515],\n", " [ 0.2525],\n", " [ 0.3535],\n", " [ 0.4545],\n", " [ 0.5556],\n", " [ 0.6566],\n", " [ 0.7576],\n", " [ 0.8586],\n", " [ 0.9596],\n", " [ 1.0606],\n", " [ 1.1616],\n", " [ 1.2626],\n", " [ 1.3636],\n", " [ 1.4646],\n", " [ 1.5657],\n", " [ 1.6667],\n", " [ 1.7677],\n", " [ 1.8687],\n", " [ 1.9697],\n", " [ 2.0707],\n", " [ 2.1717],\n", " [ 2.2727],\n", " [ 2.3737],\n", " [ 2.4747],\n", " [ 2.5758],\n", " [ 2.6768],\n", " [ 2.7778],\n", " [ 2.8788],\n", " [ 2.9798],\n", " [ 3.0808],\n", " [ 3.1818],\n", " [ 3.2828],\n", " [ 3.3838],\n", " [ 3.4848],\n", " [ 3.5859],\n", " [ 3.6869],\n", " [ 3.7879],\n", " [ 3.8889],\n", " [ 3.9899],\n", " [ 4.0909],\n", " [ 4.1919],\n", " [ 4.2929],\n", " [ 4.3939],\n", " [ 4.4949],\n", " [ 4.5960],\n", " [ 4.6970],\n", " [ 4.7980],\n", " [ 4.8990],\n", " [ 5.0000]])\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) l\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 11 \t self.fc2 = nn.Linear(32, 32)\n", " 12 \t self.fc3 = nn.Linear(32, output_size)\n", " 13 \t\n", " 14 \t def forward(self, x):\n", " 15 \t pdb.set_trace()\n", " 16 ->\t x = F.relu(self.fc1(x))\n", " 17 \t x = F.relu(self.fc2(x))\n", " 18 \t x = self.fc3(x)\n", " 19 \t return x\n", " 20 \t\n", " 21 \t def hook(self, gradient):\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) x.shape\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 1])\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) n\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "> (17)forward()\n", "-> x = F.relu(self.fc2(x))\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) n\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "> (18)forward()\n", "-> x = self.fc3(x)\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) p x\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[1.8007, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.2302],\n", " [1.7679, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.2086],\n", " [1.7351, 0.0000, 0.0000, ..., 0.0000, 0.0000, 1.1871],\n", " ...,\n", " [0.0000, 0.0000, 1.0427, ..., 0.0000, 0.0000, 1.0551],\n", " [0.0000, 0.0000, 1.0706, ..., 0.0000, 0.0000, 1.0787],\n", " [0.0000, 0.0000, 1.0985, ..., 0.0000, 0.0000, 1.1023]],\n", " grad_fn=)\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) print(x.shape)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 32])\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "(Pdb) c\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([100, 1])\n" ] } ], "source": [ "y = net(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "IHO5oRs1EQxz" }, "source": [ "Let's visualize what the networks looks like." ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 269 }, "id": "vJu-Tt2jiF2y", "outputId": "0c88791c-975d-4f30-c7e9-dfa2ebafee05" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD8CAYAAAB6paOMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAf/UlEQVR4nO3dfXRc9Z3f8fdXz7ZlWZIfZGPLD2ATwGaDQZgNtIl4Ntkcm7QkhZ7dddrkuNmGbLJp0pClJ7slSw/p9pSeZtNNnIQN2aSBJNsk3q6NQwCFNASwAAUkg23hRyHLsi2NbNl6HH37x72yBnkkjTwjzdj38zpHR3N/93dnvv6d8Xx0H353zN0REZHoyst2ASIikl0KAhGRiFMQiIhEnIJARCTiFAQiIhGnIBARibiMBIGZPWZm7WbWOMZ6M7P/aWbNZva6mV2bsG6Tme0NfzZloh4REUldpvYIvgusH2f9XcCq8Gcz8LcAZlYJ/AVwA7AO+Aszq8hQTSIikoKMBIG7Pw90jNNlI/A9D7wIlJvZIuBO4Gl373D3TuBpxg8UERHJsIJpep3FwOGE5Zawbaz2c5jZZoK9CWbMmHFddXX11FSaoqGhIfLydIoFNBbDDh8+jLuzdOnSbJeSE/S+GJErY7Fnz57j7j5/dPt0BYElafNx2s9tdN8CbAGoqanx+vr6zFV3Hurq6qitrc1qDblCYxGora0lFovR0NCQ7VJygt4XI3JlLMzsYLL26YqoFiDxT/glQOs47SIiMk2mKwi2An8cXj30+0CXux8BdgB3mFlFeJL4jrBNRESmSUYODZnZD4FaYJ6ZtRBcCVQI4O7fALYBHwSagTPAvwnXdZjZV4Cd4VM95O7jnXQWEZEMy0gQuPt9E6x34FNjrHsMeCwTdYiIyORl/zS2iIhklYJARCTiFAQiIhGnIBARiTgFgYhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiEScgkBEJOIyEgRmtt7MdptZs5k9kGT9o2bWEP7sMbNYwrp4wrqtmahHRERSl/Z3FptZPvB14HagBdhpZlvdfddwH3f/s4T+nwbWJjxFj7tfk24dIiJyfjKxR7AOaHb3fe7eDzwBbByn/33ADzPwuiIikgGZCILFwOGE5Zaw7RxmtgxYATyb0FxiZvVm9qKZ3Z2BekREZBLSPjQEWJI2H6PvvcBP3D2e0LbU3VvN7FLgWTN7w93fPudFzDYDmwGqqqqoq6tLs+z0dHd3Z72GXKGxCMRiMeLxuMYipPfFiFwfi0wEQQtQnbC8BGgdo++9wKcSG9y9Nfy9z8zqCM4fnBME7r4F2AJQU1PjtbW16dadlrq6OrJdQ67QWATKy8uJxWIai5DeFyNyfSwycWhoJ7DKzFaYWRHBh/05V/+Y2XuACuC3CW0VZlYcPp4H3ATsGr2tiIhMnbT3CNx90MzuB3YA+cBj7t5kZg8B9e4+HAr3AU+4e+JhoyuBb5rZEEEoPZJ4tZGIiEy9TBwawt23AdtGtX151PJfJtnuBeDqTNQgIiLnRzOLRUQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiEScgkBEJOIUBCIiEacgEBGJOAWBiEjEKQhERCJOQSAiEnEKAhGRiMtIEJjZejPbbWbNZvZAkvUfM7NjZtYQ/nwiYd0mM9sb/mzKRD0iIpK6tL+83szyga8DtwMtwE4z2+ruu0Z1fdLd7x+1bSXwF0AN4MAr4bad6dYlIiKpycQewTqg2d33uXs/8ASwMcVt7wSedveO8MP/aWB9BmoSEZEUpb1HACwGDicstwA3JOn3L83s/cAe4M/c/fAY2y5O9iJmthnYDFBVVUVdXV36laehu7s76zXkCo1FIBaLEY/HNRYhvS9G5PpYZCIILEmbj1r+R+CH7t5nZp8EHgduSXHboNF9C7AFoKamxmtra8+74Eyoq6sj2zXkCo1FoLy8nFgsprEI6X0xItfHIhOHhlqA6oTlJUBrYgd3P+HufeHit4DrUt1WRESmViaCYCewysxWmFkRcC+wNbGDmS1KWNwAvBk+3gHcYWYVZlYB3BG2iYjINEn70JC7D5rZ/QQf4PnAY+7eZGYPAfXuvhX4UzPbAAwCHcDHwm07zOwrBGEC8JC7d6Rbk4iIpC4T5whw923AtlFtX054/CXgS2Ns+xjwWCbqEBGRydPMYhGRiFMQiIhEnIJARCTiFAQiIhGnIBARiTgFgYhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiERcRoLAzNab2W4zazazB5Ks/5yZ7TKz183sGTNblrAubmYN4c/W0duKiMjUSvs7i80sH/g6cDvQAuw0s63uviuh22tAjbufMbM/Af4r8K/CdT3ufk26dYiIyPnJxB7BOqDZ3fe5ez/wBLAxsYO7P+fuZ8LFF4ElGXhdERHJgLT3CIDFwOGE5RbghnH6fxzYnrBcYmb1wCDwiLv/LNlGZrYZ2AxQVVVFXV1dOjWnrbu7O+s15AqNRSAWixGPxzUWIb0vRuT6WGQiCCxJmyftaPaHQA3wgYTmpe7eamaXAs+a2Rvu/vY5T+i+BdgCUFNT47W1tWkXno66ujqyXUOu0FgEysvLicViGouQ3hcjcn0sMnFoqAWoTlheArSO7mRmtwEPAhvcvW+43d1bw9/7gDpgbQZqEhGRFGUiCHYCq8xshZkVAfcC77r6x8zWAt8kCIH2hPYKMysOH88DbgISTzKLiMgUS/vQkLsPmtn9wA4gH3jM3ZvM7CGg3t23An8NlAI/NjOAQ+6+AbgS+KaZDRGE0iOjrjYSEZEplolzBLj7NmDbqLYvJzy+bYztXgCuzkQNIiJyfjSzWEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiEScgkBEJOIUBCIiEacgEBGJOAWBiEjEKQhERCJOQSAiEnEKAhGRiFMQiIhEnIJARCTiFAQiIhGXkSAws/VmttvMms3sgSTri83syXD9S2a2PGHdl8L23WZ2ZybqERGR1KUdBGaWD3wduAu4CrjPzK4a1e3jQKe7rwQeBb4abnsVcC+wGlgP/K/w+UREZJpk4svr1wHN7r4PwMyeADYCuxL6bAT+Mnz8E+BvzMzC9ifcvQ/Yb2bN4fP9drwX3L17N7W1tRko/fzFYjHKy8uzWkOu0FgEGhoaGBwczPp7M1fofTEi18ciE0GwGDicsNwC3DBWH3cfNLMuYG7Y/uKobRcnexEz2wxsBigsLCQWi2Wg9PMXj8ezXkOu0FgEBgcHcXeNRUjvixG5PhaZCAJL0uYp9kll26DRfQuwBaCmpsbr6+snU2PG1dXV6S+/kMYiUFtbSywWo6GhIdul5AS9L0bkylgEB2LOlYmTxS1AdcLyEqB1rD5mVgDMATpS3FZERKZQJoJgJ7DKzFaYWRHByd+to/psBTaFj+8BnnV3D9vvDa8qWgGsAl7OQE0iIpKitA8Nhcf87wd2APnAY+7eZGYPAfXuvhX4DvD34cngDoKwIOz3I4ITy4PAp9w9nm5NIiKSukycI8DdtwHbRrV9OeFxL/CRMbZ9GHg4E3WIiMjkaWaxiEjEKQhERCJOQSAiEnEKAhGRiFMQiIhEnIJARCTiMnL56HRrjfXwX7a9SVF+HkUF4U/C4+LwJ2jPp7gw75y+Z9cX5FFckE9+XvKp1yLD+geHONbdR1tXL0dP9rL/+Gn2Hj3F3vZu3on10HygE/chrvvK08wqLqCqrJgFZSUsLCthaeVMls2dybK5s1g0p4SSQt1kV3LHBRkEJ073s+X5fRl9zjzjbCgkC4vRQdLV0cs/tv/ubPAk6zNWMI3uU1yQd87zFORrZ22qxYecE919HOnqpe1kL21dvRzp6uVU7wC9A0P0Dsbp7h2k/VQfx071cry7f9znGxwaAoL354nT/RzqODNm3zkzCpk/u5iFZSVUhyGxtHImZSWFzCzOZ2ZRPqXFBcyZUUhpccGY94gRyYQLMgimwpAT/OcfGEp5m5faWqasnuFgCsIi/10hkWwPJzF0RgIm/90BkyR8hoPp3XtIwbbFhXmUFORTmG8X7AfRqd4B9h07zf7jpzl44gyHOs5wuOMM78R6OHqyl8GhpPc4nHJdPQN09QzQ3N49Yd/8PKNiZiGLy2ewpGImSypnMHdWEWUlhZTNKKSspJDZJQXMLilgVnHwX9odhnzk32YGp/sGaT/ZR/upPrr7Blk2dybXLq04u41El94BOerdwTSY1VrMoCQhGEoKg6AY/n3mVC/fP1hPSWEeJYVBe2L/4eAqLsynKD+PwoI8ivKNwvy8hJ+R5aICoyAvj4KwLT/PyE8IoiF34kPO4JAzGHdO9g5wMvxgbensYd/xbvYdO82+46c5dqoviyOXGfEh53h3P8e7+/ldS1dGnzs/z1h9SRnXL69k3YpKrl9eSeWsooy+huQ+BYFMyB16BuL0DMSBgaR9Gk8cnd6iJCPiQ87rLV283tLFd/7ffgBWLijl+uUV1CwLwmFJxYwLdo9QUnNBBsGiOSV8cf0V9A8O0R+PB78Hh+iPD9E3/DhcHn7cN6otWI6f3cazc4RALjDzSouoCk8AL5xTwsoFpVxeNZvL5pdyzwuVdHV18eyDt9HV009bVx9tJ3tpjfVw4MTIoakT3X1k6YhUSprbu2lu7+aHLwffNzV/djHXLi3n2qUVXLusgqsXz9HJ7ovMBRkE80qL+ZPayzL6nIPxMEgGRofFuwNlOHhee72Jyy5/T9LQGdk+fvY5xgqo0cHUF7YpmKZHWUkBC+eUsGjODBbNKaGqrIS5pUWUFAbnZWYU5jN/djFVZSXMKy2mqGDsk/gFeUa+BR+c82cXs3LB7KT94kPOidN9tJ/s451YDwfDkGjr6uV0/yA9/XFO98c51Rsc7prMeaupcOxUHzuajrKjKdjrK8rP4/eWzKFmeSXXL6/gumUVlM/U4aQL2QUZBFOhILxSJ9X3c8nx3dTWVE/c8Ty4B8e/k4dFEDDJAmXsvaH4hH1Gh2DvQBBKvQPxrJ1QzYT8PKO6YgaXzi9l+dxZwdU5c2dSXTGTRXNKsnKiND/PWDC7hAWzS1izeM6E/XsH4hw71UdLZw+HO8/QGuvhZM/g2XMjp3qDx6d6BznTH8csuNjAMMw4+0dFYYExv7SYBbNLKMg3XjsU451Yz6Tr748PUX+wk/qDnXzjV0Hb5VWlZ4Ph+uWVLKmYOennlexREOQgMzt78nRWcbarCfaW+kaFQ+9AEDC9A0PsfPU1Lr9y9Tnr+sJLMIcDpi9sHxhyBgaHGIgPnQ28gXCPbDDu9MeD5Xg8PCE85O+6AgagIC84wZyfZ5QWF5y9gmburCJWzJ/FpfNmcen8WSytnDXuX/EXgpLCfKorZ1JdOZP3MTejz90a62HngQ5e3h/87E3hKqZk9hztZs/Rbv73S4cAuGROCUtnDnC45CDrlleyakEpeZqrk7MUBDKh4b2lsf567jtcQO2aRdNclWTCJeUz2HjNYjZesxiAztP9vHKwk50HO6g/0MnrLTEG4pPfI2zt6qW1C178WSMQzJuoWVbB9SuCvYarF5df8AF9MVEQiMhZFbOKuO2qKm67qgoIDks1tZ7ktUOdvHqok50HOs/rktyungGeeaudZ95qB6C4II9rqsvPXrJ67bIKSjWfIWs08iIyppLCfK5bFpwQhuD81eGOHl4+0EH9gQ52Hujg7WOnJ/28fYNDvLS/g5f2dwDBeZOrFg3PZ6igZnkl80pz4LhoRKQVBGZWCTwJLAcOAB91985Rfa4B/hYoA+LAw+7+ZLjuu8AHgOFZMh9z94Z0ahKRqWNmLA1PuN9z3RIATnT3BSePD3Tw8oFOmt7pmvQFBvEh5413unjjnS4e+00wn+HS+bN4/6r5fOA983nfpXN1yeoUSneP4AHgGXd/xMweCJe/OKrPGeCP3X2vmV0CvGJmO9w9Fq7/grv/JM06RCRL5pYWc+fqhdy5eiEAZ/oHaTgU48d1r3KMObx6qJMz/fFJP+++Y6fZd+w0333hAMUFedy0ch53XFXFrVdWMX+29hYyKd0g2AjUho8fB+oYFQTuvifhcauZtQPzgRgictGZWVTAjSvn0d9SRG3tDQzEh9jVevLs1Un1BzvpOD3+DfxG6xsc4tm32nn2rXbM3uC6pRWsXxOET3WlLlVNl3kaM5fMLObu5QnLne5eMU7/dQSBsdrdh8JDQ+8D+oBngAfcPemZKDPbDGwGqKqquu6JJ54477ozobu7m9LS0qzWkCs0FoHPfvazxONxvva1r2W7lJww1vvC3Tly2tnTGWdv5xB7OuMc6zn/z6FlZXlcV5VPTVUBl5Tm5pVIufJ/5Oabb37F3WtGt08YBGb2S2BhklUPAo+nGgRmtohgj2GTu7+Y0NYGFAFbgLfd/aGJ/jE1NTVeX18/UbcpVVdXR21tbVZryBUai0BtbS2xWIyGBp3mgsm9L4509bDzQCc79wcnoHcfPXVes+tXLijlrnBPYfUlZTlzj6Rc+T9iZkmDYMJDQ+5+2zhPetTMFrn7kfBDvX2MfmXAPwH/aTgEwuc+Ej7sM7O/Az4/UT0icvFZNGcGG947gw3vvQQI5jP8uvk4dbvb+dXuY5xI8VBSc3s3X3u2ma8920x15QzWr17I+jWLWFtdrglt40j3HMFWYBPwSPj756M7mFkR8FPge+7+41HrhkPEgLuBxjTrEZGLQMWsIja89xI2vPcShoac1w7H+MWuNn7RdJT9x1O7XPVwRw/f+vV+vvXr/VSVBSe0169ZyLrllfrip1HSDYJHgB+Z2ceBQ8BHAMysBviku38C+CjwfmCumX0s3G74MtEfmNl8wIAG4JNp1iMiF5m8PDs7l+GB9VfQ3N7N9sY2nmpsY9eRkyk9x9GTfXzvtwf53m8PUjmriNuvrGL91Qu58bK5FBfostS0gsDdTwC3JmmvBz4RPv4+8P0xtr8lndcXkWgxM1ZVzWZV1Wz+9NZVHDpxhqeajvBUYxuvHkrtQsSO0/08WX+YJ+sPM7u4gFuvXMD6NQv5wOULmFEUzVDQzGIRuWAtnTuTze+/jM3vv4y2rl52NLWxvfEIL+/vSOk7H071DfKzhlZ+1tBKSWEetZcv4K6rF3LzFQsoKymc+n9AjlAQiMhFYeGcEjbduJxNNy7nRHcfT+86yvbGNl54+3hKN87rHRjiqaY2nmpqoyg/j5tWzuWuNYu47aqqi/7rOxUEInLRmVtazL3rlnLvuqV09Qzw7FtHeaqxjV/tOZbSF/30x4d4bvcxntt9jPyfGjesqOSuNQu5Y/VCqspKpuFfML0UBCJyUZszo5APr13Ch9cu4Uz/IL/afYynmtp45s12uvsGJ9w+PuS88PYJXnj7BF/e2sS1SyvCy1IvnlnNCgIRiYyZRQXcdfUi7rp6EX2DcV5oPsH2xiM8vesonWcGJtzeHV452MkrBzt5eNubrL6kjLvWBHMVVi7I/szh86UgEJFIKi7I5+YrFnDzFQsYjA/x8v4Otje2saOpjfYUv3OhqfUkTa0n+W+/2HN2VvP6NQu5alHuzGpOhYJARCKvID+PG1fO48aV8/jPG1bz2uFOtr/RxvbGtpS/1zlxVvPSyplnb4q3trp84o2zTEEgIpIgmMBWyXXLKnnwD66kqfUkTzUGl6Wm+iU8hzrOsOX5fWx5fh9VZcWsKY9TVH08Z2c1KwhERMZgZqxZPIc1i+fw+TvfQ3P7qbN7CpOZ1Xz0JDzzrZdGZjWvWciNK3NnVrOCQEQkRSsXzObTt87m0wmzmrc3tvHaBT6rWUEgInIeMjmreUZhPrXvmc/6NQu55YoFzJ7mWc0KAhGRNCWb1fxUUxu/aU5tVnPPQJztjcEhp2zMalYQiIhk0OhZzc+91c7f171BU4ef96zm4SuQpmpWs4JARGSKzJlRyN1rF1PetZd1N/6z9GY1/7yJa5eWc9eaRRmf1awgEBGZBqNnNf+m+ThPNbalPKsZ4NVDMV49FOPhbW+yZnHZ2W9gS3dWs4JARGSaFRfkc8sVVdxyRdV5z2pufOckje+8e1bz+X5Xs4JARCSLxprV/FRTGy2dk5/VfD7f1awgEBHJEclmNW9vDOYq7EtxVnPS72pevZB1KyrH3CatIDCzSuBJYDlwAPiou3cm6RcH3ggXD7n7hrB9BfAEUAm8CvyRu/enU5OIyMUgcVbzF+68gr1HT4W3ujj/72oeS7o3vXgAeMbdVwHPhMvJ9Lj7NeHPhoT2rwKPhtt3Ah9Psx4RkYvSqqpgRvO2z/xznv/Czfz5B69g7dLUb2jXcXrsv7HTDYKNwOPh48eBu1Pd0IKzGbcAPzmf7UVEomp4VvNP//1NvPilW3lo42red+lcUjgdkFS65wiq3P0IgLsfMbMFY/QrMbN6YBB4xN1/BswFYu4+fDFtC7B4rBcys83AZoCqqirq6urSLD093d3dWa8hV2gsArFYjHg8rrEI6X0xYqrHYinw7y6H+5bP5LX2QV5pi9N0Ik4Kk5qBFILAzH4JLEyy6sHJ1OnurWZ2KfCsmb0BJDvINWbZ7r4F2AJQU1PjtbW1k3j5zKurqyPbNeQKjUWgvLycWCymsQjpfTFiOsdi+Nj78Kzm7Y1HJvyu5gmDwN1vG2udmR01s0Xh3sAioH2M52gNf+8zszpgLfAPQLmZFYR7BUuA1onqERGRiQ3Par577eKz39X8wa8m75vuOYKtwKbw8Sbg56M7mFmFmRWHj+cBNwG73N2B54B7xtteRETSMzyreSzpBsEjwO1mthe4PVzGzGrM7NthnyuBejP7HcEH/yPuvitc90Xgc2bWTHDO4Dtp1iMiIpOU1slidz8B3JqkvR74RPj4BeDqMbbfB6xLpwYREUlP7n15poiITCsFgYhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEREIk5BICIScQoCEZGIUxCIiEScgkBEJOIUBCIiEacgEBGJOAWBiEjEpRUEZlZpZk+b2d7wd0WSPjebWUPCT6+Z3R2u+66Z7U9Yd0069YiIyOSlu0fwAPCMu68CngmX38Xdn3P3a9z9GuAW4Azwi4QuXxhe7+4NadYjIiKTlG4QbAQeDx8/Dtw9Qf97gO3ufibN1xURkQxJNwiq3P0IQPh7wQT97wV+OKrtYTN73cweNbPiNOsREZFJKpiog5n9EliYZNWDk3khM1sEXA3sSGj+EtAGFAFbgC8CD42x/WZgM0BVVRV1dXWTefmM6+7uznoNuUJjEYjFYsTjcY1FSO+LEbk+FhMGgbvfNtY6MztqZovc/Uj4Qd8+zlN9FPipuw8kPPeR8GGfmf0d8Plx6thCEBbU1NR4bW3tRKVPqbq6OrJdQ67QWATKy8uJxWIai5DeFyNyfSzSPTS0FdgUPt4E/Hycvvcx6rBQGB6YmRGcX2hMsx4REZmkdIPgEeB2M9sL3B4uY2Y1Zvbt4U5mthyoBn41avsfmNkbwBvAPOCv0qxHREQmacJDQ+Nx9xPArUna64FPJCwfABYn6XdLOq8vIiLp08xiEZGIUxCIiEScgkBEJOIUBCIiEacgEBGJOAWBiEjEKQhERCJOQSAiEnEKAhGRiFMQiIhEnIJARCTiFAQiIhGnIBARiTgFgYhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQEQk4tIKAjP7iJk1mdmQmdWM02+9me02s2YzeyChfYWZvWRme83sSTMrSqceERGZvHT3CBqBfwE8P1YHM8sHvg7cBVwF3GdmV4Wrvwo86u6rgE7g42nWIyIik5RWELj7m+6+e4Ju64Bmd9/n7v3AE8BGMzPgFuAnYb/HgbvTqUdERCavYBpeYzFwOGG5BbgBmAvE3H0woX3xWE9iZpuBzeFit5lNFEBTbR5wPMs15AqNxYh5ZqaxCOh9MSJXxmJZssYJg8DMfgksTLLqQXf/eQovbEnafJz2pNx9C7AlhdebFmZW7+5jnheJEo3FCI3FCI3FiFwfiwmDwN1vS/M1WoDqhOUlQCtBOpabWUG4VzDcLiIi02g6Lh/dCawKrxAqAu4Ftrq7A88B94T9NgGp7GGIiEgGpXv56IfNrAV4H/BPZrYjbL/EzLYBhH/t3w/sAN4EfuTuTeFTfBH4nJk1E5wz+E469UyznDlMlQM0FiM0FiM0FiNyeiws+MNcRESiSjOLRUQiTkEgIhJxCoIMMLPPm5mb2bxs15ItZvbXZvaWmb1uZj81s/Js1zTdxrqVStSYWbWZPWdmb4a3oPlMtmvKJjPLN7PXzOz/ZruWsSgI0mRm1cDtwKFs15JlTwNr3P33gD3Al7Jcz7Sa4FYqUTMI/Ad3vxL4feBTER4LgM8QXCiTsxQE6XsU+I+MMxkuCtz9FwmzxF8kmBcSJUlvpZLlmrLC3Y+4+6vh41MEH4Jj3jXgYmZmS4A/AL6d7VrGoyBIg5ltAN5x999lu5Yc82+B7dkuYpolu5VKJD/8EpnZcmAt8FJ2K8ma/0Hwh+JQtgsZz3Tca+iCNt4tNoA/B+6Y3oqyJ5XbjZjZgwSHBn4wnbXlgEndMiUKzKwU+Afgs+5+Mtv1TDcz+xDQ7u6vmFlttusZj4JgAmPdYsPMrgZWAL8LbqTKEuBVM1vn7m3TWOK0meh2I2a2CfgQcKtHb4LKWLdSiSQzKyQIgR+4+//Jdj1ZchOwwcw+CJQAZWb2fXf/wyzXdQ5NKMsQMzsA1Lh7LtxhcNqZ2XrgvwMfcPdj2a5nuplZAcFJ8luBdwhurfKvE2bRR0Z4i/nHgQ53/2y268kF4R7B5939Q9muJRmdI5BM+RtgNvC0mTWY2TeyXdB0muBWKlFzE/BHwC3he6Eh/KtYcpT2CEREIk57BCIiEacgEBGJOAWBiEjEKQhERCJOQSAiEnEKAhGRiFMQiIhE3P8H6hvm0FfV+u8AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "ccm0jfOcETl0" }, "source": [ "The network keeps track of all the parameters and gradients!" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "id": "_nQx9_nDmKXs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "print(net.fc1.bias.grad)" ] }, { "cell_type": "markdown", "metadata": { "id": "oirQcFSmDVW5" }, "source": [ "In the `__init__` function, any variable that you assign to `self` that is also a module will be automatically added as a sub-module. The parameters of a module (and all sub-modules) can be accessed through the `parameters()` function:" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 121 }, "id": "Vwrjulr8mlwR", "outputId": "6cc62d8f-491f-482e-ba48-e02f0745cd6a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 1])\n", "torch.Size([32])\n", "torch.Size([32, 32])\n", "torch.Size([32])\n", "torch.Size([1, 32])\n", "torch.Size([1])\n" ] } ], "source": [ "for p in net.parameters():\n", " print(p.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "RAkasCBWDd2D" }, "source": [ "WARNING: if you want to have a list of modules use\n", "```\n", "def __init__(self, network1, network2):\n", " self.list = nn.ModuleList([network1, network 2])\n", "```\n", "and **not** \n", "```\n", "def __init__(self, network1, network2):\n", " self.list = [network1, network 2]\n", "```\n", "In the later case, `network1` and `network2` won't be added as sub-modules." ] }, { "cell_type": "markdown", "metadata": { "id": "PcP0KwA4EYH-" }, "source": [ "The output of the module is just a tensor. We can perform operations on the tensor like before to automatically compute derivatives.\n", "For example, below, we minimize the sum-of-squares of the output." ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "id": "tqDLwkAXmSV3" }, "outputs": [], "source": [ "loss = (y**2).sum()\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "mzAJE131EkW5" }, "source": [ "We can manually update the parameters by adding the gradient (times a negative learning rate) and zero'ing out the gradients to prevent gradient accumulation." ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "id": "1tT_gl-QmfMx" }, "outputs": [], "source": [ "for p in net.parameters():\n", " p.data.add_(-0.001 * p.grad)\n", " p.grad.data.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "cBSnExmmEsy5" }, "source": [ "And we can do this in a loop to train our network!" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "class Net2(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", " super(Net2, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 32)\n", " self.fc2 = nn.Linear(32, 32)\n", " self.fc3 = nn.Linear(32, output_size)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x\n", " \n", " def hook(self, gradient):\n", " return 2*gradient\n", " \n", "net = Net2(input_size=1, output_size=1)" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "id": "mrhmNphPnWEQ" }, "outputs": [], "source": [ "for _ in range(100):\n", " y = net(x)\n", " loss = (y**2).sum()\n", " loss.backward()\n", " for p in net.parameters():\n", " p.data.add_(- 0.001 * p.grad)\n", " p.grad.data.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "oGe1_gpCE1hs" }, "source": [ "Sure enough, our network learns to set everything to zero." ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "id": "YnZ9swUJnp1d" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "5onPqgL_o9N9" }, "source": [ "## Loss Functions\n", "PyTorch has a bunch of built in loss functions, which are just other modules that you can pass your data through." ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "id": "KutB4zljpCJz" }, "outputs": [], "source": [ "y_target = torch.sin(x)\n", "loss_fn = nn.SmoothL1Loss()" ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "id": "LBtKp7iro-VR" }, "outputs": [], "source": [ "for _ in range(1000):\n", " y = net(x)\n", " loss = loss_fn(y, y_target)\n", " loss.backward()\n", " for p in net.parameters():\n", " p.data.add_(- 0.001 * p.grad)\n", " p.grad.data.zero_()" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "id": "8DHmqNfxpPoQ" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "WX-qNKKPn2Ft" }, "source": [ "## Optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "pPhBZESWE-9M" }, "source": [ "We can use more fancy optimizers with the `optim` package." ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "id": "pZ0ptKURn5TE" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "net = Net2(input_size=1, output_size=1)\n", "\n", "optimizer = optim.Adam(net.parameters(), lr=1e-3)\n", "\n", "x = torch.linspace(-5, 5, 100).view(-1, 1)\n", "y = net(x)\n", "y_target = torch.sin(x)\n", "loss_fn = nn.MSELoss()" ] }, { "cell_type": "markdown", "metadata": { "id": "TTmTVBaDFDP-" }, "source": [ "Here's the network before training" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "id": "vR278Ii7ooXa" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "CCCE8LCoFEiB" }, "source": [ "and here's how you can use the optimize to train the network.\n", "Note that we call `zero_grad` _before_ calling `loss.backward()`, and then we just call `optimizer.step()`. This `step` function will take care of updating all the parameters that were passed to that optimizer's constructor." ] }, { "cell_type": "code", "execution_count": 78, "metadata": { "id": "sJfVNGhPod5i" }, "outputs": [], "source": [ "for _ in range(100):\n", " y = net(x)\n", " loss = loss_fn(y, y_target)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": { "id": "lOM7_Xr6FUaV" }, "source": [ "And we see that this trained a network quite well" ] }, { "cell_type": "code", "execution_count": 79, "metadata": { "id": "219Pojw0os9X" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot(x.detach().numpy(), y.detach().numpy(), ylim=(-1, 1), xlim=(-5, 5))" ] }, { "cell_type": "markdown", "metadata": { "id": "5lS-wv_ABTky" }, "source": [ "## Regression example" ] }, { "cell_type": "markdown", "metadata": { "id": "j1KeajKoec60" }, "source": [ "We're going to train two separate neural networks to solve a prediction task:\n", "$$f_\\theta(x) \\approx y$$\n", "First we generate the data" ] }, { "cell_type": "code", "execution_count": 81, "metadata": { "id": "0_TUikVIuBCA" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "N = 100\n", "d = 1\n", "X = np.random.randn(N, 1)\n", "Y = X * 2 + 3 + np.random.randn(N, 1)\n", "plt.scatter(X, Y)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 82, "metadata": { "id": "EJ3h8uW1WWYg" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(100, 1) (100, 1)\n" ] } ], "source": [ "print(X.shape, Y.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "c-2yrp4uFa43" }, "source": [ "Next we convert that data to pytorch" ] }, { "cell_type": "code", "execution_count": 83, "metadata": { "id": "5Q6Mil6uUoJs" }, "outputs": [], "source": [ "X_pt = torch.from_numpy(X).to(torch.float32)\n", "Y_pt = torch.from_numpy(Y).to(torch.float32)\n", "\n", "loss_fn = nn.MSELoss()" ] }, { "cell_type": "markdown", "metadata": { "id": "d5hhaIlsFd_4" }, "source": [ "Define the training loop" ] }, { "cell_type": "code", "execution_count": 84, "metadata": { "id": "JbQaSvbruUjo" }, "outputs": [], "source": [ "def train(net: nn.Module):\n", " optimizer = optim.SGD(net.parameters(), lr=1e-2)\n", " losses = []\n", " for _ in range(100):\n", " Y_hat_pt = net(X_pt)\n", " loss = loss_fn(Y_hat_pt, Y_pt) \n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " losses.append(loss.detach().numpy())\n", " return np.array(losses)" ] }, { "cell_type": "markdown", "metadata": { "id": "J9UGOhquFfz-" }, "source": [ "Let's test two different networks" ] }, { "cell_type": "code", "execution_count": 86, "metadata": { "id": "VH4ARnKUqVt0" }, "outputs": [], "source": [ "linear_network = nn.Linear(1, 1)\n", "linear_losses = train(linear_network)\n", "\n", "non_linear_network = Net2(1, 1)\n", "non_linear_losses = train(non_linear_network)" ] }, { "cell_type": "markdown", "metadata": { "id": "2kHLnn2TFiXO" }, "source": [ "and plot the losses and predictions." ] }, { "cell_type": "code", "execution_count": 88, "metadata": { "id": "N3tTaRBkv573" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "t = np.arange(len(linear_losses))\n", "plt.plot(t, linear_losses, t, non_linear_losses)\n", "plt.legend(['linear', 'non_linear'])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 90, "metadata": { "id": "Wu_cKZ5PUzT0" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x_probe = torch.linspace(X.min(), X.max(), 300).reshape(-1, 1)\n", "y_lin = linear_network(x_probe).detach().numpy()\n", "y_non_lin = non_linear_network(x_probe).detach().numpy()\n", "\n", "plt.scatter(x_probe.numpy(), y_lin, s=3)\n", "plt.scatter(x_probe.numpy(), y_non_lin, s=3)\n", "plt.scatter(X, Y)\n", "plt.legend(['linear', 'non_linear', 'data'])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "N7Faplaeu1Ur" }, "source": [ "# PyTorch - Advanced" ] }, { "cell_type": "markdown", "metadata": { "id": "lSt5Pr5gu5qX" }, "source": [ "## Distributions\n", "PyTorch has a very convenient [distributions](https://pytorch.org/docs/stable/distributions.html) package." ] }, { "cell_type": "code", "execution_count": 91, "metadata": { "id": "qaCv2BGmbXWF" }, "outputs": [], "source": [ "from torch import distributions" ] }, { "cell_type": "markdown", "metadata": { "id": "C4p2WPuCGxww" }, "source": [ "You create distributions by passing the parameters of the distribution." ] }, { "cell_type": "code", "execution_count": 92, "metadata": { "id": "790r1b6tduyY" }, "outputs": [], "source": [ "mean = torch.zeros(1, requires_grad=True)\n", "std = torch.ones(1, requires_grad=True)\n", "gaussian = distributions.Normal(mean, std)" ] }, { "cell_type": "markdown", "metadata": { "id": "6Hoepp6iG0tl" }, "source": [ "These distributions are instances of the more general `Distribution` class, which you can read more about [here](https://pytorch.org/docs/stable/distributions.html#distribution)." ] }, { "cell_type": "code", "execution_count": 93, "metadata": { "id": "1RO2ebhxgYUx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Normal(loc: tensor([0.], requires_grad=True), scale: tensor([1.], requires_grad=True))\n", "True\n" ] } ], "source": [ "print(gaussian)\n", "print(isinstance(gaussian, distributions.Distribution))" ] }, { "cell_type": "code", "execution_count": 94, "metadata": { "id": "4X2m8tlmd0Rb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.9333]])\n" ] } ], "source": [ "sample = gaussian.sample((1,))\n", "print(sample)" ] }, { "cell_type": "code", "execution_count": 95, "metadata": { "id": "0UbTN8I5Fmxs" }, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.3545]], grad_fn=)" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian.log_prob(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "z3BYXDwOG8Z-" }, "source": [ "The log probability depends on the the parameters of the distribution. So, calling `backward` on a loss that depends on `log_prob` will back-propagate gradients into the parmaeters of the distribution.\n", "\n", "NOTE: this won't back-propagate through the samples (the \"reparatermization trick''), unless you use `rsample`, which is only implemented for some distributions." ] }, { "cell_type": "markdown", "metadata": { "id": "gLAaXWLHeLco" }, "source": [ "at 5:18, we can set loss to -log_prob to maximize the probability of an event. That makes sense to me because the higher prob is, the smaller loss would be. Now if we want to incorporate reward into this loss function, i think usually people do loss=-log_prob*reward? but that means the higher the reward is, the higher the loss is. Is this because we want to push the prob of the event to be higher with a higher loss when we have a higher reward? or should we have lower loss with higher reward (loss=log_prob*reward)?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "azs4-dQ3fWoP" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "IZ7F-74Be6Rg" }, "source": [ "$$s_1, a_1, s_2, a_2, ...$$" ] }, { "cell_type": "code", "execution_count": 96, "metadata": { "id": "p1Hm8zTjd9pR" }, "outputs": [], "source": [ "loss = - gaussian.log_prob(sample).sum()" ] }, { "cell_type": "code", "execution_count": 97, "metadata": { "id": "oxkylpG9hGZU" }, "outputs": [], "source": [ "loss.backward()" ] }, { "cell_type": "code", "execution_count": 98, "metadata": { "id": "3Ad0bfIChIgS" }, "outputs": [ { "data": { "text/plain": [ "tensor([-0.9333])" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "_KmS36DoeAgn" }, "source": [ "## Batch-Wise distribution" ] }, { "cell_type": "markdown", "metadata": { "id": "HB1GRLL3GZ-d" }, "source": [ "The distributions also support batch-operations. In this case, all the operations (`sample`, `log_prob`, etc.) are batch-wise." ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "id": "DfPMQRfWeDRs" }, "outputs": [], "source": [ "mean = torch.zeros(10)\n", "std = torch.ones(10)\n", "gaussian = distributions.Normal(mean, std)" ] }, { "cell_type": "code", "execution_count": 100, "metadata": { "id": "cFPZf9mBhXMa" }, "outputs": [ { "data": { "text/plain": [ "Normal(loc: torch.Size([10]), scale: torch.Size([10]))" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian" ] }, { "cell_type": "code", "execution_count": 101, "metadata": { "id": "EhwiqGtseGD-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-1.4607, -0.5094, -1.2975, -1.6653, -0.1941, 1.3166, 0.5912, -0.6193,\n", " -0.2724, -0.4517]])\n" ] } ], "source": [ "sample = gaussian.sample((1,))\n", "print(sample)" ] }, { "cell_type": "code", "execution_count": 102, "metadata": { "id": "fGVHHImreH5n" }, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.9857, -1.0487, -1.7606, -2.3055, -0.9378, -1.7857, -1.0937, -1.1107,\n", " -0.9560, -1.0210]])" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian.log_prob(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "wUn4p3aReLeR" }, "source": [ "## Multivariate Normal" ] }, { "cell_type": "markdown", "metadata": { "id": "OUchz-ayHlVz" }, "source": [ "There are other distributions" ] }, { "cell_type": "code", "execution_count": 103, "metadata": { "id": "Dv74f_XZbTv_" }, "outputs": [], "source": [ "mean = torch.zeros(2)\n", "covariance = torch.tensor(\n", " [[1, 0.8],\n", " [0.8, 1]]\n", ")\n", "gaussian = distributions.MultivariateNormal(mean, covariance)" ] }, { "cell_type": "code", "execution_count": 104, "metadata": { "id": "g1vfCg1McpOM" }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.3727, -0.0504]])" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian.sample((1,))" ] }, { "cell_type": "code", "execution_count": 105, "metadata": { "id": "XMG929mzeqJz" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 105, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "samples = gaussian.sample((500,))\n", "plt.scatter(samples[:, 0].numpy(), samples[:, 1].numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "vXCJs0dCHn54" }, "source": [ "NOTE: if you want to use a batch of `MultivariateNormal` distributions, you'll need to construct a batch of covariance matrices (i.e. shape `[BATCH_SIZE, DIM, DIM]`)." ] }, { "cell_type": "markdown", "metadata": { "id": "KsZZu33nfDB5" }, "source": [ "## Categorical Distribution" ] }, { "cell_type": "code", "execution_count": 106, "metadata": { "id": "IfoFevpBdy9X" }, "outputs": [], "source": [ "from torch import distributions" ] }, { "cell_type": "markdown", "metadata": { "id": "ePKpc3akHxit" }, "source": [ "Another useful distribution is the categorical distribution." ] }, { "cell_type": "code", "execution_count": 107, "metadata": { "id": "QgRw31OMfEcV" }, "outputs": [], "source": [ "probs = torch.tensor([0.1, 0.2, 0.7])\n", "dist = distributions.Categorical(probs=probs)" ] }, { "cell_type": "code", "execution_count": 108, "metadata": { "id": "K5ucMwVKfNv9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 1, 2, 2, 2])\n" ] } ], "source": [ "sample = dist.sample([20])\n", "print(sample)" ] }, { "cell_type": "code", "execution_count": 109, "metadata": { "id": "ye8mqQa4iDa5" }, "outputs": [ { "data": { "text/plain": [ "tensor([-0.3567, -0.3567, -0.3567, -1.6094, -1.6094, -0.3567, -0.3567, -0.3567,\n", " -0.3567, -0.3567, -2.3026, -0.3567, -0.3567, -0.3567, -0.3567, -0.3567,\n", " -1.6094, -0.3567, -0.3567, -0.3567])" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dist.log_prob(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "lNEq80qTfZ6s" }, "source": [ "## Distributions and Modules" ] }, { "cell_type": "markdown", "metadata": { "id": "OJm0r8WFIwqf" }, "source": [ "Typically, your network will output parameters of a distribution" ] }, { "cell_type": "code", "execution_count": 110, "metadata": { "id": "TkKsDyCmIzvF" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class Net(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", " super(Net, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 32)\n", " self.fc2 = nn.Linear(32, 32)\n", " self.fc3 = nn.Linear(32, output_size)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 111, "metadata": { "id": "3nCEfh-II2zV" }, "outputs": [], "source": [ "mean_network = Net(1, 1)\n", "x = torch.randn(100, 1)\n", "mean = mean_network(x)\n", "distribution = distributions.Normal(x, scale=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "G9FDsX3DIpFn" }, "source": [ "If you want, your nn.Module can return a distribution in the `forward` function!" ] }, { "cell_type": "code", "execution_count": 112, "metadata": { "id": "Gz2eMYZydMyC" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", " super(Net, self).__init__()\n", " self.fc1 = nn.Linear(input_size, 32)\n", " self.fc2 = nn.Linear(32, 32)\n", " self.fc3 = nn.Linear(32, output_size)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return distributions.Normal(x, scale=1)" ] }, { "cell_type": "code", "execution_count": 113, "metadata": { "id": "L5KpD2reccWJ" }, "outputs": [], "source": [ "distribution_network = Net(1, 1)\n", "x = torch.randn(100, 1)\n", "distribution = distribution_network(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "4Wh09beAu6d5" }, "source": [ "## `gather`\n", "This function will be useful for the DQN assignment.\n", "It allows you to index into arrays in a batch." ] }, { "cell_type": "code", "execution_count": 114, "metadata": { "id": "l0A4vz6iJKab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0, 1, 2],\n", " [3, 4, 5]])\n", "tensor([[0],\n", " [1]])\n" ] }, { "ename": "RuntimeError", "evalue": "Size does not match at dimension 1 get 3 vs 1", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m: Size does not match at dimension 1 get 3 vs 1" ] } ], "source": [ "x = torch.arange(6).reshape(2, 3)\n", "y = torch.tensor([0, 1]).reshape(2, 1)\n", "print(x)\n", "print(y)\n", "print(torch.gather(x, 0, y))" ] }, { "cell_type": "code", "execution_count": 116, "metadata": { "id": "NfMgAVAYu7n_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0, 1, 2],\n", " [3, 4, 5]])\n", "tensor([[0, 1, 0]])\n", "tensor([[0, 4, 2]])\n" ] } ], "source": [ "x = torch.arange(6).reshape(2, 3)\n", "y = torch.tensor([0, 1, 0]).reshape(1, 3)\n", "print(x)\n", "print(y)\n", "print(torch.gather(x, 0, y))" ] }, { "cell_type": "markdown", "metadata": { "id": "hUmPV_O_cFz4" }, "source": [ "For a 3-D tensor the output is specified by::\n", "\n", " out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0\n", " out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1\n", " out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "PyTorch_Tutorial.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }