In [None]:
import sqlite3
from contextlib import contextmanager
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from pathlib import Path

import modelcontextprotocol as mcp
from pydantic import BaseModel, Field

# Schema for our database tool
class QueryInput(BaseModel):
 query: str = Field(..., description="SQL query to execute")
 parameters: Optional[List[Any]] = Field(default=None, description="Query parameters")
 
class QueryResult(BaseModel):
 columns: List[str]
 rows: List[List[Any]]
 row_count: int
 
# Database connection manager
@contextmanager
def get_db_connection(db_path: str):
 conn = sqlite3.connect(db_path)
 try:
 # Enable dictionary cursor
 conn.row_factory = sqlite3.Row
 yield conn
 finally:
 conn.close()

# Create a test database
DB_PATH = "example.db"
with get_db_connection(DB_PATH) as conn:
 conn.execute("""
 CREATE TABLE IF NOT EXISTS users (
 id INTEGER PRIMARY KEY,
 name TEXT NOT NULL,
 email TEXT UNIQUE NOT NULL
 )
 """)


In [None]:
class DatabaseTool:
 def __init__(self, db_path: str):
 self.db_path = db_path
 
 def execute_query(self, query: str, parameters: Optional[List[Any]] = None) -> QueryResult:
 """Execute a SQL query and return the results."""
 with get_db_connection(self.db_path) as conn:
 cursor = conn.cursor()
 try:
 if parameters:
 cursor.execute(query, parameters)
 else:
 cursor.execute(query)
 
 # Get column names
 columns = [description[0] for description in cursor.description] if cursor.description else []
 
 # Fetch all rows
 rows = cursor.fetchall()
 
 # Convert rows to lists (from sqlite3.Row objects)
 rows = [list(row) for row in rows]
 
 return QueryResult(
 columns=columns,
 rows=rows,
 row_count=len(rows)
 )
 except sqlite3.Error as e:
 raise ValueError(f"Database error: {str(e)}")
 
 async def handle_query(self, input_data: QueryInput) -> QueryResult:
 """MCP handler for database queries."""
 return self.execute_query(input_data.query, input_data.parameters)

# Create our MCP tool
db_tool = DatabaseTool(DB_PATH)

# Create the MCP server
server = mcp.Server()
server.add_tool("database", db_tool.handle_query, QueryInput, QueryResult)

# Example usage
example_query = QueryInput(
 query="INSERT INTO users (name, email) VALUES (?, ?)",
 parameters=["John Doe", "john@example.com"]
)

result = db_tool.execute_query(example_query.query, example_query.parameters)
print("Query executed successfully!")


In [None]:
# Let's add some test data
test_users = [
 ("Alice Smith", "alice@example.com"),
 ("Bob Johnson", "bob@example.com"),
 ("Carol White", "carol@example.com")
]

# Insert multiple users
insert_query = QueryInput(
 query="INSERT INTO users (name, email) VALUES (?, ?)",
 parameters=test_users[0] # Insert Alice
)

result = db_tool.execute_query(insert_query.query, insert_query.parameters)
print("Inserted first user successfully!")

# Query the data
select_query = QueryInput(
 query="SELECT * FROM users WHERE email = ?",
 parameters=["alice@example.com"]
)

result = db_tool.execute_query(select_query.query, select_query.parameters)
print("\nQuery Result:")
print(f"Columns: {result.columns}")
print(f"Rows: {result.rows}")
print(f"Row count: {result.row_count}")


In [None]:
# 1. Invalid SQL
try:
 bad_query = QueryInput(
 query="SELECT * FROMM users" # Intentional typo
 )
 result = db_tool.execute_query(bad_query.query)
except ValueError as e:
 print("✓ Caught invalid SQL:", e)

# 2. Constraint Violation (duplicate email)
try:
 duplicate_query = QueryInput(
 query="INSERT INTO users (name, email) VALUES (?, ?)",
 parameters=["Another Alice", "alice@example.com"] # Email already exists
 )
 result = db_tool.execute_query(duplicate_query.query, duplicate_query.parameters)
except ValueError as e:
 print("✓ Caught duplicate email:", e)

# 3. Type Mismatch
try:
 type_mismatch_query = QueryInput(
 query="SELECT * FROM users WHERE id = ?",
 parameters=["not_a_number"] # id should be integer
 )
 result = db_tool.execute_query(type_mismatch_query.query, type_mismatch_query.parameters)
except ValueError as e:
 print("✓ Caught type mismatch:", e)

# 4. Missing Table
try:
 missing_table_query = QueryInput(
 query="SELECT * FROM nonexistent_table"
 )
 result = db_tool.execute_query(missing_table_query.query)
except ValueError as e:
 print("✓ Caught missing table:", e)
