{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "5ea243e3-7f3f-4181-814b-b964b413d43d", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jovyan/work/d2l/notebooks/d2l.py:119: SyntaxWarning: assertion is always true, perhaps remove parentheses?\n", " assert(self, 'net'), 'Neural network is defined'\n", "/home/jovyan/work/d2l/notebooks/d2l.py:123: SyntaxWarning: assertion is always true, perhaps remove parentheses?\n", " assert(self, 'trainer'), 'trainer is not inited'\n" ] } ], "source": [ "import random\n", "import torch\n", "import d2l\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from torch import autograd\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "id": "4afb1d6a-795b-40cd-9649-f8dfd1f94fcb", "metadata": {}, "source": [ "# 3.4.6. Exercises" ] }, { "cell_type": "markdown", "id": "1e8108cf-40cf-46df-92b1-b9dec8b024eb", "metadata": {}, "source": [ "## 1. What would happen if we were to initialize the weights to zero. Would the algorithm still work? What if we initialized the parameters with variance 1000 rather than 0.1?" ] }, { "cell_type": "markdown", "id": "d4c1ddac-e27e-496a-a57c-cbf1c9a57fac", "metadata": {}, "source": [ "### origin model(initialize the weights with norm(0,0.01))" ] }, { "cell_type": "code", "execution_count": 7, "id": "172a12b0-698d-4a33-811b-7f70a89912a2", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T09:38:44.036394\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = d2l.SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)\n", "model = d2l.LinearRegressScratch(2, lr=0.03)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model,data)" ] }, { "cell_type": "markdown", "id": "24747c53-f141-4b29-8c38-97c772a9b23f", "metadata": {}, "source": [ "### Initialize the weights with zero" ] }, { "cell_type": "code", "execution_count": 5, "id": "f4b5673b-4551-4eee-9833-bb9b3bdba64b", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-14T07:56:31.261316\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = d2l.LinearRegressScratch(2, lr=0.03,sigma = 0)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model,data)" ] }, { "cell_type": "markdown", "id": "54a48cf7-8244-4783-9c0d-9c1ace542b19", "metadata": {}, "source": [ "### Initialize the weights with norm(0,1000)" ] }, { "cell_type": "code", "execution_count": 6, "id": "5f19435d-cd4e-480d-a425-5806ae69375d", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-14T07:56:42.287982\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = d2l.LinearRegressScratch(2, lr=0.03,sigma=1000)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model,data)" ] }, { "cell_type": "markdown", "id": "9bf31077-3511-464d-acc8-4a8478a24599", "metadata": {}, "source": [ "## 2. Assume that you are Georg Simon Ohm trying to come up with a model for resistors that relate voltage and current. Can you use automatic differentiation to learn the parameters of your model?" ] }, { "cell_type": "markdown", "id": "3e84b854-10dd-4eaa-a0eb-2a8e4a8b6948", "metadata": {}, "source": [ "In this context, starting from Ohm's Law, it is understood that the current passing through a resistor is directly proportional to the voltage applied:\n", "$$\n", "V = I \\cdot R \n", "$$\n" ] }, { "cell_type": "markdown", "id": "c8283ff3-c4ab-429a-9395-15c893cc39d4", "metadata": {}, "source": [ "Where:\n", "* $V$ represents voltage\n", "* $I$ represents current\n", "* $R$ represents resistance" ] }, { "cell_type": "markdown", "id": "93074a64-9d2f-44d5-b1a3-91ad09a80cd9", "metadata": { "tags": [] }, "source": [ "So we generate synthetic data $V$ and $I$ with a parameter $R$ and build a line regression model to simulate the relationship between voltage and current. Meanwhile MSE is employed as the loss function to quantify the difference between predicted and actual voltages. Utilizing the stochastic gradient descent (SGD) optimizer, a training loop encompasses forward and backward propagation, as well as parameter updates. As training progresses, the model gradually adjusts the value of parameter $R$, allowing the predicted voltage to approach the actual voltage." ] }, { "cell_type": "code", "execution_count": 2, "id": "cd5b504f-3fb2-4c76-bed2-28ebaf481e7f", "metadata": { "tags": [] }, "outputs": [], "source": [ "data = d2l.SyntheticRegressionData(w=torch.tensor([5.1]), b=0, noise=1)\n", "model = d2l.LinearRegressScratch(1, lr=0.03)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model,data)" ] }, { "cell_type": "code", "execution_count": 3, "id": "fb1de441-01d6-40b7-a954-a29f1ea8232c", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUoAAADsCAYAAAASG+9CAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAauElEQVR4nO3dfXCU1b0H8O9uSDYvJguBkE1KCkvAahpJTDQxogzQIChFKB1Ga7HAdVARvMU4VvCikVEa30YZhYLObaGtQ3Vu61u9mlsaC9xqILdEC4FieQlgyW54CcnGQBLY3fsH7Eo2+7777HnO83w/M8yYzYb8jPjlPOf8zjkGt9vtBhERBWQUXQARkdoxKImIQmBQEhGFwKAkIgqBQUlEFAKDkogoBAYlEVEIDEoiohCGiC4gVi6XC21tbcjMzITBYBBdDhFJwu12o7u7G/n5+TAag48ZpQ/KtrY2FBQUiC6DiCT11VdfYdSoUUHfI31QZmZmArj0L5uVlSW4GiKShcPhQEFBgTdDgpE+KD2P21lZWQxKIopYOFN2XMwhIgpB+hElEemX0+VGU2sHTnb3YmRmKiqs2Ugyxn9Rl0FJRFKqb7Fh9R/3w9bV630tz5yK2llFmFGcF9fvxUdvIpJOfYsNS95sHhCSAGDr6sWDbzbjoz22uH4/BiURScXpcmP1H/cj2Injy37XjI/2tMXtezIoiUgqTa0dg0aSvlxu4KEtn6O+JT4jS85REpGq+S7Y/Glf+OG3+o/7Ma3IEvMCD4OSiFTlymA8evocftd0HHbHNyPISHYq27p60dTagarC4THVxKAkItXwt5LtK9LrEE92B39MDweDkogUE0mfo2clO97Xwo7MTI3592BQEpEigvU5TiuyDAjQ8tHDQq5kR8oAwGK+FM6xYlASUdwFGh16+hyHpiej89wF7+vZGcno6LmAeKudVRSXnToMSiKKq3D6HK8MSQBxD8l479BhUBJRXIXT56iE4RkpmF2aj2lFlrjv+WZQElFcRdLnGA9D05Kx/sdluGnscEUOxAC4M4eI4qi+xYZNnx1L6PfsPH8BRoNBsZAEGJREFCeeuUkR4tErGQyDkojiIh5zk9kZyVF9XTx6JYPhHCURRe3KhvKD7V/H/Pv9sOxb+HCPHfau3rB6KuPZKxkMg5KIggq0uyac7YaR+s//PYr7J1nxxo5WGICgYemZkYxXr2QwDEoinQu2zfCjPTaser8FHT393vfnmVNxZ0ke3tjRGvfthgDwwd9tWH9PGZ7574EhbDRcOj7Nw6LQaeb+MCiJdCzYNsPPj5/F6ztaB32NravX7+vx4L78+w/LSMFfH586aJvj7mNnFb8fxx8GJZFOBdpmaL+8zVCkk929SDIaBh2PFutxadHiqjeRDgXbZqjE43SklF7FjhRHlEQ6JGqbYSiJWsWOFEeURDqkdIN2NBK5ih0pjiiJdEhtj7ZAYlexI8WgJNKhCms28sypYTd2K2nKd3Jw/6TChK5iR4qP3kQ6lGQ0oHZWkegyAAD3TypEVaFyJ//Eg6JBuWPHDsyaNQv5+fkwGAx47733Bnze7XbjqaeeQl5eHtLS0lBdXY2DBw8qWRIRXTatyIK5Zd8SXYYq50t9KRqUPT09KCkpwfr16/1+/oUXXsCrr76KjRs3YteuXcjIyMD06dPR26v+HxyRTJwuNxoPn8H7X5xA4+Ez+GiPDbc8/wn+0HxCdGmqnC/1pegc5e23347bb7/d7+fcbjfWrl2LVatWYfbs2QCA3/zmN8jNzcV7772Hu+++W8nSiHTjoz1tl7chxv9OmliotRXIH2FzlK2trbDb7aiurva+ZjabUVlZicbGxoBf19fXB4fDMeAXEflX99F+PLTlc9WFpIcaW4H8EbbqbbfbAQC5ubkDXs/NzfV+zp+6ujqsXr1a0dqIZHXlARetp3oU25Mdq3hf/qU06dqDVq5ciZqaGu/HDocDBQUFAisiUgcljj2Lt9KCLDw+o0jVrUD+CAtKi8UCAGhvb0de3jd/q7S3t6O0tDTg15lMJphMJqXLI5JKoAMu1Oax264VdrBFLITNUVqtVlgsFjQ0NHhfczgc2LVrF6qqqkSVRSSdcO7RVovTPX2iS4iKoiPKr7/+GocOHfJ+3Nraii+++ALZ2dn49re/jeXLl+PZZ5/F+PHjYbVa8eSTTyI/Px9z5sxRsiwiTdl55Izwx+2xI9Jx5PS5kO+ToRXIH0WD8m9/+xumTJni/dgzt7hgwQJs3rwZP/vZz9DT04P7778fnZ2duOWWW1BfX4/UVDl/mESJ5HS5se6TQ3h9+2GhdSQnGbD6zmI89vu/o93R53dkK1MrkD8Gt9stw4g9IIfDAbPZjK6uLmRlZYkuh0hxlwLyIF7fcQTn+p2iy/Eamp6MznMXBt1141my2TC/TFWr3JFkh3Sr3kR6Vt9iw4p39qLznPr6Irsu12S+HJgeaj4VKFwMSiJJ1LfYhF/REIwbl0aPaclJWH9fGU739CX8bhulMCiJVCDYTYiez694Z6/ACsPjuRzMaDRgdqn4AzfihUFJJFigmxCfnFmEYRkpONndi5OOPlU+bgciw4lAkWBQEgkUqFHc1tWLh7ao9zE7FFnbgAJhUBIJIlOjeLhkbwMKhCecEwmi1psQQ8lISfL7upovB4sVR5REgsg0j5eabETd3AmwZF0aLW7dbx80r6qFNqBAGJREgoy4Sp7DXV6eV4I7JuR7P55RnIdpRZagK/VawqAkEqC+xYanP9gnuoywPDDJOiAkPZKMBilPAooGg5JIQf76I7fut0txJNrwjBQ8M7sYd0zQ3qN0pBiURArx1x9pTkvG+f6Lqg7JoWnJWP/jMtw0Vt1XyCYSg5JIAYH6I7vOq79pvPP8BRgNBobkFdgeRBRn/RddeOLdFlWPGkORaUU+ERiURHFU32LDTXV/RkdPv+hSYqK1nTWx4qM3UZzIcm9NMFrdWRMrjiiJ4kAL2xG1vLMmVhxREkXBt+3H5XZLuR3xSlreWRMrBiVRhPy1/QyR8Nls8a1WTL0mVxc7a2LFoCSKQKB5yIsuIeVE7eV5JZhbPkp0GdKQ8O9BIjG0MA/pkTc0TXQJUuGIkuiyUNcxyHosmq88rmpHjEFJBP/zjkPTkrFoohXLpo5DktGgiSZsA7iqHQ0+epPueeYdfUeLnecv4JU//xPlz25FfYtNqibsFCOQ6rPClGdOVd3d2rLgiJJ0LZx5x85zF/Dgm834xT1lGJaejLMSXPL1/LxS3FmSr5vzIpXGoCTdcrrc2Pxpa9jzjo/9/u8YNSxViqC0ZKXq6rxIpTEoSZf8zUmG0tPvxJftPQpWFR9crIk/BiXpjhb2ZAfDxZr442IO6YqWeiF9DUtPxkYu1iiCI0rSFa30QgLAsPQhqBo7HGNzMlFVOJwnkiuIQUm6ooVeSAB4pPpqb38nKY+P3qQrR0+rfzEmFAOAt/7vuOgydIVBSbpR32LDK38+KLqMmLkB2Lp60dTaIboU3WBQki44XW6seGev6DLiSivTCDLgHCVpku8BFxedLnRK0CgeCZm2VMqOQUma46+Z3Hffs8x4r03iMShJUwI1k/fKdrJuALzXRgwGJWmGlpvJPXivjRgMStIMLTWT+7pv4hhUF1l4ApAgDEqSmtPlxs4jZ9B4+Az+2e4QXU7c5XEEqQoMSpJWfYsNK97Zq7nV7CdnXosRmSaeIakiqlgKXL9+PcaMGYPU1FRUVlaiqalJdEmkcvUtNjz4ZrOmQtKASyPIhROtmF36LVQVcu+2WggPyrfffhs1NTWora1Fc3MzSkpKMH36dJw8eVJ0aaRSTpcbT3+wX3QZccXVbHUTHpQvv/wyFi9ejEWLFqGoqAgbN25Eeno6fvWrX4kujVRq5+EzsDu0tWhj4X02qiZ0jrK/vx+7d+/GypUrva8ZjUZUV1ejsbHR79f09fWhr6/P+7HDob0JfBqs/6ILv208ih0HT+GzQ6dFlxMXwzNSsGrmtbCY0zgXqXJCg/L06dNwOp3Izc0d8Hpubi4OHDjg92vq6uqwevXqRJRHgnm2Ib6x4zC2fXlKU/2RBgBrflDMEaQkpFv1XrlyJWpqarwfOxwOFBQUCKyIlBDNnTayYMuPfIQG5YgRI5CUlIT29vYBr7e3t8Nisfj9GpPJBJPJlIjySBAt32nz5MxrsXCilY/ZkhG6mJOSkoLy8nI0NDR4X3O5XGhoaEBVVZXAykgULW9D9LT+MCTlI/zRu6amBgsWLMANN9yAiooKrF27Fj09PVi0aJHo0iiBPPOR2/95UvrHbQMwIOjZ+iM/4UF511134dSpU3jqqadgt9tRWlqK+vr6QQs8pF1amY9cfKsV5aOHDfp34UEW8jO43W6pn3IcDgfMZjO6urqQlZUluhzy4XuArm8bjBbmIzNMSXjxhxNwx4R8AKH/nUkdIskO4SNK0i5/I8XsjBQ8O7sYd0zI8+6wkTkkAeCN+Tdg4vgR3o+TjAZUFQ4XWBHFG4OSFBFopNjR04+HtjTjgX9ZkZ4yROodNp6Txm9iKGoeg5LiLpyV69d3tCasHiVwgUZfGJQUd1o+QNeDCzT6wqCkuNPyNar/NnEMpvGkcd1hUFLMfFd5/7TPLrokRRgAfNxix3/M5OO23jAoKSZa6YEMhxuArasXTa0dXNXWGQYlRU0LPZDR0PLUAvkn/OBekpMW92RnpYY3bhiZmapwJaQ2DEqKitZWtg0Anps7AXnmVASaffTcaVNhzU5gZaQGfPSmsPgu2Ni7zosuKW6GpSejbu51mFGcB6MRWPJmMw+2oAEYlBSU0+XGuk8OYtOnR9F5/psbD7MzUgRWFT+zJliw9u4yb/jNKM7DhvllPNiCBmBQUkD1LTY8/oc96Dp/cdDnOnr6BVQUX6/96HrMKskf9PqM4jxMK7LwYAvyYlCSX557s7XqF/dc7z3txx8ebEFXYlDSIE6XGyve2Su6DEXwvhqKBoOSBtl5+Aw6z10I/UYJDEtPxsKbx2DMiAw+QlPUGJQ0yG93HhVdQlwMz0hB48rvIWUIu+AoNgxKGuCjPTbU72sP/UYV84wX1/ygmCFJccGgJK/6Fhse2iL/Ag5beSjeGJQEQBsLOPdNHINqHoFGCmBQEgBg3ScHpV3AuXJnDZESGJQEp8uNTZ8eFV1GVP596jj8tPpqjiBJUZzppkvtQOflHE1WFY5gSJLiOKLUEX/3Tf9Pix3L3/5cdGlR49mQlAgMSp3wdxJ5ekoSzvU7BVYVO54NSYnAoNSBQCeRyxySnju1eTYkJQLnKDVOiyeR82xISjSOKDVOayeRA2wop8RjUGqcFhY7hqYNwaKJVh5sQcIwKDVO5sWO24tzMb9yDG4qHM5gJKEYlBrmdLmx68iZQfe/qJ3RAKz7UfCDdYkSiUGpEZ4eSbujFx1f9+FfZ8/hv3b/C1/3ybeyve5HZbhjAucfST0YlJLxbRovHz0MG7YdxqZPW6XdXePB08dJrRiUEvHXNC7bY7U/c0rycFfFaC7SkGoxKCURqGlc9pA0GoAX5pXygF1SNf7plIAWm8Y9Ft9qZUiS6vFPqARkbxpPGWKE7wO10QA8MMmKlXcUCamJKBJ89JaAjE3jV5mG4JnZ34XFnIYKazacLjd+23gUxzrOYXR2Ou6tGsORJEmDQSkBGZvGv+67CIs5DVWFwwEASUYD7rt1rOCqiKLDv9IlUGHNRp45ddDjq9rJOBIm8odBKYEkowG1sy7N5ckUljKOhIn8USwo16xZg5tvvhnp6ekYOnSo3/ccP34cM2fORHp6OkaOHInHHnsMFy9eVKokqc0ozsOG+WVS9BkacKl5nGdFklYoNkfZ39+PefPmoaqqCr/85S8Hfd7pdGLmzJmwWCz47LPPYLPZ8JOf/ATJycn4+c9/rlRZUvLsxunpvYiLLnU0CXka3X0b3nlWJGmRwe12K/p/3ubNm7F8+XJ0dnYOeP3jjz/G97//fbS1tSE3NxcAsHHjRjz++OM4deoUUlJSwvr9HQ4HzGYzurq6kJWVFe/yhfO3G0cNfnFPGYxGDKqN2xBJFpFkh7BV78bGRlx33XXekASA6dOnY8mSJdi3bx+uv/56v1/X19eHvr4+78cOh0PxWkUJtBtHJN8gnFZkGXRhGUeSpDXCgtJutw8ISQDej+12e8Cvq6urw+rVqxWtTQR/h12obTfOkzOvxcKJ1gFBmGQ0eFuAiLQqoqBcsWIFnn/++aDv+cc//oFrrrkmpqKCWblyJWpqarwfOxwOFBQUKPb9EsHf43V2Rgo6evoFVvUNz0VeviFJpBcRBeWjjz6KhQsXBn3P2LHhNRVbLBY0NTUNeK29vd37uUBMJhNMJlNY30MGgR6v1RKSHlycIT2LKChzcnKQk5MTl29cVVWFNWvW4OTJkxg5ciQAYOvWrcjKykJRkT72//ZfdOGJd/eq6vHalyXLhKfv/C4XZ0jXFJujPH78ODo6OnD8+HE4nU588cUXAIBx48bhqquuwm233YaioiLce++9eOGFF2C327Fq1SosXbpUUyPGQOpbbHji3RZ09Kj3sN1Hqq/GsqnjOJIk3VOsPWjhwoX49a9/Pej1v/zlL5g8eTIA4NixY1iyZAm2bduGjIwMLFiwAM899xyGDAk/v2VsD1LjaravR6rH46fVV4sug0gxkWSH4n2USpMtKJ0uN255/hPV9UVeKc+cir8+PpUjSdK0SLKDe70TyOlyY/OnrcJC0jTEgJfnleCR6vEABu8bN1z+xYUbooF4zJqCruyNPHq6B79rOg67oy/0FyrkocnjMLd8FADgO5bMQS1JFu6qIfKLQakQNW49HDMiw/vPM4rzuKuGKEwMSgWodbHG99gz7qohCg/nKONMrReB8dgzougxKONMjReBcYGGKDZ89I4Dp8uNnUfO4LPDp7HryBnR5QxwlWkIXpo3gQs0RDFgUMaovsWGFe/sRec5de6weWY2tx8SxYpBGYP6FhsefLNZdBlBWcxpoksgkh6DMkpOlxtPf7BfdBkBeY5G4wIOUey4mBOlptYO2B2JXbQZmp4MIPRNjLy3hii+GJRRSvSd1Q9NLsTuVdOwcX4ZLOaB/ZC+WWgxp2LD/DLOTRLFCR+9o5ToO6tvHZ+DJKPB746a8tHDsPvYWe6wIVIIgzJKFdZsWLJSE/L47dss7m9HDXfYECmHj95R8Bx2ccd1ga+siBc2ixOJxxFlEL43I1ZYs7F1v33QYRcGAxDLqZ4GXFqoMQ0xDjhdiHdkE6kDgzIAf6f/DE1P9ttYHmtIAkDd3Ot4mg+RSjEo/Qh0+o8Su2+yM1Kw5gfF3lEj5xqJ1IdzlD4SffrPqpnX8tGaSOUYlD4SffoPtxgSqR+D0kesjeSPVF/t3UETjAE8I5JIFgxKH9E2knuCb9nUcdi9ahoeqR6P9JSkgO8F2PZDJAsGpY/y0cMGbQkMxTf4kowG/LT6aux9ejoeqR6PoWkDR5jcYkgkF656+9h97CxcEa7kBLq90BOYy6aOZ9sPkcQYlD7CnaNcNqUQ43Mzwwo+XuJFJDcGpY9w5ygnjsth+BHpBOcofZzt6Q86R8nVaiL94YjyCvUtNizdEvo+bq5WE+kLR5SXhbMjx2gA1t9zPVeriXSGQXlZODtyXG5gWIYpQRURkVro6tHb37FpnkfocFe7E30FBBGJp5ug9Hds2pXnPYa72p3oKyCISDxdPHp7jk3zfbS2dfViyZvNqG+xocKajTxzasAbDrnaTaRfmg/KUIs0bgAr39kL4NJqNjD4OljuzSbSN80HZTiLNGfPXcC6Tw5iRnEeNvi5DpZ7s4n0TfNzlOEuvmz69CiWTR3v9zpY7s0m0jfNB2W4iy+d5y+gqbUDVYXDuTebiAbQ/KN3hTV70DFngbD1h4j80XxQJhkNWDTRGtZ72fpDRP5oPigBYNnUcUGvZ2DrDxEFo4ugTDIa8Nzc6/x+jq0/RBSKLoISAGYU52Hj/DLksfWHiCKk+VXvK7H1h4iioaugBHgtAxFFTvqgdLsvbU50OByCKyEimXgyw5MhwUgflN3d3QCAgoICwZUQkYy6u7thNpuDvsfgDidOVczlcqGtrQ2ZmZkwGPQ51+hwOFBQUICvvvoKWVlZosuRCn920ZP9Z+d2u9Hd3Y38/HwYjcHXtaUfURqNRowaNUp0GaqQlZUl5R9YNeDPLnoy/+xCjSQ9dNMeREQULQYlEVEIDEoNMJlMqK2thcnEi88ixZ9d9PT0s5N+MYeISGkcURIRhcCgJCIKgUFJRBQCg5KIKAQGpYYcPXoU9913H6xWK9LS0lBYWIja2lr09/eLLk2V1q9fjzFjxiA1NRWVlZVoamoSXZLq1dXV4cYbb0RmZiZGjhyJOXPm4MsvvxRdluIYlBpy4MABuFwuvP7669i3bx9eeeUVbNy4EU888YTo0lTn7bffRk1NDWpra9Hc3IySkhJMnz4dJ0+eFF2aqm3fvh1Lly7Fzp07sXXrVly4cAG33XYbenp6RJemKLYHadyLL76IDRs24MiRI6JLUZXKykrceOONWLduHYBLZwYUFBTg4YcfxooVKwRXJ49Tp05h5MiR2L59OyZNmiS6HMVwRKlxXV1dyM7mXUBX6u/vx+7du1FdXe19zWg0orq6Go2NjQIrk09XVxcAaP7PGINSww4dOoTXXnsNDzzwgOhSVOX06dNwOp3Izc0d8Hpubi7sdrugquTjcrmwfPlyTJw4EcXFxaLLURSDUgIrVqyAwWAI+uvAgQMDvubEiROYMWMG5s2bh8WLFwuqnLRs6dKlaGlpwVtvvSW6FMVJf8yaHjz66KNYuHBh0PeMHTvW+89tbW2YMmUKbr75ZrzxxhsKVyefESNGICkpCe3t7QNeb29vh8ViEVSVXJYtW4YPP/wQO3bs0MUxhwxKCeTk5CAnJyes9544cQJTpkxBeXk5Nm3aFPJAUj1KSUlBeXk5GhoaMGfOHACXHiMbGhqwbNkyscWpnNvtxsMPP4x3330X27Ztg9VqFV1SQjAoNeTEiROYPHkyRo8ejZdeegmnTp3yfo4jpYFqamqwYMEC3HDDDaioqMDatWvR09ODRYsWiS5N1ZYuXYotW7bg/fffR2ZmpndO12w2Iy0tTXB1ymF7kIZs3rw54P/o/M882Lp16/Diiy/CbrejtLQUr776KiorK0WXpWqBrlvZtGlTyOkhmTEoiYhC4AQWEVEIDEoiohAYlEREITAoiYhCYFASEYXAoCQiCoFBSUQUAoOSiCgEBiURUQgMSiKiEBiUREQhMCiJiEL4f5xtcYmci3NEAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(3.5, 2.5))\n", "plt.scatter(data.X, data.y)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 29, "id": "033e7a05-cb0a-4310-a5a5-d4d2dd519f67", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "error in estimating w: tensor([0.3010], grad_fn=)\n", "error in estimating b: tensor([-0.0321], grad_fn=)\n" ] } ], "source": [ "print(f'error in estimating w: {data.w - model.w.reshape(data.w.shape)}')\n", "print(f'error in estimating b: {data.b - model.b}')" ] }, { "cell_type": "markdown", "id": "023514a2-5b7b-4c9e-afe1-e1283f6c356f", "metadata": {}, "source": [ "## 3. Can you use [Planck’s Law](https://en.wikipedia.org/wiki/Planck%27s_law) to determine the temperature of an object using spectral energy density? For reference, the spectral density of radiation emanating from a black body is \n", "$$\n", "B(\\lambda, T)=\\frac{2 h c^2}{\\lambda^5} \\cdot\\left(\\exp \\frac{h c}{\\lambda k T}-1\\right)^{-1}\n", "$$\n", "\n", "Here \n", "* $\\lambda$ is the wavelength\n", "* $T$ is the temperature \n", "* $c$ is the speed of light \n", "* $h$ is Planck’s quantum \n", "* $k$ is the Boltzmann constant\n", "\n", "You measure the energy for different wavelengths $\\lambda$ and you now need to fit the spectral density curve to Planck’s law." ] }, { "cell_type": "code", "execution_count": 5, "id": "53e49858-a571-4a7b-be25-cd9333e881e3", "metadata": { "tags": [] }, "outputs": [], "source": [ "def f(t, x):\n", " c = 299792458\n", " h = 6.6260701e-34\n", " k = 1.380649e-23\n", " beta = h*c/(k*x)\n", " alpha = 2*h*c**2/x**5\n", " return alpha/(torch.exp(beta/t)-1)\n", "\n", "class SyntheticPlankData(d2l.DataModule):\n", " def __init__(self, T, noise=0.05, num_train=1000, num_val=1000,\n", " batch_size=32):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " n = num_train + num_val\n", " self.X = torch.arange(300, n) * 1e-9\n", " noise = random.random() * noise\n", " self.y = f(T, self.X)*(1+noise)\n", " \n", " def get_tensorloader(self, tensor, train, indices=slice(0, None)):\n", " tensor = tuple(a[indices] for a in tensor)\n", " dataset = torch.utils.data.TensorDataset(*tensor)\n", " return torch.utils.data.DataLoader(dataset, self.batch_size,\n", " shuffle=train)\n", "\n", " def get_dataloader(self, train):\n", " i = slice(0, self.num_train) if train else slice(self.num_train, None)\n", " return self.get_tensorloader((self.X, self.y), train, i)\n", "\n", "\n", "class PlankModel(d2l.Module):\n", " def __init__(self, T, lr, sigma=0.01):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " # self.T = torch.normal(0, sigma, (1,), requires_grad=True) * 4500\n", " self.T = torch.Tensor([T])\n", " self.T.requires_grad = True\n", "\n", " def forward(self, X):\n", " return f(self.T, X)\n", "\n", " def loss(self, y_hat, y):\n", " l = (y_hat-y)**2/2\n", " return l.mean()\n", "\n", " def configure_optimizers(self):\n", " return d2l.SGD([self.T], self.lr)" ] }, { "cell_type": "code", "execution_count": 7, "id": "8cba8689-5309-4ac2-bd28-53c7c14e01b8", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAENCAYAAACGtkfvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5m0lEQVR4nO3deVxU5f4H8M8w7OCggGxKgvvKoiahmVYUGtckc/1pqGmLaamUBuZFNK+ouV2vlEumlpn7UmqYUriBUiAJbrngDogajMM2MnN+f0yMDtucczhnFub7fr3Oq8uZ53nmO3Pxyznn2SQMwzAghBBSLytjB0AIIeaAkiUhhLBAyZIQQligZEkIISxQsiSEEBYoWRJCCAuULAkhhAVKloQQwgIlS0IIYYGSJSGEsNDokuWxY8cwaNAg+Pj4QCKRYO/evZzql5eXY9y4cejWrRusra0RGRlZo8yJEyfQp08fuLm5wcHBAR07dsTy5cuF+QCEEJNkbewAhFZSUoLAwEC8/fbbGDJkCOf6KpUKDg4O+Oijj7Br165ayzg5OWHKlCkICAiAk5MTTpw4gffeew9OTk549913G/oRCCEmSNKYF9KQSCTYs2ePztVhRUUFPvvsM/zwww8oKipC165dsWjRIvTv379G/XHjxqGoqIjV1emQIUPg5OSE7777TrgPQAgxGY3uNlyfKVOmIC0tDVu3bsXZs2cxbNgwDBgwAJcvX+bd5pkzZ5Camop+/foJGCkhxJQ0utvw+ty8eRMbNmzAzZs34ePjAwD45JNPkJSUhA0bNmDBggWc2mvZsiUKCwtRWVmJ+Ph4TJw4UYywCSEmwKKSZXZ2NlQqFdq3b69zvqKiAm5ubpzbO378OBQKBU6dOoWYmBi0bdsWo0aNEipcQogJsahkqVAoIJVKkZGRAalUqvOas7Mz5/b8/f0BAN26dUNBQQHi4+MpWRLSSFlUsgwODoZKpcK9e/fQt29fQdtWq9WoqKgQtE1CiOlodMlSoVDgypUr2p9zc3ORlZUFV1dXtG/fHqNHj0ZUVBSWLl2K4OBgFBYWIjk5GQEBAYiIiAAAnD9/HkqlEg8fPsSjR4+QlZUFAAgKCgIAJCYm4plnnkHHjh0BaMZ2LlmyBB999JFBPyshxICYRua3335jANQ4xo4dyzAMwyiVSiYuLo7x8/NjbGxsGG9vb+aNN95gzp49q22jVatWtbZRZeXKlUyXLl0YR0dHRiaTMcHBwcyXX37JqFQqQ39cQoiBNOpxloQQIhSLG2dJCCF8ULIkhBAWGkUHj1qtxt27d9GkSRNIJBJjh0MIMSMMw+DRo0fw8fGBlVXd14+NIlnevXsXvr6+xg6DEGLGbt26hZYtW9b5eqNIlk2aNAGg+bAymczI0RBCzIlcLoevr682j9SlUSTLqltvmUxGyZIQwou+R3iNIlk2KuUKYGsUcD35yTmpHdB1KBCxFLB1MF5shFgwSpamoFIJnFgKpCys/XVVBfDn95oDEiBwDBDxBSVOQgyIkqWxHZwJpK/hUIEB/vxOc3h2B947AlhJ9VcjhDQIJUtjWuALKOX86xdkAvNcgTfWAYHDhYuLmASGYVBZWQmVSmXsUMyaVCqFtbV1g4cVUrI0lngX4dra8w5weC7wyTnh2iRGpVQqkZeXh9LSUmOH0ig4OjrC29sbtra2vNugZGkMQibKKorbmnbji4VvmxiUWq1Gbm4upFIpfHx8YGtrS5MteGIYBkqlEoWFhcjNzUW7du3qHXheH0qWhiZGoqzePiVMs6ZUKqFWq+Hr6wtHR0djh2P2HBwcYGNjgxs3bkCpVMLe3p5XOzQ33JDiXQ30PiInZGIQfK+ASE1CfJf0/4ahJD4PwIAP6ilhEiIoSpaGcHYnUJht+PelhEmIYChZik2tAnZPMN77U8IkRBCULMW2PYp/3S5DgWf6NzyGeO7b/BLzp1IzSLv6APuy7iDt6gOo1IbdFGHhwoWQSCSYNm2a9lx5eTkmT54MNzc3ODs7480330RBQYFOvZs3byIiIgKOjo7w8PDAjBkzUFlZqVMmJSUF3bt3h52dHdq2bYuNGzeK/nmoN1xMlUrg4n7u9SK/AoL+T/dcuQJY2IJvIMCiNsCnV3nWJ+YmKScPc386j7zicu05bxd7zBnUGQO6eov+/r///jvWrFmDgIAAnfPTp0/HgQMHsGPHDri4uGDKlCkYMmQITp48CQBQqVSIiIiAl5cXUlNTkZeXh6ioKNjY2GDBggUANJsQRkRE4P3338f333+P5ORkTJw4Ed7e3ggPDxftM9GVpZi+6sO9zpvrayZKALB31gwJkj3DL5ay+8BmmuVjCZJy8jBpc6ZOogSA/OJyTNqciaScPFHfX6FQYPTo0Vi3bh2aNWumPV9cXIz169dj2bJleOmll9CjRw9s2LABqampOHXqFADgl19+wfnz57F582YEBQVh4MCB+Pzzz5GYmAilUgkAWL16Nfz9/bF06VJ06tQJU6ZMwdChQ7F8+XJRPxfnZHns2DEMGjQIPj4+kEgk2Lt3b73ld+/ejVdeeQXNmzeHTCZDaGgoDh06pFMmPj4eEolE56jaZtZsKcuAB39xq9M2HOg2tP4y0dlA21f5xXTlEJCzm19dYhZUagZzfzqP2m64q87N/em8qLfkkydPRkREBMLCwnTOZ2Rk4PHjxzrnO3bsiGeeeQZpaWkAgLS0NHTr1g2enp7aMuHh4ZDL5Th37py2TPW2w8PDtW2IhXOyLCkpQWBgIBITE1mVP3bsGF555RUcPHgQGRkZePHFFzFo0CCcOXNGp1yXLl2Ql5enPU6cOME1NNOy6llu5a0dgDHb2ZUdswN4bhL3mABg53hNpxNplNJzH9a4onwaAyCvuBzpuQ9Fef+tW7ciMzMTCQkJNV7Lz8+Hra0tmjZtqnPe09MT+fn52jJPJ8qq16teq6+MXC5HWVmZUB+lBs7PLAcOHIiBAweyLr9ixQqdnxcsWIB9+/bhp59+QnBw8JNArK3h5eXFqs2KigpUVFRof5bLG7AYhRiUZYD8Frc6MTe5lR+wEIAVcIrdHy0dC/2BWRzfj5iFe4/qTpR8ynFx69YtTJ06FYcPH+Y9S8aUGfyZpVqtxqNHj+Dqqjub5fLly/Dx8UHr1q0xevRo3LxZ9z/mhIQEuLi4aA+T23/n65e4le84CLDmMcF/wAIghMcVprIYOBjDvR4xeR5N2CUptuW4yMjIwL1799C9e3dYW1vD2toaR48excqVK2FtbQ1PT08olUoUFRXp1CsoKNBeKHl5edXoHa/6WV8ZmUwGBwfx1ng1eLJcsmQJFAoFhg9/0tkQEhKCjRs3IikpCV999RVyc3PRt29fPHr0qNY2YmNjUVxcrD1u3eJ4FSemSiVw7zy3OsM38X+/gQuBNjyeYaZ/pYmVNCq9/F3h7WKPupbdkEDTK97LX/ipty+//DKys7ORlZWlPXr27InRo0dr/7eNjQ2Sk5/sAnDp0iXcvHkToaGhAIDQ0FBkZ2fj3r172jKHDx+GTCZD586dtWWebqOqTFUbYjHo0KEtW7Zg7ty52LdvHzw8PLTnn76tDwgIQEhICFq1aoXt27djwoSaA7rt7OxgZ2dnkJg52zeFW/nnP2n44r1v7QC+6AKU3OZW7z9ewBxxnl0R45BaSTBnUGdM2pwJCaDT0VOVQOcM6gyplfCrGDVp0gRdu3bVOefk5AQ3Nzft+QkTJiA6Ohqurq6QyWT48MMPERoaiueeew4A8Oqrr6Jz58546623sHjxYuTn52P27NmYPHmy9t/8+++/j1WrVmHmzJl4++238euvv2L79u04cOCA4J/paQa7sty6dSsmTpyI7du31+jJqq5p06Zo3749rly5YqDoBKJWAdnbuNV5aZYw7z3jHDj/38mogNV9hXl/YjIGdPXGV2O6w8tF91bby8UeX43pbpBxlnVZvnw5/vWvf+HNN9/ECy+8AC8vL+ze/WSEhlQqxf79+yGVShEaGooxY8YgKioK8+bN05bx9/fHgQMHcPjwYQQGBmLp0qX4+uuvRR1jCRjoyvKHH37A22+/ja1btyIiIkJveYVCgatXr+Ktt94yQHQCuvobt/JCXFU+Le6+ZuV0LvLPauauB+gZskTMyoCu3nilsxfScx/i3qNyeDTR3HqLcUVZn5SUFJ2f7e3tkZiYWO9omlatWuHgwYP1ttu/f/8aI2rExvnKUqFQaJ9HAJrR9FlZWdoOmdjYWERFPZnit2XLFkRFRWHp0qUICQlBfn4+8vPzUVz8ZM3FTz75BEePHsX169eRmpqKN954A1KpFKNGjWrgxzOwnz7iVl6oq8oqVlJgyHru9XZPoOFEjZDUSoLQNm4YHNQCoW3cDJ4oGxvOyfKPP/5AcHCwdthPdHQ0goODERcXBwDIy8vT6cleu3YtKisrMXnyZHh7e2uPqVOnasvcvn0bo0aNQocOHTB8+HC4ubnh1KlTaN68eUM/n+FUKgH5Hfblhb6qrBIwFGjejXu9xf7Cx0JIIyJhGMaws+tFIJfL4eLiguLiYshkMuMEcWIFcGQO+/JxD8XdlTG+KVDrPI569JoEvFbHdrzEYMrLy5Gbmwt/f/9GOV7RGOr7TtnmD5obLpSji9mX9e8n/va1s+/pL1MdDScipE6ULIWgLAMel7AvP4pjjzkf1rZAr3e511vSTvhYCGkEKFkKIYlDR42NE2Ar3iwDHa99Adg04VanvAj4k+UcdUIsCCVLIWRvZV+23wzx4qhN7A3udfa8Q73jhFRDybKhlGXA41L25Z+bLF4steE7nGjdK8LHQogZo2TZUFxuwZ08+S2Y0VABQ4GmrbnVycvQ/CEghACgZNlw53axL9vbwFeVT/voD+51FprYak6EGBEly4aoVAIVxfrLVeGznJpQ+NyOqx8DB2eKEw8Rn1oF5B4Hsndq/ivyc2h9Ox6Y82ZlACXLhjm1mn1ZY92CPy1gKODCcaZO+hoae2mOzv8IrOgKbPoXsGuC5r8rumrOi6i+HQ+mT5+On376CTt27MDRo0dx9+5dDBkyRPt61WZlSqUSqamp2LRpEzZu3KidHQg82azsxRdfRFZWFqZNm4aJEyfW2KpGDJQsGyLzW/ZljXkL/rSpGdzr0NhL83L+R80WzPK7uufleZrzIibMqh0Pqg53d3cA5r9ZGUDJkj+1Cnh4mX15Y96CP81KCryxjlsdGntpPtQqIOlT1D7V9Z9zSTGi3ZLXteOBuW9WBlCy5O/aMfZl7ZsZ/xb8aYHDARuOc+hp7KV5uJFa84pSB6NZ8OVGquBvXd+OB+a+WRlg4JXSG5Ws79mXbfuyeHHw9elVYD7HVZ2WdQI+4bi9LzEsRYH+MlzKcVDfjgdi7o1jKHRlydeVI+zLBv6feHHwZW0LPDuRWx1FAd2OmzpnT/1luJRrgKd3PPDy8jLrzcoASpb8VCqB8r9ZFpYAbfqLGQ1/EUsBcFz9iG7HTVur3oDMB6hvyzJZC005kVXteODt7Y0ePXqY9WZlACVLftLXsi/r2kb85dgaYhaHBYurrOexmyQxDCspMGDRPz9UT5j//DxgoSi/k/XteODi4qLdrOy3335DRkYGxo8fX+dmZX/++ScOHTpU62Zl165dw8yZM3Hx4kV8+eWX2L59O6ZPny7456mOkiUf10+yL9sjSn8ZY7J1AHx6cqtz5w+aCmnKOr8ODP8WkFXbmEzmoznf+XVR3lbfjgfmvFkZQCul87O4DVB6n13Z2YWm1RNeG7WK+0Znts78rkqJXoKtlK5WaXq9FQWaZ5Stepv2XY6IaKV0Y6hUsk+UpjZkqC58pkIqFdTZY+qspIB/X6DbUM1/LTRRCoWSJVdcpji6txcvDqEFDAWcOO4nTZ09xIJQsuTq0gH2ZTvp3yPdpEw/y70OrXtJLAQlS64KL7EvaypTHNnis28PrXtJLATnZHns2DEMGjQIPj4+kEgk2Lt3r946bJZUSkxMhJ+fH+zt7RESEoL09HSuoYmPy/hKx+bm8byyute+AKw4xk3rXoqiEfS9mgwhvkvOybKkpASBgYFITExkVZ7Nkkrbtm1DdHQ05syZg8zMTAQGBiI8PFxncKpJ4DK+suWz4sUhtpib3MrTupeCsrGxAQCUlnLYroTUq+q7rPpu+WjQ0CGJRII9e/YgMjKyzjKffvopDhw4gJycHO25kSNHoqioCElJSQA0E/CfffZZrFq1CgCgVqvh6+uLDz/8EDExMXrjMNjQobUvAncz2ZV9dT7Q+0PxYhHbmpc0t9hcmMMwKTORl5eHoqIieHh4wNHRERJJXTNySH0YhkFpaSnu3buHpk2bwtu7Zicm2/wh+kIadS2pNG3aNACAUqlERkYGYmNjta9bWVkhLCyszmWXKioqUFFRof1ZLpcLH3h1ahVwN4t9+V7viRaKQbxzmPvYy/8GAB9fFCceC1M1F9rk7q7MVNOmTbXfKV+iJ0t9Syr9/fffUKlUtZa5eLH2f3gJCQmYO3euaDHX6voJAGp2ZZ29zP8Kq2rdyz3vsK/zKA84u1MzDIk0iEQigbe3Nzw8PPD48WNjh2PWbGxsIJU2fIypWS7RFhsbi+joaO3Pcrkcvr4idzJcTWFf1r+vaGEYVOBwYP/HwGMOV+67JwBd36AB0AKRSqWC/EMnDSf60CF9Syq5u7tDKpXWWqauy2Y7OzvIZDKdQ3Rsn1UCQMAo8eIwtE+vcq+zrJPwcRBiZKInS31LKtna2qJHjx46ZdRqNZKTkw2y7BJrBRfYlZNYme6SbHxY2wId/8WtDq17SRohzslSoVAgKysLWVlZADRDg7KysrR7bcTGxiIq6slKO2yWVIqOjsa6deuwadMmXLhwAZMmTUJJSQnGjx/fwI8nkEolUMpyZWkX38Z3Czqcw8ZsVWgqJGlkOCfLP/74A8HBwQgODgagSXTBwcHa7Srz8vK0iRNgt6TSiBEjsGTJEsTFxSEoKAhZWVlISkqq0eljNFzGV3p0ES8OY+GzyRlA616SRoWWaGNjyyjgr4Psypr7+Mr6LGoLlBVyqzMrX7NmJiEmipZoE1LBOfZlzX18ZX1mcJgXX2Wxn+BhEGIMlCz1UauA4hvsyjp7mv/4yvrwWfeyshw4qH8WFiGmjpKlPtdPsC/rbgFDZvise5n+laaTjBAzRslSHy6D0VsGixaGSeGz7uWSdsLHQYgBUbLUh8tgdP/+YkVhWvise1leRGMviVmjZKlP8W125aysG880Rzb4rHtJYy+JGaNkWR+1Cnh4jV1Zt/aNbzC6PlzXvQSAlRy33SXERFCyrA+XlYY8Oosaiknis+d40TXNykSEmBlKlvXh0rnj2kq0MEzaxF+419k9gW7HidmhZFkfLp07fi+IF4cp4zP2EgAW+gsfCyEiomRZH+rcYSdgKODCMfkpi2mwOjErlCzrolYBD1mu5ShrYXmdO9VN5bhfD0CD1YlZoWRZl+snALBcY6QxrjTEFd+ViRZZ6LNeYnYoWdaFS+eOX2/RwjArgcMB55bc6jwupdtxYhYoWdaFS+dOY15piKtPOKzQVIVux4kZoGRZF7adO41hJ0ehzcrnXiehhfBxECIgSpa14dK542wiq7mbElsHoPUr3OqolMDBmeLEQ4gAKFnWhkvnjoyuiGoVtROAhFud9DV0O05MFiXL2lDnjjBm5XGv85/atz8mxNgoWdaGOneEwed2nFEBqy14gD8xWZQsa0OdO8Lhczuef5YW2yAmh5JldVyWZXPvIG4sjcXse9zr0GIbxMTwSpaJiYnw8/ODvb09QkJCkJ6eXmfZ/v37QyKR1DgiIiK0ZcaNG1fj9QEDBvAJreG4LMvm1FzUUBoNa1vg2Ync683zED4WQnjinCy3bduG6OhozJkzB5mZmQgMDER4eDju3av96mH37t3Iy8vTHjk5OZBKpRg2bJhOuQEDBuiU++GHH/h9oobKPcq+rKUuy8ZHxFIA1hwrVQIrAsSIhhDOOCfLZcuW4Z133sH48ePRuXNnrF69Go6Ojvjmm29qLe/q6govLy/tcfjwYTg6OtZIlnZ2djrlmjVrxu8TNdTfHFb/ttRl2fiazaN3vOgG8HOs8LEQwhGnZKlUKpGRkYGwsLAnDVhZISwsDGlpaazaWL9+PUaOHAknJyed8ykpKfDw8ECHDh0wadIkPHjwoM42KioqIJfLdQ7BKFg+X7P0Zdn44LPRGQCc/pLGXxKj45Qs79+/D5VKBU9P3Vkrnp6eyM/XP8UtPT0dOTk5mDhR9/nVgAED8O233yI5ORmLFi3C0aNHMXDgQKhUtT/gT0hIgIuLi/bw9fXl8jHqV3iRXblmfrQsGx+vfQHY8rhrmE8zpYhxGbQ3fP369ejWrRt69eqlc37kyJF4/fXX0a1bN0RGRmL//v34/fffkZKSUms7sbGxKC4u1h63bt0SJsBKJVBSwK5sE46r65AnZl3nUUlNzy+JUXFKlu7u7pBKpSgo0E0oBQUF8PKqf+ZFSUkJtm7digkTJuh9n9atW8Pd3R1Xrlyp9XU7OzvIZDKdQxDpa9mXbRkszHtaqtmF3OvQ80tiRJySpa2tLXr06IHk5GTtObVajeTkZISGhtZbd8eOHaioqMCYMWP0vs/t27fx4MEDeHt7cwmv4W6cZF/Wv79YUVgGen5JzAzn2/Do6GisW7cOmzZtwoULFzBp0iSUlJRg/PjxAICoqCjExtb8679+/XpERkbCzc1N57xCocCMGTNw6tQpXL9+HcnJyRg8eDDatm2L8PBwnh+LJ2Upu3LUuSMM3s8vaXwrMTyuA98wYsQIFBYWIi4uDvn5+QgKCkJSUpK20+fmzZuwstLNwZcuXcKJEyfwyy81t02VSqU4e/YsNm3ahKKiIvj4+ODVV1/F559/Djs7O54fiydlGbtyPj2pc0cos64D8S7c681vAcy+I3g4hNRFwjAMy7XITJdcLoeLiwuKi4v5P79Uq4B57mA1e8e/PzB2H7/3ITVVKvldLbq0AqafFT4eYlHY5g+aG16FyzRHWyf9ZQh7fJ9fFt8Avh8ufDyE1IKSZRUua1i2qr8zi/DA9/nl5UNAzm7h4yGkGkqWVWgNS+PjNf4SwM7xtEIRER0lyyq0hqVp4DP+EgDmuekvQ0gDULIEuG1Q5t5R3FgsnbUtEPI+j4qMpoecEJFQsgSAa8fAeoMymrkjvoGLANkz3OtVKoDlNCWSiIOSJQD8uYV9Wf/+YkVBnhadDdg05V6PesiJSChZApo5x2zQzB3D+uwGAB6D/6mHnIiAkiUAlD1iV86tPc3cMbT4h/zqUQ85ERglS7UKuH+JXVmPTuLGQmrHu4fcVdg4iEWjZHn9BACWVyAS+rqMgncPOfjNOyekFvSvn8sGZc149NASYQxcBDThuSI+JUwiAEqWtEGZ+fg4B5A48KtLCZM0ECVL2qDMvMzRv9dTnShhkgagZClnuSYibVBmOuKLG1DXSFssE7Nn2cmSyzRH2qDMtMTxHFIENRDvLmgoxDJYdrKkaY7my0oKDP+OZ+XHlDAJZ5adLGmao3nr/DowdBPPypQwCTeWnSxpmqP56xoJvLmeZ+XHQDwt7UbYsexkSdMcG4duQ4F2A3hWrgTimwoZDWmkLDdZ0jTHxmX0NsAzkGdlhoYVEb0sN1nSNMfGZ9IxwKMb//qUMEk9LDcL0DTHxumDE/ynRQKUMEmdeCXLxMRE+Pn5wd7eHiEhIUhPT6+z7MaNGyGRSHQOe3t7nTIMwyAuLg7e3t5wcHBAWFgYLl++zCc09miaY+P1cQ5g14AVh+JdaHk3UgPnZLlt2zZER0djzpw5yMzMRGBgIMLDw3HvXt3TBmUyGfLy8rTHjRu6vdCLFy/GypUrsXr1apw+fRpOTk4IDw9HeXk590/EFk1zbNxicxuWMOe5Atm0gDB5gnOyXLZsGd555x2MHz8enTt3xurVq+Ho6IhvvvmmzjoSiQReXl7aw9PTU/sawzBYsWIFZs+ejcGDByMgIADffvst7t69i7179/L6UKzQNMfGLzYXsG/A0KBd44HNtEUF0eCULJVKJTIyMhAWFvakASsrhIWFIS0trc56CoUCrVq1gq+vLwYPHoxz585pX8vNzUV+fr5Omy4uLggJCamzzYqKCsjlcp2DE5rmaDlirvHb/KzKlUPAF12Ei4eYLU7J8v79+1CpVDpXhgDg6emJ/PzaV4Pp0KEDvvnmG+zbtw+bN2+GWq1G7969cfu2Zp/uqnpc2kxISICLi4v28PXl+ECfpjlaluhswLMBuz6W3AbiadV1Syd6b3hoaCiioqIQFBSEfv36Yffu3WjevDnWrFnDu83Y2FgUFxdrj1u3bnFrgKY5Wp5JxwEvvuMwAUBFHT8WjlOydHd3h1QqRUFBgc75goICeHl5sWrDxsYGwcHBuHLlCgBo63Fp087ODjKZTOfghKY5Wqb3jwHtwhvWBnX8WCxOydLW1hY9evRAcnKy9pxarUZycjJCQ0NZtaFSqZCdnQ1vb28AgL+/P7y8vHTalMvlOH36NOs2OStj+YyTpjk2PqO3A6GTG9bGrvHAd8OEiYeYDWuuFaKjozF27Fj07NkTvXr1wooVK1BSUoLx48cDAKKiotCiRQskJCQAAObNm4fnnnsObdu2RVFREb744gvcuHEDEydOBKDpKZ82bRrmz5+Pdu3awd/fH//+97/h4+ODyMhI4T5pFbUKuH+RXVmn5sK/v4GVKVX4bM8Z7D5TUOM1W6kErwf64PPIbnCwtaA/CuELgJa9gB1j+bdx9RdgfitgNsu7FGL2OCfLESNGoLCwEHFxccjPz0dQUBCSkpK0HTQ3b96EldWTC9a///4b77zzDvLz89GsWTP06NEDqamp6Ny5s7bMzJkzUVJSgnfffRdFRUV4/vnnkZSUVGPwuiCunwDrzh07Z+HfXyQPFUq88b/fcKO4knUdpYrBzsw72JlZcxiVBEC3FjJ8N+E5uDjaCBipiegSCXR62LDtciuLNM8x4x7SHYgFkDAMwzJzmC65XA4XFxcUFxfrf36ZPA84vpRdw6/OB3p/2PAARaBSMzjyZx6m7zyDUgP1OTR1sMbCIQF4pYsXpFYSw7ypIQgxxfGNdUAgjck0R2zzh+Uly50TgZwd7BqeXajZs9qElClVGLchDadzG7APjQCsADzfzh1fju4BZ3vONyimZ54HoK5oWBuOPsDMC8LEQwyGbf5oBL/lHLGd5ihraVKJskypwktLfkWeXGnsUAAAagDHLt9H1/hDAIA27o7Y8X4fuDqbznfGSdw9YHk3oJjDmgHVld7VXKXOygdseW7ZS0yW5SXLQpadOzIfceNgSVmpxoD/HsW1wlJjh1Kvq/dL0X3+YQBAgLk+65yeDXw/HLh8qGHtLPAC/MOAsbuEiYuYBMtaoq1SCZTU7BWulbWjuLGwELcvB+1n/2zyibK6s3fkCJz3C/xiDuCT7VkoU5rRQO7R24GhGxreTu4RzVVmpWncCZCGs6xkmb6WfVkjTnNUVqrRdtYBfJtm/sNSdmbeQae4JLT77AB+ycmHSm0Gj8i7DmnAVrvVzG8OHPhYmLaIUVlWsrxxkn1Z//5iRVGvuT+eQ/vZP6NSbZS3F81jFfDu5gy0mXUQ/RYlo7j0sbFDqp+VFIgvBqQCDF/7/WsgvhmgLGt4W8RoLCtZVpSwK2ekaY495h3GhtTrBn9fQ7vxd7n2Nj0pO8+0rzb/XQB4NWARDi215lnmpjcFaIsYg2Uly/ssV183wjTHNjEH8KDU8p5vvf99JtrMOojXlh+Fopz9gHqDev84MITvdrvVVD3LLFcI0x4xGMtJlpVKQHGXXVm7JuLGUo1fzAG2W6fxJpUAL7RzR058OHLiw/FCa9Paa+Z8gQJd4w/B31SvNgOG/jNTR6ChUQtbAMsCaRUjM2I5g9JTVwG/fMauwW7DgDe/Fi7AOqjUDNrMOihom94yO/w4pS+ay+x4x2TomUF16ezpjO2T+pjeoPfNQ4Erh4Vrb/BXQPD/Cdce4YRm8FT3wyjgEsvENHo30O5l4QKsxcGzefhgS6YgbdlYAadnvSLagHA+886FJJUAp2LDeP8BEIWyTPMMUjBWwKy7NJjdCChZVrdpMJCbor8xiRXw7/uiPrP8z4HzWHc8t8HtONlIcPqzVw165VXfKkaGMD2sHSb1bwtbaxN5grQ8ACgWcIiXcwvgowxKmgZEybK6HROAczv1N+YbCkxIEjbAp/znwDmsO369QW1IAfwZH27029OqW/bJ28+g0sC/RTJ7KVI+eck0plf+uR3Y846wbXp2B947QqsZGQDbZGkif54NwLUVu3J+vUUL4eDZuw1OlMve7IarCyOMnigBQGolQXiwD64kROCv+QMxNay1wd5bXq5C9/mH4RdzAP898heUxhyYGjhc0/ljLeDVYEGmZvm4zM3CtUkaxHKuLK+mAN8N1t/YW/uANv2FDA+AMJ05Vxe8ZhZLoynKK/HBt6dw7JphV0Zyc7LBrx+/aNw56QdnAun895eqU/QVQGb+i1GbIroNr06tAr5oA5T9XXdDDq7AjCui3Pr4xRzgXdeniTVSP2vg3jFG8lChRP/FRyBXGvbXbHj3lpgb2dU4K8BXKoEl7YDyIuHb/iQXcKadJoVEybI2538Etr9V9+vDvwM6vy54fA1JlC91cMM3458TMBrjUFaqkZhyCf89cs2g7+tgI8HJT8OM82xTjGeZVehKUzCULOty/kfg55nAo7wn55r4AAMXiZIoO8w6gAqej9NWjQzCv4JaCBuQCXioUKLPwsMoM/BIpKHdWxh+vyG1Clj3MpB3Rpz2h20GOr1GHUENQMmyPmoVcCMVUBQAzp5Aq96i/LL1SfgFd4r5LRhhLs8nG6JMqcLsvVnYlZlv0Pe1lQInPzXwuE1lGbCwJaAW6S9ElzeBN1ab1ILV5oKSpZHN/SkbG07yW3X7+sIIgaMxfca62mzt7oidhlzhPXML8OMk8dp37wC8e5TGaXJAydKIlJVqtJ/9M6+6lpgon2bMQe8vGGpPIbFvzQHAyg6Ydo6ea7JAydKI+HboWMKtNxeF8gqEJhwx+IB3wEAzhZRlwPJO9Y/QEEK/WUDf6XSLXgdRB6UnJibCz88P9vb2CAkJQXp6ep1l161bh759+6JZs2Zo1qwZwsLCapQfN24cJBKJzjFgwAA+oRlda56J8sv/606JsprmMjtcSYhATnw4OnsadpuP5Ucuo/3sn9EmVsRB77YOwKfXgZg7gJWIY0OPLtCs2L64LVBq3F1BzRnnK8tt27YhKioKq1evRkhICFasWIEdO3bg0qVL8PDwqFF+9OjR6NOnD3r37g17e3ssWrQIe/bswblz59Cihaand9y4cSgoKMCGDU/2PrGzs0OzZs1YxWQqV5a9FxzCXTn3h25v9/FD3KAuIkTUuFRNr3xvm4i3r/WwlUqwcmSwePumKx4CS1oDMMClNF1taol2Gx4SEoJnn30Wq1atAgCo1Wr4+vriww8/RExMjN76KpUKzZo1w6pVqxAVFQVAkyyLioqwd+9eLqFomUKy5Nuh08W7CQ5MfUGEiBq34tLHeGnJr3hQapyVkJxsJFg6XKTEKS8ElrUVts26SO00Cxtb8PAjUfYNVyqVyMjIQGxsrPaclZUVwsLCkJaWxqqN0tJSPH78GK6uurMQUlJS4OHhgWbNmuGll17C/Pnz4ebmVmsbFRUVqKio0P4sl8u5fAzBKSvVvBKlk42UEiVPLo42yIgLN9pg95LHDN7/XrPEnrUEiAwWcAynrLlm/x9DJE1VBbBjjOZ/O3sD76fSDKE6cLqyvHv3Llq0aIHU1FSEhoZqz8+cORNHjx7F6dOn9bbxwQcf4NChQzh37hzs7TWbQW3duhWOjo7w9/fH1atXMWvWLDg7OyMtLQ1Sac1fvvj4eMydO7fGeWNdWfrHHOB842QF4JqF93wLTVFeieFfHcf5AuNuHSx4r7ohrzSr+HQHxuwFHE1rRX0xiHIb3tBkuXDhQixevBgpKSkICKh7E6hr166hTZs2OHLkCF5+ueYivLVdWfr6+holWQ5c/isuFHDftc/ShwiJqerZ5vvbzhji6V+9BB3HKS8ElrWDQZ5pPs2tHTD+l0Z7xSnKbbi7uzukUikKCnTHwBUUFMDLq/5Vo5csWYKFCxfiyJEj9SZKAGjdujXc3d1x5cqVWpOlnZ0d7OyMv2r2j5m3eSXKv+YPFCEaUqVq6bjcYB+jrYBU5dr9UnSfr9mCwlYqweuBPvxv12XNgfgiTUfQ0nYAY6DntQ8uA0v8/4mhBTDxqEWO3+SULG1tbdGjRw8kJycjMjISgKaDJzk5GVOmTKmz3uLFi/Gf//wHhw4dQs+ePfW+z+3bt/HgwQN4e3tzCc+gVGoGH23/k3O98b39TGeVbwvgbG+Nb999HoBmllD4smQUlhpn7UulisHOzDvYmXkHQAOuOp1dgTkPNMOAlnUAKg24H7n8zpNHAjZOQO+pFtOrzmvo0NixY7FmzRr06tULK1aswPbt23Hx4kV4enoiKioKLVq0QEJCAgBg0aJFiIuLw5YtW9CnTx9tO87OznB2doZCocDcuXPx5ptvwsvLC1evXsXMmTPx6NEjZGdns7qCNEZveJd/H0AJx2nfbo62yIh7RZyACGtVt+mTtp2BEZcM1iGVAF19ZPh2wnPc1+MsVwDbxmq22TUmM71dF3UGz6pVq/DFF18gPz8fQUFBWLlyJUJCQgAA/fv3h5+fHzZu3AgA8PPzw40bNfcomTNnDuLj41FWVobIyEicOXMGRUVF8PHxwauvvorPP/8cnp6egn5Yocz58Sw2pd7iXI+eU5oeY+8pVBepBOjTlkdHkbwQWNYeMPafAakd0HUoELHU5Oep03RHkfCd9/3X/IF0+23iiksfY8zaE8jON25vem04J8/SYuB/PYCyQvGDY8OEkyclS5Hwmfc9vrcf5rxOM3TMSXHpY4xecxw5PDrwDIXVHvGVSuDEUiBloeECY8OEkiclSxEEzTmIogpuX5fM3hpn481zSwiiYex909nS29tergC2RgHXkw0fnF6Sf555HjL4M09KlgLj+5ySVhJqXB4qlBiyKgXXi/gt6mxojjZWeK9fm5orKCkeAonPAWWm9axWh4GuPilZCojvc8r/jgzC4Ea4LQTRMPYYTr6aOlhj4ZAAzbx2qIFz+4FdEwCYwR8AERIoJUsB8ZnO2MVbhgNT+woeCzFNptqrzpZUAvRtLcMa/xOwO7EMgNLYIbEnsQZa9weGbQLsnTlXp2QpED7TGW2sgMsLaJiQpVKpGfyanY8ZuzNRVKG/vCmyghp9rdKxRPoVXKWPYQXArB4mcZhpRMlSAD9m3uY1S4eeU5Knmevt+tNkkOMnmxloKXkEKwkgMZdfbzsZEFt/XwMlywZSqRm0mXWQcz16TknqUzV7aPrOMyhVGTsafuxRjhU2S/Gy5ByszSFx6kmYoiykYUn4JMou3jJKlKReVYt8nA/2AQDtepyrjlyDueTOctjj/cefAah2uy55bJpXnRVyzcymBi7+QVeWteAznlIqAa4m0HNK0jDm3lFUddX5ouQcrAHTSZ4uzwDTs2t9iW7DeRq3Pg0plx9yrkfTGYkYzPHK82kmkzxtnIDP7tb6Et2G8/D5/hxeiZKWXSNisbW2wvSwTpge1kl7rlBegdeWJ6OwzPSvc56+ZQeMmDwda9+ihgu6svzHwbN38cEW7rsG0rJrxNgU5ZWY/N1pHL1aZOxQOKv+zBMQNoEy+GfIU/SVOp9Z0m04B3x7vgFado2YJnNYCKQu1a8+AX4JtCqzySWOcInPq7McJUsO+KwkBNBzSmJeLC2BMgxQpHZE8OOv8fussDpXZ6JkyUJDrijf7uOHuEG07Boxb+bc+24LJeZZr8dgyXHYVkuadxhX/OvxAsihyQctm9rjREzN/bwASpZ67cu6g6lbs3i9n29TBxyPeYlXXULMQWOYdfQ0Rxspzn8+oNbXqDe8HhH/PY5zeXJedW2tQImSNHpPb/T2NHPtTHJ14rivUS0sLll2nH0Q5ZX8L6b/ogUyiAVztrfGpnf61PqaKS+SvOeDmomfK4tKls8vPNygREk934TUzdXZFkdja98VwJidSzJ76/q33mDJYpJlcelj3C7iv0YfJUpC+HNxtMH+6XU/vhIrmQq5rYvFJMu3N6bzrkuJkhBx6UumXFdrYrWZG0cWkyzvFpfzqkeJkhDjq75akzHwGlGdmJgIPz8/2NvbIyQkBOnp9V+17dixAx07doS9vT26deuGgwd1xzYyDIO4uDh4e3vDwcEBYWFhuHz5Mp/Q6uTjYs+5DiVKQkgVzsly27ZtiI6Oxpw5c5CZmYnAwECEh4fj3r17tZZPTU3FqFGjMGHCBJw5cwaRkZGIjIxETk6OtszixYuxcuVKrF69GqdPn4aTkxPCw8NRXs7varA234zrxak8JUpCyNM4D0oPCQnBs88+i1WrVgEA1Go1fH198eGHHyImJqZG+REjRqCkpAT79+/XnnvuuecQFBSE1atXg2EY+Pj44OOPP8Ynn3wCACguLoanpyc2btyIkSNH6o2J7aDSfl/8ihsP9D9ApkRJiOVgmz84XVkqlUpkZGQgLCzsSQNWVggLC0NaWlqtddLS0nTKA0B4eLi2fG5uLvLz83XKuLi4ICQkpM42KyoqIJfLdQ42js54Ca3c6t4+s6OHPSVKQkitOHXw3L9/HyqVCp6enjrnPT09cfHixVrr5Ofn11o+Pz9f+3rVubrKVJeQkIC5c+dyCV3r6IyXUFz6GFFfp+JCQQmsJMBzrd2w6v96wNneYvq7CCEcmWV2iI2NRXR0tPZnuVwOX19f1vVdHG2w76N+YoRGCGmkON2Gu7u7QyqVoqBAd4WSgoICeHl51VrHy8ur3vJV/+XSpp2dHWQymc5BCCFi4nRlaWtrix49eiA5ORmRkZEANB08ycnJmDJlSq11QkNDkZycjGnTpmnPHT58GKGhoQAAf39/eHl5ITk5GUFBQQA0V4qnT5/GpEmTWMVV1UfF9tklIYRUqcobevu6GY62bt3K2NnZMRs3bmTOnz/PvPvuu0zTpk2Z/Px8hmEY5q233mJiYmK05U+ePMlYW1szS5YsYS5cuMDMmTOHsbGxYbKzs7VlFi5cyDRt2pTZt28fc/bsWWbw4MGMv78/U1ZWxiqmW7duMdCsIE8HHXTQweu4detWvXmG8zPLESNGoLCwEHFxccjPz0dQUBCSkpK0HTQ3b96EldWTu/vevXtjy5YtmD17NmbNmoV27dph79696Nq1q7bMzJkzUVJSgnfffRdFRUV4/vnnkZSUBHt7dgPJfXx8cOvWLTRp0gQSk9h3s3ZVz1Zv3bpFjw5qQd+PfvQd6cf1O2IYBo8ePYKPT/2zgxrF4r/mQuj9zRsb+n70o+9IP7G+I9pAhhBCWKBkSQghLFCyNCA7OzvMmTMHdnbCLRvVmND3ox99R/qJ9R3RM0tCCGGBriwJIYQFSpaEEMICJUtCCGGBkiUhhLBAyVJgXLbc2LhxIyQSic7BdtaSOTp27BgGDRoEHx8fSCQS7N27V2+dlJQUdO/eHXZ2dmjbti02btwoepzGxPU7SklJqfE7JJFI6lze0NwlJCTg2WefRZMmTeDh4YHIyEhcunRJbz19W9uwQclSQFy33AAAmUyGvLw87XHjxg0DRmxYJSUlCAwMRGJiIqvyubm5iIiIwIsvvoisrCxMmzYNEydOxKFDh0SO1Hi4fkdVLl26pPN75OHhIVKExnX06FFMnjwZp06dwuHDh/H48WO8+uqrKCkpqbMOm61tWOG2jAapT69evZjJkydrf1apVIyPjw+TkJBQa/kNGzYwLi4uBorOtABg9uzZU2+ZmTNnMl26dNE5N2LECCY8PFzEyEwHm+/ot99+YwAwf//9t0FiMjX37t1jADBHjx6ts8zw4cOZiIgInXMhISHMe++9x+m96MpSIHy23AAAhUKBVq1awdfXF4MHD8a5c+cMEa5Z0LclCXkiKCgI3t7eeOWVV3Dy5Eljh2MwxcXFAABXV9c6ywj1e0TJUiD1bblR1/OjDh064JtvvsG+ffuwefNmqNVq9O7dG7dv3zZEyCavri1J5HI5ysr0bzxnCby9vbF69Wrs2rULu3btgq+vL/r374/MzExjhyY6tVqNadOmoU+fPjqrmFWnb2sbtsxyW4nGIjQ0VLsIMqBZzq5Tp05Ys2YNPv/8cyNGRsxFhw4d0KFDB+3PvXv3xtWrV7F8+XJ89913RoxMfJMnT0ZOTg5OnDhhkPejK0uB8NlyozobGxsEBwfjypUrYoRodurakkQmk8HBoe5dOi1dr169Gv3v0JQpU7B//3789ttvaNmyZb1l9W1twxYlS4E8veVGlaotN56+eqyPSqVCdnY2vL29xQrTrFRtSfK0p7ckIbXLyspqtL9DDMNgypQp2LNnD3799Vf4+/vrrSPY7xGfHihSO65bbsydO5c5dOgQc/XqVSYjI4MZOXIkY29vz5w7d85YH0FUjx49Ys6cOcOcOXOGAcAsW7aMOXPmDHPjxg2GYRgmJiaGeeutt7Tlr127xjg6OjIzZsxgLly4wCQmJjJSqZRJSkoy1kcQHdfvaPny5czevXuZy5cvM9nZ2czUqVMZKysr5siRI8b6CKKaNGkS4+LiwqSkpDB5eXnao7S0VFuGz9Y2bFCyFNj//vc/5plnnmFsbW2ZXr16MadOndK+1q9fP2bs2LHan6dNm6Yt6+npybz22mtMZmamEaI2jKphLtWPqu9k7NixTL9+/WrUCQoKYmxtbZnWrVszGzZsMHjchsT1O1q0aBHTpk0bxt7ennF1dWX69+/P/Prrr8YJ3gBq+24A6PxeVP93xjAMs337dqZ9+/aMra0t06VLF+bAgQOc35uWaCOEEBbomSUhhLBAyZIQQligZEkIISxQsiSEEBYoWRJCCAuULAkhhAVKloQQwgIlS0IIYYGSJSHEoPhsL9JQd+7cwZgxY+Dm5gYHBwd069YNf/zxB6c2KFkSQgyK79YZfP3999/o06cPbGxs8PPPP+P8+fNYunQpmjVrxqkdmu5ICDEaiUSCPXv2IDIyUnuuoqICn332GX744QcUFRWha9euWLRoEfr378/rPWJiYnDy5EkcP368QbHSlSUhxKRMmTIFaWlp2Lp1K86ePYthw4ZhwIABuHz5Mq/2fvzxR/Ts2RPDhg2Dh4cHgoODsW7dOs7t0JUlIcRoql9Z3rx5E61bt8bNmzfh4+OjLRcWFoZevXphwYIFnN+janvp6OhoDBs2DL///jumTp2K1atXY+zYsazboW0lCCEmIzs7GyqVCu3bt9c5X1FRATc3NwDAxYsX0alTp3rb+fTTT7Fw4UIAmkW4e/bsqU20wcHByMnJoWRJCDFfCoUCUqkUGRkZkEqlOq85OzsDAFq3bo0LFy7U205VYgU0m7p17txZ5/VOnTph165dnGKjZEkIMRnBwcFQqVS4d+8e+vbtW2sZW1tbdOzYkXWbffr0waVLl3TO/fXXX2jVqhWn2ChZEkIMSqFQ6Gyolpubi6ysLLi6uqJ9+/YYPXo0oqKisHTpUgQHB6OwsBDJyckICAhAREQE5/ebPn06evfujQULFmD48OFIT0/H2rVrsXbtWm4N8V7fnRBCeNC3dYZSqWTi4uIYPz8/xsbGhvH29mbeeOMN5uzZs7zf86effmK6du3K2NnZMR07dmTWrl3LuQ3qDSeEEBZonCUhhLBAyZIQQligZEkIISxQsiSEEBYoWRJCCAuULAkhhAVKloQQwgIlS0IIYYGSJSGEsEDJkhBCWKBkSQghLPw/w/ivnpHz8XUAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ts = [4000, 5000]\n", "datas = []\n", "plt.figure(figsize=(3.5, 2.5))\n", "for t in ts:\n", " datas.append(SyntheticPlankData(t))\n", " plt.scatter(datas[-1].X, datas[-1].y, label=t)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 20, "id": "e8ed4a09-2310-4933-a58e-30858aa92c50", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor([5054.9624], requires_grad=True)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T08:17:48.145764\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = SyntheticPlankData(5020)\n", "model = PlankModel(T=3000, lr=1e-21)\n", "trainer = d2l.Trainer(max_epochs=5)\n", "trainer.fit(model, data)\n", "model.T" ] }, { "cell_type": "markdown", "id": "6cef624b-f9c9-415f-94cd-f8c7a0fc4a90", "metadata": {}, "source": [ "## 4. What are the problems you might encounter if you wanted to compute the second derivatives of the loss? How would you fix them?" ] }, { "cell_type": "markdown", "id": "122aa143-ba33-49e3-a876-759be649742f", "metadata": {}, "source": [ "use `autograd.grad` with `create_graph=True`" ] }, { "cell_type": "code", "execution_count": 6, "id": "933d0cfe-73f4-44f7-9f8d-7f83bda39992", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T09:05:23.251428\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def sin(x):\n", " return torch.sin(x)\n", "\n", "\n", "x = torch.arange(-5, 5, 0.1)\n", "x.requires_grad = True\n", "y = sin(x)\n", "first_derivative = autograd.grad(y.sum(), x, create_graph=True)[0]\n", "second_derivative = autograd.grad(first_derivative.sum(), x)[0]\n", "d2l.plot(x.detach(), [sin(x).detach(), first_derivative.detach(),\n", " second_derivative.detach()], 'x', 'f(x)',\n", " figsize=(5, 3), legend=['sinx', 'first_derivative',\n", " 'second_derivative'])" ] }, { "cell_type": "markdown", "id": "1fe39e45-d03b-40cd-b1ac-8a293bec713a", "metadata": {}, "source": [ "## 5. Why is the reshape method needed in the loss function?" ] }, { "cell_type": "markdown", "id": "257b89d7-6819-4fb3-b6ac-12b94bcb84b4", "metadata": {}, "source": [ "Reshaping is used to ensure that the dimensions of the predicted values match the dimensions of the ground truth values so that the loss calculation can be performed correctly." ] }, { "cell_type": "markdown", "id": "fadafbf7-dfab-4d7f-9057-bb6e45b2493c", "metadata": {}, "source": [ "## 6. Experiment using different learning rates to find out how quickly the loss function value drops. Can you reduce the error by increasing the number of epochs of training?" ] }, { "cell_type": "markdown", "id": "e6069fa1-9834-4174-9ab8-3d541245722d", "metadata": {}, "source": [ "We make some experienmts with lr in [0.003,0.03,0.3,3] with `epoch=3`\n", "* when lr is small (such as 0.003), the loss function drops very slow, and the error can be reduced by increasing epoch\n", "* when lr increases, the loss function drops faster, and if it convergences, increasing epoch will not help too.\n", "* when lr is much larger (such as 3), the loss function blows up, and there is no need to increase epoch" ] }, { "cell_type": "markdown", "id": "07faae03-a87d-4efa-9e5e-553c9d08d23a", "metadata": {}, "source": [ "### lr = 0.003, epoch = 3" ] }, { "cell_type": "code", "execution_count": 12, "id": "ff94dda4-51a6-4a51-986f-fa8054dfc831", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 0.4372],\n", " [-0.8910]], requires_grad=True),\n", " tensor([1.0680], requires_grad=True))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T09:40:40.951760\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = d2l.SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)\n", "model = d2l.LinearRegressScratch(2, lr=0.003)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model,data)\n", "model.w,model.b" ] }, { "cell_type": "markdown", "id": "c28b98e6-270a-4903-94a7-9e159aab2be4", "metadata": {}, "source": [ "### lr = 0.003, epoch = 10" ] }, { "cell_type": "code", "execution_count": null, "id": "e817672b-8cfc-4481-8deb-0d28b8876629", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T09:43:14.992522\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = d2l.LinearRegressScratch(2, lr=0.003)\n", "trainer = d2l.Trainer(max_epochs=20)\n", "trainer.fit(model,data)\n", "model.w,model.b" ] }, { "cell_type": "markdown", "id": "b4a0214e-03c8-43c0-bf6f-3dafec124d65", "metadata": {}, "source": [ "### lr = 0.03, epoch = 3" ] }, { "cell_type": "code", "execution_count": 13, "id": "8c2fdc5d-0863-4789-b376-8d527ca737da", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 1.8396],\n", " [-3.2372]], requires_grad=True),\n", " tensor([3.9771], requires_grad=True))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T09:40:45.802067\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = d2l.LinearRegressScratch(2, lr=0.03)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model,data)\n", "model.w,model.b" ] }, { "cell_type": "markdown", "id": "ad85b247-107d-47a2-b835-999d6c12a74d", "metadata": {}, "source": [ "### lr = 0.3, epoch = 3" ] }, { "cell_type": "code", "execution_count": 14, "id": "11d4f0ea-9136-469b-99f1-15854548e25e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 2.0009],\n", " [-3.3990]], requires_grad=True),\n", " tensor([4.1986], requires_grad=True))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T09:40:52.797039\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = d2l.LinearRegressScratch(2, lr=0.3)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model,data)\n", "model.w,model.b" ] }, { "cell_type": "markdown", "id": "fae78fec-298c-4fda-923e-2ce3f59a7a01", "metadata": {}, "source": [ "### lr = 3, epoch = 3" ] }, { "cell_type": "code", "execution_count": 15, "id": "dde30ead-99dd-421b-85fc-28fe35e4131c", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[-3.2974e+29],\n", " [ 4.5504e+29]], requires_grad=True),\n", " tensor([-7.8700e+28], requires_grad=True))" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T09:42:18.726971\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = d2l.LinearRegressScratch(2, lr=3)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model,data)\n", "model.w,model.b" ] }, { "cell_type": "markdown", "id": "8af385b7-4a10-44d3-afae-a50ddc3d816e", "metadata": {}, "source": [ "## 7. If the number of examples cannot be divided by the batch size, what happens to data_iter at the end of an epoch?" ] }, { "cell_type": "markdown", "id": "013dd938-d435-4cb9-a8ab-14ef428d583a", "metadata": { "tags": [] }, "source": [ "We add codes below in `function fit_epoch`\n", "\n", "```python\n", "if len(batch[0]) != 32:\n", " print(len(batch[0]))\n", "```\n", "As the result print 3 times(which equals `max_epochs`) of 8(which equals `1000 - 32*1000//32`), we may make the conclusion that the last data_iter will give out all the examples left even though the number of examples is less than the batch size. \n", "We can also set parameter `drop_last=True` of `DataLoader`, if we just want to ignore the batch whose size is smaller than others." ] }, { "cell_type": "code", "execution_count": 2, "id": "7892832f-0c56-4e57-bd79-b9e17e387392", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8\n", "8\n", "8\n" ] } ], "source": [ "data = d2l.SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)\n", "model = d2l.LinearRegressScratch(2, lr=0.03)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model, data)" ] }, { "cell_type": "markdown", "id": "a703d2ce-7a27-407b-93eb-83af7ffe9856", "metadata": {}, "source": [ "## 8. Try implementing a different loss function, such as the absolute value loss (y_hat - d2l.reshape(y, y_hat.shape)).abs().sum().\n", "* Check what happens for regular data.\n", "* Check whether there is a difference in behavior if you actively perturb some entries, such as $y_5=10000$ of $y$\n", "* Can you think of a cheap solution for combining the best aspects of squared loss and absolute value loss? Hint: how can you avoid really large gradient values?" ] }, { "cell_type": "code", "execution_count": 2, "id": "a84385c9-4f9f-4042-9e4e-b3f7c9fbc1bd", "metadata": { "tags": [] }, "outputs": [], "source": [ "class LinearRegressAbsLoss(d2l.Module):\n", " def __init__(self, num_inputs, lr, sigma=0.01):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.w = torch.normal(0, sigma, (num_inputs, 1), requires_grad=True)\n", " self.b = torch.zeros(1, requires_grad=True)\n", " \n", " def forward(self, X):\n", " return torch.matmul(X, self.w) + self.b\n", " \n", " def loss(self, y_hat, y):\n", " l = torch.abs(y_hat - y)\n", " return l.sum()\n", " \n", " def configure_optimizers(self):\n", " return d2l.SGD([self.w, self.b], self.lr)" ] }, { "cell_type": "markdown", "id": "a10b0c7f-bd33-4ca4-b5d8-e2051affe2c2", "metadata": {}, "source": [ "### Check what happens for regular data.\n", "For regular data, you'll notice the following behavior:\n", "\n", "**Squared Loss**: The squared loss is sensitive to the differences between predicted values and ground truth. It penalizes larger differences more strongly, and the loss value can increase rapidly for larger deviations.\n", "\n", "**Absolute Value Loss**: The absolute value loss considers the absolute differences between predicted values and ground truth, without squaring them. It provides a more balanced treatment of deviations, and the loss value increases linearly with the magnitude of the differences." ] }, { "cell_type": "code", "execution_count": 17, "id": "4d7c9c82-c059-40d7-8b7c-1238ec0f008e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 1.9236],\n", " [-3.2809]], requires_grad=True),\n", " tensor([4.1000], requires_grad=True))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T11:34:09.396224\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = d2l.SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)\n", "model = LinearRegressAbsLoss(2, lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model, data)\n", "model.w,model.b" ] }, { "cell_type": "code", "execution_count": 19, "id": "91024f99-1285-4729-8afe-e8f0cdc49278", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T11:34:57.459613\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = d2l.LinearRegressScratch(2, lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model, data)" ] }, { "cell_type": "markdown", "id": "abb0a52e-77e4-45e3-8e5f-af513db8e64d", "metadata": { "tags": [] }, "source": [ "### Check whether there is a difference in behavior if you actively perturb some entries, such as $y_5=10000$ of $y$\n", "\n", "When perturbing entry y[5] to a large value, you'll notice that the squared loss reacts strongly to the perturbation, leading to a much larger value compared to the absolute value loss. This is because squared loss is more sensitive to outliers." ] }, { "cell_type": "code", "execution_count": 23, "id": "ee47b82c-5514-444e-af9d-f673561c92c4", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[-50.7932],\n", " [344.7570]], requires_grad=True),\n", " tensor([4.0400], requires_grad=True))" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T12:08:14.314580\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "alpha = 10000/data.y[5]\n", "data.X[5] *= alpha\n", "data.y[5] = 10000\n", "model = LinearRegressAbsLoss(2, lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=3)\n", "trainer.fit(model, data)\n", "model.w,model.b" ] }, { "cell_type": "code", "execution_count": 25, "id": "4b5ca291-37a3-44fb-b58c-7550cc5eadfe", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(tensor([[nan],\n", " [nan]], requires_grad=True),\n", " tensor([nan], requires_grad=True))" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T12:08:40.416269\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = d2l.LinearRegressScratch(2, lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=10)\n", "trainer.fit(model, data)\n", "model.w,model.b" ] }, { "cell_type": "markdown", "id": "19b704c3-a5b0-4338-bba1-41459d99dd97", "metadata": {}, "source": [ "### Can you think of a cheap solution for combining the best aspects of squared loss and absolute value loss? Hint: how can you avoid really large gradient values?\n", "\n", "To combine the best aspects of squared loss and absolute value loss while avoiding really large gradient values, you can consider using the Huber loss (also known as smooth L1 loss), which behaves like the squared loss near zero but transitions to the absolute value loss for larger values. This can provide a compromise between the two loss functions, preventing extreme gradients while being robust to outliers. " ] }, { "cell_type": "code", "execution_count": 37, "id": "9d3788cb-eccf-4d8b-9f35-c189ceff6e7b", "metadata": { "tags": [] }, "outputs": [], "source": [ "class LineRegressionHuberLoss(d2l.Module):\n", " def __init__(self, num_inputs, lr, sigma=0.01):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.w = torch.normal(0, sigma, (num_inputs,1), requires_grad=True)\n", " self.b = torch.zeros(1, requires_grad=True)\n", " \n", " def forward(self, X):\n", " return torch.matmul(X, self.w) + self.b\n", " \n", " def loss(self, y_hat, y, sigma=1):\n", " beta = 1.0/(sigma**2)\n", " diff = torch.abs(y_hat - y)\n", " l = torch.where(diff\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-15T12:12:21.621995\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = LineRegressionHuberLoss(2, lr=1e-7)\n", "trainer = d2l.Trainer(max_epochs=10)\n", "trainer.fit(model, data)\n", "model.w,model.b" ] }, { "cell_type": "markdown", "id": "94f22317-b075-4e44-8650-f989c2785d12", "metadata": {}, "source": [ "## 9. Why do we need to reshuffle the dataset? Can you design a case where a maliciously constructed dataset would break the optimization algorithm otherwise?" ] }, { "cell_type": "markdown", "id": "249a1109-65ae-47e8-9a79-04d8d5ef6b11", "metadata": {}, "source": [ "Shuffling the dataset during training is important for improving the efficiency and effectiveness of optimization algorithms, particularly stochastic gradient descent (SGD) and its variants. Shuffling serves several purposes:\n", "\n", "1. **Randomization:** Shuffling the dataset before each epoch ensures that the model encounters the data in a random order. This randomization helps prevent the model from memorizing the order of examples, which could lead to overfitting and biased learning.\n", "\n", "2. **Smooth Convergence:** Shuffling reduces the chances of encountering clusters of similar examples together. If the dataset has some inherent order or structure, not shuffling could cause the optimization process to converge unevenly or slowly.\n", "\n", "3. **Effective Exploration:** Shuffling encourages the optimization algorithm to explore different parts of the loss landscape in each epoch. This can help the model escape local minima and reach a more optimal solution.\n", "\n", "Regarding a case where a maliciously constructed dataset could break the optimization algorithm, consider a scenario where the dataset is intentionally ordered in a way that exploits the optimization algorithm's weaknesses:\n", "\n", "Imagine a dataset with examples sorted in a manner that gradually increases the loss. For instance, the dataset contains images of cats, and the images are ordered in such a way that the difficulty of classification gradually increases. If the optimization algorithm is used without shuffling, it would first encounter a series of easy examples, leading to quick convergence. However, as it proceeds, it would encounter progressively harder examples, causing the optimization algorithm to slow down, potentially getting stuck in suboptimal regions of the loss landscape.\n", "\n", "In such a case, shuffling the dataset before each epoch would disrupt the ordered pattern and ensure that the optimization algorithm encounters examples of varying difficulty levels throughout training. This randomness helps the optimization process explore the loss landscape more effectively and prevents the algorithm from being misled by the malicious ordering.\n", "\n", "Overall, shuffling the dataset is a common practice to improve the robustness and convergence of optimization algorithms in machine learning and deep learning tasks." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:d2l]", "language": "python", "name": "conda-env-d2l-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }