forked from Shifty/pyserveX
fix: Update max-line-length in .flake8 and refactor routing and server code for improved readability and functionality
This commit is contained in:
parent
84cd1c974f
commit
6b157d7626
2
.flake8
2
.flake8
@ -1,4 +1,4 @@
|
||||
[flake8]
|
||||
max-line-length = 100
|
||||
max-line-length = 120
|
||||
exclude = __pycache__,.git,.venv,venv,build,dist
|
||||
ignore = E203,W503
|
||||
|
||||
@ -8,34 +8,36 @@ from .logging_utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RouteMatch:
|
||||
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
|
||||
self.config = config
|
||||
self.params = params or {}
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(self, static_dir: str = "./static"):
|
||||
self.static_dir = Path(static_dir)
|
||||
self.routes: Dict[Pattern, Dict[str, Any]] = {}
|
||||
self.exact_routes: Dict[str, Dict[str, Any]] = {}
|
||||
self.default_route: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
|
||||
if pattern.startswith("="):
|
||||
exact_path = pattern[1:]
|
||||
self.exact_routes[exact_path] = config
|
||||
logger.debug(f"Added exact route: {exact_path}")
|
||||
return
|
||||
|
||||
|
||||
if pattern == "__default__":
|
||||
self.default_route = config
|
||||
logger.debug("Added default route")
|
||||
return
|
||||
|
||||
|
||||
if pattern.startswith("~"):
|
||||
case_insensitive = pattern.startswith("~*")
|
||||
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
|
||||
|
||||
|
||||
flags = re.IGNORECASE if case_insensitive else 0
|
||||
try:
|
||||
compiled_pattern = re.compile(regex_pattern, flags)
|
||||
@ -43,47 +45,48 @@ class Router:
|
||||
logger.debug(f"Added regex route: {pattern}")
|
||||
except re.error as e:
|
||||
logger.error(f"Regex compilation error {pattern}: {e}")
|
||||
|
||||
|
||||
def match(self, path: str) -> Optional[RouteMatch]:
|
||||
if path in self.exact_routes:
|
||||
return RouteMatch(self.exact_routes[path])
|
||||
|
||||
|
||||
for pattern, config in self.routes.items():
|
||||
match = pattern.search(path)
|
||||
if match:
|
||||
params = match.groupdict()
|
||||
return RouteMatch(config, params)
|
||||
|
||||
|
||||
if self.default_route:
|
||||
return RouteMatch(self.default_route)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
class RequestHandler:
|
||||
def __init__(self, router: Router, static_dir: str = "./static"):
|
||||
self.router = router
|
||||
self.static_dir = Path(static_dir)
|
||||
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
path = request.url.path
|
||||
|
||||
|
||||
logger.info(f"{request.method} {path}")
|
||||
|
||||
|
||||
route_match = self.router.match(path)
|
||||
if not route_match:
|
||||
return PlainTextResponse("404 Not Found", status_code=404)
|
||||
|
||||
|
||||
try:
|
||||
return await self._process_route(request, route_match)
|
||||
except Exception as e:
|
||||
logger.error(f"Request processing error {path}: {e}")
|
||||
return PlainTextResponse("500 Internal Server Error", status_code=500)
|
||||
|
||||
|
||||
async def _process_route(self, request: Request, route_match: RouteMatch) -> Response:
|
||||
config = route_match.config
|
||||
path = request.url.path
|
||||
|
||||
# HINT: Not using it right now
|
||||
# path = request.url.path
|
||||
|
||||
if "return" in config:
|
||||
status_text = config["return"]
|
||||
if " " in status_text:
|
||||
@ -92,32 +95,32 @@ class RequestHandler:
|
||||
else:
|
||||
status_code = int(status_text)
|
||||
text = ""
|
||||
|
||||
|
||||
content_type = config.get("content_type", "text/plain")
|
||||
return PlainTextResponse(text, status_code=status_code,
|
||||
media_type=content_type)
|
||||
|
||||
return PlainTextResponse(text, status_code=status_code,
|
||||
media_type=content_type)
|
||||
|
||||
if "proxy_pass" in config:
|
||||
return await self._handle_proxy(request, config, route_match.params)
|
||||
|
||||
|
||||
if "root" in config:
|
||||
return await self._handle_static(request, config)
|
||||
|
||||
|
||||
if config.get("spa_fallback"):
|
||||
return await self._handle_spa_fallback(request, config)
|
||||
|
||||
|
||||
return PlainTextResponse("404 Not Found", status_code=404)
|
||||
|
||||
|
||||
async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response:
|
||||
root = Path(config["root"])
|
||||
path = request.url.path.lstrip("/")
|
||||
|
||||
|
||||
if not path or path == "/":
|
||||
index_file = config.get("index_file", "index.html")
|
||||
file_path = root / index_file
|
||||
else:
|
||||
file_path = root / path
|
||||
|
||||
|
||||
try:
|
||||
file_path = file_path.resolve()
|
||||
root = root.resolve()
|
||||
@ -125,59 +128,59 @@ class RequestHandler:
|
||||
return PlainTextResponse("403 Forbidden", status_code=403)
|
||||
except OSError:
|
||||
return PlainTextResponse("404 Not Found", status_code=404)
|
||||
|
||||
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
return PlainTextResponse("404 Not Found", status_code=404)
|
||||
|
||||
|
||||
content_type, _ = mimetypes.guess_type(str(file_path))
|
||||
|
||||
|
||||
response = FileResponse(str(file_path), media_type=content_type)
|
||||
|
||||
|
||||
if "headers" in config:
|
||||
for header in config["headers"]:
|
||||
if ":" in header:
|
||||
name, value = header.split(":", 1)
|
||||
response.headers[name.strip()] = value.strip()
|
||||
|
||||
|
||||
if "cache_control" in config:
|
||||
response.headers["Cache-Control"] = config["cache_control"]
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response:
|
||||
path = request.url.path
|
||||
|
||||
|
||||
exclude_patterns = config.get("exclude_patterns", [])
|
||||
for pattern in exclude_patterns:
|
||||
if path.startswith(pattern):
|
||||
return PlainTextResponse("404 Not Found", status_code=404)
|
||||
|
||||
|
||||
root = Path(config.get("root", self.static_dir))
|
||||
index_file = config.get("index_file", "index.html")
|
||||
file_path = root / index_file
|
||||
|
||||
|
||||
if file_path.exists() and file_path.is_file():
|
||||
return FileResponse(str(file_path), media_type="text/html")
|
||||
|
||||
|
||||
return PlainTextResponse("404 Not Found", status_code=404)
|
||||
|
||||
async def _handle_proxy(self, request: Request, config: Dict[str, Any],
|
||||
params: Dict[str, str]) -> Response:
|
||||
|
||||
async def _handle_proxy(self, request: Request, config: Dict[str, Any],
|
||||
params: Dict[str, str]) -> Response:
|
||||
# TODO: Реализовать полноценное проксирование
|
||||
proxy_url = config["proxy_pass"]
|
||||
|
||||
|
||||
for key, value in params.items():
|
||||
proxy_url = proxy_url.replace(f"{{{key}}}", value)
|
||||
|
||||
|
||||
logger.info(f"Proxying request to: {proxy_url}")
|
||||
|
||||
|
||||
return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200)
|
||||
|
||||
|
||||
def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router:
|
||||
router = Router()
|
||||
|
||||
|
||||
for pattern, config in regex_locations.items():
|
||||
router.add_route(pattern, config)
|
||||
|
||||
|
||||
return router
|
||||
|
||||
@ -4,8 +4,8 @@ import time
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, PlainTextResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.routing import Route
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
@ -17,22 +17,28 @@ from . import __version__
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PyServeMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, extension_manager: ExtensionManager):
|
||||
super().__init__(app)
|
||||
class PyServeMiddleware:
|
||||
def __init__(self, app: ASGIApp, extension_manager: ExtensionManager):
|
||||
self.app = app
|
||||
self.extension_manager = extension_manager
|
||||
self.access_logger = get_logger('pyserve.access')
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
request = Request(scope, receive)
|
||||
response = await self.extension_manager.process_request(request)
|
||||
|
||||
|
||||
if response is None:
|
||||
response = await call_next(request)
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response = await self.extension_manager.process_response(request, response)
|
||||
|
||||
response.headers["Server"] = f"pyserve/{__version__}"
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
method = request.method
|
||||
path = str(request.url.path)
|
||||
@ -41,13 +47,13 @@ class PyServeMiddleware(BaseHTTPMiddleware):
|
||||
path += f"?{query}"
|
||||
status_code = response.status_code
|
||||
process_time = round((time.time() - start_time) * 1000, 2)
|
||||
|
||||
|
||||
self.access_logger.info(f"{client_ip} - {method} {path} - {status_code} - {process_time}ms")
|
||||
|
||||
return response
|
||||
|
||||
await response(scope, receive, send)
|
||||
|
||||
|
||||
class PyServeServer:
|
||||
class PyServeServer:
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.extension_manager = ExtensionManager()
|
||||
@ -55,34 +61,45 @@ class PyServeServer:
|
||||
self._setup_logging()
|
||||
self._load_extensions()
|
||||
self._create_app()
|
||||
|
||||
|
||||
def _setup_logging(self) -> None:
|
||||
self.config.setup_logging()
|
||||
logger.info("PyServe сервер инициализирован")
|
||||
|
||||
logger.info("PyServe server initialized")
|
||||
|
||||
def _load_extensions(self) -> None:
|
||||
for ext_config in self.config.extensions:
|
||||
self.extension_manager.load_extension(
|
||||
ext_config.type,
|
||||
ext_config.type,
|
||||
ext_config.config
|
||||
)
|
||||
|
||||
|
||||
def _create_app(self) -> None:
|
||||
routes = [
|
||||
Route("/health", self._health_check, methods=["GET"]),
|
||||
Route("/metrics", self._metrics, methods=["GET"]),
|
||||
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]),
|
||||
Route(
|
||||
"/{path:path}",
|
||||
self._catch_all,
|
||||
methods=[
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"PATCH",
|
||||
"OPTIONS"
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
self.app = Starlette(routes=routes)
|
||||
self.app.add_middleware(PyServeMiddleware, extension_manager=self.extension_manager)
|
||||
|
||||
|
||||
async def _health_check(self, request: Request) -> Response:
|
||||
return PlainTextResponse("OK", status_code=200)
|
||||
|
||||
|
||||
async def _metrics(self, request: Request) -> Response:
|
||||
metrics = {}
|
||||
|
||||
|
||||
for extension in self.extension_manager.extensions:
|
||||
if hasattr(extension, 'get_metrics'):
|
||||
try:
|
||||
@ -96,22 +113,22 @@ class PyServeServer:
|
||||
json.dumps(metrics, ensure_ascii=False, indent=2),
|
||||
media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
async def _catch_all(self, request: Request) -> Response:
|
||||
return PlainTextResponse("404 Not Found", status_code=404)
|
||||
|
||||
|
||||
def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
|
||||
if not self.config.ssl.enabled:
|
||||
return None
|
||||
|
||||
|
||||
if not Path(self.config.ssl.cert_file).exists():
|
||||
logger.error(f"SSL certificate not found: {self.config.ssl.cert_file}")
|
||||
return None
|
||||
|
||||
|
||||
if not Path(self.config.ssl.key_file).exists():
|
||||
logger.error(f"SSL key not found: {self.config.ssl.key_file}")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
context.load_cert_chain(
|
||||
@ -123,17 +140,16 @@ class PyServeServer:
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating SSL context: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def run(self) -> None:
|
||||
if not self.config.validate():
|
||||
logger.error("Configuration is invalid, server cannot be started")
|
||||
return
|
||||
|
||||
|
||||
self._ensure_directories()
|
||||
ssl_context = self._create_ssl_context()
|
||||
|
||||
uvicorn_config = {
|
||||
"app": self.app,
|
||||
|
||||
uvicorn_config: Dict[str, Any] = {
|
||||
"host": self.config.server.host,
|
||||
"port": self.config.server.port,
|
||||
"log_level": "critical",
|
||||
@ -141,7 +157,7 @@ class PyServeServer:
|
||||
"use_colors": False,
|
||||
"server_header": False,
|
||||
}
|
||||
|
||||
|
||||
if ssl_context:
|
||||
uvicorn_config.update({
|
||||
"ssl_keyfile": self.config.ssl.key_file,
|
||||
@ -154,21 +170,22 @@ class PyServeServer:
|
||||
logger.info(f"Starting PyServe server at {protocol}://{self.config.server.host}:{self.config.server.port}")
|
||||
|
||||
try:
|
||||
uvicorn.run(**uvicorn_config)
|
||||
assert self.app is not None, "App not initialized"
|
||||
uvicorn.run(self.app, **uvicorn_config)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received shutdown signal")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting server: {e}")
|
||||
finally:
|
||||
self.shutdown()
|
||||
|
||||
|
||||
async def run_async(self) -> None:
|
||||
if not self.config.validate():
|
||||
logger.error("Configuration is invalid, server cannot be started")
|
||||
return
|
||||
|
||||
|
||||
self._ensure_directories()
|
||||
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=self.app, # type: ignore
|
||||
host=self.config.server.host,
|
||||
@ -177,24 +194,24 @@ class PyServeServer:
|
||||
access_log=False,
|
||||
use_colors=False,
|
||||
)
|
||||
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
|
||||
try:
|
||||
await server.serve()
|
||||
finally:
|
||||
self.shutdown()
|
||||
|
||||
|
||||
def _ensure_directories(self) -> None:
|
||||
directories = [
|
||||
self.config.http.static_dir,
|
||||
self.config.http.templates_dir,
|
||||
]
|
||||
|
||||
|
||||
log_dir = Path(self.config.logging.log_file).parent
|
||||
if log_dir != Path("."):
|
||||
directories.append(str(log_dir))
|
||||
|
||||
|
||||
for directory in directories:
|
||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
logger.debug(f"Created/checked directory: {directory}")
|
||||
@ -202,7 +219,7 @@ class PyServeServer:
|
||||
def shutdown(self) -> None:
|
||||
logger.info("Shutting down PyServe server")
|
||||
self.extension_manager.cleanup()
|
||||
|
||||
|
||||
from .logging_utils import shutdown_logging
|
||||
shutdown_logging()
|
||||
|
||||
@ -210,10 +227,10 @@ class PyServeServer:
|
||||
|
||||
def add_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
|
||||
self.extension_manager.load_extension(extension_type, config)
|
||||
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
metrics = {"server_status": "running"}
|
||||
|
||||
|
||||
for extension in self.extension_manager.extensions:
|
||||
if hasattr(extension, 'get_metrics'):
|
||||
try:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user