{ "cells": [ { "cell_type": "markdown", "id": "b69996dc-49af-4a0e-a4e6-36d81b51f2b4", "metadata": {}, "source": [ "# Porting a PyTorch model to JAX\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)\n", "\n", "**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.**\n", "\n", "In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`." ] }, { "cell_type": "code", "execution_count": 1, "id": "NHqB3sNbrygd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/424.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m419.8/424.2 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.2/424.2 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/175.6 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m175.6/175.6 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install -Uq flax treescope" ] }, { "cell_type": "markdown", "id": "ABCg5TvPr1pm", "metadata": {}, "source": [ "Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697).\n", "\n", "First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model." ] }, { "cell_type": "code", "execution_count": 2, "id": "38504f77-4150-47bd-9cf9-3116fe370746", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from flax import nnx" ] }, { "cell_type": "markdown", "id": "95a364c2-d34e-4820-8a86-f43f59c911bf", "metadata": {}, "source": [ "## MaxViT PyTorch model setup\n", "\n", "### Model's architecture\n", "\n", "The MaxVit model is [implemented in TorchVision](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L568). If we inspect the [forward pass](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L707-L712) of the model, we can see that it contains three high-level parts:\n", "- [stem](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L641-L655): a few convolutions, batchnorms, GELU activations.\n", "- [blocks](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L672-L692): list of MaxViT blocks\n", "- [classifier](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L696-L703): adaptive average pooling, few linear layers and Tanh activation." ] }, { "cell_type": "code", "execution_count": 3, "id": "9b1be406-d21c-410d-a2ac-9bd690e5ad60", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n", " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", "Downloading: \"https://download.pytorch.org/models/maxvit_t-bc5ab103.pth\" to /root/.cache/torch/hub/checkpoints/maxvit_t-bc5ab103.pth\n", "100%|██████████| 119M/119M [00:02<00:00, 53.9MB/s]\n" ] } ], "source": [ "from torchvision.models import maxvit_t, MaxVit_T_Weights\n", "\n", "torch_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1)" ] }, { "cell_type": "markdown", "id": "45635b2d-a77a-4368-9ecb-dbb440e647ee", "metadata": {}, "source": [ "We can use `flax.nnx.display` to display the model's architecture:" ] }, { "cell_type": "code", "execution_count": 4, "id": "sZ9x7NpHtBcx", "metadata": {}, "outputs": [], "source": [ "# nnx.display(torch_model)" ] }, { "cell_type": "markdown", "id": "0a36676a-1561-4de0-8e25-38bab90581d0", "metadata": {}, "source": [ "We can see that there are four MaxViT blocks in the model and each block contains:\n", "- MaxViT layers: two layers for blocks 0, 1, 3 and five layers for the block 4" ] }, { "cell_type": "code", "execution_count": 5, "id": "0d5bf6aa-c720-4400-a276-602fff53b413", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(4, [2, 2, 5, 2])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(torch_model.blocks), [len(b.layers) for b in torch_model.blocks]" ] }, { "cell_type": "markdown", "id": "a1d55688-5999-41de-a915-eae8b281eb18", "metadata": {}, "source": [ "A [MaxViT layer](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L386) is composed of: [`MBConv`](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L53), `window_attention` as [`PartitionAttentionLayer`](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L282) and `grid_attention` as `PartitionAttentionLayer`." ] }, { "cell_type": "code", "execution_count": 6, "id": "03ce0555-888a-4086-bb6c-64c36ae60b14", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer']]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[[mod.__class__.__name__ for mod in maxvit_layer.layers] for b in torch_model.blocks for maxvit_layer in b.layers]" ] }, { "cell_type": "markdown", "id": "d57f8545-43a4-423d-b701-c2e2ca0ebfc1", "metadata": {}, "source": [ "### Inference on data\n", "\n", "Let's check the model on dummy input and on a real image" ] }, { "cell_type": "code", "execution_count": 7, "id": "d6c95620-bf50-47e4-b8d6-3a85262941ed", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 1000])\n" ] } ], "source": [ "import torch\n", "\n", "\n", "torch_model.eval()\n", "with torch.inference_mode():\n", " x = torch.rand(2, 3, 224, 224)\n", " output = torch_model(x)\n", "\n", "print(output.shape) # (2, 1000)" ] }, { "cell_type": "markdown", "id": "133bcf21-8a9c-4c27-b551-39b7dfdcfe1c", "metadata": {}, "source": [ "We can download an image of a Pembroke Corgy dog from [TorchVision's gallery](https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true) together with [ImageNet classes dictionary](https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json):" ] }, { "cell_type": "code", "execution_count": 8, "id": "qC9hpYfNtOEF", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-01-15 21:10:00 URL:https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/dog1.jpg [97422/97422] -> \"dog1.jpg\" [1]\n", "2025-01-15 21:10:01 URL:https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json [35364/35364] -> \"imagenet_class_index.json\" [1]\n" ] } ], "source": [ "%%bash\n", "if [ -f \"dog1.jpg\" ]; then\n", " echo \"dog1.jpg already exists.\"\n", "else\n", " wget -nv \"https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true\" -O dog1.jpg\n", "fi\n", "if [ -f \"imagenet_class_index.json\" ]; then\n", " echo \"imagenet_class_index.json already exists.\"\n", "else\n", " wget -nv \"https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json\" -O imagenet_class_index.json\n", "fi" ] }, { "cell_type": "code", "execution_count": 9, "id": "82be8baf-1292-4766-be34-28c510563d71", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction for the Dog: ['n02113023', 'Pembroke'], score: 0.7800846099853516\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "import json\n", "from torchvision.io import read_image\n", "\n", "\n", "preprocess = MaxVit_T_Weights.IMAGENET1K_V1.transforms()\n", "\n", "with open(\"imagenet_class_index.json\") as labels_file:\n", " labels = json.load(labels_file)\n", "\n", "\n", "dog1 = read_image(\"dog1.jpg\")\n", "tensor = preprocess(dog1)\n", "\n", "torch_model.eval()\n", "with torch.inference_mode():\n", " output = torch_model(tensor.unsqueeze(dim=0))\n", "\n", "class_id = output.argmax(dim=1).item()\n", "\n", "print(f\"Prediction for the Dog: {labels[str(class_id)]}, score: {output.softmax(dim=-1)[0, class_id]}\")\n", "\n", "plt.title(f\"{labels[str(class_id)]}\\nScore: {output.softmax(dim=-1)[0, class_id]}\")\n", "plt.imshow(dog1.permute(1, 2, 0))" ] }, { "cell_type": "markdown", "id": "8cbe4ccc-224b-4e8a-a2a9-e2c756c9b207", "metadata": {}, "source": [ "## Port MaxViT model to JAX\n", "\n", "To port the [PyTorch implementation of the MaxVit model](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L568) in JAX using the Flax module, we will implement the following required modules:\n", "\n", "- `MaxViT`\n", " - `MaxVitBlock`\n", " - `MaxVitLayer`\n", " - `MBConv`\n", " - `Conv2dNormActivation`\n", " - `SqueezeExcitation`\n", " - `PartitionAttentionLayer`\n", " - `RelativePositionalMultiHeadAttention`\n", " - `WindowDepartition`\n", " - `WindowPartition`\n", " - `SwapAxes`\n", " - `StochasticDepth`\n", "\n", "The Flax NNX module is very similar to PyTorch `torch.nn` module and we can map the following modules between PyTorch and Flax:\n", "- `nn.Sequential` and `nn.ModuleList` -> `nnx.Sequential`\n", "- `nn.Linear` -> `nnx.Linear`\n", "- `nn.Conv2d` -> `nnx.Conv`\n", "- `nn.BatchNorm2d` -> `nnx.BatchNorm`\n", "- Activations like `nn.ReLU` -> `nnx.relu`\n", "- Pooling layers like `nn.AvgPool2d(...)` -> `lambda x: nnx.avg_pool(x, ...)`\n", "- `nn.AdaptiveAvgPool2d(1)` -> `lambda x: nnx.avg_pool(x, (x.shape[1], x.shape[2]))`, x is in NHWC format\n", "- `nn.Flatten()` -> `lambda x: x.reshape(x.shape[0], -1)`\n", "\n", "\n", "If the PyTorch model defines a learnable parameter and a buffer:\n", "```python\n", "class Model(nn.Module):\n", " def __init__(self, ...):\n", " ...\n", " self.p = nn.Parameter(torch.ones(10))\n", " self.register_buffer(\"b\", torch.ones(5))\n", "```\n", "an equivalent code in Flax would be\n", "```python\n", "class Buffer(nnx.Variable):\n", " pass\n", "\n", "\n", "class Model(nnx.Module):\n", " def __init__(self, ...):\n", " ...\n", " self.p = nnx.Param(jnp.ones((10,)))\n", " self.b = Buffer(jnp.ones(5))\n", "```\n", "\n", "To inspect NNX module's learnable parameters and buffers, we can use `nnx.state`:\n", "```python\n", "nnx_module = ...\n", "for k, v in nnx.state(nnx_module, nnx.Param).flat_state():\n", " print(\n", " k,\n", " v.value.mean() if v.value is not None else None\n", " )\n", "\n", "for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state():\n", " print(\n", " k,\n", " v.value.mean() if v.value.dtype == \"float32\" else v.value.sum()\n", " )\n", "```\n", "The equivalent PyTorch code is:\n", "```python\n", "torch_module = ...\n", "\n", "for m, p in torch_module.named_parameters():\n", " print(m, p.detach().mean())\n", "\n", "for m, b in torch_module.named_buffers():\n", " print(\n", " m,\n", " b.mean() if b.dtype == torch.float32 else b.sum()\n", " )\n", "```" ] }, { "cell_type": "markdown", "id": "305ac55b-62ed-4f4d-902c-c6f3082afb02", "metadata": {}, "source": [ "Please note some differences between `torch.nn` and Flax when porting models:\n", "- We should pass `rngs` to all NNX modules with parameters: e.g. `nnx.Linear(..., rngs=nnx.Rngs(0))`\n", "- For a 2D convolution:\n", " - In Flax, we need to explicitly define `kernel_size`, `strides` as two ints tuples, e.g. `(3, 3)`\n", " - If PyTorch code defines `padding` as integer, e.g. 2, in Flax it should be explicitly defined as a tuple of two ints per dimension, i.e. `((2, 2), (2, 2))`.\n", "- For a batch normalization: `momentum` value in `torch.nn` should be defined as `1.0 - momentum` in Flax.\n", "- 4D input arrays in Flax should be in NHWC format, i.e. of shape (N, H, W, C) compared to NCHW format (or (N, C, H, W) shape) in PyTorch." ] }, { "cell_type": "markdown", "id": "8d7e3479-bffe-4cb6-81e1-ed8f972c5bf0", "metadata": {}, "source": [ "Below we implement one by one all the modules from the above list and add simple forward pass checks.\n", "Let's first implement equivalent of `nn.Identity`." ] }, { "cell_type": "code", "execution_count": 10, "id": "54ece7f1-14c1-41ef-980a-fc279d1702f2", "metadata": {}, "outputs": [], "source": [ "class Identity(nnx.Module):\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return x" ] }, { "cell_type": "markdown", "id": "dd87b2aa-0285-4995-a9aa-ebd58ae00de6", "metadata": {}, "source": [ "### `Conv2dNormActivation` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/ops/misc.py#L125)." ] }, { "cell_type": "code", "execution_count": 11, "id": "69d71163-676e-4ad3-8d8c-45efaafd76e7", "metadata": {}, "outputs": [], "source": [ "from typing import Callable, List, Optional, Tuple\n", "from flax import nnx\n", "\n", "\n", "class Conv2dNormActivation(nnx.Sequential):\n", " def __init__(\n", " self,\n", " in_channels: int,\n", " out_channels: int,\n", " kernel_size: int = 3,\n", " stride: int = 1,\n", " padding: Optional[int] = None,\n", " groups: int = 1,\n", " norm_layer: Callable[..., nnx.Module] = nnx.BatchNorm,\n", " activation_layer: Callable = nnx.relu,\n", " dilation: int = 1,\n", " bias: Optional[bool] = None,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " self.out_channels = out_channels\n", "\n", " if padding is None:\n", " padding = (kernel_size - 1) // 2 * dilation\n", " if bias is None:\n", " bias = norm_layer is None\n", "\n", " # sequence integer pairs that give the padding to apply before\n", " # and after each spatial dimension\n", " padding = ((padding, padding), (padding, padding))\n", "\n", " layers = [\n", " nnx.Conv(\n", " in_channels,\n", " out_channels,\n", " kernel_size=(kernel_size, kernel_size),\n", " strides=(stride, stride),\n", " padding=padding,\n", " kernel_dilation=(dilation, dilation),\n", " feature_group_count=groups,\n", " use_bias=bias,\n", " rngs=rngs,\n", " )\n", " ]\n", "\n", " if norm_layer is not None:\n", " layers.append(norm_layer(out_channels, rngs=rngs))\n", "\n", " if activation_layer is not None:\n", " layers.append(activation_layer)\n", "\n", " super().__init__(*layers)" ] }, { "cell_type": "code", "execution_count": 12, "id": "e5269a0a-f43f-4fdf-9955-aa3fcde60c01", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 14, 14, 64)\n" ] } ], "source": [ "x = jnp.ones((4, 28, 28, 32))\n", "mod = Conv2dNormActivation(32, 64, 3, 2, 1)\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "2d0cd827-ad40-4cd3-9560-6565a3df10bc", "metadata": {}, "source": [ "### `SqueezeExcitation` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/ops/misc.py#L224)." ] }, { "cell_type": "code", "execution_count": 13, "id": "4232689e-e6cc-4ffd-8a2a-41fbc34e57c2", "metadata": {}, "outputs": [], "source": [ "class SqueezeExcitation(nnx.Module):\n", " def __init__(\n", " self,\n", " input_channels: int,\n", " squeeze_channels: int,\n", " activation: Callable = nnx.relu,\n", " scale_activation: Callable = nnx.sigmoid,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " self.avgpool = nnx.avg_pool\n", " self.fc1 = nnx.Conv(input_channels, squeeze_channels, (1, 1), rngs=rngs)\n", " self.fc2 = nnx.Conv(squeeze_channels, input_channels, (1, 1), rngs=rngs)\n", " self.activation = activation\n", " self.scale_activation = scale_activation\n", "\n", " def _scale(self, x: jax.Array) -> jax.Array:\n", " scale = self.avgpool(x, (x.shape[1], x.shape[2]))\n", " scale = self.fc1(scale)\n", " scale = self.activation(scale)\n", " scale = self.fc2(scale)\n", " return self.scale_activation(scale)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " scale = self._scale(x)\n", " return scale * x" ] }, { "cell_type": "code", "execution_count": 14, "id": "83c55286-b92e-49aa-bd5f-c2448a787673", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 28, 28, 32)\n" ] } ], "source": [ "x = jnp.ones((4, 28, 28, 32))\n", "mod = SqueezeExcitation(32, 4)\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "7935790a-4cb1-46dc-ab73-12d3cb8fc636", "metadata": {}, "source": [ "### `StochasticDepth` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/ops/stochastic_depth.py#L50)." ] }, { "cell_type": "code", "execution_count": 15, "id": "96834419-eec1-4690-8bb0-447524f6bdde", "metadata": {}, "outputs": [], "source": [ "def stochastic_depth(\n", " x: jax.Array,\n", " p: float,\n", " mode: str,\n", " deterministic: bool = False,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", ") -> jax.Array:\n", " if p < 0.0 or p > 1.0:\n", " raise ValueError(f\"drop probability has to be between 0 and 1, but got {p}\")\n", " if mode not in [\"batch\", \"row\"]:\n", " raise ValueError(f\"mode has to be either 'batch' or 'row', but got {mode}\")\n", " if deterministic or p == 0.0:\n", " return x\n", "\n", " survival_rate = 1.0 - p\n", " if mode == \"row\":\n", " size = [x.shape[0]] + [1] * (x.ndim - 1)\n", " else:\n", " size = [1] * x.ndim\n", "\n", " noise = jax.random.bernoulli(\n", " rngs.dropout(), p=survival_rate, shape=size\n", " ).astype(dtype=x.dtype)\n", "\n", " if survival_rate > 0.0:\n", " noise = noise / survival_rate\n", "\n", " return x * noise\n", "\n", "\n", "class StochasticDepth(nnx.Module):\n", " def __init__(\n", " self,\n", " p: float,\n", " mode: str,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " self.p = p\n", " self.mode = mode\n", " self.deterministic = False\n", " self.rngs = rngs\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return stochastic_depth(\n", " x, self.p, self.mode, self.deterministic, rngs=self.rngs\n", " )" ] }, { "cell_type": "code", "execution_count": 16, "id": "fd95babb-95b4-4015-957d-11b9c7b9957d", "metadata": {}, "outputs": [], "source": [ "x = jnp.ones((4, 28, 28, 32))\n", "mod = StochasticDepth(0.5, \"row\")\n", "\n", "mod.eval()\n", "y = mod(x)\n", "assert (y == x).all()\n", "\n", "mod.train()\n", "y = mod(x)\n", "assert (y != x).any()" ] }, { "cell_type": "markdown", "id": "0ce251eb-a8dc-4415-9856-d16421c1d646", "metadata": {}, "source": [ "### `MBConv` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L53)" ] }, { "cell_type": "code", "execution_count": 17, "id": "636c713c-4a21-439a-b220-2b9407a06dfc", "metadata": {}, "outputs": [], "source": [ "class MBConv(nnx.Module):\n", " def __init__(\n", " self,\n", " in_channels: int,\n", " out_channels: int,\n", " expansion_ratio: float,\n", " squeeze_ratio: float,\n", " stride: int,\n", " activation_layer: Callable,\n", " norm_layer: Callable[..., nnx.Module],\n", " p_stochastic_dropout: float = 0.0,\n", " rngs = nnx.Rngs(0),\n", " ):\n", " should_proj = stride != 1 or in_channels != out_channels\n", " if should_proj:\n", " proj = [nnx.Conv(\n", " in_channels, out_channels, kernel_size=(1, 1), strides=(1, 1), use_bias=True, rngs=rngs\n", " )]\n", " if stride == 2:\n", " padding = ((1, 1), (1, 1))\n", " proj = [\n", " lambda x: nnx.avg_pool(\n", " x, window_shape=(3, 3), strides=(stride, stride), padding=padding\n", " )\n", " ] + proj\n", " self.proj = nnx.Sequential(*proj)\n", " else:\n", " self.proj = Identity()\n", "\n", " mid_channels = int(out_channels * expansion_ratio)\n", " sqz_channels = int(out_channels * squeeze_ratio)\n", "\n", " if p_stochastic_dropout:\n", " self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode=\"row\", rngs=rngs)\n", " else:\n", " self.stochastic_depth = Identity()\n", "\n", " _layers = [\n", " norm_layer(in_channels, rngs=rngs), # pre_norm\n", " Conv2dNormActivation( # conv_a\n", " in_channels,\n", " mid_channels,\n", " kernel_size=1,\n", " stride=1,\n", " padding=0,\n", " activation_layer=activation_layer,\n", " norm_layer=norm_layer,\n", " rngs=rngs,\n", " ),\n", " Conv2dNormActivation( # conv_b\n", " mid_channels,\n", " mid_channels,\n", " kernel_size=3,\n", " stride=stride,\n", " padding=1,\n", " activation_layer=activation_layer,\n", " norm_layer=norm_layer,\n", " groups=mid_channels,\n", " rngs=rngs,\n", " ),\n", " SqueezeExcitation( # squeeze_excitation\n", " mid_channels, sqz_channels, activation=nnx.silu, rngs=rngs\n", " ),\n", " nnx.Conv( # conv_c\n", " mid_channels, out_channels, kernel_size=(1, 1), use_bias=True, rngs=rngs\n", " )\n", " ]\n", " self.layers = nnx.Sequential(*_layers)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " res = self.proj(x)\n", " x = self.stochastic_depth(self.layers(x))\n", " return res + x" ] }, { "cell_type": "code", "execution_count": 18, "id": "5cd24b07-f160-422c-bea3-2baf5ebca5b0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(4, 14, 14, 64)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functools import partial\n", "\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "x = jnp.ones((4, 28, 28, 32))\n", "mod = MBConv(32, 64, 4, 0.25, 2, activation_layer=nnx.gelu, norm_layer=norm_layer)\n", "y = mod(x)\n", "y.shape" ] }, { "cell_type": "markdown", "id": "3a8d9cb4-795b-4cb2-a014-bb440acc800b", "metadata": {}, "source": [ "### `RelativePositionalMultiHeadAttention` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L140). First we reimplement a helper function `_get_relative_position_index`:" ] }, { "cell_type": "code", "execution_count": 19, "id": "df647057-8c6f-4c6b-84f9-d6f78e649343", "metadata": {}, "outputs": [], "source": [ "def _get_relative_position_index(height: int, width: int) -> jax.Array:\n", " # PyTorch code:\n", " # coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))\n", "\n", " coords = jnp.stack(\n", " jnp.meshgrid(*[jnp.arange(height), jnp.arange(width)], indexing=\"ij\")\n", " )\n", " # PyTorch code: coords_flat = torch.flatten(coords, 1)\n", " coords_flat = coords.reshape(coords.shape[0], -1)\n", "\n", " relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]\n", " relative_coords = jnp.permute_dims(relative_coords, (1, 2, 0))\n", "\n", " # PyTorch code:\n", " # relative_coords[:, :, 0] += height - 1\n", " # relative_coords[:, :, 1] += width - 1\n", " # relative_coords[:, :, 0] *= 2 * width - 1\n", " relative_coords = relative_coords + jnp.array((height - 1, width - 1))\n", " relative_coords = relative_coords * jnp.array((2 * width - 1, 1))\n", "\n", " return relative_coords.sum(-1)" ] }, { "cell_type": "markdown", "id": "2670d86b", "metadata": {}, "source": [ "Let us check our implementation against PyTorch implementation:" ] }, { "cell_type": "code", "execution_count": 20, "id": "5ce55b8b-5305-4a57-a413-8df43392ec3a", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import _get_relative_position_index as pytorch_get_relative_position_index\n", "\n", "\n", "output = _get_relative_position_index(13, 12)\n", "expected = pytorch_get_relative_position_index(13, 12)\n", "assert (output == jnp.asarray(expected)).all()" ] }, { "cell_type": "markdown", "id": "5518bfc4", "metadata": {}, "source": [ "Next, we can port `RelativePositionalMultiHeadAttention` module which a learnable parameter and a buffer:" ] }, { "cell_type": "code", "execution_count": 21, "id": "1f46b3e4-fd69-42c2-8ca7-d242a20d13de", "metadata": {}, "outputs": [], "source": [ "import math\n", "\n", "\n", "class Buffer(nnx.Variable):\n", " pass\n", "\n", "\n", "class RelativePositionalMultiHeadAttention(nnx.Module):\n", " def __init__(\n", " self,\n", " feat_dim: int,\n", " head_dim: int,\n", " max_seq_len: int,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " if feat_dim % head_dim != 0:\n", " raise ValueError(f\"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}\")\n", "\n", " self.n_heads = feat_dim // head_dim\n", " self.head_dim = head_dim\n", " self.size = int(math.sqrt(max_seq_len))\n", " self.max_seq_len = max_seq_len\n", "\n", " self.to_qkv = nnx.Linear(feat_dim, self.n_heads * self.head_dim * 3, rngs=rngs)\n", " self.scale_factor = feat_dim**-0.5\n", "\n", " self.merge = nnx.Linear(self.head_dim * self.n_heads, feat_dim, rngs=rngs)\n", "\n", " self.relative_position_index = Buffer(_get_relative_position_index(self.size, self.size))\n", "\n", " # initialize with truncated normal bias\n", " initializer = jax.nn.initializers.truncated_normal(stddev=0.02)\n", " shape = ((2 * self.size - 1) * (2 * self.size - 1), self.n_heads)\n", " self.relative_position_bias_table = nnx.Param(initializer(rngs.params(), shape, jnp.float32))\n", "\n", " def get_relative_positional_bias(self) -> jax.Array:\n", " bias_index = self.relative_position_index.value.ravel()\n", " relative_bias = self.relative_position_bias_table[bias_index].reshape((self.max_seq_len, self.max_seq_len, -1))\n", " relative_bias = jnp.permute_dims(relative_bias, (2, 0, 1))\n", " return jnp.expand_dims(relative_bias, axis=0)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " B, G, P, D = x.shape\n", " H, DH = self.n_heads, self.head_dim\n", "\n", " qkv = self.to_qkv(x)\n", "\n", " q, k, v = jnp.split(qkv, 3, axis=-1)\n", " q = jnp.permute_dims(q.reshape((B, G, P, H, DH)), (0, 1, 3, 2, 4))\n", " k = jnp.permute_dims(k.reshape((B, G, P, H, DH)), (0, 1, 3, 2, 4))\n", " v = jnp.permute_dims(v.reshape((B, G, P, H, DH)), (0, 1, 3, 2, 4))\n", "\n", " k = k * self.scale_factor\n", "\n", " dot_prod = jnp.einsum(\"B G H I D, B G H J D -> B G H I J\", q, k)\n", " pos_bias = self.get_relative_positional_bias()\n", "\n", " dot_prod = jax.nn.softmax(dot_prod + pos_bias, axis=-1)\n", "\n", " out = jnp.einsum(\"B G H I J, B G H J D -> B G H I D\", dot_prod, v)\n", " out = jnp.permute_dims(out, (0, 1, 3, 2, 4)).reshape((B, G, P, D))\n", "\n", " out = self.merge(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 22, "id": "18d0c993", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 32, 49, 64)\n" ] } ], "source": [ "x = jnp.ones((4, 32, 49, 64))\n", "\n", "mod = RelativePositionalMultiHeadAttention(64, 16, 49)\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "875aba65-53d0-4241-bdd7-36384054ca59", "metadata": {}, "source": [ "### `SwapAxes`, `WindowPartition`, `WindowDepartition` implementations\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L213)." ] }, { "cell_type": "code", "execution_count": 23, "id": "d8a19362-733a-4359-9658-53dcffa25220", "metadata": {}, "outputs": [], "source": [ "class SwapAxes(nnx.Module):\n", " def __init__(self, a: int, b: int):\n", " self.a = a\n", " self.b = b\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " res = jnp.swapaxes(x, self.a, self.b)\n", " return res\n", "\n", "\n", "class WindowPartition(nnx.Module):\n", " def __call__(self, x: jax.Array, p: int) -> jax.Array:\n", " # Output array with expected layout of [B, H/P, W/P, P*P, C].\n", " B, H, W, C = x.shape\n", " P = p\n", " # chunk up H and W dimensions\n", " x = x.reshape((B, H // P, P, W // P, P, C))\n", " x = jnp.permute_dims(x, (0, 1, 3, 2, 4, 5))\n", " # colapse P * P dimension\n", " x = x.reshape((B, (H // P) * (W // P), P * P, C))\n", " return x\n", "\n", "\n", "class WindowDepartition(nnx.Module):\n", " def __call__(self, x: jax.Array, p: int, h_partitions: int, w_partitions: int) -> jax.Array:\n", " # Output array with expected layout of [B, H, W, C].\n", " B, G, PP, C = x.shape\n", " P = p\n", " HP, WP = h_partitions, w_partitions\n", " # split P * P dimension into 2 P tile dimensions\n", " x = x.reshape((B, HP, WP, P, P, C))\n", " # permute into B, HP, P, WP, P, C\n", " x = jnp.permute_dims(x, (0, 1, 3, 2, 4, 5))\n", " # reshape into B, H, W, C\n", " x = x.reshape((B, HP * P, WP * P, C))\n", " return x" ] }, { "cell_type": "code", "execution_count": 24, "id": "daee5b6b-595f-4344-af93-6e4bd44c217f", "metadata": {}, "outputs": [], "source": [ "x = jnp.ones((3, 4, 5, 6))\n", "mod = SwapAxes(1, 2)\n", "y = mod(x)\n", "assert y.shape == (3, 5, 4, 6)\n", "\n", "x = jnp.ones((4, 128, 128, 3))\n", "mod = WindowPartition()\n", "y = mod(x, p=16)\n", "assert y.shape == (4, (128 // 16) * (128 // 16), 16 * 16, 3)\n", "\n", "x = jnp.ones((4, (128 // 16) * (128 // 16), 16 * 16, 3))\n", "mod = WindowDepartition()\n", "y = mod(x, p=16, h_partitions=128 // 16, w_partitions=128 // 16)\n", "assert y.shape == (4, 128, 128, 3)" ] }, { "cell_type": "markdown", "id": "fe9643dd-b328-43c9-a82f-7180ee2b9a00", "metadata": {}, "source": [ "### `PartitionAttentionLayer` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L282)." ] }, { "cell_type": "code", "execution_count": 25, "id": "dfb3c640-4b51-4ca5-a7ba-2ad5f9907c57", "metadata": {}, "outputs": [], "source": [ "class PartitionAttentionLayer(nnx.Module):\n", " \"\"\"\n", " Layer for partitioning the input tensor into non-overlapping windows and\n", " applying attention to each window.\n", " \"\"\"\n", " def __init__(\n", " self,\n", " in_channels: int,\n", " head_dim: int,\n", " # partitioning parameters\n", " partition_size: int,\n", " partition_type: str,\n", " # grid size needs to be known at initialization time\n", " # because we need to know hamy relative offsets there are in the grid\n", " grid_size: Tuple[int, int],\n", " mlp_ratio: int,\n", " activation_layer: Callable,\n", " norm_layer: Callable[..., nnx.Module],\n", " attention_dropout: float,\n", " mlp_dropout: float,\n", " p_stochastic_dropout: float,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " self.n_heads = in_channels // head_dim\n", " self.head_dim = head_dim\n", " self.n_partitions = grid_size[0] // partition_size\n", " self.partition_type = partition_type\n", " self.grid_size = grid_size\n", "\n", " if partition_type not in [\"grid\", \"window\"]:\n", " raise ValueError(\"partition_type must be either 'grid' or 'window'\")\n", "\n", " if partition_type == \"window\":\n", " self.p, self.g = partition_size, self.n_partitions\n", " else:\n", " self.p, self.g = self.n_partitions, partition_size\n", "\n", " self.partition_op = WindowPartition()\n", " self.departition_op = WindowDepartition()\n", " self.partition_swap = SwapAxes(-2, -3) if partition_type == \"grid\" else Identity()\n", " self.departition_swap = SwapAxes(-2, -3) if partition_type == \"grid\" else Identity()\n", "\n", " self.attn_layer = nnx.Sequential(\n", " norm_layer(in_channels, rngs=rngs),\n", " # it's always going to be partition_size ** 2 because\n", " # of the axis swap in the case of grid partitioning\n", " RelativePositionalMultiHeadAttention(\n", " in_channels, head_dim, partition_size**2, rngs=rngs\n", " ),\n", " nnx.Dropout(attention_dropout, rngs=rngs),\n", " )\n", "\n", " # pre-normalization similar to transformer layers\n", " self.mlp_layer = nnx.Sequential(\n", " nnx.LayerNorm(in_channels, rngs=rngs),\n", " nnx.Linear(in_channels, in_channels * mlp_ratio, rngs=rngs),\n", " activation_layer,\n", " nnx.Linear(in_channels * mlp_ratio, in_channels, rngs=rngs),\n", " nnx.Dropout(mlp_dropout, rngs=rngs),\n", " )\n", "\n", " # layer scale factors\n", " self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode=\"row\", rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " # Undefined behavior if H or W are not divisible by p\n", " # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766\n", " gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p\n", " torch._assert(\n", " self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0,\n", " \"Grid size must be divisible by partition size. Got grid size of {} and partition size of {}\".format(\n", " self.grid_size, self.p\n", " ),\n", " )\n", " x = self.partition_op(x, self.p) # (B, H, W, C) -> (B, H/P, W/P, P*P, C)\n", " x = self.partition_swap(x) # -> grid: (B, H/P, P*P, W/P, C)\n", " x = x + self.stochastic_dropout(self.attn_layer(x))\n", " x = x + self.stochastic_dropout(self.mlp_layer(x))\n", " x = self.departition_swap(x) # grid: (B, H/P, P*P, W/P, C) -> (B, H/P, W/P, P*P, C)\n", " x = self.departition_op(x, self.p, gh, gw) # -> (B, H, W, C)\n", "\n", " return x" ] }, { "cell_type": "code", "execution_count": 26, "id": "d6feac34-35be-420b-a7cb-78995aed4c7a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 224, 224, 36)\n", "(4, 224, 224, 36)\n" ] } ], "source": [ "x = jnp.ones((4, 224, 224, 36))\n", "\n", "grid_size = (224, 224)\n", "mod = PartitionAttentionLayer(\n", " 36, 6, 7, \"window\", grid_size=grid_size, mlp_ratio=4,\n", " activation_layer=nnx.gelu, norm_layer=nnx.LayerNorm,\n", " attention_dropout=0.4, mlp_dropout=0.3, p_stochastic_dropout=0.2,\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)\n", "\n", "mod = PartitionAttentionLayer(\n", " 36, 6, 7, \"grid\", grid_size=grid_size, mlp_ratio=4,\n", " activation_layer=nnx.gelu, norm_layer=nnx.LayerNorm,\n", " attention_dropout=0.4, mlp_dropout=0.3, p_stochastic_dropout=0.2,\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "b89b4ca6-c17a-4c0f-859a-de7134348818", "metadata": {}, "source": [ "### `MaxVitLayer` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L386)." ] }, { "cell_type": "code", "execution_count": 27, "id": "45b3199e-711d-4125-86b9-22e90fafa28c", "metadata": {}, "outputs": [], "source": [ "class MaxVitLayer(nnx.Module):\n", " \"\"\"\n", " MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window`\n", " and a PartitionAttentionLayer with `grid`.\n", " \"\"\"\n", " def __init__(\n", " self,\n", " # conv parameters\n", " in_channels: int,\n", " out_channels: int,\n", " squeeze_ratio: float,\n", " expansion_ratio: float,\n", " stride: int,\n", " # conv + transformer parameters\n", " norm_layer: Callable[..., nnx.Module],\n", " activation_layer: Callable,\n", " # transformer parameters\n", " head_dim: int,\n", " mlp_ratio: int,\n", " mlp_dropout: float,\n", " attention_dropout: float,\n", " p_stochastic_dropout: float,\n", " # partitioning parameters\n", " partition_size: int,\n", " grid_size: Tuple[int, int],\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " layers = [\n", " # convolutional layer\n", " MBConv(\n", " in_channels=in_channels,\n", " out_channels=out_channels,\n", " expansion_ratio=expansion_ratio,\n", " squeeze_ratio=squeeze_ratio,\n", " stride=stride,\n", " activation_layer=activation_layer,\n", " norm_layer=norm_layer,\n", " p_stochastic_dropout=p_stochastic_dropout,\n", " rngs=rngs,\n", " ),\n", " # window_attention\n", " PartitionAttentionLayer(\n", " in_channels=out_channels,\n", " head_dim=head_dim,\n", " partition_size=partition_size,\n", " partition_type=\"window\",\n", " grid_size=grid_size,\n", " mlp_ratio=mlp_ratio,\n", " activation_layer=activation_layer,\n", " norm_layer=nnx.LayerNorm,\n", " attention_dropout=attention_dropout,\n", " mlp_dropout=mlp_dropout,\n", " p_stochastic_dropout=p_stochastic_dropout,\n", " rngs=rngs,\n", " ),\n", " # grid_attention\n", " PartitionAttentionLayer(\n", " in_channels=out_channels,\n", " head_dim=head_dim,\n", " partition_size=partition_size,\n", " partition_type=\"grid\",\n", " grid_size=grid_size,\n", " mlp_ratio=mlp_ratio,\n", " activation_layer=activation_layer,\n", " norm_layer=nnx.LayerNorm,\n", " attention_dropout=attention_dropout,\n", " mlp_dropout=mlp_dropout,\n", " p_stochastic_dropout=p_stochastic_dropout,\n", " rngs=rngs,\n", " )\n", " ]\n", " self.layers = nnx.Sequential(*layers)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return self.layers(x)\n", "\n", "\n", "def _get_conv_output_shape(\n", " input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int\n", ") -> Tuple[int, int]:\n", " return (\n", " (input_size[0] - kernel_size + 2 * padding) // stride + 1,\n", " (input_size[1] - kernel_size + 2 * padding) // stride + 1,\n", " )" ] }, { "cell_type": "code", "execution_count": 28, "id": "6a130b58-95cf-4ad7-8a42-5044a37c7c09", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 112, 112, 36)\n" ] } ], "source": [ "x = jnp.ones((4, 224, 224, 3))\n", "\n", "grid_size = _get_conv_output_shape((224, 224), kernel_size=3, stride=2, padding=1)\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "\n", "mod = MaxVitLayer(\n", " 3, 36, squeeze_ratio=0.25, expansion_ratio=4,\n", " stride=2, norm_layer=norm_layer, activation_layer=nnx.gelu,\n", " head_dim=6, mlp_ratio=4, mlp_dropout=0.5,\n", " attention_dropout=0.4, p_stochastic_dropout=0.3,\n", " partition_size=7, grid_size=grid_size,\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "21460039-0ed8-4c37-8382-7d91655f1086", "metadata": {}, "source": [ "### `MaxVitBlock` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L483)." ] }, { "cell_type": "code", "execution_count": 29, "id": "e4fd31d2-4354-4694-87b4-3d0644388d3d", "metadata": {}, "outputs": [], "source": [ "class MaxVitBlock(nnx.Module):\n", " \"\"\"\n", " A MaxVit block consisting of `n_layers` MaxVit layers.\n", " \"\"\"\n", " def __init__(\n", " self,\n", " # conv parameters\n", " in_channels: int,\n", " out_channels: int,\n", " squeeze_ratio: float,\n", " expansion_ratio: float,\n", " # conv + transformer parameters\n", " norm_layer: Callable[..., nnx.Module],\n", " activation_layer: Callable,\n", " # transformer parameters\n", " head_dim: int,\n", " mlp_ratio: int,\n", " mlp_dropout: float,\n", " attention_dropout: float,\n", " # partitioning parameters\n", " partition_size: int,\n", " input_grid_size: Tuple[int, int],\n", " # number of layers\n", " n_layers: int,\n", " p_stochastic: List[float],\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " if not len(p_stochastic) == n_layers:\n", " raise ValueError(f\"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.\")\n", "\n", " # account for the first stride of the first layer\n", " self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1)\n", "\n", " layers = []\n", " for idx, p in enumerate(p_stochastic):\n", " stride = 2 if idx == 0 else 1\n", " layers.append(\n", " MaxVitLayer(\n", " in_channels=in_channels if idx == 0 else out_channels,\n", " out_channels=out_channels,\n", " squeeze_ratio=squeeze_ratio,\n", " expansion_ratio=expansion_ratio,\n", " stride=stride,\n", " norm_layer=norm_layer,\n", " activation_layer=activation_layer,\n", " head_dim=head_dim,\n", " mlp_ratio=mlp_ratio,\n", " mlp_dropout=mlp_dropout,\n", " attention_dropout=attention_dropout,\n", " partition_size=partition_size,\n", " grid_size=self.grid_size,\n", " p_stochastic_dropout=p,\n", " ),\n", " )\n", " self.layers = nnx.Sequential(*layers)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return self.layers(x)" ] }, { "cell_type": "code", "execution_count": 30, "id": "e168c27f-98db-4831-9723-dffac88f3226", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 112, 112, 36)\n" ] } ], "source": [ "x = jnp.ones((4, 224, 224, 3))\n", "\n", "input_grid_size = (224, 224)\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "\n", "mod = MaxVitBlock(\n", " 3, 36, squeeze_ratio=0.25, expansion_ratio=4,\n", " norm_layer=norm_layer, activation_layer=nnx.gelu,\n", " head_dim=6, mlp_ratio=4, mlp_dropout=0.5, attention_dropout=0.4,\n", " partition_size=7, input_grid_size=input_grid_size,\n", " n_layers=2,\n", " p_stochastic=[0.0, 0.2],\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "cef5687d-e390-438b-95b3-e66406e2c000", "metadata": {}, "source": [ "### `MaxVit` implementation\n", "\n", "Finally, we can assemble everything together and define the MaxVit model.\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L568)." ] }, { "cell_type": "code", "execution_count": 31, "id": "0e874e63-0eb7-40ea-82f3-bf10ac33d7a6", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]:\n", " \"\"\"Util function to check that the input size is correct for a MaxVit configuration.\"\"\"\n", " shapes = []\n", " block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1)\n", " for _ in range(n_blocks):\n", " block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1)\n", " shapes.append(block_input_shape)\n", " return shapes\n", "\n", "\n", "class MaxVit(nnx.Module):\n", " \"\"\"\n", " Implements MaxVit Transformer from the \"MaxViT: Multi-Axis Vision Transformer\" paper.\n", " \"\"\"\n", " def __init__(\n", " self,\n", " # input size parameters\n", " input_size: Tuple[int, int],\n", " # stem and task parameters\n", " stem_channels: int,\n", " # partitioning parameters\n", " partition_size: int,\n", " # block parameters\n", " block_channels: List[int],\n", " block_layers: List[int],\n", " # attention head dimensions\n", " head_dim: int,\n", " stochastic_depth_prob: float,\n", " # conv + transformer parameters\n", " # norm_layer is applied only to the conv layers\n", " # activation_layer is applied both to conv and transformer layers\n", " norm_layer: Optional[Callable[..., nnx.Module]] = None,\n", " activation_layer: Callable = nnx.gelu,\n", " # conv parameters\n", " squeeze_ratio: float = 0.25,\n", " expansion_ratio: float = 4,\n", " # transformer parameters\n", " mlp_ratio: int = 4,\n", " mlp_dropout: float = 0.0,\n", " attention_dropout: float = 0.0,\n", " # task parameters\n", " num_classes: int = 1000,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " input_channels = 3\n", "\n", " if norm_layer is None:\n", " norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "\n", " # Make sure input size will be divisible by the partition size in all blocks\n", " # Undefined behavior if H or W are not divisible by p\n", " block_input_sizes = _make_block_input_shapes(input_size, len(block_channels))\n", " for idx, block_input_size in enumerate(block_input_sizes):\n", " if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0:\n", " raise ValueError(\n", " f\"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. \"\n", " f\"Consider changing the partition size or the input size.\\n\"\n", " f\"Current configuration yields the following block input sizes: {block_input_sizes}.\"\n", " )\n", "\n", " # stem\n", " self.stem = nnx.Sequential(\n", " Conv2dNormActivation(\n", " input_channels,\n", " stem_channels,\n", " 3,\n", " stride=2,\n", " norm_layer=norm_layer,\n", " activation_layer=activation_layer,\n", " bias=False,\n", " rngs=rngs,\n", " ),\n", " Conv2dNormActivation(\n", " stem_channels,\n", " stem_channels,\n", " 3,\n", " stride=1,\n", " norm_layer=None,\n", " activation_layer=None,\n", " bias=True,\n", " rngs=rngs,\n", " ),\n", " )\n", "\n", " # account for stem stride\n", " input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1)\n", " self.partition_size = partition_size\n", "\n", " # blocks\n", " blocks = []\n", " in_channels = [stem_channels] + block_channels[:-1]\n", " out_channels = block_channels\n", "\n", " # precompute the stochastic depth probabilities from 0 to stochastic_depth_prob\n", " # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed\n", " # over the range [0, stochastic_depth_prob]\n", " p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist()\n", "\n", " p_idx = 0\n", " for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers):\n", " blocks.append(\n", " MaxVitBlock(\n", " in_channels=in_channel,\n", " out_channels=out_channel,\n", " squeeze_ratio=squeeze_ratio,\n", " expansion_ratio=expansion_ratio,\n", " norm_layer=norm_layer,\n", " activation_layer=activation_layer,\n", " head_dim=head_dim,\n", " mlp_ratio=mlp_ratio,\n", " mlp_dropout=mlp_dropout,\n", " attention_dropout=attention_dropout,\n", " partition_size=partition_size,\n", " input_grid_size=input_size,\n", " n_layers=num_layers,\n", " p_stochastic=p_stochastic[p_idx : p_idx + num_layers],\n", " ),\n", " )\n", " input_size = blocks[-1].grid_size\n", " p_idx += num_layers\n", " self.blocks = nnx.Sequential(*blocks)\n", "\n", " self.classifier = nnx.Sequential(\n", " lambda x: nnx.avg_pool(x, (x.shape[1], x.shape[2])), # nn.AdaptiveAvgPool2d(1)\n", " lambda x: x.reshape(x.shape[0], -1), # nn.Flatten()\n", " nnx.LayerNorm(block_channels[-1], rngs=rngs),\n", " nnx.Linear(block_channels[-1], block_channels[-1], rngs=rngs),\n", " nnx.tanh,\n", " nnx.Linear(block_channels[-1], num_classes, use_bias=False, rngs=rngs),\n", " )\n", "\n", " self._init_weights(rngs)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " x = self.stem(x)\n", " x = self.blocks(x)\n", " x = self.classifier(x)\n", " return x\n", "\n", " def _init_weights(self, rngs):\n", " normal_initializer = nnx.initializers.normal(stddev=0.02)\n", " for name, module in self.iter_modules():\n", " if isinstance(module, (nnx.Conv, nnx.Linear)):\n", " module.kernel.value = normal_initializer(\n", " rngs(), module.kernel.value.shape, module.kernel.value.dtype\n", " )\n", " if module.bias.value is not None:\n", " module.bias.value = jnp.zeros(\n", " module.bias.value.shape, dtype=module.bias.value.dtype\n", " )\n", " elif isinstance(module, nnx.BatchNorm):\n", " module.scale.value = jnp.ones(module.scale.value.shape, module.scale.value.dtype)\n", " module.bias.value = jnp.zeros(module.bias.value.shape, module.bias.value.dtype)" ] }, { "cell_type": "code", "execution_count": 32, "id": "7e0f08b8-03a8-4941-8ca3-10d960783486", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 1000)\n" ] } ], "source": [ "x = jnp.ones((4, 224, 224, 3))\n", "\n", "mod = MaxVit(\n", " input_size=(224, 224),\n", " stem_channels=64,\n", " block_channels=[64, 128, 256, 512],\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=32,\n", " stochastic_depth_prob=0.2,\n", " partition_size=7,\n", " num_classes=1000,\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "code", "execution_count": 33, "id": "fa2a4a47-b6c9-43ba-822b-002e0c03e85c", "metadata": {}, "outputs": [], "source": [ "def maxvit_t(\n", " input_size=(224, 224),\n", " stem_channels=64,\n", " block_channels=[64, 128, 256, 512],\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=32,\n", " stochastic_depth_prob=0.2,\n", " partition_size=7,\n", " num_classes=1000,\n", "):\n", " model = MaxVit(\n", " input_size=input_size,\n", " stem_channels=stem_channels,\n", " block_channels=block_channels,\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=head_dim,\n", " stochastic_depth_prob=stochastic_depth_prob,\n", " partition_size=partition_size,\n", " num_classes=num_classes,\n", " )\n", " return model" ] }, { "cell_type": "markdown", "id": "25ff32f7-e4a1-4029-b114-8ecafb4378fd", "metadata": {}, "source": [ "### Test JAX implementation" ] }, { "cell_type": "markdown", "id": "b3e02373-c3b6-4ffd-a98e-e425824f2f88", "metadata": {}, "source": [ "Let us import equivalent PyTorch modules and check our implementations against PyTorch. Please note that\n", "PyTorch modules will contain random parameters and buffers that we need to set into our Flax implementations.\n", "\n", "Below we define a helper class `Torch2Flax` to copy parameters and buffers from a PyTorch module into equivalent Flax module." ] }, { "cell_type": "code", "execution_count": 34, "id": "22f49ecd-4999-4c1c-b1d8-d16faeb60389", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "\n", "class Torch2Flax:\n", " @staticmethod\n", " def conv_params_permute(name, torch_param):\n", " if name == \"weight\":\n", " return torch_param.permute(2, 3, 1, 0)\n", " return torch_param\n", "\n", " @staticmethod\n", " def linear_params_permute(name, torch_param):\n", " if name == \"weight\":\n", " return torch_param.permute(1, 0)\n", " return torch_param\n", "\n", " @staticmethod\n", " def default_params_transform(name, param):\n", " return param\n", "\n", " modules_mapping_info = {\n", " nn.Conv2d: {\n", " \"type\": nnx.Conv,\n", " \"params_mapping\": {\n", " \"weight\": \"kernel\",\n", " \"bias\": \"bias\",\n", " },\n", " \"params_transform\": conv_params_permute,\n", " },\n", " nn.BatchNorm2d: {\n", " \"type\": nnx.BatchNorm,\n", " \"params_mapping\": {\n", " \"weight\": \"scale\",\n", " \"bias\": \"bias\",\n", " \"running_mean\": \"mean\",\n", " \"running_var\": \"var\",\n", " },\n", " },\n", " nn.Linear: {\n", " \"type\": nnx.Linear,\n", " \"params_mapping\": {\n", " \"weight\": \"kernel\",\n", " \"bias\": \"bias\",\n", " },\n", " \"params_transform\": linear_params_permute,\n", " },\n", " nn.LayerNorm: {\n", " \"type\": nnx.LayerNorm,\n", " \"params_mapping\": {\n", " \"weight\": \"scale\",\n", " \"bias\": \"bias\",\n", " },\n", " }\n", " } | {\n", " torch_mod: {\n", " \"type\": nnx_fn_type,\n", " \"params_mapping\": {},\n", " } for torch_mod, nnx_fn_type in [\n", " (nn.Identity, Identity),\n", " (nn.Flatten, type(lambda x: x)),\n", " (nn.ReLU, type(nnx.relu)),\n", " (nn.GELU, type(nnx.gelu)),\n", " (nn.SELU, type(nnx.selu)),\n", " (nn.SiLU, type(nnx.silu)),\n", " (nn.Tanh, type(nnx.tanh)),\n", " (nn.Dropout, nnx.Dropout),\n", " (nn.Sigmoid, type(nnx.sigmoid)),\n", " (nn.AvgPool2d, type(lambda x: nnx.avg_pool(x, (2, 2)))),\n", " (nn.AdaptiveAvgPool2d, type(lambda x: nnx.avg_pool(x, (x.shape[1], x.shape[2])))),\n", " ]\n", " }\n", "\n", " def _copy_params_buffers(self, torch_nn_module, nnx_module):\n", " torch_module_type = type(torch_nn_module)\n", " assert torch_module_type in self.modules_mapping_info, torch_module_type\n", " module_mapping_info = self.modules_mapping_info[torch_module_type]\n", " assert isinstance(nnx_module, module_mapping_info[\"type\"]), (\n", " nnx_module, type(nnx_module), module_mapping_info[\"type\"]\n", " )\n", "\n", " for torch_key, nnx_key in module_mapping_info[\"params_mapping\"].items():\n", "\n", " torch_value = getattr(torch_nn_module, torch_key)\n", " nnx_param = getattr(nnx_module, nnx_key)\n", " assert nnx_param is not None, (torch_key, nnx_key, nnx_module)\n", "\n", " if torch_value is None:\n", " assert nnx_param.value is None, nnx_param\n", " continue\n", "\n", " params_transform = module_mapping_info.get(\"params_transform\", Torch2Flax.default_params_transform)\n", " torch_value = params_transform(torch_key, torch_value)\n", "\n", " assert nnx_param.value.shape == torch_value.data.shape, (\n", " nnx_key, nnx_param.value.shape, torch_key, torch_value.data.shape\n", " )\n", " nnx_param.value = jnp.asarray(torch_value.data)\n", "\n", " def _copy_sequential(self, torch_nn_seq, nnx_seq, skip_modules=None):\n", " assert isinstance(torch_nn_seq, (nn.Sequential, nn.ModuleList)), type(torch_nn_seq)\n", " assert isinstance(nnx_seq, nnx.Sequential), type(nnx_seq)\n", " for i, index in enumerate(torch_nn_seq):\n", " torch_module = torch_nn_seq[i]\n", " nnx_module = nnx_seq.layers[i]\n", " self.copy_module(torch_module, nnx_module, skip_modules=skip_modules)\n", "\n", " def copy_module(self, torch_module, nnx_module, skip_modules=None):\n", " if skip_modules is None:\n", " skip_modules = []\n", "\n", " if isinstance(torch_module, (nn.Sequential, nn.ModuleList)):\n", " self._copy_sequential(torch_module, nnx_module, skip_modules=skip_modules)\n", " elif type(torch_module) in self.modules_mapping_info:\n", " self._copy_params_buffers(torch_module, nnx_module)\n", " else:\n", " if skip_modules is not None:\n", " if torch_module.__class__.__name__ in skip_modules:\n", " return\n", "\n", " named_children = list(torch_module.named_children())\n", " assert len(named_children) > 0, type(torch_module)\n", " for name, torch_child in named_children:\n", " nnx_child = getattr(nnx_module, name, None)\n", " assert nnx_child is not None, (name, nnx_module)\n", " self.copy_module(torch_child, nnx_child, skip_modules=skip_modules)\n", " # Copy buffers and params of the module itself (not its children)\n", " for name, torch_buffer in torch_module.named_buffers():\n", " if \".\" in name:\n", " # This is child's buffer\n", " continue\n", " nnx_buffer = getattr(nnx_module, name)\n", " assert isinstance(nnx_buffer, nnx.Variable), (name, nnx_buffer, nnx_module)\n", "\n", " assert nnx_buffer.value.shape == torch_buffer.shape, (\n", " name, nnx_buffer.value.shape, torch_buffer.shape\n", " )\n", " nnx_buffer.value = jnp.asarray(torch_buffer)\n", "\n", " for name, torch_param in torch_module.named_parameters():\n", " if \".\" in name:\n", " # This is child's parameter\n", " continue\n", " nnx_param = getattr(nnx_module, name)\n", " assert isinstance(nnx_param, nnx.Param), (name, nnx_param, nnx_module)\n", "\n", " assert nnx_param.value.shape == torch_param.data.shape, (\n", " name, nnx_param.value.shape, torch_param.data.shape\n", " )\n", " nnx_param.value = jnp.asarray(torch_param.data)\n", "\n", "\n", "def test_modules(\n", " nnx_module, torch_module, torch_input, atol=1e-3, mode=\"eval\", permute_torch_input=True, device=\"cuda\"\n", "):\n", " assert torch_input.ndim == 4\n", " assert mode in (\"eval\", \"train\")\n", "\n", " torch_input = torch_input.to(device)\n", " torch_module = torch_module.to(device)\n", "\n", " if mode == \"eval\":\n", " torch_module.eval()\n", " nnx_module.eval()\n", " else:\n", " torch_module.train()\n", " nnx_module.train()\n", "\n", " with torch.inference_mode(mode=mode==\"eval\"):\n", " torch_output = torch_module(torch_input)\n", "\n", " if permute_torch_input:\n", " torch_input = torch_input.permute(0, 2, 3, 1)\n", "\n", " jax_input = jnp.asarray(torch_input, device=jax.devices(device)[0])\n", " jax_output = nnx_module(jax_input)\n", " assert jax_output.device == jax.devices(device)[0]\n", "\n", " torch_output = torch_output.detach()\n", " if permute_torch_input and torch_output.ndim == 4:\n", " torch_output = torch_output.permute(0, 2, 3, 1)\n", " jax_expected = jnp.asarray(torch_output)\n", "\n", " assert jnp.allclose(jax_output, jax_expected, atol=atol), (\n", " jnp.abs(jax_output - jax_expected).max(),\n", " jnp.abs(jax_output - jax_expected).mean(),\n", " )\n", "\n", "\n", "t2f = Torch2Flax()" ] }, { "cell_type": "markdown", "id": "a323a19e-fc64-4f8d-8be2-c7886b6191b9", "metadata": {}, "source": [ "Let us now test our JAX modules. We only test the result of the forward pass in the inference mode such that we avoid discrepancies related to random layers like `Dropout`, `StochasticDepth`, etc.\n", "By default, we use absolute error tolerence `1e-3` when comparing the JAX output against expected PyTorch result.\n", "For larger modules we set the device to CPU for the JAX model to execute on in order to reduce the errors between CPU and CUDA." ] }, { "cell_type": "code", "execution_count": 35, "id": "2e10fb43-6ae6-47c7-81fe-5027b115b25f", "metadata": {}, "outputs": [], "source": [ "from torchvision.ops.misc import Conv2dNormActivation as PyTorchConv2dNormActivation\n", "\n", "\n", "torch_module = PyTorchConv2dNormActivation(32, 64, 3, 2, 1)\n", "nnx_module = Conv2dNormActivation(32, 64, 3, 2, 1)\n", "\n", "t2f.copy_module(torch_module, nnx_module)\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))" ] }, { "cell_type": "code", "execution_count": 36, "id": "7ac4b49a-712f-4725-a8d8-33e72b8d0b66", "metadata": {}, "outputs": [], "source": [ "from torchvision.ops.misc import SqueezeExcitation as PyTorchSqueezeExcitation\n", "\n", "\n", "torch_module = PyTorchSqueezeExcitation(32, 4)\n", "nnx_module = SqueezeExcitation(32, 4)\n", "\n", "t2f.copy_module(torch_module, nnx_module)\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))" ] }, { "cell_type": "code", "execution_count": 37, "id": "746c8882-0001-4c97-b5cf-576dc5c87c02", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "from functools import partial\n", "from torchvision.models.maxvit import MBConv as PyTorchMBConv\n", "\n", "\n", "norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)\n", "torch_module = PyTorchMBConv(32, 64, 4, 0.25, 2, activation_layer=nn.GELU, norm_layer=norm_layer)\n", "\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "nnx_module = MBConv(32, 64, 4, 0.25, 2, activation_layer=nnx.gelu, norm_layer=norm_layer)\n", "\n", "\n", "t2f.copy_module(torch_module, nnx_module)\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))" ] }, { "cell_type": "code", "execution_count": 38, "id": "249f6d28-57b6-4d36-9079-cd60964e6afc", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import RelativePositionalMultiHeadAttention as PyTorchRelativePositionalMultiHeadAttention\n", "\n", "\n", "torch_module = PyTorchRelativePositionalMultiHeadAttention(64, 16, 49)\n", "nnx_module = RelativePositionalMultiHeadAttention(64, 16, 49)\n", "\n", "t2f.copy_module(torch_module, nnx_module)\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 32, 49, 64), permute_torch_input=False)" ] }, { "cell_type": "code", "execution_count": 39, "id": "f48fc475-c556-4101-ad2b-19480a73c6ba", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import PartitionAttentionLayer as PyTorchPartitionAttentionLayer\n", "\n", "\n", "grid_size = (224, 224)\n", "for partition_type in [\"window\", \"grid\"]:\n", "\n", " torch_module = PyTorchPartitionAttentionLayer(\n", " 36, 6, 7, partition_type, grid_size=grid_size, mlp_ratio=4,\n", " activation_layer=nn.GELU, norm_layer=nn.LayerNorm,\n", " attention_dropout=0.4, mlp_dropout=0.3, p_stochastic_dropout=0.2,\n", " )\n", "\n", " nnx_module = PartitionAttentionLayer(\n", " 36, 6, 7, partition_type, grid_size=grid_size, mlp_ratio=4,\n", " activation_layer=nnx.gelu, norm_layer=nnx.LayerNorm,\n", " attention_dropout=0.4, mlp_dropout=0.3, p_stochastic_dropout=0.2,\n", " )\n", "\n", " t2f.copy_module(torch_module, nnx_module, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", " ])\n", "\n", " test_modules(nnx_module, torch_module, torch.randn(4, 36, 224, 224))" ] }, { "cell_type": "code", "execution_count": 40, "id": "7ab2de6a-9790-444b-b175-535cdb05f5d8", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import MaxVitLayer as PyTorchMaxVitLayer\n", "\n", "\n", "stride = 2\n", "\n", "grid_size = _get_conv_output_shape((224, 224), kernel_size=3, stride=2, padding=1)\n", "norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)\n", "\n", "torch_module = PyTorchMaxVitLayer(\n", " 36, 36, squeeze_ratio=0.25, expansion_ratio=4,\n", " stride=stride, norm_layer=norm_layer, activation_layer=nn.GELU,\n", " head_dim=6, mlp_ratio=4, mlp_dropout=0.5,\n", " attention_dropout=0.4, p_stochastic_dropout=0.3,\n", " partition_size=7, grid_size=grid_size,\n", ")\n", "\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "nnx_module = MaxVitLayer(\n", " 36, 36, squeeze_ratio=0.25, expansion_ratio=4,\n", " stride=stride, norm_layer=norm_layer, activation_layer=nnx.gelu,\n", " head_dim=6, mlp_ratio=4, mlp_dropout=0.5,\n", " attention_dropout=0.4, p_stochastic_dropout=0.3,\n", " partition_size=7, grid_size=grid_size,\n", ")\n", "\n", "t2f.copy_module(torch_module, nnx_module, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", "])\n", "\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 36, 224, 224), device=\"cpu\")" ] }, { "cell_type": "code", "execution_count": 41, "id": "e8e8f997-0184-4af2-82b0-ceaa580645c8", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import MaxVitBlock as PyTorchMaxVitBlock\n", "\n", "\n", "norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)\n", "torch_module = PyTorchMaxVitBlock(\n", " 64, 128, squeeze_ratio=0.25, expansion_ratio=4,\n", " norm_layer=norm_layer, activation_layer=nn.GELU,\n", " head_dim=32, mlp_ratio=4, mlp_dropout=0.0, attention_dropout=0.0,\n", " partition_size=7, input_grid_size=(56, 56),\n", " n_layers=2,\n", " p_stochastic=[0.13333333333333333, 0.2],\n", ")\n", "\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "nnx_module = MaxVitBlock(\n", " 64, 128, squeeze_ratio=0.25, expansion_ratio=4,\n", " norm_layer=norm_layer, activation_layer=nnx.gelu,\n", " head_dim=32, mlp_ratio=4, mlp_dropout=0.0, attention_dropout=0.0,\n", " partition_size=7, input_grid_size=(56, 56),\n", " n_layers=2,\n", " p_stochastic=[0.13333333333333333, 0.2],\n", ")\n", "\n", "t2f.copy_module(torch_module, nnx_module, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", "])\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 64, 56, 56), device=\"cpu\")" ] }, { "cell_type": "markdown", "id": "e313819a-e93a-4201-806d-783bd1336c78", "metadata": {}, "source": [ "Finally, we can check the MaxVit implementation. Note that we raised the absolute tolerence to `1e-1` when comparing JAX output logits against PyTorch expected logits." ] }, { "cell_type": "code", "execution_count": 42, "id": "e2af63a4-b16b-40a3-ac00-bfc23d532c82", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import MaxVit as PyTorchMaxVit\n", "\n", "\n", "torch.manual_seed(77)\n", "\n", "\n", "torch_module = PyTorchMaxVit(\n", " input_size=(224, 224),\n", " stem_channels=64,\n", " block_channels=[64, 128, 256, 512],\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=32,\n", " stochastic_depth_prob=0.2,\n", " partition_size=7,\n", " num_classes=1000,\n", ")\n", "\n", "nnx_module = MaxVit(\n", " input_size=(224, 224),\n", " stem_channels=64,\n", " block_channels=[64, 128, 256, 512],\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=32,\n", " stochastic_depth_prob=0.2,\n", " partition_size=7,\n", " num_classes=1000,\n", ")\n", "\n", "t2f.copy_module(torch_module, nnx_module, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", "])\n", "\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 3, 224, 224), device=\"cpu\", atol=1e-1)" ] }, { "cell_type": "markdown", "id": "0d3c4f4d-2a50-46f4-814c-42c1a423cfd0", "metadata": {}, "source": [ "### Check Flax model\n", "Let us now reuse trained weights from TorchVision's MaxViT model to check output logits and the predictions on our example image:" ] }, { "cell_type": "code", "execution_count": 43, "id": "7975f311-7a02-4c82-99db-b0b50fb37528", "metadata": {}, "outputs": [], "source": [ "from torchvision.models import maxvit_t as pytorch_maxvit_t, MaxVit_T_Weights\n", "\n", "torch_model = pytorch_maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1)\n", "flax_model = maxvit_t()\n", "\n", "t2f = Torch2Flax()\n", "t2f.copy_module(torch_model, flax_model, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", "])" ] }, { "cell_type": "code", "execution_count": 44, "id": "922cc4b5-f181-4865-9043-fd3b56bafe43", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction for the Dog:\n", "- PyTorch model result: ['n02113023', 'Pembroke'], score: 0.7800846099853516\n", "- Flax model result: ['n02113023', 'Pembroke'], score: 0.7799879908561707\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import json\n", "from torchvision.io import read_image\n", "\n", "\n", "preprocess = MaxVit_T_Weights.IMAGENET1K_V1.transforms()\n", "\n", "with open(\"imagenet_class_index.json\") as labels_file:\n", " labels = json.load(labels_file)\n", "\n", "\n", "dog1 = read_image(\"dog1.jpg\")\n", "tensor = preprocess(dog1).unsqueeze(dim=0)\n", "\n", "torch_model.eval()\n", "with torch.inference_mode():\n", " torch_output = torch_model(tensor)\n", "\n", "torch_class_id = torch_output.argmax(dim=1).item()\n", "\n", "jax_array = jnp.asarray(tensor.permute(0, 2, 3, 1), device=jax.devices(\"cpu\")[0])\n", "flax_model.eval()\n", "flax_output = flax_model(jax_array)\n", "\n", "flax_class_id = torch_output.argmax(axis=1).item()\n", "\n", "print(\"Prediction for the Dog:\")\n", "print(f\"- PyTorch model result: {labels[str(torch_class_id)]}, score: {torch_output.softmax(axis=1)[0, torch_class_id]}\")\n", "print(f\"- Flax model result: {labels[str(flax_class_id)]}, score: {jax.nn.softmax(flax_output, axis=1)[0, flax_class_id]}\")\n", "\n", "\n", "plt.subplot(121)\n", "plt.title(f\"{labels[str(torch_class_id)]}\\nScore: {torch_output.softmax(dim=-1)[0, class_id]:.4f}\")\n", "plt.imshow(dog1.permute(1, 2, 0))\n", "\n", "plt.subplot(122)\n", "plt.title(f\"{labels[str(flax_class_id)]}\\nScore: {jax.nn.softmax(flax_output, axis=1)[0, flax_class_id]:.4f}\")\n", "plt.imshow(dog1.permute(1, 2, 0))" ] }, { "cell_type": "markdown", "id": "c77f3244", "metadata": {}, "source": [ "Let's compute cosine distance between the logits:" ] }, { "cell_type": "code", "execution_count": 45, "id": "36801241-11cc-4850-8ea2-f0a306eaba2a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array(0.99999857, dtype=float32)" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "expected = jnp.asarray(torch_output)\n", "\n", "cosine_dist = (expected * flax_output).sum() / (jnp.linalg.norm(flax_output) * jnp.linalg.norm(expected))\n", "cosine_dist" ] }, { "cell_type": "markdown", "id": "65e57aa6-1572-4805-9207-bc8a5f9f3ab1", "metadata": {}, "source": [ "## Further reading\n", "\n", "- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)\n", "- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/getting_started.html)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }