{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# einops.pack and einops.unpack\n",
    "\n",
    "einops 0.6 introduces two more functions to the family: `pack` and `unpack`.\n",
    "\n",
    "Here is what they do:\n",
    "\n",
    "- `unpack` reverses `pack`\n",
    "- `pack` reverses `unpack`\n",
    "\n",
    "Enlightened with this exhaustive description, let's move to examples.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# we'll use numpy for demo purposes\n",
    "# operations work the same way with other frameworks\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stacking data layers\n",
    "\n",
    "Assume we have RGB image along with a corresponding depth image that we want to stack:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from einops import pack, unpack\n",
    "\n",
    "h, w = 100, 200\n",
    "# image_rgb is 3-dimensional (h, w, 3) and depth is 2-dimensional (h, w)\n",
    "image_rgb = np.random.random([h, w, 3])\n",
    "image_depth = np.random.random([h, w])\n",
    "# but we can stack them\n",
    "image_rgbd, ps = pack([image_rgb, image_depth], 'h w *')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to read packing patterns\n",
    "\n",
    "pattern `h w *` means that\n",
    "- output is 3-dimensional\n",
    "- first two axes (`h` and `w`) are shared across all inputs and also shared with output\n",
    "- inputs, however do not have to be 3-dimensional. They can be 2-dim, 3-dim, 4-dim, etc. <br />\n",
    "  Regardless of inputs dimensionality, they all will be packed into 3-dim output, and information about how they were packed is stored in `PS`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((100, 200, 3), (100, 200), (100, 200, 4))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# as you see, pack properly appended depth as one more layer\n",
    "# and correctly aligned axes!\n",
    "# this won't work off the shelf with np.concatenate or torch.cat or alike\n",
    "image_rgb.shape, image_depth.shape, image_rgbd.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(3,), ()]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# now let's see what PS keeps.\n",
    "# PS means Packed Shapes, not PlayStation or Post Script\n",
    "ps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "which reads: first tensor had shape `h, w, *and 3*`, while second tensor had shape `h, w *and nothing more*`.\n",
    "That's just enough to reverse packing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((100, 200, 3), (100, 200))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# remove 1-axis in depth image during unpacking. Results are (h, w, 3) and (h, w)\n",
    "unpacked_rgb, unpacked_depth = unpack(image_rgbd, ps, 'h w *')\n",
    "unpacked_rgb.shape, unpacked_depth.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "we can unpack tensor in different ways manually:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# simple unpack by splitting the axis. Results are (h, w, 3) and (h, w, 1)\n",
    "rgb, depth = unpack(image_rgbd, [[3], [1]], 'h w *')\n",
    "# different split, both outputs have shape (h, w, 2)\n",
    "rg, bd = unpack(image_rgbd, [[2], [2]], 'h w *')\n",
    "# unpack to 4 tensors of shape (h, w). More like 'unstack over last axis'\n",
    "[r, g, b, d] = unpack(image_rgbd, [[], [], [], []], 'h w *')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Short summary so far\n",
    "\n",
    "- `einops.pack` is a 'more generic concatenation' (that can stack too)\n",
    "- `einops.unpack` is a 'more generic split'\n",
    "\n",
    "And, of course, `einops` functions are more verbose, and *reversing* concatenation now is *dead simple*\n",
    "\n",
    "Compared to other `einops` functions, `pack` and `unpack` have a compact pattern without arrow, and the same pattern can be used in `pack` and `unpack`. These patterns are very simplistic: just a sequence of space-separated axes names.\n",
    "One axis is `*`, all other axes are valid identifiers.\n",
    "\n",
    "Now let's discuss some practical cases"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Auto-batching\n",
    "\n",
    "ML models by default accept batches: batch of images, or batch of sentences, or batch of audios, etc.\n",
    "\n",
    "During debugging or inference, however, it is common to pass a single image instead (and thus output should be a single prediction) <br />\n",
    "In this example we'll write `universal_predict` that can handle both cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from einops import reduce\n",
    "def image_classifier(images_bhwc):\n",
    "    # mock for image classifier\n",
    "    predictions = reduce(images_bhwc, 'b h w c -> b c', 'mean', h=100, w=200, c=3)\n",
    "    return predictions\n",
    "\n",
    "\n",
    "def universal_predict(x):\n",
    "    x_packed, ps = pack([x], '* h w c')\n",
    "    predictions_packed = image_classifier(x_packed)\n",
    "    [predictions] = unpack(predictions_packed, ps, '* cls')\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3,)\n",
      "(5, 3)\n",
      "(5, 7, 3)\n"
     ]
    }
   ],
   "source": [
    "# works with a single image\n",
    "print(universal_predict(np.zeros([h, w, 3])).shape)\n",
    "# works with a batch of images\n",
    "batch = 5\n",
    "print(universal_predict(np.zeros([batch, h, w, 3])).shape)\n",
    "# or even a batch of videos\n",
    "n_frames = 7\n",
    "print(universal_predict(np.zeros([batch, n_frames, h, w, 3])).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**what we can learn from this example**:\n",
    "\n",
    "- `pack` and `unpack` play nicely together. That's not a coincidence :)\n",
    "- patterns in `pack` and `unpack` may differ, and that's quite common for applications\n",
    "- unlike other operations in `einops`, `(un)pack` does not provide arbitrary reordering of axes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Class token in VIT\n",
    "\n",
    "Let's assume we have a simple transformer model that works with `BTC`-shaped tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def transformer_mock(x_btc):\n",
    "    # imagine this is a transformer model, a very efficient one\n",
    "    assert len(x_btc.shape) == 3\n",
    "    return x_btc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's implement vision transformer (ViT) with a class token (i.e. static token, corresponding output is used to classify an image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# below it is assumed that you already\n",
    "# 1) split batch of images into patches 2) applied linear projection and 3) used positional embedding.\n",
    "\n",
    "# We'll skip that here. But hey, here is an einops-style way of doing all of that in a single shot!\n",
    "# from einops.layers.torch import EinMix\n",
    "# patcher_and_posembedder = EinMix('b (h h2) (w w2) c -> b h w c_out', weight_shape='h2 w2 c c_out',\n",
    "#                                  bias_shape='h w c_out', h2=..., w2=...)\n",
    "# patch_tokens_bhwc = patcher_and_posembedder(images_bhwc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# preparations\n",
    "batch, height, width, c = 6, 16, 16, 256\n",
    "patch_tokens = np.random.random([batch, height, width, c])\n",
    "class_tokens = np.zeros([batch, c])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((6, 256), (6, 16, 16, 256))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def vit_einops(class_tokens, patch_tokens):\n",
    "    input_packed, ps = pack([class_tokens, patch_tokens], 'b * c')\n",
    "    output_packed = transformer_mock(input_packed)\n",
    "    return unpack(output_packed, ps, 'b * c_out')\n",
    "\n",
    "class_token_emb, patch_tokens_emb = vit_einops(class_tokens, patch_tokens)\n",
    "\n",
    "class_token_emb.shape, patch_tokens_emb.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, let's make a small pause and understand conveniences of this pipeline, by contrasting it to more 'standard' code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def vit_vanilla(class_tokens, patch_tokens):\n",
    "    b, h, w, c = patch_tokens.shape\n",
    "    class_tokens_b1c = class_tokens[:, np.newaxis, :]\n",
    "    patch_tokens_btc = np.reshape(patch_tokens, [b, -1, c])\n",
    "    input_packed = np.concatenate([class_tokens_b1c, patch_tokens_btc], axis=1)\n",
    "    output_packed = transformer_mock(input_packed)\n",
    "    class_token_emb = np.squeeze(output_packed[:, :1, :], 1)\n",
    "    patch_tokens_emb = np.reshape(output_packed[:, 1:, :], [b, h, w, -1])\n",
    "    return class_token_emb, patch_tokens_emb\n",
    "\n",
    "class_token_emb2, patch_tokens_emb2 = vit_vanilla(class_tokens, patch_tokens)\n",
    "assert np.allclose(class_token_emb, class_token_emb2)\n",
    "assert np.allclose(patch_tokens_emb, patch_tokens_emb2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notably, we have put all packing and unpacking, reshapes, adding and removing of dummy axes into a couple of lines."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Packing different modalities together\n",
    "\n",
    "We can extend the previous example: it is quite common to mix elements of different types of inputs in transformers.\n",
    "\n",
    "The simples one is to mix tokens from all inputs:\n",
    "\n",
    "```python\n",
    "all_inputs = [text_tokens_btc, image_bhwc, task_token_bc, static_tokens_bnc]\n",
    "inputs_packed, ps = pack(all_inputs, 'b * c')\n",
    "```\n",
    "\n",
    "and you can `unpack` resulting tokens to the same structure."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Packing data coming from different sources together\n",
    "\n",
    "Most notable example is of course GANs:\n",
    "\n",
    "```python\n",
    "input_ims, ps = pack([true_images, fake_images], '* h w c')\n",
    "true_pred, fake_pred = unpack(model(input_ims), ps, '* c')\n",
    "```\n",
    "`true_pred` and `fake_pred` are handled differently, that's why we separated them"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predicting multiple outputs at the same time\n",
    "\n",
    "It is quite common to pack prediction of multiple target values into a single layer.\n",
    "\n",
    "This is more efficient, but code is less readable. For example, that's how detection code may look like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def loss_detection(model_output_bhwc, mask_h: int, mask_w: int, n_classes: int):\n",
    "    output = model_output_bhwc\n",
    "\n",
    "    confidence = output[..., 0].sigmoid()\n",
    "    bbox_x_shift = output[..., 1].sigmoid()\n",
    "    bbox_y_shift = output[..., 2].sigmoid()\n",
    "    bbox_w = output[..., 3]\n",
    "    bbox_h = output[..., 4]\n",
    "    mask_logits = output[..., 5: 5 + mask_h * mask_w]\n",
    "    mask_logits = mask_logits.reshape([*mask_logits.shape[:-1], mask_h, mask_w])\n",
    "    class_logits = output[..., 5 + mask_h * mask_w:]\n",
    "    assert class_logits.shape[-1] == n_classes, class_logits.shape[-1]\n",
    "\n",
    "    # downstream computations\n",
    "    return confidence, bbox_x_shift, bbox_y_shift, bbox_h, bbox_w, mask_logits, class_logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When the same logic is implemented in einops, there is no need to memorize offsets. <br />\n",
    "Additionally, reshapes and shape checks are automatic:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def loss_detection_einops(model_output, mask_h: int, mask_w: int, n_classes: int):\n",
    "    confidence, bbox_x_shift, bbox_y_shift, bbox_w, bbox_h, mask_logits, class_logits \\\n",
    "        = unpack(model_output, [[]] * 5 + [[mask_h, mask_w], [n_classes]], 'b h w *')\n",
    "\n",
    "    confidence = confidence.sigmoid()\n",
    "    bbox_x_shift = bbox_x_shift.sigmoid()\n",
    "    bbox_y_shift = bbox_y_shift.sigmoid()\n",
    "\n",
    "    # downstream computations\n",
    "    return confidence, bbox_x_shift, bbox_y_shift, bbox_h, bbox_w, mask_logits, class_logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# check that results are identical\n",
    "import torch\n",
    "dims = dict(mask_h=6, mask_w=8, n_classes=19)\n",
    "model_output = torch.randn([3, 5, 7, 5 + dims['mask_h'] * dims['mask_w'] + dims['n_classes']])\n",
    "for a, b in zip(loss_detection(model_output, **dims), loss_detection_einops(model_output, **dims)):\n",
    "    assert torch.allclose(a, b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or maybe **reinforcement learning** is closer to your mind?\n",
    "\n",
    "If so, predicting multiple outputs is valuable there too:\n",
    "\n",
    "```python\n",
    "action_logits, reward_expectation, q_values, expected_entropy_after_action = \\\n",
    "    unpack(predictions_btc, [[n_actions], [], [n_actions], [n_actions]], 'b step *')\n",
    "\n",
    "\n",
    "```\n",
    "\n",
    "\n",
    "## That's all for today!\n",
    "\n",
    "happy packing and unpacking!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}