{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Online softmax\n",
    "\n",
    "> **note**: this tutorial requires to be familiar with tiled `matmul`. A dedicated tutorial is available in `tutorial` folder of this repository.\n",
    "\n",
    "Naive implementation of `softmax` computation requires to perform several passes on the whole input vector.\n",
    "Vector loading from global memory (`GM`, aka the GPU DRAM) for each pass is by far the operation bottleneck (compared to the computation part).\n",
    "\n",
    "`softmax` `triton` tutorial avoids multiple read/write operations on `GM` by assuming that the whole input vector is small enough to be loaded in shared memory (`SRAM`).\n",
    "\n",
    "Below, we describe an approach when this assumption doesn't stand, aka when the vector is too large for the `SRAM`. In the case of `transformer` model, the `softmax` is applied to each row of a matrix of shape `(sequence length, sequence length)`, and the `SRAM` limit for an `fp16` vector is around 128 tokens.\n",
    "\n",
    "We will start the tutorial with a naive approach and optimize it.\n",
    "\n",
    "## Problem setup\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "torch.manual_seed(456)\n",
    "\n",
    "row_count, col_count = 4, 16\n",
    "\n",
    "long_input_vec: torch.Tensor = torch.rand((row_count, col_count))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Softmax computation\n",
    "\n",
    "### Safe softmax\n",
    "\n",
    "To avoid `FP16` or `FP32` overflow in `softmax` computation, it's usual to subtract to input vector its maximum value.\n",
    "This operation has no effect on the final output outside. It improves stability by reducing values amplitude.\n",
    "This is sometimes called `safe softmax` computation.\n",
    "\n",
    "### Global memory (DRAM) bandwidth bottleneck\n",
    "\n",
    "Computation of `safe softmax` on `PyTorch` requires multiple passes on the whole input vector if done manually:\n",
    "\n",
    "* one pass to find the maximum value\n",
    "* one pass to apply exponential operation to each value (numerator) and sum them (denominator)\n",
    "* one pass to perform the division `numerator / denominator`\n",
    "\n",
    "*Note that because of the eager execution model, on `PyTorch` step 2 requires 2 passes.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input row max\n",
      " tensor([[0.9820],\n",
      "        [0.8412],\n",
      "        [0.9198],\n",
      "        [0.9778]])\n",
      "Below we reduce values amplitude, that's the safe part of safe softmax\n",
      "original 1st row input:\n",
      " tensor([0.6815, 0.0039, 0.7451, 0.7946, 0.6127, 0.6803, 0.9820, 0.0019, 0.1609,\n",
      "        0.5916, 0.6531, 0.8855, 0.7397, 0.0681, 0.3341, 0.3200]) safe softmax input 1st row:\n",
      " tensor([-0.3005, -0.9780, -0.2369, -0.1874, -0.3693, -0.3017,  0.0000, -0.9800,\n",
      "        -0.8211, -0.3904, -0.3289, -0.0965, -0.2423, -0.9139, -0.6479, -0.6620])\n"
     ]
    }
   ],
   "source": [
    "# torch softmax as a reference\n",
    "expected_softmax = torch.softmax(long_input_vec, dim=1)\n",
    "\n",
    "# 1st read, torch max output both indexes and values, we only want the values\n",
    "# we transpose it to get a vertical tensor\n",
    "row_max = torch.max(long_input_vec, dim=1).values[:, None]\n",
    "print(\"input row max\\n\", row_max)\n",
    "# 2nd read\n",
    "input_safe = long_input_vec - row_max\n",
    "print(\"Below we reduce values amplitude, that's the safe part of safe softmax\")\n",
    "print(\"original 1st row input:\\n\", long_input_vec[0, :], \"safe softmax input 1st row:\\n\", input_safe[0, :])\n",
    "\n",
    "softmax_numerator = torch.exp(input_safe)\n",
    "# 3rd read\n",
    "normalizer_term = torch.sum(softmax_numerator, dim=1)[:, None]\n",
    "# 4th read\n",
    "naive_softmax = softmax_numerator / normalizer_term\n",
    "\n",
    "assert torch.allclose(naive_softmax, expected_softmax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Online softmax\n",
    "\n",
    "In their paper [Online normalizer calculation for softmax](https://arxiv.org/pdf/1805.02867.pdf), *M. Milakov & Al.* show an approach which makes parallelization possible by computing `softmax` progressively.\n",
    "Basically, we load the input vector in small blocks (adapted to the size of the `SRAM`) and compute 2 statistics in a single pass:\n",
    "\n",
    "* the maximum value\n",
    "* the denominator\n",
    "\n",
    "The achievement lies in the fact that you are supposed to know the maximum value of the vector to compute the denominator.\n",
    "At each step, our knowledge of the maximum value may evolve (we may meet a value bigger than our precedent maximum).\n",
    "When it happens, we just adjust the result of our computation of the precedent step.\n",
    "\n",
    "The adjustment procedure is based on rules of exponentiation: when multiplying a base raised to one exponent by the same base raised to another exponent, the exponents add."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- new row ---\n",
      "new max discovered\n",
      "current row max: 0.6815125346183777, denominator: 1.0\n",
      "current row max: 0.6815125346183777, denominator: 1.5078496932983398\n",
      "new max discovered\n",
      "current row max: 0.7450968623161316, denominator: 2.4149584770202637\n",
      "new max discovered\n",
      "current row max: 0.79459148645401, denominator: 3.2983407974243164\n",
      "current row max: 0.79459148645401, denominator: 4.132010459899902\n",
      "current row max: 0.79459148645401, denominator: 5.0240325927734375\n",
      "new max discovered\n",
      "current row max: 0.9819886684417725, denominator: 5.165497779846191\n",
      "current row max: 0.9819886684417725, denominator: 5.540792465209961\n",
      "current row max: 0.9819886684417725, denominator: 5.9807353019714355\n",
      "current row max: 0.9819886684417725, denominator: 6.657543182373047\n",
      "current row max: 0.9819886684417725, denominator: 7.377259731292725\n",
      "current row max: 0.9819886684417725, denominator: 8.285249710083008\n",
      "current row max: 0.9819886684417725, denominator: 9.070070266723633\n",
      "current row max: 0.9819886684417725, denominator: 9.471039772033691\n",
      "current row max: 0.9819886684417725, denominator: 9.994173049926758\n",
      "current row max: 0.9819886684417725, denominator: 10.509976387023926\n",
      "--- new row ---\n",
      "new max discovered\n",
      "current row max: 0.3628944754600525, denominator: 1.0\n",
      "current row max: 0.3628944754600525, denominator: 1.796286702156067\n",
      "new max discovered\n",
      "current row max: 0.44334477186203003, denominator: 2.6574349403381348\n",
      "new max discovered\n",
      "current row max: 0.5694260597229004, denominator: 3.3426434993743896\n",
      "current row max: 0.5694260597229004, denominator: 4.307759761810303\n",
      "new max discovered\n",
      "current row max: 0.8411758542060852, denominator: 4.282706260681152\n",
      "current row max: 0.8411758542060852, denominator: 5.092153549194336\n",
      "current row max: 0.8411758542060852, denominator: 5.754087924957275\n",
      "current row max: 0.8411758542060852, denominator: 6.385719299316406\n",
      "current row max: 0.8411758542060852, denominator: 7.075372695922852\n",
      "current row max: 0.8411758542060852, denominator: 7.718149185180664\n",
      "current row max: 0.8411758542060852, denominator: 8.450255393981934\n",
      "current row max: 0.8411758542060852, denominator: 9.37951946258545\n",
      "current row max: 0.8411758542060852, denominator: 9.812650680541992\n",
      "current row max: 0.8411758542060852, denominator: 10.249856948852539\n",
      "current row max: 0.8411758542060852, denominator: 11.185232162475586\n",
      "--- new row ---\n",
      "new max discovered\n",
      "current row max: 0.9197819828987122, denominator: 1.0\n",
      "current row max: 0.9197819828987122, denominator: 1.8743796348571777\n",
      "current row max: 0.9197819828987122, denominator: 2.4121508598327637\n",
      "current row max: 0.9197819828987122, denominator: 3.1733081340789795\n",
      "current row max: 0.9197819828987122, denominator: 3.648242712020874\n",
      "current row max: 0.9197819828987122, denominator: 4.4900665283203125\n",
      "current row max: 0.9197819828987122, denominator: 5.0027642250061035\n",
      "current row max: 0.9197819828987122, denominator: 5.762831687927246\n",
      "current row max: 0.9197819828987122, denominator: 6.3089094161987305\n",
      "current row max: 0.9197819828987122, denominator: 6.796399116516113\n",
      "current row max: 0.9197819828987122, denominator: 7.307489395141602\n",
      "current row max: 0.9197819828987122, denominator: 8.28607177734375\n",
      "current row max: 0.9197819828987122, denominator: 8.744580268859863\n",
      "current row max: 0.9197819828987122, denominator: 9.38587760925293\n",
      "current row max: 0.9197819828987122, denominator: 9.824188232421875\n",
      "current row max: 0.9197819828987122, denominator: 10.793715476989746\n",
      "--- new row ---\n",
      "new max discovered\n",
      "current row max: 0.177534282207489, denominator: 1.0\n",
      "new max discovered\n",
      "current row max: 0.9202759861946106, denominator: 1.4758076667785645\n",
      "current row max: 0.9202759861946106, denominator: 2.0623040199279785\n",
      "current row max: 0.9202759861946106, denominator: 2.7364466190338135\n",
      "new max discovered\n",
      "current row max: 0.9371026754379272, denominator: 3.690786600112915\n",
      "current row max: 0.9371026754379272, denominator: 4.633510112762451\n",
      "current row max: 0.9371026754379272, denominator: 5.228850841522217\n",
      "current row max: 0.9371026754379272, denominator: 5.776777744293213\n",
      "current row max: 0.9371026754379272, denominator: 6.281983852386475\n",
      "current row max: 0.9371026754379272, denominator: 6.7736921310424805\n",
      "current row max: 0.9371026754379272, denominator: 7.39810848236084\n",
      "current row max: 0.9371026754379272, denominator: 8.079381942749023\n",
      "current row max: 0.9371026754379272, denominator: 8.852633476257324\n",
      "new max discovered\n",
      "current row max: 0.9778261780738831, denominator: 9.49936580657959\n",
      "current row max: 0.9778261780738831, denominator: 9.926789283752441\n",
      "current row max: 0.9778261780738831, denominator: 10.47911262512207\n"
     ]
    }
   ],
   "source": [
    "online_softmax = torch.zeros_like(long_input_vec)\n",
    "\n",
    "for row in range(row_count):\n",
    "    row_max = 0.0\n",
    "    normalizer_term = 0.0\n",
    "    print(\"--- new row ---\")\n",
    "    for col in range(col_count):  # scalar level iteration\n",
    "        val = long_input_vec[row, col]\n",
    "        old_row_max = row_max\n",
    "        row_max = max(old_row_max, val)\n",
    "        # np.exp(old_max_row - max_row) is the adjustment factor of our precedent normalizer term,\n",
    "        # after this multiplication it's like we had always substracted row_max up to this point\n",
    "        # instead of old_row_max\n",
    "        normalizer_term = normalizer_term * torch.exp(old_row_max - row_max) + torch.exp(val - row_max)\n",
    "        if old_row_max != row_max:\n",
    "            print(\"new max discovered\")\n",
    "        print(f\"current row max: {row_max}, denominator: {normalizer_term}\")\n",
    "\n",
    "    # leverage our 2 statistics\n",
    "    online_softmax[row, :] = torch.exp(long_input_vec[row, :] - row_max) / normalizer_term\n",
    "\n",
    "assert torch.allclose(online_softmax, expected_softmax)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Instead of working on scalars, we may prefer to work on blocks of vectors as big as `GPU` `SRAM` can load for performance reasons. For that purpose, the code above needs very small modifications, something we will see in the Flash attention notebook."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "3.10.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}