konduktor/pyserve/server.py
2025-12-03 12:54:45 +03:00

304 lines
11 KiB
Python

import ssl
import time
from pathlib import Path
from typing import Any, Dict, Optional
import uvicorn
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.routing import Route
from starlette.types import ASGIApp, Receive, Scope, Send
from . import __version__
from .config import Config
from .extensions import ASGIExtension, ExtensionManager
from .logging_utils import get_logger
logger = get_logger(__name__)
class PyServeMiddleware:
def __init__(self, app: ASGIApp, extension_manager: ExtensionManager):
self.app = app
self.extension_manager = extension_manager
self.access_logger = get_logger("pyserve.access")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
request = Request(scope, receive)
asgi_handled = await self._try_asgi_mount(scope, receive, send, request, start_time)
if asgi_handled:
return
response = await self.extension_manager.process_request(request)
if response is None:
await self.app(scope, receive, send)
return
response = await self.extension_manager.process_response(request, response)
response.headers["Server"] = f"pyserve/{__version__}"
self._log_access(request, response, start_time)
await response(scope, receive, send)
async def _try_asgi_mount(self, scope: Scope, receive: Receive, send: Send, request: Request, start_time: float) -> bool:
for extension in self.extension_manager.extensions:
if isinstance(extension, ASGIExtension):
mount = extension.get_asgi_handler(request)
if mount is not None:
modified_scope = dict(scope)
if mount.strip_path:
modified_scope["path"] = mount.get_modified_path(request.url.path)
modified_scope["root_path"] = scope.get("root_path", "") + mount.path
logger.debug(f"Routing to ASGI mount '{mount.name}': " f"{request.url.path} -> {modified_scope['path']}")
try:
response_started = False
status_code = 0
async def send_wrapper(message: Dict[str, Any]) -> None:
nonlocal response_started, status_code
if message["type"] == "http.response.start":
response_started = True
status_code = message.get("status", 0)
await send(message)
await mount.app(modified_scope, receive, send_wrapper)
process_time = round((time.time() - start_time) * 1000, 2)
self.access_logger.info(
"ASGI request",
client_ip=request.client.host if request.client else "unknown",
method=request.method,
path=str(request.url.path),
mount=mount.name,
status_code=status_code,
process_time_ms=process_time,
user_agent=request.headers.get("user-agent", ""),
)
return True
except Exception as e:
logger.error(f"Error in ASGI mount '{mount.name}': {e}")
error_response = PlainTextResponse("500 Internal Server Error", status_code=500)
await error_response(scope, receive, send)
return True
return False
def _log_access(self, request: Request, response: Response, start_time: float) -> None:
client_ip = request.client.host if request.client else "unknown"
method = request.method
path = str(request.url.path)
query = str(request.url.query) if request.url.query else ""
if query:
path += f"?{query}"
status_code = response.status_code
process_time = round((time.time() - start_time) * 1000, 2)
self.access_logger.info(
"HTTP request",
client_ip=client_ip,
method=method,
path=path,
status_code=status_code,
process_time_ms=process_time,
user_agent=request.headers.get("user-agent", ""),
)
class PyServeServer:
def __init__(self, config: Config):
self.config = config
self.extension_manager = ExtensionManager()
self.app: Optional[Starlette] = None
self._setup_logging()
self._load_extensions()
self._create_app()
def _setup_logging(self) -> None:
self.config.setup_logging()
logger.info("PyServe server initialized", version=__version__)
def _load_extensions(self) -> None:
for ext_config in self.config.extensions:
config = ext_config.config.copy()
if ext_config.type == "routing":
config.setdefault("default_proxy_timeout", self.config.server.proxy_timeout)
self.extension_manager.load_extension(ext_config.type, config)
def _create_app(self) -> None:
routes = [
Route("/health", self._health_check, methods=["GET"]),
Route("/metrics", self._metrics, methods=["GET"]),
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]),
]
self.app = Starlette(routes=routes)
self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager)
async def _health_check(self, request: Request) -> Response:
return PlainTextResponse("OK", status_code=200)
async def _metrics(self, request: Request) -> Response:
metrics = {}
for extension in self.extension_manager.extensions:
if hasattr(extension, "get_metrics"):
try:
ext_metrics = getattr(extension, "get_metrics")()
metrics.update(ext_metrics)
except Exception as e:
logger.error("Error getting metrics from extension", extension=type(extension).__name__, error=str(e))
import json
return Response(json.dumps(metrics, ensure_ascii=False, indent=2), media_type="application/json")
async def _catch_all(self, request: Request) -> Response:
return PlainTextResponse("404 Not Found", status_code=404)
def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
if not self.config.ssl.enabled:
return None
if not Path(self.config.ssl.cert_file).exists():
logger.error("SSL certificate not found", cert_file=self.config.ssl.cert_file)
return None
if not Path(self.config.ssl.key_file).exists():
logger.error("SSL key not found", key_file=self.config.ssl.key_file)
return None
try:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(self.config.ssl.cert_file, self.config.ssl.key_file)
logger.info("SSL context created successfully")
return context
except Exception as e:
logger.error("Error creating SSL context", error=str(e), exc_info=True)
return None
def run(self) -> None:
if not self.config.validate():
logger.error("Configuration is invalid, server cannot be started")
return
self._ensure_directories()
ssl_context = self._create_ssl_context()
uvicorn_config: Dict[str, Any] = {
"host": self.config.server.host,
"port": self.config.server.port,
"log_level": "critical",
"access_log": False,
"use_colors": False,
"server_header": False,
}
if ssl_context:
uvicorn_config.update(
{
"ssl_keyfile": self.config.ssl.key_file,
"ssl_certfile": self.config.ssl.cert_file,
}
)
protocol = "https"
else:
protocol = "http"
logger.info("Starting PyServe server", protocol=protocol, host=self.config.server.host, port=self.config.server.port)
try:
assert self.app is not None, "App not initialized"
uvicorn.run(self.app, **uvicorn_config)
except KeyboardInterrupt:
logger.info("Received shutdown signal")
except Exception as e:
logger.error("Error starting server", error=str(e), exc_info=True)
finally:
self.shutdown()
async def run_async(self) -> None:
if not self.config.validate():
logger.error("Configuration is invalid, server cannot be started")
return
self._ensure_directories()
config = uvicorn.Config(
app=self.app, # type: ignore
host=self.config.server.host,
port=self.config.server.port,
log_level="critical",
access_log=False,
use_colors=False,
backlog=self.config.server.backlog if self.config.server.backlog else 2048,
)
server = uvicorn.Server(config)
try:
await server.serve()
finally:
self.shutdown()
def _ensure_directories(self) -> None:
directories = [
self.config.http.static_dir,
self.config.http.templates_dir,
]
for file_config in self.config.logging.files:
log_dir = Path(file_config.path).parent
if log_dir != Path("."):
directories.append(str(log_dir))
for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.debug("Created/checked directory", directory=directory)
def shutdown(self) -> None:
logger.info("Shutting down PyServe server")
self.extension_manager.cleanup()
from .logging_utils import shutdown_logging
shutdown_logging()
logger.info("Server stopped")
def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
self.extension_manager.load_extension(extension_type, config)
def get_metrics(self) -> Dict[str, Any]:
metrics = {"server_status": "running"}
for extension in self.extension_manager.extensions:
if hasattr(extension, "get_metrics"):
try:
ext_metrics = getattr(extension, "get_metrics")()
metrics.update(ext_metrics)
except Exception as e:
logger.error("Error getting metrics from extension", extension=type(extension).__name__, error=str(e))
return metrics
def create_server(config_path: str = "config.yaml") -> PyServeServer:
config = Config.from_yaml(config_path)
return PyServeServer(config)
def run_server(config_path: str = "config.yaml") -> None:
server = create_server(config_path)
server.run()