{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/21_gradient_clipping.ipynb)\n", "\n", "# ๐ŸŸข Easy: Gradient Norm Clipping\n", "\n", "Implement **gradient norm clipping** โ€” a training stability technique.\n", "\n", "### Signature\n", "```python\n", "def clip_grad_norm(parameters, max_norm: float) -> float:\n", " # Clip gradients in-place so total norm <= max_norm\n", " # Returns the original (unclipped) total norm\n", "```\n", "\n", "### Algorithm\n", "1. Compute total norm: `sqrt(sum(p.grad.norm()^2 for p in parameters))`\n", "2. If total > max_norm: scale all grads by `max_norm / total`\n", "3. Return original total norm" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def clip_grad_norm(parameters, max_norm):\n", " pass # compute total norm, clip if needed, return original norm" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "p = torch.randn(100, requires_grad=True)\n", "(p * 10).sum().backward()\n", "print('Before:', p.grad.norm().item())\n", "orig = clip_grad_norm([p], max_norm=1.0)\n", "print('After: ', p.grad.norm().item())\n", "print('Original norm:', orig)" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('gradient_clipping')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }