{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/30_cosine_lr.ipynb)\n", "\n", "# ๐ŸŸ  Medium: Cosine LR Scheduler with Warmup\n", "\n", "Implement a **cosine learning rate schedule** with linear warmup.\n", "\n", "### Signature\n", "```python\n", "def cosine_lr_schedule(step, total_steps, warmup_steps, max_lr, min_lr=0.0) -> float:\n", "```\n", "\n", "### Schedule\n", "```\n", "step < warmup: lr = max_lr * step / warmup_steps (linear ramp)\n", "step >= warmup: lr = min_lr + 0.5*(max_lr-min_lr)*(1 + cos(ฯ€ * progress))\n", "```\n", "where `progress = (step - warmup) / (total - warmup)`" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import math" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def cosine_lr_schedule(step, total_steps, warmup_steps, max_lr, min_lr=0.0):\n", " pass # warmup then cosine decay" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "lrs = [cosine_lr_schedule(i, 100, 10, 0.001) for i in range(101)]\n", "print(f'Start: {lrs[0]:.6f}')\n", "print(f'Warmup end: {lrs[10]:.6f}')\n", "print(f'Mid: {lrs[55]:.6f}')\n", "print(f'End: {lrs[100]:.6f}')" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('cosine_lr')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }