import ssl import time from pathlib import Path from typing import Any, Dict, Optional import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.routing import Route from starlette.types import ASGIApp, Receive, Scope, Send from . import __version__ from .config import Config from .extensions import ASGIExtension, ExtensionManager from .logging_utils import get_logger logger = get_logger(__name__) class PyServeMiddleware: def __init__(self, app: ASGIApp, extension_manager: ExtensionManager): self.app = app self.extension_manager = extension_manager self.access_logger = get_logger("pyserve.access") async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return start_time = time.time() request = Request(scope, receive) asgi_handled = await self._try_asgi_mount(scope, receive, send, request, start_time) if asgi_handled: return response = await self.extension_manager.process_request(request) if response is None: await self.app(scope, receive, send) return response = await self.extension_manager.process_response(request, response) response.headers["Server"] = f"pyserve/{__version__}" self._log_access(request, response, start_time) await response(scope, receive, send) async def _try_asgi_mount(self, scope: Scope, receive: Receive, send: Send, request: Request, start_time: float) -> bool: for extension in self.extension_manager.extensions: if isinstance(extension, ASGIExtension): mount = extension.get_asgi_handler(request) if mount is not None: modified_scope = dict(scope) if mount.strip_path: modified_scope["path"] = mount.get_modified_path(request.url.path) modified_scope["root_path"] = scope.get("root_path", "") + mount.path logger.debug(f"Routing to ASGI mount '{mount.name}': " f"{request.url.path} -> {modified_scope['path']}") try: response_started = False status_code = 0 async def send_wrapper(message: Dict[str, Any]) -> None: nonlocal response_started, status_code if message["type"] == "http.response.start": response_started = True status_code = message.get("status", 0) await send(message) await mount.app(modified_scope, receive, send_wrapper) process_time = round((time.time() - start_time) * 1000, 2) self.access_logger.info( "ASGI request", client_ip=request.client.host if request.client else "unknown", method=request.method, path=str(request.url.path), mount=mount.name, status_code=status_code, process_time_ms=process_time, user_agent=request.headers.get("user-agent", ""), ) return True except Exception as e: logger.error(f"Error in ASGI mount '{mount.name}': {e}") error_response = PlainTextResponse("500 Internal Server Error", status_code=500) await error_response(scope, receive, send) return True return False def _log_access(self, request: Request, response: Response, start_time: float) -> None: client_ip = request.client.host if request.client else "unknown" method = request.method path = str(request.url.path) query = str(request.url.query) if request.url.query else "" if query: path += f"?{query}" status_code = response.status_code process_time = round((time.time() - start_time) * 1000, 2) self.access_logger.info( "HTTP request", client_ip=client_ip, method=method, path=path, status_code=status_code, process_time_ms=process_time, user_agent=request.headers.get("user-agent", ""), ) class PyServeServer: def __init__(self, config: Config): self.config = config self.extension_manager = ExtensionManager() self.app: Optional[Starlette] = None self._async_extensions_loaded = False self._setup_logging() self._load_extensions() self._create_app() def _setup_logging(self) -> None: self.config.setup_logging() logger.info("PyServe server initialized", version=__version__) def _load_extensions(self) -> None: for ext_config in self.config.extensions: config = ext_config.config.copy() if ext_config.type == "routing": config.setdefault("default_proxy_timeout", self.config.server.proxy_timeout) if ext_config.type == "process_orchestration": continue self.extension_manager.load_extension(ext_config.type, config) async def _load_async_extensions(self) -> None: if self._async_extensions_loaded: return for ext_config in self.config.extensions: if ext_config.type == "process_orchestration": config = ext_config.config.copy() await self.extension_manager.load_extension_async(ext_config.type, config) self._async_extensions_loaded = True def _create_app(self) -> None: from contextlib import asynccontextmanager @asynccontextmanager async def lifespan(app: Starlette): await self._load_async_extensions() logger.info("Async extensions loaded") yield routes = [ Route("/health", self._health_check, methods=["GET"]), Route("/metrics", self._metrics, methods=["GET"]), Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]), ] self.app = Starlette(routes=routes, lifespan=lifespan) self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager) async def _health_check(self, request: Request) -> Response: return PlainTextResponse("OK", status_code=200) async def _metrics(self, request: Request) -> Response: metrics = {} for extension in self.extension_manager.extensions: if hasattr(extension, "get_metrics"): try: ext_metrics = getattr(extension, "get_metrics")() metrics.update(ext_metrics) except Exception as e: logger.error("Error getting metrics from extension", extension=type(extension).__name__, error=str(e)) import json return Response(json.dumps(metrics, ensure_ascii=False, indent=2), media_type="application/json") async def _catch_all(self, request: Request) -> Response: return PlainTextResponse("404 Not Found", status_code=404) def _create_ssl_context(self) -> Optional[ssl.SSLContext]: if not self.config.ssl.enabled: return None if not Path(self.config.ssl.cert_file).exists(): logger.error("SSL certificate not found", cert_file=self.config.ssl.cert_file) return None if not Path(self.config.ssl.key_file).exists(): logger.error("SSL key not found", key_file=self.config.ssl.key_file) return None try: context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context.load_cert_chain(self.config.ssl.cert_file, self.config.ssl.key_file) logger.info("SSL context created successfully") return context except Exception as e: logger.error("Error creating SSL context", error=str(e), exc_info=True) return None def run(self) -> None: if not self.config.validate(): logger.error("Configuration is invalid, server cannot be started") return self._ensure_directories() ssl_context = self._create_ssl_context() uvicorn_config: Dict[str, Any] = { "host": self.config.server.host, "port": self.config.server.port, "log_level": "critical", "access_log": False, "use_colors": False, "server_header": False, } if ssl_context: uvicorn_config.update( { "ssl_keyfile": self.config.ssl.key_file, "ssl_certfile": self.config.ssl.cert_file, } ) protocol = "https" else: protocol = "http" logger.info("Starting PyServe server", protocol=protocol, host=self.config.server.host, port=self.config.server.port) try: assert self.app is not None, "App not initialized" uvicorn.run(self.app, **uvicorn_config) except KeyboardInterrupt: logger.info("Received shutdown signal") except Exception as e: logger.error("Error starting server", error=str(e), exc_info=True) finally: self.shutdown() async def run_async(self) -> None: if not self.config.validate(): logger.error("Configuration is invalid, server cannot be started") return self._ensure_directories() config = uvicorn.Config( app=self.app, # type: ignore host=self.config.server.host, port=self.config.server.port, log_level="critical", access_log=False, use_colors=False, backlog=self.config.server.backlog if self.config.server.backlog else 2048, ) server = uvicorn.Server(config) try: await server.serve() finally: self.shutdown() def _ensure_directories(self) -> None: directories = [ self.config.http.static_dir, self.config.http.templates_dir, ] for file_config in self.config.logging.files: log_dir = Path(file_config.path).parent if log_dir != Path("."): directories.append(str(log_dir)) for directory in directories: Path(directory).mkdir(parents=True, exist_ok=True) logger.debug("Created/checked directory", directory=directory) def shutdown(self) -> None: logger.info("Shutting down PyServe server") self.extension_manager.cleanup() from .logging_utils import shutdown_logging shutdown_logging() logger.info("Server stopped") def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None: self.extension_manager.load_extension(extension_type, config) def get_metrics(self) -> Dict[str, Any]: metrics = {"server_status": "running"} for extension in self.extension_manager.extensions: if hasattr(extension, "get_metrics"): try: ext_metrics = getattr(extension, "get_metrics")() metrics.update(ext_metrics) except Exception as e: logger.error("Error getting metrics from extension", extension=type(extension).__name__, error=str(e)) return metrics def create_server(config_path: str = "config.yaml") -> PyServeServer: config = Config.from_yaml(config_path) return PyServeServer(config) def run_server(config_path: str = "config.yaml") -> None: server = create_server(config_path) server.run()