{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-zYDqArMfEHp"
      },
      "source": [
        "![jax](https://repository-images.githubusercontent.com/154739597/90607180-e100-11e9-8642-c65819bec604)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9-nHAbksl7q5"
      },
      "source": [
        "# JAX Basics\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/ZohebAbai/Deep-Learning-Projects/blob/master/JAX_Basics.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RFdXp1YD-PDK"
      },
      "source": [
        "**What mostly constitutes of Deep Learning at granular level**?\n",
        "- `Bandwidth`: Loading and transfer of data from memory\n",
        "- `MatMul Ops`: Array manipulation and operations\n",
        "- `Non Matmul Ops`: Algorithms with support to auto-differentiation\n",
        "\n",
        "\n",
        "**If theoretically deep-learning is all applied-maths, then why scientific computing library like Numpy is not suitable?**\n",
        "**Or, Why powerful deep learning frameworks like Tensorflow and PyTorch are not built on Numpy?**\n",
        "- Numpy cannot run on accelerated hardware.\n",
        "- Numpy doesn't have auto differentiation.\n",
        "- Numpy doesn't saves memory by fusing of operations.\n",
        "- Numpy doesn't supports vectorized batching of operations\n",
        "- Numpy (mostly python inherently) doesn't supports parallelization of data and computation\n",
        "\n",
        "Mostly we work with CPU and GPU device data formats by copying the data into device memory, converting it into device format and then running operations on it. \n",
        "\n",
        "To let numpy work on GPU, several teams came up with their solutions (we won't be discussing them on details):\n",
        " - CuPy and other Cu-Libraries by [RAPIDS AI](https://github.com/rapidsai) and NVIDIA supports CUDA Array Interface (GPU) instead of Numpy Array Interface (CPU), thus running only on GPU devices.\n",
        " - [Numba](https://numba.pydata.org/) by Ananconda is a JIT compiler supporting both CUDA Array Interface (GPU) and Numpy Array Interface (CPU) but it transpiles python bytecode directly to LLVM for compilation, thus complicating things.\n",
        "\n",
        " This copy-and-converting data between different formats is an expensive and incredibly time-consuming task that adds zero value to data science pipelines. \n",
        "\n",
        " Let's look at a report from [this recent paper](https://arxiv.org/abs/2007.00072) on FLOP counts on SOTA model like BERT for different operator types:\n",
        "![Bert Perf](https://horace.io/img/perf_intro/bert_flops.png)\n",
        "\n",
        "You can see that altogether, our non-matmul ops only make up 0.2% of our FLOPS, but 40% of our runtimes. These are also called *memory-bound operations*.\n",
        "\n",
        "\n",
        "**Well, [JAX](https://github.com/google/jax) seems like a promising alternative to Numpy, fixing all of above mentioned issues.**\n",
        "\n",
        "**JAX is a high performance, numerical computing library which incorporates composable function transformations.**\n",
        "\n",
        "It lies at the intersection of Scientific Computing and Function Transformations, yielding a wide range of capability beyond the ability to train just Deep Learning models. \n",
        "\n",
        "*A function transformation is an operator on a function whose output is another function.*\n",
        "\n",
        "![image1](https://www.assemblyai.com/blog/content/images/2022/02/JAX-overview.svg)\n",
        "\n",
        "- It treates differentiation as first-class citizen \n",
        "- Its hardware accelerator agnostic\n",
        "- Its compiler oriented\n",
        "- It provides numpy like API\n",
        "- And a lot more...\n",
        "\n",
        "It mainly constitutes of JIT Compilation, Autograd And XLA Compiler\n",
        "\n",
        "- JIT: Just in Time compilation compiles code during execution of the program i.e. during runtime. \n",
        "- Autograd: It provides a framework for general Differentiable Programming.\n",
        "- XLA: Accelerated Linear Algebra  is a graph-based, whole-program optimizing compiler, designed specifically for linear algebra. It divides the code into sequence of computation kernels, significantly increases execution speed and lower memory usage by fusing low-level operations. For example: it uses single gpu kernel by fusing addition and multiplication, without writing intermediate values into memory (instead keeping in GPU registers and streaming them).\n",
        "\n",
        "**Deep Learning Community is embracing JAX**\n",
        "- [Huggingface supports almost all their models in JAX](https://discuss.huggingface.co/t/about-the-flax-jax-projects-category/7061)\n",
        "\n",
        "- [Deepmind shifted to JAX](https://www.deepmind.com/blog/using-jax-to-accelerate-our-research)\n",
        "\n",
        "- Google used JAX with its 4096 cores TPU Supercomputer to win six out of eight [MLPerf benchmark competitions](https://cloud.google.com/blog/products/ai-machine-learning/google-breaks-ai-performance-records-in-mlperf-with-worlds-fastest-training-supercomputer).\n",
        "\n",
        "- Recently Google launched [LaMDA](https://youtu.be/ayhJii34D38) which too is built on JAX.\n",
        "\n",
        "In this notebook, we shall go through the powers of JAX."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kJnHRsytlFqh"
      },
      "source": [
        "## Install Libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "55vuAdi5H6Yp"
      },
      "outputs": [],
      "source": [
        "!add-apt-repository ppa:longsleep/golang-backports -y\n",
        "!apt update\n",
        "!apt install golang-go\n",
        "%env GOPATH=/root/go\n",
        "\n",
        "!apt-get install graphviz gv\n",
        "!go install github.com/google/pprof@latest"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xZ6vuhDub_P5",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f139b6ab-f907-4512-cd51-ecdee5a6234c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
            "Requirement already satisfied: jax[cuda] in /usr/local/lib/python3.7/dist-packages (0.3.8)\n",
            "Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax[cuda]) (1.4.1)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax[cuda]) (4.1.1)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax[cuda]) (1.1.0)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax[cuda]) (3.3.0)\n",
            "Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.7/dist-packages (from jax[cuda]) (1.21.6)\n",
            "Collecting jaxlib==0.3.7+cuda11.cudnn82\n",
            "  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.7%2Bcuda11.cudnn82-cp37-none-manylinux2014_x86_64.whl (158.1 MB)\n",
            "\u001b[K     |████████████████████████████████| 158.1 MB 29 kB/s \n",
            "\u001b[?25hRequirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.3.7+cuda11.cudnn82->jax[cuda]) (2.0)\n",
            "Installing collected packages: jaxlib\n",
            "  Attempting uninstall: jaxlib\n",
            "    Found existing installation: jaxlib 0.3.7+cuda11.cudnn805\n",
            "    Uninstalling jaxlib-0.3.7+cuda11.cudnn805:\n",
            "      Successfully uninstalled jaxlib-0.3.7+cuda11.cudnn805\n",
            "Successfully installed jaxlib-0.3.7+cuda11.cudnn82\n"
          ]
        }
      ],
      "source": [
        "!pip install --upgrade \"jax[cuda]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Y4Tqe8pMcIJG",
        "outputId": "bf55a577-b865-4c6c-c85f-1cbb3e9dce1f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Tue Jun 21 14:28:48 2022       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   36C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        }
      ],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "p97YMkOoDFDi",
        "outputId": "bc880af2-ce4b-41c5-a353-2e25647eafa6"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[GpuDevice(id=0, process_index=0)]"
            ]
          },
          "metadata": {},
          "execution_count": 2
        }
      ],
      "source": [
        "import jax\n",
        "jax.devices()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cDMNUxkgd4Z1"
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\n",
        "from jax import random\n",
        "from jax import grad, jit, make_jaxpr, vmap, pmap\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Uyez14DxJ50"
      },
      "source": [
        "## Syntax same, but BTS lies difference!!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6plfp5tqxNWv",
        "outputId": "fe6bd061-33a9-428a-aa00-3dbe37cd4409"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "x = np.zeros(10)\n",
        "x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wX6oGUfjxUsl",
        "outputId": "72c17f05-9c04-423f-d8ce-78925c31b38e"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "y = jnp.zeros(10)\n",
        "y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aNMEDLmD-5u4"
      },
      "source": [
        "\n",
        "\n",
        "Unlike Numpy, JAX arrays are immutable, meaning that once created their contents cannot be changed.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pSXD2qmR7jgA",
        "outputId": "8c7606db-b215-41b6-e81e-a065d052f2ac"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Numpy arrays are mutable\n",
            "Earlier memory address: 0x7fafbda22450\n",
            "[10.  0.  0.  0.  0.  0.  0.  0.  0.  0.]\n",
            "Current memory address: 0x7fafbda22450\n",
            "\n",
            "\n",
            "JAX cannot be in-place mutated. It returns a copy\n",
            "Earlier memory address: 0x7fafbe3554b0\n",
            "[10.  0.  0.  0.  0.  0.  0.  0.  0.  0.]\n",
            "Current memory address: 0x7fafbe355330\n"
          ]
        }
      ],
      "source": [
        "print(\"Numpy arrays are mutable\")\n",
        "print(f\"Earlier memory address: {hex(id(x))}\")\n",
        "x[0] = 10\n",
        "print(x)\n",
        "print(f\"Current memory address: {hex(id(x))}\")\n",
        "print(\"\\n\")\n",
        "print(\"JAX cannot be in-place mutated. It returns a copy\")\n",
        "print(f\"Earlier memory address: {hex(id(y))}\")\n",
        "y = y.at[0].set(10)\n",
        "print(y)\n",
        "print(f\"Current memory address: {hex(id(y))}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ikIk2XNBKBjG"
      },
      "source": [
        "**Now let's understand JAX via few of its properties**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bc5OK-AXf2Ny"
      },
      "source": [
        "## How randomness is handled in JAX"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lYSyqGsge12T",
        "outputId": "68342da7-8d8f-4adc-c601-d8a7b7627e62"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "For Numpy:\n",
            "bar + 2 x car gives 1.9791922366721637\n",
            "2 x car + bar gives 1.7504099351401847\n"
          ]
        }
      ],
      "source": [
        "seed = 0\n",
        "np.random.seed(seed)\n",
        "\n",
        "print(\"For Numpy:\")\n",
        "# function def\n",
        "def bar(): return np.random.uniform()\n",
        "def car(): return np.random.uniform()\n",
        "\n",
        "def foo1(): return bar() + 2*car()\n",
        "print(f\"bar + 2 x car gives {foo1()}\")\n",
        "\n",
        "def foo2(): return 2*car() + bar()\n",
        "print(f\"2 x car + bar gives {foo2()}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a54RLQDSkEkH"
      },
      "source": [
        "Algorithm is same but the result is different. This is because the order of the execution of the functions is not the same anymore.\n",
        "\n",
        "This becomes a problem when trying to parallelize all of our complex functions. We cannot guarantee order or their executions and therefore, there is no way of enforcing reproducibility of results we are getting.\n",
        "\n",
        "JAX solves this by pseudo-random number generator keys."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dwoB3Ei1hHVy",
        "outputId": "71aa4f32-13f4-42df-a49a-0b9c2c13c2c8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "For JAX:\n",
            "bar + 2 x car gives 2.470635175704956\n",
            "2 x car + bar gives 2.470635175704956\n"
          ]
        }
      ],
      "source": [
        "state = 101\n",
        "key = random.PRNGKey(state)\n",
        "\n",
        "# subkeys for each functions\n",
        "subkeys = random.split(key, num=2)\n",
        "\n",
        "print(\"For JAX:\")\n",
        "# function def\n",
        "def bar(): return random.uniform(subkeys[0])\n",
        "def car(): return random.uniform(subkeys[1])\n",
        "\n",
        "def foo1(): return bar() + 2*car()\n",
        "print(f\"bar + 2 x car gives {foo1()}\")\n",
        "\n",
        "def foo2(): return 2*car() + bar()\n",
        "print(f\"2 x car + bar gives {foo2()}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I6wBV7UmyCg9"
      },
      "source": [
        "## Speed Comparison"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d-6TPKRXx1fs"
      },
      "outputs": [],
      "source": [
        "x = np.random.rand(10000,10000).astype(np.float32) \n",
        "# For fair comparision \n",
        "# Numpy defaults tp 64-bit dtypes whule JAX to 32-bit.\n",
        "y = jnp.array(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gNI86Z-22Umk",
        "outputId": "b5f57275-b63e-42d2-fdb5-ee457fe49892"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 1: 22.1 s per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 1 np.dot(x,x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lJ4O1h863uKU",
        "outputId": "6c9eb82a-1769-4a2a-cb91-83deda59e2ca"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 1: 3.62 s per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "76f4eAKfOxWV"
      },
      "source": [
        "We are using `block_until_ready` for benchmarking for a reason we are about to cover. JAX compiles and caches it in the device memory. So next time you run the same operation, it gives results much faster.\n",
        "\n",
        "Its 7-8 times faster."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ITG_EnShOo6_",
        "outputId": "3764a796-4068-4536-86bb-4ac294cf2557"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 1: 229 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cTReDIyzPyV0"
      },
      "source": [
        "Its 100x faster. Impressive, but watch this - let's remove `block until ready`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "befkzD3BOgDF",
        "outputId": "b3f3092e-d315-4690-f275-6282617a8409"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 1: 931 µs per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 1 jnp.dot(y,y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G4MmwsovP5Uc"
      },
      "source": [
        "Micro seconds for 10k dim x 10k dim matrix multiplication on single NVIDIA TESLA P100 GPU of 16GB RAM, isn't that suprising?\n",
        "\n",
        "Let's understand why."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yDL4guGxQX7V"
      },
      "source": [
        "## Asynchronous Dispatch\n",
        "\n",
        "JAX is async. What happened earlier was that JAX mislead us when we removed `block until ready`. We were not timing the execution of matrix multiplication, only the time to dispatch the work. To measure the  true cost of operation, we need to wait untill the execution is complete in order to properly measure the time. So we use `block until ready` during benchmarking.\n",
        "\n",
        "**Explaination:**\n",
        "\n",
        "JAX does not wait for the operation to complete before returning control to the Python program. Instead, JAX returns a `DeviceArray` value, which is a future, i.e., a value that will be produced in the future on an accelerator device but isn’t necessarily available immediately. We can inspect the shape or type of a DeviceArray without waiting for the computation that produced it to complete, and we can even pass it to another JAX computation, as we do with the addition operation here. Only if we actually inspect the value of the array from the host, for example by printing it or by converting it into a plain old `numpy.ndarray` will JAX force the Python code to wait for the computation to complete.\n",
        "\n",
        "Asynchronous dispatch is useful since it allows Python code to “run ahead” of an accelerator device, keeping Python code out of the critical path. Provided the Python code enqueues work on the device faster than it can be executed, and provided that the Python code does not actually need to inspect the output of a computation on the host, then a Python program can enqueue arbitrary amounts of work and avoid having the accelerator wait.\n",
        "\n",
        "Let's breakdown the earlier operation:\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3RqE7mAUQaWJ"
      },
      "outputs": [],
      "source": [
        "# instead of using .dot function let's user define it\n",
        "def f(x):  \n",
        "  return x @ x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Od0zFGStXCBl",
        "outputId": "15a04c68-384e-40ef-d5ed-95dd3507c8ef"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 1: 22.4 s per loop\n"
          ]
        }
      ],
      "source": [
        "# measure NumPy runtime\n",
        "x_np = np.random.rand(10000,10000).astype(np.float32) \n",
        "%timeit -n 1 -r 1 f(x_np) "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5Q9NvlJqcV5X"
      },
      "source": [
        "NumPy takes around 20 s per evaluation on the CPU"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ks6cmWVScGSH",
        "outputId": "3f8f35b0-c247-4ddc-f6b9-ac627e713f20"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CPU times: user 9 ms, sys: 1.02 ms, total: 10 ms\n",
            "Wall time: 9.52 ms\n"
          ]
        }
      ],
      "source": [
        "# measure JAX device transfer time\n",
        "%time x_jax = jax.device_put(x_np)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P1xj2i46cZCZ"
      },
      "source": [
        "JAX takes around 5 ms to copy the NumPy arrays onto the GPU"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AtljWBKbciBu",
        "outputId": "e4aba371-e45e-4516-9229-74e72be9cf6f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CPU times: user 92.5 ms, sys: 48 ms, total: 140 ms\n",
            "Wall time: 352 ms\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "DeviceArray([[2489.3577, 2463.2961, 2486.5251, ..., 2468.2246, 2485.337 ,\n",
              "              2486.9387],\n",
              "             [2515.0486, 2532.0928, 2548.9595, ..., 2523.4004, 2516.9792,\n",
              "              2537.574 ],\n",
              "             [2513.7805, 2521.4624, 2540.517 , ..., 2508.0244, 2507.6   ,\n",
              "              2521.0579],\n",
              "             ...,\n",
              "             [2485.6765, 2488.7998, 2513.743 , ..., 2491.3413, 2475.7664,\n",
              "              2484.1887],\n",
              "             [2493.5042, 2491.062 , 2519.9312, ..., 2481.506 , 2483.618 ,\n",
              "              2498.8936],\n",
              "             [2495.9058, 2481.7222, 2509.2488, ..., 2477.701 , 2462.4023,\n",
              "              2485.2917]], dtype=float32)"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# measure JAX compilation time\n",
        "f_jit = jit(f)\n",
        "%time f_jit(x_jax).block_until_ready()  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gNv3CPQDcipB"
      },
      "source": [
        "JAX takes around 300 ms to compile the function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5vmNGrsOXwtL",
        "outputId": "4814cc25-2f1f-4117-a0a7-7cca664f3497"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 1: 229 ms per loop\n"
          ]
        }
      ],
      "source": [
        "# measure JAX runtime\n",
        "%timeit -n 1 -r 1 f_jit(x_jax).block_until_ready()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EfEWsClEYFvw"
      },
      "source": [
        "JAX takes 200 ms per evaluation on the GPU.\n",
        "\n",
        "In this case, we see that once the data is transfered and the function is compiled, JAX on the GPU is about 100x faster for repeated evaluations.\n",
        "\n",
        "Is this a fair comparison on speed? Maybe. The performance that ultimately matters is for running full deep learning applications, which inevitably include some amount of both data transfer and compilation. \n",
        "\n",
        "\n",
        "Did you notice `jit`? \n",
        "\n",
        "**JAX incorporates an extensible system for such function transformations**, and has four main transformations of interest to the typical user:\n",
        "\n",
        "- `jit()` to transform functions into just-in-time compiled versions\n",
        "- `grad()` for evaluating the gradient function of the input function\n",
        "- `vmap()` for automatic vectorization of operations\n",
        "- `pmap()` for easy parallelization of computations\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gnIvPPVLBdRP"
      },
      "source": [
        "## JIT Compilation\n",
        "\n",
        "- NumPy operations are executed eagerly, synchronously, and only on CPU.\n",
        "\n",
        "- By default JAX executes operations one at a time, in sequence or eagerly, and dispatches asynchronously on all devices CPU/GPU/TPU. Using just-in-time (JIT) compilation, sequences of operations can be optimized together and run at once. JAX uses the XLA compiler to execute blocks of code very efficiently. \n",
        "\n",
        "But there is a catch - Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time. For ex: shape is not known for x for `def f(x): return x[x<0]` during compile time. \n",
        "\n",
        "Let's understand how `jit` works.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-BV1ojg9qQ3s",
        "outputId": "a8c81ffb-c82f-4c5c-8850-a9317868803a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Running f():\n",
            "  x = [[-0.35960949  1.52744007 -0.44154394  0.02519481]\n",
            " [-0.57144428 -1.52495174  0.72102243 -0.21198663]\n",
            " [ 0.60934449 -1.01884849 -0.74926673  0.05947864]]\n",
            "\n",
            "  y = [[-1.06645266]\n",
            " [-0.26044164]\n",
            " [-1.17089024]\n",
            " [ 0.60112815]]\n",
            "\n",
            "  result = [[3.37266737]\n",
            " [0.55089334]\n",
            " [1.53262843]]\n",
            "\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "array([[3.37266737],\n",
              "       [0.55089334],\n",
              "       [1.53262843]])"
            ]
          },
          "execution_count": 25,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Let's first see how numpy works\n",
        "def f(x, y):\n",
        "  print(\"Running f():\")\n",
        "  print(f\"  x = {x}\\n\")\n",
        "  print(f\"  y = {y}\\n\")\n",
        "  result = np.dot(x + 1, y + 1)\n",
        "  print(f\"  result = {result}\\n\")\n",
        "  return result\n",
        "\n",
        "x = np.random.randn(3, 4)\n",
        "y = np.random.randn(4, 1)\n",
        "f(x, y)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WNC7mRh_yO_w",
        "outputId": "07c669fa-dca4-4d35-978c-b581df8b9858"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Running f():\n",
            "  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>\n",
            "  y = Traced<ShapedArray(float32[4,1])>with<DynamicJaxprTrace(level=0/1)>\n",
            "  result = Traced<ShapedArray(float32[3,1])>with<DynamicJaxprTrace(level=0/1)>\n",
            "\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "DeviceArray([[3.83735  ],\n",
              "             [1.4746052],\n",
              "             [5.683498 ]], dtype=float32)"
            ]
          },
          "execution_count": 26,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "def f(x, y):\n",
        "  print(\"Running f():\")\n",
        "  print(f\"  x = {x}\")\n",
        "  print(f\"  y = {y}\")\n",
        "  result = jnp.dot(x + 1, y + 1)\n",
        "  print(f\"  result = {result}\\n\")\n",
        "  return result\n",
        "\n",
        "x = random.normal(key, (3, 4))\n",
        "y = random.normal(key, (4, 1))\n",
        "f_jit = jit(f)\n",
        "f_jit(x, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mS41kp2b0vPA"
      },
      "source": [
        "Notice that rather than printing the data we passed to the function, it prints `tracer` objects that stand-in for them.\n",
        "\n",
        "These tracer objects are what `jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are **agnostic to the values**. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.\n",
        "\n",
        "Let's call the compiled function again on another input value but having same shape and dtype."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8nTucCPp1SFs",
        "outputId": "53cd62fe-b8b1-46d3-f6a6-4c7391a2041e"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "DeviceArray([[6.8841944],\n",
              "             [8.1478615],\n",
              "             [5.9593167]], dtype=float32)"
            ]
          },
          "execution_count": 27,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "key2 = random.PRNGKey(202)\n",
        "x2 = random.normal(key2, (3, 4))\n",
        "y2 = random.normal(key2, (4, 1))\n",
        "f_jit(x2, y2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lqlwQ_Bm1S_k"
      },
      "source": [
        "Did you notice `print` statements didn't run. Which means it didn't re-compile. It's because the result is computed in compiled XLA rather than in Python.\n",
        "\n",
        "You can view the sequence of operations encoded in a JAX expression using the `jax.make_jaxpr` transformation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rzuLgx6M33ME",
        "outputId": "27d5f19f-09d5-45e9-f4f1-11d62b132c13"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{ lambda ; a:f32[3,4] b:f32[4,1]. let\n",
              "    c:f32[3,4] = add a 1.0\n",
              "    d:f32[4,1] = add b 1.0\n",
              "    e:f32[3,1] = dot_general[\n",
              "      dimension_numbers=(((1,), (0,)), ((), ()))\n",
              "      precision=None\n",
              "      preferred_element_type=None\n",
              "    ] c d\n",
              "  in (e,) }"
            ]
          },
          "execution_count": 28,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "def f(x, y):\n",
        "  return jnp.dot(x + 1, y + 1)\n",
        "\n",
        "make_jaxpr(f)(x, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1cLg5M-Y4_QT"
      },
      "source": [
        "Remember:\n",
        "- **Static Manner**: Numpy executes operations only once at compile-time. \n",
        "- **Traced Manner**: JAX optimizes, compiles and executes operations at run-time.\n",
        "\n",
        "And ya, JIT is faster than default."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zLNxiRV4BivA"
      },
      "outputs": [],
      "source": [
        "x = random.normal(key, (1000, 1000))\n",
        "\n",
        "def f(x):\n",
        "    for _ in range(10):\n",
        "        x = 0.5*x + 0.1*jnp.sin(x)\n",
        "    return x\n",
        "\n",
        "g = jit(f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "u-hkjmVHCDZB",
        "outputId": "997f8c68-d499-4670-92ac-994a34c0c8c7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The slowest run took 83.29 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
            "1 loop, best of 5: 3.19 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 5 f(x).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "som-mm6UCSxd",
        "outputId": "1cdb73b7-2bb6-4dc2-c410-505d7b69c8cc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The slowest run took 976.31 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
            "1 loop, best of 5: 229 µs per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 5 g(x).block_until_ready()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s76EgPSH7kGT"
      },
      "source": [
        "Almost 30 times faster. \\m/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HPmdX1dP5694"
      },
      "source": [
        "## Auto differentiation with grad()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Zj8y3UjV5fh"
      },
      "source": [
        "### Scalar-Valued Functions\n",
        "\n",
        "Here we are looking for `Gradient`\n",
        "\n",
        "For example,\n",
        "\n",
        "gradient of \n",
        "$3x^2 + 2x + 5$ is $6x +2$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fkqz5hSh5_Th",
        "outputId": "2258b71d-c1e3-4fdd-8321-959182af3383"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "8.0\n"
          ]
        }
      ],
      "source": [
        "def f(x):\n",
        "    return 3*x**2 + 2*x + 5\n",
        "\n",
        "# derivative of f at 1\n",
        "print(grad(f)(1.0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i6iHdgTJWE1C"
      },
      "source": [
        "### Vector-Valued Functions\n",
        "Here we are looking for `Jacobian`\n",
        "\n",
        "Ex:\n",
        "For vector `[x*x, y*z]`, its jacobian is\n",
        "\n",
        "```\n",
        "[[d/dx x^2 , d/dy x^2, d/dz x^2]\n",
        "[d/dx y*z , d/dy y*z, d/dz y*z]]\n",
        "```\n",
        "\n",
        "which reduces to\n",
        "\n",
        "```\n",
        "[[2*x, 0, 0]\n",
        "[0, z, y]]\n",
        "```\n",
        "\n",
        "Let's code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "maAkSTRbWRqo",
        "outputId": "ea8fbe20-62b2-4958-f033-9996000780da"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[[8. 0. 0.]\n",
            " [0. 9. 5.]]\n"
          ]
        }
      ],
      "source": [
        "from jax import jacfwd, jacrev, hessian\n",
        "# forward mode differentiation, reverse mode differentiation, hessian\n",
        "\n",
        "def vec_f(v):\n",
        "  x = v[0]\n",
        "  y = v[1]\n",
        "  z = v[2]\n",
        "  return jnp.array([x*x, y*z])\n",
        "\n",
        "v = jnp.array([4., 5., 9.])\n",
        "f = jacfwd(vec_f)\n",
        "print(f(v))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LD3OIVqWUwxU"
      },
      "source": [
        "### Hessians - Matrix of second order mixed partials\n",
        "\n",
        "\n",
        "JAX makes computing Hessians exceedingly easy and efficient. Because of XLA, it can compute Hessians remarkably faster than PyTorch, which makes it much more practical to implement higher-order optimization techniques like `AdaHessian`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LuTWrOP_ZlMQ",
        "outputId": "e6f5e93d-eec5-450b-e760-d25d84af46a0"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "10 loops, best of 5: 2.83 ms per loop\n"
          ]
        }
      ],
      "source": [
        "import torch as pt\n",
        "\n",
        "def torch_fn(X):\n",
        "  return pt.sum(pt.mul(X,X))\n",
        "  \n",
        "X = pt.randn((1000,))\n",
        "\n",
        "%timeit -n 10 -r 5 pt.autograd.functional.hessian(torch_fn, X, vectorize=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "j78WWk3aZ_ti",
        "outputId": "dc7c6fbd-55f8-48b3-e03a-231b80a65bbe"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "The slowest run took 43.66 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
            "10 loops, best of 5: 100 µs per loop\n"
          ]
        }
      ],
      "source": [
        "def jax_fn(X):\n",
        "  return jnp.sum(jnp.square(X))\n",
        "\n",
        "jit_jax_fn = jit(hessian(jax_fn))\n",
        "\n",
        "X = jnp.array(X)\n",
        "\n",
        "%timeit -n 10 -r 5 jit_jax_fn(X).block_until_ready()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q_uQB67ijOhI"
      },
      "source": [
        "Almost 30 times faster!!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NNUAmvblEC0M"
      },
      "source": [
        "## Automatic vectorization with Vmap\n",
        "We can take a function that operates on a single data point and vectorize it so it can accept a batch of these data points (or a vector) of arbitrary size. It basically promotes matrix-vector products into matrix-matrix products.\n",
        "\n",
        "Consider the task of adding two array\n",
        "\n",
        "Watch the difference:\n",
        "\n",
        "[Unvectorized Vector Addition](https://www.assemblyai.com/blog/content/media/2022/02/not_vectorized-1.mp4)\n",
        "\n",
        "[Vectorized Vector Addition](https://www.assemblyai.com/blog/content/media/2022/02/vectorized.mp4)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "etlOWGzoEGbd"
      },
      "outputs": [],
      "source": [
        "def f(x):\n",
        "  return x * x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ig9bZi5lEe34",
        "outputId": "6dab19d8-6384-49c8-e130-05ec6c43f952"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 5: 2.87 s per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 5 jnp.stack([f(x) for x in jnp.arange(10000)]).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wjc5wLlBCLyz",
        "outputId": "0d023d18-48e4-4d61-c5a5-4745d666d94b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1 loop, best of 5: 2.9 s per loop\n"
          ]
        }
      ],
      "source": [
        "f_jit = jit(f)\n",
        "%timeit -n 1 -r 5 jnp.stack([f_jit(x) for x in jnp.arange(10000)]).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VEi0aBDyEzS0",
        "outputId": "42779349-dec3-4d2b-b878-d772074391a2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The slowest run took 90.71 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
            "1 loop, best of 5: 839 µs per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 5  vmap(f)(jnp.arange(10000)).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RKWuRgJ1CWnH",
        "outputId": "a9efab60-32f8-4ea6-f933-129cd17dca86"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The slowest run took 101.37 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
            "1 loop, best of 5: 688 µs per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n 1 -r 5  vmap(f_jit)(jnp.arange(10000)).block_until_ready()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B9esw_QVtF5f"
      },
      "source": [
        "Almost 400 times faster \\m/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_HkiKoVcDUAw"
      },
      "source": [
        "## SPMD Programming with Pmap\n",
        "\n",
        "Consider the example of vector-matrix multiplication.\n",
        "\n",
        "Watch the difference:\n",
        "\n",
        "[Unparallelized vector-matrix multiplication](https://www.assemblyai.com/blog/content/media/2022/02/not_parallel-2.mp4)\n",
        "\n",
        "[Parallelized vector-matrix multiplication](https://www.assemblyai.com/blog/content/media/2022/02/parallelized.mp4)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gFxdCwt0l7k-"
      },
      "source": [
        "Check [JAX TPU](https://nbviewer.org/github/ZohebAbai/Deep-Learning-Projects/blob/master/JAX_TPU.ipynb) Notebook."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "16YLvHOoGWmM"
      },
      "source": [
        "## Device Memory Profiler\n",
        "\n",
        "JAX’s built-in Device Memory Profiler, provides visibility into how the JAX code executes on GPUs and TPUs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jncf5FHktHR0"
      },
      "outputs": [],
      "source": [
        "import jax.profiler\n",
        "\n",
        "def func1(x):\n",
        "  return jnp.tile(x, 10) * 0.5\n",
        "\n",
        "def func2(x):\n",
        "  y = func1(x)\n",
        "  return y, jnp.tile(x, 10) + 1\n",
        "\n",
        "x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))\n",
        "y, z = func2(x)\n",
        "\n",
        "z.block_until_ready()\n",
        "jax.profiler.save_device_memory_profile(\"memory.prof\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dlRFoYeyGdU-",
        "outputId": "c03b4b88-4a7b-4115-d36b-76acd0423635"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[0;31mMain binary filename not available.\n",
            "\u001b[0m\u001b[0;31mGenerating report in profile001.png\n",
            "\u001b[0m"
          ]
        }
      ],
      "source": [
        "!go tool pprof -png memory.prof"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h8IGKJt6KJYE"
      },
      "source": [
        "## Overview:\n",
        "\n",
        "- JAX is easy to use with Numpy like API and is device agnostic\n",
        "- JAX is fast and easily parallelizable\n",
        "- Robust and powerful function transformations\n",
        "- Functional programming model which aligns well with maths, and are easier to debug and produce reproducible results.\n",
        "- It's still in dev mode, not officially released. \n",
        "- Its good for research purpose, but if you want to fast explore the JAX's deep learning capabilities try Flex or Haiku, built on JAX.\n",
        "\n",
        "**Are you convinced regarding awesomness of JAX?**\n",
        "\n",
        "If you love python, here's a fun fact you may like to know:\n",
        "\n",
        "- Numpy: 62.3% Python, 35.3% C\n",
        "- Pytorch: 52.4% C++, 37.4% Python\n",
        "- Tensorflow: 62.7% C++, 22.2% Python\n",
        "- Julia: 68.4% Julia, 16.4% C, 10.1% C++\n",
        "- Jax: 92.7% Python\n",
        "- Flax: 98.1% Python  \n",
        "\n",
        "**Next** - Train deep learning models using Flax/Haiku\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "JAX_Basics.ipynb",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}