konduktor/tests/test_reverse_proxy.py
Илья Глазунов 5262c5e1fb reverse proxy added
added tests for reverse proxy too
2025-12-03 00:05:11 +03:00

725 lines
27 KiB
Python

"""
Tests for reverse proxy functionality.
These tests start a backend test server and the main PyServe server,
then verify that requests are correctly proxied between them.
"""
import asyncio
import pytest
import httpx
import socket
import time
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Any, Optional
from unittest.mock import patch
import uvicorn
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Route
from pyserve.config import Config, ServerConfig, HttpConfig, LoggingConfig, ExtensionConfig
from pyserve.server import PyServeServer
def get_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
# ============== Backend Test Application ==============
class BackendTestApp:
def __init__(self, port: int):
self.port = port
self.request_log: list[Dict[str, Any]] = []
self.app = self._create_app()
self._server_task = None
def _create_app(self) -> Starlette:
routes = [
Route("/", self._root, methods=["GET"]),
Route("/api/v{version}/users", self._users, methods=["GET", "POST"]),
Route("/api/v{version}/users/{user_id}", self._user_detail, methods=["GET", "PUT", "DELETE"]),
Route("/api/v{version}/data", self._data, methods=["GET", "POST"]),
Route("/echo", self._echo, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
Route("/headers", self._headers, methods=["GET"]),
Route("/slow", self._slow, methods=["GET"]),
Route("/status/{code:int}", self._status, methods=["GET"]),
Route("/json", self._json, methods=["POST"]),
Route("/backend2/", self._backend2_root, methods=["GET"]),
Route("/{path:path}", self._catch_all, methods=["GET", "POST", "PUT", "DELETE", "PATCH"]),
]
return Starlette(routes=routes)
async def _root(self, request: Request) -> Response:
self._log_request(request)
return JSONResponse({"message": "Backend root", "port": self.port})
async def _users(self, request: Request) -> Response:
self._log_request(request)
version = request.path_params.get("version", "1")
if request.method == "POST":
body = await request.json()
return JSONResponse({"action": "create_user", "version": version, "data": body}, status_code=201)
return JSONResponse({"users": [{"id": 1, "name": "Test User"}], "version": version})
async def _user_detail(self, request: Request) -> Response:
self._log_request(request)
version = request.path_params.get("version", "1")
user_id = request.path_params.get("user_id", "0")
if request.method == "DELETE":
return JSONResponse({"action": "delete_user", "user_id": user_id, "version": version})
if request.method == "PUT":
body = await request.json()
return JSONResponse({"action": "update_user", "user_id": user_id, "version": version, "data": body})
return JSONResponse({"user": {"id": user_id, "name": f"User {user_id}"}, "version": version})
async def _data(self, request: Request) -> Response:
self._log_request(request)
version = request.path_params.get("version", "1")
return JSONResponse({"data": "test data", "version": version})
async def _echo(self, request: Request) -> Response:
self._log_request(request)
body = await request.body()
return Response(
content=body,
status_code=200,
media_type=request.headers.get("content-type", "text/plain")
)
async def _headers(self, request: Request) -> Response:
self._log_request(request)
headers = dict(request.headers)
return JSONResponse({
"received_headers": headers,
"client_ip": request.client.host if request.client else None,
})
async def _slow(self, request: Request) -> Response:
self._log_request(request)
delay = float(request.query_params.get("delay", "2"))
await asyncio.sleep(delay)
return JSONResponse({"message": "slow response", "delay": delay})
async def _status(self, request: Request) -> Response:
self._log_request(request)
code = request.path_params.get("code", 200)
return PlainTextResponse(f"Status: {code}", status_code=code)
async def _json(self, request: Request) -> Response:
self._log_request(request)
body = await request.json()
return JSONResponse({"received": body, "processed": True})
async def _backend2_root(self, request: Request) -> Response:
self._log_request(request)
return JSONResponse({"message": "Backend2 root", "port": self.port})
async def _catch_all(self, request: Request) -> Response:
"""Catch-all handler for debugging unmatched routes."""
self._log_request(request)
return JSONResponse({
"message": "Catch-all",
"path": str(request.url.path),
"method": request.method,
"port": self.port
})
def _log_request(self, request: Request) -> None:
self.request_log.append({
"method": request.method,
"path": str(request.url.path),
"query": str(request.url.query),
"headers": dict(request.headers),
})
async def start(self) -> None:
config = uvicorn.Config(
app=self.app,
host="127.0.0.1",
port=self.port,
log_level="critical",
access_log=False,
)
server = uvicorn.Server(config)
self._server_task = asyncio.create_task(server.serve())
# Wait for server to be ready
for _ in range(50): # 5 seconds max
try:
async with httpx.AsyncClient() as client:
await client.get(f"http://127.0.0.1:{self.port}/")
return
except httpx.ConnectError:
await asyncio.sleep(0.1)
raise RuntimeError(f"Backend server failed to start on port {self.port}")
async def stop(self) -> None:
if self._server_task:
self._server_task.cancel()
try:
await self._server_task
except asyncio.CancelledError:
pass
# ============== PyServe Test Server ==============
class PyServeTestServer:
def __init__(self, config: Config):
self.config = config
self.server = PyServeServer(config)
self._server_task = None
async def start(self) -> None:
assert self.server.app is not None, "Server app not initialized"
config = uvicorn.Config(
app=self.server.app,
host=self.config.server.host,
port=self.config.server.port,
log_level="critical",
access_log=False,
)
server = uvicorn.Server(config)
self._server_task = asyncio.create_task(server.serve())
# Wait for server to be ready
for _ in range(50): # 5 seconds max
try:
async with httpx.AsyncClient() as client:
await client.get(f"http://127.0.0.1:{self.config.server.port}/health")
return
except httpx.ConnectError:
await asyncio.sleep(0.1)
raise RuntimeError(f"PyServe server failed to start on port {self.config.server.port}")
async def stop(self) -> None:
if self._server_task:
self._server_task.cancel()
try:
await self._server_task
except asyncio.CancelledError:
pass
# ============== Fixtures ==============
@pytest.fixture
def backend_port() -> int:
return get_free_port()
@pytest.fixture
def pyserve_port() -> int:
return get_free_port()
@pytest.fixture
def create_proxy_config():
def _create_config(pyserve_port: int, backend_port: int, extra_locations: Optional[Dict[str, Any]] = None) -> Config:
locations = {
# API proxy with version capture
"~^/api/v(?P<version>\\d+)/": {
"proxy_pass": f"http://127.0.0.1:{backend_port}",
"headers": [
"X-API-Version: {version}",
"X-Forwarded-For: $remote_addr"
]
},
# Echo endpoint proxy
"=/echo": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/echo",
},
# Headers test proxy
"=/headers": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/headers",
},
# Slow endpoint proxy with timeout
"~^/slow": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
"timeout": 5.0,
},
# Status endpoint proxy
"~^/status/": {
"proxy_pass": f"http://127.0.0.1:{backend_port}",
},
# JSON endpoint proxy
"=/json": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/json",
},
# Health check
"=/health": {
"return": "200 OK",
"content_type": "text/plain"
},
# Default fallback
"__default__": {
"return": "404 Not Found",
"content_type": "text/plain"
}
}
if extra_locations:
locations.update(extra_locations)
config = Config(
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
server=ServerConfig(host="127.0.0.1", port=pyserve_port),
logging=LoggingConfig(level="ERROR", console_output=False),
extensions=[
ExtensionConfig(
type="routing",
config={"regex_locations": locations}
)
]
)
return config
return _create_config
@asynccontextmanager
async def running_servers(
pyserve_port: int,
backend_port: int,
config_factory,
extra_locations: Optional[Dict[str, Any]] = None
) -> AsyncGenerator[tuple[PyServeTestServer, BackendTestApp], None]:
backend = BackendTestApp(backend_port)
config = config_factory(pyserve_port, backend_port, extra_locations)
pyserve = PyServeTestServer(config)
await backend.start()
await pyserve.start()
try:
yield pyserve, backend
finally:
await pyserve.stop()
await backend.stop()
# ============== Tests ==============
@pytest.mark.asyncio
async def test_basic_proxy_get(backend_port, pyserve_port, create_proxy_config):
"""Test basic GET request proxying."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/users")
assert response.status_code == 200
data = response.json()
assert "users" in data
assert data["version"] == "1"
@pytest.mark.asyncio
async def test_proxy_api_versions(backend_port, pyserve_port, create_proxy_config):
"""Test that API version is correctly captured and passed."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
# Test v1
response_v1 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v1/data")
assert response_v1.status_code == 200
assert response_v1.json()["version"] == "1"
# Test v2
response_v2 = await client.get(f"http://127.0.0.1:{pyserve_port}/api/v2/data")
assert response_v2.status_code == 200
assert response_v2.json()["version"] == "2"
@pytest.mark.asyncio
async def test_proxy_post_with_body(backend_port, pyserve_port, create_proxy_config):
"""Test POST request with JSON body."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
payload = {"name": "New User", "email": "test@example.com"}
response = await client.post(
f"http://127.0.0.1:{pyserve_port}/api/v1/users",
json=payload
)
assert response.status_code == 201
data = response.json()
assert data["action"] == "create_user"
assert data["data"]["name"] == "New User"
@pytest.mark.asyncio
async def test_proxy_put_request(backend_port, pyserve_port, create_proxy_config):
"""Test PUT request proxying."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
payload = {"name": "Updated User"}
response = await client.put(
f"http://127.0.0.1:{pyserve_port}/api/v1/users/123",
json=payload
)
assert response.status_code == 200
data = response.json()
assert data["action"] == "update_user"
assert data["user_id"] == "123"
@pytest.mark.asyncio
async def test_proxy_delete_request(backend_port, pyserve_port, create_proxy_config):
"""Test DELETE request proxying."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.delete(
f"http://127.0.0.1:{pyserve_port}/api/v1/users/456"
)
assert response.status_code == 200
data = response.json()
assert data["action"] == "delete_user"
assert data["user_id"] == "456"
@pytest.mark.asyncio
async def test_proxy_headers_forwarding(backend_port, pyserve_port, create_proxy_config):
"""Test that headers are correctly forwarded."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/headers",
headers={"X-Custom-Header": "test-value"}
)
assert response.status_code == 200
data = response.json()
# Check that custom header was forwarded
assert data["received_headers"].get("x-custom-header") == "test-value"
# Check that X-Forwarded headers were added
assert "x-forwarded-for" in data["received_headers"]
assert "x-forwarded-proto" in data["received_headers"]
@pytest.mark.asyncio
async def test_proxy_echo_endpoint(backend_port, pyserve_port, create_proxy_config):
"""Test echo endpoint that returns request body."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
test_data = "Hello, Proxy!"
response = await client.post(
f"http://127.0.0.1:{pyserve_port}/echo",
content=test_data,
headers={"Content-Type": "text/plain"}
)
assert response.status_code == 200
assert response.text == test_data
@pytest.mark.asyncio
async def test_proxy_json_endpoint(backend_port, pyserve_port, create_proxy_config):
"""Test JSON processing through proxy."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
payload = {"key": "value", "numbers": [1, 2, 3]}
response = await client.post(
f"http://127.0.0.1:{pyserve_port}/json",
json=payload
)
assert response.status_code == 200
data = response.json()
assert data["processed"] is True
assert data["received"]["key"] == "value"
@pytest.mark.asyncio
async def test_proxy_status_codes(backend_port, pyserve_port, create_proxy_config):
"""Test that various status codes are correctly proxied."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
for status in [200, 201, 400, 404, 500]:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/status/{status}"
)
assert response.status_code == status
@pytest.mark.asyncio
async def test_proxy_query_params(backend_port, pyserve_port, create_proxy_config):
"""Test that query parameters are passed through."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/slow?delay=0.1"
)
assert response.status_code == 200
data = response.json()
assert data["delay"] == 0.1
@pytest.mark.asyncio
async def test_proxy_backend_unavailable(pyserve_port, create_proxy_config):
"""Test handling when backend is unavailable (502 Bad Gateway)."""
# Use a port where nothing is listening
unavailable_port = get_free_port()
config = create_proxy_config(pyserve_port, unavailable_port)
pyserve = PyServeTestServer(config)
await pyserve.start()
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
)
assert response.status_code == 502
assert "Bad Gateway" in response.text
finally:
await pyserve.stop()
@pytest.mark.asyncio
async def test_proxy_timeout(backend_port, pyserve_port, create_proxy_config):
"""Test proxy timeout handling (504 Gateway Timeout)."""
# Create config with very short timeout
extra_locations = {
"=/timeout-test": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow?delay=5",
"timeout": 0.5, # Very short timeout
}
}
async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations):
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/timeout-test"
)
assert response.status_code == 504
assert "Gateway Timeout" in response.text
@pytest.mark.asyncio
async def test_proxy_health_check_not_proxied(backend_port, pyserve_port, create_proxy_config):
"""Test that health check endpoint is handled locally, not proxied."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/health"
)
assert response.status_code == 200
assert response.text == "OK"
@pytest.mark.asyncio
async def test_proxy_default_fallback(backend_port, pyserve_port, create_proxy_config):
"""Test default fallback for unmatched routes."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/nonexistent/path"
)
assert response.status_code == 404
assert "Not Found" in response.text
@pytest.mark.asyncio
async def test_complex_config_multiple_proxies(backend_port, pyserve_port, create_proxy_config):
"""Test complex configuration with multiple proxy rules."""
# Create a second backend port for testing multiple backends
backend2_port = get_free_port()
backend2 = BackendTestApp(backend2_port)
extra_locations = {
"~^/backend2/": {
"proxy_pass": f"http://127.0.0.1:{backend2_port}",
}
}
async with running_servers(pyserve_port, backend_port, create_proxy_config, extra_locations):
await backend2.start()
try:
async with httpx.AsyncClient() as client:
# Request to first backend
response1 = await client.get(
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
)
assert response1.status_code == 200
assert response1.json()["users"][0]["id"] == 1
# Request to second backend
response2 = await client.get(
f"http://127.0.0.1:{pyserve_port}/backend2/"
)
assert response2.status_code == 200
data2 = response2.json()
assert data2["port"] == backend2_port
finally:
await backend2.stop()
@pytest.mark.asyncio
async def test_proxy_large_request_body(backend_port, pyserve_port, create_proxy_config):
"""Test proxying large request bodies."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
# Create a large payload
large_data = "x" * 100000 # 100KB of data
response = await client.post(
f"http://127.0.0.1:{pyserve_port}/echo",
content=large_data,
headers={"Content-Type": "text/plain"}
)
assert response.status_code == 200
assert response.text == large_data
@pytest.mark.asyncio
async def test_proxy_content_type_preservation(backend_port, pyserve_port, create_proxy_config):
"""Test that content-type is correctly preserved."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{pyserve_port}/api/v1/users"
)
assert response.status_code == 200
assert "application/json" in response.headers.get("content-type", "")
@pytest.mark.asyncio
async def test_proxy_concurrent_requests(backend_port, pyserve_port, create_proxy_config):
"""Test handling multiple concurrent proxy requests."""
async with running_servers(pyserve_port, backend_port, create_proxy_config):
async with httpx.AsyncClient() as client:
# Send multiple concurrent requests
tasks = [
client.get(f"http://127.0.0.1:{pyserve_port}/api/v{i % 3 + 1}/users")
for i in range(10)
]
responses = await asyncio.gather(*tasks)
# All should succeed
for response in responses:
assert response.status_code == 200
assert "users" in response.json()
@pytest.mark.asyncio
async def test_server_default_proxy_timeout(backend_port, pyserve_port):
"""Test that server-level proxy_timeout is used as default."""
locations = {
"~^/slow": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
# No timeout specified - should use server default
},
"=/health": {
"return": "200 OK",
"content_type": "text/plain"
},
}
config = Config(
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.5), # Very short timeout
logging=LoggingConfig(level="ERROR", console_output=False),
extensions=[
ExtensionConfig(
type="routing",
config={"regex_locations": locations}
)
]
)
backend = BackendTestApp(backend_port)
pyserve = PyServeTestServer(config)
await backend.start()
await pyserve.start()
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# This should timeout because server default is 0.5s and slow endpoint takes 2s
response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=2")
assert response.status_code == 504
assert "Gateway Timeout" in response.text
finally:
await pyserve.stop()
await backend.stop()
@pytest.mark.asyncio
async def test_route_timeout_overrides_server_default(backend_port, pyserve_port):
"""Test that route-level timeout overrides server default proxy_timeout."""
locations = {
"~^/slow": {
"proxy_pass": f"http://127.0.0.1:{backend_port}/slow",
"timeout": 5.0, # Route-level timeout overrides server default
},
"=/health": {
"return": "200 OK",
"content_type": "text/plain"
},
}
config = Config(
http=HttpConfig(static_dir="./static", templates_dir="./templates"),
server=ServerConfig(host="127.0.0.1", port=pyserve_port, proxy_timeout=0.1), # Very short server default
logging=LoggingConfig(level="ERROR", console_output=False),
extensions=[
ExtensionConfig(
type="routing",
config={"regex_locations": locations}
)
]
)
backend = BackendTestApp(backend_port)
pyserve = PyServeTestServer(config)
await backend.start()
await pyserve.start()
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# This should succeed because route timeout (5s) > delay (0.5s), even though server default is 0.1s
response = await client.get(f"http://127.0.0.1:{pyserve_port}/slow?delay=0.5")
assert response.status_code == 200
data = response.json()
assert data["delay"] == 0.5
finally:
await pyserve.stop()
await backend.stop()
@pytest.mark.asyncio
async def test_proxy_timeout_config_parsing():
"""Test that proxy_timeout is correctly parsed from config."""
from pyserve.config import Config
# Test default value
config = Config()
assert config.server.proxy_timeout == 30.0
# Test custom value from dict
config_with_timeout = Config._from_dict({
"server": {
"host": "127.0.0.1",
"port": 8080,
"proxy_timeout": 60.0
}
})
assert config_with_timeout.server.proxy_timeout == 60.0