{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Einops tutorial, part 2: deep learning\n",
    "\n",
    "Previous part of tutorial provides visual examples with numpy.\n",
    "\n",
    "## What's in this tutorial?\n",
    "\n",
    "- working with deep learning packages\n",
    "- important cases for deep learning models\n",
    "- `einops.asnumpy` and `einops.layers`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from einops import rearrange, reduce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "x = np.random.RandomState(42).normal(size=[10, 32, 100, 200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# utility to hide answers\n",
    "from utils import guess"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Select your flavour\n",
    "\n",
    "Switch to the framework you're most comfortable with. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select \"tensorflow\" or \"pytorch\"\n",
    "flavour = \"pytorch\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "selected pytorch backend\n"
     ]
    }
   ],
   "source": [
    "print(\"selected {} backend\".format(flavour))\n",
    "if flavour == \"tensorflow\":\n",
    "    import tensorflow as tf\n",
    "\n",
    "    tape = tf.GradientTape(persistent=True)\n",
    "    tape.__enter__()\n",
    "    x = tf.Variable(x) + 0\n",
    "else:\n",
    "    assert flavour == \"pytorch\"\n",
    "    import torch\n",
    "\n",
    "    x = torch.from_numpy(x)\n",
    "    x.requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Tensor, torch.Size([10, 32, 100, 200]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(x), x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple computations \n",
    "\n",
    "- converting bchw to bhwc format and back is a common operation in CV\n",
    "- try to predict output shape and then check your guess!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 100, 200, 32)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, \"b c h w -> b h w c\")\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Worked!\n",
    "\n",
    "Did you notice? Code above worked for you backend of choice. <br />\n",
    "Einops functions work with any tensor like they are native to the framework.\n",
    "\n",
    "## Backpropagation\n",
    "\n",
    "- gradients are a corner stone of deep learning\n",
    "- You can back-propagate through einops operations <br />\n",
    "  (just as with framework native operations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(320., dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "y0 = x\n",
    "y1 = reduce(y0, \"b c h w -> b c\", \"max\")\n",
    "y2 = rearrange(y1, \"b c -> c b\")\n",
    "y3 = reduce(y2, \"c b -> \", \"sum\")\n",
    "\n",
    "if flavour == \"tensorflow\":\n",
    "    print(reduce(tape.gradient(y3, x), \"b c h w -> \", \"sum\"))\n",
    "else:\n",
    "    y3.backward()\n",
    "    print(reduce(x.grad, \"b c h w -> \", \"sum\"))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Meet `einops.asnumpy`\n",
    "\n",
    "Just converts tensors to numpy (and pulls from gpu if necessary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "from einops import asnumpy\n",
    "\n",
    "y3_numpy = asnumpy(y3)\n",
    "\n",
    "print(type(y3_numpy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Common building blocks of deep learning\n",
    "\n",
    "Let's check how some familiar operations can be written with `einops`\n",
    "\n",
    "**Flattening** is common operation, frequently appears at the boundary\n",
    "between convolutional layers and fully connected layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 640000)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, \"b c h w -> b (c h w)\")\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**space-to-depth**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 128, 50, 100)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, \"b c (h h1) (w w1) -> b (h1 w1 c) h w\", h1=2, w1=2)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**depth-to-space** (notice that it's reverse of the previous)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 8, 200, 400)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, \"b (h1 w1 c) h w -> b c (h h1) (w w1)\", h1=2, w1=2)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reductions\n",
    "\n",
    "Simple **global average pooling**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = reduce(x, \"b c h w -> b c\", reduction=\"mean\")\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**max-pooling** with a kernel 2x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 50, 100)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = reduce(x, \"b c (h h1) (w w1) -> b c h w\", reduction=\"max\", h1=2, w1=2)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 50, 100)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# you can skip names for reduced axes\n",
    "y = reduce(x, \"b c (h 2) (w 2) -> b c h w\", reduction=\"max\")\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1d, 2d and 3d pooling are defined in a similar way\n",
    "\n",
    "for sequential 1-d models, you'll probably want pooling over time\n",
    "```python\n",
    "reduce(x, '(t 2) b c -> t b c', reduction='max')\n",
    "```\n",
    "\n",
    "for volumetric models, all three dimensions are pooled\n",
    "```python\n",
    "reduce(x, 'b c (x 2) (y 2) (z 2) -> b c x y z', reduction='max')\n",
    "```\n",
    "\n",
    "Uniformity is a strong point of `einops`, and you don't need specific operation for each particular case.\n",
    "\n",
    "\n",
    "### Good exercises \n",
    "\n",
    "- write a version of space-to-depth for 1d and 3d (2d is provided above)\n",
    "- write an average / max pooling for 1d models. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Squeeze and unsqueeze (expand_dims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# models typically work only with batches,\n",
    "# so to predict a single image ...\n",
    "image = rearrange(x[0, :3], \"c h w -> h w c\")\n",
    "# ... create a dummy 1-element axis ...\n",
    "y = rearrange(image, \"h w c -> () c h w\")\n",
    "# ... imagine you predicted this with a convolutional network for classification,\n",
    "# we'll just flatten axes ...\n",
    "predictions = rearrange(y, \"b c h w -> b (c h w)\")\n",
    "# ... finally, decompose (remove) dummy axis\n",
    "predictions = rearrange(predictions, \"() classes -> classes\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## keepdims-like behavior for reductions\n",
    "\n",
    "- empty composition `()` provides dimensions of length 1, which are broadcastable.\n",
    "- alternatively, you can use just `1` to introduce new axis, that's a synonym to `()`\n",
    "\n",
    "per-channel mean-normalization for *each image*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = x - reduce(x, \"b c h w -> b c 1 1\", \"mean\")\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "per-channel mean-normalization for *whole batch*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = x - reduce(y, \"b c h w -> 1 c 1 1\", \"mean\")\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stacking\n",
    "\n",
    "let's take a list of tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_of_tensors = list(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "New axis (one that enumerates tensors) appears *first* on the left side of expression.\n",
    "Just as if you were indexing list - first you'd get tensor by index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 100, 200, 32)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tensors = rearrange(list_of_tensors, \"b c h w -> b h w c\")\n",
    "guess(tensors.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(100, 200, 32, 10)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# or maybe stack along last dimension?\n",
    "tensors = rearrange(list_of_tensors, \"b c h w -> h w c b\")\n",
    "guess(tensors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concatenation\n",
    "\n",
    "concatenate over the first dimension?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(1000, 200, 32)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tensors = rearrange(list_of_tensors, \"b c h w -> (b h) w c\")\n",
    "guess(tensors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "or maybe concatenate along last dimension?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(100, 200, 320)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tensors = rearrange(list_of_tensors, \"b c h w -> h w (b c)\")\n",
    "guess(tensors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shuffling within a dimension\n",
    "\n",
    "**channel shuffle** (as it is drawn in shufflenet paper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, \"b (g1 g2 c) h w-> b (g2 g1 c) h w\", g1=4, g2=4)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "simpler version of **channel shuffle**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, \"b (g c) h w-> b (c g) h w\", g=4)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Split a dimension\n",
    "\n",
    "Here's a super-convenient trick.\n",
    "\n",
    "Example: when a network predicts several bboxes for each position\n",
    "\n",
    "Assume we got 8 bboxes, 4 coordinates each. <br />\n",
    "To get coordinated into 4 separate variables, you move corresponding dimension to front and unpack tuple."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 8, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "bbox_x, bbox_y, bbox_w, bbox_h = rearrange(x, \"b (coord bbox) h w -> coord b bbox h w\", coord=4, bbox=8)\n",
    "# now you can operate on individual variables\n",
    "max_bbox_area = reduce(bbox_w * bbox_h, \"b bbox h w -> b h w\", \"max\")\n",
    "guess(bbox_x.shape)\n",
    "guess(max_bbox_area.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Getting into the weeds of tensor packing\n",
    "\n",
    "you can skip this part - it explains why taking a habit of defining splits and packs explicitly\n",
    "\n",
    "when implementing custom gated activation (like GLU), split is needed:\n",
    "```python\n",
    "y1, y2 = rearrange(x, 'b (split c) h w -> split b c h w', split=2)\n",
    "result = y2 * sigmoid(y2) # or tanh\n",
    "```\n",
    "... but we could split differently\n",
    "```python\n",
    "y1, y2 = rearrange(x, 'b (c split) h w -> split b c h w', split=2)\n",
    "```\n",
    "\n",
    "- first one splits channels into consequent groups: `y1 = x[:, :x.shape[1] // 2, :, :]`\n",
    "- while second takes channels with a step: `y1 = x[:, 0::2, :, :]`\n",
    "\n",
    "This may drive to very *surprising* results when input is\n",
    "- a result of group convolution\n",
    "- a result of bidirectional LSTM/RNN\n",
    "- multi-head attention\n",
    "\n",
    "Let's focus on the second case (LSTM/RNN), since it is less obvious.\n",
    "\n",
    "For instance, cudnn concatenates LSTM outputs for forward-in-time and backward-in-time\n",
    "\n",
    "Also in pytorch GLU splits channels into consequent groups (first way)\n",
    "So when LSTM's output comes to GLU,\n",
    "- forward-in-time produces linear part, and backward-in-time produces activation ...\n",
    "- and role of directions is different, and gradients coming to two parts are different\n",
    "  - that's not what you expect from simple `GLU(BLSTM(x))`, right?\n",
    "\n",
    "`einops` notation makes such inconsistencies explicit and easy-detectable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shape parsing\n",
    "just a handy utility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from einops import parse_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convolve_2d(x):\n",
    "    # imagine we have a simple 2d convolution with padding,\n",
    "    # so output has same shape as input.\n",
    "    # Sorry for laziness, use imagination!\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imagine we are working with 3d data\n",
    "x_5d = rearrange(x, \"b c x (y z) -> b c x y z\", z=20)\n",
    "# but we have only 2d convolutions.\n",
    "# That's not a problem, since we can apply\n",
    "y = rearrange(x_5d, \"b c x y z -> (b z) c x y\")\n",
    "y = convolve_2d(y)\n",
    "# not just specifies additional information, but verifies that all dimensions match\n",
    "y = rearrange(y, \"(b z) c x y -> b c x y z\", **parse_shape(x_5d, \"b c x y z\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'b': 10, 'c': 32, 'x': 100, 'y': 10, 'z': 20}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parse_shape(x_5d, \"b c x y z\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'batch': 10, 'c': 32}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we can skip some dimensions by writing underscore\n",
    "parse_shape(x_5d, \"batch c _ _ _\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Striding anything\n",
    "\n",
    "Finally, how to convert any operation into a strided operation? <br />\n",
    "(like convolution with strides, aka dilated/atrous convolution)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# each image is split into subgrids, each subgrid now is a separate \"image\"\n",
    "y = rearrange(x, \"b c (h hs) (w ws) -> (hs ws b) c h w\", hs=2, ws=2)\n",
    "y = convolve_2d(y)\n",
    "# pack subgrids back to an image\n",
    "y = rearrange(y, \"(hs ws b) c h w -> b c (h hs) (w ws)\", hs=2, ws=2)\n",
    "\n",
    "assert y.shape == x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layers\n",
    "\n",
    "For frameworks that prefer operating with layers, layers are available.\n",
    "\n",
    "You'll need to import a proper one depending on your backend:\n",
    "\n",
    "```python\n",
    "from einops.layers.torch import Rearrange, Reduce\n",
    "from einops.layers.flax import Rearrange, Reduce\n",
    "from einops.layers.tensorflow import Rearrange, Reduce\n",
    "from einops.layers.chainer import Rearrange, Reduce\n",
    "```\n",
    "\n",
    "`Einops` layers are identical to operations, and have same parameters. <br />\n",
    "(for the exception of first argument, which should be passed during call)\n",
    "\n",
    "```python\n",
    "layer = Rearrange(pattern, **axes_lengths)\n",
    "layer = Reduce(pattern, reduction, **axes_lengths)\n",
    "\n",
    "# apply layer to tensor\n",
    "x = layer(x)\n",
    "```\n",
    "\n",
    "Usually it is more convenient to use layers, not operations, to build models\n",
    "```python\n",
    "# example given for pytorch, but code in other frameworks is almost identical\n",
    "from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU\n",
    "from einops.layers.torch import Reduce\n",
    "\n",
    "model = Sequential(\n",
    "    Conv2d(3, 6, kernel_size=5),\n",
    "    MaxPool2d(kernel_size=2),\n",
    "    Conv2d(6, 16, kernel_size=5),\n",
    "    # combined pooling and flattening in a single step\n",
    "    Reduce('b c (h 2) (w 2) -> b (c h w)', 'max'), \n",
    "    Linear(16*5*5, 120), \n",
    "    ReLU(),\n",
    "    Linear(120, 10), \n",
    "    # In flax, the {'axis': value} syntax for specifying values for axes is mandatory:\n",
    "    # Rearrange('(b1 b2) d -> b1 b2 d', {'b1': 12}), \n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What's now?\n",
    "\n",
    "- rush through [writing better code with einops+pytorch](https://arogozhnikov.github.io/einops/pytorch-examples.html)\n",
    "\n",
    "Use different framework? Not a big issue, most recommendations transfer well to other frameworks. <br />\n",
    "`einops` works the same way in any framework.\n",
    "\n",
    "Finally - just write your code with einops!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}