diff --git a/poetry.lock b/poetry.lock index 848b0b2..37ae383 100644 --- a/poetry.lock +++ b/poetry.lock @@ -65,6 +65,18 @@ d = ["aiohttp (>=3.10)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "certifi" +version = "2025.11.12" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b"}, + {file = "certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316"}, +] + [[package]] name = "click" version = "8.2.1" @@ -223,6 +235,28 @@ files = [ {file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"}, ] +[[package]] +name = "httpcore" +version = "1.0.9" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"}, + {file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.16" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + [[package]] name = "httptools" version = "0.6.4" @@ -279,6 +313,32 @@ files = [ [package.extras] test = ["Cython (>=0.29.24)"] +[[package]] +name = "httpx" +version = "0.27.2" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "idna" version = "3.10" @@ -524,6 +584,26 @@ pygments = ">=2.7.2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.10" +groups = ["main", "dev"] +files = [ + {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, + {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, +] + +[package.dependencies] +pytest = ">=8.2,<10" +typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "6.2.1" @@ -967,9 +1047,9 @@ files = [ ] [extras] -dev = ["black", "flake8", "isort", "mypy", "pytest", "pytest-cov"] +dev = ["black", "flake8", "isort", "mypy", "pytest", "pytest-asyncio", "pytest-cov"] [metadata] lock-version = "2.1" python-versions = ">=3.12" -content-hash = "5eda39db8e3d119d03c8e6083d1f9cd14691669a7130fb17b1445a0dd7bb79e7" +content-hash = "e68108657ddfdc07ac0c4f5dbd9c5d2950e78b8b0053e4487ebf2327bbf4e020" diff --git a/pyproject.toml b/pyproject.toml index 8aa97e1..d3a7526 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "pyyaml (>=6.0,<7.0)", "types-pyyaml (>=6.0.12.20250822,<7.0.0.0)", "structlog (>=25.4.0,<26.0.0)", + "httpx (>=0.27.0,<0.28.0)", ] [project.scripts] @@ -23,6 +24,7 @@ pyserve = "pyserve.cli:main" dev = [ "pytest", "pytest-cov", + "pytest-asyncio", "black", "isort", "mypy", @@ -73,4 +75,5 @@ black = "^25.1.0" isort = "^6.0.1" mypy = "^1.17.1" flake8 = "^7.3.0" +pytest-asyncio = "^1.3.0" diff --git a/pyserve/config.py b/pyserve/config.py index f81963c..19bd34c 100644 --- a/pyserve/config.py +++ b/pyserve/config.py @@ -18,6 +18,7 @@ class ServerConfig: port: int = 8080 backlog: int = 5 default_root: bool = False + proxy_timeout: float = 30.0 redirect_instructions: Dict[str, str] = field(default_factory=dict) @@ -112,6 +113,7 @@ class Config: port=server_data.get('port', config.server.port), backlog=server_data.get('backlog', config.server.backlog), default_root=server_data.get('default_root', config.server.default_root), + proxy_timeout=server_data.get('proxy_timeout', config.server.proxy_timeout), redirect_instructions=server_data.get('redirect_instructions', {}) ) diff --git a/pyserve/extensions.py b/pyserve/extensions.py index c2ade69..b5d379b 100644 --- a/pyserve/extensions.py +++ b/pyserve/extensions.py @@ -33,9 +33,10 @@ class RoutingExtension(Extension): from .routing import create_router_from_config regex_locations = config.get("regex_locations", {}) + default_proxy_timeout = config.get("default_proxy_timeout", 30.0) self.router = create_router_from_config(regex_locations) from .routing import RequestHandler - self.handler = RequestHandler(self.router) + self.handler = RequestHandler(self.router, default_proxy_timeout=default_proxy_timeout) async def process_request(self, request: Request) -> Optional[Response]: try: diff --git a/pyserve/routing.py b/pyserve/routing.py index 3eadb41..df22085 100644 --- a/pyserve/routing.py +++ b/pyserve/routing.py @@ -2,6 +2,8 @@ import re import mimetypes from pathlib import Path from typing import Dict, Any, Optional, Pattern +from urllib.parse import urlparse +import httpx from starlette.requests import Request from starlette.responses import Response, FileResponse, PlainTextResponse from .logging_utils import get_logger @@ -63,9 +65,10 @@ class Router: class RequestHandler: - def __init__(self, router: Router, static_dir: str = "./static"): + def __init__(self, router: Router, static_dir: str = "./static", default_proxy_timeout: float = 30.0): self.router = router self.static_dir = Path(static_dir) + self.default_proxy_timeout = default_proxy_timeout async def handle(self, request: Request) -> Response: path = request.url.path @@ -166,15 +169,89 @@ class RequestHandler: async def _handle_proxy(self, request: Request, config: Dict[str, Any], params: Dict[str, str]) -> Response: - # TODO: implement real proxying proxy_url = config["proxy_pass"] for key, value in params.items(): proxy_url = proxy_url.replace(f"{{{key}}}", value) - logger.info(f"Proxying request to: {proxy_url}") + parsed_proxy = urlparse(proxy_url) - return PlainTextResponse(f"Proxy to: {proxy_url}", status_code=200) + original_path = request.url.path + + if parsed_proxy.path and parsed_proxy.path not in ("/", ""): + target_url = proxy_url + else: + base_url = f"{parsed_proxy.scheme}://{parsed_proxy.netloc}" + target_url = f"{base_url}{original_path}" + + if request.url.query: + separator = "&" if "?" in target_url else "?" + target_url = f"{target_url}{separator}{request.url.query}" + + logger.info(f"Proxying request to: {target_url}") + + proxy_headers = dict(request.headers) + + hop_by_hop_headers = [ + "connection", "keep-alive", "proxy-authenticate", + "proxy-authorization", "te", "trailers", "transfer-encoding", + "upgrade", "host" + ] + for header in hop_by_hop_headers: + proxy_headers.pop(header, None) + + client_ip = request.client.host if request.client else "unknown" + proxy_headers["X-Forwarded-For"] = client_ip + proxy_headers["X-Forwarded-Proto"] = request.url.scheme + proxy_headers["X-Forwarded-Host"] = request.headers.get("host", "") + proxy_headers["X-Real-IP"] = client_ip + + proxy_headers["Host"] = parsed_proxy.netloc + + if "headers" in config: + for header in config["headers"]: + if ":" in header: + name, value = header.split(":", 1) + value = value.strip() + for key, param_value in params.items(): + value = value.replace(f"{{{key}}}", param_value) + value = value.replace("$remote_addr", client_ip) + proxy_headers[name.strip()] = value + + body = await request.body() + + timeout = config.get("timeout", self.default_proxy_timeout) + + try: + async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: + proxy_response = await client.request( + method=request.method, + url=target_url, + headers=proxy_headers, + content=body if body else None, + ) + + response_headers = dict(proxy_response.headers) + + for header in hop_by_hop_headers: + response_headers.pop(header, None) + + return Response( + content=proxy_response.content, + status_code=proxy_response.status_code, + headers=response_headers, + media_type=proxy_response.headers.get("content-type"), + ) + + except httpx.ConnectError as e: + logger.error(f"Proxy connection error to {target_url}: {e}") + return PlainTextResponse("502 Bad Gateway", status_code=502) + except httpx.TimeoutException as e: + logger.error(f"Proxy timeout to {target_url}: {e}") + return PlainTextResponse("504 Gateway Timeout", status_code=504) + except Exception as e: + logger.error(f"Proxy error to {target_url}: {e}") + return PlainTextResponse("502 Bad Gateway", status_code=502) def create_router_from_config(regex_locations: Dict[str, Dict[str, Any]]) -> Router: diff --git a/pyserve/server.py b/pyserve/server.py index d8696a9..294efe0 100644 --- a/pyserve/server.py +++ b/pyserve/server.py @@ -76,9 +76,13 @@ class PyServeServer: def _load_extensions(self) -> None: for ext_config in self.config.extensions: + config = ext_config.config.copy() + if ext_config.type == "routing": + config.setdefault("default_proxy_timeout", self.config.server.proxy_timeout) + self.extension_manager.load_extension( ext_config.type, - ext_config.config + config ) def _create_app(self) -> None: diff --git a/tests/test_reverse_proxy.py b/tests/test_reverse_proxy.py new file mode 100644 index 0000000..30d53aa --- /dev/null +++ b/tests/test_reverse_proxy.py @@ -0,0 +1,724 @@ +""" +Tests for reverse proxy functionality. + +These tests start a backend test server and the main PyServe server, +then verify that requests are correctly proxied between them. +""" + +import asyncio +import pytest +import httpx +import socket +import time +from contextlib import asynccontextmanager +from typing import AsyncGenerator, Dict, Any, Optional +from unittest.mock import patch + +import uvicorn +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse, Response +from starlette.routing import Route + +from pyserve.config import Config, ServerConfig, HttpConfig, LoggingConfig, ExtensionConfig +from pyserve.server import PyServeServer + + +def get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +# ============== Backend Test Application ============== + +class BackendTestApp: + def __init__(self, port: int): + self.port = port + self.request_log: list[Dict[str, Any]] = [] + self.app = self._create_app() + self._server_task = None + + def _create_app(self) -> Starlette: + routes = [ + Route("/", self._root, methods=["GET"]), + Route("/api/v{version}/users", self._users, methods=["GET", "POST"]), + Route("/api/v{version}/users/{user_id}", self._user_detail, methods=["GET", "PUT", "DELETE"]), + Route("/api/v{version}/data", self._data, methods=["GET", "POST"]), + Route("/echo", self._echo, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]), + Route("/headers", self._headers, methods=["GET"]), + Route("/slow", self._slow, methods=["GET"]), + Route("/status/{code:int}", self._status, methods=["GET"]), + Route("/json", self._json, methods=["POST"]), + Route("/backend2/", self._backend2_root, methods=["GET"]), + Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]), + ] + return Starlette(routes=routes) + + async def _root(self, request: Request) -> Response: + self._log_request(request) + return JSONResponse({"message": "Backend root", "port": self.port}) + + async def _users(self, request: Request) -> Response: + self._log_request(request) + version = request.path_params.get("version", "1") + if request.method == "POST": + body = await request.json() + return JSONResponse({"action": "create_user", "version": version, "data": body}, status_code=201) + return JSONResponse({"users": [{"id": 1, "name": "Test User"}], "version": version}) + + async def _user_detail(self, request: Request) -> Response: + self._log_request(request) + version = request.path_params.get("version", "1") + user_id = request.path_params.get("user_id", "0") + if request.method == "DELETE": + return JSONResponse({"action": "delete_user", "user_id": user_id, "version": version}) + if request.method == "PUT": + body = await request.json() + return JSONResponse({"action": "update_user", "user_id": user_id, "version": version, "data": body}) + return JSONResponse({"user": {"id": user_id, "name": f"User {user_id}"}, "version": version}) + + async def _data(self, request: Request) -> Response: + self._log_request(request) + version = request.path_params.get("version", "1") + return JSONResponse({"data": "test data", "version": version}) + + async def _echo(self, request: Request) -> Response: + self._log_request(request) + body = await request.body() + return Response( + content=body, + status_code=200, + media_type=request.headers.get("content-type", "text/plain") + ) + + async def _headers(self, request: Request) -> Response: + self._log_request(request) + headers = dict(request.headers) + return JSONResponse({ + "received_headers": headers, + "client_ip": request.client.host if request.client else None, + }) + + async def _slow(self, request: Request) -> Response: + self._log_request(request) + delay = float(request.query_params.get("delay", "2")) + await asyncio.sleep(delay) + return JSONResponse({"message": "slow response", "delay": delay}) + + async def _status(self, request: Request) -> Response: + self._log_request(request) + code = request.path_params.get("code", 200) + return PlainTextResponse(f"Status: {code}", status_code=code) + + async def _json(self, request: Request) -> Response: + self._log_request(request) + body = await request.json() + return JSONResponse({"received": body, "processed": True}) + + async def _backend2_root(self, request: Request) -> Response: + self._log_request(request) + return JSONResponse({"message": "Backend2 root", "port": self.port}) + + async def _catch_all(self, request: Request) -> Response: + """Catch-all handler for debugging unmatched routes.""" + self._log_request(request) + return JSONResponse({ + "message": "Catch-all", + "path": str(request.url.path), + "method": request.method, + "port": self.port + }) + + def _log_request(self, request: Request) -> None: + self.request_log.append({ + "method": request.method, + "path": str(request.url.path), + "query": str(request.url.query), + "headers": dict(request.headers), + }) + + async def start(self) -> None: + config = uvicorn.Config( + app=self.app, + host="127.0.0.1", + port=self.port, + log_level="critical", + access_log=False, + ) + server = uvicorn.Server(config) + self._server_task = asyncio.create_task(server.serve()) + + # Wait for server to be ready + for _ in range(50): # 5 seconds max + try: + async with httpx.AsyncClient() as client: + await client.get(f"http://127.0.0.1:{self.port}/") + return + except httpx.ConnectError: + await asyncio.sleep(0.1) + raise RuntimeError(f"Backend server failed to start on port {self.port}") + + async def stop(self) -> None: + if self._server_task: + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + + +# ============== PyServe Test Server ============== + +class PyServeTestServer: + def __init__(self, config: Config): + self.config = config + self.server = PyServeServer(config) + self._server_task = None + + async def start(self) -> None: + assert self.server.app is not None, "Server app not initialized" + config = uvicorn.Config( + app=self.server.app, + host=self.config.server.host, + port=self.config.server.port, + log_level="critical", + access_log=False, + ) + server = uvicorn.Server(config) + self._server_task = asyncio.create_task(server.serve()) + + # Wait for server to be ready + for _ in range(50): # 5 seconds max + try: + async with httpx.AsyncClient() as client: + await client.get(f"http://127.0.0.1:{self.config.server.port}/health") + return + except httpx.ConnectError: + await asyncio.sleep(0.1) + raise RuntimeError(f"PyServe server failed to start on port {self.config.server.port}") + + async def stop(self) -> None: + if self._server_task: + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + + +# ============== Fixtures ============== + +@pytest.fixture +def backend_port() -> int: + return get_free_port() + + +@pytest.fixture +def pyserve_port() -> int: + return get_free_port() + + +@pytest.fixture +def create_proxy_config(): + def _create_config(pyserve_port: int, backend_port: int, extra_locations: Optional[Dict[str, Any]] = None) -> Config: + locations = { + # API proxy with version capture + "~^/api/v(?P\\d+)/": { + "proxy_pass": f"http://127.0.0.1:{backend_port}", + "headers": [ + "X-API-Version: {version}", + "X-Forwarded-For: $remote_addr" + ] + }, + # Echo endpoint proxy + "=/echo": { + "proxy_pass": f"http://127.0.0.1:{backend_port}/echo", + }, + # Headers test proxy + "=/headers": { + "proxy_pass": f"http://127.0.0.1:{backend_port}/headers", + }, + # Slow endpoint proxy with timeout + "~^/slow": { + "proxy_pass": f"http://127.0.0.1:{backend_port}/slow", + "timeout": 5.0, + }, + # Status endpoint proxy + "~^/status/": { + "proxy_pass": f"http://127.0.0.1:{backend_port}", + }, + # JSON endpoint proxy + "=/json": { + "proxy_pass": f"http://127.0.0.1:{backend_port}/json", + }, + # Health check + "=/health": { + "return": "200 OK", + "content_type": "text/plain" + }, + # Default fallback + "__default__": { + "return": "404 Not Found", + "content_type": "text/plain" + } + } + + if extra_locations: + locations.update(extra_locations) + + config = Config( + http=HttpConfig(static_dir="./static", templates_dir="./templates"), + server=ServerConfig(host="127.0.0.1", port=pyserve_port), + logging=LoggingConfig(level="ERROR", console_output=False), + extensions=[ + ExtensionConfig( + type="routing", + config={"regex_locations": locations} + ) + ] + ) + return config + + return _create_config + + +@asynccontextmanager +async def running_servers( + pyserve_port: int, + backend_port: int, + config_factory, + extra_locations: Optional[Dict[str, Any]] = None +) -> AsyncGenerator[tuple[PyServeTestServer, BackendTestApp], None]: + backend = BackendTestApp(backend_port) + config = config_factory(pyserve_port, backend_port, extra_locations) + pyserve = PyServeTestServer(config) + + await backend.start() + await pyserve.start() + + try: + yield pyserve, backend + finally: + await pyserve.stop() + await backend.stop() + + +# ============== Tests ============== + +@pytest.mark.asyncio +async def test_basic_proxy_get(backend_port, pyserve_port, create_proxy_config): + """Test basic GET request proxying.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + response = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/users") + + assert response.status_code == 200 + data = response.json() + assert "users" in data + assert data["version"] == "1" + + +@pytest.mark.asyncio +async def test_proxy_api_versions(backend_port, pyserve_port, create_proxy_config): + """Test that API version is correctly captured and passed.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + # Test v1 + response_v1 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/data") + assert response_v1.status_code == 200 + assert response_v1.json()["version"] == "1" + + # Test v2 + response_v2 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v2/data") + assert response_v2.status_code == 200 + assert response_v2.json()["version"] == "2" + + +@pytest.mark.asyncio +async def test_proxy_post_with_body(backend_port, pyserve_port, create_proxy_config): + """Test POST request with JSON body.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + payload = {"name": "New User", "email": "test@example.com"} + response = await client.post( + f"http://127.0.0.1:{pyserve_port}/api/v1/users", + json=payload + ) + + assert response.status_code == 201 + data = response.json() + assert data["action"] == "create_user" + assert data["data"]["name"] == "New User" + + +@pytest.mark.asyncio +async def test_proxy_put_request(backend_port, pyserve_port, create_proxy_config): + """Test PUT request proxying.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + payload = {"name": "Updated User"} + response = await client.put( + f"http://127.0.0.1:{pyserve_port}/api/v1/users/123", + json=payload + ) + + assert response.status_code == 200 + data = response.json() + assert data["action"] == "update_user" + assert data["user_id"] == "123" + + +@pytest.mark.asyncio +async def test_proxy_delete_request(backend_port, pyserve_port, create_proxy_config): + """Test DELETE request proxying.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + response = await client.delete( + f"http://127.0.0.1:{pyserve_port}/api/v1/users/456" + ) + + assert response.status_code == 200 + data = response.json() + assert data["action"] == "delete_user" + assert data["user_id"] == "456" + + +@pytest.mark.asyncio +async def test_proxy_headers_forwarding(backend_port, pyserve_port, create_proxy_config): + """Test that headers are correctly forwarded.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{pyserve_port}/headers", + headers={"X-Custom-Header": "test-value"} + ) + + assert response.status_code == 200 + data = response.json() + + # Check that custom header was forwarded + assert data["received_headers"].get("x-custom-header") == "test-value" + + # Check that X-Forwarded headers were added + assert "x-forwarded-for" in data["received_headers"] + assert "x-forwarded-proto" in data["received_headers"] + + +@pytest.mark.asyncio +async def test_proxy_echo_endpoint(backend_port, pyserve_port, create_proxy_config): + """Test echo endpoint that returns request body.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + test_data = "Hello, Proxy!" + response = await client.post( + f"http://127.0.0.1:{pyserve_port}/echo", + content=test_data, + headers={"Content-Type": "text/plain"} + ) + + assert response.status_code == 200 + assert response.text == test_data + + +@pytest.mark.asyncio +async def test_proxy_json_endpoint(backend_port, pyserve_port, create_proxy_config): + """Test JSON processing through proxy.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + payload = {"key": "value", "numbers": [1, 2, 3]} + response = await client.post( + f"http://127.0.0.1:{pyserve_port}/json", + json=payload + ) + + assert response.status_code == 200 + data = response.json() + assert data["processed"] is True + assert data["received"]["key"] == "value" + + +@pytest.mark.asyncio +async def test_proxy_status_codes(backend_port, pyserve_port, create_proxy_config): + """Test that various status codes are correctly proxied.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + for status in [200, 201, 400, 404, 500]: + response = await client.get( + f"http://127.0.0.1:{pyserve_port}/status/{status}" + ) + assert response.status_code == status + + +@pytest.mark.asyncio +async def test_proxy_query_params(backend_port, pyserve_port, create_proxy_config): + """Test that query parameters are passed through.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{pyserve_port}/slow?delay=0.1" + ) + + assert response.status_code == 200 + data = response.json() + assert data["delay"] == 0.1 + + +@pytest.mark.asyncio +async def test_proxy_backend_unavailable(pyserve_port, create_proxy_config): + """Test handling when backend is unavailable (502 Bad Gateway).""" + # Use a port where nothing is listening + unavailable_port = get_free_port() + + config = create_proxy_config(pyserve_port, unavailable_port) + pyserve = PyServeTestServer(config) + + await pyserve.start() + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{pyserve_port}/api/v1/users" + ) + + assert response.status_code == 502 + assert "Bad Gateway" in response.text + finally: + await pyserve.stop() + + +@pytest.mark.asyncio +async def test_proxy_timeout(backend_port, pyserve_port, create_proxy_config): + """Test proxy timeout handling (504 Gateway Timeout).""" + # Create config with very short timeout + extra_locations = { + "=/timeout-test": { + "proxy_pass": f"http://127.0.0.1:{backend_port}/slow?delay=5", + "timeout": 0.5, # Very short timeout + } + } + + async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations): + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get( + f"http://127.0.0.1:{pyserve_port}/timeout-test" + ) + + assert response.status_code == 504 + assert "Gateway Timeout" in response.text + + +@pytest.mark.asyncio +async def test_proxy_health_check_not_proxied(backend_port, pyserve_port, create_proxy_config): + """Test that health check endpoint is handled locally, not proxied.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{pyserve_port}/health" + ) + + assert response.status_code == 200 + assert response.text == "OK" + + +@pytest.mark.asyncio +async def test_proxy_default_fallback(backend_port, pyserve_port, create_proxy_config): + """Test default fallback for unmatched routes.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{pyserve_port}/nonexistent/path" + ) + + assert response.status_code == 404 + assert "Not Found" in response.text + + +@pytest.mark.asyncio +async def test_complex_config_multiple_proxies(backend_port, pyserve_port, create_proxy_config): + """Test complex configuration with multiple proxy rules.""" + # Create a second backend port for testing multiple backends + backend2_port = get_free_port() + backend2 = BackendTestApp(backend2_port) + + extra_locations = { + "~^/backend2/": { + "proxy_pass": f"http://127.0.0.1:{backend2_port}", + } + } + + async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations): + await backend2.start() + try: + async with httpx.AsyncClient() as client: + # Request to first backend + response1 = await client.get( + f"http://127.0.0.1:{pyserve_port}/api/v1/users" + ) + assert response1.status_code == 200 + assert response1.json()["users"][0]["id"] == 1 + + # Request to second backend + response2 = await client.get( + f"http://127.0.0.1:{pyserve_port}/backend2/" + ) + assert response2.status_code == 200 + data2 = response2.json() + assert data2["port"] == backend2_port + finally: + await backend2.stop() + + +@pytest.mark.asyncio +async def test_proxy_large_request_body(backend_port, pyserve_port, create_proxy_config): + """Test proxying large request bodies.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + # Create a large payload + large_data = "x" * 100000 # 100KB of data + response = await client.post( + f"http://127.0.0.1:{pyserve_port}/echo", + content=large_data, + headers={"Content-Type": "text/plain"} + ) + + assert response.status_code == 200 + assert response.text == large_data + + +@pytest.mark.asyncio +async def test_proxy_content_type_preservation(backend_port, pyserve_port, create_proxy_config): + """Test that content-type is correctly preserved.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://127.0.0.1:{pyserve_port}/api/v1/users" + ) + + assert response.status_code == 200 + assert "application/json" in response.headers.get("content-type", "") + + +@pytest.mark.asyncio +async def test_proxy_concurrent_requests(backend_port, pyserve_port, create_proxy_config): + """Test handling multiple concurrent proxy requests.""" + async with running_servers(pyserve_port, backend_port, create_proxy_config): + async with httpx.AsyncClient() as client: + # Send multiple concurrent requests + tasks = [ + client.get(f"http://127.0.0.1:{pyserve_port}/api/v{i % 3 + 1}/users") + for i in range(10) + ] + responses = await asyncio.gather(*tasks) + + # All should succeed + for response in responses: + assert response.status_code == 200 + assert "users" in response.json() + + +@pytest.mark.asyncio +async def test_server_default_proxy_timeout(backend_port, pyserve_port): + """Test that server-level proxy_timeout is used as default.""" + locations = { + "~^/slow": { + "proxy_pass": f"http://127.0.0.1:{backend_port}/slow", + # No timeout specified - should use server default + }, + "=/health": { + "return": "200 OK", + "content_type": "text/plain" + }, + } + + config = Config( + http=HttpConfig(static_dir="./static", templates_dir="./templates"), + server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.5), # Very short timeout + logging=LoggingConfig(level="ERROR", console_output=False), + extensions=[ + ExtensionConfig( + type="routing", + config={"regex_locations": locations} + ) + ] + ) + + backend = BackendTestApp(backend_port) + pyserve = PyServeTestServer(config) + + await backend.start() + await pyserve.start() + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # This should timeout because server default is 0.5s and slow endpoint takes 2s + response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=2") + assert response.status_code == 504 + assert "Gateway Timeout" in response.text + finally: + await pyserve.stop() + await backend.stop() + + +@pytest.mark.asyncio +async def test_route_timeout_overrides_server_default(backend_port, pyserve_port): + """Test that route-level timeout overrides server default proxy_timeout.""" + locations = { + "~^/slow": { + "proxy_pass": f"http://127.0.0.1:{backend_port}/slow", + "timeout": 5.0, # Route-level timeout overrides server default + }, + "=/health": { + "return": "200 OK", + "content_type": "text/plain" + }, + } + + config = Config( + http=HttpConfig(static_dir="./static", templates_dir="./templates"), + server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.1), # Very short server default + logging=LoggingConfig(level="ERROR", console_output=False), + extensions=[ + ExtensionConfig( + type="routing", + config={"regex_locations": locations} + ) + ] + ) + + backend = BackendTestApp(backend_port) + pyserve = PyServeTestServer(config) + + await backend.start() + await pyserve.start() + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # This should succeed because route timeout (5s) > delay (0.5s), even though server default is 0.1s + response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=0.5") + assert response.status_code == 200 + data = response.json() + assert data["delay"] == 0.5 + finally: + await pyserve.stop() + await backend.stop() + + +@pytest.mark.asyncio +async def test_proxy_timeout_config_parsing(): + """Test that proxy_timeout is correctly parsed from config.""" + from pyserve.config import Config + + # Test default value + config = Config() + assert config.server.proxy_timeout == 30.0 + + # Test custom value from dict + config_with_timeout = Config._from_dict({ + "server": { + "host": "127.0.0.1", + "port": 8080, + "proxy_timeout": 60.0 + } + }) + assert config_with_timeout.server.proxy_timeout == 60.0