{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/29_adam.ipynb)\n", "\n", "# 🟠 Medium: Adam Optimizer\n", "\n", "Implement the **Adam** optimizer from scratch.\n", "\n", "### Signature\n", "```python\n", "class MyAdam:\n", " def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): ...\n", " def step(self): ...\n", " def zero_grad(self): ...\n", "```\n", "\n", "### Algorithm (per parameter)\n", "```\n", "m = β1 * m + (1-β1) * grad\n", "v = β2 * v + (1-β2) * grad²\n", "m̂ = m / (1 - β1ᵗ) # bias correction\n", "v̂ = v / (1 - β2ᵗ)\n", "p -= lr * m̂ / (√v̂ + ε)\n", "```" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ✏️ YOUR IMPLEMENTATION HERE\n", "\n", "class MyAdam:\n", " def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):\n", " pass # store params, init m and v to zeros\n", "\n", " def step(self):\n", " pass # update params using Adam rule\n", "\n", " def zero_grad(self):\n", " pass # zero all gradients" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# 🧪 Debug\n", "torch.manual_seed(0)\n", "w = torch.randn(4, 3, requires_grad=True)\n", "opt = MyAdam([w], lr=0.01)\n", "for i in range(5):\n", " loss = (w ** 2).sum()\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " print(f'Step {i}: loss={loss.item():.4f}')" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ✅ SUBMIT\n", "from torch_judge import check\n", "check('adam')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }