{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/34_speculative_decoding.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: Speculative Decoding\n", "\n", "Implement the **acceptance/rejection step** of speculative decoding โ€” a technique for accelerating LLM inference.\n", "\n", "### Signature\n", "```python\n", "def speculative_decode(target_probs, draft_probs, draft_tokens) -> list[int]:\n", " # target_probs: (K, V) from target (large) model\n", " # draft_probs: (K, V) from draft (small) model\n", " # draft_tokens: (K,) tokens sampled by draft model\n", " # Returns: list of accepted tokens (1 to K)\n", "```\n", "\n", "### Algorithm\n", "For each position i = 0, ..., K-1:\n", "1. `ratio = target_probs[i, token_i] / draft_probs[i, token_i]`\n", "2. Accept with probability `min(1, ratio)`\n", "3. If rejected: sample from `normalize(max(0, target - draft))`, append, and stop" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def speculative_decode(target_probs, draft_probs, draft_tokens):\n", " pass # accept/reject loop" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "torch.manual_seed(0)\n", "probs = torch.softmax(torch.randn(4, 10), dim=-1)\n", "tokens = torch.tensor([2, 5, 1, 8])\n", "print('Perfect draft:', speculative_decode(probs, probs, tokens))\n", "target = torch.softmax(torch.randn(4, 10), dim=-1)\n", "draft = torch.softmax(torch.randn(4, 10), dim=-1)\n", "print('Random draft:', speculative_decode(target, draft, tokens))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('speculative_decoding')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }