{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/32_topk_sampling.ipynb)\n", "\n", "# ๐ŸŸ  Medium: Top-k / Top-p (Nucleus) Sampling\n", "\n", "Implement **sampling with top-k and top-p filtering** โ€” the standard LLM decoding strategy.\n", "\n", "### Signature\n", "```python\n", "def sample_top_k_top_p(logits, top_k=0, top_p=1.0, temperature=1.0) -> int:\n", " # logits: (V,) unnormalized log-probabilities\n", " # Returns: sampled token index\n", "```\n", "\n", "### Algorithm\n", "1. Scale by temperature: `logits /= temperature`\n", "2. Top-k: keep only top-k logits, set rest to `-inf`\n", "3. Top-p: sort by prob, mask tokens where cumulative prob exceeds p\n", "4. Sample from filtered distribution" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def sample_top_k_top_p(logits, top_k=0, top_p=1.0, temperature=1.0):\n", " pass # temperature, top-k filter, top-p filter, sample" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "logits = torch.tensor([1.0, 5.0, 2.0, 0.5])\n", "print('top_k=1:', sample_top_k_top_p(logits.clone(), top_k=1))\n", "print('top_p=0.5:', sample_top_k_top_p(logits.clone(), top_p=0.5))\n", "print('temp=0.01:', sample_top_k_top_p(logits.clone(), temperature=0.01))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('topk_sampling')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }