{"cells":[{"cell_type":"markdown","id":"245065c6","metadata":{},"source":["# 使用 MyScale 的向量 SQL 检索器\n","\n",">[MyScale](https://docs.myscale.com/en/) 是一个集成的向量数据库。您可以通过 SQL 访问您的数据库,也可以通过 LangChain 访问。MyScale 可以利用[各种数据类型和过滤函数](https://blog.myscale.com/2023/06/06/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints)。无论您是扩展数据还是将系统扩展到更广泛的应用程序,它都将提升您的 LLM 应用程序。"]},{"cell_type":"code","execution_count":null,"id":"0246c5bf","metadata":{},"outputs":[],"source":["# 安装所需的Python包\n","!pip3 install clickhouse-sqlalchemy InstructorEmbedding sentence_transformers openai langchain-experimental"]},{"cell_type":"code","execution_count":null,"id":"7585d2c3","metadata":{},"outputs":[],"source":["import getpass\n","from os import environ\n","\n","# 导入所需的模块和类\n","from langchain.chains import LLMChain\n","from langchain.prompts import PromptTemplate\n","from langchain_community.utilities import SQLDatabase\n","from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n","from langchain_openai import OpenAI\n","from sqlalchemy import MetaData, create_engine\n","\n","# 设置连接信息\n","MYSCALE_HOST = \"msc-4a9e710a.us-east-1.aws.staging.myscale.cloud\"\n","MYSCALE_PORT = 443\n","MYSCALE_USER = \"chatdata\"\n","MYSCALE_PASSWORD = \"myscale_rocks\"\n","OPENAI_API_KEY = getpass.getpass(\"OpenAI API Key:\")\n","\n","# 创建 ClickHouse 引擎\n","engine = create_engine(\n"," f\"clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/default?protocol=https\"\n",")\n","\n","# 创建元数据对象\n","metadata = MetaData(bind=engine)\n","\n","# 设置 OpenAI API 密钥\n","environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n"]},{"cell_type":"code","execution_count":null,"id":"e08d9ddc","metadata":{},"outputs":[],"source":["from langchain_community.embeddings import HuggingFaceInstructEmbeddings\n","from langchain_experimental.sql.vector_sql import VectorSQLOutputParser\n","\n","output_parser = VectorSQLOutputParser.from_embeddings(\n"," model=HuggingFaceInstructEmbeddings(\n"," model_name=\"hkunlp/instructor-xl\", model_kwargs={\"device\": \"cpu\"}\n"," )\n",")"]},{"cell_type":"code","execution_count":null,"id":"84b705b2","metadata":{},"outputs":[],"source":["# 导入必要的模块\n","from langchain.callbacks import StdOutCallbackHandler\n","from langchain_community.utilities.sql_database import SQLDatabase\n","from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n","from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain\n","from langchain_openai import OpenAI\n","\n","# 创建一个VectorSQLDatabaseChain对象\n","chain = VectorSQLDatabaseChain(\n"," llm_chain=LLMChain(\n"," llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0), # 使用OpenAI的API密钥和温度参数初始化OpenAI对象\n"," prompt=MYSCALE_PROMPT, # 设置prompt为MYSCALE_PROMPT\n"," ),\n"," top_k=10, # 设置top_k参数为10\n"," return_direct=True, # 设置return_direct参数为True\n"," sql_cmd_parser=output_parser, # 使用output_parser解析SQL命令\n"," database=SQLDatabase(engine, None, metadata), # 使用engine和metadata初始化SQLDatabase对象\n",")\n","\n","# 导入pandas模块\n","import pandas as pd\n","\n","# 运行chain对象,并将结果转换为DataFrame格式\n","pd.DataFrame(\n"," chain.run(\n"," \"Please give me 10 papers to ask what is PageRank?\", # 输入查询语句\n"," callbacks=[StdOutCallbackHandler()], # 设置回调函数为StdOutCallbackHandler\n"," )\n",")"]},{"cell_type":"markdown","id":"6c09cda0","metadata":{},"source":["## SQL数据库作为检索器"]},{"cell_type":"code","execution_count":null,"id":"734d7ff5","metadata":{},"outputs":[],"source":["from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain\n","from langchain_experimental.retrievers.vector_sql_database import (\n"," VectorSQLDatabaseChainRetriever,\n",")\n","from langchain_experimental.sql.prompt import MYSCALE_PROMPT\n","from langchain_experimental.sql.vector_sql import (\n"," VectorSQLDatabaseChain,\n"," VectorSQLRetrieveAllOutputParser,\n",")\n","from langchain_openai import ChatOpenAI\n","\n","output_parser_retrieve_all = VectorSQLRetrieveAllOutputParser.from_embeddings(\n"," output_parser.model\n",") # 从嵌入中创建VectorSQLRetrieveAllOutputParser对象\n","\n","chain = VectorSQLDatabaseChain.from_llm(\n"," llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0), # 使用OpenAI模型创建llm对象\n"," prompt=MYSCALE_PROMPT, # 使用MYSCALE_PROMPT作为提示\n"," top_k=10, # 返回前10个结果\n"," return_direct=True, # 返回直接结果\n"," db=SQLDatabase(engine, None, metadata), # 使用给定的engine、None和metadata创建SQLDatabase对象\n"," sql_cmd_parser=output_parser_retrieve_all, # 使用output_parser_retrieve_all作为sql_cmd_parser\n"," native_format=True, # 使用原生格式\n",")\n","\n","# 您需要所有这些键来获取文档\n","retriever = VectorSQLDatabaseChainRetriever(\n"," sql_db_chain=chain, page_content_key=\"abstract\" # 使用chain和\"abstract\"作为page_content_key创建VectorSQLDatabaseChainRetriever对象\n",")\n","\n","document_with_metadata_prompt = PromptTemplate(\n"," input_variables=[\"page_content\", \"id\", \"title\", \"authors\", \"pubdate\", \"categories\"], # 输入变量列表\n"," template=\"Content:\\n\\tTitle: {title}\\n\\tAbstract: {page_content}\\n\\tAuthors: {authors}\\n\\tDate of Publication: {pubdate}\\n\\tCategories: {categories}\\nSOURCE: {id}\", # 模板字符串\n",")\n","\n","chain = RetrievalQAWithSourcesChain.from_chain_type(\n"," ChatOpenAI(\n"," model_name=\"gpt-3.5-turbo-16k\", openai_api_key=OPENAI_API_KEY, temperature=0.6 # 使用给定的模型名称、OPENAI_API_KEY和温度创建ChatOpenAI对象\n"," ),\n"," retriever=retriever, # 使用retriever作为retriever参数\n"," chain_type=\"stuff\", # 使用\"stuff\"作为chain_type\n"," chain_type_kwargs={\n"," \"document_prompt\": document_with_metadata_prompt, # 使用document_with_metadata_prompt作为document_prompt参数\n"," },\n"," return_source_documents=True, # 返回源文档\n",")\n","ans = chain(\n"," \"Please give me 10 papers to ask what is PageRank?\", # 输入问题\n"," callbacks=[StdOutCallbackHandler()], # 使用StdOutCallbackHandler作为回调函数\n",")\n","print(ans[\"answer\"]) # 打印答案"]},{"cell_type":"code","execution_count":null,"id":"4948ff25","metadata":{},"outputs":[],"source":[]}],"metadata":{"kernelspec":{"display_name":"Python 3 (ipykernel)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.11.3"}},"nbformat":4,"nbformat_minor":5}