{"cells": [{"attachments": {}, "cell_type": "markdown", "id": "1c0beb97", "metadata": {}, "source": ["\n"]}, {"attachments": {}, "cell_type": "markdown", "id": "8cd3f128-866a-4857-a00a-df19f926c952", "metadata": {}, "source": ["# Evaporate演示\n", "\n", "这个演示展示了如何使用Evaporate论文(Arora等人)从原始文本中提取DataFrame:https://arxiv.org/abs/2304.09433。\n", "\n", "灵感是首先在一组训练文本上进行“拟合”。拟合过程使用LLM从文本生成一组解析函数。\n", "这些拟合的函数然后在推理时应用于文本。\n"]}, {"attachments": {}, "cell_type": "markdown", "id": "7c152e41", "metadata": {}, "source": ["如果您在colab上打开这个笔记本,您可能需要安装LlamaIndex 🦙。\n"]}, {"cell_type": "code", "execution_count": null, "id": "b6cc721a", "metadata": {}, "outputs": [], "source": ["%pip install llama-index-llms-openai\n", "%pip install llama-index-program-evaporate"]}, {"cell_type": "code", "execution_count": null, "id": "3096cbad", "metadata": {}, "outputs": [], "source": ["!pip install llama-index"]}, {"cell_type": "code", "execution_count": null, "id": "db7210f2-f19d-4112-ab72-ddb3afe282f7", "metadata": {}, "outputs": [], "source": ["%load_ext autoreload\n", "%autoreload 2"]}, {"attachments": {}, "cell_type": "markdown", "id": "da19d340-57b5-439f-9cb1-5ba9576ec304", "metadata": {}, "source": ["## 使用 `DFEvaporateProgram`\n", "\n", "`DFEvaporateProgram` 将从一组数据点中提取出一个二维数据框,给定一组字段和一些训练数据以在其上“拟合”一些函数。\n"]}, {"attachments": {}, "cell_type": "markdown", "id": "a299cad8-af81-4974-a3de-ed43877d3490", "metadata": {}, "source": ["### 加载数据\n", "\n", "在这里,我们从维基百科加载了一组城市数据。\n"]}, {"cell_type": "code", "execution_count": null, "id": "daf434f6-3b27-4805-9de8-8fc92d7d776b", "metadata": {}, "outputs": [], "source": ["wiki_titles = [\"Toronto\", \"Seattle\", \"Chicago\", \"Boston\", \"Houston\"]"]}, {"cell_type": "code", "execution_count": null, "id": "8438168c-3b1b-425e-98b0-2c67a8a58a5f", "metadata": {}, "outputs": [], "source": ["from pathlib import Path\n", "\n", "import requests\n", "\n", "for title in wiki_titles:\n", " response = requests.get(\n", " \"https://en.wikipedia.org/w/api.php\",\n", " params={\n", " \"action\": \"query\",\n", " \"format\": \"json\",\n", " \"titles\": title,\n", " \"prop\": \"extracts\",\n", " # 'exintro': True,\n", " \"explaintext\": True,\n", " },\n", " ).json()\n", " page = next(iter(response[\"query\"][\"pages\"].values()))\n", " wiki_text = page[\"extract\"]\n", "\n", " data_path = Path(\"data\")\n", " if not data_path.exists():\n", " Path.mkdir(data_path)\n", "\n", " with open(data_path / f\"{title}.txt\", \"w\") as fp:\n", " fp.write(wiki_text)"]}, {"cell_type": "code", "execution_count": null, "id": "c01dbcb8-5ea1-4e76-b5de-ea5ebe4f0392", "metadata": {}, "outputs": [], "source": ["from llama_index.core import SimpleDirectoryReader\n", "\n", "# 加载所有维基文档\n", "city_docs = {}\n", "for wiki_title in wiki_titles:\n", " city_docs[wiki_title] = SimpleDirectoryReader(\n", " input_files=[f\"data/{wiki_title}.txt\"]\n", " ).load_data()"]}, {"attachments": {}, "cell_type": "markdown", "id": "e7310883-2aeb-4a4d-b101-b3279e670ea8", "metadata": {}, "source": ["### 解析数据\n"]}, {"cell_type": "code", "execution_count": null, "id": "b8e98279-b4c4-41ec-b696-13e6a6f841a4", "metadata": {}, "outputs": [], "source": ["from llama_index.llms.openai import OpenAI\n", "from llama_index.core import Settings\n", "\n", "# 设置设置\n", "Settings.llm = OpenAI(temperature=0, model=\"gpt-3.5-turbo\")\n", "Settings.chunk_size = 512"]}, {"cell_type": "code", "execution_count": null, "id": "74c6c1c3-b797-45c8-b692-7a6e4bd1898d", "metadata": {}, "outputs": [], "source": ["# 为每个文档获取节点\n", "city_nodes = {}\n", "for wiki_title in wiki_titles:\n", " docs = city_docs[wiki_title]\n", " nodes = Settings.node_parser.get_nodes_from_documents(docs)\n", " city_nodes[wiki_title] = nodes"]}, {"attachments": {}, "cell_type": "markdown", "id": "bb369a78-e634-43f4-805e-52f6ea0f3588", "metadata": {}, "source": ["### 运行DFEvaporateProgram\n", "\n", "在这里,我们演示了如何使用我们的`DFEvaporateProgram`提取数据点。给定一组字段,`DFEvaporateProgram`可以首先在一组训练数据上拟合函数,然后在推断数据上运行提取操作。\n"]}, {"cell_type": "code", "execution_count": null, "id": "6c260836", "metadata": {}, "outputs": [], "source": ["from llama_index.program.evaporate import DFEvaporateProgram\n", "\n", "# 定义程序\n", "program = DFEvaporateProgram.from_defaults(\n", " fields_to_extract=[\"population\"],\n", ")"]}, {"attachments": {}, "cell_type": "markdown", "id": "c548768e-9d4a-4708-9c84-9266503edf01", "metadata": {}, "source": ["在本节中,我们将讨论如何使用Python中的`scipy.optimize.curve_fit`函数来拟合函数到数据。拟合函数是一种将数学模型与实际数据相匹配的方法,它可以用于预测、分析和理解数据。我们将使用`curve_fit`函数来拟合一个简单的线性函数,并讨论如何处理非线性函数的拟合。\n"]}, {"cell_type": "code", "execution_count": null, "id": "6c186eb7-116f-4b28-a508-8639cbc86633", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'population': 'def get_population_field(text: str):\\n \"\"\"\\n Function to extract population. \\n \"\"\"\\n \\n # Use regex to find the population field\\n pattern = r\\'(?<=population of )(\\\\d+,?\\\\d*)\\'\\n population_field = re.search(pattern, text).group(1)\\n \\n # Return the population field as a single value\\n return int(population_field.replace(\\',\\', \\'\\'))'}"]}, "execution_count": null, "metadata": {}, "output_type": "execute_result"}], "source": ["program.fit_fields(city_nodes[\"Toronto\"][:1])"]}, {"cell_type": "code", "execution_count": null, "id": "483676a4-4937-40a8-acd9-8fec4a991270", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["def get_population_field(text: str):\n", " \"\"\"\n", " Function to extract population. \n", " \"\"\"\n", " \n", " # Use regex to find the population field\n", " pattern = r'(?<=population of )(\\d+,?\\d*)'\n", " population_field = re.search(pattern, text).group(1)\n", " \n", " # Return the population field as a single value\n", " return int(population_field.replace(',', ''))\n"]}], "source": ["# 查看提取的函数\n", "print(program.get_function_str(\"population\"))"]}, {"attachments": {}, "cell_type": "markdown", "id": "508a442c-d7d8-4a27-8add-1d58f1ecc66b", "metadata": {}, "source": ["### 运行推断\n"]}, {"cell_type": "code", "execution_count": null, "id": "83e38b62-bad0-4154-9597-555a27e976d9", "metadata": {}, "outputs": [], "source": ["seattle_df = program(nodes=city_nodes[\"Seattle\"][:1])"]}, {"cell_type": "code", "execution_count": null, "id": "dc72f611-da8b-4882-b532-69e46b9589bb", "metadata": {}, "outputs": [{"data": {"text/plain": ["DataFrameRowsOnly(rows=[DataFrameRow(row_values=[749256])])"]}, "execution_count": null, "metadata": {}, "output_type": "execute_result"}], "source": ["seattle_df"]}, {"attachments": {}, "cell_type": "markdown", "id": "9465ba41-8318-40bb-a202-49df6e3c16e3", "metadata": {}, "source": ["## 使用 `MultiValueEvaporateProgram`\n", "\n", "与假设输出符合2D表格格式(每个节点一行)的 `DFEvaporateProgram` 相比,`MultiValueEvaporateProgram` 返回一个 `DataFrameRow` 对象列表 - 每个对象对应一列,并且可以包含可变长度的值。如果我们想要从给定的文本中提取一个字段的多个值,这将会很有帮助。\n", "\n", "在这个例子中,我们使用这个程序来解析金牌计数。\n"]}, {"cell_type": "code", "execution_count": null, "id": "c3d5e9dd-0d20-447b-96b2-a82f8350e430", "metadata": {}, "outputs": [], "source": ["Settings.llm = OpenAI(temperature=0, model=\"gpt-4\")\n", "Settings.chunk_size = 1024\n", "Settings.chunk_overlap = 0"]}, {"cell_type": "code", "execution_count": null, "id": "08b44698-4f7e-4686-9b6e-1b77c341a778", "metadata": {}, "outputs": [], "source": ["from llama_index.core.data_structs import Node\n", "\n", "# 奥运会奖牌总数:https://en.wikipedia.org/wiki/All-time_Olympic_Games_medal_table\n", "\n", "train_text = \"\"\"\n", "
队伍(IOC代码)\n", " | \n", "夏季次数\n", " | \n", "冬季次数\n", " | \n", "总次数\n", " |
---|---|---|---|
阿尔巴尼亚 (ALB)\n", " | \n", "9 | \n", "5 | \n", "14\n", " |
美属萨摩亚 (ASA)\n", " | \n", "9 | \n", "2 | \n", "11\n", " |
安道尔 (AND)\n", " | \n", "12 | \n", "13 | \n", "25\n", " |
安哥拉 (ANG)\n", " | \n", "10 | \n", "0 | \n", "10\n", " |
安提瓜和巴布达 (ANT)\n", " | \n", "11 | \n", "0 | \n", "11\n", " |
阿鲁巴 (ARU)\n", " | \n", "9 | \n", "0 | \n", "9\n", " | Bangladesh (BAN)\n", " | \n", "10 | \n", "0 | \n", "10\n", " | \n", "
Belize (BIZ) [BIZ]\n", " | \n", "13 | \n", "0 | \n", "13\n", " |
Benin (BEN) [BEN]\n", " | \n", "12 | \n", "0 | \n", "12\n", " |
Bhutan (BHU)\n", " | \n", "10 | \n", "0 | \n", "10\n", " |
Bolivia (BOL)\n", " | \n", "15 | \n", "7 | \n", "22\n", " |
Bosnia and Herzegovina (BIH)\n", " | \n", "8 | \n", "8 | \n", "16\n", " |
British Virgin Islands (IVB)\n", " | \n", "10 | \n", "2 | \n", "12\n", " |
Brunei (BRU) [A]\n", " | \n", "6 | \n", "0 | \n", "6\n", " |
Cambodia (CAM)\n", " | \n", "10 | \n", "0 | \n", "10\n", " |
Cape Verde (CPV)\n", " | \n", "7 | \n", "0 | \n", "7\n", " |
Cayman Islands (CAY)\n", " | \n", "11 | \n", "2 | \n", "13\n", " |
Central African Republic (CAF)\n", " | \n", "11 | \n", "0 | \n", "11\n", " |
Chad (CHA)\n", " | \n", "13 | \n", "0 | \n", "13\n", " |
Comoros (COM)\n", " | \n", "7 | \n", "0 | \n", "7\n", " |
Republic of the Congo (CGO)\n", " | \n", "13 | \n", "0 | \n", "13\n", " |
Democratic Republic of the Congo (COD) [COD]\n", " | \n", "11 | \n", "0 | \n", "11\n", " | (.*?) | \\', text)\\n \\n # Return the result as a list\\n return medal_count_field'}"]}, "execution_count": null, "metadata": {}, "output_type": "execute_result"}], "source": ["program.fit_fields(train_nodes[:1])"]}, {"cell_type": "code", "execution_count": null, "id": "cc32440c-910a-483c-81df-80ae81fedb2d", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["def get_countries_field(text: str):\n", " \"\"\"\n", " Function to extract countries. \n", " \"\"\"\n", " \n", " # Use regex to extract the countries field\n", " countries_field = re.findall(r'(.*)', text)\n", " \n", " # Return the result as a list\n", " return countries_field\n"]}], "source": ["print(program.get_function_str(\"countries\"))"]}, {"cell_type": "code", "execution_count": null, "id": "8ed16aa9-8b36-439a-a596-1b90d6775a30", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["def get_medal_count_field(text: str):\n", " \"\"\"\n", " Function to extract medal_count. \n", " \"\"\"\n", " \n", " # Use regex to extract the medal count field\n", " medal_count_field = re.findall(r'(.*?) | ', text)\n", " \n", " # Return the result as a list\n", " return medal_count_field\n"]}], "source": ["print(program.get_function_str(\"medal_count\"))"]}, {"cell_type": "code", "execution_count": null, "id": "8f7bae5f-ee4e-4d9f-b551-1986efd317b3", "metadata": {}, "outputs": [], "source": ["result = program(nodes=infer_nodes[:1])"]}, {"cell_type": "code", "execution_count": null, "id": "85bc4f9c-9e6c-41da-b6fb-a8b227b3ce67", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Countries: ['Bangladesh', '[BIZ]', '[BEN]', 'Bhutan', 'Bolivia', 'Bosnia and Herzegovina', 'British Virgin Islands', '[A]', 'Cambodia', 'Cape Verde', 'Cayman Islands', 'Central African Republic', 'Chad', 'Comoros', 'Republic of the Congo', '[COD]']\n", "\n", "Medal Counts: ['Bangladesh', '[BIZ]', '[BEN]', 'Bhutan', 'Bolivia', 'Bosnia and Herzegovina', 'British Virgin Islands', '[A]', 'Cambodia', 'Cape Verde', 'Cayman Islands', 'Central African Republic', 'Chad', 'Comoros', 'Republic of the Congo', '[COD]']\n", "\n"]}], "source": ["# 输出国家\n", "print(f\"国家: {result.columns[0].row_values}\\n\")\n", "# 输出奖牌计数\n", "print(f\"奖牌计数: {result.columns[0].row_values}\\n\")"]}, {"attachments": {}, "cell_type": "markdown", "id": "820768fe-aa23-4999-bcc1-102e6fc817e5", "metadata": {}, "source": ["## 奖励:使用底层的`EvaporateExtractor`\n", "\n", "底层的`EvaporateExtractor`提供了一些额外的功能,例如实际上帮助在一组文本中识别字段。\n", "\n", "在这里,我们展示了如何使用`identify_fields`来确定围绕一个通用的`topic`字段的相关字段。\n"]}, {"cell_type": "code", "execution_count": null, "id": "7ff32b4b-a85b-4266-bdf1-7fa492925034", "metadata": {}, "outputs": [], "source": ["# 一个节点列表,每个城市对应一个节点,对应于介绍段落\n", "# 城市人口节点 = []\n", "城市人口节点 = [城市节点[\"多伦多\"][0], 城市节点[\"西雅图\"][0]]"]}, {"cell_type": "code", "execution_count": null, "id": "dc96646f-ac7e-407f-87dd-c14c8d83aa84", "metadata": {}, "outputs": [], "source": ["extractor = program.extractor"]}, {"cell_type": "code", "execution_count": null, "id": "1df3a7df-6d00-4487-b114-f45a6dba4764", "metadata": {}, "outputs": [], "source": ["# 尝试使用多伦多和西雅图(应该提取“人口”)进行测试\n", "existing_fields = extractor.identify_fields(\n", " city_pop_nodes, topic=\"population\", fields_top_k=4\n", ")"]}, {"cell_type": "code", "execution_count": null, "id": "d8a56bb6-3a26-40db-9ca3-8aa9ed4f2c52", "metadata": {}, "outputs": [{"data": {"text/plain": ["[\"seattle metropolitan area's population\"]"]}, "execution_count": null, "metadata": {}, "output_type": "execute_result"}], "source": ["existing_fields"]}], "metadata": {"kernelspec": {"display_name": "llama_index_v2", "language": "python", "name": "llama_index_v2"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3"}}, "nbformat": 4, "nbformat_minor": 5}