Spaces:
Sleeping
Sleeping
| """ | |
| Enterprise Rate Limiting for MCP Servers | |
| Features: | |
| - Token bucket algorithm for smooth rate limiting | |
| - Per-client rate limiting | |
| - Global rate limiting | |
| - Different limits for different endpoints | |
| - Distributed rate limiting with Redis (optional) | |
| """ | |
| import time | |
| import logging | |
| from typing import Dict, Optional | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from aiohttp import web | |
| import asyncio | |
| logger = logging.getLogger(__name__) | |
| class TokenBucket: | |
| """Token bucket for rate limiting""" | |
| capacity: int # Maximum tokens | |
| refill_rate: float # Tokens per second | |
| tokens: float = field(default=0) | |
| last_refill: float = field(default_factory=time.time) | |
| def __post_init__(self): | |
| self.tokens = self.capacity | |
| def _refill(self): | |
| """Refill tokens based on time elapsed""" | |
| now = time.time() | |
| elapsed = now - self.last_refill | |
| # Add tokens based on refill rate | |
| self.tokens = min( | |
| self.capacity, | |
| self.tokens + (elapsed * self.refill_rate) | |
| ) | |
| self.last_refill = now | |
| def consume(self, tokens: int = 1) -> bool: | |
| """ | |
| Try to consume tokens | |
| Returns: | |
| True if tokens were available, False otherwise | |
| """ | |
| self._refill() | |
| if self.tokens >= tokens: | |
| self.tokens -= tokens | |
| return True | |
| return False | |
| def get_wait_time(self, tokens: int = 1) -> float: | |
| """ | |
| Get time to wait until tokens are available | |
| Returns: | |
| Seconds to wait | |
| """ | |
| self._refill() | |
| if self.tokens >= tokens: | |
| return 0.0 | |
| tokens_needed = tokens - self.tokens | |
| return tokens_needed / self.refill_rate | |
| class RateLimiter: | |
| """ | |
| In-memory rate limiter with token bucket algorithm | |
| """ | |
| def __init__(self): | |
| # Client-specific buckets | |
| self.client_buckets: Dict[str, TokenBucket] = {} | |
| # Global bucket for all requests | |
| self.global_bucket: Optional[TokenBucket] = None | |
| # Endpoint-specific limits | |
| self.endpoint_limits: Dict[str, Dict] = { | |
| "/rpc": {"capacity": 100, "refill_rate": 10.0}, # 100 requests, 10/sec refill | |
| "default": {"capacity": 50, "refill_rate": 5.0} # Default for other endpoints | |
| } | |
| # Global rate limit (disabled by default) | |
| # self.global_bucket = TokenBucket(capacity=1000, refill_rate=100.0) | |
| # Cleanup task | |
| self._cleanup_task = None | |
| logger.info("Rate limiter initialized") | |
| def _get_client_id(self, request: web.Request) -> str: | |
| """ | |
| Get client identifier for rate limiting | |
| Uses (in order): | |
| 1. API key | |
| 2. IP address | |
| """ | |
| # Try API key first | |
| if "api_key" in request and hasattr(request["api_key"], "key_id"): | |
| return f"key:{request['api_key'].key_id}" | |
| # Fall back to IP address | |
| peername = request.transport.get_extra_info('peername') | |
| if peername: | |
| return f"ip:{peername[0]}" | |
| return "unknown" | |
| def _get_endpoint_limits(self, path: str) -> Dict: | |
| """Get rate limits for endpoint""" | |
| return self.endpoint_limits.get(path, self.endpoint_limits["default"]) | |
| def _get_or_create_bucket(self, client_id: str, path: str) -> TokenBucket: | |
| """Get or create token bucket for client""" | |
| bucket_key = f"{client_id}:{path}" | |
| if bucket_key not in self.client_buckets: | |
| limits = self._get_endpoint_limits(path) | |
| self.client_buckets[bucket_key] = TokenBucket( | |
| capacity=limits["capacity"], | |
| refill_rate=limits["refill_rate"] | |
| ) | |
| return self.client_buckets[bucket_key] | |
| async def check_rate_limit( | |
| self, | |
| request: web.Request, | |
| tokens: int = 1 | |
| ) -> tuple[bool, Optional[float]]: | |
| """ | |
| Check if request is within rate limit | |
| Returns: | |
| Tuple of (allowed, retry_after_seconds) | |
| """ | |
| client_id = self._get_client_id(request) | |
| path = request.path | |
| # Check global rate limit first (if enabled) | |
| if self.global_bucket: | |
| if not self.global_bucket.consume(tokens): | |
| wait_time = self.global_bucket.get_wait_time(tokens) | |
| logger.warning(f"Global rate limit exceeded, retry after {wait_time:.2f}s") | |
| return False, wait_time | |
| # Check client-specific rate limit | |
| bucket = self._get_or_create_bucket(client_id, path) | |
| if not bucket.consume(tokens): | |
| wait_time = bucket.get_wait_time(tokens) | |
| logger.warning(f"Rate limit exceeded for {client_id} on {path}, retry after {wait_time:.2f}s") | |
| return False, wait_time | |
| return True, None | |
| async def start_cleanup_task(self): | |
| """Start background cleanup task""" | |
| if self._cleanup_task is None: | |
| self._cleanup_task = asyncio.create_task(self._cleanup_loop()) | |
| logger.info("Rate limiter cleanup task started") | |
| async def _cleanup_loop(self): | |
| """Periodically clean up old buckets""" | |
| while True: | |
| await asyncio.sleep(300) # Every 5 minutes | |
| # Remove buckets that haven't been used recently | |
| cutoff_time = time.time() - 600 # 10 minutes | |
| removed = 0 | |
| for key in list(self.client_buckets.keys()): | |
| bucket = self.client_buckets[key] | |
| if bucket.last_refill < cutoff_time: | |
| del self.client_buckets[key] | |
| removed += 1 | |
| if removed > 0: | |
| logger.info(f"Cleaned up {removed} unused rate limit buckets") | |
| class RateLimitMiddleware: | |
| """aiohttp middleware for rate limiting""" | |
| def __init__(self, rate_limiter: RateLimiter, exempt_paths: set[str] = None): | |
| self.rate_limiter = rate_limiter | |
| self.exempt_paths = exempt_paths or {"/health", "/metrics"} | |
| logger.info("Rate limit middleware initialized") | |
| async def middleware(self, request: web.Request, handler): | |
| """Middleware handler""" | |
| # Skip rate limiting for exempt paths | |
| if request.path in self.exempt_paths: | |
| return await handler(request) | |
| # Check rate limit | |
| allowed, retry_after = await self.rate_limiter.check_rate_limit(request) | |
| if not allowed: | |
| return web.json_response( | |
| { | |
| "error": "Rate limit exceeded", | |
| "message": f"Too many requests. Please retry after {retry_after:.2f} seconds.", | |
| "retry_after": retry_after | |
| }, | |
| status=429, | |
| headers={"Retry-After": str(int(retry_after) + 1)} | |
| ) | |
| # Add rate limit headers | |
| response = await handler(request) | |
| # TODO: Add X-RateLimit-* headers | |
| # response.headers["X-RateLimit-Limit"] = "100" | |
| # response.headers["X-RateLimit-Remaining"] = "95" | |
| return response | |
| class RedisRateLimiter: | |
| """ | |
| Distributed rate limiter using Redis | |
| Suitable for multi-instance deployments | |
| """ | |
| def __init__(self, redis_client=None): | |
| """ | |
| Initialize with Redis client | |
| Args: | |
| redis_client: redis.asyncio.Redis client | |
| """ | |
| self.redis = redis_client | |
| logger.info("Redis rate limiter initialized" if redis_client else "Redis rate limiter (disabled)") | |
| async def check_rate_limit( | |
| self, | |
| key: str, | |
| limit: int, | |
| window_seconds: int | |
| ) -> tuple[bool, Optional[int]]: | |
| """ | |
| Check rate limit using Redis | |
| Uses sliding window algorithm with Redis sorted sets | |
| Returns: | |
| Tuple of (allowed, retry_after_seconds) | |
| """ | |
| if not self.redis: | |
| # If Redis is not available, allow all requests | |
| return True, None | |
| now = time.time() | |
| window_start = now - window_seconds | |
| try: | |
| # Redis pipeline for atomic operations | |
| pipe = self.redis.pipeline() | |
| # Remove old entries | |
| pipe.zremrangebyscore(key, 0, window_start) | |
| # Count current requests | |
| pipe.zcard(key) | |
| # Add current request | |
| pipe.zadd(key, {str(now): now}) | |
| # Set expiry | |
| pipe.expire(key, window_seconds) | |
| results = await pipe.execute() | |
| count = results[1] # Result from ZCARD | |
| if count < limit: | |
| return True, None | |
| else: | |
| # Calculate retry time | |
| oldest_entries = await self.redis.zrange(key, 0, 0, withscores=True) | |
| if oldest_entries: | |
| oldest_time = oldest_entries[0][1] | |
| retry_after = int(oldest_time + window_seconds - now) + 1 | |
| return False, retry_after | |
| return False, window_seconds | |
| except Exception as e: | |
| logger.error(f"Redis rate limit error: {e}") | |
| # On error, allow request (fail open) | |
| return True, None | |
| # Global rate limiter instance | |
| _rate_limiter: Optional[RateLimiter] = None | |
| def get_rate_limiter() -> RateLimiter: | |
| """Get or create the global rate limiter""" | |
| global _rate_limiter | |
| if _rate_limiter is None: | |
| _rate_limiter = RateLimiter() | |
| return _rate_limiter | |