In [None]:
from typing import Dict, Any, Optional, List, Type, TypeVar, Generic
import json
import asyncio
from datetime import datetime
from dataclasses import dataclass, asdict
from abc import ABC, abstractmethod
import sqlite3
import modelcontextprotocol as mcp
from pydantic import BaseModel, Field

T = TypeVar('T')

# State interface
class State(BaseModel):
 """Base class for state objects."""
 version: int = Field(default=1, description="State version")
 created_at: datetime = Field(default_factory=datetime.now, description="Creation timestamp")
 updated_at: datetime = Field(default_factory=datetime.now, description="Last update timestamp")
 
 def update(self):
 """Update the timestamp."""
 self.updated_at = datetime.now()

# State store interface
class StateStore(ABC, Generic[T]):
 """Abstract base class for state stores."""
 
 @abstractmethod
 async def get(self, key: str) -> Optional[T]:
 """Get state by key."""
 pass
 
 @abstractmethod
 async def set(self, key: str, state: T) -> None:
 """Set state for key."""
 pass
 
 @abstractmethod
 async def delete(self, key: str) -> None:
 """Delete state for key."""
 pass
 
 @abstractmethod
 async def list(self) -> List[str]:
 """List all state keys."""
 pass

# Memory state store
class MemoryStateStore(StateStore[T]):
 """In-memory state store implementation."""
 
 def __init__(self):
 self.store: Dict[str, T] = {}
 self.locks: Dict[str, asyncio.Lock] = {}
 
 def _get_lock(self, key: str) -> asyncio.Lock:
 """Get or create lock for key."""
 if key not in self.locks:
 self.locks[key] = asyncio.Lock()
 return self.locks[key]
 
 async def get(self, key: str) -> Optional[T]:
 """Get state by key."""
 async with self._get_lock(key):
 return self.store.get(key)
 
 async def set(self, key: str, state: T) -> None:
 """Set state for key."""
 async with self._get_lock(key):
 self.store[key] = state
 
 async def delete(self, key: str) -> None:
 """Delete state for key."""
 async with self._get_lock(key):
 if key in self.store:
 del self.store[key]
 
 async def list(self) -> List[str]:
 """List all state keys."""
 return list(self.store.keys())

# SQLite state store
class SQLiteStateStore(StateStore[T]):
 """SQLite-based state store implementation."""
 
 def __init__(self, db_path: str, state_type: Type[T]):
 self.db_path = db_path
 self.state_type = state_type
 self.locks: Dict[str, asyncio.Lock] = {}
 self._init_db()
 
 def _init_db(self):
 """Initialize database."""
 conn = sqlite3.connect(self.db_path)
 try:
 conn.execute("""
 CREATE TABLE IF NOT EXISTS states (
 key TEXT PRIMARY KEY,
 data TEXT NOT NULL,
 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
 )
 """)
 conn.commit()
 finally:
 conn.close()
 
 def _get_lock(self, key: str) -> asyncio.Lock:
 """Get or create lock for key."""
 if key not in self.locks:
 self.locks[key] = asyncio.Lock()
 return self.locks[key]
 
 async def get(self, key: str) -> Optional[T]:
 """Get state by key."""
 async with self._get_lock(key):
 conn = sqlite3.connect(self.db_path)
 try:
 cursor = conn.execute(
 "SELECT data FROM states WHERE key = ?",
 (key,)
 )
 row = cursor.fetchone()
 if row:
 data = json.loads(row[0])
 return self.state_type.parse_obj(data)
 return None
 finally:
 conn.close()
 
 async def set(self, key: str, state: T) -> None:
 """Set state for key."""
 async with self._get_lock(key):
 conn = sqlite3.connect(self.db_path)
 try:
 data = json.dumps(state.dict())
 conn.execute(
 """
 INSERT OR REPLACE INTO states (key, data, updated_at)
 VALUES (?, ?, CURRENT_TIMESTAMP)
 """,
 (key, data)
 )
 conn.commit()
 finally:
 conn.close()
 
 async def delete(self, key: str) -> None:
 """Delete state for key."""
 async with self._get_lock(key):
 conn = sqlite3.connect(self.db_path)
 try:
 conn.execute("DELETE FROM states WHERE key = ?", (key,))
 conn.commit()
 finally:
 conn.close()
 
 async def list(self) -> List[str]:
 """List all state keys."""
 conn = sqlite3.connect(self.db_path)
 try:
 cursor = conn.execute("SELECT key FROM states")
 return [row[0] for row in cursor.fetchall()]
 finally:
 conn.close()


In [None]:
# Example state models
class UserState(State):
 """User state model."""
 name: str = Field(..., description="User name")
 email: str = Field(..., description="User email")
 preferences: Dict[str, Any] = Field(default_factory=dict, description="User preferences")
 last_login: Optional[datetime] = Field(default=None, description="Last login timestamp")

class SessionState(State):
 """Session state model."""
 user_id: str = Field(..., description="User ID")
 token: str = Field(..., description="Session token")
 expires_at: datetime = Field(..., description="Expiration timestamp")
 data: Dict[str, Any] = Field(default_factory=dict, description="Session data")

# State manager
class StateManager:
 """Manager for handling multiple state stores."""
 
 def __init__(self):
 # Create state stores
 self.user_store = SQLiteStateStore("users.db", UserState)
 self.session_store = MemoryStateStore[SessionState]()
 
 async def create_user(self, user_id: str, name: str, email: str) -> UserState:
 """Create a new user state."""
 user = UserState(
 name=name,
 email=email
 )
 await self.user_store.set(user_id, user)
 return user
 
 async def get_user(self, user_id: str) -> Optional[UserState]:
 """Get user state."""
 return await self.user_store.get(user_id)
 
 async def update_user(self, user_id: str, **updates) -> Optional[UserState]:
 """Update user state."""
 user = await self.get_user(user_id)
 if user:
 for key, value in updates.items():
 setattr(user, key, value)
 user.update()
 await self.user_store.set(user_id, user)
 return user
 
 async def create_session(self, user_id: str) -> SessionState:
 """Create a new session state."""
 import uuid
 from datetime import timedelta
 
 session = SessionState(
 user_id=user_id,
 token=str(uuid.uuid4()),
 expires_at=datetime.now() + timedelta(hours=1)
 )
 await self.session_store.set(session.token, session)
 return session
 
 async def get_session(self, token: str) -> Optional[SessionState]:
 """Get session state."""
 session = await self.session_store.get(token)
 if session and session.expires_at > datetime.now():
 return session
 return None
 
 async def update_session(self, token: str, data: Dict[str, Any]) -> Optional[SessionState]:
 """Update session data."""
 session = await self.get_session(token)
 if session:
 session.data.update(data)
 session.update()
 await self.session_store.set(token, session)
 return session

# Test state management
async def test_state_management():
 print("Testing state management...")
 manager = StateManager()
 
 # Create user
 print("\nCreating user...")
 user = await manager.create_user(
 "user-1",
 "John Doe",
 "john@example.com"
 )
 print(f"Created user: {user}")
 
 # Update user
 print("\nUpdating user...")
 user = await manager.update_user(
 "user-1",
 preferences={"theme": "dark", "notifications": True}
 )
 print(f"Updated user: {user}")
 
 # Create session
 print("\nCreating session...")
 session = await manager.create_session("user-1")
 print(f"Created session: {session}")
 
 # Update session
 print("\nUpdating session...")
 session = await manager.update_session(
 session.token,
 {"last_page": "/dashboard"}
 )
 print(f"Updated session: {session}")
 
 # Get user and session
 print("\nRetrieving states...")
 user = await manager.get_user("user-1")
 print(f"Retrieved user: {user}")
 
 session = await manager.get_session(session.token)
 print(f"Retrieved session: {session}")

# Run test
await test_state_management()
