import os
import streamlit as st
from langchain.utilities.sql_database import SQLDatabase
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
from langchain.chains.sql_database.prompt import SQL_PROMPTS
from langchain.prompts.prompt import PromptTemplate
from langchain.chains.llm import LLMChain
QUERY_CHECKER = """
{query}
Double check the {dialect} query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
Output the final SQL query only. Remove ```sql and any other ``` from the output.
SQL Query: """
load_dotenv()
def main():
st.markdown(
"""
CVE-2024-21513
""",
unsafe_allow_html=True,
)
pg_uri = os.getenv("PG_URI")
openai_api_key = os.getenv("OPENAI_API_KEY")
if not pg_uri or not openai_api_key:
st.error("Missing environment variables. Please check your .env file.")
return
db = SQLDatabase.from_uri(pg_uri)
st.success("Database connection successful!")
query = st.text_area("Enter your query")
llm = ChatOpenAI(model="gpt-4o", temperature=0, openai_api_key=openai_api_key)
encoder = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key=openai_api_key)
# SQL Prompt Template
prompt = SQL_PROMPTS["postgresql"]
# Create LLM Chain
llm_chain = LLMChain(llm=llm, prompt=prompt)
# SQL Output Parser for VectorSQLDatabaseChain
sql_cmd_parser = VectorSQLOutputParser(model=encoder)
query_checker_prompt = PromptTemplate(template=QUERY_CHECKER, input_variables=["query", "dialect"])
# Initialize VectorSQLDatabaseChain
db_chain = VectorSQLDatabaseChain(
llm_chain=llm_chain,
database=db,
sql_cmd_parser=sql_cmd_parser,
use_query_checker=True,
query_checker_prompt=query_checker_prompt,
return_direct=True,
# return_sql=True # If you want to return SQL queries
)
if st.button("Submit"):
if query.strip():
response = db_chain({"query": query, "dialect": "PostgreSQL", "top_k": 3})
st.success("Result from the database:")
if response["result"] == []:
st.write("Nothing to return")
else:
st.write(response["result"])
else:
st.warning("Please enter a valid query.")
if __name__ == "__main__":
st.set_page_config(page_title="Test", layout="wide")
main()