{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Using LangChain to get structured outputs\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from langchain.prompts import ChatPromptTemplate\n", "from langchain_anthropic import ChatAnthropic\n", "from langchain_ollama import ChatOllama\n", "from langchain_fireworks import ChatFireworks\n", "\n", "from experiment_xml import (\n", " pydantic_to_xml_instructions,\n", " run_xml_experiment,\n", ")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "ANTHROPIC_API_KEY = \"\"\n", "FIREWORKS_API_KEY = \"\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import streamlit as st\n", "\n", "LANGSMITH_API_KEY = st.secrets[\"api_keys\"][\"LANGSMITH_API_KEY\"]\n", "ANTHROPIC_API_KEY = st.secrets[\"api_keys\"][\"ANTHROPIC_API_KEY\"]\n", "FIREWORKS_API_KEY = st.secrets[\"api_keys\"][\"FIREWORKS_API_KEY\"]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "experiment_date = \"21-02-25\"\n", "n_iter = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by creating a LLM model to run our structured output queries. Use a temperature of 0 to improve structured output generation (but at the cost of \"creativity\").\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model setup\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# LLM parameters\n", "temperature = 0.8\n", "timeout = 30\n", "num_ctx = 8192\n", "num_predict = 4096" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "llm_models_test = {\n", " \"Ollama_llama32\": ChatOllama(\n", " model=\"llama3.2\",\n", " temperature=temperature,\n", " num_ctx=num_ctx,\n", " num_thread=1,\n", " num_predict=num_predict,\n", " ),\n", " \"Ollama_phi3\": ChatOllama(\n", " model=\"phi3\",\n", " temperature=temperature,\n", " num_ctx=num_ctx,\n", " num_thread=1,\n", " num_predict=num_predict,\n", " ),\n", " \"Ollama_deepseekr1\": ChatOllama(\n", " model=\"deepseek-r1\",\n", " temperature=temperature,\n", " num_ctx=num_ctx,\n", " num_thread=1,\n", " num_predict=num_predict,\n", " ),\n", "}\n", "llm_models = {\n", " \"Ollama_llama32\": ChatOllama(\n", " model=\"llama3.2\",\n", " temperature=temperature,\n", " num_ctx=num_ctx,\n", " num_thread=1,\n", " num_predict=num_predict,\n", " ),\n", " \"Ollama_nemotron\": ChatOllama(\n", " model=\"nemotron-mini\",\n", " temperature=temperature,\n", " num_ctx=num_ctx,\n", " num_thread=1,\n", " num_predict=num_predict,\n", " ),\n", " \"Ollama_phi3\": ChatOllama(\n", " model=\"phi3\",\n", " temperature=temperature,\n", " num_ctx=num_ctx,\n", " num_thread=1,\n", " num_predict=num_predict,\n", " ),\n", " \"Ollama_phi4\": ChatOllama(\n", " model=\"phi4\",\n", " temperature=temperature,\n", " num_ctx=num_ctx,\n", " num_thread=1,\n", " num_predict=num_predict,\n", " ),\n", " \"Ollama_deepseekr1\": ChatOllama(\n", " model=\"deepseek-r1\",\n", " temperature=temperature,\n", " num_ctx=num_ctx,\n", " num_thread=1,\n", " num_predict=num_predict,\n", " ),\n", " \"fireworks_llama31\": ChatFireworks(\n", " model_name=\"accounts/fireworks/models/llama-v3p1-70b-instruct\",\n", " api_key=FIREWORKS_API_KEY,\n", " temperature=temperature,\n", " timeout=timeout,\n", " ),\n", " \"fireworks_llama32\": ChatFireworks(\n", " model_name=\"accounts/fireworks/models/llama-v3p2-3b-instruct\",\n", " api_key=FIREWORKS_API_KEY,\n", " temperature=temperature,\n", " timeout=timeout,\n", " ),\n", " \"fireworks_llama33\": ChatFireworks(\n", " model_name=\"accounts/fireworks/models/llama-v3p3-70b-instruct\",\n", " api_key=FIREWORKS_API_KEY,\n", " temperature=temperature,\n", " timeout=timeout,\n", " ),\n", " \"fireworks_qwen25\": ChatFireworks(\n", " model_name=\"accounts/fireworks/models/qwen2p5-72b-instruct\",\n", " api_key=FIREWORKS_API_KEY,\n", " temperature=temperature,\n", " timeout=timeout,\n", " ),\n", "}\n", "llm_models_with_anthropic = {\n", " **llm_models,\n", " \"Anthropic_Sonnet_35\": ChatAnthropic(\n", " model=\"claude-3-5-sonnet-20241022\",\n", " api_key=ANTHROPIC_API_KEY,\n", " timeout=timeout,\n", " ),\n", " \"Anthropic_Haiku_35\": ChatAnthropic(\n", " model=\"claude-3-5-haiku-20241022\",\n", " api_key=ANTHROPIC_API_KEY,\n", " timeout=timeout,\n", " ),\n", " \"Anthropic_Haiku_3\": ChatAnthropic(\n", " model=\"claude-3-haiku-20240307\",\n", " api_key=ANTHROPIC_API_KEY,\n", " timeout=timeout,\n", " ),\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Problem setup and prompt\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "test_science_prompt_txt = \"\"\"\n", "You are a professional science writer tasked with responding to members of\n", "the general public who write in asking questions about science.\n", "Write an article responding to a writer's question for publication in a\n", "science magazine intended for a general readership with a high-school education.\n", "You should write clearly and compellingly, include all relavent context,\n", "and provide motivating stories where applicable.\n", "\n", "Your response must be less than 200 words.\n", "\n", "The question given to you is the following:\n", "{question}\n", "\"\"\"\n", "\n", "questions = [\n", " \"What is the oldest recorded fossil?\",\n", " \"What is a black hole?\",\n", " \"How far away is the sun?\",\n", " \"Which other planet in the Solar System has a surface gravity closest to that of the Earth?\",\n", " \"Eris, Haumea, Makemake and Ceres are all examples of what?\",\n", " \"Why does earth have seasons? Do other planets exhibit seasons too?\",\n", " \"What causes the aurora borealis?\",\n", " \"Why is the sky blue?\",\n", " \"How do bees communicate?\",\n", " \"What is the smallest unit of life?\",\n", " \"How do plants make their own food?\",\n", " \"Why do we dream?\",\n", " \"What is the theory of relativity?\",\n", " \"How do volcanoes erupt?\",\n", " \"What is the speed of light?\",\n", " \"How do magnets work?\",\n", " \"What is the purpose of DNA?\",\n", " \"What are the different types of galaxies?\",\n", " \"Why do some animals hibernate?\",\n", " \"How do vaccines work?\",\n", "]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "prompt_direct = ChatPromptTemplate.from_template(test_science_prompt_txt)\n", "\n", "prompt_system_format = ChatPromptTemplate.from_messages(\n", " [\n", " (\n", " \"system\",\n", " \"Return a publishable article in the requested format.\\n{format_instructions}\",\n", " ),\n", " (\"human\", test_science_prompt_txt),\n", " ]\n", ")\n", "\n", "prompt_system_plus_reminder_format = ChatPromptTemplate.from_messages(\n", " [\n", " (\n", " \"system\",\n", " \"Return a publishable article in the requested format.\\n{format_instructions}\",\n", " ),\n", " (\n", " \"human\",\n", " test_science_prompt_txt + \"\\nYour response must be in valid XML.\",\n", " ),\n", " ]\n", ")\n", "\n", "prompt_user_format = ChatPromptTemplate.from_template(\n", " test_science_prompt_txt + \"\\n{format_instructions}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Schema using Pydantic XML\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from pydantic_xml import BaseXmlModel, element, attr\n", "\n", "\n", "class ArticleResponse1XML(BaseXmlModel, tag=\"article\"):\n", " \"\"\"Structured article for publication answering a reader's question.\"\"\"\n", "\n", " title: str = element(description=\"Title of the article\")\n", " answer: str = element(\n", " description=\"Provide a detailed description of historical events to answer the question\"\n", " )\n", " number: int = element(description=\"A number that is most relevant to the question.\")\n", "\n", "\n", "class ArticleResponse1nointXML(BaseXmlModel, tag=\"article\"):\n", " \"\"\"Structured article for publication answering a reader's question.\"\"\"\n", "\n", " title: str = element(description=\"Title of the article\")\n", " answer: str = element(\n", " description=\"Provide a detailed description of historical events to answer the question\"\n", " )\n", " number: str = element(description=\"A number that is most relevant to the question.\")\n", "\n", "\n", "# Lists of simple types\n", "class ArticleResponse2XML(BaseXmlModel, tag=\"article\"):\n", " \"\"\"Structured article for publication answering a reader's question.\"\"\"\n", "\n", " title: str = element(description=\"Title of the article\")\n", " answer: str = element(description=\"Answer the writer's question\")\n", " further_questions: list[str] = element(\n", " tag=\"further_question\",\n", " description=\"A list of related questions that may be of interest to the readers.\",\n", " )\n", "\n", "\n", "# Nested types\n", "class HistoricalEventXML(BaseXmlModel):\n", " \"\"\"The year and explanation of a historical event.\"\"\"\n", "\n", " year: str = element(description=\"The year of the historical event\")\n", " event: str = element(\n", " description=\"A clear and concise explanation of what happened in this event\"\n", " )\n", "\n", "\n", "class ArticleResponse3XML(BaseXmlModel, tag=\"article\"):\n", " \"\"\"Structured article for publication answering a reader's question.\"\"\"\n", "\n", " title: str = element(description=\"[Title of the article]\")\n", " historical_event_1: HistoricalEventXML = element(\n", " description=\"A first historical event relevant to the question\"\n", " )\n", " historical_event_2: HistoricalEventXML = element(\n", " description=\"A second historical event relevant to the question\"\n", " )\n", "\n", "\n", "# Lists of custom types\n", "class ArticleResponse4XML(BaseXmlModel, tag=\"article\"):\n", " \"\"\"Structured article for publication answering a reader's question.\"\"\"\n", "\n", " title: str = element(description=\"Title of the article\")\n", " historical_timeline: list[HistoricalEventXML] = element(\n", " description=\"A list of historical events relevant to the question\"\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "structured_formats_xml = [\n", " dict(pydantic=schema, format_instructions=pydantic_to_xml_instructions(schema))\n", " for schema in [\n", " ArticleResponse1XML,\n", " ArticleResponse1nointXML,\n", " ArticleResponse2XML,\n", " ArticleResponse3XML,\n", " ArticleResponse4XML,\n", " ]\n", "]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from langchain.output_parsers import XMLOutputParser" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "xml_output_parsers = [\n", " XMLOutputParser(name=\"article\", tags=[\"article\", \"title\", \"answer\", \"number\"]),\n", " XMLOutputParser(\n", " name=\"article\", tags=[\"article\", \"title\", \"answer\", \"further_question\"]\n", " ),\n", " XMLOutputParser(\n", " name=\"article\",\n", " tags=[\n", " \"article\",\n", " \"title\",\n", " \"historical_event_1\",\n", " \"year\",\n", " \"event\",\n", " \"historical_event_2\",\n", " \"year\",\n", " \"event\",\n", " ],\n", " ),\n", " XMLOutputParser(\n", " name=\"article\", tags=[\"article\", \"title\", \"historical_event\", \"year\", \"event\"]\n", " ),\n", "]\n", "structured_formats_output_parser_xml = [\n", " dict(pydantic=schema, format_instructions=op.get_format_instructions())\n", " for schema, op in zip(\n", " [\n", " ArticleResponse1nointXML,\n", " ArticleResponse2XML,\n", " ArticleResponse3XML,\n", " ArticleResponse4XML,\n", " ],\n", " xml_output_parsers,\n", " )\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Example prompt\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "You must respond only in XML using the following schema. Do not provide any explanation outside the XML.\n", "
\n", " \n", " \n", " {Title of the article - must be type str}\n", " \n", " \n", " {Answer the writer's question - must be type str}\n", " \n", "\n", " \n", " {A list of related questions that may be of interest to the readers. - must be type str}\n", " \n", "\n", " \n", " {A list of related questions that may be of interest to the readers. - must be type str}\n", " \n", "\n", " \n", " ...\n", " \n", "
\n" ] } ], "source": [ "print(pydantic_to_xml_instructions(ArticleResponse2XML))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluation\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "System prompt\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: Ollama_llama32 Output: ArticleResponse1XML Pos: 1\n", "Error: ValidationError\n", "..Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", "..Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", "..Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", ".Error: ValidationError\n", ".\n", "Model: Ollama_llama32 Output: ArticleResponse1nointXML Pos: 2\n", "....................\n", "Model: Ollama_llama32 Output: ArticleResponse2XML Pos: 3\n", ".........Error: XMLSyntaxError\n", "....Error: XMLSyntaxError\n", ".......\n", "Model: Ollama_llama32 Output: ArticleResponse3XML Pos: 4\n", "Error: ValidationError\n", "..Error: ValidationError\n", "." ] } ], "source": [ "if \"structure_support_by_model_sp\" not in locals():\n", " structure_support_by_model_sp = {}\n", "\n", "_ = run_xml_experiment(\n", " prompt_system_format,\n", " questions,\n", " llm_models,\n", " structured_formats_xml,\n", " n_iter=1,\n", " results_out=structure_support_by_model_sp,\n", " save_file_name=f\"exp5_xml_output_sys_{experiment_date}.pkl\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "User prompt\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"structure_support_by_model_up\" not in locals():\n", " structure_support_by_model_up = {}\n", "\n", "_ = run_xml_experiment(\n", " prompt_user_format,\n", " questions,\n", " llm_models,\n", " structured_formats_xml,\n", " n_iter=1,\n", " results_out=structure_support_by_model_up,\n", " save_file_name=f\"exp5_xml_output_user_{experiment_date}.pkl\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "System prompt with \"format reminder\"\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"structure_support_by_model_sprem\" not in locals():\n", " structure_support_by_model_sprem = {}\n", "\n", "_ = run_xml_experiment(\n", " prompt_system_plus_reminder_format,\n", " questions,\n", " llm_models,\n", " structured_formats_xml,\n", " n_iter=1,\n", " results_out=structure_support_by_model_sprem,\n", " save_file_name=f\"exp5_xml_output_sys_w_reminder_{experiment_date}.pkl\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Output parsers with system prompt\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if \"structure_support_by_model_parsers\" not in locals():\n", " structure_support_by_model_parsers = {}\n", "\n", "_ = run_xml_experiment(\n", " prompt_system_format,\n", " questions,\n", " llm_models_test,\n", " structured_formats_output_parser_xml,\n", " n_iter=1,\n", " results_out=structure_support_by_model_parsers,\n", " save_file_name=f\"exp5_xml_output_parser_{experiment_date}.pkl\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Extract all error messages & count\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def results_to_df(ss_results, key=\"valid\"):\n", " df = pd.DataFrame.from_dict(\n", " {\n", " mname: {\n", " tname: ss_results[mname][tname][key] * 100 / n_questions\n", " for tname in ss_results[mname].keys()\n", " }\n", " for mname in ss_results.keys()\n", " },\n", " orient=\"index\",\n", " )\n", " return df\n", "\n", "\n", "def analyse_errors_from_results(ss_results, method=\"code\"):\n", " error_counts = {}\n", " for mname in ss_results.keys():\n", " error_counts[mname] = {}\n", " for tname in ss_results[mname].keys():\n", " validation_error = 0\n", " json_error = 0\n", " unknown_error = 0\n", "\n", " # Count errors by failure code above\n", " if method == \"code\":\n", " error_types = pd.Series(\n", " output[\"error_type\"]\n", " for output in ss_results[mname][tname][\"outputs\"]\n", " )\n", " error_codes = error_types.value_counts()\n", "\n", " for e_name, e_count in error_codes.items():\n", " error_counts[mname][(tname, e_name)] = e_count\n", "\n", " elif method == \"parse\":\n", " # Count errors by parsing error message\n", " errors = (\n", " output[\"error_message\"]\n", " for output in ss_results[mname][tname][\"outputs\"]\n", " )\n", " for error in errors:\n", " if error is None:\n", " continue\n", " if error.lower().find(\"opening and ending tag mismatch\") >= 0:\n", " error_str = \"XML tag mismatch\"\n", " elif error.lower().find(\"extracterror\") >= 0:\n", " error_str = \"Missing main tags\"\n", " elif error.lower().find(\"input should be a valid integer\") >= 0:\n", " error_str = \"Validation error (int)\"\n", " elif error.lower().find(\"premature end of data in tag\") >= 0:\n", " error_str = \"Premature end\"\n", " elif error.lower().find(\"field required\") >= 0:\n", " error_str = \"Missing field\"\n", " elif error.lower().find(\"expected '>'\") >= 0:\n", " error_str = \"Tag malformed\"\n", " elif (\n", " error.lower().find(\"extra content at the end of the document\")\n", " >= 0\n", " ):\n", " error_str = \"Tag malformed\"\n", " else:\n", " error_str = error\n", "\n", " error_counts[mname][(tname, error_str)] = (\n", " error_counts[mname].get((tname, error_str), 0) + 1\n", " )\n", "\n", " else:\n", " raise NameError(f\"Method {method} not supported\")\n", "\n", " return pd.DataFrame.from_dict(error_counts, orient=\"index\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "analyse_errors_from_results(structure_support_by_model_sp, method=\"parse\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if 0:\n", " for output in structure_support_by_model_sp[\"Ollama_deepseekr1\"][\n", " \"ArticleResponse3XML\"\n", " ][\"outputs\"]:\n", " if output[\"error_type\"] != \"ok\":\n", " print(output[\"error_message\"], \"\\n\")\n", " print(output[\"raw\"].content)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results_list = {\n", " \"System Prompt\": structure_support_by_model_sp,\n", " \"System Prompt (FM)\": structure_support_by_model_sprem,\n", " \"User Prompt\": structure_support_by_model_up,\n", " \"Output Parsers\": structure_support_by_model_parsers,\n", "}\n", "\n", "df_results = {}\n", "for name, ss_results in results_list.items():\n", " df_results[name] = pd.DataFrame.from_dict(\n", " {\n", " mname: {\n", " tname: ss_results[mname][tname][\"valid\"] * 100\n", " for tname in ss_results[mname].keys()\n", " }\n", " for mname in ss_results.keys()\n", " },\n", " orient=\"index\",\n", " )\n", " display(name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.concat(df_results).swaplevel(axis=0).sort_index(axis=0)\n", "df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }