import re import mimetypes from pathlib import Path from typing import Dict, Any, Optional, Pattern from starlette.requests import Request from starlette.responses import Response, FileResponse, PlainTextResponse from .logging_utils import get_logger logger = get_logger(__name__) class RouteMatch: def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None): self.config = config self.params = params or {} class Router: def __init__(self, static_dir: str = "./static"): self.static_dir = Path(static_dir) self.routes: Dict[Pattern, Dict[str, Any]] = {} self.exact_routes: Dict[str, Dict[str, Any]] = {} self.default_route: Optional[Dict[str, Any]] = None def add_route(self, pattern: str, config: Dict[str, Any]) -> None: if pattern.startswith("="): exact_path = pattern[1:] self.exact_routes[exact_path] = config logger.debug(f"Added exact route: {exact_path}") return if pattern == "__default__": self.default_route = config logger.debug("Added default route") return if pattern.startswith("~"): case_insensitive = pattern.startswith("~*") regex_pattern = pattern[2:] if case_insensitive else pattern[1:] flags = re.IGNORECASE if case_insensitive else 0 try: compiled_pattern = re.compile(regex_pattern, flags) self.routes[compiled_pattern] = config logger.debug(f"Added regex route: {pattern}") except re.error as e: logger.error(f"Regex compilation error {pattern}: {e}") def match(self, path: str) -> Optional[RouteMatch]: if path in self.exact_routes: return RouteMatch(self.exact_routes[path]) for pattern, config in self.routes.items(): match = pattern.search(path) if match: params = match.groupdict() return RouteMatch(config, params) if self.default_route: return RouteMatch(self.default_route) return None class RequestHandler: def __init__(self, router: Router, static_dir: str = "./static"): self.router = router self.static_dir = Path(static_dir) async def handle(self, request: Request) -> Response: path = request.url.path logger.info(f"{request.method} {path}") route_match = self.router.match(path) if not route_match: return PlainTextResponse("404 Not Found", status_code=404) try: return await self._process_route(request, route_match) except Exception as e: logger.error(f"Request processing error {path}: {e}") return PlainTextResponse("500 Internal Server Error", status_code=500) async def _process_route(self, request: Request, route_match: RouteMatch) -> Response: config = route_match.config # HINT: Not using it right now # path = request.url.path if "return" in config: status_text = config["return"] if " " in status_text: status_code, text = status_text.split(" ", 1) status_code = int(status_code) else: status_code = int(status_text) text = "" content_type = config.get("content_type", "text/plain") return PlainTextResponse(text, status_code=status_code, media_type=content_type) if "proxy_pass" in config: return await self._handle_proxy(request, config, route_match.params) if "root" in config: return await self._handle_static(request, config) if config.get("spa_fallback"): return await self._handle_spa_fallback(request, config) return PlainTextResponse("404 Not Found", status_code=404) async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response: root = Path(config["root"]) path = request.url.path.lstrip("/") if not path or path == "/": index_file = config.get("index_file", "index.html") file_path = root / index_file else: file_path = root / path try: file_path = file_path.resolve() root = root.resolve() if not str(file_path).startswith(str(root)): return PlainTextResponse("403 Forbidden", status_code=403) except OSError: return PlainTextResponse("404 Not Found", status_code=404) if not file_path.exists() or not file_path.is_file(): return PlainTextResponse("404 Not Found", status_code=404) content_type, _ = mimetypes.guess_type(str(file_path)) response = FileResponse(str(file_path), media_type=content_type) if "headers" in config: for header in config["headers"]: if ":" in header: name, value = header.split(":", 1) response.headers[name.strip()] = value.strip() if "cache_control" in config: response.headers["Cache-Control"] = config["cache_control"] return response async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response: path = request.url.path exclude_patterns = config.get("exclude_patterns", []) for pattern in exclude_patterns: if path.startswith(pattern): return PlainTextResponse("404 Not Found", status_code=404) root = Path(config.get("root", self.static_dir)) index_file = config.get("index_file", "index.html") file_path = root / index_file if file_path.exists() and file_path.is_file(): return FileResponse(str(file_path), media_type="text/html") return PlainTextResponse("404 Not Found", status_code=404) async def _handle_proxy(self, request: Request, config: Dict[str, Any], params: Dict[str, str]) -> Response: # TODO: implement real proxying proxy_url = config["proxy_pass"] for key, value in params.items(): proxy_url = proxy_url.replace(f"{{{key}}}", value) logger.info(f"Proxying request to: {proxy_url}") return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200) def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router: router = Router() for pattern, config in regex_locations.items(): router.add_route(pattern, config) return router