{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/37_dpo_loss.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: DPO Loss\n", "\n", "Implement the **Direct Preference Optimization** loss โ€” the standard loss for LLM alignment.\n", "\n", "$$\\mathcal{L}_{\\text{DPO}} = -\\log \\sigma\\Big(\\beta \\big[\\log\\frac{\\pi(y_w)}{\\text{ref}(y_w)} - \\log\\frac{\\pi(y_l)}{\\text{ref}(y_l)}\\big]\\Big)$$\n", "\n", "### Signature\n", "```python\n", "def dpo_loss(policy_chosen_logps, policy_rejected_logps,\n", " ref_chosen_logps, ref_rejected_logps, beta=0.1) -> Tensor:\n", " # All inputs: (B,) log-probabilities\n", " # Returns: scalar loss\n", "```" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def dpo_loss(policy_chosen_logps, policy_rejected_logps,\n", " ref_chosen_logps, ref_rejected_logps, beta=0.1):\n", " pass # -log(sigmoid(beta * (chosen_reward - rejected_reward)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "chosen = torch.tensor([0.0, 0.0])\n", "rejected = torch.tensor([-5.0, -5.0])\n", "ref_c = torch.tensor([-1.0, -1.0])\n", "ref_r = torch.tensor([-1.0, -1.0])\n", "print('Loss:', dpo_loss(chosen, rejected, ref_c, ref_r, beta=0.1).item())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('dpo_loss')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 4 }