import os import streamlit as st from langchain.utilities.sql_database import SQLDatabase from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from dotenv import load_dotenv from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain from langchain_experimental.sql.vector_sql import VectorSQLOutputParser from langchain.chains.sql_database.prompt import SQL_PROMPTS from langchain.prompts.prompt import PromptTemplate from langchain.chains.llm import LLMChain QUERY_CHECKER = """ {query} Double check the {dialect} query above for common mistakes, including: - Using NOT IN with NULL values - Using UNION when UNION ALL should have been used - Using BETWEEN for exclusive ranges - Data type mismatch in predicates - Properly quoting identifiers - Using the correct number of arguments for functions - Casting to the correct data type - Using the proper columns for joins If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query. Output the final SQL query only. Remove ```sql and any other ``` from the output. SQL Query: """ load_dotenv() def main(): st.markdown( """

CVE-2024-21513

""", unsafe_allow_html=True, ) pg_uri = os.getenv("PG_URI") openai_api_key = os.getenv("OPENAI_API_KEY") if not pg_uri or not openai_api_key: st.error("Missing environment variables. Please check your .env file.") return db = SQLDatabase.from_uri(pg_uri) st.success("Database connection successful!") query = st.text_area("Enter your query") llm = ChatOpenAI(model="gpt-4o", temperature=0, openai_api_key=openai_api_key) encoder = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=openai_api_key) # SQL Prompt Template prompt = SQL_PROMPTS["postgresql"] # Create LLM Chain llm_chain = LLMChain(llm=llm, prompt=prompt) # SQL Output Parser for VectorSQLDatabaseChain sql_cmd_parser = VectorSQLOutputParser(model=encoder) query_checker_prompt = PromptTemplate(template=QUERY_CHECKER, input_variables=["query", "dialect"]) # Initialize VectorSQLDatabaseChain db_chain = VectorSQLDatabaseChain( llm_chain=llm_chain, database=db, sql_cmd_parser=sql_cmd_parser, use_query_checker=True, query_checker_prompt=query_checker_prompt, return_direct=True, # return_sql=True # If you want to return SQL queries ) if st.button("Submit"): if query.strip(): response = db_chain({"query": query, "dialect": "PostgreSQL", "top_k": 3}) st.success("Result from the database:") if response["result"] == []: st.write("Nothing to return") else: st.write(response["result"]) else: st.warning("Please enter a valid query.") if __name__ == "__main__": st.set_page_config(page_title="Test", layout="wide") main()