fix: Update max-line-length in .flake8 and refactor routing and server code for improved readability and functionality

This commit is contained in:
Илья Глазунов 2025-09-02 14:23:01 +03:00
parent 84cd1c974f
commit 6b157d7626
3 changed files with 113 additions and 93 deletions

View File

@ -1,4 +1,4 @@
[flake8]
max-line-length = 100
max-line-length = 120
exclude = __pycache__,.git,.venv,venv,build,dist
ignore = E203,W503

View File

@ -8,34 +8,36 @@ from .logging_utils import get_logger
logger = get_logger(__name__)
class RouteMatch:
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
self.config = config
self.params = params or {}
class Router:
def __init__(self, static_dir: str = "./static"):
self.static_dir = Path(static_dir)
self.routes: Dict[Pattern, Dict[str, Any]] = {}
self.exact_routes: Dict[str, Dict[str, Any]] = {}
self.default_route: Optional[Dict[str, Any]] = None
def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
if pattern.startswith("="):
exact_path = pattern[1:]
self.exact_routes[exact_path] = config
logger.debug(f"Added exact route: {exact_path}")
return
if pattern == "__default__":
self.default_route = config
logger.debug("Added default route")
return
if pattern.startswith("~"):
case_insensitive = pattern.startswith("~*")
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
flags = re.IGNORECASE if case_insensitive else 0
try:
compiled_pattern = re.compile(regex_pattern, flags)
@ -43,47 +45,48 @@ class Router:
logger.debug(f"Added regex route: {pattern}")
except re.error as e:
logger.error(f"Regex compilation error {pattern}: {e}")
def match(self, path: str) -> Optional[RouteMatch]:
if path in self.exact_routes:
return RouteMatch(self.exact_routes[path])
for pattern, config in self.routes.items():
match = pattern.search(path)
if match:
params = match.groupdict()
return RouteMatch(config, params)
if self.default_route:
return RouteMatch(self.default_route)
return None
class RequestHandler:
class RequestHandler:
def __init__(self, router: Router, static_dir: str = "./static"):
self.router = router
self.static_dir = Path(static_dir)
async def handle(self, request: Request) -> Response:
path = request.url.path
logger.info(f"{request.method} {path}")
route_match = self.router.match(path)
if not route_match:
return PlainTextResponse("404 Not Found", status_code=404)
try:
return await self._process_route(request, route_match)
except Exception as e:
logger.error(f"Request processing error {path}: {e}")
return PlainTextResponse("500 Internal Server Error", status_code=500)
async def _process_route(self, request: Request, route_match: RouteMatch) -> Response:
config = route_match.config
path = request.url.path
# HINT: Not using it right now
# path = request.url.path
if "return" in config:
status_text = config["return"]
if " " in status_text:
@ -92,32 +95,32 @@ class RequestHandler:
else:
status_code = int(status_text)
text = ""
content_type = config.get("content_type", "text/plain")
return PlainTextResponse(text, status_code=status_code,
media_type=content_type)
return PlainTextResponse(text, status_code=status_code,
media_type=content_type)
if "proxy_pass" in config:
return await self._handle_proxy(request, config, route_match.params)
if "root" in config:
return await self._handle_static(request, config)
if config.get("spa_fallback"):
return await self._handle_spa_fallback(request, config)
return PlainTextResponse("404 Not Found", status_code=404)
async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response:
root = Path(config["root"])
path = request.url.path.lstrip("/")
if not path or path == "/":
index_file = config.get("index_file", "index.html")
file_path = root / index_file
else:
file_path = root / path
try:
file_path = file_path.resolve()
root = root.resolve()
@ -125,59 +128,59 @@ class RequestHandler:
return PlainTextResponse("403 Forbidden", status_code=403)
except OSError:
return PlainTextResponse("404 Not Found", status_code=404)
if not file_path.exists() or not file_path.is_file():
return PlainTextResponse("404 Not Found", status_code=404)
content_type, _ = mimetypes.guess_type(str(file_path))
response = FileResponse(str(file_path), media_type=content_type)
if "headers" in config:
for header in config["headers"]:
if ":" in header:
name, value = header.split(":", 1)
response.headers[name.strip()] = value.strip()
if "cache_control" in config:
response.headers["Cache-Control"] = config["cache_control"]
return response
async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response:
path = request.url.path
exclude_patterns = config.get("exclude_patterns", [])
for pattern in exclude_patterns:
if path.startswith(pattern):
return PlainTextResponse("404 Not Found", status_code=404)
root = Path(config.get("root", self.static_dir))
index_file = config.get("index_file", "index.html")
file_path = root / index_file
if file_path.exists() and file_path.is_file():
return FileResponse(str(file_path), media_type="text/html")
return PlainTextResponse("404 Not Found", status_code=404)
async def _handle_proxy(self, request: Request, config: Dict[str, Any],
params: Dict[str, str]) -> Response:
async def _handle_proxy(self, request: Request, config: Dict[str, Any],
params: Dict[str, str]) -> Response:
# TODO: Реализовать полноценное проксирование
proxy_url = config["proxy_pass"]
for key, value in params.items():
proxy_url = proxy_url.replace(f"{{{key}}}", value)
logger.info(f"Proxying request to: {proxy_url}")
return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200)
def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router:
router = Router()
for pattern, config in regex_locations.items():
router.add_route(pattern, config)
return router

View File

@ -4,8 +4,8 @@ import time
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response, PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route
from starlette.types import ASGIApp, Receive, Scope, Send
from pathlib import Path
from typing import Optional, Dict, Any
@ -17,22 +17,28 @@ from . import __version__
logger = get_logger(__name__)
class PyServeMiddleware(BaseHTTPMiddleware):
def __init__(self, app, extension_manager: ExtensionManager):
super().__init__(app)
class PyServeMiddleware:
def __init__(self, app: ASGIApp, extension_manager: ExtensionManager):
self.app = app
self.extension_manager = extension_manager
self.access_logger = get_logger('pyserve.access')
async def dispatch(self, request: Request, call_next):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
request = Request(scope, receive)
response = await self.extension_manager.process_request(request)
if response is None:
response = await call_next(request)
await self.app(scope, receive, send)
return
response = await self.extension_manager.process_response(request, response)
response.headers["Server"] = f"pyserve/{__version__}"
client_ip = request.client.host if request.client else "unknown"
method = request.method
path = str(request.url.path)
@ -41,13 +47,13 @@ class PyServeMiddleware(BaseHTTPMiddleware):
path += f"?{query}"
status_code = response.status_code
process_time = round((time.time() - start_time) * 1000, 2)
self.access_logger.info(f"{client_ip} - {method} {path} - {status_code} - {process_time}ms")
return response
await response(scope, receive, send)
class PyServeServer:
class PyServeServer:
def __init__(self, config: Config):
self.config = config
self.extension_manager = ExtensionManager()
@ -55,34 +61,45 @@ class PyServeServer:
self._setup_logging()
self._load_extensions()
self._create_app()
def _setup_logging(self) -> None:
self.config.setup_logging()
logger.info("PyServe сервер инициализирован")
logger.info("PyServe server initialized")
def _load_extensions(self) -> None:
for ext_config in self.config.extensions:
self.extension_manager.load_extension(
ext_config.type,
ext_config.type,
ext_config.config
)
def _create_app(self) -> None:
routes = [
Route("/health", self._health_check, methods=["GET"]),
Route("/metrics", self._metrics, methods=["GET"]),
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]),
Route(
"/{path:path}",
self._catch_all,
methods=[
"GET",
"POST",
"PUT",
"DELETE",
"PATCH",
"OPTIONS"
]
),
]
self.app = Starlette(routes=routes)
self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager)
async def _health_check(self, request: Request) -> Response:
return PlainTextResponse("OK", status_code=200)
async def _metrics(self, request: Request) -> Response:
metrics = {}
for extension in self.extension_manager.extensions:
if hasattr(extension, 'get_metrics'):
try:
@ -96,22 +113,22 @@ class PyServeServer:
json.dumps(metrics, ensure_ascii=False, indent=2),
media_type="application/json"
)
async def _catch_all(self, request: Request) -> Response:
return PlainTextResponse("404 Not Found", status_code=404)
def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
if not self.config.ssl.enabled:
return None
if not Path(self.config.ssl.cert_file).exists():
logger.error(f"SSL certificate not found: {self.config.ssl.cert_file}")
return None
if not Path(self.config.ssl.key_file).exists():
logger.error(f"SSL key not found: {self.config.ssl.key_file}")
return None
try:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(
@ -123,17 +140,16 @@ class PyServeServer:
except Exception as e:
logger.error(f"Error creating SSL context: {e}")
return None
def run(self) -> None:
if not self.config.validate():
logger.error("Configuration is invalid, server cannot be started")
return
self._ensure_directories()
ssl_context = self._create_ssl_context()
uvicorn_config = {
"app": self.app,
uvicorn_config: Dict[str, Any] = {
"host": self.config.server.host,
"port": self.config.server.port,
"log_level": "critical",
@ -141,7 +157,7 @@ class PyServeServer:
"use_colors": False,
"server_header": False,
}
if ssl_context:
uvicorn_config.update({
"ssl_keyfile": self.config.ssl.key_file,
@ -154,21 +170,22 @@ class PyServeServer:
logger.info(f"Starting PyServe server at {protocol}://{self.config.server.host}:{self.config.server.port}")
try:
uvicorn.run(**uvicorn_config)
assert self.app is not None, "App not initialized"
uvicorn.run(self.app, **uvicorn_config)
except KeyboardInterrupt:
logger.info("Received shutdown signal")
except Exception as e:
logger.error(f"Error starting server: {e}")
finally:
self.shutdown()
async def run_async(self) -> None:
if not self.config.validate():
logger.error("Configuration is invalid, server cannot be started")
return
self._ensure_directories()
config = uvicorn.Config(
app=self.app, # type: ignore
host=self.config.server.host,
@ -177,24 +194,24 @@ class PyServeServer:
access_log=False,
use_colors=False,
)
server = uvicorn.Server(config)
try:
await server.serve()
finally:
self.shutdown()
def _ensure_directories(self) -> None:
directories = [
self.config.http.static_dir,
self.config.http.templates_dir,
]
log_dir = Path(self.config.logging.log_file).parent
if log_dir != Path("."):
directories.append(str(log_dir))
for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.debug(f"Created/checked directory: {directory}")
@ -202,7 +219,7 @@ class PyServeServer:
def shutdown(self) -> None:
logger.info("Shutting down PyServe server")
self.extension_manager.cleanup()
from .logging_utils import shutdown_logging
shutdown_logging()
@ -210,10 +227,10 @@ class PyServeServer:
def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
self.extension_manager.load_extension(extension_type, config)
def get_metrics(self) -> Dict[str, Any]:
metrics = {"server_status": "running"}
for extension in self.extension_manager.extensions:
if hasattr(extension, 'get_metrics'):
try: