{ "cells": [ { "cell_type": "markdown", "id": "d5dd7679-5ed6-4243-8e0f-549f8118bff7", "metadata": {}, "source": [ "# How to test and evaluate LLMs for SQL generation\n", "\n", "LLMs are fundamentally non-deterministic in their responses, this attribute makes them wonderfully creative and dynamic in their responses. However, this trait poses significant challenges in achieving consistency, a crucial aspect for integrating LLMs into production environments.\n", "\n", "The key to harnessing the potential of LLMs in practical applications lies in consistent and systematic evaluation. This enables the identification and rectification of inconsistencies and helps with monitoring progress over time as the application evolves.\n", "\n", "## Scope of this notebook\n", "\n", "This notebook aims to demonstrate a framework for evaluating LLMs, particularly focusing on:\n", "\n", "* **Unit Testing:** Essential for assessing individual components of the application.\n", "* **Evaluation Metrics:** Methods to quantitatively measure the model's effectiveness.\n", "* **Runbook Documentation:** A record of historical evaluations to track progress and regression.\n", "\n", "This example focuses on a natural language to SQL use case - code generation use cases fit well with this approach when you combine **code validation** with **code execution**, so your application can test code for real as it is generated to ensure consistency.\n", "\n", "Although this notebook uses SQL generation usecase to demonstrate the concept, the approach is generic and can be applied to a wide variety of LLM driven applications.\n", "\n", "We will use two versions of a prompt to perform SQL generation. We will then use the unit tests and evaluation functions to test the perforamance of the prompts. Specifically, in this demonstration, we will evaluate:\n", "\n", "1. The consistency of JSON response.\n", "2. Syntactic correctness of SQL in response.\n", "\n", "\n", "## Table of contents\n", "\n", "1. **[Setup](#Setup):** Install required libraries, download data consisting of SQL queries and corresponding natural language translations.\n", "2. **[Test Development](#Test-development):** Create unit tests and define evaluation metrics for the SQL generation process.\n", "3. **[Evaluation](#Evaluation):** Conduct tests using different prompts to assess the impact on performance.\n", "4. **[Reporting](#Report):** Compile a report that succinctly presents the performance differences observed across various tests." ] }, { "cell_type": "markdown", "id": "2913d615", "metadata": {}, "source": [ "## Setup\n", "\n", "Import our libraries and the dataset we'll use, which is the natural language to SQL [b-mc2/sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context) dataset from HuggingFace." ] }, { "cell_type": "code", "execution_count": 1, "id": "c7f325fc", "metadata": {}, "outputs": [], "source": [ "# Uncomment this to install all necessary dependencies\n", "# !pip install openai datasets pandas pydantic matplotlib python-dotenv numpy tqdm" ] }, { "cell_type": "code", "execution_count": 2, "id": "245fcedb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "78577 rows\n" ] } ], "source": [ "from datasets import load_dataset\n", "from openai import OpenAI\n", "import pandas as pd\n", "import pydantic\n", "import os\n", "import sqlite3\n", "from sqlite3 import Error\n", "from pprint import pprint\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from dotenv import load_dotenv\n", "from tqdm.notebook import tqdm\n", "from IPython.display import HTML, display\n", "\n", "# Loads key from local .env file to setup API KEY in env variables\n", "%reload_ext dotenv\n", "%dotenv\n", " \n", "GPT_MODEL = 'gpt-4o'\n", "dataset = load_dataset(\"b-mc2/sql-create-context\")\n", "\n", "print(dataset['train'].num_rows, \"rows\")" ] }, { "cell_type": "markdown", "id": "04c7fde6-d7dc-4a0d-b9a0-32858f3bac25", "metadata": {}, "source": [ "### Looking at the dataset\n", "\n", "We use Huggingface datasets library to download SQL create context dataset. This dataset consists of:\n", "\n", "1. Question, expressed in natural language\n", "2. Answer, expressed in SQL designed to answer the question in natural language.\n", "3. Context, expressed as a CREATE SQL statement, that describes the table that may be used to answer the question.\n", "\n", "In our demonstration today, we will use LLM to attempt to answer the question (in natural language). The LLM will be expected to generate a CREATE SQL statement to create a context suitable to answer the user question and a coresponding SELECT SQL query designed to answer the user question completely.\n", "\n", "The dataset looks like this:" ] }, { "cell_type": "code", "execution_count": 3, "id": "f8027115", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
answerquestioncontext
0SELECT COUNT(*) FROM head WHERE age > 56How many heads of the departments are older th...CREATE TABLE head (age INTEGER)
1SELECT name, born_state, age FROM head ORDER B...List the name, born state and age of the heads...CREATE TABLE head (name VARCHAR, born_state VA...
2SELECT creation, name, budget_in_billions FROM...List the creation year, name and budget of eac...CREATE TABLE department (creation VARCHAR, nam...
3SELECT MAX(budget_in_billions), MIN(budget_in_...What are the maximum and minimum budget of the...CREATE TABLE department (budget_in_billions IN...
4SELECT AVG(num_employees) FROM department WHER...What is the average number of employees of the...CREATE TABLE department (num_employees INTEGER...
\n", "
" ], "text/plain": [ " answer \\\n", "0 SELECT COUNT(*) FROM head WHERE age > 56 \n", "1 SELECT name, born_state, age FROM head ORDER B... \n", "2 SELECT creation, name, budget_in_billions FROM... \n", "3 SELECT MAX(budget_in_billions), MIN(budget_in_... \n", "4 SELECT AVG(num_employees) FROM department WHER... \n", "\n", " question \\\n", "0 How many heads of the departments are older th... \n", "1 List the name, born state and age of the heads... \n", "2 List the creation year, name and budget of eac... \n", "3 What are the maximum and minimum budget of the... \n", "4 What is the average number of employees of the... \n", "\n", " context \n", "0 CREATE TABLE head (age INTEGER) \n", "1 CREATE TABLE head (name VARCHAR, born_state VA... \n", "2 CREATE TABLE department (creation VARCHAR, nam... \n", "3 CREATE TABLE department (budget_in_billions IN... \n", "4 CREATE TABLE department (num_employees INTEGER... " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sql_df = dataset['train'].to_pandas()\n", "sql_df.head()" ] }, { "cell_type": "markdown", "id": "b04cb5eb", "metadata": {}, "source": [ "## Test development\n", "\n", "To test the output of the LLM generations, we'll develop two unit tests and an evaluation, which will combine to give us a basic evaluation framework to grade the quality of our LLM iterations.\n", "\n", "To re-iterate, our purpose is to measure the correctness and consistency of LLM output given our questions.\n", "\n", "### Unit tests\n", "\n", "Unit tests should test the most granular components of your LLM application.\n", "\n", "For this section we'll develop unit tests to test the following:\n", "- `test_valid_schema` will check that a parseable `create` and `select` statement are returned by the LLM.\n", "- `test_llm_sql` will execute both the `create` and `select` statements on a `sqlite` database to ensure they are syntactically correct." ] }, { "cell_type": "code", "execution_count": 4, "id": "eb811101", "metadata": {}, "outputs": [], "source": [ "from pydantic import BaseModel\n", "\n", "\n", "class LLMResponse(BaseModel):\n", " \"\"\"This is the structure that we expect the LLM to respond with.\n", "\n", " The LLM should respond with a JSON string with `create` and `select` fields.\n", " \"\"\"\n", " create: str\n", " select: str" ] }, { "cell_type": "markdown", "id": "19fadf67-8b2f-4e17-95df-030a36aad90b", "metadata": {}, "source": [ "#### Prompting the LLM\n", "\n", "For this demonstration purposes, we use a fairly simple prompt requesting GPT to generate a `(context, answer)` pair. `context` is the `CREATE` SQL statement, and `answer` is the `SELECT` SQL statement. We supply the natural language question as part of the prompt. We request the response to be in JSON format, so that it can be parsed easily." ] }, { "cell_type": "code", "execution_count": 5, "id": "c2be3ba4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Question: How many heads of the departments are older than 56 ?\n", "Answer: {\"create\":\"CREATE TABLE DepartmentHeads (\\n id INT PRIMARY KEY,\\n name VARCHAR(100),\\n age INT,\\n department VARCHAR(100)\\n);\",\"select\":\"SELECT COUNT(*) AS NumberOfHeadsOlderThan56 \\nFROM DepartmentHeads \\nWHERE age > 56;\"}\n" ] } ], "source": [ "system_prompt = \"\"\"Translate this natural language request into a JSON\n", "object containing two SQL queries. The first query should be a CREATE \n", "tatement for a table answering the user's request, while the second\n", "should be a SELECT query answering their question.\"\"\"\n", "\n", "# Sending the message array to GPT, requesting a response (ensure that you\n", "# have API key loaded to Env for this step)\n", "client = OpenAI()\n", "\n", "def get_response(system_prompt, user_message, model=GPT_MODEL):\n", " messages = []\n", " messages.append({\"role\": \"system\", \"content\": system_prompt})\n", " messages.append({\"role\": \"user\", \"content\": user_message})\n", "\n", " response = client.beta.chat.completions.parse(\n", " model=GPT_MODEL,\n", " messages=messages,\n", " response_format=LLMResponse,\n", " )\n", " return response.choices[0].message.content\n", "\n", "question = sql_df.iloc[0]['question']\n", "content = get_response(system_prompt, question)\n", "print(\"Question:\", question)\n", "print(\"Answer:\", content)" ] }, { "cell_type": "markdown", "id": "901e3bb7", "metadata": {}, "source": [ "#### Check JSON formatting\n", "\n", "Our first simple unit test checks that the LLM response is parseable into the `LLMResponse` Pydantic class that we've defined.\n", "\n", "We'll test that our first response passes, then create a failing example to check that the check fails. This logic will be wrapped in a simple function `test_valid_schema`.\n", "\n", "We expect GPT to respond with a valid SQL, we can validate this using LLMResponse base model. `test_valid_schema` is designed to help us validate this." ] }, { "cell_type": "code", "execution_count": 6, "id": "4c7133f1-74d6-43f1-9443-09a3f8308c35", "metadata": {}, "outputs": [], "source": [ "def test_valid_schema(content):\n", " \"\"\"Tests whether the content provided can be parsed into our Pydantic model.\"\"\"\n", " try:\n", " LLMResponse.model_validate_json(content)\n", " return True\n", " # Catch pydantic's validation errors:\n", " except pydantic.ValidationError as exc:\n", " print(f\"ERROR: Invalid schema: {exc}\")\n", " return False" ] }, { "cell_type": "code", "execution_count": 7, "id": "6a9a9128", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_valid_schema(content)" ] }, { "cell_type": "markdown", "id": "78f1af23-4dd0-4860-8a1a-88e5146be6ed", "metadata": {}, "source": [ "#### Testing negative scenario\n", "\n", "To simulate a scenario in which we get an invalid JSON response from GPT, we hardcode an invalid JSON as response. We expect `test_valid_schema` function to throw an exception." ] }, { "cell_type": "code", "execution_count": 8, "id": "a0a26690", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ERROR: Invalid schema: 1 validation error for LLMResponse\n", " Invalid JSON: expected value at line 1 column 1 [type=json_invalid, input_value='CREATE departments, select * from departments', input_type=str]\n", " For further information visit https://errors.pydantic.dev/2.10/v/json_invalid\n" ] }, { "data": { "text/plain": [ "False" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "failing_query = 'CREATE departments, select * from departments'\n", "test_valid_schema(failing_query)" ] }, { "cell_type": "markdown", "id": "a5fdc420-94cc-4e47-80e1-82bf51e44f2a", "metadata": {}, "source": [ "As expected, we get an exception thrown from the `test_valid_schema` fucntion." ] }, { "cell_type": "markdown", "id": "a4e972cd-5734-43c0-a103-b9ceb41552fd", "metadata": {}, "source": [ "### Test SQL queries\n", "\n", "Next we'll validate the correctness of the SQL. This test will be desined to validate:\n", "\n", "1. The CREATE SQL returned in GPT response is syntactically correct.\n", "2. The SELECT SQL returned in the GPT response is syntactically correct.\n", "\n", "To achieve this, we will use a sqlite instance. We will direct the retured SQL functions to a sqlite instance. If the SQL statements are valid, sqlite instance will accept and execute the statements; otherwise we will expect an exception to be thrown.\n", "\n", "`create_connection` function below will setup a sqlite instance (in-memory by default) and create a connection to be used later." ] }, { "cell_type": "code", "execution_count": 9, "id": "9cc95481", "metadata": {}, "outputs": [], "source": [ "# Set up SQLite to act as our test database\n", "def create_connection(db_file=\":memory:\"):\n", " \"\"\"create a database connection to a SQLite database\"\"\"\n", " try:\n", " conn = sqlite3.connect(db_file)\n", " # print(sqlite3.version)\n", " except Error as e:\n", " print(e)\n", " return None\n", "\n", " return conn\n", "\n", "def close_connection(conn):\n", " \"\"\"close a database connection\"\"\"\n", " try:\n", " conn.close()\n", " except Error as e:\n", " print(e)\n", "\n", "\n", "conn = create_connection()" ] }, { "cell_type": "markdown", "id": "aa5c5cb8-1c81-403b-a3f2-f2784d8235fc", "metadata": {}, "source": [ "Next, we will create the following functions to carry out the syntactical correctness checks.\n", "\n", "\n", "- `test_create`: Function testing if the CREATE SQL statement succeeds.\n", "- `test_select`: Function testing if the SELECT SQL statement succeeds.\n", "- `test_llm_sql`: Wrapper function executing the two tests above." ] }, { "cell_type": "code", "execution_count": 10, "id": "c6d2573d", "metadata": {}, "outputs": [], "source": [ "def test_select(conn, cursor, select, should_log=True):\n", " \"\"\"Tests that a SQLite select query can be executed successfully.\"\"\"\n", " try:\n", " if should_log:\n", " print(f\"Testing select query: {select}\")\n", " cursor.execute(select)\n", " record = cursor.fetchall()\n", " if should_log:\n", " print(f\"Result of query: {record}\")\n", "\n", " return True\n", "\n", " except sqlite3.Error as error:\n", " if should_log:\n", " print(\"Error while executing select query:\", error)\n", " return False\n", "\n", "\n", "def test_create(conn, cursor, create, should_log=True):\n", " \"\"\"Tests that a SQLite create query can be executed successfully\"\"\"\n", " try:\n", " if should_log:\n", " print(f\"Testing create query: {create}\")\n", " cursor.execute(create)\n", " conn.commit()\n", "\n", " return True\n", "\n", " except sqlite3.Error as error:\n", " if should_log:\n", " print(\"Error while creating the SQLite table:\", error)\n", " return False\n", "\n", "\n", "def test_llm_sql(llm_response, should_log=True):\n", " \"\"\"Runs a suite of SQLite tests\"\"\"\n", " try:\n", " conn = create_connection()\n", " cursor = conn.cursor()\n", "\n", " create_response = test_create(conn, cursor, llm_response.create, should_log=should_log)\n", "\n", " select_response = test_select(conn, cursor, llm_response.select, should_log=should_log)\n", "\n", " if conn:\n", " close_connection(conn)\n", "\n", " if create_response is not True:\n", " return False\n", "\n", " elif select_response is not True:\n", " return False\n", "\n", " else:\n", " return True\n", "\n", " except sqlite3.Error as error:\n", " if should_log:\n", " print(\"Error while creating a sqlite table\", error)\n", " return False" ] }, { "cell_type": "code", "execution_count": 11, "id": "a9266753-4646-4901-bc14-632d3bf47aaa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CREATE SQL is: CREATE TABLE DepartmentHeads (\n", " id INT PRIMARY KEY,\n", " name VARCHAR(100),\n", " age INT,\n", " department VARCHAR(100)\n", ");\n", "SELECT SQL is: SELECT COUNT(*) AS NumberOfHeadsOlderThan56 \n", "FROM DepartmentHeads \n", "WHERE age > 56;\n" ] } ], "source": [ "# Viewing CREATE and SELECT sqls returned by GPT\n", "\n", "test_query = LLMResponse.model_validate_json(content)\n", "print(f\"CREATE SQL is: {test_query.create}\")\n", "print(f\"SELECT SQL is: {test_query.select}\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "83bc1f1b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing create query: CREATE TABLE DepartmentHeads (\n", " id INT PRIMARY KEY,\n", " name VARCHAR(100),\n", " age INT,\n", " department VARCHAR(100)\n", ");\n", "Testing select query: SELECT COUNT(*) AS NumberOfHeadsOlderThan56 \n", "FROM DepartmentHeads \n", "WHERE age > 56;\n", "Result of query: [(0,)]\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Testing the CREATE and SELECT sqls are valid (we expect this to be succesful)\n", "\n", "test_llm_sql(test_query)" ] }, { "cell_type": "code", "execution_count": 13, "id": "589c7cc7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing create query: CREATE TABLE departments (id INT, name VARCHAR(255), head_of_department VARCHAR(255))\n", "Testing select query: SELECT COUNT(*) FROM departments WHERE age > 56\n", "Error while executing select query: no such column: age\n" ] }, { "data": { "text/plain": [ "False" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Again we'll perform a negative test to confirm that a failing SELECT will return an error.\n", "\n", "test_failure_query = '{\"create\": \"CREATE TABLE departments (id INT, name VARCHAR(255), head_of_department VARCHAR(255))\", \"select\": \"SELECT COUNT(*) FROM departments WHERE age > 56\"}'\n", "test_failure_query = LLMResponse.model_validate_json(test_failure_query)\n", "test_llm_sql(test_failure_query)" ] }, { "cell_type": "markdown", "id": "8148f820", "metadata": {}, "source": [ "### Using an LLM to evaluate relevancy\n", "\n", "Next, we **evaluate** whether the generated SQL actually answers the user's question. This test will be performed by `gpt-4o-mini`, and will assess how **relevant** the produced SQL query is when compared to the initial user request.\n", "\n", "This is a simple example which adapts an approach outlined in the [G-Eval paper](https://arxiv.org/abs/2303.16634), and tested in one of our other [cookbooks](https://github.com/openai/openai-cookbook/blob/main/examples/evaluation/How_to_eval_abstractive_summarization.ipynb)." ] }, { "cell_type": "code", "execution_count": 14, "id": "029c8426", "metadata": {}, "outputs": [], "source": [ "EVALUATION_MODEL = \"gpt-4o-mini\"\n", "\n", "EVALUATION_PROMPT_TEMPLATE = \"\"\"\n", "You will be given one summary written for an article. Your task is to rate the summary on one metric.\n", "Please make sure you read and understand these instructions very carefully. \n", "Please keep this document open while reviewing, and refer to it as needed.\n", "\n", "Evaluation Criteria:\n", "\n", "{criteria}\n", "\n", "Evaluation Steps:\n", "\n", "{steps}\n", "\n", "Example:\n", "\n", "Request:\n", "\n", "{request}\n", "\n", "Queries:\n", "\n", "{queries}\n", "\n", "Evaluation Form (scores ONLY):\n", "\n", "- {metric_name}\n", "\"\"\"\n", "\n", "# Relevance\n", "\n", "RELEVANCY_SCORE_CRITERIA = \"\"\"\n", "Relevance(1-5) - review of how relevant the produced SQL queries are to the original question. \\\n", "The queries should contain all points highlighted in the user's request. \\\n", "Annotators were instructed to penalize queries which contained redundancies and excess information.\n", "\"\"\"\n", "\n", "RELEVANCY_SCORE_STEPS = \"\"\"\n", "1. Read the request and the queries carefully.\n", "2. Compare the queries to the request document and identify the main points of the request.\n", "3. Assess how well the queries cover the main points of the request, and how much irrelevant or redundant information it contains.\n", "4. Assign a relevance score from 1 to 5.\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 15, "id": "85cfb78d", "metadata": {}, "outputs": [], "source": [ "def get_geval_score(\n", " criteria: str, steps: str, request: str, queries: str, metric_name: str\n", "):\n", " \"\"\"Given evaluation criteria and an observation, this function uses EVALUATION GPT to evaluate the observation against those criteria.\n", "\"\"\"\n", " prompt = EVALUATION_PROMPT_TEMPLATE.format(\n", " criteria=criteria,\n", " steps=steps,\n", " request=request,\n", " queries=queries,\n", " metric_name=metric_name,\n", " )\n", " response = client.chat.completions.create(\n", " model=EVALUATION_MODEL,\n", " messages=[{\"role\": \"user\", \"content\": prompt}],\n", " temperature=0,\n", " max_tokens=5,\n", " top_p=1,\n", " frequency_penalty=0,\n", " presence_penalty=0,\n", " )\n", " return response.choices[0].message.content" ] }, { "cell_type": "code", "execution_count": 16, "id": "607ee304", "metadata": {}, "outputs": [], "source": [ "# Test out evaluation on a few records\n", "\n", "evaluation_results = []\n", "\n", "for x,y in sql_df.head(3).iterrows():\n", " score = get_geval_score(\n", " RELEVANCY_SCORE_CRITERIA,\n", " RELEVANCY_SCORE_STEPS,\n", " y['question'],\n", " y['context'] + '\\n' + y['answer'],'relevancy'\n", " )\n", " evaluation_results.append((y['question'],y['context'] + '\\n' + y['answer'],score))" ] }, { "cell_type": "code", "execution_count": 17, "id": "bd1002c2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "User Question \t: How many heads of the departments are older than 56 ?\n", "CREATE SQL Returned \t: CREATE TABLE head (age INTEGER)\n", "SELECT SQL Returned \t: SELECT COUNT(*) FROM head WHERE age > 56\n", "5\n", "********************\n", "User Question \t: List the name, born state and age of the heads of departments ordered by age.\n", "CREATE SQL Returned \t: CREATE TABLE head (name VARCHAR, born_state VARCHAR, age VARCHAR)\n", "SELECT SQL Returned \t: SELECT name, born_state, age FROM head ORDER BY age\n", "4\n", "********************\n", "User Question \t: List the creation year, name and budget of each department.\n", "CREATE SQL Returned \t: CREATE TABLE department (creation VARCHAR, name VARCHAR, budget_in_billions VARCHAR)\n", "SELECT SQL Returned \t: SELECT creation, name, budget_in_billions FROM department\n", "4\n", "********************\n" ] } ], "source": [ "for result in evaluation_results:\n", " print(f\"User Question \\t: {result[0]}\")\n", " print(f\"CREATE SQL Returned \\t: {result[1].splitlines()[0]}\")\n", " print(f\"SELECT SQL Returned \\t: {result[1].splitlines()[1]}\")\n", " print(f\"{result[2]}\")\n", " print(\"*\" * 20)" ] }, { "cell_type": "markdown", "id": "61b68e2a", "metadata": {}, "source": [ "## Evaluation\n", "\n", "We will test these functions in combination including our unit test and evaluations to test out two system prompts.\n", "\n", "Each iteration of input/output and scores should be stored as a **run**. Optionally you can add GPT-4 annotation within your evaluations or as a separate step to review an entire run and highlight the reasons for errors.\n", "\n", "For this example, the second system prompt will include an extra line of clarification, so we can assess the impact of this for both SQL validity and quality of solution." ] }, { "cell_type": "markdown", "id": "3b578b00-1b27-49de-8fd1-15c00ec99729", "metadata": {}, "source": [ "### Building the test framework\n", "\n", "We want to build a function, `test_system_prompt`, which will run our unit tests and evaluation against a given system prompt." ] }, { "cell_type": "code", "execution_count": 18, "id": "40be5fae-4eb3-40ce-8645-613c24d5e0b4", "metadata": {}, "outputs": [], "source": [ "def execute_unit_tests(input_df, output_list, system_prompt):\n", " \"\"\"Unit testing function that takes in a dataframe and appends test results to an output_list.\"\"\"\n", "\n", " for x, y in tqdm(input_df.iterrows(), total=len(input_df)):\n", " model_response = get_response(system_prompt, y['question'])\n", "\n", " format_valid = test_valid_schema(model_response)\n", "\n", " try:\n", " test_query = LLMResponse.model_validate_json(model_response)\n", " # Avoid logging since we're executing many rows at once\n", " sql_valid = test_llm_sql(test_query, should_log=False)\n", " except:\n", " sql_valid = False\n", "\n", " output_list.append((y['question'], model_response, format_valid, sql_valid))\n", " \n", "def evaluate_row(row):\n", " \"\"\"Simple evaluation function to categorize unit testing results.\n", " \n", " If the format or SQL are flagged it returns a label, otherwise it is correct\"\"\"\n", " if row['format'] is False:\n", " return 'Format incorrect'\n", " elif row['sql'] is False:\n", " return 'SQL incorrect'\n", " else:\n", " return 'SQL correct'\n", "\n", "def test_system_prompt(test_df, system_prompt):\n", " # Execute unit tests and capture results\n", " results = []\n", " execute_unit_tests(\n", " input_df=test_df,\n", " output_list=results,\n", " system_prompt=system_prompt\n", " )\n", " \n", " results_df = pd.DataFrame(results)\n", " results_df.columns = ['question','response','format','sql']\n", " \n", " # Use `apply` to calculate the geval score and unit test evaluation\n", " # for each generated response\n", " results_df['evaluation_score'] = results_df.apply(\n", " lambda x: get_geval_score(\n", " RELEVANCY_SCORE_CRITERIA,\n", " RELEVANCY_SCORE_STEPS,\n", " x['question'],\n", " x['response'],\n", " 'relevancy'\n", " ),\n", " axis=1\n", " )\n", " results_df['unit_test_evaluation'] = results_df.apply(\n", " lambda x: evaluate_row(x),\n", " axis=1\n", " )\n", " return results_df" ] }, { "cell_type": "markdown", "id": "6abc2c22-d7c6-4f15-b519-60cc58ff7774", "metadata": {}, "source": [ "### System Prompt 1\n", "\n", "The system under test is the first system prompt as shown below. This `run` will generate responses for this system prompt and evaluate the responses using the functions we've created so far." ] }, { "cell_type": "code", "execution_count": 19, "id": "85c44a17", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4d39ec72385f4b74bed652bfa54427f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
questionresponseformatsqlevaluation_scoreunit_test_evaluationrunEvaluating Model
0What venue did the parntership of shoaib malik...{\"create\":\"CREATE TABLE cricket_partnerships (...TrueTrue5SQL correct1gpt-4
1What venue did the partnership of herschelle g...{\"create\":\"CREATE TABLE CricketPartnerships (\\...TrueTrue5SQL correct1gpt-4
2What is the number Played that has 310 Points ...{\"create\":\"CREATE TABLE game_stats (\\n numb...TrueTrue5SQL correct1gpt-4
3What Losing bonus has a Points against of 588?{\"create\":\"CREATE TABLE BonusInfo (\\n id IN...TrueTrue5SQL correct1gpt-4
4What Tries against has a Losing bonus of 7?{\"create\":\"CREATE TABLE matches (\\n id SERI...TrueTrue5SQL correct1gpt-4
\n", "" ], "text/plain": [ " question \\\n", "0 What venue did the parntership of shoaib malik... \n", "1 What venue did the partnership of herschelle g... \n", "2 What is the number Played that has 310 Points ... \n", "3 What Losing bonus has a Points against of 588? \n", "4 What Tries against has a Losing bonus of 7? \n", "\n", " response format sql \\\n", "0 {\"create\":\"CREATE TABLE cricket_partnerships (... True True \n", "1 {\"create\":\"CREATE TABLE CricketPartnerships (\\... True True \n", "2 {\"create\":\"CREATE TABLE game_stats (\\n numb... True True \n", "3 {\"create\":\"CREATE TABLE BonusInfo (\\n id IN... True True \n", "4 {\"create\":\"CREATE TABLE matches (\\n id SERI... True True \n", "\n", " evaluation_score unit_test_evaluation run Evaluating Model \n", "0 5 SQL correct 1 gpt-4 \n", "1 5 SQL correct 1 gpt-4 \n", "2 5 SQL correct 1 gpt-4 \n", "3 5 SQL correct 1 gpt-4 \n", "4 5 SQL correct 1 gpt-4 " ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_df['run'] = 1\n", "results_df['Evaluating Model'] = 'gpt-4'\n", "\n", "results_2_df['run'] = 2\n", "results_2_df['Evaluating Model'] = 'gpt-4'\n", "\n", "run_df = pd.concat([results_df,results_2_df])\n", "run_df.head()" ] }, { "cell_type": "markdown", "id": "0162a009-fc43-484c-90f6-d59a8e52f365", "metadata": {}, "source": [ "#### Plotting unit test results\n", "\n", "We can create a simple bar chart to visualise the results of unit tests for both runs." ] }, { "cell_type": "code", "execution_count": 26, "id": "ed800f0c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Number of records
rununit_test_evaluation
1SQL correct46
SQL incorrect4
2SQL correct44
SQL incorrect6
\n", "
" ], "text/plain": [ " Number of records\n", "run unit_test_evaluation \n", "1 SQL correct 46\n", " SQL incorrect 4\n", "2 SQL correct 44\n", " SQL incorrect 6" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "unittest_df_pivot = pd.pivot_table(\n", " run_df,\n", " values='format',\n", " index=['run','unit_test_evaluation'],\n", " aggfunc='count'\n", ")\n", "unittest_df_pivot.columns = ['Number of records']\n", "unittest_df_pivot" ] }, { "cell_type": "code", "execution_count": 27, "id": "e2b4aa03-42f5-4c30-a610-e553937bf160", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "unittest_df_pivot.reset_index(inplace=True)\n", "\n", "# Plotting\n", "plt.figure(figsize=(10, 6))\n", "\n", "# Set the width of each bar\n", "bar_width = 0.35\n", "\n", "# OpenAI brand colors\n", "openai_colors = ['#00D1B2', '#000000'] # Green and Black\n", "\n", "# Get unique runs and unit test evaluations\n", "unique_runs = unittest_df_pivot['run'].unique()\n", "unique_unit_test_evaluations = unittest_df_pivot['unit_test_evaluation'].unique()\n", "\n", "# Ensure we have enough colors (repeating the pattern if necessary)\n", "colors = openai_colors * (len(unique_runs) // len(openai_colors) + 1)\n", "\n", "# Iterate over each run to plot\n", "for i, run in enumerate(unique_runs):\n", " run_data = unittest_df_pivot[unittest_df_pivot['run'] == run]\n", "\n", " # Position of bars for this run\n", " positions = np.arange(len(unique_unit_test_evaluations)) + i * bar_width\n", "\n", " plt.bar(positions, run_data['Number of records'], width=bar_width, label=f'Run {run}', color=colors[i])\n", "\n", "# Setting the x-axis labels to be the unit test evaluations, centered under the groups\n", "plt.xticks(np.arange(len(unique_unit_test_evaluations)) + bar_width / 2, unique_unit_test_evaluations)\n", "\n", "plt.xlabel('Unit Test Evaluation')\n", "plt.ylabel('Number of Records')\n", "plt.title('Unit Test Evaluations vs Number of Records for Each Run')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "786515fa-6841-4820-98f9-aa29ae76cf76", "metadata": {}, "source": [ "#### Plotting evaluation results\n", "\n", "We can similarly plot the results of the evaluation." ] }, { "cell_type": "code", "execution_count": 28, "id": "7228eac7-e0a9-473d-9432-e558bbc91841", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Number of records
runevaluation_score
131
416
533
231
415
534
\n", "
" ], "text/plain": [ " Number of records\n", "run evaluation_score \n", "1 3 1\n", " 4 16\n", " 5 33\n", "2 3 1\n", " 4 15\n", " 5 34" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluation_df_pivot = pd.pivot_table(\n", " run_df,\n", " values='format',\n", " index=['run','evaluation_score'],\n", " aggfunc='count'\n", ")\n", "evaluation_df_pivot.columns = ['Number of records']\n", "evaluation_df_pivot" ] }, { "cell_type": "code", "execution_count": 29, "id": "b2a18a78-55ec-43f6-9d62-929707a94364", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Reset index without dropping the 'run' and 'evaluation_score' columns\n", "evaluation_df_pivot.reset_index(inplace=True)\n", "\n", "# Plotting\n", "plt.figure(figsize=(10, 6))\n", "\n", "bar_width = 0.35\n", "\n", "# OpenAI brand colors\n", "openai_colors = ['#00D1B2', '#000000'] # Green, Black\n", "\n", "# Identify unique runs and evaluation scores\n", "unique_runs = evaluation_df_pivot['run'].unique()\n", "unique_evaluation_scores = evaluation_df_pivot['evaluation_score'].unique()\n", "\n", "# Repeat colors if there are more runs than colors\n", "colors = openai_colors * (len(unique_runs) // len(openai_colors) + 1)\n", "\n", "for i, run in enumerate(unique_runs):\n", " # Select rows for this run only\n", " run_data = evaluation_df_pivot[evaluation_df_pivot['run'] == run].copy()\n", " \n", " # Ensure every 'evaluation_score' is present\n", " run_data.set_index('evaluation_score', inplace=True)\n", " run_data = run_data.reindex(unique_evaluation_scores, fill_value=0)\n", " run_data.reset_index(inplace=True)\n", " \n", " # Plot each bar\n", " positions = np.arange(len(unique_evaluation_scores)) + i * bar_width\n", " plt.bar(\n", " positions,\n", " run_data['Number of records'],\n", " width=bar_width,\n", " label=f'Run {run}',\n", " color=colors[i]\n", " )\n", "\n", "# Configure the x-axis to show evaluation scores under the grouped bars\n", "plt.xticks(np.arange(len(unique_evaluation_scores)) + bar_width / 2, unique_evaluation_scores)\n", "\n", "plt.xlabel('Evaluation Score')\n", "plt.ylabel('Number of Records')\n", "plt.title('Evaluation Scores vs Number of Records for Each Run')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "992f9aa0", "metadata": {}, "source": [ "## Conclusion\n", "\n", "Now you have a framework to test SQL generation using LLMs, and with some tweaks this approach can be extended to many other code generation use cases. With GPT-4 and engaged human labellers you can aim to automate the evaluation of these test cases, making an iterative loop where new examples are added to the test set and this structure detects any performance regressions. \n", "\n", "We hope you find this useful, and please supply any feedback." ] }, { "cell_type": "code", "execution_count": null, "id": "8368c786-38eb-4ca3-b5f4-cad63fec87bd", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 5 }