{
"cells": [
{
"cell_type": "markdown",
"id": "0c6e8fa3-2900-4879-a610-3c038f8c8d23",
"metadata": {},
"source": [
"# EinMix: universal toolkit for advanced MLP architectures\n",
"\n",
"Recent progress in MLP-based architectures demonstrated that *very specific* MLPs can compete with convnets and transformers (and even outperform them).\n",
"\n",
"EinMix allows writing such architectures in a more uniform and readable way."
]
},
{
"cell_type": "markdown",
"id": "dc9a3913-5d5f-41da-8946-f17b3e537ec4",
"metadata": {},
"source": [
"## EinMix — building block of MLPs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f24a8f2-6680-4298-9736-081719c53f88",
"metadata": {},
"outputs": [],
"source": [
"from einops.layers.torch import EinMix as Mix\n",
"\n",
"# tutorial uses torch. EinMix is available for other frameworks too\n",
"from torch import nn\n",
"from torch.nn import functional as F"
]
},
{
"cell_type": "markdown",
"id": "eacd97e9-7ae2-4321-a3ad-3fb9b122fd81",
"metadata": {},
"source": [
"Logic of EinMix is very close to the one of `einsum`. \n",
"If you're not familiar with einsum, follow these guides first:\n",
"\n",
"- https://rockt.github.io/2018/04/30/einsum\n",
"- https://towardsdatascience.com/einsum-an-underestimated-function-99ca96e2942e\n",
"\n",
"Einsum uniformly describes a number of operations. \n",
"`EinMix` is a *layer* (not function) implementing a similar logic, it has some differences with `einsum`.\n",
"\n",
"Let's implement simple linear layer using einsum\n",
"\n",
"```python\n",
"weight = <...create and initialize parameter...>\n",
"bias = <...create and initialize parameter...>\n",
"result = torch.einsum('tbc,cd->tbd', embeddings, weight) + bias\n",
"```\n",
"\n",
"EinMix counter-part is:\n",
"\n",
"```python\n",
"mix_channels = Mix('t b c -> t b c_out', weight_shape='c c_out', bias_shape='c_out', ...)\n",
"result = mix_channels(embeddings)\n",
"```\n",
"\n",
"Main differences compared to plain `einsum` are:\n",
"\n",
"- layer takes care of the parameter initialization & management\n",
"- weight is not in the comprehension\n",
"- EinMix includes bias term\n",
"\n",
"We'll discuss other changes a bit later, now let's implement some elements from MLPMixer."
]
},
{
"cell_type": "markdown",
"id": "05428b35-5788-4e7f-867e-aa742d7215e7",
"metadata": {},
"source": [
"## TokenMixer from MLPMixer — original code\n",
"\n",
"We start from pytorch [implementation](https://github.com/jaketae/mlp-mixer/blob/e7d68dfc31e94721724689e6ec90f05806b50124/mlp_mixer/core.py) of MLPMixer by Jake Tae.\n",
"\n",
"We'll focus on two components of MLPMixer that don't exist in convnets. First component is TokenMixer:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8f6a6073-44af-40ec-8154-02ae8370f749",
"metadata": {},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self, num_features, expansion_factor, dropout):\n",
" super().__init__()\n",
" num_hidden = num_features * expansion_factor\n",
" self.fc1 = nn.Linear(num_features, num_hidden)\n",
" self.dropout1 = nn.Dropout(dropout)\n",
" self.fc2 = nn.Linear(num_hidden, num_features)\n",
" self.dropout2 = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" x = self.dropout1(F.gelu(self.fc1(x)))\n",
" x = self.dropout2(self.fc2(x))\n",
" return x\n",
"\n",
"\n",
"class TokenMixer(nn.Module):\n",
" def __init__(self, num_features, num_patches, expansion_factor, dropout):\n",
" super().__init__()\n",
" self.norm = nn.LayerNorm(num_features)\n",
" self.mlp = MLP(num_patches, expansion_factor, dropout)\n",
"\n",
" def forward(self, x):\n",
" # x.shape == (batch_size, num_patches, num_features)\n",
" residual = x\n",
" x = self.norm(x)\n",
" x = x.transpose(1, 2)\n",
" # x.shape == (batch_size, num_features, num_patches)\n",
" x = self.mlp(x)\n",
" x = x.transpose(1, 2)\n",
" # x.shape == (batch_size, num_patches, num_features)\n",
" out = x + residual\n",
" return out"
]
},
{
"cell_type": "markdown",
"id": "362fa6ac-7146-4f7b-9e1d-0765486c26da",
"metadata": {},
"source": [
"## TokenMixer from MLPMixer — reimplemented\n",
"\n",
"We can significantly reduce amount of code by using `EinMix`. \n",
"\n",
"- Main caveat addressed by original code is that `nn.Linear` mixes only last axis. `EinMix` can mix any axis (or set of axes).\n",
"- Sequential structure is always preferred as it is easier to follow\n",
"- Intentionally there is no residual connection in `TokenMixer`, because honestly it's not work of Mixer and should be done by caller\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f447c6c5-932f-4929-991a-f97b14776618",
"metadata": {},
"outputs": [],
"source": [
"def TokenMixer(num_features: int, n_patches: int, expansion_factor: int, dropout: float):\n",
" n_hidden = n_patches * expansion_factor\n",
" return nn.Sequential(\n",
" nn.LayerNorm(num_features),\n",
" Mix(\"b hw c -> b hid c\", weight_shape=\"hw hid\", bias_shape=\"hid\", hw=n_patches, hidden=n_hidden),\n",
" nn.GELU(),\n",
" nn.Dropout(dropout),\n",
" Mix(\"b hid c -> b hw c\", weight_shape=\"hid hw\", bias_shape=\"hw\", hw=n_patches, hidden=n_hidden),\n",
" nn.Dropout(dropout),\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "19caacfc-a066-4727-918e-b3a98a190014",
"metadata": {},
"source": [
"You may also check another [implementation](https://github.com/lucidrains/mlp-mixer-pytorch/blob/main/mlp_mixer_pytorch/mlp_mixer_pytorch.py) of MLPMixer from Phil Wang.
\n",
"Phil solves the issue by repurposing `nn.Conv1d` to mix on the second dimension. Hacky, but does the job\n"
]
},
{
"cell_type": "markdown",
"id": "5210dfd1-0195-458f-a8b9-819547123bad",
"metadata": {},
"source": [
"## MLPMixer's patch embeddings (aka ViT patch embeddings) — original\n",
"\n",
"Second interesting part of MLPMixer is derived from vision transformers.\n",
"\n",
"In the very beginning an image is split into patches, and each patch is linearly projected into embedding:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "24752920-006c-49e1-acb9-227543f84afb",
"metadata": {},
"outputs": [],
"source": [
"def check_sizes(image_size, patch_size):\n",
" sqrt_num_patches, remainder = divmod(image_size, patch_size)\n",
" assert remainder == 0, \"`image_size` must be divisibe by `patch_size`\"\n",
" num_patches = sqrt_num_patches ** 2\n",
" return num_patches\n",
"\n",
"class Patcher(nn.Module):\n",
" def __init__(\n",
" self,\n",
" image_size=256,\n",
" patch_size=16,\n",
" in_channels=3,\n",
" num_features=128,\n",
" ):\n",
" _num_patches = check_sizes(image_size, patch_size)\n",
" super().__init__()\n",
" # per-patch fully-connected is equivalent to strided conv2d\n",
" self.patcher = nn.Conv2d(\n",
" in_channels, num_features, kernel_size=patch_size, stride=patch_size\n",
" )\n",
"\n",
" def forward(self, x):\n",
" patches = self.patcher(x)\n",
" batch_size, num_features, _, _ = patches.shape\n",
" patches = patches.permute(0, 2, 3, 1)\n",
" patches = patches.view(batch_size, -1, num_features)\n",
"\n",
" return patches"
]
},
{
"cell_type": "markdown",
"id": "b5ecaac5-d5fe-48b7-9339-fe84fdccd9ac",
"metadata": {},
"source": [
"## MLPMixer's patch embeddings — reimplemented\n",
"\n",
"`EinMix` does this in a single operation. This may require some training at first to understand.\n",
"\n",
"Let's go step-by-step:\n",
"\n",
"- `b c_in (h hp) (w wp) ->` - 4-dimensional input tensor (BCHW-ordered) is split into patches of shape `hp x wp`\n",
"- `weight_shape='c_in hp wp c'`. Axes `c_in`, `hp` and `wp` are all absent in the output: three dimensional patch tensor was *mixed* to produce a vector of length `c`\n",
"- `-> b (h w) c` - output is 3-dimensional. All patches were reorganized from `h x w` grid to one-dimensional sequence of vectors\n",
"\n",
"\n",
"We don't need to provide image_size beforehead, new implementation handles images of different dimensions as long as they can be divided into patches"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2b786a62-a742-49d3-9e66-d02eb2f4e384",
"metadata": {},
"outputs": [],
"source": [
"def patcher(patch_size=16, in_channels=3, num_features=128):\n",
" return Mix(\"b c_in (h hp) (w wp) -> b (h w) c\", weight_shape=\"c_in hp wp c\", bias_shape=\"c\",\n",
" c=num_features, hp=patch_size, wp=patch_size, c_in=in_channels)"
]
},
{
"cell_type": "markdown",
"id": "a26b2f34-07cd-4fc2-bdf9-0043dd3451ea",
"metadata": {},
"source": [
"## Vision Permutator\n",
"\n",
"As a third example we consider pytorch-like code from [ViP paper](https://arxiv.org/pdf/2106.12368.pdf).\n",
"\n",
"Vision permutator is only slightly more nuanced than previous models, because \n",
"1. it operates on spatial dimensions separately, while MLPMixer and its friends just pack all spatial info into one axis. \n",
"2. it splits channels into groups called 'segments'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2c94453b-bc78-4f06-8d51-f110e4b6501a",
"metadata": {},
"outputs": [],
"source": [
"class WeightedPermuteMLP(nn.Module):\n",
" def __init__(self, H, W, C, S):\n",
" super().__init__()\n",
"\n",
" self.proj_h = nn.Linear(H * S, H * S)\n",
" self.proj_w = nn.Linear(W * S, W * S)\n",
" self.proj_c = nn.Linear(C, C)\n",
" self.proj = nn.Linear(C, C)\n",
" self.S = S\n",
"\n",
" def forward(self, x):\n",
" B, H, W, C = x.shape\n",
" S = self.S\n",
" N = C // S\n",
" x_h = x.reshape(B, H, W, N, S).permute(0, 3, 2, 1, 4).reshape(B, N, W, H*S)\n",
" x_h = self.proj_h(x_h).reshape(B, N, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)\n",
"\n",
" x_w = x.reshape(B, H, W, N, S).permute(0, 1, 3, 2, 4).reshape(B, H, N, W*S)\n",
" x_w = self.proj_w(x_w).reshape(B, H, N, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)\n",
"\n",
" x_c = self.proj_c(x)\n",
"\n",
" x = x_h + x_w + x_c\n",
" x = self.proj(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "27a4aa53-e719-4079-973a-83b07d507c1f",
"metadata": {},
"source": [
"That didn't look readable, right? \n",
"\n",
"This code is also very inflexible: code in the paper did not support batch dimension, and multiple changes were necessary to allow batch processing.
\n",
"This process is fragile and easily can result in virtually uncatchable bugs.\n",
"\n",
"Now good news: each of these long method chains can be replaced with a single `EinMix` layer:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a0a699c0-f0b9-4e2c-8dc1-b33f5a7dfa1e",
"metadata": {},
"outputs": [],
"source": [
"class WeightedPermuteMLP_new(nn.Module):\n",
" def __init__(self, H, W, C, seg_len):\n",
" super().__init__()\n",
" assert C % seg_len == 0, f\"can't divide {C} into segments of length {seg_len}\"\n",
" self.mlp_c = Mix(\"b h w c -> b h w c0\", weight_shape=\"c c0\", bias_shape=\"c0\",\n",
" c=C, c0=C)\n",
" self.mlp_h = Mix(\"b h w (n c) -> b h0 w (n c0)\", weight_shape=\"h c h0 c0\", bias_shape=\"h0 c0\",\n",
" h=H, h0=H, c=seg_len, c0=seg_len)\n",
" self.mlp_w = Mix(\"b h w (n c) -> b h w0 (n c0)\", weight_shape=\"w c w0 c0\", bias_shape=\"w0 c0\",\n",
" w=W, w0=W, c=seg_len, c0=seg_len)\n",
" self.proj = nn.Linear(C, C)\n",
"\n",
" def forward(self, x):\n",
" x = self.mlp_c(x) + self.mlp_h(x) + self.mlp_w(x)\n",
" return self.proj(x)"
]
},
{
"cell_type": "markdown",
"id": "de9937d8",
"metadata": {},
"source": [
"### Multi-head attention, once again\n",
"\n",
"EinMix can be (mis)used to compute multiple projections and perform transpositions along the way.\n",
"\n",
"For example, F.scaled_dot_product_attention wants a specific order of axes, and an explicit head axis.\n",
"We can combine linear projection with providing desired order of arguments in a single operation."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d899ed57",
"metadata": {},
"outputs": [],
"source": [
"class MultiheadAttention(nn.Module):\n",
" def __init__(self, dim_input, n_heads, head_dim):\n",
" super().__init__()\n",
" self.input_to_qkv = Mix(\"b t c -> qkv b h t hid\", \"c qkv h hid\",\n",
" c=dim_input, qkv=3, h=n_heads, hid=head_dim)\n",
" self.out_proj = Mix(\"b h t hid -> b t c\", \"h hid c\",\n",
" h=n_heads, hid=head_dim, c=dim_input)\n",
"\n",
" def forward(self, x):\n",
" q, k, v = self.input_to_qkv(x) # fused projections, computed in one go\n",
" return self.out_proj(F.scaled_dot_product_attention(q, k, v)) # flash attention"
]
},
{
"cell_type": "markdown",
"id": "f05b34fa",
"metadata": {},
"source": [
"## Exercises\n",
"\n",
"1. Many normalizations (batch norm, layer norm, etc) use affine scaling afterwards. \n",
" Implement this scaling using `EinMix`.\n",
"\n",
"\n",
"2. let's assume you have an input tensor of shape `[b, t, n_groups, n_channels]`, and you want to apply a separate linear layer to every group of channels. \n",
"\n",
" This will introduce `n_groups` matrices of shape `[n_channels, n_channels]` and `n_groups` biases of shape `[n_channels]`. Can you perform this operation with just one `EinMix`?"
]
},
{
"cell_type": "markdown",
"id": "6694ade8-557d-4461-8f3a-4bd5f74dbea2",
"metadata": {},
"source": [
"## Final remarks\n",
"\n",
"`EinMix` helps with MLPs that don't fit into a limited 'mix all in the last axis' paradigm, and specially helpful for non-1d inputs (images, videos, etc).\n",
"\n",
"However existing research does not cover real possibilities of densely connected architectures.\n",
"\n",
"Most of its *systematic* novelty is \"mix along spatial axes actually works\". \n",
"But `EinMix` provides **an astonishing amount of other possibilities!**. Let me mention some examples:"
]
},
{
"cell_type": "markdown",
"id": "a1ee15e6-933a-4c3e-842e-e908a0aeee15",
"metadata": {},
"source": [
"\n",
"### Mixing within a patch on a grid\n",
"\n",
"What if you make mixing 'local' in space? Completely doable:\n",
"\n",
"```python\n",
"'b c (h hI) (w wI) -> b c (h hO) (w wO)', weight_shape='c hI wI hO wO'\n",
"```\n",
"\n",
"We split tensor into patches of shape `hI wI` and mixed per-channel.\n",
"\n",
"### Mixing in subgrids\n",
"\n",
"Opposite question: how to collect information from the whole image (without attention)?
\n",
"\n",
"Well, you can again densely connect all the tokens, but all-to-all connection can be too expensive.\n",
"\n",
"Here is EinMix-way: split the image into subgrids (each subgrid has steps `h` and `w`), and connect densely tokens within each subgrid\n",
"\n",
"```python\n",
"'b c (hI h) (wI w) -> b c (hO h) (wO w)', weight_shape='c hI wI hO wO'\n",
"```\n",
"\n",
"\n",
"### Going deeper\n",
"\n",
"And that's very top of the iceberg.\n",
"\n",
"- Want to mix part of axis? — No problems!\n",
"- ... in a grid-like manner — Supported! \n",
"- ... while mixing channels within group? — Welcome! \n",
"- In 2d/3d/4d? — Sure!\n",
"- Don't use pytorch? — EinMix is available for multiple frameworks!\n",
"\n",
"Hopefully this guide helped you to find MLPs more interesting and intriguing. And simpler to experiment with."
]
},
{
"cell_type": "markdown",
"id": "03790b14",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py12",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}