{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8cb2d883-6db2-4a1f-b699-ff8343df0fc8", "metadata": {}, "outputs": [], "source": [ "import logging, torch, torchvision, torch.nn.functional as F, torchvision.transforms.functional as TF, matplotlib as mpl\n", "import fastcore.all as fc\n", "from matplotlib import pyplot as plt\n", "from functools import partial\n", "from torch import tensor,nn,optim, einsum\n", "from torch.utils.data import DataLoader, default_collate\n", "from torchvision.utils import make_grid\n", "from datasets import load_dataset,load_dataset_builder\n", "from miniai.datasets import *\n", "from miniai.learner import *\n", "from miniai.conv import *\n", "from fastcore.all import *\n", "from fastprogress import progress_bar\n", "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": 2, "id": "e8273fb3", "metadata": {}, "outputs": [], "source": [ "mpl.rcParams['image.cmap'] = 'gray_r'\n", "logging.disable(logging.WARNING)" ] }, { "cell_type": "markdown", "id": "33e945bc-26a4-4194-ba12-4cbb7b79e49d", "metadata": {}, "source": [ "Load a dataset:" ] }, { "cell_type": "code", "execution_count": 3, "id": "99edd708", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b08231c14e84b5daa7000741e36d79d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00 b c 1 1\") + h\n", "\n", " h = self.block2(h)\n", " return h + self.res_conv(x)\n", "\n", "class Attention(nn.Module):\n", " def __init__(self, dim, heads=4, dim_head=32):\n", " super().__init__()\n", " self.scale = dim_head**-0.5\n", " self.heads = heads\n", " hidden_dim = dim_head * heads\n", " self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n", " self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n", "\n", " def forward(self, x):\n", " b, c, h, w = x.shape\n", " qkv = self.to_qkv(x).chunk(3, dim=1)\n", " q, k, v = map(\n", " lambda t: rearrange(t, \"b (h c) x y -> b h c (x y)\", h=self.heads), qkv\n", " )\n", " q = q * self.scale\n", "\n", " sim = einsum(\"b h d i, b h d j -> b h i j\", q, k)\n", " sim = sim - sim.amax(dim=-1, keepdim=True).detach()\n", " attn = sim.softmax(dim=-1)\n", "\n", " out = einsum(\"b h i j, b h d j -> b h i d\", attn, v)\n", " out = rearrange(out, \"b h (x y) d -> b (h d) x y\", x=h, y=w)\n", " return self.to_out(out)\n", "\n", "class PreNorm(nn.Module):\n", " def __init__(self, dim, fn):\n", " super().__init__()\n", " self.fn = fn\n", " self.norm = nn.GroupNorm(1, dim)\n", "\n", " def forward(self, x):\n", " x = self.norm(x)\n", " return self.fn(x)\n", "\n", "class Unet(nn.Module):\n", " def __init__(\n", " self,\n", " dim,\n", " init_dim = None,\n", " out_dim = None,\n", " dim_mults=(1, 2, 4, 8),\n", " channels = 3,\n", " resnet_block_groups = 8\n", " ):\n", " super().__init__()\n", "\n", " # determine dimensions\n", "\n", " self.channels = channels\n", "\n", " init_dim = default(init_dim, dim)\n", " self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3)\n", "\n", " dims = [init_dim, *map(lambda m: dim * m, dim_mults)]\n", " in_out = list(zip(dims[:-1], dims[1:]))\n", "\n", " block_klass = partial(ResnetBlock, groups = resnet_block_groups)\n", "\n", " # time embeddings\n", "\n", " time_dim = dim * 4\n", "\n", " self.time_mlp = nn.Sequential(\n", " SinusoidalPositionEmbeddings(dim),\n", " nn.Linear(dim, time_dim),\n", " nn.GELU(),\n", " nn.Linear(time_dim, time_dim)\n", " )\n", "\n", " # layers\n", "\n", " self.downs = nn.ModuleList([])\n", " self.ups = nn.ModuleList([])\n", " num_resolutions = len(in_out)\n", "\n", " for ind, (dim_in, dim_out) in enumerate(in_out):\n", " is_last = ind >= (num_resolutions - 1)\n", "\n", " self.downs.append(nn.ModuleList([\n", " block_klass(dim_in, dim_in, time_emb_dim = time_dim),\n", " block_klass(dim_in, dim_in, time_emb_dim = time_dim),\n", " Residual(PreNorm(dim_in, Attention(dim_in))),\n", " Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)\n", " ]))\n", "\n", " mid_dim = dims[-1]\n", " self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)\n", " self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))\n", " self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)\n", "\n", " for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):\n", " is_last = ind == (len(in_out) - 1)\n", "\n", " self.ups.append(nn.ModuleList([\n", " block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),\n", " block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),\n", " Residual(PreNorm(dim_out, Attention(dim_out))),\n", " Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)\n", " ]))\n", "\n", " default_out_dim = channels \n", " self.out_dim = default(out_dim, default_out_dim)\n", "\n", " self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)\n", " self.final_conv = nn.Conv2d(dim, self.out_dim, 1)\n", "\n", " def forward(self, x, time, x_self_cond = None):\n", " x = self.init_conv(x)\n", " r = x.clone()\n", "\n", " t = self.time_mlp(time)\n", "\n", " h = []\n", "\n", " for block1, block2, attn, downsample in self.downs:\n", " x = block1(x, t)\n", " h.append(x)\n", "\n", " x = block2(x, t)\n", " x = attn(x)\n", " h.append(x)\n", "\n", " x = downsample(x)\n", "\n", " x = self.mid_block1(x, t)\n", " x = self.mid_attn(x)\n", " x = self.mid_block2(x, t)\n", "\n", " for block1, block2, attn, upsample in self.ups:\n", " x = torch.cat((x, h.pop()), dim = 1)\n", " x = block1(x, t)\n", "\n", " x = torch.cat((x, h.pop()), dim = 1)\n", " x = block2(x, t)\n", " x = attn(x)\n", "\n", " x = upsample(x)\n", "\n", " x = torch.cat((x, r), dim = 1)\n", "\n", " x = self.final_res_block(x, t)\n", " return self.final_conv(x)" ] }, { "cell_type": "code", "execution_count": 7, "id": "aa916302-00c5-4ec0-ac69-de4dccce755f", "metadata": {}, "outputs": [], "source": [ "class DDPMCB(Callback):\n", " order = DeviceCB.order+1\n", " def __init__(self, n_steps, beta_min, beta_max):\n", " store_attr()\n", " try: self.device = L(self.learn.cbs).filter(f=fc.risinstance(DeviceCB))[0].device\n", " except: self.device=def_device\n", " self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.device) # variance schedule, linearly increased with timestep\n", " self.alpha = 1. - self.beta \n", " self.alpha_bar = torch.cumprod(self.alpha, dim=0)\n", " self.sigma = torch.sqrt(self.beta)\n", "\n", " def before_batch(self):\n", " eps = torch.randn(self.learn.batch[0].shape, device=self.learn.batch[0].device) # noise, x_T\n", " x0 = self.learn.batch[0] # original images, x_0\n", " batch_size = x0.shape[0]\n", " t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps\n", " alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)\n", " \n", " xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image\n", " self.learn.batch = (xt, t, eps) # input to our model is noisy image and timestep, ground truth is the noise \n", " \n", " @torch.no_grad()\n", " def sample(self, image_size, batch_size=16, channels=3):\n", " shape = (batch_size, channels, image_size, image_size)\n", " self.learn.model.to(self.device)\n", " xt = torch.randn(shape, device=self.device)\n", " with torch.profiler.profile(\n", " schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),\n", " on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/ddpm_sampling_wo_no_grad'),\n", " record_shapes=True,\n", " profile_memory=True,\n", " with_stack=True\n", " ) as prof:\n", " for t in reversed(range(self.n_steps)):\n", " t_batch = torch.full((xt.shape[0],), t, device=xt.device, dtype=torch.long)\n", " z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)\n", " alpha_t = self.alpha[t] # get noise level at current timestep\n", " alpha_bar_t = self.alpha_bar[t]\n", " sigma_t = self.sigma[t]\n", " alpha_bar_t_1 = self.alpha_bar[t-1] if t > 0 else torch.tensor(1, device=xt.device)\n", " beta_bar_t = 1 - alpha_bar_t\n", " beta_bar_t_1 = 1 - alpha_bar_t_1\n", " x0hat = (xt - torch.sqrt(beta_bar_t) * self.learn.model(xt, t_batch))/torch.sqrt(alpha_bar_t)\n", " x0hat = torch.clamp(x0hat, -1, 1)\n", " xt = x0hat * torch.sqrt(alpha_bar_t_1)*(1-alpha_t)/beta_bar_t + xt * torch.sqrt(alpha_t)*beta_bar_t_1/beta_bar_t + sigma_t*z \n", " #xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper\n", " prof.step()\n", " return xt\n", " \n", " def predict(self): self.learn.preds = self.learn.model(self.learn.batch[0],self.learn.batch[1])\n", " def get_loss(self): self.learn.loss = self.learn.loss_func(self.learn.preds, self.learn.batch[2])\n", " def backward(self): self.learn.loss.backward()\n", " def step(self): self.learn.opt.step()\n", " def zero_grad(self): self.learn.opt.zero_grad()" ] }, { "cell_type": "code", "execution_count": 8, "id": "bc78b703-9e50-452b-903c-218b08af2391", "metadata": {}, "outputs": [], "source": [ "class DDPMMetricsCB(MetricsCB):\n", " def __init__(self):\n", " super().__init__()\n", " def after_batch(self): self.loss.update(to_cpu(self.learn.loss), weight=len(x))" ] }, { "cell_type": "code", "execution_count": 9, "id": "30733743", "metadata": {}, "outputs": [], "source": [ "class ProfilerCB(Callback):\n", " order = 30\n", " def __init__(self, **kwargs): self.prof = torch.profiler.profile(**kwargs)\n", " def before_fit(self): self.prof.start()\n", " def after_batch(self): self.prof.step()\n", " def after_fit(self): self.prof.stop()" ] }, { "cell_type": "code", "execution_count": 10, "id": "07704f2c-2c5e-4422-9134-d81b9016c1a5", "metadata": {}, "outputs": [], "source": [ "model = Unet(dim=32, channels=1, dim_mults=(1,2,4,))" ] }, { "cell_type": "code", "execution_count": 11, "id": "e64d43f5", "metadata": {}, "outputs": [], "source": [ "profiler_args = {'schedule': torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),\n", " 'on_trace_ready': torch.profiler.tensorboard_trace_handler('./log/ddpm_training'),\n", " 'record_shapes': True,\n", " 'profile_memory': True,\n", " 'with_stack': True\n", " }" ] }, { "cell_type": "code", "execution_count": 12, "id": "b78c80e8-1bb5-4591-9021-40b2f41468be", "metadata": {}, "outputs": [], "source": [ "cbs = [DDPMCB(n_steps=1000, beta_min=0.0001, beta_max=0.02), DeviceCB(), ProgressCB(),DDPMMetricsCB(), ProfilerCB(**profiler_args)]\n", "learn = Learner(model, dls, nn.MSELoss(), lr=1e-3, cbs=cbs, opt_func=optim.Adam)" ] }, { "cell_type": "code", "execution_count": 13, "id": "1fbe1213-6e5f-4879-8414-574a2d393914", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "{'loss': '0.055', 'epoch': 0, 'train': True}

{'loss': '0.027', 'epoch': 0, 'train': False}

{'loss': '0.024', 'epoch': 1, 'train': True}

{'loss': '0.022', 'epoch': 1, 'train': False}

{'loss': '0.021', 'epoch': 2, 'train': True}

{'loss': '0.021', 'epoch': 2, 'train': False}

{'loss': '0.019', 'epoch': 3, 'train': True}

{'loss': '0.019', 'epoch': 3, 'train': False}

{'loss': '0.018', 'epoch': 4, 'train': True}

{'loss': '0.018', 'epoch': 4, 'train': False}" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(5) " ] }, { "cell_type": "markdown", "id": "3a335290-ad68-4c18-8bc5-85b4753dceda", "metadata": {}, "source": [ "Viewing the predictions on images with increasing noise levels:" ] }, { "cell_type": "code", "execution_count": 12, "id": "6e98b94f-38c5-4474-9e49-721201f2a188", "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 79.35 GiB total capacity; 49.87 GiB already allocated; 197.69 MiB free; 50.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m batch_size \u001b[39m=\u001b[39m \u001b[39m16\u001b[39m\n\u001b[0;32m----> 2\u001b[0m samples \u001b[39m=\u001b[39m learn\u001b[39m.\u001b[39;49mcbs[\u001b[39m0\u001b[39;49m]\u001b[39m.\u001b[39;49msample(\u001b[39m32\u001b[39;49m, batch_size\u001b[39m=\u001b[39;49mbatch_size,channels\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m)\n", "Cell \u001b[0;32mIn [8], line 43\u001b[0m, in \u001b[0;36mDDPMCB.sample\u001b[0;34m(self, image_size, batch_size, channels)\u001b[0m\n\u001b[1;32m 41\u001b[0m beta_bar_t \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m-\u001b[39m alpha_bar_t\n\u001b[1;32m 42\u001b[0m beta_bar_t_1 \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m-\u001b[39m alpha_bar_t_1\n\u001b[0;32m---> 43\u001b[0m x0hat \u001b[39m=\u001b[39m (xt \u001b[39m-\u001b[39m torch\u001b[39m.\u001b[39msqrt(beta_bar_t) \u001b[39m*\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlearn\u001b[39m.\u001b[39;49mmodel(xt, t_batch))\u001b[39m/\u001b[39mtorch\u001b[39m.\u001b[39msqrt(alpha_bar_t)\n\u001b[1;32m 44\u001b[0m x0hat \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mclamp(x0hat, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, \u001b[39m1\u001b[39m)\n\u001b[1;32m 45\u001b[0m xt \u001b[39m=\u001b[39m x0hat \u001b[39m*\u001b[39m torch\u001b[39m.\u001b[39msqrt(alpha_bar_t_1)\u001b[39m*\u001b[39m(\u001b[39m1\u001b[39m\u001b[39m-\u001b[39malpha_t)\u001b[39m/\u001b[39mbeta_bar_t \u001b[39m+\u001b[39m xt \u001b[39m*\u001b[39m torch\u001b[39m.\u001b[39msqrt(alpha_t)\u001b[39m*\u001b[39mbeta_bar_t_1\u001b[39m/\u001b[39mbeta_bar_t \u001b[39m+\u001b[39m sigma_t\u001b[39m*\u001b[39mz \n", "File \u001b[0;32m~/anaconda3/envs/course22p2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "Cell \u001b[0;32mIn [6], line 206\u001b[0m, in \u001b[0;36mUnet.forward\u001b[0;34m(self, x, time, x_self_cond)\u001b[0m\n\u001b[1;32m 203\u001b[0m h\u001b[39m.\u001b[39mappend(x)\n\u001b[1;32m 205\u001b[0m x \u001b[39m=\u001b[39m block2(x, t)\n\u001b[0;32m--> 206\u001b[0m x \u001b[39m=\u001b[39m attn(x)\n\u001b[1;32m 207\u001b[0m h\u001b[39m.\u001b[39mappend(x)\n\u001b[1;32m 209\u001b[0m x \u001b[39m=\u001b[39m downsample(x)\n", "File \u001b[0;32m~/anaconda3/envs/course22p2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "Cell \u001b[0;32mIn [6], line 18\u001b[0m, in \u001b[0;36mResidual.forward\u001b[0;34m(self, x, *args, **kwargs)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m---> 18\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfn(x, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs) \u001b[39m+\u001b[39m x\n", "File \u001b[0;32m~/anaconda3/envs/course22p2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "Cell \u001b[0;32mIn [6], line 119\u001b[0m, in \u001b[0;36mPreNorm.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x):\n\u001b[1;32m 118\u001b[0m x \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnorm(x)\n\u001b[0;32m--> 119\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfn(x)\n", "File \u001b[0;32m~/anaconda3/envs/course22p2/lib/python3.10/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", "Cell \u001b[0;32mIn [6], line 103\u001b[0m, in \u001b[0;36mAttention.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 98\u001b[0m q, k, v \u001b[39m=\u001b[39m \u001b[39mmap\u001b[39m(\n\u001b[1;32m 99\u001b[0m \u001b[39mlambda\u001b[39;00m t: rearrange(t, \u001b[39m\"\u001b[39m\u001b[39mb (h c) x y -> b h c (x y)\u001b[39m\u001b[39m\"\u001b[39m, h\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mheads), qkv\n\u001b[1;32m 100\u001b[0m )\n\u001b[1;32m 101\u001b[0m q \u001b[39m=\u001b[39m q \u001b[39m*\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscale\n\u001b[0;32m--> 103\u001b[0m sim \u001b[39m=\u001b[39m einsum(\u001b[39m\"\u001b[39;49m\u001b[39mb h d i, b h d j -> b h i j\u001b[39;49m\u001b[39m\"\u001b[39;49m, q, k)\n\u001b[1;32m 104\u001b[0m sim \u001b[39m=\u001b[39m sim \u001b[39m-\u001b[39m sim\u001b[39m.\u001b[39mamax(dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, keepdim\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\u001b[39m.\u001b[39mdetach()\n\u001b[1;32m 105\u001b[0m attn \u001b[39m=\u001b[39m sim\u001b[39m.\u001b[39msoftmax(dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n", "File \u001b[0;32m~/anaconda3/envs/course22p2/lib/python3.10/site-packages/torch/functional.py:360\u001b[0m, in \u001b[0;36meinsum\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[39m# recurse incase operands contains value that has torch function\u001b[39;00m\n\u001b[1;32m 357\u001b[0m \u001b[39m# in the original implementation this line is omitted\u001b[39;00m\n\u001b[1;32m 358\u001b[0m \u001b[39mreturn\u001b[39;00m einsum(equation, \u001b[39m*\u001b[39m_operands)\n\u001b[0;32m--> 360\u001b[0m \u001b[39mreturn\u001b[39;00m _VF\u001b[39m.\u001b[39;49meinsum(equation, operands)\n", "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 79.35 GiB total capacity; 49.87 GiB already allocated; 197.69 MiB free; 50.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF" ] } ], "source": [ "batch_size = 16\n", "samples = learn.cbs[0].sample(32, batch_size=batch_size,channels=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "680058ba-9cc8-4327-b2de-c5a96f058c48", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16, 1, 32, 32])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "samples.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "c18b3adb-cd33-42f0-b058-5496fb5e3508", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "

" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_images(-1*samples, figsize=(5,5))" ] } ], "metadata": { "kernelspec": { "display_name": "course22p2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "vscode": { "interpreter": { "hash": "0652906208a1dcd94e9ea7623081d93dd4d2f6cda070da042189d63fdc8dadfe" } } }, "nbformat": 4, "nbformat_minor": 5 }