{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/02_softmax.ipynb)\n", "\n", "# ๐ŸŸข Easy: Implement Softmax\n", "\n", "Implement the **Softmax** function from scratch.\n", "\n", "$$\\text{softmax}(x_i) = \\frac{e^{x_i}}{\\sum_j e^{x_j}}$$\n", "\n", "### Signature\n", "```python\n", "def my_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:\n", " ...\n", "```\n", "\n", "### Rules\n", "- Do **NOT** use `torch.softmax`, `F.softmax`, or `torch.nn.Softmax`\n", "- Must be **numerically stable** (hint: subtract `max` before `exp`)\n", "\n", "### Example\n", "```\n", "Input: tensor([1., 2., 3.])\n", "Output: tensor([0.0900, 0.2447, 0.6652]) # sums to 1.0\n", "```" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def my_softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:\n", " pass # Replace this" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "x = torch.tensor([1.0, 2.0, 3.0])\n", "print(\"Output:\", my_softmax(x, dim=-1))\n", "print(\"Sum: \", my_softmax(x, dim=-1).sum()) # should be ~1.0\n", "print(\"Ref: \", torch.softmax(x, dim=-1))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check(\"softmax\")" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }