konduktor/pyserve/routing.py
2025-12-03 12:54:45 +03:00

274 lines
9.8 KiB
Python

import mimetypes
import re
from pathlib import Path
from typing import Any, Dict, Optional, Pattern
from urllib.parse import urlparse
import httpx
from starlette.requests import Request
from starlette.responses import FileResponse, PlainTextResponse, Response
from .logging_utils import get_logger
logger = get_logger(__name__)
class RouteMatch:
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
self.config = config
self.params = params or {}
class Router:
def __init__(self, static_dir: str = "./static"):
self.static_dir = Path(static_dir)
self.routes: Dict[Pattern, Dict[str, Any]] = {}
self.exact_routes: Dict[str, Dict[str, Any]] = {}
self.default_route: Optional[Dict[str, Any]] = None
def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
if pattern.startswith("="):
exact_path = pattern[1:]
self.exact_routes[exact_path] = config
logger.debug(f"Added exact route: {exact_path}")
return
if pattern == "__default__":
self.default_route = config
logger.debug("Added default route")
return
if pattern.startswith("~"):
case_insensitive = pattern.startswith("~*")
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
flags = re.IGNORECASE if case_insensitive else 0
try:
compiled_pattern = re.compile(regex_pattern, flags)
self.routes[compiled_pattern] = config
logger.debug(f"Added regex route: {pattern}")
except re.error as e:
logger.error(f"Regex compilation error {pattern}: {e}")
def match(self, path: str) -> Optional[RouteMatch]:
if path in self.exact_routes:
return RouteMatch(self.exact_routes[path])
for pattern, config in self.routes.items():
match = pattern.search(path)
if match:
params = match.groupdict()
return RouteMatch(config, params)
if self.default_route:
return RouteMatch(self.default_route)
return None
class RequestHandler:
def __init__(self, router: Router, static_dir: str = "./static", default_proxy_timeout: float = 30.0):
self.router = router
self.static_dir = Path(static_dir)
self.default_proxy_timeout = default_proxy_timeout
async def handle(self, request: Request) -> Response:
path = request.url.path
logger.info(f"{request.method} {path}")
route_match = self.router.match(path)
if not route_match:
return PlainTextResponse("404 Not Found", status_code=404)
try:
return await self._process_route(request, route_match)
except Exception as e:
logger.error(f"Request processing error {path}: {e}")
return PlainTextResponse("500 Internal Server Error", status_code=500)
async def _process_route(self, request: Request, route_match: RouteMatch) -> Response:
config = route_match.config
# HINT: Not using it right now
# path = request.url.path
if "return" in config:
status_text = config["return"]
if " " in status_text:
status_code, text = status_text.split(" ", 1)
status_code = int(status_code)
else:
status_code = int(status_text)
text = ""
content_type = config.get("content_type", "text/plain")
return PlainTextResponse(text, status_code=status_code, media_type=content_type)
if "proxy_pass" in config:
return await self._handle_proxy(request, config, route_match.params)
if "root" in config:
return await self._handle_static(request, config)
if config.get("spa_fallback"):
return await self._handle_spa_fallback(request, config)
return PlainTextResponse("404 Not Found", status_code=404)
async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response:
root = Path(config["root"])
path = request.url.path.lstrip("/")
if not path or path == "/":
index_file = config.get("index_file", "index.html")
file_path = root / index_file
else:
file_path = root / path
# If path is a directory, look for index file
if file_path.is_dir():
index_file = config.get("index_file", "index.html")
file_path = file_path / index_file
try:
file_path = file_path.resolve()
root = root.resolve()
if not str(file_path).startswith(str(root)):
return PlainTextResponse("403 Forbidden", status_code=403)
except OSError:
return PlainTextResponse("404 Not Found", status_code=404)
if not file_path.exists() or not file_path.is_file():
return PlainTextResponse("404 Not Found", status_code=404)
content_type, _ = mimetypes.guess_type(str(file_path))
response = FileResponse(str(file_path), media_type=content_type)
if "headers" in config:
for header in config["headers"]:
if ":" in header:
name, value = header.split(":", 1)
response.headers[name.strip()] = value.strip()
if "cache_control" in config:
response.headers["Cache-Control"] = config["cache_control"]
return response
async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response:
path = request.url.path
exclude_patterns = config.get("exclude_patterns", [])
for pattern in exclude_patterns:
if path.startswith(pattern):
return PlainTextResponse("404 Not Found", status_code=404)
root = Path(config.get("root", self.static_dir))
index_file = config.get("index_file", "index.html")
file_path = root / index_file
if file_path.exists() and file_path.is_file():
return FileResponse(str(file_path), media_type="text/html")
return PlainTextResponse("404 Not Found", status_code=404)
async def _handle_proxy(self, request: Request, config: Dict[str, Any], params: Dict[str, str]) -> Response:
proxy_url = config["proxy_pass"]
for key, value in params.items():
proxy_url = proxy_url.replace(f"{{{key}}}", value)
parsed_proxy = urlparse(proxy_url)
original_path = request.url.path
if parsed_proxy.path and parsed_proxy.path not in ("/", ""):
target_url = proxy_url
else:
base_url = f"{parsed_proxy.scheme}://{parsed_proxy.netloc}"
target_url = f"{base_url}{original_path}"
if request.url.query:
separator = "&" if "?" in target_url else "?"
target_url = f"{target_url}{separator}{request.url.query}"
logger.info(f"Proxying request to: {target_url}")
proxy_headers = dict(request.headers)
hop_by_hop_headers = [
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
"host",
]
for header in hop_by_hop_headers:
proxy_headers.pop(header, None)
client_ip = request.client.host if request.client else "unknown"
proxy_headers["X-Forwarded-For"] = client_ip
proxy_headers["X-Forwarded-Proto"] = request.url.scheme
proxy_headers["X-Forwarded-Host"] = request.headers.get("host", "")
proxy_headers["X-Real-IP"] = client_ip
proxy_headers["Host"] = parsed_proxy.netloc
if "headers" in config:
for header in config["headers"]:
if ":" in header:
name, value = header.split(":", 1)
value = value.strip()
for key, param_value in params.items():
value = value.replace(f"{{{key}}}", param_value)
value = value.replace("$remote_addr", client_ip)
proxy_headers[name.strip()] = value
body = await request.body()
timeout = config.get("timeout", self.default_proxy_timeout)
try:
async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client:
proxy_response = await client.request(
method=request.method,
url=target_url,
headers=proxy_headers,
content=body if body else None,
)
response_headers = dict(proxy_response.headers)
for header in hop_by_hop_headers:
response_headers.pop(header, None)
return Response(
content=proxy_response.content,
status_code=proxy_response.status_code,
headers=response_headers,
media_type=proxy_response.headers.get("content-type"),
)
except httpx.ConnectError as e:
logger.error(f"Proxy connection error to {target_url}: {e}")
return PlainTextResponse("502 Bad Gateway", status_code=502)
except httpx.TimeoutException as e:
logger.error(f"Proxy timeout to {target_url}: {e}")
return PlainTextResponse("504 Gateway Timeout", status_code=504)
except Exception as e:
logger.error(f"Proxy error to {target_url}: {e}")
return PlainTextResponse("502 Bad Gateway", status_code=502)
def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router:
router = Router()
for pattern, config in regex_locations.items():
router.add_route(pattern, config)
return router