konduktor/pyserve/extensions.py
2025-12-04 01:25:13 +03:00

307 lines
11 KiB
Python

import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from starlette.requests import Request
from starlette.responses import Response
from .logging_utils import get_logger
logger = get_logger(__name__)
class Extension(ABC):
def __init__(self, config: Dict[str, Any]):
self.config = config
self.enabled = True
@abstractmethod
async def process_request(self, request: Request) -> Optional[Response]:
pass
@abstractmethod
async def process_response(self, request: Request, response: Response) -> Response:
pass
def initialize(self) -> None:
pass
def cleanup(self) -> None:
pass
class RoutingExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
from .routing import create_router_from_config
regex_locations = config.get("regex_locations", {})
default_proxy_timeout = config.get("default_proxy_timeout", 30.0)
self.router = create_router_from_config(regex_locations)
from .routing import RequestHandler
self.handler = RequestHandler(self.router, default_proxy_timeout=default_proxy_timeout)
async def process_request(self, request: Request) -> Optional[Response]:
try:
return await self.handler.handle(request)
except Exception as e:
logger.error(f"Error in RoutingExtension: {e}")
return None
async def process_response(self, request: Request, response: Response) -> Response:
return response
class SecurityExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.allowed_ips = config.get("allowed_ips", [])
self.blocked_ips = config.get("blocked_ips", [])
self.security_headers = config.get(
"security_headers", {"X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "X-XSS-Protection": "1; mode=block"}
)
async def process_request(self, request: Request) -> Optional[Response]:
client_ip = request.client.host if request.client else "unknown"
if self.blocked_ips and client_ip in self.blocked_ips:
logger.warning(f"Blocked request from IP: {client_ip}")
from starlette.responses import PlainTextResponse
return PlainTextResponse("403 Forbidden", status_code=403)
if self.allowed_ips and client_ip not in self.allowed_ips:
logger.warning(f"Access denied for IP: {client_ip}")
from starlette.responses import PlainTextResponse
return PlainTextResponse("403 Forbidden", status_code=403)
return None
async def process_response(self, request: Request, response: Response) -> Response:
for header, value in self.security_headers.items():
response.headers[header] = value
return response
class CachingExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.cache: Dict[str, Any] = {}
self.cache_patterns = config.get("cache_patterns", [])
self.cache_ttl = config.get("cache_ttl", 3600)
async def process_request(self, request: Request) -> Optional[Response]:
# TODO: Implement cache check
return None
async def process_response(self, request: Request, response: Response) -> Response:
# TODO: Implement response caching
return response
class MonitoringExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.request_count = 0
self.error_count = 0
self.response_times: list[float] = []
self.enable_metrics = config.get("enable_metrics", True)
async def process_request(self, request: Request) -> Optional[Response]:
if self.enable_metrics:
self.request_count += 1
request.state.start_time = __import__("time").time()
return None
async def process_response(self, request: Request, response: Response) -> Response:
if self.enable_metrics and hasattr(request.state, "start_time"):
response_time = __import__("time").time() - request.state.start_time
self.response_times.append(response_time)
if response.status_code >= 400:
self.error_count += 1
logger.info(f"Request: {request.method} {request.url.path} - " f"Status: {response.status_code} - " f"Time: {response_time:.3f}s")
return response
def get_metrics(self) -> Dict[str, Any]:
avg_response_time = sum(self.response_times) / len(self.response_times) if self.response_times else 0
return {
"request_count": self.request_count,
"error_count": self.error_count,
"error_rate": self.error_count / max(self.request_count, 1),
"avg_response_time": avg_response_time,
"total_response_times": len(self.response_times),
}
class ASGIExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
from .asgi_mount import ASGIMountManager
self.mount_manager = ASGIMountManager()
self._load_mounts(config.get("mounts", []))
def _load_mounts(self, mounts: List[Dict[str, Any]]) -> None:
from .asgi_mount import create_django_app
for mount_config in mounts:
path = mount_config.get("path", "/")
if "django_settings" in mount_config:
app = create_django_app(
settings_module=mount_config["django_settings"],
module_path=mount_config.get("module_path"),
)
if app:
self.mount_manager.mount(
path=path,
app=app,
name=mount_config.get("name", f"django:{mount_config['django_settings']}"),
strip_path=mount_config.get("strip_path", True),
)
continue
self.mount_manager.mount(
path=path,
app_path=mount_config.get("app_path"),
app_type=mount_config.get("app_type", "asgi"),
module_path=mount_config.get("module_path"),
factory=mount_config.get("factory", False),
factory_args=mount_config.get("factory_args"),
name=mount_config.get("name", ""),
strip_path=mount_config.get("strip_path", True),
)
async def process_request(self, request: Request) -> Optional[Response]:
path = request.url.path
mount = self.mount_manager.get_mount(path)
if mount is not None:
# Store mount info in request state for middleware to use
request.state.asgi_mount = mount
# Return a special marker response that middleware will intercept
return None # Will be handled by get_asgi_handler
return None
async def process_response(self, request: Request, response: Response) -> Response:
return response
def get_asgi_handler(self, request: Request) -> Optional[Any]:
path = request.url.path
return self.mount_manager.get_mount(path)
def get_metrics(self) -> Dict[str, Any]:
return {
"asgi_mounts": self.mount_manager.list_mounts(),
"asgi_mount_count": len(self.mount_manager.mounts),
}
def cleanup(self) -> None:
logger.info("Cleaning up ASGI mounts")
class ExtensionManager:
def __init__(self) -> None:
self.extensions: List[Extension] = []
self.extension_registry: Dict[str, Type[Extension]] = {
"routing": RoutingExtension,
"security": SecurityExtension,
"caching": CachingExtension,
"monitoring": MonitoringExtension,
"asgi": ASGIExtension,
}
self._register_process_orchestration()
def _register_process_orchestration(self) -> None:
try:
from .process_extension import ProcessOrchestrationExtension
self.extension_registry["process_orchestration"] = ProcessOrchestrationExtension # type: ignore
except ImportError:
pass # Optional dependency
def register_extension_type(self, name: str, extension_class: Type[Extension]) -> None:
self.extension_registry[name] = extension_class
def load_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
if extension_type not in self.extension_registry:
logger.error(f"Unknown extension type: {extension_type}")
return
try:
extension_class = self.extension_registry[extension_type]
extension = extension_class(config)
extension.initialize()
self.extensions.append(extension)
logger.info(f"Loaded extension: {extension_type}")
except Exception as e:
logger.error(f"Error loading extension {extension_type}: {e}")
async def load_extension_async(self, extension_type: str, config: Dict[str, Any]) -> None:
"""Load extension with async setup support (for ProcessOrchestration)."""
if extension_type not in self.extension_registry:
logger.error(f"Unknown extension type: {extension_type}")
return
try:
extension_class = self.extension_registry[extension_type]
extension = extension_class(config)
setup_method = getattr(extension, "setup", None)
if setup_method is not None and asyncio.iscoroutinefunction(setup_method):
await setup_method(config)
else:
extension.initialize()
start_method = getattr(extension, "start", None)
if start_method is not None and asyncio.iscoroutinefunction(start_method):
await start_method()
# Insert at the beginning so process_orchestration is checked first
self.extensions.insert(0, extension)
logger.info(f"Loaded extension (async): {extension_type}")
except Exception as e:
logger.error(f"Error loading extension {extension_type}: {e}")
async def process_request(self, request: Request) -> Optional[Response]:
for extension in self.extensions:
if not extension.enabled:
continue
try:
response = await extension.process_request(request)
if response is not None:
return response
except Exception as e:
logger.error(f"Error in extension {type(extension).__name__}: {e}")
return None
async def process_response(self, request: Request, response: Response) -> Response:
for extension in self.extensions:
if not extension.enabled:
continue
try:
response = await extension.process_response(request, response)
except Exception as e:
logger.error(f"Error in extension {type(extension).__name__}: {e}")
return response
def cleanup(self) -> None:
for extension in self.extensions:
try:
extension.cleanup()
except Exception as e:
logger.error(f"Error cleaning up extension {type(extension).__name__}: {e}")
self.extensions.clear()