import re import mimetypes from pathlib import Path from typing import Dict, Any, Optional, Pattern from urllib.parse import urlparse import httpx from starlette.requests import Request from starlette.responses import Response, FileResponse, PlainTextResponse from .logging_utils import get_logger logger = get_logger(__name__) class RouteMatch: def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None): self.config = config self.params = params or {} class Router: def __init__(self, static_dir: str = "./static"): self.static_dir = Path(static_dir) self.routes: Dict[Pattern, Dict[str, Any]] = {} self.exact_routes: Dict[str, Dict[str, Any]] = {} self.default_route: Optional[Dict[str, Any]] = None def add_route(self, pattern: str, config: Dict[str, Any]) -> None: if pattern.startswith("="): exact_path = pattern[1:] self.exact_routes[exact_path] = config logger.debug(f"Added exact route: {exact_path}") return if pattern == "__default__": self.default_route = config logger.debug("Added default route") return if pattern.startswith("~"): case_insensitive = pattern.startswith("~*") regex_pattern = pattern[2:] if case_insensitive else pattern[1:] flags = re.IGNORECASE if case_insensitive else 0 try: compiled_pattern = re.compile(regex_pattern, flags) self.routes[compiled_pattern] = config logger.debug(f"Added regex route: {pattern}") except re.error as e: logger.error(f"Regex compilation error {pattern}: {e}") def match(self, path: str) -> Optional[RouteMatch]: if path in self.exact_routes: return RouteMatch(self.exact_routes[path]) for pattern, config in self.routes.items(): match = pattern.search(path) if match: params = match.groupdict() return RouteMatch(config, params) if self.default_route: return RouteMatch(self.default_route) return None class RequestHandler: def __init__(self, router: Router, static_dir: str = "./static", default_proxy_timeout: float = 30.0): self.router = router self.static_dir = Path(static_dir) self.default_proxy_timeout = default_proxy_timeout async def handle(self, request: Request) -> Response: path = request.url.path logger.info(f"{request.method} {path}") route_match = self.router.match(path) if not route_match: return PlainTextResponse("404 Not Found", status_code=404) try: return await self._process_route(request, route_match) except Exception as e: logger.error(f"Request processing error {path}: {e}") return PlainTextResponse("500 Internal Server Error", status_code=500) async def _process_route(self, request: Request, route_match: RouteMatch) -> Response: config = route_match.config # HINT: Not using it right now # path = request.url.path if "return" in config: status_text = config["return"] if " " in status_text: status_code, text = status_text.split(" ", 1) status_code = int(status_code) else: status_code = int(status_text) text = "" content_type = config.get("content_type", "text/plain") return PlainTextResponse(text, status_code=status_code, media_type=content_type) if "proxy_pass" in config: return await self._handle_proxy(request, config, route_match.params) if "root" in config: return await self._handle_static(request, config) if config.get("spa_fallback"): return await self._handle_spa_fallback(request, config) return PlainTextResponse("404 Not Found", status_code=404) async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response: root = Path(config["root"]) path = request.url.path.lstrip("/") if not path or path == "/": index_file = config.get("index_file", "index.html") file_path = root / index_file else: file_path = root / path # If path is a directory, look for index file if file_path.is_dir(): index_file = config.get("index_file", "index.html") file_path = file_path / index_file try: file_path = file_path.resolve() root = root.resolve() if not str(file_path).startswith(str(root)): return PlainTextResponse("403 Forbidden", status_code=403) except OSError: return PlainTextResponse("404 Not Found", status_code=404) if not file_path.exists() or not file_path.is_file(): return PlainTextResponse("404 Not Found", status_code=404) content_type, _ = mimetypes.guess_type(str(file_path)) response = FileResponse(str(file_path), media_type=content_type) if "headers" in config: for header in config["headers"]: if ":" in header: name, value = header.split(":", 1) response.headers[name.strip()] = value.strip() if "cache_control" in config: response.headers["Cache-Control"] = config["cache_control"] return response async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response: path = request.url.path exclude_patterns = config.get("exclude_patterns", []) for pattern in exclude_patterns: if path.startswith(pattern): return PlainTextResponse("404 Not Found", status_code=404) root = Path(config.get("root", self.static_dir)) index_file = config.get("index_file", "index.html") file_path = root / index_file if file_path.exists() and file_path.is_file(): return FileResponse(str(file_path), media_type="text/html") return PlainTextResponse("404 Not Found", status_code=404) async def _handle_proxy(self, request: Request, config: Dict[str, Any], params: Dict[str, str]) -> Response: proxy_url = config["proxy_pass"] for key, value in params.items(): proxy_url = proxy_url.replace(f"{{{key}}}", value) parsed_proxy = urlparse(proxy_url) original_path = request.url.path if parsed_proxy.path and parsed_proxy.path not in ("/", ""): target_url = proxy_url else: base_url = f"{parsed_proxy.scheme}://{parsed_proxy.netloc}" target_url = f"{base_url}{original_path}" if request.url.query: separator = "&" if "?" in target_url else "?" target_url = f"{target_url}{separator}{request.url.query}" logger.info(f"Proxying request to: {target_url}") proxy_headers = dict(request.headers) hop_by_hop_headers = [ "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host" ] for header in hop_by_hop_headers: proxy_headers.pop(header, None) client_ip = request.client.host if request.client else "unknown" proxy_headers["X-Forwarded-For"] = client_ip proxy_headers["X-Forwarded-Proto"] = request.url.scheme proxy_headers["X-Forwarded-Host"] = request.headers.get("host", "") proxy_headers["X-Real-IP"] = client_ip proxy_headers["Host"] = parsed_proxy.netloc if "headers" in config: for header in config["headers"]: if ":" in header: name, value = header.split(":", 1) value = value.strip() for key, param_value in params.items(): value = value.replace(f"{{{key}}}", param_value) value = value.replace("$remote_addr", client_ip) proxy_headers[name.strip()] = value body = await request.body() timeout = config.get("timeout", self.default_proxy_timeout) try: async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: proxy_response = await client.request( method=request.method, url=target_url, headers=proxy_headers, content=body if body else None, ) response_headers = dict(proxy_response.headers) for header in hop_by_hop_headers: response_headers.pop(header, None) return Response( content=proxy_response.content, status_code=proxy_response.status_code, headers=response_headers, media_type=proxy_response.headers.get("content-type"), ) except httpx.ConnectError as e: logger.error(f"Proxy connection error to {target_url}: {e}") return PlainTextResponse("502 Bad Gateway", status_code=502) except httpx.TimeoutException as e: logger.error(f"Proxy timeout to {target_url}: {e}") return PlainTextResponse("504 Gateway Timeout", status_code=504) except Exception as e: logger.error(f"Proxy error to {target_url}: {e}") return PlainTextResponse("502 Bad Gateway", status_code=502) def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router: router = Router() for pattern, config in regex_locations.items(): router.add_route(pattern, config) return router