{
"cells": [
{
"cell_type": "markdown",
"id": "693de1e8",
"metadata": {},
"source": [
"## Andrej's makemore lecture - 3\n",
"\n",
"Implementation following Andrej Karpathy's lecture [Building makemore Part 3: Activations & Gradients, BatchNorm](https://youtu.be/P6sfmUTpUmc).\n",
"\n",
"I have some liberty to refactor and pythonise his implementation."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1cdd0780",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import black\n",
"import jupyter_black\n",
"\n",
"jupyter_black.load(\n",
" lab=False,\n",
" line_length=79,\n",
" target_version=black.TargetVersion.PY310,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "246c1917",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from IPython.display import display, HTML, clear_output\n",
"\n",
"display(HTML(\"\"))\n",
"\n",
"\n",
"from dataclasses import dataclass, field\n",
"import typing as t\n",
"import itertools as it\n",
"import collections as c\n",
"import json\n",
"from copy import deepcopy\n",
"import math\n",
"import time\n",
"import functools as ft\n",
"import numpy as np\n",
"import random\n",
"from tqdm.notebook import tqdm\n",
"import heapq\n",
"import torch as T\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import torch.utils.tensorboard as tb\n",
"\n",
"plt.rcParams[\"figure.figsize\"] = (12, 4)\n",
"plt.rcParams[\"font.size\"] = 14"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cae62e50",
"metadata": {},
"outputs": [],
"source": [
"# Number of past tokens to use to predict next token\n",
"CTX_WIN_SZ = 3"
]
},
{
"cell_type": "markdown",
"id": "d0a4107a",
"metadata": {},
"source": [
"### Load data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "918744d9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['emma', 'olivia', 'ava', 'isabella']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"DOT = \".\"\n",
"words = open(\"names.txt\").read().splitlines()\n",
"words[:4]"
]
},
{
"cell_type": "markdown",
"id": "130b61aa",
"metadata": {},
"source": [
"#### Build mapping of character to index"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "40d8add8",
"metadata": {},
"outputs": [],
"source": [
"def build_ixes(words):\n",
" chars = [DOT] + sorted(set(it.chain.from_iterable(words)))\n",
" nchars = len(chars)\n",
" ctoix = {c: i for i, c in enumerate(chars)}\n",
" ixtoc = dict(enumerate(chars))\n",
" return (ctoix, ixtoc)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7c18c2f8",
"metadata": {},
"outputs": [],
"source": [
"(ctoix, ixtoc) = build_ixes(words)"
]
},
{
"cell_type": "markdown",
"id": "03333c0e",
"metadata": {},
"source": [
"#### Create training data with context window size"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ca5067d1",
"metadata": {},
"outputs": [],
"source": [
"def build_train_data(words, ctoix, ctx_win=CTX_WIN_SZ):\n",
" Xs, Ys = [], []\n",
" pad = DOT * ctx_win\n",
" for wnum, w in enumerate(words):\n",
" pw = pad + w + DOT\n",
" if wnum < 2:\n",
" print(pw)\n",
" for i in range(len(w) + 1):\n",
" if wnum < 2:\n",
" print(pw[i : i + ctx_win], \"--->\", pw[i + ctx_win])\n",
" Xs.append([ctoix[c] for c in pw[i : i + ctx_win]])\n",
" Ys.append([ctoix[pw[i + ctx_win]]])\n",
" return T.tensor(Xs, dtype=int), T.tensor(Ys, dtype=int).flatten()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "784d2742",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"...emma.\n",
"... ---> e\n",
"..e ---> m\n",
".em ---> m\n",
"emm ---> a\n",
"mma ---> .\n",
"...olivia.\n",
"... ---> o\n",
"..o ---> l\n",
".ol ---> i\n",
"oli ---> v\n",
"liv ---> i\n",
"ivi ---> a\n",
"via ---> .\n"
]
},
{
"data": {
"text/plain": [
"(228146, 3)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Xs, Ys = build_train_data(words, ctoix, ctx_win=CTX_WIN_SZ)\n",
"n, m = Xs.shape\n",
"n, m"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2bff6fc0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 5, 13, 13, 1, 0, 15, 12, 9, 22, 9, 1, 0, 1, 22, 1, 0, 9, 19,\n",
" 1, 2, 5, 12, 12, 1, 0, 19, 15, 16, 8, 9])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Ys[:30]"
]
},
{
"cell_type": "markdown",
"id": "5efb3e62",
"metadata": {},
"source": [
"#### Train, validation, test split"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "88e03a4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"27"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"NCHARS = len(ctoix)\n",
"NCHARS"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f3b7b224",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ntrain=182516, nval=22814, ntest=22816\n"
]
}
],
"source": [
"ntrain, nval = int(n * 0.8), int(n * 0.1)\n",
"ntest = n - ntrain - nval\n",
"print(f\"{ntrain=}, {nval=}, {ntest=}\")\n",
"ixes = list(range(0, n))\n",
"random.shuffle(ixes)\n",
"valend = ntrain + nval\n",
"ixtr, ixval, ixtest = ixes[:ntrain], ixes[ntrain:valend], ixes[valend:]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "63041722",
"metadata": {},
"outputs": [],
"source": [
"(Xtr, Ytr), (Xval, Yval) = (Xs[ixtr], Ys[ixtr]), (Xs[ixval], Ys[ixval])\n",
"(Xtest, Ytest) = (Xs[ixtest], Ys[ixtest])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6e639831",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 0, 0, 12],\n",
" [ 0, 13, 9],\n",
" [ 0, 20, 1],\n",
" ...,\n",
" [15, 13, 1],\n",
" [ 1, 22, 1],\n",
" [12, 15, 18]]),\n",
" tensor([15, 18, 2, ..., 25, 0, 1]))"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Xtest, Ytest"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "fc404c8e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([182516, 3]) torch.Size([22814, 3]) torch.Size([22816, 3])\n"
]
}
],
"source": [
"print(Xtr.shape, Xval.shape, Xtest.shape)"
]
},
{
"cell_type": "markdown",
"id": "c044dfc3",
"metadata": {},
"source": [
"### Model layer implementations\n",
"\n",
"Requires building following layers\n",
"\n",
"- Embedding\n",
"- Linear layer\n",
"- BatchNorm1D\n",
"- Tanh"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "e43faff1",
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Embedding:\n",
"\n",
" num_embed: int\n",
" embed_dim: int\n",
" E: T.Tensor = field(init=False, repr=False)\n",
" _params: list[T.Tensor] = field(init=False, repr=False)\n",
" #: last forward pass, mutated during forward pass\n",
" out: t.Optional[T.Tensor] = field(default=None, repr=False)\n",
"\n",
" def __post_init__(self):\n",
" self.E = T.randn(self.num_embed, self.embed_dim)\n",
" self._params = [self.E]\n",
" self.parameters = lambda: self._params\n",
"\n",
" def __call__(self, X):\n",
" batch_sz, num_terms = X.shape\n",
" self.out = self.E[X].view(batch_sz, -1)\n",
" return self.out\n",
"\n",
"\n",
"@dataclass\n",
"class Linear:\n",
"\n",
" fanin: int\n",
" fanout: int\n",
" bias: bool = True\n",
" # gain used in kaiming he activation\n",
" wt_gain: float = 1.0\n",
" # if not set we use kaiming he activation\n",
" init_wt_scale: t.Optional[float] = None\n",
" b: t.Optional[T.Tensor] = field(init=False, repr=False)\n",
" W: T.Tensor = field(init=False, repr=False)\n",
" _params: list[T.Tensor] = field(init=False, repr=False)\n",
" #: last forward pass, mutated during forward pass\n",
" out: t.Optional[T.Tensor] = field(default=None, repr=False)\n",
"\n",
" def __post_init__(self):\n",
" if self.init_wt_scale is None:\n",
" self.init_wt_scale = self.wt_gain / (self.fanin**0.5)\n",
" # sample from uniform random\n",
" self.W = T.FloatTensor(self.fanin, self.fanout).uniform_(\n",
" -self.init_wt_scale, self.init_wt_scale\n",
" )\n",
" if self.bias:\n",
" self.b = T.ones(1, self.fanout) * 0.01\n",
" self._params = [self.W, self.b]\n",
" else:\n",
" self.b = None\n",
" self._params = [self.W]\n",
" self.parameters = lambda: self._params\n",
"\n",
" def __call__(self, X):\n",
" if self.bias:\n",
" self.out = X @ self.W + self.b\n",
" else:\n",
" self.out = X @ self.W\n",
" return self.out\n",
"\n",
"\n",
"@dataclass\n",
"class BatchNorm1D:\n",
"\n",
" size: int\n",
" momentum: float = 0.01\n",
" eps: float = 1e-5\n",
" #: scaling after standardising\n",
" gamma: T.Tensor = field(init=False, repr=False)\n",
" #: shift after standardising\n",
" beta: T.Tensor = field(init=False, repr=False)\n",
" _params: list[T.Tensor] = field(init=False, repr=False)\n",
" #: running averages of mean and variance\n",
" buffer_mean: T.Tensor = field(init=False, repr=False)\n",
" buffer_var: T.Tensor = field(init=False, repr=False)\n",
" #: last forward pass, mutated during forward pass\n",
" out: t.Optional[T.Tensor] = field(default=None, repr=False)\n",
"\n",
" def __post_init__(self):\n",
" self.gamma = T.ones(1, self.size)\n",
" self.beta = T.zeros(1, self.size)\n",
" self._params = [self.gamma, self.beta]\n",
" self.parameters = lambda: self._params\n",
" self.buffer_mean = T.zeros(1, self.size, requires_grad=False)\n",
" self.buffer_var = T.ones(1, self.size, requires_grad=False)\n",
"\n",
" def __call__(self, X, inference=False):\n",
" fwd_fn = self._fwd_inference if inference else self._fwd_train\n",
" return fwd_fn(X)\n",
"\n",
" def _fwd_train(self, X):\n",
" mu = X.mean(dim=0, keepdims=True)\n",
" var = X.var(dim=0, keepdims=True) + self.eps\n",
" with T.no_grad():\n",
" mom_old, mom_new = 1 - self.momentum, self.momentum\n",
" self.buffer_mean = mom_old * self.buffer_mean + mom_new * mu\n",
" self.buffer_var = mom_old * self.buffer_var + mom_new * var\n",
" self.out = BatchNorm1D._fwd(\n",
" X=X, mu=mu, var=var, gamma=self.gamma, beta=self.beta\n",
" )\n",
" return self.out\n",
"\n",
" def _fwd_inference(self, X):\n",
" with T.no_grad():\n",
" self.out = BatchNorm1D._fwd(\n",
" X=X,\n",
" mu=self.buffer_mean,\n",
" var=self.buffer_var,\n",
" gamma=self.gamma,\n",
" beta=self.beta,\n",
" )\n",
" return self.out\n",
"\n",
" @staticmethod\n",
" def _fwd(X, mu, var, gamma, beta):\n",
" X_std = (X - mu) / T.sqrt(var)\n",
" return (X_std * gamma) + beta\n",
"\n",
"\n",
"@dataclass\n",
"class Tanh:\n",
"\n",
" #: last forward pass, mutated during forward pass\n",
" out: t.Optional[T.Tensor] = field(default=None, repr=False)\n",
"\n",
" def parameters(self):\n",
" return []\n",
"\n",
" def __call__(self, X):\n",
" self.out = T.tanh(X)\n",
" return self.out"
]
},
{
"cell_type": "markdown",
"id": "c1b154ed",
"metadata": {},
"source": [
"### Training loop\n",
"- Split data into batches\n",
"- Fwd pass, zero grad, loss.backward, batch gradient update.\n",
"- Keep track of learning losses for later plotting."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c39a65b9",
"metadata": {},
"outputs": [],
"source": [
"def train_loop(\n",
" Xs,\n",
" Ys,\n",
" mdl,\n",
" lr,\n",
" num_iter,\n",
" max_sub_iter=None,\n",
" batch_sz=128,\n",
" losses=None,\n",
" wt_update_ratios=None,\n",
" verbose=True,\n",
"):\n",
" # train_ix to loss\n",
" wt_update_ratios = wt_update_ratios if wt_update_ratios is not None else {}\n",
" losses = losses if losses is not None else []\n",
" max_sub_iter = max_sub_iter or np.inf\n",
" nrows = Xs.shape[0]\n",
" ixes = list(range(nrows))\n",
" if verbose:\n",
" itrs = tqdm(range(num_iter))\n",
" else:\n",
" itrs = range(num_iter)\n",
" total_iter = len(losses)\n",
" for i in itrs:\n",
" total_iter += 1\n",
" random.shuffle(ixes)\n",
" for sub_iter, begix in enumerate(\n",
" T.arange(0, nrows, batch_sz), start=1\n",
" ):\n",
" batch_ix = ixes[begix : begix + batch_sz]\n",
" p = fwd_pass(Xs=Xs[batch_ix], mdl=mdl)\n",
" loss = F.cross_entropy(input=p, target=Ys[batch_ix])\n",
" losses.append(loss.item())\n",
" # The output of each layer is an intermediate computation\n",
" # the gradient is only needed for model params. We force\n",
" # pytorch to keep these grads.\n",
" retain_out_grad(mdl)\n",
" _zero_grad(mdl=mdl)\n",
" loss.backward()\n",
" _update_params(\n",
" mdl=mdl,\n",
" lr=_get_lr(lr=lr, it=total_iter),\n",
" wt_update_ratios=wt_update_ratios,\n",
" )\n",
" if sub_iter >= max_sub_iter:\n",
" break\n",
" if verbose:\n",
" itrs.set_description(f\"Loss: {loss.item():.2f}\")\n",
" return losses\n",
"\n",
"\n",
"def fwd_pass(Xs, mdl, act_fn=None, return_intermediates=False):\n",
" x = Xs\n",
" for layer in mdl[\"layers\"]:\n",
" x = layer(x)\n",
" return x\n",
"\n",
"\n",
"def retain_out_grad(mdl):\n",
" for layer in mdl[\"layers\"]:\n",
" layer.out.retain_grad()\n",
"\n",
"\n",
"def fwd_proba(Xs, mdl, act_fn=None):\n",
" return F.softmax(fwd_pass(Xs=Xs, mdl=mdl, act_fn=act_fn), dim=1)\n",
"\n",
"\n",
"def _get_lr(lr, it):\n",
" if isinstance(lr, (int, float, T.TensorType, T.Tensor)):\n",
" return lr\n",
" elif isinstance(lr, dict):\n",
" for (min_it, max_it), _lr in lr.items():\n",
" if min_it <= _lr < max_it:\n",
" return _lr\n",
" else:\n",
" raise ValueError(f\"Iteration {it} not in any range in {lr}\")\n",
" else:\n",
" raise NotImplementedError(\n",
" f\"Don't know how to handle learning {lr} of type {type(lr)}\"\n",
" )\n",
"\n",
"\n",
"def _zero_grad(mdl):\n",
" for param in mdl[\"params\"]:\n",
" param.grad = None\n",
"\n",
"\n",
"def _update_params(mdl, lr, wt_update_ratios):\n",
" lyr_num = 0\n",
" for param in mdl[\"params\"]:\n",
" param.data -= lr * param.grad\n",
" if param.ndim != 2 or param.shape[0] == 1:\n",
" # only pick linear layer, don't choose gradients for batchnorm\n",
" continue\n",
" # weight updates only for W matrices\n",
" lyr_num += 1\n",
" wname = f\"W_{lyr_num}\"\n",
" ratios = wt_update_ratios.get(wname, [])\n",
" ratios.append(\n",
" ((lr * param.grad).std() / param.data.std()).log10().item()\n",
" )\n",
" wt_update_ratios[wname] = ratios\n",
"\n",
"\n",
"def _loss(Xs, Ys, mdl):\n",
" logits = fwd_pass(Xs=Xs, mdl=mdl)\n",
" return F.cross_entropy(input=logits, target=Ys).item()\n",
"\n",
"\n",
"def _plot_losses(losses: list[int]):\n",
" plt.plot(losses)\n",
" plt.title(\"Loss vs iteration\")\n",
" plt.xlabel(\"Iter\")\n",
" plt.ylabel(\"Cross entropy or NLL loss\")\n",
" plt.grid()\n",
"\n",
"\n",
"@T.no_grad()\n",
"def generate_words(mdl, nwords, ctoix, ctx_win=CTX_WIN_SZ, generator=None):\n",
" st_X = _ctx_to_X(ctx_chars=DOT * CTX_WIN_SZ, ctoix=ctoix).repeat(nwords, 1)\n",
" lst_X = st_X\n",
" words = [[] for _ in range(nwords)]\n",
" for i in range(30):\n",
" # print(f\"{lst_X=}\")\n",
" new_X = T.multinomial(\n",
" input=F.softmax(fwd_pass(lst_X, mdl), dim=1),\n",
" num_samples=1,\n",
" replacement=True,\n",
" generator=generator,\n",
" )\n",
" # print(f\"{new_X=}\")\n",
" char_added = False\n",
" for w, ix in zip(words, new_X):\n",
" if w and w[-1] == DOT:\n",
" continue\n",
" char_added |= True\n",
" w.append(ixtoc[ix.item()])\n",
" lst_X[:, :-1] = lst_X[:, 1:]\n",
" lst_X[:, -1] = new_X.squeeze()\n",
" # print(f\"{lst_X=}\\n\\n\")\n",
" if not char_added:\n",
" break\n",
" return [\"\".join(w[:-1]) for w in words]\n",
"\n",
"\n",
"def num_params(mdl):\n",
" return sum(T.numel(param) for param in mdl[\"params\"])\n",
"\n",
"\n",
"def _ctx_to_X(ctx_chars, ctoix):\n",
" return T.tensor([ctoix[c] for c in ctx_chars]).unsqueeze(dim=0)"
]
},
{
"cell_type": "markdown",
"id": "fb91021b",
"metadata": {},
"source": [
"#### Train validation and test losses"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e36e774d",
"metadata": {},
"outputs": [],
"source": [
"@T.no_grad()\n",
"def _split_loss(Xtr, Ytr, Xval, Yval, Xtest, Ytest, mdl, split):\n",
" match split:\n",
" case \"train\":\n",
" spl, loss = \"Train\", _loss(Xtr, Ytr, mdl)\n",
" case \"val\":\n",
" spl, loss = \"Validation\", _loss(Xval, Yval, mdl)\n",
" case \"test\":\n",
" spl, loss = \"Test\", _loss(Xtest, Ytest, mdl)\n",
" case _:\n",
" raise NotImplementedError(\"Split should be train, val or test\")\n",
" print(f\"{spl} loss={loss:.4f}\")\n",
" return loss\n",
"\n",
"\n",
"split_loss = ft.partial(\n",
" _split_loss,\n",
" Xtr=Xtr,\n",
" Ytr=Ytr,\n",
" Xval=Xval,\n",
" Yval=Yval,\n",
" Xtest=Xtest,\n",
" Ytest=Ytest,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "797704e9",
"metadata": {},
"source": [
"### Build model by stacking Layers"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "19c04dd3",
"metadata": {},
"outputs": [],
"source": [
"EMBED_DIM = 5\n",
"HIDDEN_DIM = 50\n",
"NUM_HIDDEN = 1"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "098fba0f",
"metadata": {},
"outputs": [],
"source": [
"def build_model(\n",
" nchrs=NCHARS,\n",
" ctx_win=CTX_WIN_SZ,\n",
" embed_dim=EMBED_DIM,\n",
" hidden_dim=HIDDEN_DIM,\n",
" num_hidden=NUM_HIDDEN,\n",
" lin_wt_gain=1.0,\n",
" init_wt_scale=None,\n",
" batchnorm=False,\n",
"):\n",
" assert hidden_dim > 0, \"at least one hidden dim is needed\"\n",
" lkwrgs = dict(wt_gain=lin_wt_gain, init_wt_scale=init_wt_scale)\n",
" if batchnorm:\n",
" activation_fn = lambda: [BatchNorm1D(size=hidden_dim), Tanh()]\n",
" else:\n",
" activation_fn = lambda: [Tanh()]\n",
" layers = [\n",
" Embedding(num_embed=nchrs, embed_dim=embed_dim),\n",
" Linear(fanin=embed_dim * ctx_win, fanout=hidden_dim, **lkwrgs),\n",
" ] + activation_fn()\n",
" for _ in range(num_hidden - 1):\n",
" layers.extend([Linear(fanin=hidden_dim, fanout=hidden_dim, **lkwrgs)])\n",
" layers.extend(activation_fn())\n",
" if batchnorm:\n",
" layers.append(Linear(fanin=hidden_dim, fanout=nchrs, wt_gain=1.0))\n",
" bn = BatchNorm1D(size=layers[-1].fanout)\n",
" # last layer we scale the gamma not the weights as they are standardised\n",
" bn.gamma *= 0.01\n",
" layers.append(bn)\n",
" else:\n",
" layers.append(Linear(fanin=hidden_dim, fanout=nchrs, wt_gain=0.1))\n",
" params = [p for lyr in layers for p in lyr.parameters()]\n",
" for p in params:\n",
" p.requires_grad = True\n",
" lyrs_str = \"\\n\\t\".join(str(l) for l in layers)\n",
" print(f\"Layers: \\n\\t{lyrs_str}\")\n",
" return dict(layers=layers, params=params)"
]
},
{
"cell_type": "markdown",
"id": "03e0ee0f",
"metadata": {},
"source": [
"### Test if it works and try to overfit"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "34d5c073",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Layers: \n",
"\tEmbedding(num_embed=27, embed_dim=5)\n",
"\tLinear(fanin=15, fanout=50, bias=True, wt_gain=1.6666666666666667, init_wt_scale=0.43033148291193524)\n",
"\tBatchNorm1D(size=50, momentum=0.01, eps=1e-05)\n",
"\tTanh()\n",
"\tLinear(fanin=50, fanout=27, bias=True, wt_gain=1.0, init_wt_scale=0.1414213562373095)\n",
"\tBatchNorm1D(size=27, momentum=0.01, eps=1e-05)\n"
]
}
],
"source": [
"mdl = build_model(num_hidden=1, lin_wt_gain=5 / 3, batchnorm=True)\n",
"Xtr_sml, Ytr_sml = Xtr[:50], Ytr[:50]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "20fde998",
"metadata": {},
"outputs": [],
"source": [
"losses = []"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "4cd6b15d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10249d59e96c48b29b826fe79bdca94f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"We expect this to be large as we overfitted on a small batch.\n",
"Train loss=5.8337\n",
"Validation loss=5.8518\n"
]
},
{
"data": {
"text/plain": [
"5.851779460906982"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"