forked from Shifty/pyserveX
328 lines
12 KiB
Python
328 lines
12 KiB
Python
import ssl
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
import uvicorn
|
|
from starlette.applications import Starlette
|
|
from starlette.requests import Request
|
|
from starlette.responses import PlainTextResponse, Response
|
|
from starlette.routing import Route
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
|
|
from . import __version__
|
|
from .config import Config
|
|
from .extensions import ASGIExtension, ExtensionManager
|
|
from .logging_utils import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class PyServeMiddleware:
|
|
def __init__(self, app: ASGIApp, extension_manager: ExtensionManager):
|
|
self.app = app
|
|
self.extension_manager = extension_manager
|
|
self.access_logger = get_logger("pyserve.access")
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
start_time = time.time()
|
|
request = Request(scope, receive)
|
|
|
|
asgi_handled = await self._try_asgi_mount(scope, receive, send, request, start_time)
|
|
if asgi_handled:
|
|
return
|
|
|
|
response = await self.extension_manager.process_request(request)
|
|
|
|
if response is None:
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
response = await self.extension_manager.process_response(request, response)
|
|
response.headers["Server"] = f"pyserve/{__version__}"
|
|
|
|
self._log_access(request, response, start_time)
|
|
|
|
await response(scope, receive, send)
|
|
|
|
async def _try_asgi_mount(self, scope: Scope, receive: Receive, send: Send, request: Request, start_time: float) -> bool:
|
|
for extension in self.extension_manager.extensions:
|
|
if isinstance(extension, ASGIExtension):
|
|
mount = extension.get_asgi_handler(request)
|
|
if mount is not None:
|
|
modified_scope = dict(scope)
|
|
if mount.strip_path:
|
|
modified_scope["path"] = mount.get_modified_path(request.url.path)
|
|
modified_scope["root_path"] = scope.get("root_path", "") + mount.path
|
|
|
|
logger.debug(f"Routing to ASGI mount '{mount.name}': " f"{request.url.path} -> {modified_scope['path']}")
|
|
|
|
try:
|
|
response_started = False
|
|
status_code = 0
|
|
|
|
async def send_wrapper(message: Dict[str, Any]) -> None:
|
|
nonlocal response_started, status_code
|
|
if message["type"] == "http.response.start":
|
|
response_started = True
|
|
status_code = message.get("status", 0)
|
|
await send(message)
|
|
|
|
await mount.app(modified_scope, receive, send_wrapper)
|
|
|
|
process_time = round((time.time() - start_time) * 1000, 2)
|
|
self.access_logger.info(
|
|
"ASGI request",
|
|
client_ip=request.client.host if request.client else "unknown",
|
|
method=request.method,
|
|
path=str(request.url.path),
|
|
mount=mount.name,
|
|
status_code=status_code,
|
|
process_time_ms=process_time,
|
|
user_agent=request.headers.get("user-agent", ""),
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error in ASGI mount '{mount.name}': {e}")
|
|
error_response = PlainTextResponse("500 Internal Server Error", status_code=500)
|
|
await error_response(scope, receive, send)
|
|
return True
|
|
return False
|
|
|
|
def _log_access(self, request: Request, response: Response, start_time: float) -> None:
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
method = request.method
|
|
path = str(request.url.path)
|
|
query = str(request.url.query) if request.url.query else ""
|
|
if query:
|
|
path += f"?{query}"
|
|
status_code = response.status_code
|
|
process_time = round((time.time() - start_time) * 1000, 2)
|
|
|
|
self.access_logger.info(
|
|
"HTTP request",
|
|
client_ip=client_ip,
|
|
method=method,
|
|
path=path,
|
|
status_code=status_code,
|
|
process_time_ms=process_time,
|
|
user_agent=request.headers.get("user-agent", ""),
|
|
)
|
|
|
|
|
|
class PyServeServer:
|
|
def __init__(self, config: Config):
|
|
self.config = config
|
|
self.extension_manager = ExtensionManager()
|
|
self.app: Optional[Starlette] = None
|
|
self._async_extensions_loaded = False
|
|
self._setup_logging()
|
|
self._load_extensions()
|
|
self._create_app()
|
|
|
|
def _setup_logging(self) -> None:
|
|
self.config.setup_logging()
|
|
logger.info("PyServe server initialized", version=__version__)
|
|
|
|
def _load_extensions(self) -> None:
|
|
for ext_config in self.config.extensions:
|
|
config = ext_config.config.copy()
|
|
if ext_config.type == "routing":
|
|
config.setdefault("default_proxy_timeout", self.config.server.proxy_timeout)
|
|
|
|
if ext_config.type == "process_orchestration":
|
|
continue
|
|
|
|
self.extension_manager.load_extension(ext_config.type, config)
|
|
|
|
async def _load_async_extensions(self) -> None:
|
|
if self._async_extensions_loaded:
|
|
return
|
|
|
|
for ext_config in self.config.extensions:
|
|
if ext_config.type == "process_orchestration":
|
|
config = ext_config.config.copy()
|
|
await self.extension_manager.load_extension_async(ext_config.type, config)
|
|
|
|
self._async_extensions_loaded = True
|
|
|
|
def _create_app(self) -> None:
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncIterator
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
|
await self._load_async_extensions()
|
|
logger.info("Async extensions loaded")
|
|
yield
|
|
|
|
routes = [
|
|
Route("/health", self._health_check, methods=["GET"]),
|
|
Route("/metrics", self._metrics, methods=["GET"]),
|
|
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]),
|
|
]
|
|
|
|
self.app = Starlette(routes=routes, lifespan=lifespan)
|
|
self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager)
|
|
|
|
async def _health_check(self, request: Request) -> Response:
|
|
return PlainTextResponse("OK", status_code=200)
|
|
|
|
async def _metrics(self, request: Request) -> Response:
|
|
metrics = {}
|
|
|
|
for extension in self.extension_manager.extensions:
|
|
if hasattr(extension, "get_metrics"):
|
|
try:
|
|
ext_metrics = getattr(extension, "get_metrics")()
|
|
metrics.update(ext_metrics)
|
|
except Exception as e:
|
|
logger.error("Error getting metrics from extension", extension=type(extension).__name__, error=str(e))
|
|
|
|
import json
|
|
|
|
return Response(json.dumps(metrics, ensure_ascii=False, indent=2), media_type="application/json")
|
|
|
|
async def _catch_all(self, request: Request) -> Response:
|
|
return PlainTextResponse("404 Not Found", status_code=404)
|
|
|
|
def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
|
|
if not self.config.ssl.enabled:
|
|
return None
|
|
|
|
if not Path(self.config.ssl.cert_file).exists():
|
|
logger.error("SSL certificate not found", cert_file=self.config.ssl.cert_file)
|
|
return None
|
|
|
|
if not Path(self.config.ssl.key_file).exists():
|
|
logger.error("SSL key not found", key_file=self.config.ssl.key_file)
|
|
return None
|
|
|
|
try:
|
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
context.load_cert_chain(self.config.ssl.cert_file, self.config.ssl.key_file)
|
|
logger.info("SSL context created successfully")
|
|
return context
|
|
except Exception as e:
|
|
logger.error("Error creating SSL context", error=str(e), exc_info=True)
|
|
return None
|
|
|
|
def run(self) -> None:
|
|
if not self.config.validate():
|
|
logger.error("Configuration is invalid, server cannot be started")
|
|
return
|
|
|
|
self._ensure_directories()
|
|
ssl_context = self._create_ssl_context()
|
|
|
|
uvicorn_config: Dict[str, Any] = {
|
|
"host": self.config.server.host,
|
|
"port": self.config.server.port,
|
|
"log_level": "critical",
|
|
"access_log": False,
|
|
"use_colors": False,
|
|
"server_header": False,
|
|
}
|
|
|
|
if ssl_context:
|
|
uvicorn_config.update(
|
|
{
|
|
"ssl_keyfile": self.config.ssl.key_file,
|
|
"ssl_certfile": self.config.ssl.cert_file,
|
|
}
|
|
)
|
|
protocol = "https"
|
|
else:
|
|
protocol = "http"
|
|
|
|
logger.info("Starting PyServe server", protocol=protocol, host=self.config.server.host, port=self.config.server.port)
|
|
|
|
try:
|
|
assert self.app is not None, "App not initialized"
|
|
uvicorn.run(self.app, **uvicorn_config)
|
|
except KeyboardInterrupt:
|
|
logger.info("Received shutdown signal")
|
|
except Exception as e:
|
|
logger.error("Error starting server", error=str(e), exc_info=True)
|
|
finally:
|
|
self.shutdown()
|
|
|
|
async def run_async(self) -> None:
|
|
if not self.config.validate():
|
|
logger.error("Configuration is invalid, server cannot be started")
|
|
return
|
|
|
|
self._ensure_directories()
|
|
|
|
config = uvicorn.Config(
|
|
app=self.app, # type: ignore
|
|
host=self.config.server.host,
|
|
port=self.config.server.port,
|
|
log_level="critical",
|
|
access_log=False,
|
|
use_colors=False,
|
|
backlog=self.config.server.backlog if self.config.server.backlog else 2048,
|
|
)
|
|
|
|
server = uvicorn.Server(config)
|
|
|
|
try:
|
|
await server.serve()
|
|
finally:
|
|
self.shutdown()
|
|
|
|
def _ensure_directories(self) -> None:
|
|
directories = [
|
|
self.config.http.static_dir,
|
|
self.config.http.templates_dir,
|
|
]
|
|
|
|
for file_config in self.config.logging.files:
|
|
log_dir = Path(file_config.path).parent
|
|
if log_dir != Path("."):
|
|
directories.append(str(log_dir))
|
|
|
|
for directory in directories:
|
|
Path(directory).mkdir(parents=True, exist_ok=True)
|
|
logger.debug("Created/checked directory", directory=directory)
|
|
|
|
def shutdown(self) -> None:
|
|
logger.info("Shutting down PyServe server")
|
|
self.extension_manager.cleanup()
|
|
|
|
from .logging_utils import shutdown_logging
|
|
|
|
shutdown_logging()
|
|
|
|
logger.info("Server stopped")
|
|
|
|
def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
|
|
self.extension_manager.load_extension(extension_type, config)
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
metrics = {"server_status": "running"}
|
|
|
|
for extension in self.extension_manager.extensions:
|
|
if hasattr(extension, "get_metrics"):
|
|
try:
|
|
ext_metrics = getattr(extension, "get_metrics")()
|
|
metrics.update(ext_metrics)
|
|
except Exception as e:
|
|
logger.error("Error getting metrics from extension", extension=type(extension).__name__, error=str(e))
|
|
|
|
return metrics
|
|
|
|
|
|
def create_server(config_path: str = "config.yaml") -> PyServeServer:
|
|
config = Config.from_yaml(config_path)
|
|
return PyServeServer(config)
|
|
|
|
|
|
def run_server(config_path: str = "config.yaml") -> None:
|
|
server = create_server(config_path)
|
|
server.run()
|