In [None]:
import json
from typing import Dict, Any, Optional
from dataclasses import dataclass, asdict
from pathlib import Path
from enum import Enum
import threading
import modelcontextprotocol as mcp
from pydantic import BaseModel, Field

# Define state types
class StateType(Enum):
 MEMORY = "memory"
 FILE = "file"
 DATABASE = "database"

class StateInput(BaseModel):
 key: str = Field(..., description="State key to get/set")
 value: Optional[Any] = Field(None, description="Value to set (if setting state)")
 
class StateOutput(BaseModel):
 key: str
 value: Any
 state_type: StateType
 
class StatefulTool:
 def __init__(self, state_type: StateType = StateType.MEMORY):
 self.state_type = state_type
 self.state: Dict[str, Any] = {}
 self._lock = threading.Lock()
 
 if state_type == StateType.FILE:
 self.state_file = Path("tool_state.json")
 self._load_state()
 
 def _load_state(self):
 """Load state from file if it exists."""
 if self.state_file.exists():
 with open(self.state_file, 'r') as f:
 self.state = json.load(f)
 
 def _save_state(self):
 """Save state to file if using file persistence."""
 if self.state_type == StateType.FILE:
 with open(self.state_file, 'w') as f:
 json.dump(self.state, f)
 
 def get_state(self, key: str) -> StateOutput:
 """Get state value for a key."""
 with self._lock:
 value = self.state.get(key)
 return StateOutput(
 key=key,
 value=value,
 state_type=self.state_type
 )
 
 def set_state(self, key: str, value: Any) -> StateOutput:
 """Set state value for a key."""
 with self._lock:
 self.state[key] = value
 if self.state_type == StateType.FILE:
 self._save_state()
 return StateOutput(
 key=key,
 value=value,
 state_type=self.state_type
 )
 
 async def handle_state(self, input_data: StateInput) -> StateOutput:
 """MCP handler for state operations."""
 if input_data.value is not None:
 return self.set_state(input_data.key, input_data.value)
 return self.get_state(input_data.key)

# Create our MCP tool
state_tool = StatefulTool(state_type=StateType.FILE)

# Create the MCP server
server = mcp.Server()
server.add_tool("state", state_tool.handle_state, StateInput, StateOutput)


In [None]:
# Basic state operations
print("Setting initial state...")
result = state_tool.set_state("user_preferences", {
 "theme": "dark",
 "language": "en",
 "notifications": True
})
print(f"Set state: {result}")

print("\nGetting state...")
result = state_tool.get_state("user_preferences")
print(f"Got state: {result}")

print("\nUpdating state...")
preferences = result.value
preferences["theme"] = "light"
result = state_tool.set_state("user_preferences", preferences)
print(f"Updated state: {result}")

# Test persistence
print("\nTesting file persistence...")
new_tool = StatefulTool(state_type=StateType.FILE)
result = new_tool.get_state("user_preferences")
print(f"Loaded state from file: {result}")

# Test concurrent access
import threading
import time

def concurrent_updates():
 for i in range(5):
 state_tool.set_state(f"counter_{threading.current_thread().name}", i)
 time.sleep(0.1)

print("\nTesting concurrent access...")
threads = [
 threading.Thread(target=concurrent_updates, name=f"thread_{i}")
 for i in range(3)
]

for thread in threads:
 thread.start()
 
for thread in threads:
 thread.join()

print("\nFinal state after concurrent updates:")
for i in range(3):
 result = state_tool.get_state(f"counter_thread_{i}")
 print(f"Thread {i} final count: {result.value}")
