{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Flash attention (forward)\n",
    "\n",
    "As seen in its dedicated notebook, `online softmax` limits the computation to 2 passes on global memory input.\n",
    "\n",
    "In [`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`](https://arxiv.org/pdf/2205.14135.pdf), the goal of the author is to have as few passes on input data as possible and, most importantly, avoid saving intermediate results to `GPU` global memory (the big slow `DRAM` memory).  \n",
    "The reason is that the intermediate results are basically the attention matrix `QK^t` of shape `seq length x seq length` (one before `softmax`, one after it) which scales by definition quadratically regarding the sequence length.  \n",
    "\n",
    "Fusing two `matmul` (technique known as kernel fusion) is quite simple when the number of columns of input matrices is low (which is the case in `transformer` for a single `head`).  \n",
    "That way, we avoid saving to global memory the output of the first `matmul` (which is an intermediate result).\n",
    "\n",
    "One more complicated challenge is to perform the `softmax` between the two fused `matmul`.  \n",
    "Traditional `softmax` formula requires several full passes on input data, and it's impossible in an operation between two fused `matmul` because intermediate result is not available.  \n",
    "\n",
    "In the original `online softmax` paper, the computation of input maximum value and the normalizer is progressive, using one part of input data at a time. When we load a new part of the input vector, we may discover a new input maximum value. If so, we need to adjust already computed part of the `softmax` (the past) in a way which makes it as we had applied the new row maximum since the beginning of the computation.\n",
    "\n",
    "As explained at the beginning of this notebook, we don't want to save those very large data. So, the trick is the following: the adjustment of the partial `softmax` output is a multiplication of the past data by a scalar, multiplication is commutative, so we can apply the adjustment not on `softmax` output itself, but to the output of the second `matmul` (`softmax` output times `V` matrix). This output is the self-attention output, it's quite small (`seq length x d`) compared to intermediate results (`seq length x seq length`), plus it is saved to global memory.\n",
    "\n",
    "As a reminder `seq length x d` is the size of `V` matrix (and `K`, `Q`). `d` is the number of columns of `V` (and `K`, `Q`), aka the number of dimensions per head for the model. This number is low: <= 128 even for a `Davinci GPT-3` size model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "torch.manual_seed(456)\n",
    "\n",
    "N, d = 16, 8\n",
    "\n",
    "Q_mat = torch.rand((N, d))\n",
    "K_mat = torch.rand((N, d))\n",
    "V_mat = torch.rand((N, d))\n",
    "\n",
    "# tile size for matmul, no op bigger than this size can be stored in SRAM\n",
    "Br = 4\n",
    "Bc = d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Classic PyTorch implementation of attention\n",
    "\n",
    "We start by implementing attention mechanism in PyTorch.\n",
    "The code is simple and many read/write in global memory are performed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "expected_softmax = torch.softmax(Q_mat @ K_mat.T, dim=1)\n",
    "expected_attention = expected_softmax @ V_mat\n",
    "\n",
    "# 1st read\n",
    "S_mat = Q_mat @ K_mat.T\n",
    "row_max = torch.max(S_mat, dim=1).values[:, None]\n",
    "# 2nd read\n",
    "input_safe = S_mat - row_max\n",
    "softmax_numerator = torch.exp(input_safe)\n",
    "# 3rd read\n",
    "softmax_denominator = torch.sum(softmax_numerator, dim=1)[:, None]\n",
    "# 4th read\n",
    "naive_softmax = softmax_numerator / softmax_denominator\n",
    "# final matmul (another read / write)\n",
    "matmul_result = naive_softmax @ V_mat\n",
    "\n",
    "assert torch.allclose(naive_softmax, expected_softmax)\n",
    "assert torch.allclose(matmul_result, expected_attention)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Simple tiled matmul\n",
    "\n",
    "Tiling is a technique based on matrix partition, each block is called a tile.\n",
    "A dedicated notebook is in the `tutorial` folder of this repository.\n",
    "\n",
    "We start with the `matmul` of `QK^t`.\n",
    "One particularity of transformer models is that the number of columns (`d`) of those matrices is small enough that several complete rows can be stored in shared memory (a fast memory close to compute cores of the GPU), so we don't have to iterate across this axis. This is an important aspect that we will leverage later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "S_mat_for_check = torch.zeros((N, N))\n",
    "\n",
    "for block_start_Bc in range(0, N, Bc):\n",
    "    block_end_Bc = block_start_Bc + Bc\n",
    "    Kj = K_mat[block_start_Bc:block_end_Bc, :]  # shape Bc x d\n",
    "    for block_start_Br in range(0, N, Br):\n",
    "        block_end_Br = block_start_Br + Br\n",
    "        Qi = Q_mat[block_start_Br:block_end_Br, :]  # shape Br x d\n",
    "\n",
    "        # QKt at the tile level\n",
    "        Sij = Qi @ Kj.T  # shape Br x Bc\n",
    "        S_mat_for_check[block_start_Br:block_end_Br, block_start_Bc:block_end_Bc] += Sij\n",
    "\n",
    "assert torch.allclose(S_mat_for_check, Q_mat @ K_mat.T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Fused `matmul`\n",
    "\n",
    "We will perform the computation `O=SV` where `S=QK^t`. Therefore, we need to perform 2 `matmul`s.\n",
    "For now, we do not put any `softmax` in between.\n",
    "\n",
    "Our main challenge is to build on top of the previous notebook block and not save in `GPU` global memory the intermediate `matmul` output `S`. One trick to reduce global memory accesses is to reuse data as much as possible. That is why in the outer loop we load 2 blocks (`Kj` and `Vj`) and reuse both of them during the iteration in the inner loop where only a single block is loaded from global memory (if executed in `Cuda`).\n",
    "\n",
    "The input matrices are supposed to not be transposed.\n",
    "The transposition of `K` is done implicitely through the way we iterate over it. Because `d` is small, we have no non-coalesced memory access issue."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "O = torch.zeros((N, d))\n",
    "\n",
    "for block_start_Bc in range(0, N, Bc):\n",
    "    block_end_Bc = block_start_Bc + Bc\n",
    "    Kj = K_mat[block_start_Bc:block_end_Bc, :]  # shape Bc x d\n",
    "    Vj = V_mat[block_start_Bc:block_end_Bc, :]  # shape Bc x d\n",
    "    for block_start_Br in range(0, N, Br):\n",
    "        block_end_Br = block_start_Br + Br\n",
    "        Qi = Q_mat[block_start_Br:block_end_Br, :]  # shape Br x d\n",
    "\n",
    "        # QKt at the tile level\n",
    "        Sij = Qi @ Kj.T  # shape Br x Bc\n",
    "        Oi = Sij @ Vj  # shape Br x d\n",
    "        O[block_start_Br:block_end_Br, :] += Oi\n",
    "\n",
    "assert torch.allclose(O, (Q_mat @ K_mat.T) @ V_mat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Full flash attention (forward)\n",
    "\n",
    "In the previous block, we didn't apply the `softmax` on top of `QK^t` `matmul`.\n",
    "The challenge in introducing it is that both `softmax` input and output shall not be saved in global memory. Let remind us that both input and output of `softmax` have the shape `seq len X seq len`.\n",
    "\n",
    "For that purpose, we will leverage the `online softmax` technique presented in details in a dedicated notebook.\n",
    "The idea is that to compute `safe softmax` (a numerical stable version of the `softmax`) we need to know the input `max` value (and the `softmax` denominator which itself depends on this input `max` statistic), an information that we can only know by scanning whole input data.\n",
    "\n",
    "`online softmax` computes (safe) `softmax` progressively in a single pass over input data (which are themselves computes progressively), block of row after block of row. During the process, each time we discover in a block a `max` bigger than the currently known row `max`, we correct the already computed values in a way which simulates that we have applied the new row `max` for the `softmax` numerator and denominator since the beginning. \n",
    "\n",
    "The correction of the `softmax` denominator is applied here:\n",
    "\n",
    "```python\n",
    "# This line is exactly the same mechanism seen in `online softmax` notebook but applied to a vector instead of scalar (the math stays the same).\n",
    "li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat\n",
    "```\n",
    "\n",
    "The remaining question is where to correct the *past* partial `softmax` output?  \n",
    "Because we want to perform all computations in a single pass, we do not save first `matmul` output to `GPU` global memory, meaning there is no *past* part to adjust.\n",
    "\n",
    "It appears that we can apply the correction to the matrix `O` (output of the second `matmul`).  \n",
    "It works because multiplication by a scalar is commutative, aka we can change the order of operations and get the same result mathematically.  \n",
    "*Nb: this is not 100% true in our code, order matters because float numbers has limited precision and introduces some roundings, still, the effect is small and Okish in deep learning*. \n",
    "\n",
    "This is done in the line:\n",
    "\n",
    "```python\n",
    "Oi = (li * torch.exp(mi - mi_new) * Oi / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj\n",
    "```\n",
    "\n",
    "In the left part of the addition (`(li * torch.exp(mi - mi_new) * Oi / li_new)`), `Oi` contains the sum of past output tiles, and that's where we can correct the past. In the right part of the addition (`(torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj`), we first correct the current tile `softmax` and then compute the current `matmul` tile output.\n",
    "\n",
    "Lines below refer to `Algorithm 1` from [flash attention](https://arxiv.org/pdf/2205.14135.pdf) paper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# variables outside the for loop represent the global memory\n",
    "# they are the only ones bigger than what the SRAM can store\n",
    "O = torch.zeros((N, d))\n",
    "\n",
    "# For the 2 variables below, they may be removed in a serially executed code (in particular the outter for loop)\n",
    "# They are needed in parallelized execution where each thread block need to sync its findings with the others\n",
    "# line 4, l will store the denominator of the softmax for each row\n",
    "l = torch.zeros((N, 1))\n",
    "# line 4, m will store the row max (computed progressively, block after block)\n",
    "m = torch.full((N, 1), -torch.inf)\n",
    "\n",
    "for block_start_Bc in range(0, N, Bc):\n",
    "    block_end_Bc = block_start_Bc + Bc\n",
    "    # line 6, load a block from matmul input tensor\n",
    "    Kj = K_mat[block_start_Bc:block_end_Bc, :]  # shape Bc x d\n",
    "    Vj = V_mat[block_start_Bc:block_end_Bc, :]  # shape Bc x d\n",
    "    for block_start_Br in range(0, N, Br):\n",
    "        block_end_Br = block_start_Br + Br\n",
    "\n",
    "        # line 8, load stuff from globabl memory, aka the work of the other thread blocks\n",
    "        mi = m[block_start_Br:block_end_Br, :]  # shape Br x 1\n",
    "        li = l[block_start_Br:block_end_Br, :]  # shape Br x 1\n",
    "        Oi = O[block_start_Br:block_end_Br, :]  # shape Br x d\n",
    "        Qi = Q_mat[block_start_Br:block_end_Br, :]  # shape Br x d\n",
    "\n",
    "        # line 9, QKt at the tile level\n",
    "        Sij = Qi @ Kj.T  # shape Br x Bc\n",
    "\n",
    "        # line 10, find max of each row of the current loaded block (and only this block)\n",
    "        mij_hat = torch.max(Sij, dim=1).values[:, None]\n",
    "        # line 10, compute the softmax numerator like if we only had the data from this block (and nothing before or after)\n",
    "        pij_hat = torch.exp(Sij - mij_hat)\n",
    "        # line 10, compute the softmax denominator like if we only had the data from this block (and nothing before or after)\n",
    "        lij_hat = torch.sum(pij_hat, dim=1)[:, None]\n",
    "\n",
    "        # line 11, find max of each row regarding the current block and all the previous ones we have already visited\n",
    "        mi_new = torch.max(torch.column_stack([mi, mij_hat]), dim=1).values[:, None]\n",
    "        # line 11, adjusting factor (see online softmax computation above) leveraging the rule of exponentiation\n",
    "        li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat\n",
    "\n",
    "        # line 12, first part before the \"+\" is the adjustment of the past blocks\n",
    "        # second part after the \"+\" is the incorporation of the information from the current block and the matmul for this block\n",
    "        Oi = (li * torch.exp(mi - mi_new) * Oi / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj\n",
    "\n",
    "        # Note that we replace (=) and not update (+=) the global variables like we would do in tilted matmul\n",
    "        # line 13, save statistics\n",
    "        m[block_start_Br:block_end_Br, :] = mi_new  # row max\n",
    "        l[block_start_Br:block_end_Br, :] = li_new  # softmax denominator\n",
    "        # save attention block to global memory\n",
    "        O[block_start_Br:block_end_Br, :] = Oi\n",
    "\n",
    "assert torch.allclose(O, expected_attention)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Triton implementation of Flash attention and original Flash attention Cuda implementations differ on an important point: the way they are parallelized.\n",
    "\n",
    "In Cuda implementation, it's quite simple, algorithm above is executed in a serialized way. The parallelization only happens at the `head x batch` level (so it needs on `A100` at least head x batch >= 80 to keep the GPU busy).\n",
    "\n",
    "In Triton implementation, the inner and outer loops in the algo above are switched and the parallelization happens at the level of the outer loop, it increases the level of parallelization and it makes the GPU busy even for small batches / low number of heads. See https://github.com/HazyResearch/flash-attention/issues/40 for detailed analysis."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "3.8.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}