{"cells": [{"attachments": {}, "cell_type": "markdown", "id": "6d2b5335", "metadata": {}, "source": ["\"在\n"]}, {"cell_type": "markdown", "id": "a0b5c122-d577-4045-b980-cab2eae2aa0c", "metadata": {}, "source": ["# 从零开始构建高级融合检索器\n", "\n", "在本教程中,我们将向您展示如何从零开始构建一个高级检索器。\n", "\n", "具体来说,我们将向您展示如何从头开始构建我们的`QueryFusionRetriever`。\n", "\n", "这在很大程度上受到了RAG-fusion仓库的启发,网址为:https://github.com/Raudaschl/rag-fusion。\n"]}, {"cell_type": "markdown", "id": "0d82203e-1aa0-4d85-8a0f-3854dfa81494", "metadata": {}, "source": ["## 设置\n", "\n", "我们加载文档并构建一个简单的向量索引。\n"]}, {"cell_type": "code", "execution_count": null, "id": "9bafe694", "metadata": {}, "outputs": [], "source": ["%pip install llama-index-readers-file pymupdf\n", "%pip install llama-index-llms-openai\n", "%pip install llama-index-retrievers-bm25"]}, {"cell_type": "code", "execution_count": null, "id": "c79e8b40-c963-46ee-9601-6c31e5901568", "metadata": {}, "outputs": [], "source": ["import nest_asyncio\n", "\n", "nest_asyncio.apply()"]}, {"cell_type": "markdown", "id": "41d6148f-3185-4a32-973c-316f23e45804", "metadata": {}, "source": ["#### 加载文档\n"]}, {"cell_type": "code", "execution_count": null, "id": "c054a492-56c9-4dae-bede-06739858ba57", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["--2024-04-03 09:32:31-- https://arxiv.org/pdf/2307.09288.pdf\n", "Resolving arxiv.org (arxiv.org)... 151.101.3.42, 151.101.131.42, 151.101.67.42, ...\n", "Connecting to arxiv.org (arxiv.org)|151.101.3.42|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 13661300 (13M) [application/pdf]\n", "Saving to: ‘data/llama2.pdf’\n", "\n", "data/llama2.pdf 100%[===================>] 13.03M 7.44MB/s in 1.8s \n", "\n", "2024-04-03 09:32:33 (7.44 MB/s) - ‘data/llama2.pdf’ saved [13661300/13661300]\n", "\n"]}], "source": ["!mkdir data\n", "!wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""]}, {"attachments": {}, "cell_type": "markdown", "id": "1126a6d3", "metadata": {}, "source": ["如果您在Colab上打开此笔记本,您可能需要安装LlamaIndex 🦙。\n"]}, {"cell_type": "code", "execution_count": null, "id": "0f03cf99", "metadata": {}, "outputs": [], "source": ["!pip install llama-index"]}, {"cell_type": "code", "execution_count": null, "id": "b3b7ec9e-30cf-49ba-9b3b-9beb9a2b6758", "metadata": {}, "outputs": [], "source": ["from pathlib import Path\n", "from llama_index.readers.file import PyMuPDFReader\n", "\n", "loader = PyMuPDFReader()\n", "documents = loader.load(file_path=\"./data/llama2.pdf\")"]}, {"cell_type": "markdown", "id": "46ea385d", "metadata": {}, "source": ["```python\n", "# 设置模型\n", "```\n", "\n", "这里是设置模型的部分。\n"]}, {"cell_type": "code", "execution_count": null, "id": "0bc3bd76", "metadata": {}, "outputs": [], "source": ["import os\n", "\n", "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\""]}, {"cell_type": "code", "execution_count": null, "id": "75c07062", "metadata": {}, "outputs": [], "source": ["from llama_index.llms.openai import OpenAI\n", "from llama_index.embeddings.openai import OpenAIEmbedding\n", "\n", "llm = OpenAI(model=\"gpt-3.5-turbo\", temperature=0.1)\n", "embed_model = OpenAIEmbedding(\n", " model=\"text-embedding-3-small\", embed_batch_size=256\n", ")"]}, {"cell_type": "markdown", "id": "f59a6cd1-802a-4e69-afd1-de6faf4b064b", "metadata": {}, "source": ["#### 加载到向量存储中\n"]}, {"cell_type": "code", "execution_count": null, "id": "b2423b4b-5c3b-4d36-b338-9bf74f0e6a82", "metadata": {}, "outputs": [], "source": ["from llama_index.core import VectorStoreIndex\n", "from llama_index.core.node_parser import SentenceSplitter\n", "\n", "splitter = SentenceSplitter(chunk_size=1024)\n", "index = VectorStoreIndex.from_documents(\n", " documents, transformations=[splitter], embed_model=embed_model\n", ")"]}, {"cell_type": "markdown", "id": "cbc7a9bf-0ef7-45a5-bc76-3c37a8a09c88", "metadata": {}, "source": ["## 定义高级检索器\n", "\n", "我们定义一个高级检索器,执行以下步骤:\n", "1. 查询生成/重写:根据原始用户查询生成多个查询。\n", "2. 对每个查询在一组检索器上执行检索。\n", "3. 重新排序/融合:融合所有查询的结果,并对“融合”出的前几个相关结果应用重新排序步骤!\n", "\n", "然后在下一节中,我们将把这个模块插入到我们的响应合成模块中。\n"]}, {"cell_type": "markdown", "id": "3586a793-3c5a-4d7c-b401-cd6fb71f87a1", "metadata": {}, "source": ["### 第一步:查询生成/重写\n", "\n", "第一步是从原始查询中生成查询,以更好地匹配查询意图,并提高检索结果的精确度/召回率。例如,我们可以将查询重写为更小的查询。\n", "\n", "我们可以通过提示ChatGPT来实现这一点。\n"]}, {"cell_type": "code", "execution_count": null, "id": "0a5183a0-58ce-4cc5-a74b-8428dfe12bb5", "metadata": {}, "outputs": [], "source": ["from llama_index.core import PromptTemplate"]}, {"cell_type": "code", "execution_count": null, "id": "9e745f03-5c06-43b5-ad9d-65c5b8150ba7", "metadata": {}, "outputs": [], "source": ["query_str = \"How do the models developed in this work compare to open-source chat models based on the benchmarks tested?\""]}, {"cell_type": "code", "execution_count": null, "id": "0a17441d-1f14-4de4-a4b3-55177d0a2dee", "metadata": {}, "outputs": [], "source": ["query_gen_prompt_str = (\n", " \"You are a helpful assistant that generates multiple search queries based on a \"\n", " \"single input query. Generate {num_queries} search queries, one on each line, \"\n", " \"related to the following input query:\\n\"\n", " \"Query: {query}\\n\"\n", " \"Queries:\\n\"\n", ")\n", "query_gen_prompt = PromptTemplate(query_gen_prompt_str)"]}, {"cell_type": "code", "execution_count": null, "id": "5c3a7b04-c4fb-456c-8584-5ea93bdc7bf0", "metadata": {}, "outputs": [], "source": ["def generate_queries(llm, query_str: str, num_queries: int = 4):\n", " fmt_prompt = query_gen_prompt.format(\n", " num_queries=num_queries - 1, query=query_str\n", " )\n", " response = llm.complete(fmt_prompt)\n", " queries = response.text.split(\"\\n\")\n", " return queries"]}, {"cell_type": "code", "execution_count": null, "id": "2a577a95-2b58-424d-aa7d-bed9aa9ceb98", "metadata": {}, "outputs": [], "source": ["queries = generate_queries(llm, query_str, num_queries=4)"]}, {"cell_type": "code", "execution_count": null, "id": "73fe53d2-c556-44ed-a255-374ae8eca494", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["['1. Comparison of models developed in this work to open-source chat models in benchmark testing', '2. Performance evaluation of models developed in this work versus open-source chat models on tested benchmarks', '3. Analysis of differences between models developed in this work and open-source chat models in benchmark assessments']\n"]}], "source": ["print(queries)"]}, {"cell_type": "markdown", "id": "a9876c65-1a90-4606-9104-5df10009d935", "metadata": {}, "source": ["### 步骤2:对每个查询执行向量搜索\n", "\n", "现在我们对每个查询运行检索。这意味着我们从每个向量存储中获取前k个最相关的结果。\n", "\n", "**注意**:我们也可以有多个检索器。那么我们运行的总查询数量为N*M,其中N是检索器的数量,M是生成的查询数量。因此将会有N*M个检索列表。\n", "\n", "在这里,我们将使用从我们的向量存储中提供的检索器。如果您想了解如何从头开始构建这个,请参阅[我们的教程](https://docs.llamaindex.ai/en/latest/examples/low_level/retrieval.html#put-this-into-a-retriever)。\n"]}, {"cell_type": "code", "execution_count": null, "id": "53114651-0b57-4cbc-a07f-d906c5820cb7", "metadata": {}, "outputs": [], "source": ["from tqdm.asyncio import tqdm", "", "", "async def run_queries(queries, retrievers):", " \"\"\"对检索器运行查询。\"\"\"", " tasks = []", " for query in queries:", " for i, retriever in enumerate(retrievers):", " tasks.append(retriever.aretrieve(query))", "", " task_results = await tqdm.gather(*tasks)", "", " results_dict = {}", " for i, (query, query_result) in enumerate(zip(queries, task_results)):", " results_dict[(query, i)] = query_result", "", " return results_dict"]}, {"cell_type": "code", "execution_count": null, "id": "d046d284-ab9e-4242-b91a-8d06c472dfaf", "metadata": {}, "outputs": [], "source": ["# 获取检索器", "from llama_index.retrievers.bm25 import BM25Retriever", "", "", "## 向量检索器", "vector_retriever = index.as_retriever(similarity_top_k=2)", "", "## bm25检索器", "bm25_retriever = BM25Retriever.from_defaults(", " docstore=index.docstore, similarity_top_k=2", ")"]}, {"cell_type": "code", "execution_count": null, "id": "6a0ddc59-1602-4078-b3e9-dadb852709fc", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": [" 0%| | 0/6 [00:00 None:", " \"\"\"初始化参数。\"\"\"", " self._retrievers = retrievers", " self._similarity_top_k = similarity_top_k", " self._llm = llm", " super().__init__()", "", " def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:", " \"\"\"检索。\"\"\"", " queries = generate_queries(", " self._llm, query_bundle.query_str, num_queries=4", " )", " results = asyncio.run(run_queries(queries, self._retrievers))", " final_results = fuse_results(", " results, similarity_top_k=self._similarity_top_k", " )", "", " return final_results"]}, {"cell_type": "code", "execution_count": null, "id": "b2b641b1-e64e-4ddf-9cc5-88ff5c57b70e", "metadata": {}, "outputs": [], "source": ["from llama_index.core.query_engine import RetrieverQueryEngine\n", "\n", "fusion_retriever = FusionRetriever(\n", " llm, [vector_retriever, bm25_retriever], similarity_top_k=2\n", ")\n", "\n", "query_engine = RetrieverQueryEngine(fusion_retriever)"]}, {"cell_type": "code", "execution_count": null, "id": "c0d81b5b-39f2-42da-92b9-ce6113fa43d9", "metadata": {}, "outputs": [], "source": ["response = query_engine.query(query_str)"]}, {"cell_type": "code", "execution_count": null, "id": "93daeaaa-ce68-465a-b246-287714b4b370", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["The models developed in this work, specifically the Llama 2-Chat models, outperform open-source chat models on most benchmarks that were tested.\n"]}], "source": ["print(str(response))"]}], "metadata": {"kernelspec": {"display_name": "venv", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3"}}, "nbformat": 4, "nbformat_minor": 5}