#!/usr/bin/env python import os import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass import time from datetime import datetime, timedelta from enum import Enum import dotenv import requests from fastmcp import FastMCP, Context from prometheus_mcp_server.logging_config import get_logger dotenv.load_dotenv() # Get tool prefix from environment (empty string for backward compatibility) TOOL_PREFIX = os.environ.get("TOOL_PREFIX", "") def _tool_name(name: str) -> str: """Build tool name with optional prefix.""" return f"{TOOL_PREFIX}_{name}" if TOOL_PREFIX else name # Include prefix in MCP server name if set mcp_name = f"Prometheus MCP ({TOOL_PREFIX})" if TOOL_PREFIX else "Prometheus MCP" mcp = FastMCP(mcp_name) from starlette.requests import Request from starlette.responses import JSONResponse @mcp.custom_route("/health", methods=["GET"]) async def health_endpoint(request: Request) -> JSONResponse: return JSONResponse({"status": "ok"}) # Cache for metrics list to improve completion performance _metrics_cache = {"data": None, "timestamp": 0} _CACHE_TTL = 300 # 5 minutes def clear_metrics_cache(): """Reset the metrics cache, forcing the next fetch to hit Prometheus.""" _metrics_cache["data"] = None _metrics_cache["timestamp"] = 0 # Get logger instance logger = get_logger() # Health check tool for Docker containers and monitoring @mcp.tool( name=_tool_name("health_check"), description="Health check endpoint for container monitoring and status verification", annotations={ "title": "Health Check", "icon": "â¤ī¸", "readOnlyHint": True, "destructiveHint": False, "idempotentHint": True, "openWorldHint": True } ) async def health_check() -> Dict[str, Any]: """Return health status of the MCP server and Prometheus connection. Returns: Health status including service information, configuration, and connectivity """ try: health_status = { "status": "healthy", "service": "prometheus-mcp-server", "version": "1.6.1", "timestamp": datetime.utcnow().isoformat(), "transport": config.mcp_server_config.mcp_server_transport if config.mcp_server_config else "stdio", "configuration": { "prometheus_url_configured": bool(config.url), "authentication_configured": bool(config.username or config.token or config.client_cert), "org_id_configured": bool(config.org_id) } } # Test Prometheus connectivity if configured if config.url: try: # Quick connectivity test make_prometheus_request("query", params={"query": "up", "time": str(int(time.time()))}) health_status["prometheus_connectivity"] = "healthy" health_status["prometheus_url"] = config.url except Exception as e: health_status["prometheus_connectivity"] = "unhealthy" health_status["prometheus_error"] = str(e) health_status["status"] = "degraded" else: health_status["status"] = "unhealthy" health_status["error"] = "PROMETHEUS_URL not configured" logger.info("Health check completed", status=health_status["status"]) return health_status except Exception as e: logger.error("Health check failed", error=str(e)) return { "status": "unhealthy", "service": "prometheus-mcp-server", "error": str(e), "timestamp": datetime.utcnow().isoformat() } class TransportType(str, Enum): """Supported MCP server transport types.""" STDIO = "stdio" HTTP = "http" SSE = "sse" @classmethod def values(cls) -> list[str]: """Get all valid transport values.""" return [transport.value for transport in cls] @dataclass class MCPServerConfig: """Global Configuration for MCP.""" mcp_server_transport: TransportType = None mcp_bind_host: str = None mcp_bind_port: int = None stateless_http: bool = False def __post_init__(self): """Validate mcp configuration.""" if not self.mcp_server_transport: raise ValueError("MCP SERVER TRANSPORT is required") if not self.mcp_bind_host: raise ValueError(f"MCP BIND HOST is required") if not self.mcp_bind_port: raise ValueError(f"MCP BIND PORT is required") @dataclass class PrometheusConfig: url: str url_ssl_verify: bool = True disable_prometheus_links: bool = False # Optional credentials username: Optional[str] = None password: Optional[str] = None token: Optional[str] = None # Optional Org ID for multi-tenant setups org_id: Optional[str] = None # Optional client TLS certificate for mutual TLS authentication client_cert: Optional[str] = None client_key: Optional[str] = None # Optional Custom MCP Server Configuration mcp_server_config: Optional[MCPServerConfig] = None # Optional custom headers for Prometheus requests custom_headers: Optional[Dict[str, str]] = None # Request timeout in seconds to prevent hanging requests (DDoS protection) request_timeout: int = 30 config = PrometheusConfig( url=os.environ.get("PROMETHEUS_URL", ""), url_ssl_verify=os.environ.get("PROMETHEUS_URL_SSL_VERIFY", "True").lower() in ("true", "1", "yes"), disable_prometheus_links=os.environ.get("PROMETHEUS_DISABLE_LINKS", "False").lower() in ("true", "1", "yes"), username=os.environ.get("PROMETHEUS_USERNAME", ""), password=os.environ.get("PROMETHEUS_PASSWORD", ""), token=os.environ.get("PROMETHEUS_TOKEN", ""), org_id=os.environ.get("ORG_ID", ""), mcp_server_config=MCPServerConfig( mcp_server_transport=os.environ.get("PROMETHEUS_MCP_SERVER_TRANSPORT", "stdio").lower(), mcp_bind_host=os.environ.get("PROMETHEUS_MCP_BIND_HOST", "127.0.0.1"), mcp_bind_port=int(os.environ.get("PROMETHEUS_MCP_BIND_PORT", "8080")), stateless_http=os.environ.get("PROMETHEUS_MCP_STATELESS_HTTP", "False").lower() in ("true", "1", "yes"), ), client_cert=os.environ.get("PROMETHEUS_CLIENT_CERT", "") or None, client_key=os.environ.get("PROMETHEUS_CLIENT_KEY", "") or None, custom_headers=json.loads(os.environ.get("PROMETHEUS_CUSTOM_HEADERS")) if os.environ.get("PROMETHEUS_CUSTOM_HEADERS") else None, request_timeout=int(os.environ.get("PROMETHEUS_REQUEST_TIMEOUT", "30")), ) def get_prometheus_auth(): """Get authentication for Prometheus based on provided credentials.""" if config.token: return {"Authorization": f"Bearer {config.token}"} elif config.username and config.password: return requests.auth.HTTPBasicAuth(config.username, config.password) return None def make_prometheus_request(endpoint, params=None): """Make a request to the Prometheus API with proper authentication and headers.""" if not config.url: logger.error("Prometheus configuration missing", error="PROMETHEUS_URL not set") raise ValueError("Prometheus configuration is missing. Please set PROMETHEUS_URL environment variable.") if not config.url_ssl_verify: logger.warning("SSL certificate verification is disabled. This is insecure and should not be used in production environments.", endpoint=endpoint) url = f"{config.url.rstrip('/')}/api/v1/{endpoint}" url_ssl_verify = config.url_ssl_verify auth = get_prometheus_auth() headers = {} if isinstance(auth, dict): # Token auth is passed via headers headers.update(auth) auth = None # Clear auth for requests.get if it's already in headers # Add OrgID header if specified if config.org_id: headers["X-Scope-OrgID"] = config.org_id if config.custom_headers: headers.update(config.custom_headers) # Build client certificate tuple for mutual TLS authentication client_cert = None if config.client_cert: if config.client_key: client_cert = (config.client_cert, config.client_key) else: client_cert = config.client_cert try: logger.debug("Making Prometheus API request", endpoint=endpoint, url=url, params=params, headers=headers, timeout=config.request_timeout) # Make the request with appropriate headers, auth, and timeout (DDoS protection) response = requests.get(url, params=params, auth=auth, headers=headers, verify=url_ssl_verify, cert=client_cert, timeout=config.request_timeout) response.raise_for_status() result = response.json() if result["status"] != "success": error_msg = result.get('error', 'Unknown error') logger.error("Prometheus API returned error", endpoint=endpoint, error=error_msg, status=result["status"]) raise ValueError(f"Prometheus API error: {error_msg}") data_field = result.get("data", {}) if isinstance(data_field, dict): result_type = data_field.get("resultType") else: result_type = "list" logger.debug("Prometheus API request successful", endpoint=endpoint, result_type=result_type) return result["data"] except requests.exceptions.RequestException as e: logger.error("HTTP request to Prometheus failed", endpoint=endpoint, url=url, error=str(e), error_type=type(e).__name__) raise except json.JSONDecodeError as e: logger.error("Failed to parse Prometheus response as JSON", endpoint=endpoint, url=url, error=str(e)) raise ValueError(f"Invalid JSON response from Prometheus: {str(e)}") except Exception as e: logger.error("Unexpected error during Prometheus request", endpoint=endpoint, url=url, error=str(e), error_type=type(e).__name__) raise def get_cached_metrics() -> List[str]: """Get metrics list with caching to improve completion performance. This helper function is available for future completion support when FastMCP implements the completion capability. For now, it can be used internally to optimize repeated metric list requests. """ current_time = time.time() # snapshot for clarity cached_data = _metrics_cache["data"] cached_timestamp = _metrics_cache["timestamp"] if cached_data is not None and (current_time - cached_timestamp) < _CACHE_TTL: logger.debug("Using cached metrics list", cache_age=current_time - cached_timestamp) return cached_data # Fetch fresh metrics data = make_prometheus_request("label/__name__/values") _metrics_cache["data"] = data _metrics_cache["timestamp"] = current_time logger.debug("Refreshed metrics cache", metric_count=len(data)) return data # Note: Argument completions will be added when FastMCP supports the completion # capability. The get_cached_metrics() function above is ready for that integration. @mcp.tool( name=_tool_name("execute_query"), description="Execute a PromQL instant query against Prometheus", annotations={ "title": "Execute PromQL Query", "icon": "📊", "readOnlyHint": True, "destructiveHint": False, "idempotentHint": True, "openWorldHint": True } ) async def execute_query(query: str, time: Optional[str] = None) -> Dict[str, Any]: """Execute an instant query against Prometheus. Args: query: PromQL query string time: Optional RFC3339 or Unix timestamp (default: current time) Returns: Query result with type (vector, matrix, scalar, string) and values """ params = {"query": query} if time: params["time"] = time logger.info("Executing instant query", query=query, time=time) data = make_prometheus_request("query", params=params) result = { "resultType": data["resultType"], "result": data["result"] } if not config.disable_prometheus_links: from urllib.parse import urlencode ui_params = {"g0.expr": query, "g0.tab": "0"} if time: ui_params["g0.moment_input"] = time prometheus_ui_link = f"{config.url.rstrip('/')}/graph?{urlencode(ui_params)}" result["links"] = [{ "href": prometheus_ui_link, "rel": "prometheus-ui", "title": "View in Prometheus UI" }] logger.info("Instant query completed", query=query, result_type=data["resultType"], result_count=len(data["result"]) if isinstance(data["result"], list) else 1) return result @mcp.tool( name=_tool_name("execute_range_query"), description="Execute a PromQL range query with start time, end time, and step interval", annotations={ "title": "Execute PromQL Range Query", "icon": "📈", "readOnlyHint": True, "destructiveHint": False, "idempotentHint": True, "openWorldHint": True } ) async def execute_range_query(query: str, start: str, end: str, step: str, ctx: Context | None = None) -> Dict[str, Any]: """Execute a range query against Prometheus. Args: query: PromQL query string start: Start time as RFC3339 or Unix timestamp end: End time as RFC3339 or Unix timestamp step: Query resolution step width (e.g., '15s', '1m', '1h') Returns: Range query result with type (usually matrix) and values over time """ params = { "query": query, "start": start, "end": end, "step": step } logger.info("Executing range query", query=query, start=start, end=end, step=step) # Report progress if context available if ctx: await ctx.report_progress(progress=0, total=100, message="Initiating range query...") data = make_prometheus_request("query_range", params=params) # Report progress if ctx: await ctx.report_progress(progress=50, total=100, message="Processing query results...") result = { "resultType": data["resultType"], "result": data["result"] } if not config.disable_prometheus_links: from urllib.parse import urlencode ui_params = { "g0.expr": query, "g0.tab": "0", "g0.range_input": f"{start} to {end}", "g0.step_input": step } prometheus_ui_link = f"{config.url.rstrip('/')}/graph?{urlencode(ui_params)}" result["links"] = [{ "href": prometheus_ui_link, "rel": "prometheus-ui", "title": "View in Prometheus UI" }] # Report completion if ctx: await ctx.report_progress(progress=100, total=100, message="Range query completed") logger.info("Range query completed", query=query, result_type=data["resultType"], result_count=len(data["result"]) if isinstance(data["result"], list) else 1) return result @mcp.tool( name=_tool_name("list_metrics"), description="List all available metrics in Prometheus with optional pagination support", annotations={ "title": "List Available Metrics", "icon": "📋", "readOnlyHint": True, "destructiveHint": False, "idempotentHint": True, "openWorldHint": True } ) async def list_metrics( limit: Optional[int] = None, offset: int = 0, filter_pattern: Optional[str] = None, ctx: Context | None = None, refresh_cache: bool = False, ) -> Dict[str, Any]: """Retrieve a list of all metric names available in Prometheus. Args: limit: Maximum number of metrics to return (default: all metrics) offset: Number of metrics to skip for pagination (default: 0) filter_pattern: Optional substring to filter metric names (case-insensitive) refresh_cache: Force a cache refresh to pick up newly scraped metrics (default: False) Returns: Dictionary containing: - metrics: List of metric names - total_count: Total number of metrics (before pagination) - returned_count: Number of metrics returned - offset: Current offset - has_more: Whether more metrics are available """ logger.info("Listing available metrics", limit=limit, offset=offset, filter_pattern=filter_pattern, refresh_cache=refresh_cache) # Report progress if context available if ctx: await ctx.report_progress(progress=0, total=100, message="Fetching metrics list...") if refresh_cache: clear_metrics_cache() data = get_cached_metrics() if ctx: await ctx.report_progress(progress=50, total=100, message=f"Processing {len(data)} metrics...") # Apply filter if provided if filter_pattern: filtered_data = [m for m in data if filter_pattern.lower() in m.lower()] logger.debug("Applied filter", original_count=len(data), filtered_count=len(filtered_data), pattern=filter_pattern) data = filtered_data total_count = len(data) # Apply pagination start_idx = offset end_idx = offset + limit if limit is not None else len(data) paginated_data = data[start_idx:end_idx] result = { "metrics": paginated_data, "total_count": total_count, "returned_count": len(paginated_data), "offset": offset, "has_more": end_idx < total_count } if ctx: await ctx.report_progress(progress=100, total=100, message=f"Retrieved {len(paginated_data)} of {total_count} metrics") logger.info("Metrics list retrieved", total_count=total_count, returned_count=len(paginated_data), offset=offset, has_more=result["has_more"]) return result def _coerce_metadata_entries(value: Any) -> List[Dict[str, Any]]: """Normalize metadata value into a list of metadata dictionaries.""" if isinstance(value, list): return [entry for entry in value if isinstance(entry, dict)] if isinstance(value, dict): return [value] return [] def _normalize_metadata_map(raw_data: Any) -> Dict[str, List[Dict[str, Any]]]: """Normalize diverse metadata response shapes into {metric_name: [entries]}.""" if isinstance(raw_data, dict): if "metadata" in raw_data: return _normalize_metadata_map(raw_data["metadata"]) if "data" in raw_data: return _normalize_metadata_map(raw_data["data"]) normalized: Dict[str, List[Dict[str, Any]]] = {} for metric_name, entries in raw_data.items(): if not isinstance(metric_name, str): continue coerced_entries = _coerce_metadata_entries(entries) if coerced_entries: normalized[metric_name] = coerced_entries if normalized: return normalized metric_name = raw_data.get("metric") if isinstance(metric_name, str): return {metric_name: [raw_data]} if isinstance(raw_data, list): grouped: Dict[str, List[Dict[str, Any]]] = {} for entry in raw_data: if not isinstance(entry, dict): continue metric_name = entry.get("metric") if not isinstance(metric_name, str): continue grouped.setdefault(metric_name, []).append(entry) return grouped return {} def _metadata_matches_pattern(metric_name: str, entries: List[Dict[str, Any]], pattern: str) -> bool: """Return True when pattern matches metric name or metadata text fields.""" lowered_pattern = pattern.lower() if lowered_pattern in metric_name.lower(): return True for entry in entries: for value in entry.values(): if isinstance(value, str) and lowered_pattern in value.lower(): return True return False @mcp.tool( name=_tool_name("get_metric_metadata"), description=( "Get metadata (type, help, unit) for metrics. " "Returns all metric metadata when no metric name is provided. " "Use filter_pattern to search metric names and descriptions." ), annotations={ "title": "Get Metric Metadata", "icon": "â„šī¸", "readOnlyHint": True, "destructiveHint": False, "idempotentHint": True, "openWorldHint": True } ) async def get_metric_metadata( metric: Optional[str] = None, filter_pattern: Optional[str] = None, limit: Optional[int] = None, offset: int = 0, ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: """Get metadata for one metric or bulk metadata for all metrics. Args: metric: Optional metric name. If provided, returns legacy list format. filter_pattern: Optional substring filter on metric name and descriptions. limit: Maximum number of metrics to return in bulk mode. offset: Number of metrics to skip in bulk mode. Returns: If metric is provided: list of metadata entries for that metric. If metric is not provided: dict with filtered metadata and pagination info. """ logger.info("Retrieving metric metadata", metric=metric, filter_pattern=filter_pattern, limit=limit, offset=offset) params = {"metric": metric} if metric else None raw_data = make_prometheus_request("metadata", params=params) metadata_by_metric = _normalize_metadata_map(raw_data) # Fallback for atypical single-metric response formats. if metric and metric not in metadata_by_metric: fallback_entries = _coerce_metadata_entries(raw_data) if fallback_entries: metadata_by_metric[metric] = fallback_entries if filter_pattern: metadata_by_metric = { metric_name: entries for metric_name, entries in metadata_by_metric.items() if _metadata_matches_pattern(metric_name, entries, filter_pattern) } if metric: metric_entries = metadata_by_metric.get(metric, []) logger.info("Metric metadata retrieved", metric=metric, metadata_count=len(metric_entries)) return metric_entries metric_names = list(metadata_by_metric.keys()) total_count = len(metric_names) start_idx = offset end_idx = offset + limit if limit is not None else total_count selected_metric_names = metric_names[start_idx:end_idx] paginated_metadata = {name: metadata_by_metric[name] for name in selected_metric_names} result = { "metadata": paginated_metadata, "total_count": total_count, "returned_count": len(paginated_metadata), "offset": offset, "has_more": end_idx < total_count, } logger.info( "Bulk metric metadata retrieved", total_count=total_count, returned_count=result["returned_count"], offset=offset, has_more=result["has_more"], ) return result @mcp.tool( name=_tool_name("get_targets"), description="Get information about all scrape targets", annotations={ "title": "Get Scrape Targets", "icon": "đŸŽ¯", "readOnlyHint": True, "destructiveHint": False, "idempotentHint": True, "openWorldHint": True } ) async def get_targets() -> Dict[str, List[Dict[str, Any]]]: """Get information about all Prometheus scrape targets. Returns: Dictionary with active and dropped targets information """ logger.info("Retrieving scrape targets information") data = make_prometheus_request("targets") result = { "activeTargets": data["activeTargets"], "droppedTargets": data["droppedTargets"] } logger.info("Scrape targets retrieved", active_targets=len(data["activeTargets"]), dropped_targets=len(data["droppedTargets"])) return result if __name__ == "__main__": logger.info("Starting Prometheus MCP Server", mode="direct") mcp.run()