{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"id": "4db23f5f-90f5-4f80-8efa-d9137af4a447",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('/home/jovyan/work/d2l/notebooks/d2l_utils')\n",
"import d2l\n",
"import torch\n",
"import warnings\n",
"import matplotlib.pyplot as plt\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"\n",
"class HighDimData(d2l.DataModule):\n",
" def __init__(self, num_train, num_val, num_inputs, batch_size):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" n = num_train + num_val\n",
" self.X = torch.randn(n, num_inputs)\n",
" noise = torch.randn(n, 1) * 0.01\n",
" self.w, self.b = torch.ones(num_inputs, 1) * 0.01, 0.05\n",
" self.y = torch.matmul(self.X, self.w) + self.b + noise\n",
"\n",
" def get_dataloader(self, train):\n",
" i = slice(0, self.num_train) if train else slice(self.num_train, None)\n",
" return self.get_tensorloader([self.X, self.y], train, i)\n",
"\n",
"\n",
"class WeightDecayScratch(d2l.LinearRegressScratch):\n",
" def __init__(self, num_inputs, lambd, lr, sigma=0.01):\n",
" super().__init__(num_inputs, lr, sigma)\n",
" self.save_hyperparameters()\n",
"\n",
" def loss(self, y_hat, y):\n",
" return super().loss(y_hat, y) + self.lambd * d2l.l2_penalty(self.w)\n",
" \n",
"\n",
"class WeightDecay(d2l.LinearRegression):\n",
" def __init__(self, wd, lr):\n",
" super().__init__(lr)\n",
" self.save_hyperparameters()\n",
"\n",
" def configure_optimizers(self):\n",
" return torch.optim.SGD([{'params': self.net.weight, 'weight_decay': self.wd},\n",
" {'params': self.net.bias}], lr=self.lr)\n",
"\n",
"\n",
"def train_strach(lambd, trainer, data):\n",
" model = WeightDecayScratch(num_inputs=200, lambd=lambd, lr=0.01)\n",
" train_loss, valid_loss = trainer.fit(model, data)\n",
" if trainer.plot_flag:\n",
" print(f'l2 norm of w:{d2l.l2_penalty(model.w):.2g}')\n",
" return train_loss, valid_loss"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4f33664e-2337-42bd-b62e-9a9fcd4a59bd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"l2 norm of w:0.012\n"
]
},
{
"data": {
"text/plain": [
"(4.6591752834501676e-05, 0.4330679289996624)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"