{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "import stackprinter # type: ignore\n", "import jupyter_black # type: ignore\n", "from dotenv import load_dotenv # type: ignore\n", "from typing import Literal\n", "from baml_client.async_client import b\n", "\n", "from baml_agents import ActionRunner\n", "from baml_agents.jupyter import JupyterBamlMonitor\n", "from baml_agents import init_logging, with_model, Action, Result\n", "from baml_client import types\n", "from notebooks._utils import (\n", " celsius_to_fahrenheit,\n", " city_to_number,\n", " city_to_weather_condition,\n", ")\n", "\n", "init_logging(level=\"INFO\")\n", "# stackprinter.set_excepthook()\n", "load_dotenv()\n", "jupyter_black.load()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "b_low_latency = with_model(b, \"gpt-4.1-nano\")\n", "\n", "\n", "async def summarize(action, result):\n", " return await b_low_latency.SummarizeAction(\n", " action=json.dumps(action.chosen_action, indent=4),\n", " result=result.content if not result.error else result,\n", " )\n", "\n", "\n", "def new_interaction(action, result):\n", " return types.Interaction(\n", " action=str(action),\n", " result=types.Result(content=result.content, error=result.error),\n", " )\n", "\n", "\n", "def is_result_available(action) -> str | None:\n", " if action.chosen_action[\"action_id\"] != Stop.get_action_id(): # type: ignore\n", " return None\n", " return action.chosen_action[\"final_result\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Simple BAML Agent demo" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's put it all together:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Result(content='The weather in New York is 57.2 degrees fahrenheit with foggy conditions.', error=False)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Stop(Action):\n", " \"\"\"You're ready to provide the final answer or are unable to continue (e.g. stuck in a loop).\"\"\"\n", "\n", " final_result: str\n", "\n", " def run(self) -> Result:\n", " raise NotImplementedError(\"Stop action should not be called directly.\")\n", "\n", "\n", "class GetWeatherInfo(Action):\n", " \"\"\"Get weather information for a given city.\"\"\"\n", "\n", " city: str\n", " measurement: Literal[\"celsius\", \"fahrenheit\"] | None = None\n", "\n", " def run(self) -> Result:\n", " self.measurement = self.measurement or \"celsius\"\n", " c = city_to_number(self.city, -10, 35)\n", " condition = city_to_weather_condition(self.city)\n", " if self.measurement.lower() == \"fahrenheit\": # type: ignore\n", " c, u = celsius_to_fahrenheit(c), \"fahrenheit\"\n", " else:\n", " u = \"celsius\"\n", " content = f\"The weather in {self.city} is {round(c, 1)} degrees {u} with {condition.lower()} conditions.\"\n", " return Result(content=content, error=False)\n", "\n", "\n", "GetWeatherInfo(city=\"New York\", measurement=\"fahrenheit\").run()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Queried the current local date and time.\n", "Retrieved current weather information for Los Angeles in Fahrenheit.\n", "Retrieved the weather information for New York in Fahrenheit (57.2°F, foggy conditions).\n", "Retrieved the current weather in Chicago: 37.4°F with cloudy conditions.\n", "Calculated the average of 50.0, 57.2, and 37.4.\n" ] }, { "data": { "text/plain": [ "'The current date is 2025-05-03. The average temperature between Los Angeles, New York, and Chicago is 48.2°F.'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from baml_client.async_client import b\n", "from baml_client.type_builder import TypeBuilder\n", "\n", "b = with_model(b, \"gpt-4.1\")\n", "r = ActionRunner(TypeBuilder, b=b, cache=True)\n", "r.add_from_mcp_server(server=\"uvx mcp-server-calculator\")\n", "r.add_from_mcp_server(server=\"uvx mcp-timeserver\") # Note: you can also add URLs\n", "r.add_action(GetWeatherInfo)\n", "r.add_action(Stop)\n", "\n", "\n", "async def execute_task(r, b, task: str) -> str:\n", " interactions = []\n", " while True:\n", " action = await b.GetNextAction(\n", " task, interactions, baml_options={\"tb\": r.tb(\"NextAction\")}\n", " )\n", "\n", " if result := is_result_available(action):\n", " return result\n", "\n", " result = r.run(action)\n", "\n", " interactions.append(new_interaction(action, result))\n", " print(await summarize(action, result))\n", "\n", "\n", "task = \"State the current date along with average temperature between LA, NY and Chicago in Fahrenheit.\"\n", "await execute_task(r, b, task)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We got the answer but we'd like more transparency on the what are the exact prompts and completions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Inspect prompts and completions" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
{'chosen_action': {'action_id': 'stop', 'final_result': 'The current date is 2025-05-03. The average temperature between Los Angeles, New York, and Chicago is approximately 48.2 degrees Fahrenheit.'}}"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Queried the current local date and time.\n",
"Retrieved current weather information for Los Angeles in Fahrenheit.\n",
"Retrieved the weather information for New York in Fahrenheit (57.2°F, foggy conditions).\n",
"Retrieved the current weather in Chicago: 37.4°F with cloudy conditions.\n",
"Calculated the average of 50.0, 57.2, and 37.4.\n"
]
},
{
"data": {
"application/javascript": "\n (function(){\n var el = document.getElementById(\"stream-d7b65667-d72b-4925-a254-b1ac667d927d\");\n if (el) el.remove();\n })();\n ",
"text/plain": [
"