{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/31_gradient_accumulation.ipynb)\n", "\n", "# ๐ŸŸข Easy: Gradient Accumulation\n", "\n", "Implement a **training step with gradient accumulation** โ€” simulating large batches with limited memory.\n", "\n", "### Signature\n", "```python\n", "def accumulated_step(model, optimizer, loss_fn, micro_batches) -> float:\n", " # micro_batches: list of (input, target) tuples\n", " # Returns: average loss (float)\n", "```\n", "\n", "### Algorithm\n", "1. `optimizer.zero_grad()`\n", "2. For each `(x, y)` in micro_batches: `loss = loss_fn(model(x), y) / len(micro_batches)`, then `loss.backward()`\n", "3. `optimizer.step()`\n", "4. Return total accumulated loss\n", "\n", "The key insight: dividing each loss by `n` before backward makes accumulated gradients equal to a single large-batch gradient." ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def accumulated_step(model, optimizer, loss_fn, micro_batches):\n", " pass # zero_grad, loop (forward, scale loss, backward), step" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "model = nn.Linear(4, 2)\n", "opt = torch.optim.SGD(model.parameters(), lr=0.01)\n", "loss = accumulated_step(model, opt, nn.MSELoss(),\n", " [(torch.randn(2, 4), torch.randn(2, 2)) for _ in range(4)])\n", "print('Loss:', loss)" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('gradient_accumulation')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }