{ "cells": [ { "cell_type": "markdown", "id": "968cc37c", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/38_grpo_loss.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: GRPO Loss\n", "\n", "Implement the **Group Relative Policy Optimization (GRPO)** loss โ€” a group-wise, baseline-subtracted REINFORCE objective commonly used in RLAIF (reinforcement learning from AI feedback).\n", "\n", "Given a batch of log-probabilities, scalar rewards, and group ids (one group per prompt), define the within-group normalized advantages:\n", "\n", "$$A_i = \\frac{r_i - \\bar r_{g(i)}}{\\text{std}_{g(i)} + \\epsilon}$$\n", "\n", "where \\(\\bar r_{g(i)}\\) and \\(\\text{std}_{g(i)}\\) are the mean and standard deviation of rewards in the group of example \\(i\\).\n", "\n", "The GRPO loss is then the negative advantage-weighted log-probability:\n", "\n", "$$\\mathcal{L}_{\\text{GRPO}} = -\\mathbb{E}_i \\big[\\,\\text{stop\\_grad}(A_i)\\, \\log \\pi_\\theta(y_i)\\big].$$\n", "\n", "### Signature\n", "```python\n", "from torch import Tensor\n", "\n", "def grpo_loss(logps: Tensor, rewards: Tensor, group_ids: Tensor,\n", " eps: float = 1e-5) -> Tensor:\n", " \"\"\"GRPO loss over a batch.\n", "\n", " logps: (B,) policy log-probs for each sampled response\n", " rewards: (B,) scalar rewards for each response\n", " group_ids: (B,) integers, same id = same prompt/group\n", " returns: scalar loss (Tensor)\n", " \"\"\"\n", "```" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "execution_count": null, "id": "d1038dfe", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": null, "id": "68d0bd84", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "from torch import Tensor\n", "\n", "def grpo_loss(logps: Tensor, rewards: Tensor, group_ids: Tensor,\n", " eps: float = 1e-5) -> Tensor:\n", " pass # compute normalized advantages per group and return -mean(adv.detach() * logps)" ] }, { "cell_type": "code", "execution_count": null, "id": "eb215c40", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "logps = torch.tensor([0.0, -0.5, -1.0, -1.5])\n", "rewards = torch.tensor([1.0, 0.8, 0.2, 0.0])\n", "group_ids = torch.tensor([0, 0, 1, 1])\n", "print('Loss:', grpo_loss(logps, rewards, group_ids).item())" ] }, { "cell_type": "code", "execution_count": null, "id": "95b2e29e", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('grpo_loss')" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }