forked from Shifty/pyserveX
307 lines
11 KiB
Python
307 lines
11 KiB
Python
import asyncio
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional, Type
|
|
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
from .logging_utils import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class Extension(ABC):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.config = config
|
|
self.enabled = True
|
|
|
|
@abstractmethod
|
|
async def process_request(self, request: Request) -> Optional[Response]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def process_response(self, request: Request, response: Response) -> Response:
|
|
pass
|
|
|
|
def initialize(self) -> None:
|
|
pass
|
|
|
|
def cleanup(self) -> None:
|
|
pass
|
|
|
|
|
|
class RoutingExtension(Extension):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
from .routing import create_router_from_config
|
|
|
|
regex_locations = config.get("regex_locations", {})
|
|
default_proxy_timeout = config.get("default_proxy_timeout", 30.0)
|
|
self.router = create_router_from_config(regex_locations)
|
|
from .routing import RequestHandler
|
|
|
|
self.handler = RequestHandler(self.router, default_proxy_timeout=default_proxy_timeout)
|
|
|
|
async def process_request(self, request: Request) -> Optional[Response]:
|
|
try:
|
|
return await self.handler.handle(request)
|
|
except Exception as e:
|
|
logger.error(f"Error in RoutingExtension: {e}")
|
|
return None
|
|
|
|
async def process_response(self, request: Request, response: Response) -> Response:
|
|
return response
|
|
|
|
|
|
class SecurityExtension(Extension):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.allowed_ips = config.get("allowed_ips", [])
|
|
self.blocked_ips = config.get("blocked_ips", [])
|
|
self.security_headers = config.get(
|
|
"security_headers", {"X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "X-XSS-Protection": "1; mode=block"}
|
|
)
|
|
|
|
async def process_request(self, request: Request) -> Optional[Response]:
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
|
|
if self.blocked_ips and client_ip in self.blocked_ips:
|
|
logger.warning(f"Blocked request from IP: {client_ip}")
|
|
from starlette.responses import PlainTextResponse
|
|
|
|
return PlainTextResponse("403 Forbidden", status_code=403)
|
|
|
|
if self.allowed_ips and client_ip not in self.allowed_ips:
|
|
logger.warning(f"Access denied for IP: {client_ip}")
|
|
from starlette.responses import PlainTextResponse
|
|
|
|
return PlainTextResponse("403 Forbidden", status_code=403)
|
|
|
|
return None
|
|
|
|
async def process_response(self, request: Request, response: Response) -> Response:
|
|
for header, value in self.security_headers.items():
|
|
response.headers[header] = value
|
|
return response
|
|
|
|
|
|
class CachingExtension(Extension):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.cache: Dict[str, Any] = {}
|
|
self.cache_patterns = config.get("cache_patterns", [])
|
|
self.cache_ttl = config.get("cache_ttl", 3600)
|
|
|
|
async def process_request(self, request: Request) -> Optional[Response]:
|
|
# TODO: Implement cache check
|
|
return None
|
|
|
|
async def process_response(self, request: Request, response: Response) -> Response:
|
|
# TODO: Implement response caching
|
|
return response
|
|
|
|
|
|
class MonitoringExtension(Extension):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.request_count = 0
|
|
self.error_count = 0
|
|
self.response_times: list[float] = []
|
|
self.enable_metrics = config.get("enable_metrics", True)
|
|
|
|
async def process_request(self, request: Request) -> Optional[Response]:
|
|
if self.enable_metrics:
|
|
self.request_count += 1
|
|
request.state.start_time = __import__("time").time()
|
|
return None
|
|
|
|
async def process_response(self, request: Request, response: Response) -> Response:
|
|
if self.enable_metrics and hasattr(request.state, "start_time"):
|
|
response_time = __import__("time").time() - request.state.start_time
|
|
self.response_times.append(response_time)
|
|
|
|
if response.status_code >= 400:
|
|
self.error_count += 1
|
|
|
|
logger.info(f"Request: {request.method} {request.url.path} - " f"Status: {response.status_code} - " f"Time: {response_time:.3f}s")
|
|
|
|
return response
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
avg_response_time = sum(self.response_times) / len(self.response_times) if self.response_times else 0
|
|
|
|
return {
|
|
"request_count": self.request_count,
|
|
"error_count": self.error_count,
|
|
"error_rate": self.error_count / max(self.request_count, 1),
|
|
"avg_response_time": avg_response_time,
|
|
"total_response_times": len(self.response_times),
|
|
}
|
|
|
|
|
|
class ASGIExtension(Extension):
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
from .asgi_mount import ASGIMountManager
|
|
|
|
self.mount_manager = ASGIMountManager()
|
|
self._load_mounts(config.get("mounts", []))
|
|
|
|
def _load_mounts(self, mounts: List[Dict[str, Any]]) -> None:
|
|
from .asgi_mount import create_django_app
|
|
|
|
for mount_config in mounts:
|
|
path = mount_config.get("path", "/")
|
|
|
|
if "django_settings" in mount_config:
|
|
app = create_django_app(
|
|
settings_module=mount_config["django_settings"],
|
|
module_path=mount_config.get("module_path"),
|
|
)
|
|
if app:
|
|
self.mount_manager.mount(
|
|
path=path,
|
|
app=app,
|
|
name=mount_config.get("name", f"django:{mount_config['django_settings']}"),
|
|
strip_path=mount_config.get("strip_path", True),
|
|
)
|
|
continue
|
|
|
|
self.mount_manager.mount(
|
|
path=path,
|
|
app_path=mount_config.get("app_path"),
|
|
app_type=mount_config.get("app_type", "asgi"),
|
|
module_path=mount_config.get("module_path"),
|
|
factory=mount_config.get("factory", False),
|
|
factory_args=mount_config.get("factory_args"),
|
|
name=mount_config.get("name", ""),
|
|
strip_path=mount_config.get("strip_path", True),
|
|
)
|
|
|
|
async def process_request(self, request: Request) -> Optional[Response]:
|
|
path = request.url.path
|
|
mount = self.mount_manager.get_mount(path)
|
|
|
|
if mount is not None:
|
|
# Store mount info in request state for middleware to use
|
|
request.state.asgi_mount = mount
|
|
# Return a special marker response that middleware will intercept
|
|
return None # Will be handled by get_asgi_handler
|
|
|
|
return None
|
|
|
|
async def process_response(self, request: Request, response: Response) -> Response:
|
|
return response
|
|
|
|
def get_asgi_handler(self, request: Request) -> Optional[Any]:
|
|
path = request.url.path
|
|
return self.mount_manager.get_mount(path)
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
return {
|
|
"asgi_mounts": self.mount_manager.list_mounts(),
|
|
"asgi_mount_count": len(self.mount_manager.mounts),
|
|
}
|
|
|
|
def cleanup(self) -> None:
|
|
logger.info("Cleaning up ASGI mounts")
|
|
|
|
|
|
class ExtensionManager:
|
|
def __init__(self) -> None:
|
|
self.extensions: List[Extension] = []
|
|
self.extension_registry: Dict[str, Type[Extension]] = {
|
|
"routing": RoutingExtension,
|
|
"security": SecurityExtension,
|
|
"caching": CachingExtension,
|
|
"monitoring": MonitoringExtension,
|
|
"asgi": ASGIExtension,
|
|
}
|
|
self._register_process_orchestration()
|
|
|
|
def _register_process_orchestration(self) -> None:
|
|
try:
|
|
from .process_extension import ProcessOrchestrationExtension
|
|
|
|
self.extension_registry["process_orchestration"] = ProcessOrchestrationExtension # type: ignore
|
|
except ImportError:
|
|
pass # Optional dependency
|
|
|
|
def register_extension_type(self, name: str, extension_class: Type[Extension]) -> None:
|
|
self.extension_registry[name] = extension_class
|
|
|
|
def load_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
|
|
if extension_type not in self.extension_registry:
|
|
logger.error(f"Unknown extension type: {extension_type}")
|
|
return
|
|
|
|
try:
|
|
extension_class = self.extension_registry[extension_type]
|
|
extension = extension_class(config)
|
|
extension.initialize()
|
|
self.extensions.append(extension)
|
|
logger.info(f"Loaded extension: {extension_type}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading extension {extension_type}: {e}")
|
|
|
|
async def load_extension_async(self, extension_type: str, config: Dict[str, Any]) -> None:
|
|
"""Load extension with async setup support (for ProcessOrchestration)."""
|
|
if extension_type not in self.extension_registry:
|
|
logger.error(f"Unknown extension type: {extension_type}")
|
|
return
|
|
|
|
try:
|
|
extension_class = self.extension_registry[extension_type]
|
|
extension = extension_class(config)
|
|
|
|
setup_method = getattr(extension, "setup", None)
|
|
if setup_method is not None and asyncio.iscoroutinefunction(setup_method):
|
|
await setup_method(config)
|
|
else:
|
|
extension.initialize()
|
|
|
|
start_method = getattr(extension, "start", None)
|
|
if start_method is not None and asyncio.iscoroutinefunction(start_method):
|
|
await start_method()
|
|
|
|
# Insert at the beginning so process_orchestration is checked first
|
|
self.extensions.insert(0, extension)
|
|
logger.info(f"Loaded extension (async): {extension_type}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading extension {extension_type}: {e}")
|
|
|
|
async def process_request(self, request: Request) -> Optional[Response]:
|
|
for extension in self.extensions:
|
|
if not extension.enabled:
|
|
continue
|
|
|
|
try:
|
|
response = await extension.process_request(request)
|
|
if response is not None:
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Error in extension {type(extension).__name__}: {e}")
|
|
|
|
return None
|
|
|
|
async def process_response(self, request: Request, response: Response) -> Response:
|
|
for extension in self.extensions:
|
|
if not extension.enabled:
|
|
continue
|
|
|
|
try:
|
|
response = await extension.process_response(request, response)
|
|
except Exception as e:
|
|
logger.error(f"Error in extension {type(extension).__name__}: {e}")
|
|
|
|
return response
|
|
|
|
def cleanup(self) -> None:
|
|
for extension in self.extensions:
|
|
try:
|
|
extension.cleanup()
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up extension {type(extension).__name__}: {e}")
|
|
|
|
self.extensions.clear()
|