diff --git a/.flake8 b/.flake8 index f6398cc..171ef66 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] -max-line-length = 100 +max-line-length = 120 exclude = __pycache__,.git,.venv,venv,build,dist ignore = E203,W503 diff --git a/pyserve/routing.py b/pyserve/routing.py index f9da3b3..af222b2 100644 --- a/pyserve/routing.py +++ b/pyserve/routing.py @@ -8,34 +8,36 @@ from .logging_utils import get_logger logger = get_logger(__name__) + class RouteMatch: def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None): self.config = config self.params = params or {} + class Router: def __init__(self, static_dir: str = "./static"): self.static_dir = Path(static_dir) self.routes: Dict[Pattern, Dict[str, Any]] = {} self.exact_routes: Dict[str, Dict[str, Any]] = {} self.default_route: Optional[Dict[str, Any]] = None - + def add_route(self, pattern: str, config: Dict[str, Any]) -> None: if pattern.startswith("="): exact_path = pattern[1:] self.exact_routes[exact_path] = config logger.debug(f"Added exact route: {exact_path}") return - + if pattern == "__default__": self.default_route = config logger.debug("Added default route") return - + if pattern.startswith("~"): case_insensitive = pattern.startswith("~*") regex_pattern = pattern[2:] if case_insensitive else pattern[1:] - + flags = re.IGNORECASE if case_insensitive else 0 try: compiled_pattern = re.compile(regex_pattern, flags) @@ -43,47 +45,48 @@ class Router: logger.debug(f"Added regex route: {pattern}") except re.error as e: logger.error(f"Regex compilation error {pattern}: {e}") - + def match(self, path: str) -> Optional[RouteMatch]: if path in self.exact_routes: return RouteMatch(self.exact_routes[path]) - + for pattern, config in self.routes.items(): match = pattern.search(path) if match: params = match.groupdict() return RouteMatch(config, params) - + if self.default_route: return RouteMatch(self.default_route) - + return None -class RequestHandler: +class RequestHandler: def __init__(self, router: Router, static_dir: str = "./static"): self.router = router self.static_dir = Path(static_dir) - + async def handle(self, request: Request) -> Response: path = request.url.path - + logger.info(f"{request.method} {path}") - + route_match = self.router.match(path) if not route_match: return PlainTextResponse("404 Not Found", status_code=404) - + try: return await self._process_route(request, route_match) except Exception as e: logger.error(f"Request processing error {path}: {e}") return PlainTextResponse("500 Internal Server Error", status_code=500) - + async def _process_route(self, request: Request, route_match: RouteMatch) -> Response: config = route_match.config - path = request.url.path - + # HINT: Not using it right now + # path = request.url.path + if "return" in config: status_text = config["return"] if " " in status_text: @@ -92,32 +95,32 @@ class RequestHandler: else: status_code = int(status_text) text = "" - + content_type = config.get("content_type", "text/plain") - return PlainTextResponse(text, status_code=status_code, - media_type=content_type) - + return PlainTextResponse(text, status_code=status_code, + media_type=content_type) + if "proxy_pass" in config: return await self._handle_proxy(request, config, route_match.params) - + if "root" in config: return await self._handle_static(request, config) - + if config.get("spa_fallback"): return await self._handle_spa_fallback(request, config) - + return PlainTextResponse("404 Not Found", status_code=404) - + async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response: root = Path(config["root"]) path = request.url.path.lstrip("/") - + if not path or path == "/": index_file = config.get("index_file", "index.html") file_path = root / index_file else: file_path = root / path - + try: file_path = file_path.resolve() root = root.resolve() @@ -125,59 +128,59 @@ class RequestHandler: return PlainTextResponse("403 Forbidden", status_code=403) except OSError: return PlainTextResponse("404 Not Found", status_code=404) - + if not file_path.exists() or not file_path.is_file(): return PlainTextResponse("404 Not Found", status_code=404) - + content_type, _ = mimetypes.guess_type(str(file_path)) - + response = FileResponse(str(file_path), media_type=content_type) - + if "headers" in config: for header in config["headers"]: if ":" in header: name, value = header.split(":", 1) response.headers[name.strip()] = value.strip() - + if "cache_control" in config: response.headers["Cache-Control"] = config["cache_control"] - + return response - + async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response: path = request.url.path - + exclude_patterns = config.get("exclude_patterns", []) for pattern in exclude_patterns: if path.startswith(pattern): return PlainTextResponse("404 Not Found", status_code=404) - + root = Path(config.get("root", self.static_dir)) index_file = config.get("index_file", "index.html") file_path = root / index_file - + if file_path.exists() and file_path.is_file(): return FileResponse(str(file_path), media_type="text/html") - + return PlainTextResponse("404 Not Found", status_code=404) - - async def _handle_proxy(self, request: Request, config: Dict[str, Any], - params: Dict[str, str]) -> Response: + + async def _handle_proxy(self, request: Request, config: Dict[str, Any], + params: Dict[str, str]) -> Response: # TODO: Реализовать полноценное проксирование proxy_url = config["proxy_pass"] - + for key, value in params.items(): proxy_url = proxy_url.replace(f"{{{key}}}", value) - + logger.info(f"Proxying request to: {proxy_url}") - + return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200) def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router: router = Router() - + for pattern, config in regex_locations.items(): router.add_route(pattern, config) - + return router diff --git a/pyserve/server.py b/pyserve/server.py index 0ad753b..f31432f 100644 --- a/pyserve/server.py +++ b/pyserve/server.py @@ -4,8 +4,8 @@ import time from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response, PlainTextResponse -from starlette.middleware.base import BaseHTTPMiddleware from starlette.routing import Route +from starlette.types import ASGIApp, Receive, Scope, Send from pathlib import Path from typing import Optional, Dict, Any @@ -17,22 +17,28 @@ from . import __version__ logger = get_logger(__name__) -class PyServeMiddleware(BaseHTTPMiddleware): - def __init__(self, app, extension_manager: ExtensionManager): - super().__init__(app) +class PyServeMiddleware: + def __init__(self, app: ASGIApp, extension_manager: ExtensionManager): + self.app = app self.extension_manager = extension_manager self.access_logger = get_logger('pyserve.access') - - async def dispatch(self, request: Request, call_next): + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + start_time = time.time() + request = Request(scope, receive) response = await self.extension_manager.process_request(request) - + if response is None: - response = await call_next(request) - + await self.app(scope, receive, send) + return + response = await self.extension_manager.process_response(request, response) - response.headers["Server"] = f"pyserve/{__version__}" + client_ip = request.client.host if request.client else "unknown" method = request.method path = str(request.url.path) @@ -41,13 +47,13 @@ class PyServeMiddleware(BaseHTTPMiddleware): path += f"?{query}" status_code = response.status_code process_time = round((time.time() - start_time) * 1000, 2) - + self.access_logger.info(f"{client_ip} - {method} {path} - {status_code} - {process_time}ms") - - return response + + await response(scope, receive, send) -class PyServeServer: +class PyServeServer: def __init__(self, config: Config): self.config = config self.extension_manager = ExtensionManager() @@ -55,34 +61,45 @@ class PyServeServer: self._setup_logging() self._load_extensions() self._create_app() - + def _setup_logging(self) -> None: self.config.setup_logging() - logger.info("PyServe сервер инициализирован") - + logger.info("PyServe server initialized") + def _load_extensions(self) -> None: for ext_config in self.config.extensions: self.extension_manager.load_extension( - ext_config.type, + ext_config.type, ext_config.config ) - + def _create_app(self) -> None: routes = [ Route("/health", self._health_check, methods=["GET"]), Route("/metrics", self._metrics, methods=["GET"]), - Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]), + Route( + "/{path:path}", + self._catch_all, + methods=[ + "GET", + "POST", + "PUT", + "DELETE", + "PATCH", + "OPTIONS" + ] + ), ] - + self.app = Starlette(routes=routes) self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager) - + async def _health_check(self, request: Request) -> Response: return PlainTextResponse("OK", status_code=200) - + async def _metrics(self, request: Request) -> Response: metrics = {} - + for extension in self.extension_manager.extensions: if hasattr(extension, 'get_metrics'): try: @@ -96,22 +113,22 @@ class PyServeServer: json.dumps(metrics, ensure_ascii=False, indent=2), media_type="application/json" ) - + async def _catch_all(self, request: Request) -> Response: return PlainTextResponse("404 Not Found", status_code=404) - + def _create_ssl_context(self) -> Optional[ssl.SSLContext]: if not self.config.ssl.enabled: return None - + if not Path(self.config.ssl.cert_file).exists(): logger.error(f"SSL certificate not found: {self.config.ssl.cert_file}") return None - + if not Path(self.config.ssl.key_file).exists(): logger.error(f"SSL key not found: {self.config.ssl.key_file}") return None - + try: context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context.load_cert_chain( @@ -123,17 +140,16 @@ class PyServeServer: except Exception as e: logger.error(f"Error creating SSL context: {e}") return None - + def run(self) -> None: if not self.config.validate(): logger.error("Configuration is invalid, server cannot be started") return - + self._ensure_directories() ssl_context = self._create_ssl_context() - - uvicorn_config = { - "app": self.app, + + uvicorn_config: Dict[str, Any] = { "host": self.config.server.host, "port": self.config.server.port, "log_level": "critical", @@ -141,7 +157,7 @@ class PyServeServer: "use_colors": False, "server_header": False, } - + if ssl_context: uvicorn_config.update({ "ssl_keyfile": self.config.ssl.key_file, @@ -154,21 +170,22 @@ class PyServeServer: logger.info(f"Starting PyServe server at {protocol}://{self.config.server.host}:{self.config.server.port}") try: - uvicorn.run(**uvicorn_config) + assert self.app is not None, "App not initialized" + uvicorn.run(self.app, **uvicorn_config) except KeyboardInterrupt: logger.info("Received shutdown signal") except Exception as e: logger.error(f"Error starting server: {e}") finally: self.shutdown() - + async def run_async(self) -> None: if not self.config.validate(): logger.error("Configuration is invalid, server cannot be started") return - + self._ensure_directories() - + config = uvicorn.Config( app=self.app, # type: ignore host=self.config.server.host, @@ -177,24 +194,24 @@ class PyServeServer: access_log=False, use_colors=False, ) - + server = uvicorn.Server(config) - + try: await server.serve() finally: self.shutdown() - + def _ensure_directories(self) -> None: directories = [ self.config.http.static_dir, self.config.http.templates_dir, ] - + log_dir = Path(self.config.logging.log_file).parent if log_dir != Path("."): directories.append(str(log_dir)) - + for directory in directories: Path(directory).mkdir(parents=True, exist_ok=True) logger.debug(f"Created/checked directory: {directory}") @@ -202,7 +219,7 @@ class PyServeServer: def shutdown(self) -> None: logger.info("Shutting down PyServe server") self.extension_manager.cleanup() - + from .logging_utils import shutdown_logging shutdown_logging() @@ -210,10 +227,10 @@ class PyServeServer: def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None: self.extension_manager.load_extension(extension_type, config) - + def get_metrics(self) -> Dict[str, Any]: metrics = {"server_status": "running"} - + for extension in self.extension_manager.extensions: if hasattr(extension, 'get_metrics'): try: