import mimetypes from pathlib import Path from typing import Any, Dict from urllib.parse import urlparse import httpx from starlette.requests import Request from starlette.responses import FileResponse, PlainTextResponse, Response from .logging_utils import get_logger try: from pyserve._routing import FastRouteMatch, FastRouter, fast_match # type: ignore CYTHON_ROUTING_AVAILABLE = True except ImportError: from pyserve._routing_py import FastRouteMatch, FastRouter, fast_match CYTHON_ROUTING_AVAILABLE = False logger = get_logger(__name__) # Aliases for backward compatibility RouteMatch = FastRouteMatch Router = FastRouter class RequestHandler: def __init__(self, router: Router, static_dir: str = "./static", default_proxy_timeout: float = 30.0): self.router = router self.static_dir = Path(static_dir) self.default_proxy_timeout = default_proxy_timeout async def handle(self, request: Request) -> Response: path = request.url.path logger.info(f"{request.method} {path}") route_match = self.router.match(path) if not route_match: return PlainTextResponse("404 Not Found", status_code=404) try: return await self._process_route(request, route_match) except Exception as e: logger.error(f"Request processing error {path}: {e}") return PlainTextResponse("500 Internal Server Error", status_code=500) async def _process_route(self, request: Request, route_match: RouteMatch) -> Response: config = route_match.config # HINT: Not using it right now # path = request.url.path if "return" in config: status_text = config["return"] if " " in status_text: status_code, text = status_text.split(" ", 1) status_code = int(status_code) else: status_code = int(status_text) text = "" content_type = config.get("content_type", "text/plain") return PlainTextResponse(text, status_code=status_code, media_type=content_type) if "proxy_pass" in config: return await self._handle_proxy(request, config, route_match.params) if "root" in config: return await self._handle_static(request, config) if config.get("spa_fallback"): return await self._handle_spa_fallback(request, config) return PlainTextResponse("404 Not Found", status_code=404) async def _handle_static(self, request: Request, config: Dict[str, Any]) -> Response: root = Path(config["root"]) path = request.url.path.lstrip("/") if not path or path == "/": index_file = config.get("index_file", "index.html") file_path = root / index_file else: file_path = root / path # If path is a directory, look for index file if file_path.is_dir(): index_file = config.get("index_file", "index.html") file_path = file_path / index_file try: file_path = file_path.resolve() root = root.resolve() if not str(file_path).startswith(str(root)): return PlainTextResponse("403 Forbidden", status_code=403) except OSError: return PlainTextResponse("404 Not Found", status_code=404) if not file_path.exists() or not file_path.is_file(): return PlainTextResponse("404 Not Found", status_code=404) content_type, _ = mimetypes.guess_type(str(file_path)) response = FileResponse(str(file_path), media_type=content_type) if "headers" in config: for header in config["headers"]: if ":" in header: name, value = header.split(":", 1) response.headers[name.strip()] = value.strip() if "cache_control" in config: response.headers["Cache-Control"] = config["cache_control"] return response async def _handle_spa_fallback(self, request: Request, config: Dict[str, Any]) -> Response: path = request.url.path exclude_patterns = config.get("exclude_patterns", []) for pattern in exclude_patterns: if path.startswith(pattern): return PlainTextResponse("404 Not Found", status_code=404) root = Path(config.get("root", self.static_dir)) index_file = config.get("index_file", "index.html") file_path = root / index_file if file_path.exists() and file_path.is_file(): return FileResponse(str(file_path), media_type="text/html") return PlainTextResponse("404 Not Found", status_code=404) async def _handle_proxy(self, request: Request, config: Dict[str, Any], params: Dict[str, str]) -> Response: proxy_url = config["proxy_pass"] for key, value in params.items(): proxy_url = proxy_url.replace(f"{{{key}}}", value) parsed_proxy = urlparse(proxy_url) original_path = request.url.path if parsed_proxy.path and parsed_proxy.path not in ("/", ""): target_url = proxy_url else: base_url = f"{parsed_proxy.scheme}://{parsed_proxy.netloc}" target_url = f"{base_url}{original_path}" if request.url.query: separator = "&" if "?" in target_url else "?" target_url = f"{target_url}{separator}{request.url.query}" logger.info(f"Proxying request to: {target_url}") proxy_headers = dict(request.headers) hop_by_hop_headers = [ "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host", ] for header in hop_by_hop_headers: proxy_headers.pop(header, None) client_ip = request.client.host if request.client else "unknown" proxy_headers["X-Forwarded-For"] = client_ip proxy_headers["X-Forwarded-Proto"] = request.url.scheme proxy_headers["X-Forwarded-Host"] = request.headers.get("host", "") proxy_headers["X-Real-IP"] = client_ip proxy_headers["Host"] = parsed_proxy.netloc if "headers" in config: for header in config["headers"]: if ":" in header: name, value = header.split(":", 1) value = value.strip() for key, param_value in params.items(): value = value.replace(f"{{{key}}}", param_value) value = value.replace("$remote_addr", client_ip) proxy_headers[name.strip()] = value body = await request.body() timeout = config.get("timeout", self.default_proxy_timeout) try: async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: proxy_response = await client.request( method=request.method, url=target_url, headers=proxy_headers, content=body if body else None, ) response_headers = dict(proxy_response.headers) for header in hop_by_hop_headers: response_headers.pop(header, None) return Response( content=proxy_response.content, status_code=proxy_response.status_code, headers=response_headers, media_type=proxy_response.headers.get("content-type"), ) except httpx.ConnectError as e: logger.error(f"Proxy connection error to {target_url}: {e}") return PlainTextResponse("502 Bad Gateway", status_code=502) except httpx.TimeoutException as e: logger.error(f"Proxy timeout to {target_url}: {e}") return PlainTextResponse("504 Gateway Timeout", status_code=504) except Exception as e: logger.error(f"Proxy error to {target_url}: {e}") return PlainTextResponse("502 Bad Gateway", status_code=502) def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router: router = Router() for pattern, config in regex_locations.items(): router.add_route(pattern, config) return router