pyserveX/pyserve/server.py
Илья Глазунов 84cd1c974f feat: Add CLI for PyServe with configuration options
- Introduced a new CLI module (`cli.py`) to manage server configurations via command line arguments.
- Added script entry point in `pyproject.toml` for easy access to the CLI.
- Enhanced `Config` class to load configurations from a YAML file.
- Updated `__init__.py` to include `__version__` in the module exports.
- Added optional dependencies for development tools in `pyproject.toml`.
- Implemented logging improvements and error handling in various modules.
- Created tests for the CLI functionality to ensure proper behavior.
- Removed the old `run.py` implementation in favor of the new CLI approach.
2025-09-02 00:20:40 +03:00

236 lines
8.0 KiB
Python

import ssl
import uvicorn
import time
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route
from pathlib import Path
from typing import Optional, Dict, Any
from .config import Config
from .extensions import ExtensionManager
from .logging_utils import get_logger
from . import __version__
logger = get_logger(__name__)
class PyServeMiddleware(BaseHTTPMiddleware):
def __init__(self, app, extension_manager: ExtensionManager):
super().__init__(app)
self.extension_manager = extension_manager
self.access_logger = get_logger('pyserve.access')
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await self.extension_manager.process_request(request)
if response is None:
response = await call_next(request)
response = await self.extension_manager.process_response(request, response)
response.headers["Server"] = f"pyserve/{__version__}"
client_ip = request.client.host if request.client else "unknown"
method = request.method
path = str(request.url.path)
query = str(request.url.query) if request.url.query else ""
if query:
path += f"?{query}"
status_code = response.status_code
process_time = round((time.time() - start_time) * 1000, 2)
self.access_logger.info(f"{client_ip} - {method} {path} - {status_code} - {process_time}ms")
return response
class PyServeServer:
def __init__(self, config: Config):
self.config = config
self.extension_manager = ExtensionManager()
self.app: Optional[Starlette] = None
self._setup_logging()
self._load_extensions()
self._create_app()
def _setup_logging(self) -> None:
self.config.setup_logging()
logger.info("PyServe сервер инициализирован")
def _load_extensions(self) -> None:
for ext_config in self.config.extensions:
self.extension_manager.load_extension(
ext_config.type,
ext_config.config
)
def _create_app(self) -> None:
routes = [
Route("/health", self._health_check, methods=["GET"]),
Route("/metrics", self._metrics, methods=["GET"]),
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]),
]
self.app = Starlette(routes=routes)
self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager)
async def _health_check(self, request: Request) -> Response:
return PlainTextResponse("OK", status_code=200)
async def _metrics(self, request: Request) -> Response:
metrics = {}
for extension in self.extension_manager.extensions:
if hasattr(extension, 'get_metrics'):
try:
ext_metrics = getattr(extension, 'get_metrics')()
metrics.update(ext_metrics)
except Exception as e:
logger.error(f"Error getting metrics from {type(extension).__name__}: {e}")
import json
return Response(
json.dumps(metrics, ensure_ascii=False, indent=2),
media_type="application/json"
)
async def _catch_all(self, request: Request) -> Response:
return PlainTextResponse("404 Not Found", status_code=404)
def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
if not self.config.ssl.enabled:
return None
if not Path(self.config.ssl.cert_file).exists():
logger.error(f"SSL certificate not found: {self.config.ssl.cert_file}")
return None
if not Path(self.config.ssl.key_file).exists():
logger.error(f"SSL key not found: {self.config.ssl.key_file}")
return None
try:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(
self.config.ssl.cert_file,
self.config.ssl.key_file
)
logger.info("SSL context created successfully")
return context
except Exception as e:
logger.error(f"Error creating SSL context: {e}")
return None
def run(self) -> None:
if not self.config.validate():
logger.error("Configuration is invalid, server cannot be started")
return
self._ensure_directories()
ssl_context = self._create_ssl_context()
uvicorn_config = {
"app": self.app,
"host": self.config.server.host,
"port": self.config.server.port,
"log_level": "critical",
"access_log": False,
"use_colors": False,
"server_header": False,
}
if ssl_context:
uvicorn_config.update({
"ssl_keyfile": self.config.ssl.key_file,
"ssl_certfile": self.config.ssl.cert_file,
})
protocol = "https"
else:
protocol = "http"
logger.info(f"Starting PyServe server at {protocol}://{self.config.server.host}:{self.config.server.port}")
try:
uvicorn.run(**uvicorn_config)
except KeyboardInterrupt:
logger.info("Received shutdown signal")
except Exception as e:
logger.error(f"Error starting server: {e}")
finally:
self.shutdown()
async def run_async(self) -> None:
if not self.config.validate():
logger.error("Configuration is invalid, server cannot be started")
return
self._ensure_directories()
config = uvicorn.Config(
app=self.app, # type: ignore
host=self.config.server.host,
port=self.config.server.port,
log_level="critical",
access_log=False,
use_colors=False,
)
server = uvicorn.Server(config)
try:
await server.serve()
finally:
self.shutdown()
def _ensure_directories(self) -> None:
directories = [
self.config.http.static_dir,
self.config.http.templates_dir,
]
log_dir = Path(self.config.logging.log_file).parent
if log_dir != Path("."):
directories.append(str(log_dir))
for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.debug(f"Created/checked directory: {directory}")
def shutdown(self) -> None:
logger.info("Shutting down PyServe server")
self.extension_manager.cleanup()
from .logging_utils import shutdown_logging
shutdown_logging()
logger.info("Server stopped")
def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
self.extension_manager.load_extension(extension_type, config)
def get_metrics(self) -> Dict[str, Any]:
metrics = {"server_status": "running"}
for extension in self.extension_manager.extensions:
if hasattr(extension, 'get_metrics'):
try:
ext_metrics = getattr(extension, 'get_metrics')()
metrics.update(ext_metrics)
except Exception as e:
logger.error(f"Error getting metrics from {type(extension).__name__}: {e}")
return metrics
def create_server(config_path: str = "config.yaml") -> PyServeServer:
config = Config.from_yaml(config_path)
return PyServeServer(config)
def run_server(config_path: str = "config.yaml") -> None:
server = create_server(config_path)
server.run()