{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# einops.pack and einops.unpack\n", "\n", "einops 0.6 introduces two more functions to the family: `pack` and `unpack`.\n", "\n", "Here is what they do:\n", "\n", "- `unpack` reverses `pack`\n", "- `pack` reverses `unpack`\n", "\n", "Enlightened with this exhaustive description, let's move to examples.\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "# we'll use numpy for demo purposes\n", "# operations work the same way with other frameworks\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stacking data layers\n", "\n", "Assume we have RGB image along with a corresponding depth image that we want to stack:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "from einops import pack, unpack\n", "\n", "h, w = 100, 200\n", "# image_rgb is 3-dimensional (h, w, 3) and depth is 2-dimensional (h, w)\n", "image_rgb = np.random.random([h, w, 3])\n", "image_depth = np.random.random([h, w])\n", "# but we can stack them\n", "image_rgbd, ps = pack([image_rgb, image_depth], 'h w *')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to read packing patterns\n", "\n", "pattern `h w *` means that\n", "- output is 3-dimensional\n", "- first two axes (`h` and `w`) are shared across all inputs and also shared with output\n", "- inputs, however do not have to be 3-dimensional. They can be 2-dim, 3-dim, 4-dim, etc. <br />\n", " Regardless of inputs dimensionality, they all will be packed into 3-dim output, and information about how they were packed is stored in `PS`" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "data": { "text/plain": [ "((100, 200, 3), (100, 200), (100, 200, 4))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# as you see, pack properly appended depth as one more layer\n", "# and correctly aligned axes!\n", "# this won't work off the shelf with np.concatenate or torch.cat or alike\n", "image_rgb.shape, image_depth.shape, image_rgbd.shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "data": { "text/plain": [ "[(3,), ()]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# now let's see what PS keeps.\n", "# PS means Packed Shapes, not PlayStation or Post Script\n", "ps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "which reads: first tensor had shape `h, w, *and 3*`, while second tensor had shape `h, w *and nothing more*`.\n", "That's just enough to reverse packing:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "data": { "text/plain": [ "((100, 200, 3), (100, 200))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# remove 1-axis in depth image during unpacking. Results are (h, w, 3) and (h, w)\n", "unpacked_rgb, unpacked_depth = unpack(image_rgbd, ps, 'h w *')\n", "unpacked_rgb.shape, unpacked_depth.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "we can unpack tensor in different ways manually:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "# simple unpack by splitting the axis. Results are (h, w, 3) and (h, w, 1)\n", "rgb, depth = unpack(image_rgbd, [[3], [1]], 'h w *')\n", "# different split, both outputs have shape (h, w, 2)\n", "rg, bd = unpack(image_rgbd, [[2], [2]], 'h w *')\n", "# unpack to 4 tensors of shape (h, w). More like 'unstack over last axis'\n", "[r, g, b, d] = unpack(image_rgbd, [[], [], [], []], 'h w *')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Short summary so far\n", "\n", "- `einops.pack` is a 'more generic concatenation' (that can stack too)\n", "- `einops.unpack` is a 'more generic split'\n", "\n", "And, of course, `einops` functions are more verbose, and *reversing* concatenation now is *dead simple*\n", "\n", "Compared to other `einops` functions, `pack` and `unpack` have a compact pattern without arrow, and the same pattern can be used in `pack` and `unpack`. These patterns are very simplistic: just a sequence of space-separated axes names.\n", "One axis is `*`, all other axes are valid identifiers.\n", "\n", "Now let's discuss some practical cases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Auto-batching\n", "\n", "ML models by default accept batches: batch of images, or batch of sentences, or batch of audios, etc.\n", "\n", "During debugging or inference, however, it is common to pass a single image instead (and thus output should be a single prediction) <br />\n", "In this example we'll write `universal_predict` that can handle both cases." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "from einops import reduce\n", "def image_classifier(images_bhwc):\n", " # mock for image classifier\n", " predictions = reduce(images_bhwc, 'b h w c -> b c', 'mean', h=100, w=200, c=3)\n", " return predictions\n", "\n", "\n", "def universal_predict(x):\n", " x_packed, ps = pack([x], '* h w c')\n", " predictions_packed = image_classifier(x_packed)\n", " [predictions] = unpack(predictions_packed, ps, '* cls')\n", " return predictions" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(3,)\n", "(5, 3)\n", "(5, 7, 3)\n" ] } ], "source": [ "# works with a single image\n", "print(universal_predict(np.zeros([h, w, 3])).shape)\n", "# works with a batch of images\n", "batch = 5\n", "print(universal_predict(np.zeros([batch, h, w, 3])).shape)\n", "# or even a batch of videos\n", "n_frames = 7\n", "print(universal_predict(np.zeros([batch, n_frames, h, w, 3])).shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**what we can learn from this example**:\n", "\n", "- `pack` and `unpack` play nicely together. That's not a coincidence :)\n", "- patterns in `pack` and `unpack` may differ, and that's quite common for applications\n", "- unlike other operations in `einops`, `(un)pack` does not provide arbitrary reordering of axes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Class token in VIT\n", "\n", "Let's assume we have a simple transformer model that works with `BTC`-shaped tensors." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "def transformer_mock(x_btc):\n", " # imagine this is a transformer model, a very efficient one\n", " assert len(x_btc.shape) == 3\n", " return x_btc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's implement vision transformer (ViT) with a class token (i.e. static token, corresponding output is used to classify an image)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "# below it is assumed that you already\n", "# 1) split batch of images into patches 2) applied linear projection and 3) used positional embedding.\n", "\n", "# We'll skip that here. But hey, here is an einops-style way of doing all of that in a single shot!\n", "# from einops.layers.torch import EinMix\n", "# patcher_and_posembedder = EinMix('b (h h2) (w w2) c -> b h w c_out', weight_shape='h2 w2 c c_out',\n", "# bias_shape='h w c_out', h2=..., w2=...)\n", "# patch_tokens_bhwc = patcher_and_posembedder(images_bhwc)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "# preparations\n", "batch, height, width, c = 6, 16, 16, 256\n", "patch_tokens = np.random.random([batch, height, width, c])\n", "class_tokens = np.zeros([batch, c])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "data": { "text/plain": [ "((6, 256), (6, 16, 16, 256))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def vit_einops(class_tokens, patch_tokens):\n", " input_packed, ps = pack([class_tokens, patch_tokens], 'b * c')\n", " output_packed = transformer_mock(input_packed)\n", " return unpack(output_packed, ps, 'b * c_out')\n", "\n", "class_token_emb, patch_tokens_emb = vit_einops(class_tokens, patch_tokens)\n", "\n", "class_token_emb.shape, patch_tokens_emb.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At this point, let's make a small pause and understand conveniences of this pipeline, by contrasting it to more 'standard' code" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "def vit_vanilla(class_tokens, patch_tokens):\n", " b, h, w, c = patch_tokens.shape\n", " class_tokens_b1c = class_tokens[:, np.newaxis, :]\n", " patch_tokens_btc = np.reshape(patch_tokens, [b, -1, c])\n", " input_packed = np.concatenate([class_tokens_b1c, patch_tokens_btc], axis=1)\n", " output_packed = transformer_mock(input_packed)\n", " class_token_emb = np.squeeze(output_packed[:, :1, :], 1)\n", " patch_tokens_emb = np.reshape(output_packed[:, 1:, :], [b, h, w, -1])\n", " return class_token_emb, patch_tokens_emb\n", "\n", "class_token_emb2, patch_tokens_emb2 = vit_vanilla(class_tokens, patch_tokens)\n", "assert np.allclose(class_token_emb, class_token_emb2)\n", "assert np.allclose(patch_tokens_emb, patch_tokens_emb2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notably, we have put all packing and unpacking, reshapes, adding and removing of dummy axes into a couple of lines." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packing different modalities together\n", "\n", "We can extend the previous example: it is quite common to mix elements of different types of inputs in transformers.\n", "\n", "The simples one is to mix tokens from all inputs:\n", "\n", "```python\n", "all_inputs = [text_tokens_btc, image_bhwc, task_token_bc, static_tokens_bnc]\n", "inputs_packed, ps = pack(all_inputs, 'b * c')\n", "```\n", "\n", "and you can `unpack` resulting tokens to the same structure." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packing data coming from different sources together\n", "\n", "Most notable example is of course GANs:\n", "\n", "```python\n", "input_ims, ps = pack([true_images, fake_images], '* h w c')\n", "true_pred, fake_pred = unpack(model(input_ims), ps, '* c')\n", "```\n", "`true_pred` and `fake_pred` are handled differently, that's why we separated them" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predicting multiple outputs at the same time\n", "\n", "It is quite common to pack prediction of multiple target values into a single layer.\n", "\n", "This is more efficient, but code is less readable. For example, that's how detection code may look like:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "def loss_detection(model_output_bhwc, mask_h: int, mask_w: int, n_classes: int):\n", " output = model_output_bhwc\n", "\n", " confidence = output[..., 0].sigmoid()\n", " bbox_x_shift = output[..., 1].sigmoid()\n", " bbox_y_shift = output[..., 2].sigmoid()\n", " bbox_w = output[..., 3]\n", " bbox_h = output[..., 4]\n", " mask_logits = output[..., 5: 5 + mask_h * mask_w]\n", " mask_logits = mask_logits.reshape([*mask_logits.shape[:-1], mask_h, mask_w])\n", " class_logits = output[..., 5 + mask_h * mask_w:]\n", " assert class_logits.shape[-1] == n_classes, class_logits.shape[-1]\n", "\n", " # downstream computations\n", " return confidence, bbox_x_shift, bbox_y_shift, bbox_h, bbox_w, mask_logits, class_logits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When the same logic is implemented in einops, there is no need to memorize offsets. <br />\n", "Additionally, reshapes and shape checks are automatic:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "def loss_detection_einops(model_output, mask_h: int, mask_w: int, n_classes: int):\n", " confidence, bbox_x_shift, bbox_y_shift, bbox_w, bbox_h, mask_logits, class_logits \\\n", " = unpack(model_output, [[]] * 5 + [[mask_h, mask_w], [n_classes]], 'b h w *')\n", "\n", " confidence = confidence.sigmoid()\n", " bbox_x_shift = bbox_x_shift.sigmoid()\n", " bbox_y_shift = bbox_y_shift.sigmoid()\n", "\n", " # downstream computations\n", " return confidence, bbox_x_shift, bbox_y_shift, bbox_h, bbox_w, mask_logits, class_logits" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "# check that results are identical\n", "import torch\n", "dims = dict(mask_h=6, mask_w=8, n_classes=19)\n", "model_output = torch.randn([3, 5, 7, 5 + dims['mask_h'] * dims['mask_w'] + dims['n_classes']])\n", "for a, b in zip(loss_detection(model_output, **dims), loss_detection_einops(model_output, **dims)):\n", " assert torch.allclose(a, b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or maybe **reinforcement learning** is closer to your mind?\n", "\n", "If so, predicting multiple outputs is valuable there too:\n", "\n", "```python\n", "action_logits, reward_expectation, q_values, expected_entropy_after_action = \\\n", " unpack(predictions_btc, [[n_actions], [], [n_actions], [n_actions]], 'b step *')\n", "\n", "\n", "```\n", "\n", "\n", "## That's all for today!\n", "\n", "happy packing and unpacking!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }