{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "1caaee84-f043-42a1-b270-f1163efa290d", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# code for loading notebook's format\n", "import os\n", "\n", "# path : store the current path to convert back to it later\n", "path = os.getcwd()\n", "os.chdir(os.path.join('..', '..', '..', 'notebook_format'))\n", "\n", "from formats import load_style\n", "load_style(css_style='custom2.css', plot_style=False)" ] }, { "cell_type": "code", "execution_count": 2, "id": "c83efd8a-fbfc-44e4-af03-eb8d772aeb60", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Author: Ethen\n", "\n", "Last updated: 2024-10-12\n", "\n", "Python implementation: CPython\n", "Python version : 3.10.14\n", "IPython version : 8.26.0\n", "\n", "pytorch_lightning: 2.1.4\n", "transformers : 4.41.1\n", "datasets : 3.0.0\n", "torch : 2.1.2+cu121\n", "\n" ] } ], "source": [ "os.chdir(path)\n", "\n", "%load_ext watermark\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import transformers\n", "import pytorch_lightning as pl\n", "from dataclasses import dataclass\n", "from torch.utils.data import DataLoader\n", "from transformers import (\n", " AutoTokenizer,\n", " AutoModelForCausalLM,\n", " GenerationConfig,\n", " PreTrainedTokenizerBase,\n", ")\n", "from typing import Any, Dict, List, Optional, Union\n", "from transformers.data.data_collator import DataCollatorMixin\n", "from transformers.utils import PaddingStrategy\n", "\n", "%watermark -a 'Ethen' -d -v -u -p pytorch_lightning,transformers,datasets,torch" ] }, { "cell_type": "markdown", "id": "59d2c9e4-fc9c-4c42-8c49-0359c9f0564a", "metadata": {}, "source": [ "# LLM Pairwise Judge" ] }, { "cell_type": "markdown", "id": "d2280d18-4ca4-479e-aedb-dd1d58672ccb", "metadata": {}, "source": [ "In this article, we'll be implementing a LLM pairwise judge, where a LLM is presented with a question and two answers, and tasked with determining which answer is better or declaring a tie. Using LLMs as judges for evaluation offers several benefits:\n", "\n", "- Scalability: Compared to obtaining ground truth labels from human evaluators, LLM inference is generally faster and more cost-effective.\n", "- Explainability: Unlike metrics such as BLEU or ROUGE, which primarily focus on variants of text overlap or re-ranker based relevance model, LLMs can also generate reasoning or explanations along with scores, providing more interpretable evaluations.\n", "- Versatility: LLMs can be fine-tuned or adapted to judge outputs across various domains, tasks, and languages, offering a versatile evaluation framework. This makes LLMs more suitable for evaluating diverse instruction-following and conversational abilities.\n", "\n", "Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena [[5]](https://arxiv.org/abs/2306.05685) provides a thorough examination of using LLMs as judges. The researchers curated two distinct benchmark suites for this purpose:\n", "\n", "- MT-Bench: A benchmark consisting of 80 high-quality multi-turn questions.\n", "- Chatbot Arena: An innovative crowdsourcing benchmark platform featuring anonymous battles. On this platform, users can interact with two anonymous chatbot models simultaneously by posing the same question to both. They then vote for the model that provides their preferred response, with the models' identities revealed only after the voting process. Unlike traditional benchmarks that rely on predefined questions, Chatbot Arena enables users to ask any question they desire, effectively capturing a wide range of evaluations \"in the wild\".\n", "\n", "They verify by using state of art LLMs, GPT-4, as judges, it's capable of matching human evaluation at an agreement rate exceeding 80%." ] }, { "cell_type": "markdown", "id": "7e97d8e4-bf34-443d-bde6-3311a3653399", "metadata": {}, "source": [ "## LLM Generation" ] }, { "cell_type": "markdown", "id": "79624461-eb27-4568-be85-0b83475fb4ec", "metadata": {}, "source": [ "We'll first implement a generation module for generating responses from LLM. We use Qwen 2.5 Collection [[2]](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e) in this article, feel free to pick your favorite LLM. While doing so be sure to set the correct padding token, padding side as well as configure max_new_tokens [[1]](https://huggingface.co/docs/transformers/llm_tutorial).\n", "\n", "- Huggingface's generate function returns up to 20 tokens by default if `max_new_tokens` is not explicitly specified in `GenerationConfig`.\n", "- LLMs (decoder only models)'s also returns the input prompt as part of the output by default. We'll need some post-processing to crop those input prompts out if that is not the desired behaviour.\n", "- Similar to other tasks, while operating on a batch of inputs, if our input prompts have varying lengths, they need to be padded to ensure consistent length. Since LLMs often times don't have a default pad token and are not trained to continue from pad tokens, be sure to assign a pad token (e.g. assign eos token) and left pad our inputs." ] }, { "cell_type": "code", "execution_count": 3, "id": "5dd8fbe4-f81f-4162-be7b-5b855ceb6dec", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class DataCollatorForGeneration(DataCollatorMixin):\n", " \"\"\"\n", " tokenize raw text (prompt) as well as padding while forming a batch for data loader.\n", " \"\"\"\n", "\n", " tokenizer: PreTrainedTokenizerBase\n", " max_seq_len: int = 512\n", " padding: Union[bool, str, PaddingStrategy] = True\n", " return_tensors: str = \"pt\"\n", " prompt_col_name: str = \"prompt\"\n", "\n", " def __post_init__(self):\n", " self.tokenizer.padding_side = \"left\"\n", " self.tokenizer.pad_token = self.tokenizer.eos_token\n", "\n", " def __call__(\n", " self, features: List[Dict[str, Any]], return_tensors=None\n", " ) -> Dict[str, Any]:\n", "\n", " prompts = [feature[self.prompt_col_name] for feature in features]\n", " tokenized_text = self.tokenizer(\n", " prompts,\n", " padding=self.padding,\n", " max_length=self.max_seq_len,\n", " truncation=True,\n", " return_attention_mask=True,\n", " return_tensors=self.return_tensors,\n", " )\n", "\n", " batch = {\n", " \"prompts\": prompts,\n", " \"input_ids\": tokenized_text[\"input_ids\"],\n", " \"attention_mask\": tokenized_text[\"attention_mask\"],\n", " }\n", " return batch" ] }, { "cell_type": "code", "execution_count": 4, "id": "903935b5-8f75-4abb-b88c-42d2cf690d2a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] }, { "data": { "text/plain": [ "{'prompts': ['What is the capital of France?',\n", " 'What is the biggest planet in the solar system?'],\n", " 'input_ids': tensor([[151645, 151645, 151645, 3838, 374, 279, 6722, 315, 9625,\n", " 30],\n", " [ 3838, 374, 279, 8538, 11580, 304, 279, 12941, 1849,\n", " 30]]),\n", " 'attention_mask': tensor([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1],\n", " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "examples = [{\"prompt\": \"What is the capital of France?\"}, {\"prompt\": \"What is the biggest planet in the solar system?\"}]\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-1.5B-Instruct\")\n", "data_collator = DataCollatorForGeneration(tokenizer)\n", "data_loader = DataLoader(examples, batch_size=2, num_workers=2, collate_fn=data_collator)\n", "batch = next(iter(data_loader))\n", "batch" ] }, { "cell_type": "code", "execution_count": 5, "id": "803286c3-4d5a-4e5f-897f-e957df95c0f4", "metadata": {}, "outputs": [], "source": [ "class LLMGenerateLightningModule(pl.LightningModule):\n", " \"\"\"\n", " Generate responses from LLM. Expects input prompts, tokenized input_ids, attention_mask\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " pretrained_model_name_or_path,\n", " generation_config,\n", " prediction_config,\n", " cache_dir=\"/data\",\n", " ):\n", " super().__init__()\n", " self.model = AutoModelForCausalLM.from_pretrained(\n", " pretrained_model_name_or_path, cache_dir=cache_dir\n", " )\n", "\n", " self.tokenizer = AutoTokenizer.from_pretrained(\n", " pretrained_model_name_or_path, padding_side=\"left\", cache_dir=cache_dir\n", " )\n", " self.tokenizer.pad_token = self.tokenizer.eos_token\n", "\n", " self.generation_config = generation_config\n", " self._setup_prediction(prediction_config)\n", "\n", " def predict_step(self, batch, batch_idx, dataloader_idx=None):\n", " prompts = batch[\"prompts\"]\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", "\n", " responses = self.generate(input_ids, attention_mask)\n", "\n", " prediction_output = {\n", " \"prompts\": prompts,\n", " \"responses\": responses,\n", " }\n", " self.prediction_outputs.append(prediction_output)\n", " return prediction_output\n", "\n", " def generate(self, input_ids, attention_mask):\n", " model_output = self.model.generate(\n", " input_ids,\n", " attention_mask=attention_mask,\n", " generation_config=self.generation_config\n", " )\n", " # crop input prompt from generated response\n", " input_seq_length = input_ids.shape[-1]\n", " model_output_answer_only = model_output[:, input_seq_length:]\n", " responses = self.tokenizer.batch_decode(model_output_answer_only, skip_special_tokens=True)\n", " return responses\n", "\n", " def _setup_prediction(self, prediction_config):\n", " if prediction_config:\n", " self.prediction_outputs = []\n", " self._prediction_partition_idx = 0\n", " self.prediction_partition_format = prediction_config[\"prediction_partition_format\"]\n", " self.prediction_output_path = prediction_config[\"prediction_output_path\"]\n", " self.prediction_accumulation_steps = prediction_config.get(\"prediction_accumulation_steps\", 10)\n", "\n", " def _save_prediction_outputs(self):\n", " if self.prediction_output_path:\n", " data = {field: [] for field in self.prediction_outputs[0]}\n", " for prediction_output in self.prediction_outputs:\n", " for field in data:\n", " data[field].extend(prediction_output[field])\n", "\n", " partition_file_name = self.prediction_partition_format.format(\n", " rank=self.global_rank, partition=self._prediction_partition_idx\n", " )\n", " formatted_output_path = os.path.join(\n", " self.prediction_output_path, partition_file_name\n", " )\n", "\n", " # saves prediction batch locally via pandas data frame\n", " df_prediction_outputs = pd.DataFrame.from_dict(data)\n", " os.makedirs(self.prediction_output_path, exist_ok=True)\n", " df_prediction_outputs.to_parquet(formatted_output_path, index=False)\n", "\n", " self._prediction_partition_idx += 1\n", " self.prediction_outputs.clear()\n", "\n", " def on_predict_batch_end(self, outputs, batch, batch_idx):\n", " if len(self.prediction_outputs) == self.prediction_accumulation_steps:\n", " self._save_prediction_outputs()\n", "\n", " def on_predict_epoch_end(self):\n", " if len(self.prediction_outputs) > 0:\n", " self._save_prediction_outputs()" ] }, { "cell_type": "code", "execution_count": 6, "id": "9d786955-e5f4-438a-b978-f54beef409c1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicting DataLoader 0: 100%|██████████| 1/1 [00:10<00:00, 0.09it/s]\n" ] }, { "data": { "text/plain": [ "[{'prompts': ['What is the capital of France?',\n", " 'What is the biggest planet in the solar system?'],\n", " 'responses': [' The capital of France is Paris. It is located in the north of the country and is the largest city in France. Paris is known for its famous landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. It is also the political, cultural, and financial center of France. Paris is home to many museums, theaters, and other cultural institutions, and is known for its fashion, cuisine, and nightlife. The city is also famous for its art, music, and literature, and is home to many famous artists and writers. Paris is also known for its beautiful parks, including the Luxembourg Gardens and the Bois de Boulogne. The city is also home to many universities and research institutions, including the Sorbonne University and the École Normale Supérieure. Paris is also known for its fashion, cuisine, and nightlife, and is home to many famous artists and writers. The city is also known for its beautiful parks, including the Luxembourg Gardens and the Bois de Boulogne. The city is also home to many universities and research institutions, including the Sorbonne University and the École Normale Supérieure. Paris is also known for its fashion, cuisine, and nightlife',\n", " ' The biggest planet in the solar system is Jupiter. It is a gas giant planet with a diameter of about 86,881 miles (139,822 kilometers) and a mass of about 1.90 x 10^27 kilograms. Jupiter is also the fifth planet from the sun and is the largest planet in the solar system. It is composed mostly of hydrogen and helium, with a small amount of other elements. Jupiter has a strong magnetic field and a ring system, and it has at least 79 known moons. It is also known for its Great Red Spot, a giant storm that has been raging on Jupiter for at least 400 years. Jupiter is also known for its powerful winds, which can reach speeds of up to 430 miles per hour (690 kilometers per hour). Jupiter is also known for its powerful winds, which can reach speeds of up to 430 miles per hour (690 kilometers per hour). Jupiter is also known for its powerful winds, which can reach speeds of up to 430 miles per hour (690 kilometers per hour). Jupiter is also known for its powerful winds, which can reach speeds']}]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generation_config = GenerationConfig(\n", " max_new_tokens=250\n", ")\n", "llm_generate_module = LLMGenerateLightningModule(\n", " pretrained_model_name_or_path=\"Qwen/Qwen2.5-1.5B-Instruct\",\n", " generation_config=generation_config,\n", " prediction_config={\n", " \"prediction_output_path\": \"prediction\",\n", " \"prediction_partition_format\": \"rank-{rank:02d}-partition-{partition:06d}.parquet\"\n", " }\n", ")\n", "trainer = pl.Trainer()\n", "prediction_output = trainer.predict(llm_generate_module, data_loader)\n", "prediction_output" ] }, { "cell_type": "code", "execution_count": 7, "id": "7b60de75-7bd1-4203-b1cb-bd69a94cfed7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
promptsresponses
0What is the capital of France?The capital of France is Paris. It is located...
1What is the biggest planet in the solar system?The biggest planet in the solar system is Jup...
\n", "
" ], "text/plain": [ " prompts \\\n", "0 What is the capital of France? \n", "1 What is the biggest planet in the solar system? \n", "\n", " responses \n", "0 The capital of France is Paris. It is located... \n", "1 The biggest planet in the solar system is Jup... " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_prediction_output = pd.read_parquet(\"prediction\")\n", "df_prediction_output" ] }, { "cell_type": "markdown", "id": "a14775ac-11e1-467c-8c11-038bf8d6b5eb", "metadata": {}, "source": [ "## LLM Pairwise Judge" ] }, { "cell_type": "markdown", "id": "6db36da8-8004-4ec8-b6a1-c68acd9a47b6", "metadata": {}, "source": [ "The pairwise judge's implementation (prompt used) is inspired by huggingface's [HfPairwiseJudge](https://huggingface.co/docs/trl/main/en/judges#trl.HfPairwiseJudge). At the time of writing this, its backend relies on their own inference client which has poses some restriction on the model size free tier users are allowed to use.\n", "\n", "Our judge will also make an attempt to handle position bias. Position bias is when an LLM exhibits a propensity to favor certain positions over others, regardless of the actual content or quality of the answers. A conservative approach for addressing this issue is to call the judge twice, swapping the two answers' order, and only declare a win when an answer is preferred in both orders. If results are inconsistent after swapping, a tie can be declared. A more aggressive approach is to assign positions randomly, which can be effective at a large scale with the correct expectations. In the following experiments, we use the conservative approach." ] }, { "cell_type": "code", "execution_count": 8, "id": "21407936-e62f-4f0c-8692-fc2e4d91020d", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class DataCollatorForPairwiseJudge(DataCollatorMixin):\n", " \"\"\"\n", " tokenize raw text (prompt) as well as padding while forming a batch for data loader.\n", "\n", " Parameters\n", " ----------\n", " system_prompt :\n", " System prompt to be used for the judge. If not provided, a default prompt is used.\n", " System prompt should contain following placeholders: `{prompt}`, `{response1}`, and `{response2}`.\n", " \"\"\"\n", "\n", " default_system_prompt = '''\n", " I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.\n", "\n", " Instruction: {prompt}\n", "\n", " Model Outputs: Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.\n", "\n", " \"model_identifier\": \"1\", \"output\": \"\"\"{response1}\"\"\" \"model_identifier\": \"2\", \"output\": \"\"\"{response2}\"\"\"\n", "\n", " Task Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).\n", " '''\n", "\n", " tokenizer: PreTrainedTokenizerBase\n", " max_seq_len: int = 1024\n", " padding: Union[bool, str, PaddingStrategy] = True\n", " return_tensors: str = \"pt\"\n", " prompt_col_name: str = \"prompts\"\n", " response1_col_name: str = \"responses1\"\n", " response2_col_name: str = \"responses2\"\n", " system_prompt: Optional[str] = None\n", "\n", " def __post_init__(self):\n", " self.tokenizer.padding_side = \"left\"\n", " self.tokenizer.pad_token = self.tokenizer.eos_token\n", "\n", " self.system_prompt = self.system_prompt if self.system_prompt is not None else self.default_system_prompt\n", "\n", " def __call__(\n", " self, features: List[Dict[str, Any]], return_tensors=None\n", " ) -> Dict[str, Any]:\n", "\n", " judge_prompts = []\n", " judge_swapped_position_prompts = []\n", " for feature in features:\n", " prompt = feature[self.prompt_col_name]\n", " response1 = feature[self.response1_col_name]\n", " response2 = feature[self.response2_col_name]\n", " judge_prompt = self.system_prompt.format(\n", " prompt=prompt, response1=response1, response2=response2\n", " )\n", " judge_swapped_position_prompt = self.system_prompt.format(\n", " prompt=prompt, response1=response2, response2=response1\n", " )\n", " judge_prompts.append(judge_prompt)\n", " judge_swapped_position_prompts.append(judge_swapped_position_prompt)\n", "\n", " tokenized_text = self.tokenizer(\n", " judge_prompts,\n", " padding=self.padding,\n", " max_length=self.max_seq_len,\n", " truncation=True,\n", " return_attention_mask=True,\n", " return_tensors=self.return_tensors,\n", " )\n", "\n", " tokenized_swapped_position_text = self.tokenizer(\n", " judge_swapped_position_prompts,\n", " padding=self.padding,\n", " max_length=self.max_seq_len,\n", " truncation=True,\n", " return_attention_mask=True,\n", " return_tensors=self.return_tensors,\n", " )\n", "\n", " batch = {\n", " \"prompts\": judge_prompts,\n", " \"input_ids\": tokenized_text[\"input_ids\"],\n", " \"attention_mask\": tokenized_text[\"attention_mask\"],\n", " \"input_ids_swapped_position\": tokenized_swapped_position_text[\"input_ids\"],\n", " \"attention_mask_swapped_position\": tokenized_swapped_position_text[\"attention_mask\"],\n", " }\n", " return batch" ] }, { "cell_type": "code", "execution_count": 9, "id": "d51c561f-802e-4a80-9a28-0f4ca33006a8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['prompts', 'input_ids', 'attention_mask', 'input_ids_swapped_position', 'attention_mask_swapped_position'])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "examples = [\n", " {\"prompts\": \"What is the capital of France?\", \"responses1\": \"Paris\", \"responses2\": \"Taipei\"},\n", " {\"prompts\": \"What is the biggest planet in the solar system?\", \"responses1\": \"Saturn\", \"responses2\": \"Jupiter\"}\n", "]\n", "data_collator = DataCollatorForPairwiseJudge(tokenizer)\n", "data_loader = DataLoader(examples, batch_size=2, num_workers=2, collate_fn=data_collator)\n", "batch = next(iter(data_loader))\n", "batch.keys()" ] }, { "cell_type": "code", "execution_count": 10, "id": "bdcb82aa-852f-460b-ab9e-1924b85424e7", "metadata": {}, "outputs": [], "source": [ "class PairwiseLLMJudgeLightningModule(pl.LightningModule):\n", "\n", " def __init__(\n", " self,\n", " pretrained_model_name_or_path,\n", " generation_config,\n", " prediction_config,\n", " cache_dir=\"/data\",\n", " ):\n", " super().__init__()\n", " self.model = AutoModelForCausalLM.from_pretrained(\n", " pretrained_model_name_or_path, cache_dir=cache_dir\n", " )\n", "\n", " self.tokenizer = AutoTokenizer.from_pretrained(\n", " pretrained_model_name_or_path, padding_side=\"left\", cache_dir=cache_dir\n", " )\n", " self.tokenizer.pad_token = self.tokenizer.eos_token\n", "\n", " self.generation_config = generation_config\n", " self._setup_prediction(prediction_config)\n", "\n", " def predict_step(self, batch, batch_idx, dataloader_idx=None):\n", " prompts = batch[\"prompts\"]\n", "\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " responses = self.generate(input_ids, attention_mask)\n", "\n", " input_ids_swapped_position = batch[\"input_ids_swapped_position\"]\n", " attention_mask_swapped_position = batch[\"attention_mask_swapped_position\"]\n", " responses_swapped_position = self.generate(input_ids_swapped_position, attention_mask_swapped_position)\n", " \n", " prediction_output = {\n", " \"prompts\": prompts,\n", " \"responses\": responses,\n", " \"responses_swapped_position\": responses_swapped_position,\n", " }\n", " self.prediction_outputs.append(prediction_output)\n", " return prediction_output\n", "\n", " def generate(self, input_ids, attention_mask):\n", " model_output = self.model.generate(\n", " input_ids,\n", " attention_mask=attention_mask,\n", " generation_config=self.generation_config\n", " )\n", " # crop input prompt from generated response\n", " input_seq_length = input_ids.shape[-1]\n", " model_output_answer_only = model_output[:, input_seq_length:]\n", " responses = self.tokenizer.batch_decode(model_output_answer_only, skip_special_tokens=True)\n", " return responses\n", "\n", " def _setup_prediction(self, prediction_config):\n", " if prediction_config:\n", " self.prediction_outputs = []\n", " self._prediction_partition_idx = 0\n", " self.prediction_partition_format = prediction_config[\"prediction_partition_format\"]\n", " self.prediction_output_path = prediction_config[\"prediction_output_path\"]\n", " self.prediction_accumulation_steps = prediction_config.get(\"prediction_accumulation_steps\", 10)\n", "\n", " def _save_prediction_outputs(self):\n", " if self.prediction_output_path:\n", " data = {field: [] for field in self.prediction_outputs[0]}\n", " for prediction_output in self.prediction_outputs:\n", " for field in data:\n", " data[field].extend(prediction_output[field])\n", "\n", " partition_file_name = self.prediction_partition_format.format(\n", " rank=self.global_rank, partition=self._prediction_partition_idx\n", " )\n", " formatted_output_path = os.path.join(\n", " self.prediction_output_path, partition_file_name\n", " )\n", "\n", " # saves prediction batch locally via pandas data frame\n", " df_prediction_outputs = pd.DataFrame.from_dict(data)\n", " os.makedirs(self.prediction_output_path, exist_ok=True)\n", " df_prediction_outputs.to_parquet(formatted_output_path, index=False)\n", "\n", " self._prediction_partition_idx += 1\n", " self.prediction_outputs.clear()\n", "\n", " def on_predict_batch_end(self, outputs, batch, batch_idx):\n", " if len(self.prediction_outputs) == self.prediction_accumulation_steps:\n", " self._save_prediction_outputs()\n", "\n", " def on_predict_epoch_end(self):\n", " if len(self.prediction_outputs) > 0:\n", " self._save_prediction_outputs()" ] }, { "cell_type": "code", "execution_count": 11, "id": "bf548610-8ac0-4d78-af27-95ee6abb5911", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.31s/it]\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.83it/s]\n" ] }, { "data": { "text/plain": [ "[{'prompts': ['\\n I require a leaderboard for various large language models. I\\'ll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.\\n\\n Instruction: What is the capital of France?\\n\\n Model Outputs: Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.\\n\\n \"model_identifier\": \"1\", \"output\": \"\"\"Paris\"\"\" \"model_identifier\": \"2\", \"output\": \"\"\"Taipei\"\"\"\\n\\n Task Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).\\n ',\n", " '\\n I require a leaderboard for various large language models. I\\'ll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.\\n\\n Instruction: What is the biggest planet in the solar system?\\n\\n Model Outputs: Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.\\n\\n \"model_identifier\": \"1\", \"output\": \"\"\"Saturn\"\"\" \"model_identifier\": \"2\", \"output\": \"\"\"Jupiter\"\"\"\\n\\n Task Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).\\n '],\n", " 'responses': [' 1', ' 2'],\n", " 'responses_swapped_position': [' 2', ' 1']}]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generation_config = GenerationConfig(\n", " max_new_tokens=2,\n", ")\n", "pairwise_judge_module = PairwiseLLMJudgeLightningModule(\n", " pretrained_model_name_or_path=\"Qwen/Qwen2.5-3B-Instruct\",\n", " generation_config=generation_config,\n", " prediction_config={\n", " \"prediction_output_path\": \"judge\",\n", " \"prediction_partition_format\": \"rank-{rank:02d}-partition-{partition:06d}.parquet\"\n", " }\n", ")\n", "trainer = pl.Trainer()\n", "prediction_output = trainer.predict(pairwise_judge_module, data_loader)\n", "prediction_output" ] }, { "cell_type": "code", "execution_count": 12, "id": "dc3ebe32-0fa4-4c73-85af-7b1faab9d70b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
promptsresponsesresponses_swapped_position
0\\n I require a leaderboard for various larg...12
1\\n I require a leaderboard for various larg...21
\n", "
" ], "text/plain": [ " prompts responses \\\n", "0 \\n I require a leaderboard for various larg... 1 \n", "1 \\n I require a leaderboard for various larg... 2 \n", "\n", " responses_swapped_position \n", "0 2 \n", "1 1 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_prediction_output = pd.read_parquet(\"judge\")\n", "df_prediction_output" ] }, { "cell_type": "markdown", "id": "e8c82daf-20b6-449d-80df-cfa7292f1c6f", "metadata": {}, "source": [ "# Reference" ] }, { "cell_type": "markdown", "id": "1927f72c-9036-4f6b-b372-9bbe018fc964", "metadata": {}, "source": [ "- [[1]](https://huggingface.co/docs/transformers/llm_tutorial) Huggingface Documentation: Generation with LLMs\n", "- [[2]](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e) Huggingface Space Qwen 2.5 Collection\n", "- [[3]](https://huggingface.co/learn/cookbook/en/llm_judge) Using LLM-as-a-judge 🧑‍⚖️ for an automated and versatile evaluation\n", "- [[4]](https://www.databricks.com/blog/LLM-auto-eval-best-practices-RAG) Best Practices for LLM Evaluation of RAG Applications\n", "- [[5]](https://arxiv.org/abs/2306.05685) Lianmin Zheng, Wei-Lin Chiang, Ying Sheng et al. - Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena (2023)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }