{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/36_int8_quantization.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: INT8 Quantized Linear\n", "\n", "Implement a **post-training quantized linear layer** using INT8 weights.\n", "\n", "### Signature\n", "```python\n", "class Int8Linear(nn.Module):\n", " def __init__(self, weight: Tensor, bias: Tensor = None): ...\n", " def forward(self, x: Tensor) -> Tensor: ...\n", "```\n", "\n", "### Quantization (per-channel)\n", "1. `scale = weight.abs().max(dim=1) / 127`\n", "2. `weight_int8 = round(weight / scale).clamp(-128, 127).to(int8)`\n", "3. Store as `register_buffer` (not trainable)\n", "4. Forward: dequantize (`int8.float() * scale`) then matmul" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "class Int8Linear(nn.Module):\n", " def __init__(self, weight, bias=None):\n", " super().__init__()\n", " pass # quantize weight, register buffers\n", "\n", " def forward(self, x):\n", " pass # dequantize and matmul" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "w = torch.randn(8, 4)\n", "q = Int8Linear(w)\n", "x = torch.randn(2, 4)\n", "print('Output:', q(x).shape)\n", "print('dtype:', q.weight_int8.dtype)\n", "print('Max quant error:', (w - q.weight_int8.float() * q.scale).abs().max().item())" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('int8_quantization')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }