from abc import ABC, abstractmethod from typing import Dict, Any, List, Optional from starlette.requests import Request from starlette.responses import Response from .logging_utils import get_logger logger = get_logger(__name__) class Extension(ABC): def __init__(self, config: Dict[str, Any]): self.config = config self.enabled = True @abstractmethod async def process_request(self, request: Request) -> Optional[Response]: pass @abstractmethod async def process_response(self, request: Request, response: Response) -> Response: pass def initialize(self) -> None: pass def cleanup(self) -> None: pass class RoutingExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) from .routing import create_router_from_config regex_locations = config.get("regex_locations", {}) self.router = create_router_from_config(regex_locations) from .routing import RequestHandler self.handler = RequestHandler(self.router) async def process_request(self, request: Request) -> Optional[Response]: try: return await self.handler.handle(request) except Exception as e: logger.error(f"Error in RoutingExtension: {e}") return None async def process_response(self, request: Request, response: Response) -> Response: return response class SecurityExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) self.allowed_ips = config.get("allowed_ips", []) self.blocked_ips = config.get("blocked_ips", []) self.security_headers = config.get("security_headers", { "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "X-XSS-Protection": "1; mode=block" }) async def process_request(self, request: Request) -> Optional[Response]: client_ip = request.client.host if request.client else "unknown" if self.blocked_ips and client_ip in self.blocked_ips: logger.warning(f"Blocked request from IP: {client_ip}") from starlette.responses import PlainTextResponse return PlainTextResponse("403 Forbidden", status_code=403) if self.allowed_ips and client_ip not in self.allowed_ips: logger.warning(f"Access denied for IP: {client_ip}") from starlette.responses import PlainTextResponse return PlainTextResponse("403 Forbidden", status_code=403) return None async def process_response(self, request: Request, response: Response) -> Response: for header, value in self.security_headers.items(): response.headers[header] = value return response class CachingExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) self.cache: Dict[str, Any] = {} self.cache_patterns = config.get("cache_patterns", []) self.cache_ttl = config.get("cache_ttl", 3600) async def process_request(self, request: Request) -> Optional[Response]: # TODO: Implement cache check return None async def process_response(self, request: Request, response: Response) -> Response: # TODO: Implement response caching return response class MonitoringExtension(Extension): def __init__(self, config: Dict[str, Any]): super().__init__(config) self.request_count = 0 self.error_count = 0 self.response_times = [] self.enable_metrics = config.get("enable_metrics", True) async def process_request(self, request: Request) -> Optional[Response]: if self.enable_metrics: self.request_count += 1 request.state.start_time = __import__('time').time() return None async def process_response(self, request: Request, response: Response) -> Response: if self.enable_metrics and hasattr(request.state, 'start_time'): response_time = __import__('time').time() - request.state.start_time self.response_times.append(response_time) if response.status_code >= 400: self.error_count += 1 logger.info(f"Request: {request.method} {request.url.path} - " f"Status: {response.status_code} - " f"Time: {response_time:.3f}s") return response def get_metrics(self) -> Dict[str, Any]: avg_response_time = (sum(self.response_times) / len(self.response_times) if self.response_times else 0) return { "request_count": self.request_count, "error_count": self.error_count, "error_rate": self.error_count / max(self.request_count, 1), "avg_response_time": avg_response_time, "total_response_times": len(self.response_times) } class ExtensionManager: def __init__(self): self.extensions: List[Extension] = [] self.extension_registry = { "routing": RoutingExtension, "security": SecurityExtension, "caching": CachingExtension, "monitoring": MonitoringExtension } def register_extension_type(self, name: str, extension_class: type) -> None: self.extension_registry[name] = extension_class def load_extension(self, extension_type: str, config: Dict[str, Any]) -> None: if extension_type not in self.extension_registry: logger.error(f"Неизвестный тип расширения: {extension_type}") return try: extension_class = self.extension_registry[extension_type] extension = extension_class(config) extension.initialize() self.extensions.append(extension) logger.info(f"Загружено расширение: {extension_type}") except Exception as e: logger.error(f"Ошибка загрузки расширения {extension_type}: {e}") async def process_request(self, request: Request) -> Optional[Response]: for extension in self.extensions: if not extension.enabled: continue try: response = await extension.process_request(request) if response is not None: return response except Exception as e: logger.error(f"Ошибка в расширении {type(extension).__name__}: {e}") return None async def process_response(self, request: Request, response: Response) -> Response: for extension in self.extensions: if not extension.enabled: continue try: response = await extension.process_response(request, response) except Exception as e: logger.error(f"Ошибка в расширении {type(extension).__name__}: {e}") return response def cleanup(self) -> None: for extension in self.extensions: try: extension.cleanup() except Exception as e: logger.error(f"Ошибка при очистке расширения {type(extension).__name__}: {e}") self.extensions.clear()