{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/33_beam_search.ipynb)\n", "\n", "# ๐ŸŸ  Medium: Beam Search Decoding\n", "\n", "Implement **beam search** โ€” the classic decoding algorithm for sequence generation.\n", "\n", "### Signature\n", "```python\n", "def beam_search(log_prob_fn, start_token, max_len, beam_width, eos_token) -> list[int]:\n", " # log_prob_fn: takes token list, returns (V,) log-probabilities\n", " # Returns: best sequence (list of ints)\n", "```\n", "\n", "### Algorithm\n", "1. Start with `[(0.0, [start_token])]`\n", "2. Each step: expand each beam with top-k next tokens\n", "3. Keep top `beam_width` beams by total log-probability\n", "4. Stop when best beam ends with `eos_token` or `max_len` reached" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def beam_search(log_prob_fn, start_token, max_len, beam_width, eos_token):\n", " pass # maintain beams, expand, prune, return best" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "def simple_fn(tokens):\n", " lp = torch.full((5,), -10.0)\n", " lp[min(len(tokens), 4)] = 0.0\n", " return lp\n", "seq = beam_search(simple_fn, start_token=0, max_len=5, beam_width=2, eos_token=4)\n", "print('Sequence:', seq)" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('beam_search')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }