import asyncio from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type from starlette.requests import Request from starlette.responses import Response from .logging_utils import get_logger logger = get_logger(__name__) class Extension(ABC): def __init__(self, config: Dict[str, Any]): self.config = config self.enabled = True @abstractmethod async def process_request(self, request: Request) -> Optional[Response]: pass @abstractmethod async def process_response(self, request: Request, response: Response) -> Response: pass def initialize(self) -> None: pass def cleanup(self) -> None: pass class RoutingExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) from .routing import create_router_from_config regex_locations = config.get("regex_locations", {}) default_proxy_timeout = config.get("default_proxy_timeout", 30.0) self.router = create_router_from_config(regex_locations) from .routing import RequestHandler self.handler = RequestHandler(self.router, default_proxy_timeout=default_proxy_timeout) async def process_request(self, request: Request) -> Optional[Response]: try: return await self.handler.handle(request) except Exception as e: logger.error(f"Error in RoutingExtension: {e}") return None async def process_response(self, request: Request, response: Response) -> Response: return response class SecurityExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) self.allowed_ips = config.get("allowed_ips", []) self.blocked_ips = config.get("blocked_ips", []) self.security_headers = config.get( "security_headers", {"X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "X-XSS-Protection": "1; mode=block"} ) async def process_request(self, request: Request) -> Optional[Response]: client_ip = request.client.host if request.client else "unknown" if self.blocked_ips and client_ip in self.blocked_ips: logger.warning(f"Blocked request from IP: {client_ip}") from starlette.responses import PlainTextResponse return PlainTextResponse("403 Forbidden", status_code=403) if self.allowed_ips and client_ip not in self.allowed_ips: logger.warning(f"Access denied for IP: {client_ip}") from starlette.responses import PlainTextResponse return PlainTextResponse("403 Forbidden", status_code=403) return None async def process_response(self, request: Request, response: Response) -> Response: for header, value in self.security_headers.items(): response.headers[header] = value return response class CachingExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) self.cache: Dict[str, Any] = {} self.cache_patterns = config.get("cache_patterns", []) self.cache_ttl = config.get("cache_ttl", 3600) async def process_request(self, request: Request) -> Optional[Response]: # TODO: Implement cache check return None async def process_response(self, request: Request, response: Response) -> Response: # TODO: Implement response caching return response class MonitoringExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) self.request_count = 0 self.error_count = 0 self.response_times: list[float] = [] self.enable_metrics = config.get("enable_metrics", True) async def process_request(self, request: Request) -> Optional[Response]: if self.enable_metrics: self.request_count += 1 request.state.start_time = __import__("time").time() return None async def process_response(self, request: Request, response: Response) -> Response: if self.enable_metrics and hasattr(request.state, "start_time"): response_time = __import__("time").time() - request.state.start_time self.response_times.append(response_time) if response.status_code >= 400: self.error_count += 1 logger.info(f"Request: {request.method} {request.url.path} - " f"Status: {response.status_code} - " f"Time: {response_time:.3f}s") return response def get_metrics(self) -> Dict[str, Any]: avg_response_time = sum(self.response_times) / len(self.response_times) if self.response_times else 0 return { "request_count": self.request_count, "error_count": self.error_count, "error_rate": self.error_count / max(self.request_count, 1), "avg_response_time": avg_response_time, "total_response_times": len(self.response_times), } class ASGIExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) from .asgi_mount import ASGIMountManager self.mount_manager = ASGIMountManager() self._load_mounts(config.get("mounts", [])) def _load_mounts(self, mounts: List[Dict[str, Any]]) -> None: from .asgi_mount import create_django_app for mount_config in mounts: path = mount_config.get("path", "/") if "django_settings" in mount_config: app = create_django_app( settings_module=mount_config["django_settings"], module_path=mount_config.get("module_path"), ) if app: self.mount_manager.mount( path=path, app=app, name=mount_config.get("name", f"django:{mount_config['django_settings']}"), strip_path=mount_config.get("strip_path", True), ) continue self.mount_manager.mount( path=path, app_path=mount_config.get("app_path"), app_type=mount_config.get("app_type", "asgi"), module_path=mount_config.get("module_path"), factory=mount_config.get("factory", False), factory_args=mount_config.get("factory_args"), name=mount_config.get("name", ""), strip_path=mount_config.get("strip_path", True), ) async def process_request(self, request: Request) -> Optional[Response]: path = request.url.path mount = self.mount_manager.get_mount(path) if mount is not None: # Store mount info in request state for middleware to use request.state.asgi_mount = mount # Return a special marker response that middleware will intercept return None # Will be handled by get_asgi_handler return None async def process_response(self, request: Request, response: Response) -> Response: return response def get_asgi_handler(self, request: Request) -> Optional[Any]: path = request.url.path return self.mount_manager.get_mount(path) def get_metrics(self) -> Dict[str, Any]: return { "asgi_mounts": self.mount_manager.list_mounts(), "asgi_mount_count": len(self.mount_manager.mounts), } def cleanup(self) -> None: logger.info("Cleaning up ASGI mounts") class ExtensionManager: def __init__(self) -> None: self.extensions: List[Extension] = [] self.extension_registry: Dict[str, Type[Extension]] = { "routing": RoutingExtension, "security": SecurityExtension, "caching": CachingExtension, "monitoring": MonitoringExtension, "asgi": ASGIExtension, } self._register_process_orchestration() def _register_process_orchestration(self) -> None: try: from .process_extension import ProcessOrchestrationExtension self.extension_registry["process_orchestration"] = ProcessOrchestrationExtension # type: ignore except ImportError: pass # Optional dependency def register_extension_type(self, name: str, extension_class: Type[Extension]) -> None: self.extension_registry[name] = extension_class def load_extension(self, extension_type: str, config: Dict[str, Any]) -> None: if extension_type not in self.extension_registry: logger.error(f"Unknown extension type: {extension_type}") return try: extension_class = self.extension_registry[extension_type] extension = extension_class(config) extension.initialize() self.extensions.append(extension) logger.info(f"Loaded extension: {extension_type}") except Exception as e: logger.error(f"Error loading extension {extension_type}: {e}") async def load_extension_async(self, extension_type: str, config: Dict[str, Any]) -> None: """Load extension with async setup support (for ProcessOrchestration).""" if extension_type not in self.extension_registry: logger.error(f"Unknown extension type: {extension_type}") return try: extension_class = self.extension_registry[extension_type] extension = extension_class(config) setup_method = getattr(extension, "setup", None) if setup_method is not None and asyncio.iscoroutinefunction(setup_method): await setup_method(config) else: extension.initialize() start_method = getattr(extension, "start", None) if start_method is not None and asyncio.iscoroutinefunction(start_method): await start_method() # Insert at the beginning so process_orchestration is checked first self.extensions.insert(0, extension) logger.info(f"Loaded extension (async): {extension_type}") except Exception as e: logger.error(f"Error loading extension {extension_type}: {e}") async def process_request(self, request: Request) -> Optional[Response]: for extension in self.extensions: if not extension.enabled: continue try: response = await extension.process_request(request) if response is not None: return response except Exception as e: logger.error(f"Error in extension {type(extension).__name__}: {e}") return None async def process_response(self, request: Request, response: Response) -> Response: for extension in self.extensions: if not extension.enabled: continue try: response = await extension.process_response(request, response) except Exception as e: logger.error(f"Error in extension {type(extension).__name__}: {e}") return response def cleanup(self) -> None: for extension in self.extensions: try: extension.cleanup() except Exception as e: logger.error(f"Error cleaning up extension {type(extension).__name__}: {e}") self.extensions.clear()