{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/15_mlp.ipynb)\n", "\n", "# ๐ŸŸ  Medium: SwiGLU MLP\n", "\n", "Implement the **SwiGLU MLP** (feed-forward network) used in modern LLMs like LLaMA.\n", "\n", "$$\\text{SwiGLU}(x) = \\text{down\\_proj}\\big(\\text{SiLU}(\\text{gate\\_proj}(x)) \\odot \\text{up\\_proj}(x)\\big)$$\n", "\n", "where $\\text{SiLU}(x) = x \\cdot \\sigma(x)$\n", "\n", "### Signature\n", "```python\n", "class SwiGLUMLP(nn.Module):\n", " def __init__(self, d_model: int, d_ff: int): ...\n", " def forward(self, x: torch.Tensor) -> torch.Tensor: ...\n", "```\n", "\n", "### Requirements\n", "- Inherit from `nn.Module`\n", "- `self.gate_proj`: `nn.Linear(d_model, d_ff)`\n", "- `self.up_proj`: `nn.Linear(d_model, d_ff)`\n", "- `self.down_proj`: `nn.Linear(d_ff, d_model)`\n", "- Activation: **SiLU** (a.k.a. Swish) โ€” `F.silu` or implement as `x * torch.sigmoid(x)`\n", "\n", "### Why SwiGLU?\n", "Unlike the classic `Linear โ†’ ReLU/GELU โ†’ Linear` FFN, SwiGLU uses a **gating mechanism**:\n", "the gate projection controls information flow, while the up projection provides the content.\n", "This consistently outperforms standard FFNs in practice (PaLM, LLaMA, Mistral all use it)." ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "class SwiGLUMLP(nn.Module):\n", " def __init__(self, d_model, d_ff):\n", " super().__init__()\n", " pass # Initialize gate_proj, up_proj, down_proj\n", "\n", " def forward(self, x):\n", " pass # down_proj(silu(gate_proj(x)) * up_proj(x))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "mlp = SwiGLUMLP(d_model=64, d_ff=128)\n", "x = torch.randn(2, 8, 64)\n", "out = mlp(x)\n", "print(\"Output shape:\", out.shape) # (2, 8, 64)\n", "print(\"Params:\", sum(p.numel() for p in mlp.parameters()))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('mlp')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }