{ "cells": [ { "cell_type": "raw", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "# 🔄 Advanced Protocol Features in MCP\n", "\n", "Learn how to leverage advanced protocol features in MCP for building sophisticated tools and applications. This notebook covers protocol extensions, custom serialization, middleware, and advanced communication patterns.\n", "\n", "## 🎯 Learning Objectives\n", "\n", "By the end of this notebook, you will:\n", "- Implement custom protocol extensions\n", "- Create middleware components\n", "- Handle complex data types\n", "- Use advanced communication patterns\n", "- Optimize protocol performance\n", "\n", "## 📋 Prerequisites\n", "\n", "- Completed notebooks 01-12\n", "- Understanding of async programming\n", "- Knowledge of serialization\n", "- Familiarity with middleware concepts\n", "\n", "## 🔑 Key Concepts\n", "\n", "1. **Protocol Extensions**\n", " - Custom serializers\n", " - Protocol hooks\n", " - Message transformers\n", " - Type converters\n", "\n", "2. **Middleware**\n", " - Request/response pipeline\n", " - Authentication\n", " - Caching\n", " - Logging\n", "\n", "3. **Communication Patterns**\n", " - Request-response\n", " - Streaming\n", " - Pub-sub\n", " - Bidirectional\n", "\n", "## 📚 Table of Contents\n", "\n", "1. [Protocol Extensions](#extensions)\n", "2. [Middleware Components](#middleware)\n", "3. [Communication Patterns](#patterns)\n", "4. [Performance Optimization](#performance)\n", "5. [Best Practices](#practices)\n", "\n", "---\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from typing import Any, Dict, List, Optional, Type, TypeVar, Generic\n", "import json\n", "import asyncio\n", "from datetime import datetime\n", "from dataclasses import dataclass\n", "from abc import ABC, abstractmethod\n", "import modelcontextprotocol as mcp\n", "from pydantic import BaseModel, Field\n", "\n", "T = TypeVar('T')\n", "\n", "# Custom serializer\n", "class JSONSerializer:\n", " \"\"\"Custom JSON serializer with advanced type handling.\"\"\"\n", " \n", " @staticmethod\n", " def serialize(obj: Any) -> str:\n", " \"\"\"Serialize an object to JSON string.\"\"\"\n", " if isinstance(obj, datetime):\n", " return obj.isoformat()\n", " elif isinstance(obj, BaseModel):\n", " return obj.json()\n", " return json.dumps(obj)\n", " \n", " @staticmethod\n", " def deserialize(data: str, cls: Type[T]) -> T:\n", " \"\"\"Deserialize JSON string to object.\"\"\"\n", " raw = json.loads(data)\n", " if issubclass(cls, BaseModel):\n", " return cls.parse_obj(raw)\n", " elif cls == datetime:\n", " return datetime.fromisoformat(raw)\n", " return cls(raw)\n", "\n", "# Middleware base class\n", "class Middleware(ABC):\n", " \"\"\"Base class for middleware components.\"\"\"\n", " \n", " @abstractmethod\n", " async def process_request(self, request: Any) -> Any:\n", " \"\"\"Process incoming request.\"\"\"\n", " pass\n", " \n", " @abstractmethod\n", " async def process_response(self, response: Any) -> Any:\n", " \"\"\"Process outgoing response.\"\"\"\n", " pass\n", "\n", "# Authentication middleware\n", "class AuthMiddleware(Middleware):\n", " \"\"\"Middleware for authentication.\"\"\"\n", " \n", " def __init__(self, api_key: str):\n", " self.api_key = api_key\n", " \n", " async def process_request(self, request: Any) -> Any:\n", " \"\"\"Validate API key.\"\"\"\n", " if not hasattr(request, \"api_key\") or request.api_key != self.api_key:\n", " raise ValueError(\"Invalid API key\")\n", " return request\n", " \n", " async def process_response(self, response: Any) -> Any:\n", " \"\"\"Pass through response.\"\"\"\n", " return response\n", "\n", "# Caching middleware\n", "class CacheMiddleware(Middleware):\n", " \"\"\"Middleware for response caching.\"\"\"\n", " \n", " def __init__(self, ttl: int = 300):\n", " self.cache: Dict[str, Any] = {}\n", " self.ttl = ttl\n", " self.timestamps: Dict[str, datetime] = {}\n", " \n", " def _is_cached(self, key: str) -> bool:\n", " \"\"\"Check if key is cached and not expired.\"\"\"\n", " if key not in self.cache:\n", " return False\n", " if (datetime.now() - self.timestamps[key]).total_seconds() > self.ttl:\n", " del self.cache[key]\n", " del self.timestamps[key]\n", " return False\n", " return True\n", " \n", " async def process_request(self, request: Any) -> Any:\n", " \"\"\"Check cache for request.\"\"\"\n", " cache_key = JSONSerializer.serialize(request)\n", " if self._is_cached(cache_key):\n", " return self.cache[cache_key]\n", " return request\n", " \n", " async def process_response(self, response: Any) -> Any:\n", " \"\"\"Cache response.\"\"\"\n", " cache_key = JSONSerializer.serialize(response)\n", " self.cache[cache_key] = response\n", " self.timestamps[cache_key] = datetime.now()\n", " return response\n", "\n", "# Logging middleware\n", "class LoggingMiddleware(Middleware):\n", " \"\"\"Middleware for request/response logging.\"\"\"\n", " \n", " async def process_request(self, request: Any) -> Any:\n", " \"\"\"Log request.\"\"\"\n", " print(f\"Request: {request}\")\n", " return request\n", " \n", " async def process_response(self, response: Any) -> Any:\n", " \"\"\"Log response.\"\"\"\n", " print(f\"Response: {response}\")\n", " return response\n", "\n", "# Middleware chain\n", "class MiddlewareChain:\n", " \"\"\"Chain of middleware components.\"\"\"\n", " \n", " def __init__(self):\n", " self.middlewares: List[Middleware] = []\n", " \n", " def add(self, middleware: Middleware) -> None:\n", " \"\"\"Add middleware to chain.\"\"\"\n", " self.middlewares.append(middleware)\n", " \n", " async def process_request(self, request: Any) -> Any:\n", " \"\"\"Process request through middleware chain.\"\"\"\n", " for middleware in self.middlewares:\n", " request = await middleware.process_request(request)\n", " return request\n", " \n", " async def process_response(self, response: Any) -> Any:\n", " \"\"\"Process response through middleware chain.\"\"\"\n", " for middleware in reversed(self.middlewares):\n", " response = await middleware.process_response(response)\n", " return response\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Example models\n", "class WeatherRequest(BaseModel):\n", " \"\"\"Weather request with authentication.\"\"\"\n", " api_key: str = Field(..., description=\"API key for authentication\")\n", " location: str = Field(..., description=\"Location to get weather for\")\n", " units: str = Field(default=\"metric\", description=\"Units (metric/imperial)\")\n", "\n", "class WeatherResponse(BaseModel):\n", " \"\"\"Weather response with timestamp.\"\"\"\n", " location: str = Field(..., description=\"Location\")\n", " temperature: float = Field(..., description=\"Temperature\")\n", " humidity: float = Field(..., description=\"Humidity percentage\")\n", " timestamp: datetime = Field(default_factory=datetime.now, description=\"Response timestamp\")\n", "\n", "# Example tool\n", "class WeatherTool:\n", " \"\"\"Weather service with middleware support.\"\"\"\n", " \n", " def __init__(self, api_key: str):\n", " self.middleware = MiddlewareChain()\n", " \n", " # Add middleware components\n", " self.middleware.add(LoggingMiddleware())\n", " self.middleware.add(AuthMiddleware(api_key))\n", " self.middleware.add(CacheMiddleware(ttl=60)) # Cache for 1 minute\n", " \n", " async def get_weather(self, request: WeatherRequest) -> WeatherResponse:\n", " \"\"\"Get weather for location.\"\"\"\n", " # Process request through middleware\n", " processed_request = await self.middleware.process_request(request)\n", " \n", " # Simulate API call\n", " await asyncio.sleep(1)\n", " response = WeatherResponse(\n", " location=processed_request.location,\n", " temperature=20.5,\n", " humidity=65.0\n", " )\n", " \n", " # Process response through middleware\n", " processed_response = await self.middleware.process_response(response)\n", " return processed_response\n", "\n", "# Create MCP server with weather tool\n", "weather_tool = WeatherTool(api_key=\"secret-key\")\n", "server = mcp.Server()\n", "server.add_tool(\"weather\", weather_tool.get_weather, WeatherRequest, WeatherResponse)\n", "\n", "# Test the weather tool\n", "async def test_weather_tool():\n", " # Test with valid API key\n", " print(\"Testing with valid API key...\")\n", " request = WeatherRequest(\n", " api_key=\"secret-key\",\n", " location=\"London\",\n", " units=\"metric\"\n", " )\n", " response = await weather_tool.get_weather(request)\n", " print(f\"Weather: {response}\\n\")\n", " \n", " # Test caching (should be instant)\n", " print(\"Testing cache...\")\n", " start = datetime.now()\n", " response = await weather_tool.get_weather(request)\n", " duration = (datetime.now() - start).total_seconds()\n", " print(f\"Cached response time: {duration:.3f} seconds\")\n", " print(f\"Weather: {response}\\n\")\n", " \n", " # Test with invalid API key\n", " print(\"Testing with invalid API key...\")\n", " try:\n", " request = WeatherRequest(\n", " api_key=\"wrong-key\",\n", " location=\"London\",\n", " units=\"metric\"\n", " )\n", " await weather_tool.get_weather(request)\n", " except ValueError as e:\n", " print(f\"Caught expected error: {e}\")\n", "\n", "# Run tests\n", "await test_weather_tool()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Stream models\n", "class DataStreamRequest(BaseModel):\n", " \"\"\"Request for streaming data.\"\"\"\n", " stream_id: str = Field(..., description=\"Stream identifier\")\n", " batch_size: int = Field(default=10, description=\"Number of items per batch\")\n", "\n", "class DataStreamResponse(BaseModel):\n", " \"\"\"Response with streaming data.\"\"\"\n", " stream_id: str = Field(..., description=\"Stream identifier\")\n", " data: List[float] = Field(..., description=\"Batch of data\")\n", " is_last: bool = Field(default=False, description=\"Whether this is the last batch\")\n", "\n", "# Pub-sub models\n", "class PublishRequest(BaseModel):\n", " \"\"\"Request to publish data.\"\"\"\n", " topic: str = Field(..., description=\"Topic to publish to\")\n", " message: Any = Field(..., description=\"Message to publish\")\n", "\n", "class SubscribeRequest(BaseModel):\n", " \"\"\"Request to subscribe to topic.\"\"\"\n", " topic: str = Field(..., description=\"Topic to subscribe to\")\n", "\n", "class Message(BaseModel):\n", " \"\"\"Message received from subscription.\"\"\"\n", " topic: str = Field(..., description=\"Topic of the message\")\n", " data: Any = Field(..., description=\"Message data\")\n", " timestamp: datetime = Field(default_factory=datetime.now, description=\"Message timestamp\")\n", "\n", "# Stream handler\n", "class StreamHandler:\n", " \"\"\"Handler for data streaming.\"\"\"\n", " \n", " def __init__(self):\n", " self.streams: Dict[str, asyncio.Queue] = {}\n", " \n", " async def create_stream(self, stream_id: str) -> None:\n", " \"\"\"Create a new data stream.\"\"\"\n", " if stream_id not in self.streams:\n", " self.streams[stream_id] = asyncio.Queue()\n", " \n", " async def write_data(self, stream_id: str, data: List[float]) -> None:\n", " \"\"\"Write data to stream.\"\"\"\n", " if stream_id in self.streams:\n", " await self.streams[stream_id].put(data)\n", " \n", " async def read_data(self, stream_id: str, batch_size: int) -> List[float]:\n", " \"\"\"Read data from stream.\"\"\"\n", " if stream_id not in self.streams:\n", " raise ValueError(f\"Stream {stream_id} not found\")\n", " \n", " queue = self.streams[stream_id]\n", " if queue.empty():\n", " return []\n", " \n", " data = []\n", " while len(data) < batch_size and not queue.empty():\n", " batch = await queue.get()\n", " data.extend(batch)\n", " return data[:batch_size]\n", "\n", "# Pub-sub handler\n", "class PubSubHandler:\n", " \"\"\"Handler for publish-subscribe pattern.\"\"\"\n", " \n", " def __init__(self):\n", " self.topics: Dict[str, List[asyncio.Queue]] = {}\n", " \n", " async def publish(self, topic: str, message: Any) -> None:\n", " \"\"\"Publish message to topic.\"\"\"\n", " if topic not in self.topics:\n", " self.topics[topic] = []\n", " \n", " msg = Message(topic=topic, data=message)\n", " for queue in self.topics[topic]:\n", " await queue.put(msg)\n", " \n", " async def subscribe(self, topic: str) -> asyncio.Queue:\n", " \"\"\"Subscribe to topic.\"\"\"\n", " if topic not in self.topics:\n", " self.topics[topic] = []\n", " \n", " queue = asyncio.Queue()\n", " self.topics[topic].append(queue)\n", " return queue\n", " \n", " async def unsubscribe(self, topic: str, queue: asyncio.Queue) -> None:\n", " \"\"\"Unsubscribe from topic.\"\"\"\n", " if topic in self.topics and queue in self.topics[topic]:\n", " self.topics[topic].remove(queue)\n", "\n", "# Example usage\n", "async def test_streaming():\n", " print(\"Testing data streaming...\")\n", " handler = StreamHandler()\n", " \n", " # Create stream\n", " stream_id = \"test-stream\"\n", " await handler.create_stream(stream_id)\n", " \n", " # Write data\n", " for i in range(3):\n", " data = [float(x) for x in range(i*5, (i+1)*5)]\n", " await handler.write_data(stream_id, data)\n", " print(f\"Wrote batch {i}: {data}\")\n", " \n", " # Read data in batches\n", " batch_size = 7\n", " while True:\n", " data = await handler.read_data(stream_id, batch_size)\n", " if not data:\n", " break\n", " print(f\"Read batch (size={batch_size}): {data}\")\n", "\n", "async def test_pubsub():\n", " print(\"\\nTesting pub-sub...\")\n", " handler = PubSubHandler()\n", " \n", " # Create subscribers\n", " sub1 = await handler.subscribe(\"weather\")\n", " sub2 = await handler.subscribe(\"weather\")\n", " \n", " # Publish messages\n", " messages = [\n", " {\"location\": \"London\", \"temp\": 20},\n", " {\"location\": \"Paris\", \"temp\": 25},\n", " {\"location\": \"Tokyo\", \"temp\": 30}\n", " ]\n", " \n", " for msg in messages:\n", " await handler.publish(\"weather\", msg)\n", " print(f\"Published: {msg}\")\n", " \n", " # Read messages from subscribers\n", " print(\"\\nSubscriber 1 messages:\")\n", " while not sub1.empty():\n", " msg = await sub1.get()\n", " print(f\"Received: {msg}\")\n", " \n", " print(\"\\nSubscriber 2 messages:\")\n", " while not sub2.empty():\n", " msg = await sub2.get()\n", " print(f\"Received: {msg}\")\n", "\n", "# Run tests\n", "print(\"Testing advanced communication patterns...\\n\")\n", "await test_streaming()\n", "await test_pubsub()\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }