{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "273a6ade", "metadata": {}, "outputs": [], "source": [ "import math,torch\n", "from torch import nn\n", "from miniai.activations import *" ] }, { "cell_type": "code", "execution_count": null, "id": "e2c95260", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "id": "64e5ea7b", "metadata": {}, "outputs": [], "source": [ "from diffusers.models.attention import AttentionBlock" ] }, { "cell_type": "code", "execution_count": null, "id": "43bd330b", "metadata": {}, "outputs": [], "source": [ "set_seed(42)\n", "x = torch.randn(64,32,16,16)" ] }, { "cell_type": "code", "execution_count": null, "id": "8174db82", "metadata": {}, "outputs": [], "source": [ "t = x.view(*x.shape[:2], -1).transpose(1, 2)\n", "t.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "51245dae", "metadata": {}, "outputs": [], "source": [ "ni = 32" ] }, { "cell_type": "code", "execution_count": null, "id": "6f10cb36", "metadata": {}, "outputs": [], "source": [ "sk = nn.Linear(ni, ni)\n", "sq = nn.Linear(ni, ni)\n", "sv = nn.Linear(ni, ni)" ] }, { "cell_type": "code", "execution_count": null, "id": "15237a98", "metadata": {}, "outputs": [], "source": [ "k = sk(t)\n", "q = sq(t)\n", "v = sv(t)" ] }, { "cell_type": "code", "execution_count": null, "id": "d34cd0dd", "metadata": {}, "outputs": [], "source": [ "(q@k.transpose(1,2)).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "15062786", "metadata": {}, "outputs": [], "source": [ "class SelfAttention(nn.Module):\n", " def __init__(self, ni):\n", " super().__init__()\n", " self.scale = math.sqrt(ni)\n", " self.norm = nn.GroupNorm(1, ni)\n", " self.q = nn.Linear(ni, ni)\n", " self.k = nn.Linear(ni, ni)\n", " self.v = nn.Linear(ni, ni)\n", " self.proj = nn.Linear(ni, ni)\n", " \n", " def forward(self, x):\n", " inp = x\n", " n,c,h,w = x.shape\n", " x = self.norm(x)\n", " x = x.view(n, c, -1).transpose(1, 2)\n", " q = self.q(x)\n", " k = self.k(x)\n", " v = self.v(x)\n", " s = (q@k.transpose(1,2))/self.scale\n", " x = s.softmax(dim=-1)@v\n", " x = self.proj(x)\n", " x = x.transpose(1,2).reshape(n,c,h,w)\n", " return x+inp" ] }, { "cell_type": "code", "execution_count": null, "id": "fcb48706", "metadata": {}, "outputs": [], "source": [ "sa = SelfAttention(32)" ] }, { "cell_type": "code", "execution_count": null, "id": "3fb4ae6d", "metadata": {}, "outputs": [], "source": [ "ra = sa(x)\n", "ra.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "ce69830a", "metadata": {}, "outputs": [], "source": [ "ra[0,0,0]" ] }, { "cell_type": "code", "execution_count": null, "id": "f1a380fa", "metadata": {}, "outputs": [], "source": [ "def cp_parms(a,b):\n", " b.weight = a.weight\n", " b.bias = a.bias" ] }, { "cell_type": "code", "execution_count": null, "id": "bc6e3969", "metadata": {}, "outputs": [], "source": [ "at = AttentionBlock(32, norm_num_groups=1)\n", "src = sa.q,sa.k,sa.v,sa.proj,sa.norm\n", "dst = at.query,at.key,at.value,at.proj_attn,at.group_norm\n", "for s,d in zip(src,dst): cp_parms(s,d)" ] }, { "cell_type": "code", "execution_count": null, "id": "5bfc0087", "metadata": {}, "outputs": [], "source": [ "rb = at(x)\n", "rb[0,0,0]" ] }, { "cell_type": "code", "execution_count": null, "id": "5a4f25e8", "metadata": {}, "outputs": [], "source": [ "sqkv = nn.Linear(ni, ni*3)\n", "st = sqkv(t)\n", "st.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "64df0786", "metadata": {}, "outputs": [], "source": [ "q,k,v = torch.chunk(st, 3, dim=-1)\n", "q.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "afdd291d", "metadata": {}, "outputs": [], "source": [ "(k@q.transpose(1,2)).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "cde31928", "metadata": {}, "outputs": [], "source": [ "class SelfAttention(nn.Module):\n", " def __init__(self, ni):\n", " super().__init__()\n", " self.scale = math.sqrt(ni)\n", " self.norm = nn.BatchNorm2d(ni)\n", " self.qkv = nn.Linear(ni, ni*3)\n", " self.proj = nn.Linear(ni, ni)\n", " \n", " def forward(self, inp):\n", " n,c,h,w = inp.shape\n", " x = self.norm(inp).view(n, c, -1).transpose(1, 2)\n", " q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)\n", " s = (q@k.transpose(1,2))/self.scale\n", " x = s.softmax(dim=-1)@v\n", " x = self.proj(x).transpose(1,2).reshape(n,c,h,w)\n", " return x+inp" ] }, { "cell_type": "code", "execution_count": null, "id": "a1caa223", "metadata": {}, "outputs": [], "source": [ "class SelfAttention(nn.Module):\n", " def __init__(self, ni):\n", " super().__init__()\n", " self.scale = math.sqrt(ni)\n", " self.norm = nn.BatchNorm2d(ni)\n", " self.qkv = nn.Linear(ni, ni*3)\n", " self.proj = nn.Linear(ni, ni)\n", " \n", " def forward(self, x):\n", " x = self.norm(x).transpose(1, 2)\n", " q,k,v = torch.chunk(self.qkv(x), 3, dim=-1)\n", " s = (q@k.transpose(1,2))/self.scale\n", " x = s.softmax(dim=-1)@v\n", " return self.proj(x).transpose(1,2)" ] }, { "cell_type": "code", "execution_count": null, "id": "67202ea5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([64, 32, 16, 16])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sa = SelfAttention(32)\n", "sa(x).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "eadf0c2e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.0047, grad_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sa(x).std()" ] }, { "cell_type": "code", "execution_count": null, "id": "6a167b96", "metadata": {}, "outputs": [], "source": [ "def heads_to_batch(x, heads):\n", " n,sl,d = x.shape\n", " x = x.reshape(n, sl, heads, -1)\n", " return x.transpose(2, 1).reshape(n*heads,sl,-1)\n", "\n", "def batch_to_heads(x, heads):\n", " n,sl,d = x.shape\n", " x = x.reshape(-1, heads, sl, d)\n", " return x.transpose(2, 1).reshape(-1,sl,d*heads)" ] }, { "cell_type": "code", "execution_count": null, "id": "11734bcd", "metadata": {}, "outputs": [], "source": [ "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": null, "id": "7c3466d6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t2 = rearrange(t , 'n s (h d) -> (n h) s d', h=8)\n", "t.shape, t2.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "64105f9b", "metadata": {}, "outputs": [], "source": [ "t3 = rearrange(t2, '(n h) s d -> n s (h d)', h=8)" ] }, { "cell_type": "code", "execution_count": null, "id": "0f34fabe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([512, 256, 4]), torch.Size([64, 256, 32]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t2.shape,t3.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "b4739d87", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(True)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(t==t3).all()" ] }, { "cell_type": "code", "execution_count": null, "id": "ced6c513", "metadata": {}, "outputs": [], "source": [ "class SelfAttentionMultiHead(nn.Module):\n", " def __init__(self, ni, nheads):\n", " super().__init__()\n", " self.nheads = nheads\n", " self.scale = math.sqrt(ni/nheads)\n", " self.norm = nn.BatchNorm2d(ni)\n", " self.qkv = nn.Linear(ni, ni*3)\n", " self.proj = nn.Linear(ni, ni)\n", " \n", " def forward(self, inp):\n", " n,c,h,w = inp.shape\n", " x = self.norm(inp).view(n, c, -1).transpose(1, 2)\n", " x = self.qkv(x)\n", " x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)\n", " q,k,v = torch.chunk(x, 3, dim=-1)\n", " s = (q@k.transpose(1,2))/self.scale\n", " x = s.softmax(dim=-1)@v\n", " x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)\n", " x = self.proj(x).transpose(1,2).reshape(n,c,h,w)\n", " return x+inp" ] }, { "cell_type": "code", "execution_count": null, "id": "b3ed8798", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([64, 32, 16, 16])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sa = SelfAttentionMultiHead(32, 4)\n", "sx = sa(x)\n", "sx.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "18c46b50", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.0248, grad_fn=),\n", " tensor(1.0069, grad_fn=))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sx.mean(),sx.std()" ] }, { "cell_type": "code", "execution_count": null, "id": "cd9b3f90", "metadata": {}, "outputs": [], "source": [ "nm = nn.MultiheadAttention(32, num_heads=8, batch_first=True)\n", "nmx,nmw = nm(t,t,t)\n", "nmx = nmx+t" ] }, { "cell_type": "code", "execution_count": null, "id": "451f2d42", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(-0.0021, grad_fn=),\n", " tensor(1.0015, grad_fn=))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nmx.mean(),nmx.std()" ] }, { "cell_type": "code", "execution_count": null, "id": "6078e184", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }