{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/28_moe.ipynb)\n", "\n", "# 🔴 Hard: Mixture of Experts (MoE)\n", "\n", "Implement a **Mixture of Experts** layer (Mixtral / Switch Transformer style).\n", "\n", "### Signature\n", "```python\n", "class MixtureOfExperts(nn.Module):\n", " def __init__(self, d_model, d_ff, num_experts, top_k=2): ...\n", " def forward(self, x: Tensor) -> Tensor:\n", " # x: (B, S, D) -> (B, S, D)\n", "```\n", "\n", "### Architecture\n", "- `self.router`: `nn.Linear(d_model, num_experts)` — gating network\n", "- `self.experts`: `nn.ModuleList` of MLPs `(Linear→ReLU→Linear)`\n", "- For each token: select top-k experts, compute weighted sum of their outputs" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ✏️ YOUR IMPLEMENTATION HERE\n", "\n", "class MixtureOfExperts(nn.Module):\n", " def __init__(self, d_model, d_ff, num_experts, top_k=2):\n", " super().__init__()\n", " pass # router + experts\n", "\n", " def forward(self, x):\n", " pass # route tokens to top-k experts" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# 🧪 Debug\n", "moe = MixtureOfExperts(32, 64, num_experts=4, top_k=2)\n", "x = torch.randn(2, 8, 32)\n", "print('Output:', moe(x).shape)\n", "print('Params:', sum(p.numel() for p in moe.parameters()))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ✅ SUBMIT\n", "from torch_judge import check\n", "check('moe')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }