pyserveX/pyserve/process_extension.py
2025-12-04 01:25:13 +03:00

366 lines
12 KiB
Python

"""Process Orchestration Extension
Extension that manages ASGI/WSGI applications as isolated processes
and routes requests to them via reverse proxy.
"""
import asyncio
import logging
import time
import uuid
from typing import Any, Dict, Optional
import httpx
from starlette.requests import Request
from starlette.responses import Response
from .extensions import Extension
from .logging_utils import get_logger
from .process_manager import ProcessConfig, ProcessManager
logger = get_logger(__name__)
class ProcessOrchestrationExtension(Extension):
"""
Extension that orchestrates ASGI/WSGI applications as separate processes.
Unlike ASGIExtension which runs apps in-process, this extension:
- Runs each app in its own isolated process
- Provides health monitoring and auto-restart
- Routes requests via HTTP reverse proxy
- Supports multiple workers per app
Configuration example:
```yaml
extensions:
- type: process_orchestration
config:
port_range: [9000, 9999]
health_check_enabled: true
apps:
- name: api
path: /api
app_path: myapp.api:app
workers: 4
health_check_path: /health
- name: admin
path: /admin
app_path: myapp.admin:create_app
factory: true
workers: 2
```
"""
name = "process_orchestration"
def __init__(self, config: Dict[str, Any]) -> None:
super().__init__(config)
self._manager: Optional[ProcessManager] = None
self._mounts: Dict[str, MountConfig] = {} # path -> config
self._http_client: Optional[httpx.AsyncClient] = None
self._started = False
self._proxy_timeout: float = config.get("proxy_timeout", 60.0)
self._pending_config = config # Store for async setup
logging_config = config.get("logging", {})
self._log_proxy_requests: bool = logging_config.get("proxy_logs", True)
self._log_health_checks: bool = logging_config.get("health_check_logs", False)
httpx_level = logging_config.get("httpx_level", "warning").upper()
logging.getLogger("httpx").setLevel(getattr(logging, httpx_level, logging.WARNING))
logging.getLogger("httpcore").setLevel(getattr(logging, httpx_level, logging.WARNING))
async def setup(self, config: Optional[Dict[str, Any]] = None) -> None:
if config is None:
config = self._pending_config
port_range = tuple(config.get("port_range", [9000, 9999]))
health_check_enabled = config.get("health_check_enabled", True)
self._proxy_timeout = config.get("proxy_timeout", 60.0)
self._manager = ProcessManager(
port_range=port_range,
health_check_enabled=health_check_enabled,
)
self._http_client = httpx.AsyncClient(
timeout=httpx.Timeout(self._proxy_timeout),
follow_redirects=False,
limits=httpx.Limits(
max_keepalive_connections=100,
max_connections=200,
),
)
apps_config = config.get("apps", [])
for app_config in apps_config:
await self._register_app(app_config)
logger.info(
"Process orchestration extension initialized",
app_count=len(self._mounts),
)
async def _register_app(self, app_config: Dict[str, Any]) -> None:
if not self._manager:
return
name = app_config.get("name")
path = app_config.get("path", "").rstrip("/")
app_path = app_config.get("app_path")
if not name or not app_path:
logger.error("App config missing 'name' or 'app_path'")
return
process_config = ProcessConfig(
name=name,
app_path=app_path,
app_type=app_config.get("app_type", "asgi"),
workers=app_config.get("workers", 1),
module_path=app_config.get("module_path"),
factory=app_config.get("factory", False),
factory_args=app_config.get("factory_args"),
env=app_config.get("env", {}),
health_check_enabled=app_config.get("health_check_enabled", True),
health_check_path=app_config.get("health_check_path", "/health"),
health_check_interval=app_config.get("health_check_interval", 10.0),
health_check_timeout=app_config.get("health_check_timeout", 5.0),
health_check_retries=app_config.get("health_check_retries", 3),
max_memory_mb=app_config.get("max_memory_mb"),
max_restart_count=app_config.get("max_restart_count", 5),
restart_delay=app_config.get("restart_delay", 1.0),
shutdown_timeout=app_config.get("shutdown_timeout", 30.0),
)
await self._manager.register(process_config)
self._mounts[path] = MountConfig(
path=path,
process_name=name,
strip_path=app_config.get("strip_path", True),
)
logger.info(f"Registered app '{name}' at path '{path}'")
async def start(self) -> None:
if self._started or not self._manager:
return
await self._manager.start()
results = await self._manager.start_all()
self._started = True
success = sum(1 for v in results.values() if v)
failed = len(results) - success
logger.info(
"Process orchestration started",
success=success,
failed=failed,
)
async def stop(self) -> None:
if not self._started:
return
if self._http_client:
await self._http_client.aclose()
self._http_client = None
if self._manager:
await self._manager.stop()
self._started = False
logger.info("Process orchestration stopped")
def cleanup(self) -> None:
try:
loop = asyncio.get_running_loop()
loop.create_task(self.stop())
except RuntimeError:
asyncio.run(self.stop())
async def process_request(self, request: Request) -> Optional[Response]:
if not self._started or not self._manager:
logger.debug(
"Process orchestration not ready",
started=self._started,
has_manager=self._manager is not None,
)
return None
mount = self._get_mount(request.url.path)
if not mount:
logger.debug(
"No mount found for path",
path=request.url.path,
available_mounts=list(self._mounts.keys()),
)
return None
upstream_url = self._manager.get_upstream_url(mount.process_name)
if not upstream_url:
logger.warning(
f"Process '{mount.process_name}' not running",
path=request.url.path,
)
return Response("Service Unavailable", status_code=503)
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4())[:8])
start_time = time.perf_counter()
response = await self._proxy_request(request, upstream_url, mount, request_id)
latency_ms = (time.perf_counter() - start_time) * 1000
if self._log_proxy_requests:
logger.info(
"Proxy request completed",
request_id=request_id,
method=request.method,
path=request.url.path,
process=mount.process_name,
upstream=upstream_url,
status=response.status_code,
latency_ms=round(latency_ms, 2),
)
return response
def _get_mount(self, path: str) -> Optional["MountConfig"]:
for mount_path in sorted(self._mounts.keys(), key=len, reverse=True):
if mount_path == "":
return self._mounts[mount_path]
if path == mount_path or path.startswith(f"{mount_path}/"):
return self._mounts[mount_path]
return None
async def _proxy_request(
self,
request: Request,
upstream_url: str,
mount: "MountConfig",
request_id: str = "",
) -> Response:
path = request.url.path
if mount.strip_path and mount.path:
path = path[len(mount.path) :] or "/"
target_url = f"{upstream_url}{path}"
if request.url.query:
target_url += f"?{request.url.query}"
headers = dict(request.headers)
headers.pop("host", None)
headers["X-Forwarded-For"] = request.client.host if request.client else "unknown"
headers["X-Forwarded-Proto"] = request.url.scheme
headers["X-Forwarded-Host"] = request.headers.get("host", "")
if request_id:
headers["X-Request-ID"] = request_id
try:
if not self._http_client:
return Response("Service Unavailable", status_code=503)
body = await request.body()
response = await self._http_client.request(
method=request.method,
url=target_url,
headers=headers,
content=body,
)
response_headers = dict(response.headers)
for header in ["transfer-encoding", "connection", "keep-alive"]:
response_headers.pop(header, None)
return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers,
)
except httpx.TimeoutException:
logger.error(f"Proxy timeout to {upstream_url}")
return Response("Gateway Timeout", status_code=504)
except httpx.ConnectError as e:
logger.error(f"Proxy connection error to {upstream_url}: {e}")
return Response("Bad Gateway", status_code=502)
except Exception as e:
logger.error(f"Proxy error to {upstream_url}: {e}")
return Response("Internal Server Error", status_code=500)
async def process_response(
self,
request: Request,
response: Response,
) -> Response:
return response
def get_metrics(self) -> Dict[str, Any]:
metrics = {
"process_orchestration": {
"enabled": self._started,
"mounts": len(self._mounts),
}
}
if self._manager:
metrics["process_orchestration"].update(self._manager.get_metrics())
return metrics
async def get_process_status(self, name: str) -> Optional[Dict[str, Any]]:
if not self._manager:
return None
info = self._manager.get_process(name)
return info.to_dict() if info else None
async def get_all_status(self) -> Dict[str, Any]:
if not self._manager:
return {}
return {name: info.to_dict() for name, info in self._manager.get_all_processes().items()}
async def restart_process(self, name: str) -> bool:
if not self._manager:
return False
return await self._manager.restart_process(name)
async def scale_process(self, name: str, workers: int) -> bool:
if not self._manager:
return False
info = self._manager.get_process(name)
if not info:
return False
info.config.workers = workers
return await self._manager.restart_process(name)
class MountConfig:
def __init__(
self,
path: str,
process_name: str,
strip_path: bool = True,
):
self.path = path
self.process_name = process_name
self.strip_path = strip_path
async def setup_process_orchestration(config: Dict[str, Any]) -> ProcessOrchestrationExtension:
ext = ProcessOrchestrationExtension(config)
await ext.setup(config)
await ext.start()
return ext
async def shutdown_process_orchestration(ext: ProcessOrchestrationExtension) -> None:
await ext.stop()