forked from aegis/pyserveX
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eeeccd57da | ||
|
|
fe541778f1 |
@ -22,6 +22,11 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Install system dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libpcre2-dev
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
@ -45,6 +50,9 @@ jobs:
|
|||||||
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
|
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
|
||||||
run: poetry install --with dev
|
run: poetry install --with dev
|
||||||
|
|
||||||
|
- name: Build Cython extensions
|
||||||
|
run: poetry run python scripts/build_cython.py build_ext --inplace
|
||||||
|
|
||||||
- name: Build package
|
- name: Build package
|
||||||
run: |
|
run: |
|
||||||
poetry build
|
poetry build
|
||||||
|
|||||||
@ -17,6 +17,11 @@ jobs:
|
|||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install system dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libpcre2-dev
|
||||||
|
|
||||||
- name: Setup Python ${{ matrix.python-version }}
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
@ -40,6 +45,9 @@ jobs:
|
|||||||
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
|
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
|
||||||
run: poetry install --with dev
|
run: poetry install --with dev
|
||||||
|
|
||||||
|
- name: Build Cython extensions
|
||||||
|
run: poetry run python scripts/build_cython.py build_ext --inplace
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: poetry run pytest tests/ -v
|
run: poetry run pytest tests/ -v
|
||||||
|
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -28,6 +28,3 @@ build/
|
|||||||
.vscode/
|
.vscode/
|
||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
|
|
||||||
# Go binaries
|
|
||||||
go/bin
|
|
||||||
153
benchmarks/bench_routing.py
Normal file
153
benchmarks/bench_routing.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Benchmark script for routing performance comparison.
|
||||||
|
|
||||||
|
Compares:
|
||||||
|
- Pure Python implementation with standard re (_routing_py)
|
||||||
|
- Cython implementation with PCRE2 JIT (_routing)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python benchmarks/bench_routing.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import statistics
|
||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
from pyserve._routing_py import (
|
||||||
|
FastRouter as PyFastRouter,
|
||||||
|
FastRouteMatch as PyFastRouteMatch,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyserve._routing import (
|
||||||
|
FastRouter as CyFastRouter,
|
||||||
|
FastRouteMatch as CyFastRouteMatch,
|
||||||
|
)
|
||||||
|
CYTHON_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CYTHON_AVAILABLE = False
|
||||||
|
print("Cython module not compiled. Run: poetry run python scripts/build_cython.py\n")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark(func: Callable, iterations: int = 100000) -> Tuple[float, float]:
|
||||||
|
"""Benchmark a function and return mean/stdev in nanoseconds."""
|
||||||
|
times = []
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(1000):
|
||||||
|
func()
|
||||||
|
|
||||||
|
# Actual benchmark
|
||||||
|
for _ in range(iterations):
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
func()
|
||||||
|
end = time.perf_counter_ns()
|
||||||
|
times.append(end - start)
|
||||||
|
|
||||||
|
return statistics.mean(times), statistics.stdev(times)
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(ns: float) -> str:
|
||||||
|
"""Format time in nanoseconds to human readable format."""
|
||||||
|
if ns < 1000:
|
||||||
|
return f"{ns:.1f} ns"
|
||||||
|
elif ns < 1_000_000:
|
||||||
|
return f"{ns/1000:.2f} µs"
|
||||||
|
else:
|
||||||
|
return f"{ns/1_000_000:.2f} ms"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_router(router_class):
|
||||||
|
"""Setup a router with typical routes."""
|
||||||
|
router = router_class()
|
||||||
|
|
||||||
|
# Exact routes
|
||||||
|
router.add_route("=/health", {"return": "200 OK"})
|
||||||
|
router.add_route("=/api/status", {"return": "200 OK"})
|
||||||
|
router.add_route("=/favicon.ico", {"return": "204"})
|
||||||
|
|
||||||
|
# Regex routes
|
||||||
|
router.add_route("~^/api/v1/users/(?P<user_id>\\d+)$", {"proxy_pass": "http://users-service"})
|
||||||
|
router.add_route("~^/api/v1/posts/(?P<post_id>\\d+)$", {"proxy_pass": "http://posts-service"})
|
||||||
|
router.add_route("~\\.(css|js|png|jpg|gif|svg|woff2?)$", {"root": "./static"})
|
||||||
|
router.add_route("~^/api/", {"proxy_pass": "http://api-gateway"})
|
||||||
|
|
||||||
|
# Default route
|
||||||
|
router.add_route("__default__", {"spa_fallback": True, "root": "./dist"})
|
||||||
|
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmarks():
|
||||||
|
print("=" * 70)
|
||||||
|
print("ROUTING BENCHMARK")
|
||||||
|
print("=" * 70)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test paths with different matching scenarios
|
||||||
|
test_cases = [
|
||||||
|
("/health", "Exact match (first)"),
|
||||||
|
("/api/status", "Exact match (middle)"),
|
||||||
|
("/api/v1/users/12345", "Regex match with groups"),
|
||||||
|
("/static/app.js", "Regex match (file extension)"),
|
||||||
|
("/api/v2/other", "Regex match (simple prefix)"),
|
||||||
|
("/some/random/path", "Default route (fallback)"),
|
||||||
|
("/nonexistent", "Default route (fallback)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
iterations = 100000
|
||||||
|
|
||||||
|
print(f"Iterations: {iterations:,}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Setup routers
|
||||||
|
py_router = setup_router(PyFastRouter)
|
||||||
|
cy_router = setup_router(CyFastRouter) if CYTHON_AVAILABLE else None
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for path, description in test_cases:
|
||||||
|
print(f"Path: {path}")
|
||||||
|
print(f" {description}")
|
||||||
|
|
||||||
|
# Python implementation (standard re)
|
||||||
|
py_mean, py_std = benchmark(lambda p=path: py_router.match(p), iterations)
|
||||||
|
results[(path, "Python (re)")] = py_mean
|
||||||
|
print(f" Python (re): {format_time(py_mean):>12} ± {format_time(py_std)}")
|
||||||
|
|
||||||
|
# Cython implementation (PCRE2 JIT)
|
||||||
|
if CYTHON_AVAILABLE and cy_router:
|
||||||
|
cy_mean, cy_std = benchmark(lambda p=path: cy_router.match(p), iterations)
|
||||||
|
results[(path, "Cython (PCRE2)")] = cy_mean
|
||||||
|
speedup = py_mean / cy_mean if cy_mean > 0 else 0
|
||||||
|
print(f" Cython (PCRE2): {format_time(cy_mean):>12} ± {format_time(cy_std)} ({speedup:.2f}x faster)")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
if CYTHON_AVAILABLE:
|
||||||
|
print("=" * 70)
|
||||||
|
print("SUMMARY")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
py_total = sum(v for k, v in results.items() if k[1] == "Python (re)")
|
||||||
|
cy_total = sum(v for k, v in results.items() if k[1] == "Cython (PCRE2)")
|
||||||
|
|
||||||
|
print(f" Python (re) total: {format_time(py_total)}")
|
||||||
|
print(f" Cython (PCRE2) total: {format_time(cy_total)}")
|
||||||
|
print(f" Overall speedup: {py_total / cy_total:.2f}x")
|
||||||
|
|
||||||
|
# Show JIT compilation status
|
||||||
|
print()
|
||||||
|
print("PCRE2 JIT Status:")
|
||||||
|
for route in cy_router.list_routes(): # type: ignore False linter error
|
||||||
|
if route["type"] == "regex":
|
||||||
|
jit = route.get("jit_compiled", False)
|
||||||
|
status = "✓ JIT" if jit else "✗ No JIT"
|
||||||
|
print(f" {status}: {route['pattern']}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_benchmarks()
|
||||||
@ -1,34 +0,0 @@
|
|||||||
# Multi-stage build for Konduktor
|
|
||||||
FROM golang:1.23-alpine AS builder
|
|
||||||
|
|
||||||
RUN apk add --no-cache git make
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
|
|
||||||
COPY go.mod go.sum* ./
|
|
||||||
RUN go mod download
|
|
||||||
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
RUN make build
|
|
||||||
|
|
||||||
FROM alpine:3.19
|
|
||||||
|
|
||||||
RUN apk add --no-cache ca-certificates tzdata
|
|
||||||
|
|
||||||
RUN adduser -D -g '' konduktor
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY --from=builder /build/bin/konduktor /usr/local/bin/
|
|
||||||
COPY --from=builder /build/bin/konduktorctl /usr/local/bin/
|
|
||||||
|
|
||||||
RUN mkdir -p /app/static /app/templates /app/logs && \
|
|
||||||
chown -R konduktor:konduktor /app
|
|
||||||
|
|
||||||
USER konduktor
|
|
||||||
|
|
||||||
EXPOSE 8080
|
|
||||||
|
|
||||||
ENTRYPOINT ["konduktor"]
|
|
||||||
CMD ["-c", "/app/config.yaml"]
|
|
||||||
108
go/Makefile
108
go/Makefile
@ -1,108 +0,0 @@
|
|||||||
# Konduktor Go Build
|
|
||||||
# Makefile for building and testing Konduktor
|
|
||||||
|
|
||||||
.PHONY: all build build-konduktor build-konduktorctl test clean deps fmt lint run
|
|
||||||
|
|
||||||
# Build configuration
|
|
||||||
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
|
||||||
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
|
||||||
BUILD_TIME ?= $(shell date -u '+%Y-%m-%dT%H:%M:%SZ')
|
|
||||||
LDFLAGS := -X main.Version=$(VERSION) -X main.GitCommit=$(GIT_COMMIT) -X main.BuildTime=$(BUILD_TIME)
|
|
||||||
|
|
||||||
# Output directories
|
|
||||||
BIN_DIR := bin
|
|
||||||
|
|
||||||
all: deps build
|
|
||||||
|
|
||||||
# Download dependencies
|
|
||||||
deps:
|
|
||||||
@echo "==> Downloading dependencies..."
|
|
||||||
go mod download
|
|
||||||
go mod tidy
|
|
||||||
|
|
||||||
# Build all binaries
|
|
||||||
build: build-konduktor build-konduktorctl
|
|
||||||
|
|
||||||
# Build konduktor server
|
|
||||||
build-konduktor:
|
|
||||||
@echo "==> Building konduktor..."
|
|
||||||
@mkdir -p $(BIN_DIR)
|
|
||||||
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktor ./cmd/konduktor
|
|
||||||
|
|
||||||
# Build konduktorctl CLI
|
|
||||||
build-konduktorctl:
|
|
||||||
@echo "==> Building konduktorctl..."
|
|
||||||
@mkdir -p $(BIN_DIR)
|
|
||||||
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktorctl ./cmd/konduktorctl
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
test:
|
|
||||||
@echo "==> Running tests..."
|
|
||||||
go test -v -race -cover ./...
|
|
||||||
|
|
||||||
# Run tests with coverage report
|
|
||||||
test-coverage:
|
|
||||||
@echo "==> Running tests with coverage..."
|
|
||||||
go test -v -race -coverprofile=coverage.out ./...
|
|
||||||
go tool cover -html=coverage.out -o coverage.html
|
|
||||||
@echo "Coverage report: coverage.html"
|
|
||||||
|
|
||||||
# Format code
|
|
||||||
fmt:
|
|
||||||
@echo "==> Formatting code..."
|
|
||||||
go fmt ./...
|
|
||||||
goimports -w .
|
|
||||||
|
|
||||||
# Lint code
|
|
||||||
lint:
|
|
||||||
@echo "==> Linting code..."
|
|
||||||
golangci-lint run ./...
|
|
||||||
|
|
||||||
# Run the server (development)
|
|
||||||
run: build-konduktor
|
|
||||||
@echo "==> Running konduktor..."
|
|
||||||
./$(BIN_DIR)/konduktor -c ../config.yaml
|
|
||||||
|
|
||||||
# Clean build artifacts
|
|
||||||
clean:
|
|
||||||
@echo "==> Cleaning..."
|
|
||||||
rm -rf $(BIN_DIR)
|
|
||||||
rm -f coverage.out coverage.html
|
|
||||||
|
|
||||||
# Install binaries to GOPATH/bin
|
|
||||||
install: build
|
|
||||||
@echo "==> Installing binaries..."
|
|
||||||
cp $(BIN_DIR)/konduktor $(GOPATH)/bin/
|
|
||||||
cp $(BIN_DIR)/konduktorctl $(GOPATH)/bin/
|
|
||||||
|
|
||||||
# Generate mocks (for testing)
|
|
||||||
generate:
|
|
||||||
@echo "==> Generating code..."
|
|
||||||
go generate ./...
|
|
||||||
|
|
||||||
# Docker build
|
|
||||||
docker-build:
|
|
||||||
@echo "==> Building Docker image..."
|
|
||||||
docker build -t konduktor:$(VERSION) .
|
|
||||||
|
|
||||||
# Show help
|
|
||||||
help:
|
|
||||||
@echo "Konduktor Build System"
|
|
||||||
@echo ""
|
|
||||||
@echo "Usage: make [target]"
|
|
||||||
@echo ""
|
|
||||||
@echo "Targets:"
|
|
||||||
@echo " all Download deps and build all binaries"
|
|
||||||
@echo " deps Download and tidy dependencies"
|
|
||||||
@echo " build Build all binaries"
|
|
||||||
@echo " build-konduktor Build the server binary"
|
|
||||||
@echo " build-konduktorctl Build the CLI binary"
|
|
||||||
@echo " test Run tests"
|
|
||||||
@echo " test-coverage Run tests with coverage report"
|
|
||||||
@echo " fmt Format code"
|
|
||||||
@echo " lint Lint code"
|
|
||||||
@echo " run Build and run the server"
|
|
||||||
@echo " clean Clean build artifacts"
|
|
||||||
@echo " install Install binaries to GOPATH/bin"
|
|
||||||
@echo " docker-build Build Docker image"
|
|
||||||
@echo " help Show this help"
|
|
||||||
149
go/README.md
149
go/README.md
@ -1,149 +0,0 @@
|
|||||||
# Konduktor (Go)
|
|
||||||
|
|
||||||
High-performance HTTP web server with extensible routing and process orchestration. (Previously known as PyServe in Python)
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
go/
|
|
||||||
├── cmd/
|
|
||||||
│ ├── konduktor/ # Main server binary
|
|
||||||
│ └── konduktorctl/ # CLI management tool
|
|
||||||
├── internal/
|
|
||||||
│ ├── config/ # Configuration management
|
|
||||||
│ ├── logging/ # Structured logging
|
|
||||||
│ ├── middleware/ # HTTP middleware
|
|
||||||
│ ├── routing/ # HTTP routing
|
|
||||||
│ ├── extensions/ # Extension system (TODO)
|
|
||||||
│ └── process/ # Process management (TODO)
|
|
||||||
├── pkg/ # Public packages (TODO)
|
|
||||||
├── go.mod
|
|
||||||
├── go.sum
|
|
||||||
└── Makefile
|
|
||||||
```
|
|
||||||
|
|
||||||
## Building
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd go
|
|
||||||
|
|
||||||
# Download dependencies
|
|
||||||
make deps
|
|
||||||
|
|
||||||
# Build all binaries
|
|
||||||
make build
|
|
||||||
|
|
||||||
# Or build individually
|
|
||||||
make build-konduktor
|
|
||||||
make build-konduktorctl
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run with default config
|
|
||||||
./bin/konduktor
|
|
||||||
|
|
||||||
# Run with custom config
|
|
||||||
./bin/konduktor -c ../config.yaml
|
|
||||||
|
|
||||||
# Run with flags
|
|
||||||
./bin/konduktor --host 127.0.0.1 --port 3000 --debug
|
|
||||||
```
|
|
||||||
|
|
||||||
## CLI Commands (konduktorctl)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Start services
|
|
||||||
konduktorctl up
|
|
||||||
|
|
||||||
# Stop services
|
|
||||||
konduktorctl down
|
|
||||||
|
|
||||||
# View status
|
|
||||||
konduktorctl status
|
|
||||||
|
|
||||||
# View logs
|
|
||||||
konduktorctl logs -f
|
|
||||||
|
|
||||||
# Health check
|
|
||||||
konduktorctl health
|
|
||||||
|
|
||||||
# Scale services
|
|
||||||
konduktorctl scale api=3
|
|
||||||
|
|
||||||
# Configuration management
|
|
||||||
konduktorctl config show
|
|
||||||
konduktorctl config validate
|
|
||||||
|
|
||||||
# Initialize new project
|
|
||||||
konduktorctl init
|
|
||||||
```
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
Uses the same YAML configuration format as the Python version:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
server:
|
|
||||||
host: 0.0.0.0
|
|
||||||
port: 8080
|
|
||||||
|
|
||||||
http:
|
|
||||||
static_dir: ./static
|
|
||||||
templates_dir: ./templates
|
|
||||||
|
|
||||||
ssl:
|
|
||||||
enabled: false
|
|
||||||
cert_file: ./ssl/cert.pem
|
|
||||||
key_file: ./ssl/key.pem
|
|
||||||
|
|
||||||
logging:
|
|
||||||
level: INFO
|
|
||||||
console_output: true
|
|
||||||
|
|
||||||
extensions:
|
|
||||||
- type: routing
|
|
||||||
config:
|
|
||||||
regex_locations:
|
|
||||||
"=/health":
|
|
||||||
return: "200 OK"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Development
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Format code
|
|
||||||
make fmt
|
|
||||||
|
|
||||||
# Run linter
|
|
||||||
make lint
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
make test
|
|
||||||
|
|
||||||
# Run with coverage
|
|
||||||
make test-coverage
|
|
||||||
```
|
|
||||||
|
|
||||||
## Migration from Python
|
|
||||||
|
|
||||||
This is a gradual rewrite of PyServe to Go. The project is now called **Konduktor**.
|
|
||||||
|
|
||||||
### Completed
|
|
||||||
- [x] Basic project structure
|
|
||||||
- [x] Configuration loading
|
|
||||||
- [x] HTTP server with graceful shutdown
|
|
||||||
- [x] Basic routing
|
|
||||||
- [x] Middleware (access log, recovery, server header)
|
|
||||||
- [x] CLI structure (konduktor, konduktorctl)
|
|
||||||
|
|
||||||
### TODO
|
|
||||||
- [ ] Extension system
|
|
||||||
- [x] Regex routing
|
|
||||||
- [x] Reverse proxy
|
|
||||||
- [ ] Process orchestration
|
|
||||||
- [ ] ASGI/WSGI adapter support
|
|
||||||
- [ ] WebSocket support
|
|
||||||
- [ ] Hot reload
|
|
||||||
- [ ] Metrics and monitoring
|
|
||||||
@ -1,79 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/config"
|
|
||||||
"github.com/konduktor/konduktor/internal/server"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
Version = "0.1.0"
|
|
||||||
BuildTime = "unknown"
|
|
||||||
GitCommit = "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
cfgFile string
|
|
||||||
host string
|
|
||||||
port int
|
|
||||||
debug bool
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
rootCmd := &cobra.Command{
|
|
||||||
Use: "konduktor",
|
|
||||||
Short: "Konduktor - HTTP web server",
|
|
||||||
Long: `Konduktor is a high-performance HTTP web server with extensible routing and process orchestration.`,
|
|
||||||
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
|
|
||||||
RunE: runServer,
|
|
||||||
}
|
|
||||||
|
|
||||||
rootCmd.Flags().StringVarP(&cfgFile, "config", "c", "config.yaml", "Path to configuration file")
|
|
||||||
rootCmd.Flags().StringVar(&host, "host", "", "Host to bind the server to")
|
|
||||||
rootCmd.Flags().IntVar(&port, "port", 0, "Port to bind the server to")
|
|
||||||
rootCmd.Flags().BoolVar(&debug, "debug", false, "Enable debug mode")
|
|
||||||
|
|
||||||
if err := rootCmd.Execute(); err != nil {
|
|
||||||
fmt.Fprintln(os.Stderr, err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func runServer(cmd *cobra.Command, args []string) error {
|
|
||||||
cfg, err := config.Load(cfgFile)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
fmt.Printf("Configuration file %s not found, using defaults\n", cfgFile)
|
|
||||||
cfg = config.Default()
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("configuration loading error: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if host != "" {
|
|
||||||
cfg.Server.Host = host
|
|
||||||
}
|
|
||||||
if port != 0 {
|
|
||||||
cfg.Server.Port = port
|
|
||||||
}
|
|
||||||
if debug {
|
|
||||||
cfg.Logging.Level = "DEBUG"
|
|
||||||
}
|
|
||||||
|
|
||||||
srv, err := server.New(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("server creation error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Starting Konduktor server on %s:%d\n", cfg.Server.Host, cfg.Server.Port)
|
|
||||||
|
|
||||||
if err := srv.Run(); err != nil {
|
|
||||||
return fmt.Errorf("server startup error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@ -1,180 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
Version = "0.1.0"
|
|
||||||
BuildTime = "unknown"
|
|
||||||
GitCommit = "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
rootCmd := &cobra.Command{
|
|
||||||
Use: "konduktorctl",
|
|
||||||
Short: "Konduktorctl - Service management CLI",
|
|
||||||
Long: `Konduktorctl is a CLI tool for managing Konduktor services.`,
|
|
||||||
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
|
|
||||||
}
|
|
||||||
|
|
||||||
rootCmd.AddCommand(
|
|
||||||
newUpCmd(),
|
|
||||||
newDownCmd(),
|
|
||||||
newStatusCmd(),
|
|
||||||
newLogsCmd(),
|
|
||||||
newHealthCmd(),
|
|
||||||
newScaleCmd(),
|
|
||||||
newConfigCmd(),
|
|
||||||
newInitCmd(),
|
|
||||||
newTopCmd(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := rootCmd.Execute(); err != nil {
|
|
||||||
fmt.Fprintln(os.Stderr, err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUpCmd() *cobra.Command {
|
|
||||||
cmd := &cobra.Command{
|
|
||||||
Use: "up [service...]",
|
|
||||||
Short: "Start services",
|
|
||||||
Long: `Start one or more services. If no service is specified, all services are started.`,
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Starting services...")
|
|
||||||
// TODO: Implement service start logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cmd.Flags().BoolP("detach", "d", false, "Run in background")
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDownCmd() *cobra.Command {
|
|
||||||
return &cobra.Command{
|
|
||||||
Use: "down [service...]",
|
|
||||||
Short: "Stop services",
|
|
||||||
Long: `Stop one or more services. If no service is specified, all services are stopped.`,
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Stopping services...")
|
|
||||||
// TODO: Implement service stop logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newStatusCmd() *cobra.Command {
|
|
||||||
return &cobra.Command{
|
|
||||||
Use: "status [service...]",
|
|
||||||
Short: "Show service status",
|
|
||||||
Long: `Show the status of one or more services.`,
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Service status:")
|
|
||||||
// TODO: Implement status display logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newLogsCmd() *cobra.Command {
|
|
||||||
cmd := &cobra.Command{
|
|
||||||
Use: "logs [service]",
|
|
||||||
Short: "View service logs",
|
|
||||||
Long: `View logs for a specific service.`,
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Fetching logs...")
|
|
||||||
// TODO: Implement logs viewing logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cmd.Flags().BoolP("follow", "f", false, "Follow log output")
|
|
||||||
cmd.Flags().IntP("tail", "n", 100, "Number of lines to show from the end")
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHealthCmd() *cobra.Command {
|
|
||||||
return &cobra.Command{
|
|
||||||
Use: "health",
|
|
||||||
Short: "Check service health",
|
|
||||||
Long: `Check the health status of all services.`,
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Health check:")
|
|
||||||
// TODO: Implement health check logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newScaleCmd() *cobra.Command {
|
|
||||||
return &cobra.Command{
|
|
||||||
Use: "scale <service>=<count>",
|
|
||||||
Short: "Scale a service",
|
|
||||||
Long: `Scale a service to a specific number of instances.`,
|
|
||||||
Args: cobra.MinimumNArgs(1),
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Printf("Scaling: %v\n", args)
|
|
||||||
// TODO: Implement scaling logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newConfigCmd() *cobra.Command {
|
|
||||||
cmd := &cobra.Command{
|
|
||||||
Use: "config",
|
|
||||||
Short: "Manage configuration",
|
|
||||||
Long: `View and validate configuration.`,
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.AddCommand(&cobra.Command{
|
|
||||||
Use: "show",
|
|
||||||
Short: "Show current configuration",
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Current configuration:")
|
|
||||||
// TODO: Implement config show logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
cmd.AddCommand(&cobra.Command{
|
|
||||||
Use: "validate",
|
|
||||||
Short: "Validate configuration file",
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Validating configuration...")
|
|
||||||
// TODO: Implement config validation logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
func newInitCmd() *cobra.Command {
|
|
||||||
return &cobra.Command{
|
|
||||||
Use: "init",
|
|
||||||
Short: "Initialize a new project",
|
|
||||||
Long: `Create a new Konduktor project with default configuration.`,
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Initializing new project...")
|
|
||||||
// TODO: Implement init logic
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTopCmd() *cobra.Command {
|
|
||||||
return &cobra.Command{
|
|
||||||
Use: "top",
|
|
||||||
Short: "Display running processes",
|
|
||||||
Long: `Display real-time view of running processes and resource usage.`,
|
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
|
||||||
fmt.Println("Process monitor:")
|
|
||||||
// TODO: Implement top-like display
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
18
go/go.mod
18
go/go.mod
@ -1,18 +0,0 @@
|
|||||||
module github.com/konduktor/konduktor
|
|
||||||
|
|
||||||
go 1.23.0
|
|
||||||
|
|
||||||
toolchain go1.24.2
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/spf13/cobra v1.10.2
|
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
|
||||||
github.com/spf13/pflag v1.0.10 // indirect
|
|
||||||
go.uber.org/multierr v1.10.0 // indirect
|
|
||||||
go.uber.org/zap v1.27.1 // indirect
|
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
|
||||||
)
|
|
||||||
20
go/go.sum
20
go/go.sum
@ -1,20 +0,0 @@
|
|||||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
|
||||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
|
||||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
|
||||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
|
||||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
|
||||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
|
||||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
|
||||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
|
||||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
|
||||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
|
||||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
||||||
@ -1,134 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
HTTP HTTPConfig `yaml:"http"`
|
|
||||||
Server ServerConfig `yaml:"server"`
|
|
||||||
SSL SSLConfig `yaml:"ssl"`
|
|
||||||
Logging LoggingConfig `yaml:"logging"`
|
|
||||||
Extensions []ExtensionConfig `yaml:"extensions"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type HTTPConfig struct {
|
|
||||||
StaticDir string `yaml:"static_dir"`
|
|
||||||
TemplatesDir string `yaml:"templates_dir"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerConfig struct {
|
|
||||||
Host string `yaml:"host"`
|
|
||||||
Port int `yaml:"port"`
|
|
||||||
Backlog int `yaml:"backlog"`
|
|
||||||
DefaultRoot bool `yaml:"default_root"`
|
|
||||||
ProxyTimeout time.Duration `yaml:"proxy_timeout"`
|
|
||||||
RedirectInstructions map[string]string `yaml:"redirect_instructions"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SSLConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
CertFile string `yaml:"cert_file"`
|
|
||||||
KeyFile string `yaml:"key_file"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type LoggingConfig struct {
|
|
||||||
Level string `yaml:"level"`
|
|
||||||
ConsoleOutput bool `yaml:"console_output"`
|
|
||||||
Format LogFormatConfig `yaml:"format"`
|
|
||||||
Console *ConsoleLogConfig `yaml:"console"`
|
|
||||||
Files []FileLogConfig `yaml:"files"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type LogFormatConfig struct {
|
|
||||||
Type string `yaml:"type"`
|
|
||||||
UseColors bool `yaml:"use_colors"`
|
|
||||||
ShowModule bool `yaml:"show_module"`
|
|
||||||
TimestampFormat string `yaml:"timestamp_format"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConsoleLogConfig struct {
|
|
||||||
Format LogFormatConfig `yaml:"format"`
|
|
||||||
Level string `yaml:"level"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FileLogConfig struct {
|
|
||||||
Path string `yaml:"path"`
|
|
||||||
Level string `yaml:"level"`
|
|
||||||
Loggers []string `yaml:"loggers"`
|
|
||||||
Format LogFormatConfig `yaml:"format"`
|
|
||||||
MaxBytes int64 `yaml:"max_bytes"`
|
|
||||||
BackupCount int `yaml:"backup_count"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ExtensionConfig struct {
|
|
||||||
Type string `yaml:"type"`
|
|
||||||
Config map[string]interface{} `yaml:"config"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func Load(path string) (*Config, error) {
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := Default()
|
|
||||||
if err := yaml.Unmarshal(data, cfg); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Default() *Config {
|
|
||||||
return &Config{
|
|
||||||
HTTP: HTTPConfig{
|
|
||||||
StaticDir: "./static",
|
|
||||||
TemplatesDir: "./templates",
|
|
||||||
},
|
|
||||||
Server: ServerConfig{
|
|
||||||
Host: "0.0.0.0",
|
|
||||||
Port: 8080,
|
|
||||||
Backlog: 5,
|
|
||||||
DefaultRoot: false,
|
|
||||||
ProxyTimeout: 30 * time.Second,
|
|
||||||
},
|
|
||||||
SSL: SSLConfig{
|
|
||||||
Enabled: false,
|
|
||||||
CertFile: "./ssl/cert.pem",
|
|
||||||
KeyFile: "./ssl/key.pem",
|
|
||||||
},
|
|
||||||
Logging: LoggingConfig{
|
|
||||||
Level: "INFO",
|
|
||||||
ConsoleOutput: true,
|
|
||||||
Format: LogFormatConfig{
|
|
||||||
Type: "standard",
|
|
||||||
UseColors: true,
|
|
||||||
ShowModule: true,
|
|
||||||
TimestampFormat: "2006-01-02 15:04:05",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Extensions: []ExtensionConfig{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
|
||||||
if c.Server.Port < 1 || c.Server.Port > 65535 {
|
|
||||||
return fmt.Errorf("invalid port: %d", c.Server.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.SSL.Enabled {
|
|
||||||
if c.SSL.CertFile == "" {
|
|
||||||
return fmt.Errorf("SSL enabled but cert_file not specified")
|
|
||||||
}
|
|
||||||
if c.SSL.KeyFile == "" {
|
|
||||||
return fmt.Errorf("SSL enabled but key_file not specified")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@ -1,127 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDefault(t *testing.T) {
|
|
||||||
cfg := Default()
|
|
||||||
|
|
||||||
if cfg.Server.Host != "0.0.0.0" {
|
|
||||||
t.Errorf("Expected host 0.0.0.0, got %s", cfg.Server.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Server.Port != 8080 {
|
|
||||||
t.Errorf("Expected port 8080, got %d", cfg.Server.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.SSL.Enabled {
|
|
||||||
t.Error("Expected SSL to be disabled by default")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
modify func(*Config)
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid default config",
|
|
||||||
modify: func(c *Config) {},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid port - too low",
|
|
||||||
modify: func(c *Config) {
|
|
||||||
c.Server.Port = 0
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid port - too high",
|
|
||||||
modify: func(c *Config) {
|
|
||||||
c.Server.Port = 70000
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "SSL enabled without cert",
|
|
||||||
modify: func(c *Config) {
|
|
||||||
c.SSL.Enabled = true
|
|
||||||
c.SSL.CertFile = ""
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "SSL enabled without key",
|
|
||||||
modify: func(c *Config) {
|
|
||||||
c.SSL.Enabled = true
|
|
||||||
c.SSL.CertFile = "cert.pem"
|
|
||||||
c.SSL.KeyFile = ""
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := Default()
|
|
||||||
tt.modify(cfg)
|
|
||||||
|
|
||||||
err := cfg.Validate()
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoad(t *testing.T) {
|
|
||||||
// Create temporary config file
|
|
||||||
content := `
|
|
||||||
server:
|
|
||||||
host: 127.0.0.1
|
|
||||||
port: 3000
|
|
||||||
|
|
||||||
logging:
|
|
||||||
level: DEBUG
|
|
||||||
`
|
|
||||||
tmpfile, err := os.CreateTemp("", "config-*.yaml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer os.Remove(tmpfile.Name())
|
|
||||||
|
|
||||||
if _, err := tmpfile.Write([]byte(content)); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := tmpfile.Close(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := Load(tmpfile.Name())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Server.Host != "127.0.0.1" {
|
|
||||||
t.Errorf("Expected host 127.0.0.1, got %s", cfg.Server.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Server.Port != 3000 {
|
|
||||||
t.Errorf("Expected port 3000, got %d", cfg.Server.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Logging.Level != "DEBUG" {
|
|
||||||
t.Errorf("Expected level DEBUG, got %s", cfg.Logging.Level)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadNotFound(t *testing.T) {
|
|
||||||
_, err := Load("/nonexistent/config.yaml")
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error for non-existent file")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,466 +0,0 @@
|
|||||||
package extension
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CachingExtension struct {
|
|
||||||
BaseExtension
|
|
||||||
cache map[string]*cacheEntry
|
|
||||||
cachePatterns []*cachePattern
|
|
||||||
defaultTTL time.Duration
|
|
||||||
maxSize int
|
|
||||||
currentSize int
|
|
||||||
mu sync.RWMutex
|
|
||||||
|
|
||||||
hits int64
|
|
||||||
misses int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type cacheEntry struct {
|
|
||||||
key string
|
|
||||||
body []byte
|
|
||||||
headers http.Header
|
|
||||||
statusCode int
|
|
||||||
contentType string
|
|
||||||
createdAt time.Time
|
|
||||||
expiresAt time.Time
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
type cachePattern struct {
|
|
||||||
pattern *regexp.Regexp
|
|
||||||
ttl time.Duration
|
|
||||||
methods []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type CachingConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
DefaultTTL string `yaml:"default_ttl"` // e.g., "5m", "1h"
|
|
||||||
MaxSizeMB int `yaml:"max_size_mb"` // Max cache size in MB
|
|
||||||
CachePatterns []PatternConfig `yaml:"cache_patterns"` // Patterns to cache
|
|
||||||
}
|
|
||||||
|
|
||||||
type PatternConfig struct {
|
|
||||||
Pattern string `yaml:"pattern"` // Regex pattern
|
|
||||||
TTL string `yaml:"ttl"` // TTL for this pattern
|
|
||||||
Methods []string `yaml:"methods"` // HTTP methods to cache (default: GET)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCachingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
|
||||||
ext := &CachingExtension{
|
|
||||||
BaseExtension: NewBaseExtension("caching", 20, logger),
|
|
||||||
cache: make(map[string]*cacheEntry),
|
|
||||||
cachePatterns: make([]*cachePattern, 0),
|
|
||||||
defaultTTL: 5 * time.Minute,
|
|
||||||
maxSize: 100 * 1024 * 1024, // 100MB default
|
|
||||||
}
|
|
||||||
|
|
||||||
if ttl, ok := config["default_ttl"].(string); ok {
|
|
||||||
if duration, err := time.ParseDuration(ttl); err == nil {
|
|
||||||
ext.defaultTTL = duration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxSize, ok := config["max_size_mb"].(int); ok {
|
|
||||||
ext.maxSize = maxSize * 1024 * 1024
|
|
||||||
} else if maxSizeFloat, ok := config["max_size_mb"].(float64); ok {
|
|
||||||
ext.maxSize = int(maxSizeFloat) * 1024 * 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
if patterns, ok := config["cache_patterns"].([]interface{}); ok {
|
|
||||||
for _, p := range patterns {
|
|
||||||
if patternCfg, ok := p.(map[string]interface{}); ok {
|
|
||||||
pattern := &cachePattern{
|
|
||||||
ttl: ext.defaultTTL,
|
|
||||||
methods: []string{"GET"},
|
|
||||||
}
|
|
||||||
|
|
||||||
if patternStr, ok := patternCfg["pattern"].(string); ok {
|
|
||||||
re, err := regexp.Compile(patternStr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Invalid cache pattern", "pattern", patternStr, "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pattern.pattern = re
|
|
||||||
}
|
|
||||||
|
|
||||||
if ttl, ok := patternCfg["ttl"].(string); ok {
|
|
||||||
if duration, err := time.ParseDuration(ttl); err == nil {
|
|
||||||
pattern.ttl = duration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if methods, ok := patternCfg["methods"].([]interface{}); ok {
|
|
||||||
pattern.methods = make([]string, 0)
|
|
||||||
for _, m := range methods {
|
|
||||||
if method, ok := m.(string); ok {
|
|
||||||
pattern.methods = append(pattern.methods, strings.ToUpper(method))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ext.cachePatterns = append(ext.cachePatterns, pattern)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
go ext.cleanupLoop()
|
|
||||||
|
|
||||||
return ext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
||||||
if !e.shouldCache(r) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key := e.cacheKey(r)
|
|
||||||
|
|
||||||
e.mu.RLock()
|
|
||||||
entry, exists := e.cache[key]
|
|
||||||
e.mu.RUnlock()
|
|
||||||
|
|
||||||
if exists && time.Now().Before(entry.expiresAt) {
|
|
||||||
e.mu.Lock()
|
|
||||||
e.hits++
|
|
||||||
e.mu.Unlock()
|
|
||||||
|
|
||||||
// Mark as cache hit to prevent setting X-Cache: MISS
|
|
||||||
// Try to find cachingResponseWriter in the wrapper chain
|
|
||||||
setCacheHitFlag(w)
|
|
||||||
|
|
||||||
for k, values := range entry.headers {
|
|
||||||
for _, v := range values {
|
|
||||||
w.Header().Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.Header().Set("X-Cache", "HIT")
|
|
||||||
w.Header().Set("Content-Type", entry.contentType)
|
|
||||||
w.WriteHeader(entry.statusCode)
|
|
||||||
w.Write(entry.body)
|
|
||||||
|
|
||||||
e.logger.Debug("Cache hit", "key", key, "path", r.URL.Path)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
e.mu.Lock()
|
|
||||||
e.misses++
|
|
||||||
e.mu.Unlock()
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setCacheHitFlag tries to find cachingResponseWriter and set cache hit flag
|
|
||||||
func setCacheHitFlag(w http.ResponseWriter) {
|
|
||||||
// Direct match
|
|
||||||
if cw, ok := w.(*cachingResponseWriter); ok {
|
|
||||||
cw.SetCacheHit()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try unwrapping
|
|
||||||
type unwrapper interface {
|
|
||||||
Unwrap() http.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
if u, ok := w.(unwrapper); ok {
|
|
||||||
w = u.Unwrap()
|
|
||||||
if cw, ok := w.(*cachingResponseWriter); ok {
|
|
||||||
cw.SetCacheHit()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessResponse caches the response if applicable
|
|
||||||
func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Response caching is handled by the CachingResponseWriter
|
|
||||||
// X-Cache header is set in the cachingResponseWriter.WriteHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
// WrapResponseWriter wraps the response writer to capture the response for caching
|
|
||||||
func (e *CachingExtension) WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter {
|
|
||||||
if !e.shouldCache(r) {
|
|
||||||
return w
|
|
||||||
}
|
|
||||||
|
|
||||||
return &cachingResponseWriter{
|
|
||||||
ResponseWriter: w,
|
|
||||||
ext: e,
|
|
||||||
request: r,
|
|
||||||
buffer: &bytes.Buffer{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type cachingResponseWriter struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
ext *CachingExtension
|
|
||||||
request *http.Request
|
|
||||||
buffer *bytes.Buffer
|
|
||||||
statusCode int
|
|
||||||
wroteHeader bool
|
|
||||||
cacheHit bool // Flag to indicate if this was a cache hit
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *cachingResponseWriter) WriteHeader(code int) {
|
|
||||||
if !cw.wroteHeader {
|
|
||||||
cw.statusCode = code
|
|
||||||
cw.wroteHeader = true
|
|
||||||
// Set X-Cache: MISS header before writing headers (only if not a cache hit)
|
|
||||||
if !cw.cacheHit {
|
|
||||||
cw.ResponseWriter.Header().Set("X-Cache", "MISS")
|
|
||||||
}
|
|
||||||
cw.ResponseWriter.WriteHeader(code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCacheHit marks this response as a cache hit (to avoid setting X-Cache: MISS)
|
|
||||||
func (cw *cachingResponseWriter) SetCacheHit() {
|
|
||||||
cw.cacheHit = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *cachingResponseWriter) Write(b []byte) (int, error) {
|
|
||||||
if !cw.wroteHeader {
|
|
||||||
cw.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
cw.buffer.Write(b)
|
|
||||||
|
|
||||||
return cw.ResponseWriter.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *cachingResponseWriter) Finalize() {
|
|
||||||
if cw.statusCode < 200 || cw.statusCode >= 400 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
body := cw.buffer.Bytes()
|
|
||||||
if len(body) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
key := cw.ext.cacheKey(cw.request)
|
|
||||||
ttl := cw.ext.getTTL(cw.request)
|
|
||||||
|
|
||||||
entry := &cacheEntry{
|
|
||||||
key: key,
|
|
||||||
body: body,
|
|
||||||
headers: cw.Header().Clone(),
|
|
||||||
statusCode: cw.statusCode,
|
|
||||||
contentType: cw.Header().Get("Content-Type"),
|
|
||||||
createdAt: time.Now(),
|
|
||||||
expiresAt: time.Now().Add(ttl),
|
|
||||||
size: len(body),
|
|
||||||
}
|
|
||||||
|
|
||||||
cw.ext.store(entry)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) shouldCache(r *http.Request) bool {
|
|
||||||
path := r.URL.Path
|
|
||||||
method := r.Method
|
|
||||||
|
|
||||||
for _, pattern := range e.cachePatterns {
|
|
||||||
if pattern.pattern.MatchString(path) {
|
|
||||||
for _, m := range pattern.methods {
|
|
||||||
if m == method {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default: only cache GET requests
|
|
||||||
return method == "GET" && len(e.cachePatterns) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) cacheKey(r *http.Request) string {
|
|
||||||
// Create cache key from method + URL + relevant headers
|
|
||||||
h := sha256.New()
|
|
||||||
h.Write([]byte(r.Method))
|
|
||||||
h.Write([]byte(r.URL.String()))
|
|
||||||
|
|
||||||
// Include Accept-Encoding for vary
|
|
||||||
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
|
|
||||||
h.Write([]byte(ae))
|
|
||||||
}
|
|
||||||
|
|
||||||
return hex.EncodeToString(h.Sum(nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) getTTL(r *http.Request) time.Duration {
|
|
||||||
path := r.URL.Path
|
|
||||||
|
|
||||||
for _, pattern := range e.cachePatterns {
|
|
||||||
if pattern.pattern.MatchString(path) {
|
|
||||||
return pattern.ttl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return e.defaultTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) store(entry *cacheEntry) {
|
|
||||||
e.mu.Lock()
|
|
||||||
defer e.mu.Unlock()
|
|
||||||
|
|
||||||
// Evict old entries if needed
|
|
||||||
for e.currentSize+entry.size > e.maxSize && len(e.cache) > 0 {
|
|
||||||
e.evictOldest()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store new entry
|
|
||||||
if existing, ok := e.cache[entry.key]; ok {
|
|
||||||
e.currentSize -= existing.size
|
|
||||||
}
|
|
||||||
|
|
||||||
e.cache[entry.key] = entry
|
|
||||||
e.currentSize += entry.size
|
|
||||||
|
|
||||||
e.logger.Debug("Cached response",
|
|
||||||
"key", entry.key[:16],
|
|
||||||
"size", entry.size,
|
|
||||||
"ttl", entry.expiresAt.Sub(entry.createdAt).String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) evictOldest() {
|
|
||||||
var oldestKey string
|
|
||||||
var oldestTime time.Time
|
|
||||||
|
|
||||||
for key, entry := range e.cache {
|
|
||||||
if oldestKey == "" || entry.createdAt.Before(oldestTime) {
|
|
||||||
oldestKey = key
|
|
||||||
oldestTime = entry.createdAt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if oldestKey != "" {
|
|
||||||
entry := e.cache[oldestKey]
|
|
||||||
e.currentSize -= entry.size
|
|
||||||
delete(e.cache, oldestKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) cleanupLoop() {
|
|
||||||
ticker := time.NewTicker(time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for range ticker.C {
|
|
||||||
e.cleanupExpired()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) cleanupExpired() {
|
|
||||||
e.mu.Lock()
|
|
||||||
defer e.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for key, entry := range e.cache {
|
|
||||||
if now.After(entry.expiresAt) {
|
|
||||||
e.currentSize -= entry.size
|
|
||||||
delete(e.cache, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *CachingExtension) Invalidate(key string) {
|
|
||||||
e.mu.Lock()
|
|
||||||
defer e.mu.Unlock()
|
|
||||||
|
|
||||||
if entry, ok := e.cache[key]; ok {
|
|
||||||
e.currentSize -= entry.size
|
|
||||||
delete(e.cache, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvalidatePattern removes all entries matching a pattern, unlike Invalidate
|
|
||||||
func (e *CachingExtension) InvalidatePattern(pattern string) error {
|
|
||||||
re, err := regexp.Compile(pattern)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
e.mu.Lock()
|
|
||||||
defer e.mu.Unlock()
|
|
||||||
|
|
||||||
for key, entry := range e.cache {
|
|
||||||
if re.MatchString(key) {
|
|
||||||
e.currentSize -= entry.size
|
|
||||||
delete(e.cache, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear removes all entries from cache
|
|
||||||
func (e *CachingExtension) Clear() {
|
|
||||||
e.mu.Lock()
|
|
||||||
defer e.mu.Unlock()
|
|
||||||
|
|
||||||
e.cache = make(map[string]*cacheEntry)
|
|
||||||
e.currentSize = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMetrics returns caching metrics
|
|
||||||
func (e *CachingExtension) GetMetrics() map[string]interface{} {
|
|
||||||
e.mu.RLock()
|
|
||||||
defer e.mu.RUnlock()
|
|
||||||
|
|
||||||
hitRate := float64(0)
|
|
||||||
total := e.hits + e.misses
|
|
||||||
if total > 0 {
|
|
||||||
hitRate = float64(e.hits) / float64(total) * 100
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[string]interface{}{
|
|
||||||
"entries": len(e.cache),
|
|
||||||
"size_bytes": e.currentSize,
|
|
||||||
"max_size": e.maxSize,
|
|
||||||
"hits": e.hits,
|
|
||||||
"misses": e.misses,
|
|
||||||
"hit_rate": hitRate,
|
|
||||||
"patterns": len(e.cachePatterns),
|
|
||||||
"default_ttl": e.defaultTTL.String(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup stops the cleanup goroutine
|
|
||||||
func (e *CachingExtension) Cleanup() error {
|
|
||||||
e.Clear()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheReader wraps an io.ReadCloser to cache the body
|
|
||||||
type CacheReader struct {
|
|
||||||
io.ReadCloser
|
|
||||||
buffer *bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cr *CacheReader) Read(p []byte) (int, error) {
|
|
||||||
n, err := cr.ReadCloser.Read(p)
|
|
||||||
if n > 0 {
|
|
||||||
cr.buffer.Write(p[:n])
|
|
||||||
}
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cr *CacheReader) GetBody() []byte {
|
|
||||||
return cr.buffer.Bytes()
|
|
||||||
}
|
|
||||||
@ -1,123 +0,0 @@
|
|||||||
package extension
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Extension is the interface that all extensions must implement
|
|
||||||
type Extension interface {
|
|
||||||
// Name returns the unique name of the extension
|
|
||||||
Name() string
|
|
||||||
|
|
||||||
// Initialize is called when the extension is loaded
|
|
||||||
Initialize() error
|
|
||||||
|
|
||||||
// ProcessRequest processes an incoming request before routing.
|
|
||||||
// Returns:
|
|
||||||
// - response: if non-nil, the request is handled and no further processing occurs
|
|
||||||
// - handled: if true, the request was handled by this extension
|
|
||||||
// - err: any error that occurred
|
|
||||||
ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (handled bool, err error)
|
|
||||||
|
|
||||||
// ProcessResponse is called after the response is generated but before it's sent.
|
|
||||||
// Extensions can modify the response here.
|
|
||||||
ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request)
|
|
||||||
|
|
||||||
// Cleanup is called when the extension is being unloaded
|
|
||||||
Cleanup() error
|
|
||||||
|
|
||||||
// Enabled returns whether the extension is currently enabled
|
|
||||||
Enabled() bool
|
|
||||||
|
|
||||||
// SetEnabled enables or disables the extension
|
|
||||||
SetEnabled(enabled bool)
|
|
||||||
|
|
||||||
// Priority returns the extension's priority (lower = earlier execution)
|
|
||||||
Priority() int
|
|
||||||
}
|
|
||||||
|
|
||||||
// BaseExtension provides a default implementation for common Extension methods
|
|
||||||
type BaseExtension struct {
|
|
||||||
name string
|
|
||||||
enabled bool
|
|
||||||
priority int
|
|
||||||
logger *logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBaseExtension creates a new BaseExtension
|
|
||||||
func NewBaseExtension(name string, priority int, logger *logging.Logger) BaseExtension {
|
|
||||||
return BaseExtension{
|
|
||||||
name: name,
|
|
||||||
enabled: true,
|
|
||||||
priority: priority,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name returns the extension name
|
|
||||||
func (b *BaseExtension) Name() string {
|
|
||||||
return b.name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enabled returns whether the extension is enabled
|
|
||||||
func (b *BaseExtension) Enabled() bool {
|
|
||||||
return b.enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEnabled sets the enabled state
|
|
||||||
func (b *BaseExtension) SetEnabled(enabled bool) {
|
|
||||||
b.enabled = enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority returns the extension priority
|
|
||||||
func (b *BaseExtension) Priority() int {
|
|
||||||
return b.priority
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize default implementation (no-op)
|
|
||||||
func (b *BaseExtension) Initialize() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup default implementation (no-op)
|
|
||||||
func (b *BaseExtension) Cleanup() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessRequest default implementation (pass-through)
|
|
||||||
func (b *BaseExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessResponse default implementation (no-op)
|
|
||||||
func (b *BaseExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logger returns the extension's logger
|
|
||||||
func (b *BaseExtension) Logger() *logging.Logger {
|
|
||||||
return b.logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtensionConfig holds configuration for creating extensions
|
|
||||||
type ExtensionConfig struct {
|
|
||||||
Type string
|
|
||||||
Config map[string]interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtensionFactory is a function that creates an extension from config
|
|
||||||
type ExtensionFactory func(config map[string]interface{}, logger *logging.Logger) (Extension, error)
|
|
||||||
|
|
||||||
// ResponseWriterWrapper is an optional interface that extensions can implement
|
|
||||||
// to wrap the response writer for capturing/modifying responses
|
|
||||||
type ResponseWriterWrapper interface {
|
|
||||||
WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResponseFinalizer is an optional interface for response writers that need
|
|
||||||
// to perform finalization after the response is written (e.g., caching)
|
|
||||||
type ResponseFinalizer interface {
|
|
||||||
Finalize()
|
|
||||||
}
|
|
||||||
@ -1,271 +0,0 @@
|
|||||||
package extension
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manager manages all loaded extensions
|
|
||||||
type Manager struct {
|
|
||||||
extensions []Extension
|
|
||||||
registry map[string]ExtensionFactory
|
|
||||||
logger *logging.Logger
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager creates a new extension manager
|
|
||||||
func NewManager(logger *logging.Logger) *Manager {
|
|
||||||
m := &Manager{
|
|
||||||
extensions: make([]Extension, 0),
|
|
||||||
registry: make(map[string]ExtensionFactory),
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register built-in extensions
|
|
||||||
m.RegisterFactory("routing", NewRoutingExtension)
|
|
||||||
m.RegisterFactory("security", NewSecurityExtension)
|
|
||||||
m.RegisterFactory("caching", NewCachingExtension)
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterFactory registers an extension factory
|
|
||||||
func (m *Manager) RegisterFactory(name string, factory ExtensionFactory) {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
m.registry[name] = factory
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadExtension loads an extension by type and config
|
|
||||||
func (m *Manager) LoadExtension(extType string, config map[string]interface{}) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
factory, ok := m.registry[extType]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unknown extension type: %s", extType)
|
|
||||||
}
|
|
||||||
|
|
||||||
ext, err := factory(config, m.logger)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create extension %s: %w", extType, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ext.Initialize(); err != nil {
|
|
||||||
return fmt.Errorf("failed to initialize extension %s: %w", extType, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.extensions = append(m.extensions, ext)
|
|
||||||
|
|
||||||
// Sort by priority (lower first)
|
|
||||||
sort.Slice(m.extensions, func(i, j int) bool {
|
|
||||||
return m.extensions[i].Priority() < m.extensions[j].Priority()
|
|
||||||
})
|
|
||||||
|
|
||||||
m.logger.Info("Loaded extension", "type", extType, "name", ext.Name(), "priority", ext.Priority())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddExtension adds a pre-created extension
|
|
||||||
func (m *Manager) AddExtension(ext Extension) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if err := ext.Initialize(); err != nil {
|
|
||||||
return fmt.Errorf("failed to initialize extension %s: %w", ext.Name(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.extensions = append(m.extensions, ext)
|
|
||||||
|
|
||||||
// Sort by priority
|
|
||||||
sort.Slice(m.extensions, func(i, j int) bool {
|
|
||||||
return m.extensions[i].Priority() < m.extensions[j].Priority()
|
|
||||||
})
|
|
||||||
|
|
||||||
m.logger.Info("Added extension", "name", ext.Name(), "priority", ext.Priority())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessRequest runs all extensions' ProcessRequest in order
|
|
||||||
// Returns true if any extension handled the request
|
|
||||||
func (m *Manager) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
||||||
m.mu.RLock()
|
|
||||||
extensions := m.extensions
|
|
||||||
m.mu.RUnlock()
|
|
||||||
|
|
||||||
for _, ext := range extensions {
|
|
||||||
if !ext.Enabled() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
handled, err := ext.ProcessRequest(ctx, w, r)
|
|
||||||
if err != nil {
|
|
||||||
m.logger.Error("Extension error", "extension", ext.Name(), "error", err)
|
|
||||||
// Continue to next extension on error
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if handled {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessResponse runs all extensions' ProcessResponse in reverse order
|
|
||||||
func (m *Manager) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
||||||
m.mu.RLock()
|
|
||||||
extensions := m.extensions
|
|
||||||
m.mu.RUnlock()
|
|
||||||
|
|
||||||
// Process in reverse order for response
|
|
||||||
for i := len(extensions) - 1; i >= 0; i-- {
|
|
||||||
ext := extensions[i]
|
|
||||||
if !ext.Enabled() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ext.ProcessResponse(ctx, w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup cleans up all extensions
|
|
||||||
func (m *Manager) Cleanup() {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
for _, ext := range m.extensions {
|
|
||||||
if err := ext.Cleanup(); err != nil {
|
|
||||||
m.logger.Error("Extension cleanup error", "extension", ext.Name(), "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.extensions = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetExtension returns an extension by name
|
|
||||||
func (m *Manager) GetExtension(name string) Extension {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
|
|
||||||
for _, ext := range m.extensions {
|
|
||||||
if ext.Name() == name {
|
|
||||||
return ext
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extensions returns all loaded extensions
|
|
||||||
func (m *Manager) Extensions() []Extension {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
|
|
||||||
result := make([]Extension, len(m.extensions))
|
|
||||||
copy(result, m.extensions)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler returns an http.Handler that processes requests through all extensions
|
|
||||||
func (m *Manager) Handler(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
|
|
||||||
// Wrap response writer through all extensions that support it
|
|
||||||
// Process in reverse priority order so highest priority wrapper is outermost
|
|
||||||
wrappedWriter := w
|
|
||||||
var finalizers []ResponseFinalizer
|
|
||||||
|
|
||||||
m.mu.RLock()
|
|
||||||
extensions := m.extensions
|
|
||||||
m.mu.RUnlock()
|
|
||||||
|
|
||||||
// Wrap response writer (lowest priority first, so they wrap in correct order)
|
|
||||||
for _, ext := range extensions {
|
|
||||||
if !ext.Enabled() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if wrapper, ok := ext.(ResponseWriterWrapper); ok {
|
|
||||||
wrappedWriter = wrapper.WrapResponseWriter(wrappedWriter, r)
|
|
||||||
// Check if the wrapped writer implements Finalizer
|
|
||||||
if finalizer, ok := wrappedWriter.(ResponseFinalizer); ok {
|
|
||||||
finalizers = append(finalizers, finalizer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create response wrapper to capture status code
|
|
||||||
responseWrapper := newResponseWrapper(wrappedWriter)
|
|
||||||
|
|
||||||
// Process request through extensions
|
|
||||||
handled, err := m.ProcessRequest(ctx, responseWrapper, r)
|
|
||||||
if err != nil {
|
|
||||||
m.logger.Error("Error processing request", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if handled {
|
|
||||||
// Extension handled the request, process response
|
|
||||||
m.ProcessResponse(ctx, responseWrapper, r)
|
|
||||||
// Finalize all response writers
|
|
||||||
for i := len(finalizers) - 1; i >= 0; i-- {
|
|
||||||
finalizers[i].Finalize()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// No extension handled, pass to next handler
|
|
||||||
next.ServeHTTP(responseWrapper, r)
|
|
||||||
|
|
||||||
// Process response
|
|
||||||
m.ProcessResponse(ctx, responseWrapper, r)
|
|
||||||
|
|
||||||
// Finalize all response writers
|
|
||||||
for i := len(finalizers) - 1; i >= 0; i-- {
|
|
||||||
finalizers[i].Finalize()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// responseWrapper wraps http.ResponseWriter to allow response modification
|
|
||||||
type responseWrapper struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
statusCode int
|
|
||||||
written bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newResponseWrapper(w http.ResponseWriter) *responseWrapper {
|
|
||||||
return &responseWrapper{
|
|
||||||
ResponseWriter: w,
|
|
||||||
statusCode: http.StatusOK,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *responseWrapper) WriteHeader(code int) {
|
|
||||||
if !rw.written {
|
|
||||||
rw.statusCode = code
|
|
||||||
rw.ResponseWriter.WriteHeader(code)
|
|
||||||
rw.written = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *responseWrapper) Write(b []byte) (int, error) {
|
|
||||||
if !rw.written {
|
|
||||||
rw.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
return rw.ResponseWriter.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *responseWrapper) StatusCode() int {
|
|
||||||
return rw.statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unwrap returns the underlying ResponseWriter (for type assertions)
|
|
||||||
func (rw *responseWrapper) Unwrap() http.ResponseWriter {
|
|
||||||
return rw.ResponseWriter
|
|
||||||
}
|
|
||||||
@ -1,176 +0,0 @@
|
|||||||
package extension
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestLogger() *logging.Logger {
|
|
||||||
logger, _ := logging.New(logging.Config{Level: "DEBUG"})
|
|
||||||
return logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewManager(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
manager := NewManager(logger)
|
|
||||||
|
|
||||||
if manager == nil {
|
|
||||||
t.Fatal("Expected manager, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check built-in factories are registered
|
|
||||||
if _, ok := manager.registry["routing"]; !ok {
|
|
||||||
t.Error("Expected routing factory to be registered")
|
|
||||||
}
|
|
||||||
if _, ok := manager.registry["security"]; !ok {
|
|
||||||
t.Error("Expected security factory to be registered")
|
|
||||||
}
|
|
||||||
if _, ok := manager.registry["caching"]; !ok {
|
|
||||||
t.Error("Expected caching factory to be registered")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_LoadExtension(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
manager := NewManager(logger)
|
|
||||||
|
|
||||||
err := manager.LoadExtension("security", map[string]interface{}{})
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Failed to load security extension: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
exts := manager.Extensions()
|
|
||||||
if len(exts) != 1 {
|
|
||||||
t.Errorf("Expected 1 extension, got %d", len(exts))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_LoadExtension_Unknown(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
manager := NewManager(logger)
|
|
||||||
|
|
||||||
err := manager.LoadExtension("unknown", map[string]interface{}{})
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error for unknown extension type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_GetExtension(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
manager := NewManager(logger)
|
|
||||||
|
|
||||||
manager.LoadExtension("security", map[string]interface{}{})
|
|
||||||
|
|
||||||
ext := manager.GetExtension("security")
|
|
||||||
if ext == nil {
|
|
||||||
t.Error("Expected to find security extension")
|
|
||||||
}
|
|
||||||
|
|
||||||
ext = manager.GetExtension("nonexistent")
|
|
||||||
if ext != nil {
|
|
||||||
t.Error("Expected nil for nonexistent extension")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_ProcessRequest(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
manager := NewManager(logger)
|
|
||||||
|
|
||||||
// Load security extension with blocked IP
|
|
||||||
manager.LoadExtension("security", map[string]interface{}{
|
|
||||||
"blocked_ips": []interface{}{"192.168.1.1"},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create test request
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
req.RemoteAddr = "192.168.1.1:12345"
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handled, err := manager.ProcessRequest(context.Background(), rr, req)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !handled {
|
|
||||||
t.Error("Expected request to be handled (blocked)")
|
|
||||||
}
|
|
||||||
|
|
||||||
if rr.Code != http.StatusForbidden {
|
|
||||||
t.Errorf("Expected status 403, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_Handler(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
manager := NewManager(logger)
|
|
||||||
|
|
||||||
// Load routing extension with a simple route
|
|
||||||
manager.LoadExtension("routing", map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"=/health": map[string]interface{}{
|
|
||||||
"return": "200 OK",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
})
|
|
||||||
|
|
||||||
handler := manager.Handler(baseHandler)
|
|
||||||
|
|
||||||
// Test health route
|
|
||||||
req := httptest.NewRequest("GET", "/health", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_Priority(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
manager := NewManager(logger)
|
|
||||||
|
|
||||||
// Load extensions in any order
|
|
||||||
manager.LoadExtension("routing", map[string]interface{}{}) // Priority 50
|
|
||||||
manager.LoadExtension("security", map[string]interface{}{}) // Priority 10
|
|
||||||
manager.LoadExtension("caching", map[string]interface{}{}) // Priority 20
|
|
||||||
|
|
||||||
exts := manager.Extensions()
|
|
||||||
if len(exts) != 3 {
|
|
||||||
t.Fatalf("Expected 3 extensions, got %d", len(exts))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check order by priority
|
|
||||||
if exts[0].Name() != "security" {
|
|
||||||
t.Errorf("Expected security first, got %s", exts[0].Name())
|
|
||||||
}
|
|
||||||
if exts[1].Name() != "caching" {
|
|
||||||
t.Errorf("Expected caching second, got %s", exts[1].Name())
|
|
||||||
}
|
|
||||||
if exts[2].Name() != "routing" {
|
|
||||||
t.Errorf("Expected routing third, got %s", exts[2].Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_Cleanup(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
manager := NewManager(logger)
|
|
||||||
|
|
||||||
manager.LoadExtension("security", map[string]interface{}{})
|
|
||||||
manager.LoadExtension("routing", map[string]interface{}{})
|
|
||||||
|
|
||||||
manager.Cleanup()
|
|
||||||
|
|
||||||
exts := manager.Extensions()
|
|
||||||
if len(exts) != 0 {
|
|
||||||
t.Errorf("Expected 0 extensions after cleanup, got %d", len(exts))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,428 +0,0 @@
|
|||||||
package extension
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
"github.com/konduktor/konduktor/internal/proxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RoutingExtension handles request routing based on patterns
|
|
||||||
type RoutingExtension struct {
|
|
||||||
BaseExtension
|
|
||||||
exactRoutes map[string]RouteConfig
|
|
||||||
regexRoutes []*regexRoute
|
|
||||||
defaultRoute *RouteConfig
|
|
||||||
staticDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
// RouteConfig holds configuration for a route
|
|
||||||
type RouteConfig struct {
|
|
||||||
ProxyPass string
|
|
||||||
Root string
|
|
||||||
IndexFile string
|
|
||||||
Return string
|
|
||||||
ContentType string
|
|
||||||
Headers []string
|
|
||||||
CacheControl string
|
|
||||||
SPAFallback bool
|
|
||||||
ExcludePatterns []string
|
|
||||||
Timeout float64
|
|
||||||
}
|
|
||||||
|
|
||||||
type regexRoute struct {
|
|
||||||
pattern *regexp.Regexp
|
|
||||||
config RouteConfig
|
|
||||||
caseSensitive bool
|
|
||||||
originalExpr string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRoutingExtension creates a new routing extension
|
|
||||||
func NewRoutingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
|
||||||
ext := &RoutingExtension{
|
|
||||||
BaseExtension: NewBaseExtension("routing", 50, logger), // Middle priority
|
|
||||||
exactRoutes: make(map[string]RouteConfig),
|
|
||||||
regexRoutes: make([]*regexRoute, 0),
|
|
||||||
staticDir: "./static",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Routing extension config", "config", config)
|
|
||||||
|
|
||||||
// Parse regex_locations from config
|
|
||||||
if locations, ok := config["regex_locations"].(map[string]interface{}); ok {
|
|
||||||
logger.Debug("Found regex_locations", "count", len(locations))
|
|
||||||
for pattern, routeCfg := range locations {
|
|
||||||
logger.Debug("Adding route", "pattern", pattern)
|
|
||||||
if rc, ok := routeCfg.(map[string]interface{}); ok {
|
|
||||||
ext.addRoute(pattern, parseRouteConfig(rc))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Warn("No regex_locations found in config", "config_keys", getKeys(config))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse static_dir if provided
|
|
||||||
if staticDir, ok := config["static_dir"].(string); ok {
|
|
||||||
ext.staticDir = staticDir
|
|
||||||
}
|
|
||||||
|
|
||||||
return ext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseRouteConfig(cfg map[string]interface{}) RouteConfig {
|
|
||||||
rc := RouteConfig{
|
|
||||||
IndexFile: "index.html",
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := cfg["proxy_pass"].(string); ok {
|
|
||||||
rc.ProxyPass = v
|
|
||||||
}
|
|
||||||
if v, ok := cfg["root"].(string); ok {
|
|
||||||
rc.Root = v
|
|
||||||
}
|
|
||||||
if v, ok := cfg["index_file"].(string); ok {
|
|
||||||
rc.IndexFile = v
|
|
||||||
}
|
|
||||||
if v, ok := cfg["return"].(string); ok {
|
|
||||||
rc.Return = v
|
|
||||||
}
|
|
||||||
if v, ok := cfg["content_type"].(string); ok {
|
|
||||||
rc.ContentType = v
|
|
||||||
}
|
|
||||||
if v, ok := cfg["cache_control"].(string); ok {
|
|
||||||
rc.CacheControl = v
|
|
||||||
}
|
|
||||||
if v, ok := cfg["spa_fallback"].(bool); ok {
|
|
||||||
rc.SPAFallback = v
|
|
||||||
}
|
|
||||||
if v, ok := cfg["timeout"].(float64); ok {
|
|
||||||
rc.Timeout = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse headers
|
|
||||||
if headers, ok := cfg["headers"].([]interface{}); ok {
|
|
||||||
for _, h := range headers {
|
|
||||||
if header, ok := h.(string); ok {
|
|
||||||
rc.Headers = append(rc.Headers, header)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse exclude_patterns
|
|
||||||
if patterns, ok := cfg["exclude_patterns"].([]interface{}); ok {
|
|
||||||
for _, p := range patterns {
|
|
||||||
if pattern, ok := p.(string); ok {
|
|
||||||
rc.ExcludePatterns = append(rc.ExcludePatterns, pattern)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return rc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *RoutingExtension) addRoute(pattern string, config RouteConfig) {
|
|
||||||
switch {
|
|
||||||
case pattern == "__default__":
|
|
||||||
e.defaultRoute = &config
|
|
||||||
|
|
||||||
case strings.HasPrefix(pattern, "="):
|
|
||||||
// Exact match
|
|
||||||
path := strings.TrimPrefix(pattern, "=")
|
|
||||||
e.exactRoutes[path] = config
|
|
||||||
|
|
||||||
case strings.HasPrefix(pattern, "~*"):
|
|
||||||
// Case-insensitive regex
|
|
||||||
expr := strings.TrimPrefix(pattern, "~*")
|
|
||||||
re, err := regexp.Compile("(?i)" + expr)
|
|
||||||
if err != nil {
|
|
||||||
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
e.regexRoutes = append(e.regexRoutes, ®exRoute{
|
|
||||||
pattern: re,
|
|
||||||
config: config,
|
|
||||||
caseSensitive: false,
|
|
||||||
originalExpr: expr,
|
|
||||||
})
|
|
||||||
|
|
||||||
case strings.HasPrefix(pattern, "~"):
|
|
||||||
// Case-sensitive regex
|
|
||||||
expr := strings.TrimPrefix(pattern, "~")
|
|
||||||
re, err := regexp.Compile(expr)
|
|
||||||
if err != nil {
|
|
||||||
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
e.regexRoutes = append(e.regexRoutes, ®exRoute{
|
|
||||||
pattern: re,
|
|
||||||
config: config,
|
|
||||||
caseSensitive: true,
|
|
||||||
originalExpr: expr,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessRequest handles the request routing
|
|
||||||
func (e *RoutingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
||||||
path := r.URL.Path
|
|
||||||
|
|
||||||
// 1. Check exact routes (ignore request path for proxy)
|
|
||||||
if config, ok := e.exactRoutes[path]; ok {
|
|
||||||
return e.handleRoute(w, r, config, nil, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Check regex routes
|
|
||||||
for _, route := range e.regexRoutes {
|
|
||||||
match := route.pattern.FindStringSubmatch(path)
|
|
||||||
if match != nil {
|
|
||||||
params := make(map[string]string)
|
|
||||||
names := route.pattern.SubexpNames()
|
|
||||||
for i, name := range names {
|
|
||||||
if i > 0 && name != "" && i < len(match) {
|
|
||||||
params[name] = match[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return e.handleRoute(w, r, route.config, params, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Check default route
|
|
||||||
if e.defaultRoute != nil {
|
|
||||||
return e.handleRoute(w, r, *e.defaultRoute, nil, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *RoutingExtension) handleRoute(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
|
|
||||||
// Handle "return" directive
|
|
||||||
if config.Return != "" {
|
|
||||||
return e.handleReturn(w, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle proxy_pass
|
|
||||||
if config.ProxyPass != "" {
|
|
||||||
return e.handleProxy(w, r, config, params, exactMatch)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle static files with root
|
|
||||||
if config.Root != "" {
|
|
||||||
return e.handleStatic(w, r, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle SPA fallback
|
|
||||||
if config.SPAFallback {
|
|
||||||
return e.handleSPAFallback(w, r, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *RoutingExtension) handleReturn(w http.ResponseWriter, config RouteConfig) (bool, error) {
|
|
||||||
parts := strings.SplitN(config.Return, " ", 2)
|
|
||||||
statusCode := 200
|
|
||||||
body := "OK"
|
|
||||||
|
|
||||||
if len(parts) >= 1 {
|
|
||||||
switch parts[0] {
|
|
||||||
case "200":
|
|
||||||
statusCode = 200
|
|
||||||
case "201":
|
|
||||||
statusCode = 201
|
|
||||||
case "301":
|
|
||||||
statusCode = 301
|
|
||||||
case "302":
|
|
||||||
statusCode = 302
|
|
||||||
case "400":
|
|
||||||
statusCode = 400
|
|
||||||
case "404":
|
|
||||||
statusCode = 404
|
|
||||||
case "500":
|
|
||||||
statusCode = 500
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(parts) >= 2 {
|
|
||||||
body = parts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
contentType := "text/plain"
|
|
||||||
if config.ContentType != "" {
|
|
||||||
contentType = config.ContentType
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", contentType)
|
|
||||||
w.WriteHeader(statusCode)
|
|
||||||
w.Write([]byte(body))
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *RoutingExtension) handleProxy(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
|
|
||||||
target := config.ProxyPass
|
|
||||||
|
|
||||||
// Check if target URL contains parameter placeholders
|
|
||||||
hasParams := strings.Contains(target, "{") && strings.Contains(target, "}")
|
|
||||||
|
|
||||||
// Substitute params in target URL
|
|
||||||
for key, value := range params {
|
|
||||||
target = strings.ReplaceAll(target, "{"+key+"}", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create proxy config
|
|
||||||
// IgnoreRequestPath=true when:
|
|
||||||
// - exact match route (=/path)
|
|
||||||
// - target URL had parameter substitutions (the target path is fully specified)
|
|
||||||
proxyConfig := &proxy.Config{
|
|
||||||
Target: target,
|
|
||||||
Headers: make(map[string]string),
|
|
||||||
IgnoreRequestPath: exactMatch || hasParams,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set timeout if specified
|
|
||||||
if config.Timeout > 0 {
|
|
||||||
proxyConfig.Timeout = time.Duration(config.Timeout * float64(time.Second))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse headers
|
|
||||||
clientIP := getClientIP(r)
|
|
||||||
for _, header := range config.Headers {
|
|
||||||
parts := strings.SplitN(header, ": ", 2)
|
|
||||||
if len(parts) == 2 {
|
|
||||||
value := parts[1]
|
|
||||||
// Substitute params
|
|
||||||
for key, pValue := range params {
|
|
||||||
value = strings.ReplaceAll(value, "{"+key+"}", pValue)
|
|
||||||
}
|
|
||||||
// Substitute special variables
|
|
||||||
value = strings.ReplaceAll(value, "$remote_addr", clientIP)
|
|
||||||
proxyConfig.Headers[parts[0]] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := proxy.New(proxyConfig, e.logger)
|
|
||||||
if err != nil {
|
|
||||||
e.logger.Error("Failed to create proxy", "target", target, "error", err)
|
|
||||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
p.ProxyRequest(w, r, params)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *RoutingExtension) handleStatic(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
|
|
||||||
path := r.URL.Path
|
|
||||||
|
|
||||||
// Handle index file for root or directory paths
|
|
||||||
if path == "/" || strings.HasSuffix(path, "/") {
|
|
||||||
path = "/" + config.IndexFile
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get absolute path for root dir
|
|
||||||
absRoot, err := filepath.Abs(config.Root)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath := filepath.Join(absRoot, filepath.Clean("/"+path))
|
|
||||||
cleanPath := filepath.Clean(filePath)
|
|
||||||
|
|
||||||
// Prevent directory traversal
|
|
||||||
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) {
|
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if file exists
|
|
||||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
|
||||||
return false, nil // Let other handlers try
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set cache control header
|
|
||||||
if config.CacheControl != "" {
|
|
||||||
w.Header().Set("Cache-Control", config.CacheControl)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set custom headers
|
|
||||||
for _, header := range config.Headers {
|
|
||||||
parts := strings.SplitN(header, ": ", 2)
|
|
||||||
if len(parts) == 2 {
|
|
||||||
w.Header().Set(parts[0], parts[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
http.ServeFile(w, r, filePath)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *RoutingExtension) handleSPAFallback(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
|
|
||||||
path := r.URL.Path
|
|
||||||
|
|
||||||
// Check exclude patterns
|
|
||||||
for _, pattern := range config.ExcludePatterns {
|
|
||||||
if strings.HasPrefix(path, pattern) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
root := config.Root
|
|
||||||
if root == "" {
|
|
||||||
root = e.staticDir
|
|
||||||
}
|
|
||||||
|
|
||||||
indexFile := config.IndexFile
|
|
||||||
if indexFile == "" {
|
|
||||||
indexFile = "index.html"
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath := filepath.Join(root, indexFile)
|
|
||||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
http.ServeFile(w, r, filePath)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getClientIP(r *http.Request) string {
|
|
||||||
// Check X-Forwarded-For header first
|
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
||||||
parts := strings.Split(xff, ",")
|
|
||||||
return strings.TrimSpace(parts[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check X-Real-IP header
|
|
||||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
||||||
return xri
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to RemoteAddr
|
|
||||||
ip := r.RemoteAddr
|
|
||||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
|
||||||
ip = ip[:idx]
|
|
||||||
}
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMetrics returns routing metrics
|
|
||||||
func (e *RoutingExtension) GetMetrics() map[string]interface{} {
|
|
||||||
return map[string]interface{}{
|
|
||||||
"exact_routes": len(e.exactRoutes),
|
|
||||||
"regex_routes": len(e.regexRoutes),
|
|
||||||
"has_default": e.defaultRoute != nil,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getKeys(m map[string]interface{}) []string {
|
|
||||||
keys := make([]string, 0, len(m))
|
|
||||||
for k := range m {
|
|
||||||
keys = append(keys, k)
|
|
||||||
}
|
|
||||||
return keys
|
|
||||||
}
|
|
||||||
@ -1,312 +0,0 @@
|
|||||||
package extension
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SecurityExtension provides security features like IP filtering and security headers
|
|
||||||
type SecurityExtension struct {
|
|
||||||
BaseExtension
|
|
||||||
allowedIPs map[string]bool
|
|
||||||
blockedIPs map[string]bool
|
|
||||||
allowedCIDRs []*net.IPNet
|
|
||||||
blockedCIDRs []*net.IPNet
|
|
||||||
securityHeaders map[string]string
|
|
||||||
|
|
||||||
// Rate limiting
|
|
||||||
rateLimitEnabled bool
|
|
||||||
rateLimitRequests int
|
|
||||||
rateLimitWindow time.Duration
|
|
||||||
rateLimitByIP map[string]*rateLimitEntry
|
|
||||||
rateLimitMu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
type rateLimitEntry struct {
|
|
||||||
count int
|
|
||||||
resetTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// SecurityConfig holds security extension configuration
|
|
||||||
type SecurityConfig struct {
|
|
||||||
AllowedIPs []string `yaml:"allowed_ips"`
|
|
||||||
BlockedIPs []string `yaml:"blocked_ips"`
|
|
||||||
SecurityHeaders map[string]string `yaml:"security_headers"`
|
|
||||||
RateLimit *RateLimitConfig `yaml:"rate_limit"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// RateLimitConfig holds rate limiting configuration
|
|
||||||
type RateLimitConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
Requests int `yaml:"requests"`
|
|
||||||
Window string `yaml:"window"` // e.g., "1m", "1h"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default security headers
|
|
||||||
var defaultSecurityHeaders = map[string]string{
|
|
||||||
"X-Content-Type-Options": "nosniff",
|
|
||||||
"X-Frame-Options": "DENY",
|
|
||||||
"X-XSS-Protection": "1; mode=block",
|
|
||||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSecurityExtension creates a new security extension
|
|
||||||
func NewSecurityExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
|
||||||
ext := &SecurityExtension{
|
|
||||||
BaseExtension: NewBaseExtension("security", 10, logger), // High priority (early execution)
|
|
||||||
allowedIPs: make(map[string]bool),
|
|
||||||
blockedIPs: make(map[string]bool),
|
|
||||||
allowedCIDRs: make([]*net.IPNet, 0),
|
|
||||||
blockedCIDRs: make([]*net.IPNet, 0),
|
|
||||||
securityHeaders: make(map[string]string),
|
|
||||||
rateLimitByIP: make(map[string]*rateLimitEntry),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy default security headers
|
|
||||||
for k, v := range defaultSecurityHeaders {
|
|
||||||
ext.securityHeaders[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse allowed_ips
|
|
||||||
if allowedIPs, ok := config["allowed_ips"].([]interface{}); ok {
|
|
||||||
for _, ip := range allowedIPs {
|
|
||||||
if ipStr, ok := ip.(string); ok {
|
|
||||||
ext.addAllowedIP(ipStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse blocked_ips
|
|
||||||
if blockedIPs, ok := config["blocked_ips"].([]interface{}); ok {
|
|
||||||
for _, ip := range blockedIPs {
|
|
||||||
if ipStr, ok := ip.(string); ok {
|
|
||||||
ext.addBlockedIP(ipStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse security_headers
|
|
||||||
if headers, ok := config["security_headers"].(map[string]interface{}); ok {
|
|
||||||
for k, v := range headers {
|
|
||||||
if vStr, ok := v.(string); ok {
|
|
||||||
ext.securityHeaders[k] = vStr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse rate_limit
|
|
||||||
if rateLimit, ok := config["rate_limit"].(map[string]interface{}); ok {
|
|
||||||
if enabled, ok := rateLimit["enabled"].(bool); ok && enabled {
|
|
||||||
ext.rateLimitEnabled = true
|
|
||||||
|
|
||||||
if requests, ok := rateLimit["requests"].(int); ok {
|
|
||||||
ext.rateLimitRequests = requests
|
|
||||||
} else if requestsFloat, ok := rateLimit["requests"].(float64); ok {
|
|
||||||
ext.rateLimitRequests = int(requestsFloat)
|
|
||||||
} else {
|
|
||||||
ext.rateLimitRequests = 100 // default
|
|
||||||
}
|
|
||||||
|
|
||||||
if window, ok := rateLimit["window"].(string); ok {
|
|
||||||
if duration, err := time.ParseDuration(window); err == nil {
|
|
||||||
ext.rateLimitWindow = duration
|
|
||||||
} else {
|
|
||||||
ext.rateLimitWindow = time.Minute // default
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ext.rateLimitWindow = time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Rate limiting enabled",
|
|
||||||
"requests", ext.rateLimitRequests,
|
|
||||||
"window", ext.rateLimitWindow.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *SecurityExtension) addAllowedIP(ip string) {
|
|
||||||
if strings.Contains(ip, "/") {
|
|
||||||
// CIDR notation
|
|
||||||
_, cidr, err := net.ParseCIDR(ip)
|
|
||||||
if err == nil {
|
|
||||||
e.allowedCIDRs = append(e.allowedCIDRs, cidr)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
e.allowedIPs[ip] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *SecurityExtension) addBlockedIP(ip string) {
|
|
||||||
if strings.Contains(ip, "/") {
|
|
||||||
// CIDR notation
|
|
||||||
_, cidr, err := net.ParseCIDR(ip)
|
|
||||||
if err == nil {
|
|
||||||
e.blockedCIDRs = append(e.blockedCIDRs, cidr)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
e.blockedIPs[ip] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessRequest checks security rules
|
|
||||||
func (e *SecurityExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
||||||
clientIP := getClientIP(r)
|
|
||||||
parsedIP := net.ParseIP(clientIP)
|
|
||||||
|
|
||||||
// Check blocked IPs first
|
|
||||||
if e.isBlocked(clientIP, parsedIP) {
|
|
||||||
e.logger.Warn("Blocked request from IP", "ip", clientIP)
|
|
||||||
http.Error(w, "403 Forbidden", http.StatusForbidden)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check allowed IPs (if configured, only these IPs are allowed)
|
|
||||||
if len(e.allowedIPs) > 0 || len(e.allowedCIDRs) > 0 {
|
|
||||||
if !e.isAllowed(clientIP, parsedIP) {
|
|
||||||
e.logger.Warn("Access denied for IP", "ip", clientIP)
|
|
||||||
http.Error(w, "403 Forbidden", http.StatusForbidden)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check rate limit
|
|
||||||
if e.rateLimitEnabled {
|
|
||||||
if !e.checkRateLimit(clientIP) {
|
|
||||||
e.logger.Warn("Rate limit exceeded", "ip", clientIP)
|
|
||||||
w.Header().Set("Retry-After", "60")
|
|
||||||
http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessResponse adds security headers to the response
|
|
||||||
func (e *SecurityExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
|
||||||
for header, value := range e.securityHeaders {
|
|
||||||
w.Header().Set(header, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *SecurityExtension) isBlocked(ip string, parsedIP net.IP) bool {
|
|
||||||
// Check exact match
|
|
||||||
if e.blockedIPs[ip] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check CIDR ranges
|
|
||||||
if parsedIP != nil {
|
|
||||||
for _, cidr := range e.blockedCIDRs {
|
|
||||||
if cidr.Contains(parsedIP) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *SecurityExtension) isAllowed(ip string, parsedIP net.IP) bool {
|
|
||||||
// Check exact match
|
|
||||||
if e.allowedIPs[ip] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check CIDR ranges
|
|
||||||
if parsedIP != nil {
|
|
||||||
for _, cidr := range e.allowedCIDRs {
|
|
||||||
if cidr.Contains(parsedIP) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *SecurityExtension) checkRateLimit(ip string) bool {
|
|
||||||
e.rateLimitMu.Lock()
|
|
||||||
defer e.rateLimitMu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
entry, exists := e.rateLimitByIP[ip]
|
|
||||||
|
|
||||||
if !exists || now.After(entry.resetTime) {
|
|
||||||
// Create new entry or reset expired one
|
|
||||||
e.rateLimitByIP[ip] = &rateLimitEntry{
|
|
||||||
count: 1,
|
|
||||||
resetTime: now.Add(e.rateLimitWindow),
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increment counter
|
|
||||||
entry.count++
|
|
||||||
return entry.count <= e.rateLimitRequests
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddBlockedIP adds an IP to the blocked list at runtime
|
|
||||||
func (e *SecurityExtension) AddBlockedIP(ip string) {
|
|
||||||
e.addBlockedIP(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveBlockedIP removes an IP from the blocked list
|
|
||||||
func (e *SecurityExtension) RemoveBlockedIP(ip string) {
|
|
||||||
delete(e.blockedIPs, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAllowedIP adds an IP to the allowed list at runtime
|
|
||||||
func (e *SecurityExtension) AddAllowedIP(ip string) {
|
|
||||||
e.addAllowedIP(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllowedIP removes an IP from the allowed list
|
|
||||||
func (e *SecurityExtension) RemoveAllowedIP(ip string) {
|
|
||||||
delete(e.allowedIPs, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSecurityHeader sets or updates a security header
|
|
||||||
func (e *SecurityExtension) SetSecurityHeader(name, value string) {
|
|
||||||
e.securityHeaders[name] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMetrics returns security metrics
|
|
||||||
func (e *SecurityExtension) GetMetrics() map[string]interface{} {
|
|
||||||
e.rateLimitMu.RLock()
|
|
||||||
activeRateLimits := len(e.rateLimitByIP)
|
|
||||||
e.rateLimitMu.RUnlock()
|
|
||||||
|
|
||||||
return map[string]interface{}{
|
|
||||||
"allowed_ips": len(e.allowedIPs),
|
|
||||||
"allowed_cidrs": len(e.allowedCIDRs),
|
|
||||||
"blocked_ips": len(e.blockedIPs),
|
|
||||||
"blocked_cidrs": len(e.blockedCIDRs),
|
|
||||||
"security_headers": len(e.securityHeaders),
|
|
||||||
"rate_limit_enabled": e.rateLimitEnabled,
|
|
||||||
"active_rate_limits": activeRateLimits,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup cleans up rate limit entries periodically
|
|
||||||
func (e *SecurityExtension) Cleanup() error {
|
|
||||||
e.rateLimitMu.Lock()
|
|
||||||
defer e.rateLimitMu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
for ip, entry := range e.rateLimitByIP {
|
|
||||||
if now.After(entry.resetTime) {
|
|
||||||
delete(e.rateLimitByIP, ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@ -1,213 +0,0 @@
|
|||||||
package extension
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewSecurityExtension(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
|
|
||||||
ext, err := NewSecurityExtension(map[string]interface{}{}, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create security extension: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ext.Name() != "security" {
|
|
||||||
t.Errorf("Expected name 'security', got %s", ext.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
if ext.Priority() != 10 {
|
|
||||||
t.Errorf("Expected priority 10, got %d", ext.Priority())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecurityExtension_BlockedIP(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
|
|
||||||
ext, _ := NewSecurityExtension(map[string]interface{}{
|
|
||||||
"blocked_ips": []interface{}{"192.168.1.100"},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
req.RemoteAddr = "192.168.1.100:12345"
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handled, err := ext.ProcessRequest(context.Background(), rr, req)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !handled {
|
|
||||||
t.Error("Expected blocked request to be handled")
|
|
||||||
}
|
|
||||||
|
|
||||||
if rr.Code != http.StatusForbidden {
|
|
||||||
t.Errorf("Expected status 403, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecurityExtension_AllowedIP(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
|
|
||||||
ext, _ := NewSecurityExtension(map[string]interface{}{
|
|
||||||
"allowed_ips": []interface{}{"192.168.1.50"},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
// Allowed IP
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
req.RemoteAddr = "192.168.1.50:12345"
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
|
|
||||||
if handled {
|
|
||||||
t.Error("Expected allowed IP to pass through")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not allowed IP
|
|
||||||
req = httptest.NewRequest("GET", "/test", nil)
|
|
||||||
req.RemoteAddr = "192.168.1.51:12345"
|
|
||||||
rr = httptest.NewRecorder()
|
|
||||||
|
|
||||||
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
|
|
||||||
if !handled {
|
|
||||||
t.Error("Expected non-allowed IP to be blocked")
|
|
||||||
}
|
|
||||||
|
|
||||||
if rr.Code != http.StatusForbidden {
|
|
||||||
t.Errorf("Expected status 403, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecurityExtension_CIDR(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
|
|
||||||
ext, _ := NewSecurityExtension(map[string]interface{}{
|
|
||||||
"blocked_ips": []interface{}{"10.0.0.0/8"},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
// IP in blocked CIDR
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
req.RemoteAddr = "10.1.2.3:12345"
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
|
|
||||||
if !handled {
|
|
||||||
t.Error("Expected IP in blocked CIDR to be blocked")
|
|
||||||
}
|
|
||||||
|
|
||||||
// IP not in blocked CIDR
|
|
||||||
req = httptest.NewRequest("GET", "/test", nil)
|
|
||||||
req.RemoteAddr = "192.168.1.1:12345"
|
|
||||||
rr = httptest.NewRecorder()
|
|
||||||
|
|
||||||
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
|
|
||||||
if handled {
|
|
||||||
t.Error("Expected IP not in blocked CIDR to pass through")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecurityExtension_SecurityHeaders(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
|
|
||||||
ext, _ := NewSecurityExtension(map[string]interface{}{
|
|
||||||
"security_headers": map[string]interface{}{
|
|
||||||
"X-Custom-Header": "custom-value",
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
ext.ProcessResponse(context.Background(), rr, req)
|
|
||||||
|
|
||||||
// Check default headers
|
|
||||||
if rr.Header().Get("X-Content-Type-Options") != "nosniff" {
|
|
||||||
t.Error("Expected X-Content-Type-Options header")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check custom header
|
|
||||||
if rr.Header().Get("X-Custom-Header") != "custom-value" {
|
|
||||||
t.Error("Expected custom header")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecurityExtension_RateLimit(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
|
|
||||||
ext, _ := NewSecurityExtension(map[string]interface{}{
|
|
||||||
"rate_limit": map[string]interface{}{
|
|
||||||
"enabled": true,
|
|
||||||
"requests": 2,
|
|
||||||
"window": "1m",
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
securityExt := ext.(*SecurityExtension)
|
|
||||||
clientIP := "192.168.1.1"
|
|
||||||
|
|
||||||
// First request - should pass
|
|
||||||
if !securityExt.checkRateLimit(clientIP) {
|
|
||||||
t.Error("First request should pass")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second request - should pass
|
|
||||||
if !securityExt.checkRateLimit(clientIP) {
|
|
||||||
t.Error("Second request should pass")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Third request - should be rate limited
|
|
||||||
if securityExt.checkRateLimit(clientIP) {
|
|
||||||
t.Error("Third request should be rate limited")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecurityExtension_GetMetrics(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
|
|
||||||
ext, _ := NewSecurityExtension(map[string]interface{}{
|
|
||||||
"blocked_ips": []interface{}{"192.168.1.1"},
|
|
||||||
"allowed_ips": []interface{}{"192.168.1.2"},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
securityExt := ext.(*SecurityExtension)
|
|
||||||
metrics := securityExt.GetMetrics()
|
|
||||||
|
|
||||||
if metrics["blocked_ips"].(int) != 1 {
|
|
||||||
t.Errorf("Expected 1 blocked IP, got %v", metrics["blocked_ips"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if metrics["allowed_ips"].(int) != 1 {
|
|
||||||
t.Errorf("Expected 1 allowed IP, got %v", metrics["allowed_ips"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecurityExtension_AddRemoveIPs(t *testing.T) {
|
|
||||||
logger := newTestLogger()
|
|
||||||
|
|
||||||
ext, _ := NewSecurityExtension(map[string]interface{}{}, logger)
|
|
||||||
securityExt := ext.(*SecurityExtension)
|
|
||||||
|
|
||||||
// Add blocked IP
|
|
||||||
securityExt.AddBlockedIP("192.168.1.100")
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
req.RemoteAddr = "192.168.1.100:12345"
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
|
|
||||||
if !handled {
|
|
||||||
t.Error("Expected dynamically blocked IP to be blocked")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove blocked IP
|
|
||||||
securityExt.RemoveBlockedIP("192.168.1.100")
|
|
||||||
|
|
||||||
rr = httptest.NewRecorder()
|
|
||||||
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
|
|
||||||
if handled {
|
|
||||||
t.Error("Expected removed blocked IP to pass through")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,338 +0,0 @@
|
|||||||
// Package logging provides structured logging with zap
|
|
||||||
package logging
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
|
||||||
"go.uber.org/zap/zapcore"
|
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Config is a simple configuration for basic logger setup
|
|
||||||
type Config struct {
|
|
||||||
Level string
|
|
||||||
TimestampFormat string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logger wraps zap.SugaredLogger with additional functionality
|
|
||||||
type Logger struct {
|
|
||||||
*zap.SugaredLogger
|
|
||||||
zap *zap.Logger
|
|
||||||
config *config.LoggingConfig
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new Logger with basic configuration
|
|
||||||
func New(cfg Config) (*Logger, error) {
|
|
||||||
level := parseLevel(cfg.Level)
|
|
||||||
|
|
||||||
timestampFormat := cfg.TimestampFormat
|
|
||||||
if timestampFormat == "" {
|
|
||||||
timestampFormat = "2006-01-02 15:04:05"
|
|
||||||
}
|
|
||||||
|
|
||||||
encoderConfig := zap.NewProductionEncoderConfig()
|
|
||||||
encoderConfig.TimeKey = "timestamp"
|
|
||||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(timestampFormat)
|
|
||||||
encoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
|
||||||
|
|
||||||
core := zapcore.NewCore(
|
|
||||||
zapcore.NewConsoleEncoder(encoderConfig),
|
|
||||||
zapcore.AddSync(os.Stdout),
|
|
||||||
level,
|
|
||||||
)
|
|
||||||
|
|
||||||
zapLogger := zap.New(core)
|
|
||||||
return &Logger{
|
|
||||||
SugaredLogger: zapLogger.Sugar(),
|
|
||||||
zap: zapLogger,
|
|
||||||
name: "konduktor",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFromConfig creates a Logger from full LoggingConfig
|
|
||||||
func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) {
|
|
||||||
var cores []zapcore.Core
|
|
||||||
|
|
||||||
// Parse main level
|
|
||||||
mainLevel := parseLevel(cfg.Level)
|
|
||||||
|
|
||||||
// Add console core if enabled
|
|
||||||
if cfg.ConsoleOutput {
|
|
||||||
consoleLevel := mainLevel
|
|
||||||
if cfg.Console != nil && cfg.Console.Level != "" {
|
|
||||||
consoleLevel = parseLevel(cfg.Console.Level)
|
|
||||||
}
|
|
||||||
|
|
||||||
var consoleEncoder zapcore.Encoder
|
|
||||||
formatConfig := cfg.Format
|
|
||||||
if cfg.Console != nil {
|
|
||||||
formatConfig = mergeFormatConfig(cfg.Format, cfg.Console.Format)
|
|
||||||
}
|
|
||||||
|
|
||||||
encoderCfg := createEncoderConfig(formatConfig)
|
|
||||||
if formatConfig.Type == "json" {
|
|
||||||
consoleEncoder = zapcore.NewJSONEncoder(encoderCfg)
|
|
||||||
} else {
|
|
||||||
if formatConfig.UseColors {
|
|
||||||
encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
|
||||||
}
|
|
||||||
consoleEncoder = zapcore.NewConsoleEncoder(encoderCfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
consoleSyncer := zapcore.AddSync(os.Stdout)
|
|
||||||
cores = append(cores, zapcore.NewCore(consoleEncoder, consoleSyncer, consoleLevel))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add file cores
|
|
||||||
for _, fileConfig := range cfg.Files {
|
|
||||||
fileCore, err := createFileCore(fileConfig, cfg.Format, mainLevel)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create file logger for %s: %w", fileConfig.Path, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If specific loggers are configured, wrap with filter
|
|
||||||
if len(fileConfig.Loggers) > 0 {
|
|
||||||
fileCore = &filteredCore{
|
|
||||||
Core: fileCore,
|
|
||||||
loggers: fileConfig.Loggers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cores = append(cores, fileCore)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no cores configured, add default console
|
|
||||||
if len(cores) == 0 {
|
|
||||||
encoderCfg := zap.NewProductionEncoderConfig()
|
|
||||||
encoderCfg.EncodeTime = zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05")
|
|
||||||
encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
|
||||||
cores = append(cores, zapcore.NewCore(
|
|
||||||
zapcore.NewConsoleEncoder(encoderCfg),
|
|
||||||
zapcore.AddSync(os.Stdout),
|
|
||||||
mainLevel,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Combine all cores
|
|
||||||
core := zapcore.NewTee(cores...)
|
|
||||||
zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1))
|
|
||||||
|
|
||||||
return &Logger{
|
|
||||||
SugaredLogger: zapLogger.Sugar(),
|
|
||||||
zap: zapLogger,
|
|
||||||
config: &cfg,
|
|
||||||
name: "konduktor",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Named returns a logger with a specific name (for filtering)
|
|
||||||
func (l *Logger) Named(name string) *Logger {
|
|
||||||
return &Logger{
|
|
||||||
SugaredLogger: l.SugaredLogger.Named(name),
|
|
||||||
zap: l.zap.Named(name),
|
|
||||||
config: l.config,
|
|
||||||
name: name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// With returns a logger with additional fields
|
|
||||||
func (l *Logger) With(args ...interface{}) *Logger {
|
|
||||||
return &Logger{
|
|
||||||
SugaredLogger: l.SugaredLogger.With(args...),
|
|
||||||
zap: l.zap.Sugar().With(args...).Desugar(),
|
|
||||||
config: l.config,
|
|
||||||
name: l.name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sync flushes any buffered log entries
|
|
||||||
func (l *Logger) Sync() error {
|
|
||||||
return l.zap.Sync()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetZap returns the underlying zap.Logger
|
|
||||||
func (l *Logger) GetZap() *zap.Logger {
|
|
||||||
return l.zap
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug logs a debug message
|
|
||||||
func (l *Logger) Debug(msg string, keysAndValues ...interface{}) {
|
|
||||||
l.SugaredLogger.Debugw(msg, keysAndValues...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Info logs an info message
|
|
||||||
func (l *Logger) Info(msg string, keysAndValues ...interface{}) {
|
|
||||||
l.SugaredLogger.Infow(msg, keysAndValues...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warn logs a warning message
|
|
||||||
func (l *Logger) Warn(msg string, keysAndValues ...interface{}) {
|
|
||||||
l.SugaredLogger.Warnw(msg, keysAndValues...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error logs an error message
|
|
||||||
func (l *Logger) Error(msg string, keysAndValues ...interface{}) {
|
|
||||||
l.SugaredLogger.Errorw(msg, keysAndValues...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fatal logs a fatal message and exits
|
|
||||||
func (l *Logger) Fatal(msg string, keysAndValues ...interface{}) {
|
|
||||||
l.SugaredLogger.Fatalw(msg, keysAndValues...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Helper functions ---
|
|
||||||
|
|
||||||
func parseLevel(level string) zapcore.Level {
|
|
||||||
switch strings.ToUpper(level) {
|
|
||||||
case "DEBUG":
|
|
||||||
return zapcore.DebugLevel
|
|
||||||
case "INFO":
|
|
||||||
return zapcore.InfoLevel
|
|
||||||
case "WARN", "WARNING":
|
|
||||||
return zapcore.WarnLevel
|
|
||||||
case "ERROR":
|
|
||||||
return zapcore.ErrorLevel
|
|
||||||
case "CRITICAL", "FATAL":
|
|
||||||
return zapcore.FatalLevel
|
|
||||||
default:
|
|
||||||
return zapcore.InfoLevel
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createEncoderConfig(format config.LogFormatConfig) zapcore.EncoderConfig {
|
|
||||||
timestampFormat := format.TimestampFormat
|
|
||||||
if timestampFormat == "" {
|
|
||||||
timestampFormat = "2006-01-02 15:04:05"
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := zapcore.EncoderConfig{
|
|
||||||
TimeKey: "timestamp",
|
|
||||||
LevelKey: "level",
|
|
||||||
NameKey: "logger",
|
|
||||||
CallerKey: "caller",
|
|
||||||
FunctionKey: zapcore.OmitKey,
|
|
||||||
MessageKey: "msg",
|
|
||||||
StacktraceKey: "stacktrace",
|
|
||||||
LineEnding: zapcore.DefaultLineEnding,
|
|
||||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
|
||||||
EncodeTime: zapcore.TimeEncoderOfLayout(timestampFormat),
|
|
||||||
EncodeDuration: zapcore.SecondsDurationEncoder,
|
|
||||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !format.ShowModule {
|
|
||||||
cfg.NameKey = zapcore.OmitKey
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
func mergeFormatConfig(base, override config.LogFormatConfig) config.LogFormatConfig {
|
|
||||||
result := base
|
|
||||||
if override.Type != "" {
|
|
||||||
result.Type = override.Type
|
|
||||||
}
|
|
||||||
if override.TimestampFormat != "" {
|
|
||||||
result.TimestampFormat = override.TimestampFormat
|
|
||||||
}
|
|
||||||
// UseColors and ShowModule are bool - check if override has non-default
|
|
||||||
result.UseColors = override.UseColors
|
|
||||||
result.ShowModule = override.ShowModule
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func createFileCore(fileConfig config.FileLogConfig, defaultFormat config.LogFormatConfig, defaultLevel zapcore.Level) (zapcore.Core, error) {
|
|
||||||
// Ensure directory exists
|
|
||||||
dir := filepath.Dir(fileConfig.Path)
|
|
||||||
if dir != "" && dir != "." {
|
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create log directory %s: %w", dir, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configure log rotation with lumberjack
|
|
||||||
maxSize := 10 // MB
|
|
||||||
if fileConfig.MaxBytes > 0 {
|
|
||||||
maxSize = int(fileConfig.MaxBytes / (1024 * 1024))
|
|
||||||
if maxSize < 1 {
|
|
||||||
maxSize = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
backupCount := 5
|
|
||||||
if fileConfig.BackupCount > 0 {
|
|
||||||
backupCount = fileConfig.BackupCount
|
|
||||||
}
|
|
||||||
|
|
||||||
rotator := &lumberjack.Logger{
|
|
||||||
Filename: fileConfig.Path,
|
|
||||||
MaxSize: maxSize,
|
|
||||||
MaxBackups: backupCount,
|
|
||||||
MaxAge: 30, // days
|
|
||||||
Compress: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine level
|
|
||||||
level := defaultLevel
|
|
||||||
if fileConfig.Level != "" {
|
|
||||||
level = parseLevel(fileConfig.Level)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create encoder
|
|
||||||
format := defaultFormat
|
|
||||||
if fileConfig.Format.Type != "" {
|
|
||||||
format = mergeFormatConfig(defaultFormat, fileConfig.Format)
|
|
||||||
}
|
|
||||||
// Files should not use colors
|
|
||||||
format.UseColors = false
|
|
||||||
|
|
||||||
encoderConfig := createEncoderConfig(format)
|
|
||||||
var encoder zapcore.Encoder
|
|
||||||
if format.Type == "json" {
|
|
||||||
encoder = zapcore.NewJSONEncoder(encoderConfig)
|
|
||||||
} else {
|
|
||||||
encoder = zapcore.NewConsoleEncoder(encoderConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
return zapcore.NewCore(encoder, zapcore.AddSync(rotator), level), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// filteredCore wraps a Core to filter by logger name
|
|
||||||
type filteredCore struct {
|
|
||||||
zapcore.Core
|
|
||||||
loggers []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *filteredCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
|
||||||
if !c.shouldLog(entry.LoggerName) {
|
|
||||||
return ce
|
|
||||||
}
|
|
||||||
return c.Core.Check(entry, ce)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *filteredCore) shouldLog(loggerName string) bool {
|
|
||||||
if len(c.loggers) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, allowed := range c.loggers {
|
|
||||||
if loggerName == allowed || strings.HasPrefix(loggerName, allowed+".") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *filteredCore) With(fields []zapcore.Field) zapcore.Core {
|
|
||||||
return &filteredCore{
|
|
||||||
Core: c.Core.With(fields),
|
|
||||||
loggers: c.loggers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,209 +0,0 @@
|
|||||||
package logging
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
logger, err := New(Config{Level: "INFO"})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if logger == nil {
|
|
||||||
t.Fatal("Expected logger, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if logger.name != "konduktor" {
|
|
||||||
t.Errorf("Expected name konduktor, got %s", logger.name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew_DefaultTimestampFormat(t *testing.T) {
|
|
||||||
logger, err := New(Config{Level: "DEBUG"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logger should be created successfully
|
|
||||||
if logger == nil {
|
|
||||||
t.Fatal("Expected logger, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew_CustomTimestampFormat(t *testing.T) {
|
|
||||||
logger, err := New(Config{
|
|
||||||
Level: "DEBUG",
|
|
||||||
TimestampFormat: "15:04:05",
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if logger == nil {
|
|
||||||
t.Fatal("Expected logger, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewFromConfig(t *testing.T) {
|
|
||||||
cfg := config.LoggingConfig{
|
|
||||||
Level: "DEBUG",
|
|
||||||
ConsoleOutput: true,
|
|
||||||
Format: config.LogFormatConfig{
|
|
||||||
Type: "standard",
|
|
||||||
UseColors: true,
|
|
||||||
ShowModule: true,
|
|
||||||
TimestampFormat: "2006-01-02 15:04:05",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
logger, err := NewFromConfig(cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if logger == nil {
|
|
||||||
t.Fatal("Expected logger, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewFromConfig_WithConsole(t *testing.T) {
|
|
||||||
cfg := config.LoggingConfig{
|
|
||||||
Level: "INFO",
|
|
||||||
ConsoleOutput: true,
|
|
||||||
Format: config.LogFormatConfig{
|
|
||||||
Type: "standard",
|
|
||||||
UseColors: true,
|
|
||||||
},
|
|
||||||
Console: &config.ConsoleLogConfig{
|
|
||||||
Level: "DEBUG",
|
|
||||||
Format: config.LogFormatConfig{
|
|
||||||
Type: "standard",
|
|
||||||
UseColors: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
logger, err := NewFromConfig(cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if logger == nil {
|
|
||||||
t.Fatal("Expected logger, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger_Debug(t *testing.T) {
|
|
||||||
logger, _ := New(Config{Level: "DEBUG"})
|
|
||||||
|
|
||||||
// Should not panic
|
|
||||||
logger.Debug("test message", "key", "value")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger_Info(t *testing.T) {
|
|
||||||
logger, _ := New(Config{Level: "INFO"})
|
|
||||||
|
|
||||||
// Should not panic
|
|
||||||
logger.Info("test message", "key", "value")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger_Warn(t *testing.T) {
|
|
||||||
logger, _ := New(Config{Level: "WARN"})
|
|
||||||
|
|
||||||
// Should not panic
|
|
||||||
logger.Warn("test message", "key", "value")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger_Error(t *testing.T) {
|
|
||||||
logger, _ := New(Config{Level: "ERROR"})
|
|
||||||
|
|
||||||
// Should not panic
|
|
||||||
logger.Error("test message", "key", "value")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger_Named(t *testing.T) {
|
|
||||||
logger, _ := New(Config{Level: "INFO"})
|
|
||||||
named := logger.Named("test.module")
|
|
||||||
|
|
||||||
if named == nil {
|
|
||||||
t.Fatal("Expected named logger, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if named.name != "test.module" {
|
|
||||||
t.Errorf("Expected name 'test.module', got %s", named.name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should not panic
|
|
||||||
named.Info("test from named logger")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger_With(t *testing.T) {
|
|
||||||
logger, _ := New(Config{Level: "INFO"})
|
|
||||||
withFields := logger.With("service", "test")
|
|
||||||
|
|
||||||
if withFields == nil {
|
|
||||||
t.Fatal("Expected logger with fields, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should not panic
|
|
||||||
withFields.Info("test with fields")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger_Sync(t *testing.T) {
|
|
||||||
logger, _ := New(Config{Level: "INFO"})
|
|
||||||
|
|
||||||
// Should not panic
|
|
||||||
err := logger.Sync()
|
|
||||||
// Sync may return an error for stdout on some systems, ignore it
|
|
||||||
_ = err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseLevel(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"DEBUG", "debug"},
|
|
||||||
{"INFO", "info"},
|
|
||||||
{"WARN", "warn"},
|
|
||||||
{"WARNING", "warn"},
|
|
||||||
{"ERROR", "error"},
|
|
||||||
{"CRITICAL", "fatal"},
|
|
||||||
{"FATAL", "fatal"},
|
|
||||||
{"invalid", "info"}, // defaults to INFO
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.input, func(t *testing.T) {
|
|
||||||
level := parseLevel(tt.input)
|
|
||||||
if level.String() != tt.expected {
|
|
||||||
t.Errorf("parseLevel(%s) = %s, want %s", tt.input, level.String(), tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Benchmarks ==============
|
|
||||||
|
|
||||||
func BenchmarkLogger_Info(b *testing.B) {
|
|
||||||
logger, _ := New(Config{Level: "INFO"})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
logger.Info("test message", "key", "value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLogger_Debug_Filtered(b *testing.B) {
|
|
||||||
logger, _ := New(Config{Level: "ERROR"})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
logger.Debug("test message", "key", "value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,74 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"runtime/debug"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
type responseWriter struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
status int
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *responseWriter) WriteHeader(code int) {
|
|
||||||
rw.status = code
|
|
||||||
rw.ResponseWriter.WriteHeader(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
|
||||||
size, err := rw.ResponseWriter.Write(b)
|
|
||||||
rw.size += size
|
|
||||||
return size, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func ServerHeader(next http.Handler, version string) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Server", fmt.Sprintf("konduktor/%s", version))
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func AccessLog(next http.Handler, logger *logging.Logger) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
wrapped := &responseWriter{
|
|
||||||
ResponseWriter: w,
|
|
||||||
status: http.StatusOK,
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(wrapped, r)
|
|
||||||
|
|
||||||
duration := time.Since(start)
|
|
||||||
|
|
||||||
logger.Info("HTTP request",
|
|
||||||
"method", r.Method,
|
|
||||||
"path", r.URL.Path,
|
|
||||||
"status", wrapped.status,
|
|
||||||
"duration_ms", duration.Milliseconds(),
|
|
||||||
"client_ip", r.RemoteAddr,
|
|
||||||
"user_agent", r.UserAgent(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func Recovery(next http.Handler, logger *logging.Logger) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
logger.Error("Panic recovered",
|
|
||||||
"error", fmt.Sprintf("%v", err),
|
|
||||||
"stack", string(debug.Stack()),
|
|
||||||
)
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -1,244 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ============== ServerHeader Tests ==============
|
|
||||||
|
|
||||||
func TestServerHeader(t *testing.T) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
wrapped := ServerHeader(handler, "1.0.0")
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
serverHeader := rr.Header().Get("Server")
|
|
||||||
if serverHeader != "konduktor/1.0.0" {
|
|
||||||
t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== AccessLog Tests ==============
|
|
||||||
|
|
||||||
func TestAccessLog(t *testing.T) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("Hello"))
|
|
||||||
})
|
|
||||||
|
|
||||||
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
|
||||||
wrapped := AccessLog(handler, logger)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccessLog_CapturesStatusCode(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
statusCode int
|
|
||||||
}{
|
|
||||||
{"OK", http.StatusOK},
|
|
||||||
{"NotFound", http.StatusNotFound},
|
|
||||||
{"InternalError", http.StatusInternalServerError},
|
|
||||||
{"Redirect", http.StatusMovedPermanently},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(tt.statusCode)
|
|
||||||
})
|
|
||||||
|
|
||||||
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
|
||||||
wrapped := AccessLog(handler, logger)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != tt.statusCode {
|
|
||||||
t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Recovery Tests ==============
|
|
||||||
|
|
||||||
func TestRecovery_NoPanic(t *testing.T) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
})
|
|
||||||
|
|
||||||
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
|
||||||
wrapped := Recovery(handler, logger)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rr.Body.String() != "OK" {
|
|
||||||
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRecovery_WithPanic(t *testing.T) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
panic("test panic")
|
|
||||||
})
|
|
||||||
|
|
||||||
logger, _ := logging.New(logging.Config{Level: "ERROR"})
|
|
||||||
wrapped := Recovery(handler, logger)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
// Should not panic
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusInternalServerError {
|
|
||||||
t.Errorf("Expected status 500, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(rr.Body.String(), "Internal Server Error") {
|
|
||||||
t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== responseWriter Tests ==============
|
|
||||||
|
|
||||||
func TestResponseWriter_WriteHeader(t *testing.T) {
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
|
|
||||||
|
|
||||||
rw.WriteHeader(http.StatusNotFound)
|
|
||||||
|
|
||||||
if rw.status != http.StatusNotFound {
|
|
||||||
t.Errorf("Expected status 404, got %d", rw.status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResponseWriter_Write(t *testing.T) {
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
|
|
||||||
|
|
||||||
n, err := rw.Write([]byte("Hello World"))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if n != 11 {
|
|
||||||
t.Errorf("Expected 11 bytes written, got %d", n)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rw.size != 11 {
|
|
||||||
t.Errorf("Expected size 11, got %d", rw.size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Middleware Chain Tests ==============
|
|
||||||
|
|
||||||
func TestMiddlewareChain(t *testing.T) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
})
|
|
||||||
|
|
||||||
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
|
||||||
|
|
||||||
// Apply middleware chain
|
|
||||||
wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
// Check all middleware worked
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rr.Header().Get("Server") != "konduktor/1.0.0" {
|
|
||||||
t.Errorf("Expected Server header")
|
|
||||||
}
|
|
||||||
|
|
||||||
if rr.Body.String() != "OK" {
|
|
||||||
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Benchmarks ==============
|
|
||||||
|
|
||||||
func BenchmarkServerHeader(b *testing.B) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
wrapped := ServerHeader(handler, "1.0.0")
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkAccessLog(b *testing.B) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead
|
|
||||||
wrapped := AccessLog(handler, logger)
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkRecovery(b *testing.B) {
|
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
logger, _ := logging.New(logging.Config{Level: "ERROR"})
|
|
||||||
wrapped := Recovery(handler, logger)
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
wrapped.ServeHTTP(rr, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,263 +0,0 @@
|
|||||||
package pathmatcher
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MountedPath struct {
|
|
||||||
path string
|
|
||||||
name string
|
|
||||||
stripPath bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMountedPath(path string, opts ...MountedPathOption) *MountedPath {
|
|
||||||
// Normalize: remove trailing slash (except for root)
|
|
||||||
normalizedPath := strings.TrimSuffix(path, "/")
|
|
||||||
if normalizedPath == "" {
|
|
||||||
normalizedPath = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &MountedPath{
|
|
||||||
path: normalizedPath,
|
|
||||||
name: normalizedPath,
|
|
||||||
stripPath: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.name == "" {
|
|
||||||
m.name = normalizedPath
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
type MountedPathOption func(*MountedPath)
|
|
||||||
|
|
||||||
func WithName(name string) MountedPathOption {
|
|
||||||
return func(m *MountedPath) {
|
|
||||||
m.name = name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithStripPath(strip bool) MountedPathOption {
|
|
||||||
return func(m *MountedPath) {
|
|
||||||
m.stripPath = strip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MountedPath) Path() string {
|
|
||||||
return m.path
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MountedPath) Name() string {
|
|
||||||
return m.name
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MountedPath) StripPath() bool {
|
|
||||||
return m.stripPath
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MountedPath) Matches(requestPath string) bool {
|
|
||||||
// Empty or "/" mount matches everything
|
|
||||||
if m.path == "" || m.path == "/" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request path must be at least as long as mount path
|
|
||||||
if len(requestPath) < len(m.path) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if request path starts with mount path
|
|
||||||
if !strings.HasPrefix(requestPath, m.path) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// If paths are equal length, it's a match
|
|
||||||
if len(requestPath) == len(m.path) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, next char must be '/' to prevent /api matching /api-v2
|
|
||||||
return requestPath[len(m.path)] == '/'
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MountedPath) GetModifiedPath(requestPath string) string {
|
|
||||||
if !m.stripPath {
|
|
||||||
return requestPath
|
|
||||||
}
|
|
||||||
|
|
||||||
// Root mount doesn't strip anything
|
|
||||||
if m.path == "" || m.path == "/" {
|
|
||||||
return requestPath
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strip the prefix
|
|
||||||
modified := strings.TrimPrefix(requestPath, m.path)
|
|
||||||
|
|
||||||
// Ensure result starts with /
|
|
||||||
if modified == "" || modified[0] != '/' {
|
|
||||||
modified = "/" + modified
|
|
||||||
}
|
|
||||||
|
|
||||||
return modified
|
|
||||||
}
|
|
||||||
|
|
||||||
type MountManager struct {
|
|
||||||
mounts []*MountedPath
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMountManager() *MountManager {
|
|
||||||
return &MountManager{
|
|
||||||
mounts: make([]*MountedPath, 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mm *MountManager) AddMount(mount *MountedPath) {
|
|
||||||
mm.mu.Lock()
|
|
||||||
defer mm.mu.Unlock()
|
|
||||||
|
|
||||||
// Insert in sorted order (longer paths first)
|
|
||||||
inserted := false
|
|
||||||
for i, existing := range mm.mounts {
|
|
||||||
if len(mount.path) > len(existing.path) {
|
|
||||||
// Insert at position i
|
|
||||||
mm.mounts = append(mm.mounts[:i], append([]*MountedPath{mount}, mm.mounts[i:]...)...)
|
|
||||||
inserted = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !inserted {
|
|
||||||
mm.mounts = append(mm.mounts, mount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mm *MountManager) RemoveMount(path string) bool {
|
|
||||||
mm.mu.Lock()
|
|
||||||
defer mm.mu.Unlock()
|
|
||||||
|
|
||||||
normalizedPath := strings.TrimSuffix(path, "/")
|
|
||||||
|
|
||||||
for i, mount := range mm.mounts {
|
|
||||||
if mount.path == normalizedPath {
|
|
||||||
mm.mounts = append(mm.mounts[:i], mm.mounts[i+1:]...)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mm *MountManager) GetMount(requestPath string) *MountedPath {
|
|
||||||
mm.mu.RLock()
|
|
||||||
defer mm.mu.RUnlock()
|
|
||||||
|
|
||||||
// Mounts are sorted by path length (longest first)
|
|
||||||
// so the first match is the best match
|
|
||||||
for _, mount := range mm.mounts {
|
|
||||||
if mount.Matches(requestPath) {
|
|
||||||
return mount
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mm *MountManager) MountCount() int {
|
|
||||||
mm.mu.RLock()
|
|
||||||
defer mm.mu.RUnlock()
|
|
||||||
return len(mm.mounts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mm *MountManager) Mounts() []*MountedPath {
|
|
||||||
mm.mu.RLock()
|
|
||||||
defer mm.mu.RUnlock()
|
|
||||||
|
|
||||||
result := make([]*MountedPath, len(mm.mounts))
|
|
||||||
copy(result, mm.mounts)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mm *MountManager) ListMounts() []map[string]interface{} {
|
|
||||||
mm.mu.RLock()
|
|
||||||
defer mm.mu.RUnlock()
|
|
||||||
|
|
||||||
result := make([]map[string]interface{}, len(mm.mounts))
|
|
||||||
for i, mount := range mm.mounts {
|
|
||||||
result[i] = map[string]interface{}{
|
|
||||||
"path": mount.path,
|
|
||||||
"name": mount.name,
|
|
||||||
"strip_path": mount.stripPath,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Utility functions
|
|
||||||
|
|
||||||
func PathMatchesPrefix(requestPath, prefix string) bool {
|
|
||||||
// Normalize prefix
|
|
||||||
prefix = strings.TrimSuffix(prefix, "/")
|
|
||||||
|
|
||||||
// Empty or "/" prefix matches everything
|
|
||||||
if prefix == "" || prefix == "/" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request path must be at least as long as prefix
|
|
||||||
if len(requestPath) < len(prefix) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if request path starts with prefix
|
|
||||||
if !strings.HasPrefix(requestPath, prefix) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// If paths are equal length, it's a match
|
|
||||||
if len(requestPath) == len(prefix) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, next char must be '/'
|
|
||||||
return requestPath[len(prefix)] == '/'
|
|
||||||
}
|
|
||||||
|
|
||||||
func StripPathPrefix(requestPath, prefix string) string {
|
|
||||||
// Normalize prefix
|
|
||||||
prefix = strings.TrimSuffix(prefix, "/")
|
|
||||||
|
|
||||||
// Empty or "/" prefix doesn't strip anything
|
|
||||||
if prefix == "" || prefix == "/" {
|
|
||||||
return requestPath
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strip the prefix
|
|
||||||
modified := strings.TrimPrefix(requestPath, prefix)
|
|
||||||
|
|
||||||
// Ensure result starts with /
|
|
||||||
if modified == "" || modified[0] != '/' {
|
|
||||||
modified = "/" + modified
|
|
||||||
}
|
|
||||||
|
|
||||||
return modified
|
|
||||||
}
|
|
||||||
|
|
||||||
func MatchAndModifyPath(requestPath, prefix string, stripPath bool) (matches bool, modifiedPath string) {
|
|
||||||
if !PathMatchesPrefix(requestPath, prefix) {
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if stripPath {
|
|
||||||
return true, StripPathPrefix(requestPath, prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, requestPath
|
|
||||||
}
|
|
||||||
@ -1,460 +0,0 @@
|
|||||||
package pathmatcher
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ============== MountedPath Tests ==============
|
|
||||||
|
|
||||||
func TestMountedPath_RootMountMatchesEverything(t *testing.T) {
|
|
||||||
mount := NewMountedPath("")
|
|
||||||
|
|
||||||
tests := []string{"/", "/api", "/api/users", "/anything/at/all"}
|
|
||||||
|
|
||||||
for _, path := range tests {
|
|
||||||
if !mount.Matches(path) {
|
|
||||||
t.Errorf("Root mount should match %s", path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_SlashRootMountMatchesEverything(t *testing.T) {
|
|
||||||
mount := NewMountedPath("/")
|
|
||||||
|
|
||||||
tests := []string{"/", "/api", "/api/users"}
|
|
||||||
|
|
||||||
for _, path := range tests {
|
|
||||||
if !mount.Matches(path) {
|
|
||||||
t.Errorf("'/' mount should match %s", path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_ExactPathMatch(t *testing.T) {
|
|
||||||
mount := NewMountedPath("/api")
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
path string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{"/api", true},
|
|
||||||
{"/api/", true},
|
|
||||||
{"/api/users", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
if got := mount.Matches(tt.path); got != tt.expected {
|
|
||||||
t.Errorf("Matches(%s) = %v, want %v", tt.path, got, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_NoFalsePrefixMatch(t *testing.T) {
|
|
||||||
mount := NewMountedPath("/api")
|
|
||||||
|
|
||||||
tests := []string{"/api-v2", "/api2", "/apiv2"}
|
|
||||||
|
|
||||||
for _, path := range tests {
|
|
||||||
if mount.Matches(path) {
|
|
||||||
t.Errorf("/api should not match %s", path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_ShorterPathNoMatch(t *testing.T) {
|
|
||||||
mount := NewMountedPath("/api/v1")
|
|
||||||
|
|
||||||
tests := []string{"/api", "/ap", "/"}
|
|
||||||
|
|
||||||
for _, path := range tests {
|
|
||||||
if mount.Matches(path) {
|
|
||||||
t.Errorf("/api/v1 should not match shorter path %s", path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_TrailingSlashNormalized(t *testing.T) {
|
|
||||||
mount1 := NewMountedPath("/api/")
|
|
||||||
mount2 := NewMountedPath("/api")
|
|
||||||
|
|
||||||
if mount1.Path() != "/api" {
|
|
||||||
t.Errorf("Expected path /api, got %s", mount1.Path())
|
|
||||||
}
|
|
||||||
|
|
||||||
if mount2.Path() != "/api" {
|
|
||||||
t.Errorf("Expected path /api, got %s", mount2.Path())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !mount1.Matches("/api/users") {
|
|
||||||
t.Error("mount1 should match /api/users")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !mount2.Matches("/api/users") {
|
|
||||||
t.Error("mount2 should match /api/users")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_GetModifiedPathStripsPrefix(t *testing.T) {
|
|
||||||
mount := NewMountedPath("/api")
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"/api", "/"},
|
|
||||||
{"/api/", "/"},
|
|
||||||
{"/api/users", "/users"},
|
|
||||||
{"/api/users/123", "/users/123"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
|
||||||
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_GetModifiedPathNoStrip(t *testing.T) {
|
|
||||||
mount := NewMountedPath("/api", WithStripPath(false))
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"/api/users", "/api/users"},
|
|
||||||
{"/api", "/api"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
|
||||||
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_RootMountModifiedPath(t *testing.T) {
|
|
||||||
mount := NewMountedPath("")
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"/api/users", "/api/users"},
|
|
||||||
{"/", "/"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
|
||||||
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountedPath_NameProperty(t *testing.T) {
|
|
||||||
mount1 := NewMountedPath("/api")
|
|
||||||
mount2 := NewMountedPath("/api", WithName("API Mount"))
|
|
||||||
|
|
||||||
if mount1.Name() != "/api" {
|
|
||||||
t.Errorf("Expected name /api, got %s", mount1.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
if mount2.Name() != "API Mount" {
|
|
||||||
t.Errorf("Expected name 'API Mount', got %s", mount2.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== MountManager Tests ==============
|
|
||||||
|
|
||||||
func TestMountManager_EmptyManager(t *testing.T) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
|
|
||||||
if got := manager.GetMount("/api"); got != nil {
|
|
||||||
t.Error("Empty manager should return nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if got := manager.MountCount(); got != 0 {
|
|
||||||
t.Errorf("Expected mount count 0, got %d", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountManager_AddMount(t *testing.T) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
mount := NewMountedPath("/api")
|
|
||||||
|
|
||||||
manager.AddMount(mount)
|
|
||||||
|
|
||||||
if manager.MountCount() != 1 {
|
|
||||||
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
|
|
||||||
}
|
|
||||||
|
|
||||||
if got := manager.GetMount("/api/users"); got != mount {
|
|
||||||
t.Error("GetMount should return the added mount")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountManager_LongestPrefixMatching(t *testing.T) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
|
|
||||||
apiMount := NewMountedPath("/api", WithName("api"))
|
|
||||||
apiV1Mount := NewMountedPath("/api/v1", WithName("api_v1"))
|
|
||||||
apiV2Mount := NewMountedPath("/api/v2", WithName("api_v2"))
|
|
||||||
|
|
||||||
manager.AddMount(apiMount)
|
|
||||||
manager.AddMount(apiV2Mount)
|
|
||||||
manager.AddMount(apiV1Mount)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
path string
|
|
||||||
expectedName string
|
|
||||||
}{
|
|
||||||
{"/api/v1/users", "api_v1"},
|
|
||||||
{"/api/v2/items", "api_v2"},
|
|
||||||
{"/api/v3/other", "api"},
|
|
||||||
{"/api", "api"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
got := manager.GetMount(tt.path)
|
|
||||||
if got == nil {
|
|
||||||
t.Errorf("GetMount(%s) returned nil, want mount with name %s", tt.path, tt.expectedName)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if got.Name() != tt.expectedName {
|
|
||||||
t.Errorf("GetMount(%s).Name() = %s, want %s", tt.path, got.Name(), tt.expectedName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountManager_RemoveMount(t *testing.T) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
|
|
||||||
manager.AddMount(NewMountedPath("/api"))
|
|
||||||
manager.AddMount(NewMountedPath("/admin"))
|
|
||||||
|
|
||||||
if manager.MountCount() != 2 {
|
|
||||||
t.Errorf("Expected mount count 2, got %d", manager.MountCount())
|
|
||||||
}
|
|
||||||
|
|
||||||
result := manager.RemoveMount("/api")
|
|
||||||
|
|
||||||
if !result {
|
|
||||||
t.Error("RemoveMount should return true")
|
|
||||||
}
|
|
||||||
|
|
||||||
if manager.MountCount() != 1 {
|
|
||||||
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
|
|
||||||
}
|
|
||||||
|
|
||||||
if manager.GetMount("/api/users") != nil {
|
|
||||||
t.Error("GetMount(/api/users) should return nil after removal")
|
|
||||||
}
|
|
||||||
|
|
||||||
if manager.GetMount("/admin/users") == nil {
|
|
||||||
t.Error("GetMount(/admin/users) should still work")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountManager_RemoveNonexistentMount(t *testing.T) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
|
|
||||||
result := manager.RemoveMount("/api")
|
|
||||||
|
|
||||||
if result {
|
|
||||||
t.Error("RemoveMount should return false for nonexistent mount")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountManager_ListMounts(t *testing.T) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
|
|
||||||
manager.AddMount(NewMountedPath("/api", WithName("API")))
|
|
||||||
manager.AddMount(NewMountedPath("/admin", WithName("Admin")))
|
|
||||||
|
|
||||||
mounts := manager.ListMounts()
|
|
||||||
|
|
||||||
if len(mounts) != 2 {
|
|
||||||
t.Errorf("Expected 2 mounts, got %d", len(mounts))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range mounts {
|
|
||||||
if _, ok := m["path"]; !ok {
|
|
||||||
t.Error("Mount should have 'path' key")
|
|
||||||
}
|
|
||||||
if _, ok := m["name"]; !ok {
|
|
||||||
t.Error("Mount should have 'name' key")
|
|
||||||
}
|
|
||||||
if _, ok := m["strip_path"]; !ok {
|
|
||||||
t.Error("Mount should have 'strip_path' key")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMountManager_MountsReturnsCopy(t *testing.T) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
manager.AddMount(NewMountedPath("/api"))
|
|
||||||
|
|
||||||
mounts1 := manager.Mounts()
|
|
||||||
mounts2 := manager.Mounts()
|
|
||||||
|
|
||||||
if &mounts1[0] == &mounts2[0] {
|
|
||||||
t.Error("Mounts() should return different slices")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Utility Functions Tests ==============
|
|
||||||
|
|
||||||
func TestPathMatchesPrefix_Basic(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
path string
|
|
||||||
prefix string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{"/api/users", "/api", true},
|
|
||||||
{"/api", "/api", true},
|
|
||||||
{"/api-v2", "/api", false},
|
|
||||||
{"/ap", "/api", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
|
|
||||||
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPathMatchesPrefix_Root(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
path string
|
|
||||||
prefix string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{"/anything", "", true},
|
|
||||||
{"/anything", "/", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
|
|
||||||
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStripPathPrefix_Basic(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
path string
|
|
||||||
prefix string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"/api/users", "/api", "/users"},
|
|
||||||
{"/api", "/api", "/"},
|
|
||||||
{"/api/", "/api", "/"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
|
|
||||||
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStripPathPrefix_Root(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
path string
|
|
||||||
prefix string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"/api/users", "", "/api/users"},
|
|
||||||
{"/api/users", "/", "/api/users"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
|
|
||||||
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMatchAndModifyPath_Combined(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
path string
|
|
||||||
prefix string
|
|
||||||
stripPath bool
|
|
||||||
wantMatches bool
|
|
||||||
wantModified string
|
|
||||||
}{
|
|
||||||
{"/api/users", "/api", true, true, "/users"},
|
|
||||||
{"/api", "/api", true, true, "/"},
|
|
||||||
{"/other", "/api", true, false, ""},
|
|
||||||
{"/api/users", "/api", false, true, "/api/users"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
matches, modified := MatchAndModifyPath(tt.path, tt.prefix, tt.stripPath)
|
|
||||||
if matches != tt.wantMatches {
|
|
||||||
t.Errorf("MatchAndModifyPath(%s, %s, %v) matches = %v, want %v",
|
|
||||||
tt.path, tt.prefix, tt.stripPath, matches, tt.wantMatches)
|
|
||||||
}
|
|
||||||
if modified != tt.wantModified {
|
|
||||||
t.Errorf("MatchAndModifyPath(%s, %s, %v) modified = %s, want %s",
|
|
||||||
tt.path, tt.prefix, tt.stripPath, modified, tt.wantModified)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Performance Tests ==============
|
|
||||||
|
|
||||||
func TestPerformance_ManyMatches(t *testing.T) {
|
|
||||||
mount := NewMountedPath("/api/v1/users")
|
|
||||||
|
|
||||||
for i := 0; i < 10000; i++ {
|
|
||||||
if !mount.Matches("/api/v1/users/123/posts") {
|
|
||||||
t.Fatal("Should match")
|
|
||||||
}
|
|
||||||
if mount.Matches("/other/path") {
|
|
||||||
t.Fatal("Should not match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPerformance_ManyMounts(t *testing.T) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10)) + string(rune('0'+i/10))))
|
|
||||||
}
|
|
||||||
|
|
||||||
if manager.MountCount() != 100 {
|
|
||||||
t.Errorf("Expected 100 mounts, got %d", manager.MountCount())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Benchmarks ==============
|
|
||||||
|
|
||||||
func BenchmarkMountedPath_Matches(b *testing.B) {
|
|
||||||
mount := NewMountedPath("/api/v1")
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
mount.Matches("/api/v1/users/123")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkMountManager_GetMount(b *testing.B) {
|
|
||||||
manager := NewMountManager()
|
|
||||||
|
|
||||||
for i := 0; i < 20; i++ {
|
|
||||||
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10))))
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
manager.GetMount("/api/v5/users/123")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkPathMatchesPrefix(b *testing.B) {
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
PathMatchesPrefix("/api/v1/users/123", "/api/v1")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,333 +0,0 @@
|
|||||||
// Package proxy provides reverse proxy functionality for Konduktor
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
// Target is the backend server URL
|
|
||||||
Target string
|
|
||||||
|
|
||||||
// Timeout is the request timeout (default: 30s)
|
|
||||||
Timeout time.Duration
|
|
||||||
|
|
||||||
// Headers are additional headers to add to requests
|
|
||||||
Headers map[string]string
|
|
||||||
|
|
||||||
// StripPrefix removes this prefix from the request path
|
|
||||||
StripPrefix string
|
|
||||||
|
|
||||||
// PreserveHost keeps the original Host header
|
|
||||||
PreserveHost bool
|
|
||||||
|
|
||||||
// IgnoreRequestPath ignores the request path and uses only the target path
|
|
||||||
// This is useful for exact match routes where target URL should be used as-is
|
|
||||||
IgnoreRequestPath bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type ReverseProxy struct {
|
|
||||||
config *Config
|
|
||||||
targetURL *url.URL
|
|
||||||
httpClient *http.Client
|
|
||||||
logger *logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) {
|
|
||||||
if cfg.Target == "" {
|
|
||||||
return nil, fmt.Errorf("proxy target is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
targetURL, err := url.Parse(cfg.Target)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid proxy target URL: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := cfg.Timeout
|
|
||||||
if timeout == 0 {
|
|
||||||
timeout = 30 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := &http.Transport{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: (&net.Dialer{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).DialContext,
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
ResponseHeaderTimeout: timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ReverseProxy{
|
|
||||||
config: cfg,
|
|
||||||
targetURL: targetURL,
|
|
||||||
httpClient: &http.Client{
|
|
||||||
Transport: transport,
|
|
||||||
Timeout: timeout,
|
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse // Don't follow redirects
|
|
||||||
},
|
|
||||||
},
|
|
||||||
logger: logger,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
rp.ProxyRequest(w, r, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
|
||||||
ctx := r.Context()
|
|
||||||
|
|
||||||
// Build target URL
|
|
||||||
targetURL := rp.buildTargetURL(r)
|
|
||||||
|
|
||||||
// Create proxy request
|
|
||||||
proxyReq, err := rp.createProxyRequest(ctx, r, targetURL)
|
|
||||||
if err != nil {
|
|
||||||
rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add custom headers with parameter substitution
|
|
||||||
rp.addCustomHeaders(proxyReq, r, params)
|
|
||||||
|
|
||||||
// Execute request
|
|
||||||
resp, err := rp.httpClient.Do(proxyReq)
|
|
||||||
if err != nil {
|
|
||||||
rp.handleProxyError(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Copy response
|
|
||||||
rp.copyResponse(w, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
|
|
||||||
targetURL := *rp.targetURL
|
|
||||||
|
|
||||||
// If ignoring request path, use target URL path as-is
|
|
||||||
if rp.config.IgnoreRequestPath {
|
|
||||||
// Preserve query string only
|
|
||||||
targetURL.RawQuery = r.URL.RawQuery
|
|
||||||
return &targetURL
|
|
||||||
}
|
|
||||||
|
|
||||||
// Strip prefix if configured
|
|
||||||
path := r.URL.Path
|
|
||||||
if rp.config.StripPrefix != "" {
|
|
||||||
path = strings.TrimPrefix(path, rp.config.StripPrefix)
|
|
||||||
if path == "" || path[0] != '/' {
|
|
||||||
path = "/" + path
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If target URL has a non-empty path, combine it with the request path
|
|
||||||
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
|
|
||||||
// Combine target path with request path
|
|
||||||
targetURL.Path = strings.TrimSuffix(rp.targetURL.Path, "/") + path
|
|
||||||
} else {
|
|
||||||
// No path in target, use request path as-is
|
|
||||||
targetURL.Path = path
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preserve query string
|
|
||||||
targetURL.RawQuery = r.URL.RawQuery
|
|
||||||
|
|
||||||
return &targetURL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) {
|
|
||||||
proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy ContentLength
|
|
||||||
proxyReq.ContentLength = r.ContentLength
|
|
||||||
|
|
||||||
// Copy headers
|
|
||||||
for key, values := range r.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
proxyReq.Header.Add(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set/update Host header
|
|
||||||
if rp.config.PreserveHost {
|
|
||||||
proxyReq.Host = r.Host
|
|
||||||
} else {
|
|
||||||
proxyReq.Host = targetURL.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove hop-by-hop headers
|
|
||||||
removeHopByHopHeaders(proxyReq.Header)
|
|
||||||
|
|
||||||
return proxyReq, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) {
|
|
||||||
// Add X-Forwarded headers
|
|
||||||
clientIP := getClientIP(originalReq)
|
|
||||||
if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" {
|
|
||||||
clientIP = prior + ", " + clientIP
|
|
||||||
}
|
|
||||||
proxyReq.Header.Set("X-Forwarded-For", clientIP)
|
|
||||||
proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq))
|
|
||||||
proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host)
|
|
||||||
|
|
||||||
// Add custom headers from config
|
|
||||||
for key, value := range rp.config.Headers {
|
|
||||||
// Substitute parameters like {version}
|
|
||||||
substituted := value
|
|
||||||
for paramKey, paramValue := range params {
|
|
||||||
substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue)
|
|
||||||
}
|
|
||||||
// Substitute $remote_addr
|
|
||||||
substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP)
|
|
||||||
proxyReq.Header.Set(key, substituted)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) {
|
|
||||||
// Copy headers
|
|
||||||
for key, values := range resp.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
w.Header().Add(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove hop-by-hop headers from response
|
|
||||||
removeHopByHopHeaders(w.Header())
|
|
||||||
|
|
||||||
// Write status code
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
// Copy body
|
|
||||||
io.Copy(w, resp.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) {
|
|
||||||
if rp.logger != nil {
|
|
||||||
rp.logger.Error(message, "error", err)
|
|
||||||
}
|
|
||||||
http.Error(w, message, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) {
|
|
||||||
if rp.logger != nil {
|
|
||||||
rp.logger.Error("Proxy request failed", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for timeout
|
|
||||||
if err, ok := err.(net.Error); ok && err.Timeout() {
|
|
||||||
http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for connection errors
|
|
||||||
if isConnectionError(err) {
|
|
||||||
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Context cancelled (client disconnected)
|
|
||||||
if err == context.Canceled {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
|
|
||||||
func singleJoiningSlash(a, b string) string {
|
|
||||||
aslash := strings.HasSuffix(a, "/")
|
|
||||||
bslash := strings.HasPrefix(b, "/")
|
|
||||||
switch {
|
|
||||||
case aslash && bslash:
|
|
||||||
return a + b[1:]
|
|
||||||
case !aslash && !bslash:
|
|
||||||
return a + "/" + b
|
|
||||||
}
|
|
||||||
return a + b
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeHopByHopHeaders(h http.Header) {
|
|
||||||
hopByHopHeaders := []string{
|
|
||||||
"Connection",
|
|
||||||
"Proxy-Connection",
|
|
||||||
"Keep-Alive",
|
|
||||||
"Proxy-Authenticate",
|
|
||||||
"Proxy-Authorization",
|
|
||||||
"Te",
|
|
||||||
"Trailer",
|
|
||||||
"Transfer-Encoding",
|
|
||||||
"Upgrade",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, header := range hopByHopHeaders {
|
|
||||||
h.Del(header)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getClientIP(r *http.Request) string {
|
|
||||||
// Check X-Real-IP first
|
|
||||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get from RemoteAddr
|
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return r.RemoteAddr
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
||||||
func getScheme(r *http.Request) string {
|
|
||||||
if r.TLS != nil {
|
|
||||||
return "https"
|
|
||||||
}
|
|
||||||
if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
|
||||||
return scheme
|
|
||||||
}
|
|
||||||
return "http"
|
|
||||||
}
|
|
||||||
|
|
||||||
func isConnectionError(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
errStr := err.Error()
|
|
||||||
connectionErrors := []string{
|
|
||||||
"connection refused",
|
|
||||||
"no such host",
|
|
||||||
"network is unreachable",
|
|
||||||
"connection reset",
|
|
||||||
"broken pipe",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, connErr := range connectionErrors {
|
|
||||||
if strings.Contains(strings.ToLower(errStr), connErr) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@ -1,747 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ============== Test Backend Server ==============
|
|
||||||
|
|
||||||
type testBackend struct {
|
|
||||||
server *httptest.Server
|
|
||||||
requestLog []requestLogEntry
|
|
||||||
mu sync.Mutex
|
|
||||||
requestCount int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type requestLogEntry struct {
|
|
||||||
Method string
|
|
||||||
Path string
|
|
||||||
Query string
|
|
||||||
Headers http.Header
|
|
||||||
Body string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestBackend(handler http.HandlerFunc) *testBackend {
|
|
||||||
tb := &testBackend{
|
|
||||||
requestLog: make([]requestLogEntry, 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
if handler == nil {
|
|
||||||
handler = tb.defaultHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tb.logRequest(r)
|
|
||||||
handler(w, r)
|
|
||||||
}))
|
|
||||||
|
|
||||||
return tb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb *testBackend) logRequest(r *http.Request) {
|
|
||||||
tb.mu.Lock()
|
|
||||||
defer tb.mu.Unlock()
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
// Restore the body for the handler
|
|
||||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
||||||
|
|
||||||
tb.requestLog = append(tb.requestLog, requestLogEntry{
|
|
||||||
Method: r.Method,
|
|
||||||
Path: r.URL.Path,
|
|
||||||
Query: r.URL.RawQuery,
|
|
||||||
Headers: r.Header.Clone(),
|
|
||||||
Body: string(body),
|
|
||||||
})
|
|
||||||
atomic.AddInt64(&tb.requestCount, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb *testBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"message": "Backend response",
|
|
||||||
"path": r.URL.Path,
|
|
||||||
"method": r.Method,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb *testBackend) close() {
|
|
||||||
tb.server.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb *testBackend) URL() string {
|
|
||||||
return tb.server.URL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb *testBackend) getRequestCount() int64 {
|
|
||||||
return atomic.LoadInt64(&tb.requestCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb *testBackend) getLastRequest() *requestLogEntry {
|
|
||||||
tb.mu.Lock()
|
|
||||||
defer tb.mu.Unlock()
|
|
||||||
if len(tb.requestLog) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &tb.requestLog[len(tb.requestLog)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Proxy Creation Tests ==============
|
|
||||||
|
|
||||||
func TestNew_ValidConfig(t *testing.T) {
|
|
||||||
cfg := &Config{
|
|
||||||
Target: "http://localhost:8080",
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy, err := New(cfg, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create proxy: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if proxy == nil {
|
|
||||||
t.Fatal("Expected proxy instance")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew_EmptyTarget(t *testing.T) {
|
|
||||||
cfg := &Config{
|
|
||||||
Target: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := New(cfg, nil)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error for empty target")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew_InvalidTargetURL(t *testing.T) {
|
|
||||||
cfg := &Config{
|
|
||||||
Target: "://invalid-url",
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := New(cfg, nil)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error for invalid URL")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew_DefaultTimeout(t *testing.T) {
|
|
||||||
cfg := &Config{
|
|
||||||
Target: "http://localhost:8080",
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy, err := New(cfg, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create proxy: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if proxy.httpClient.Timeout != 30*time.Second {
|
|
||||||
t.Errorf("Expected default timeout 30s, got %v", proxy.httpClient.Timeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Basic Proxy Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_BasicGET(t *testing.T) {
|
|
||||||
backend := newTestBackend(nil)
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
var response map[string]interface{}
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["path"] != "/test" {
|
|
||||||
t.Errorf("Expected path /test, got %v", response["path"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_BasicPOST(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"received": string(body),
|
|
||||||
"method": r.Method,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/data", strings.NewReader(`{"key":"value"}`))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected status 200, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
var response map[string]interface{}
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["method"] != "POST" {
|
|
||||||
t.Errorf("Expected method POST, got %v", response["method"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_PUT(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("PUT", "/resource/123", strings.NewReader(`{"name":"updated"}`))
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
var response map[string]string
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["method"] != "PUT" {
|
|
||||||
t.Errorf("Expected method PUT, got %v", response["method"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_DELETE(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("DELETE", "/resource/123", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
var response map[string]string
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["method"] != "DELETE" {
|
|
||||||
t.Errorf("Expected method DELETE, got %v", response["method"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Header Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_HeadersForwarding(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"custom_header": r.Header.Get("X-Custom-Header"),
|
|
||||||
"forwarded_for": r.Header.Get("X-Forwarded-For"),
|
|
||||||
"forwarded_host": r.Header.Get("X-Forwarded-Host"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/headers", nil)
|
|
||||||
req.Header.Set("X-Custom-Header", "test-value")
|
|
||||||
req.RemoteAddr = "192.168.1.100:12345"
|
|
||||||
req.Host = "example.com"
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
var response map[string]interface{}
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["custom_header"] != "test-value" {
|
|
||||||
t.Errorf("Expected custom header, got %v", response["custom_header"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if response["forwarded_for"] != "192.168.1.100" {
|
|
||||||
t.Errorf("Expected X-Forwarded-For, got %v", response["forwarded_for"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if response["forwarded_host"] != "example.com" {
|
|
||||||
t.Errorf("Expected X-Forwarded-Host, got %v", response["forwarded_host"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_CustomHeaders(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"api_version": r.Header.Get("X-API-Version"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{
|
|
||||||
Target: backend.URL(),
|
|
||||||
Headers: map[string]string{
|
|
||||||
"X-API-Version": "{version}",
|
|
||||||
},
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/api", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
// Simulate parameter substitution
|
|
||||||
proxy.ProxyRequest(rr, req, map[string]string{"version": "2"})
|
|
||||||
|
|
||||||
var response map[string]string
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["api_version"] != "2" {
|
|
||||||
t.Errorf("Expected API version 2, got %v", response["api_version"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_RemoteAddrSubstitution(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"client_ip": r.Header.Get("X-Client-IP"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{
|
|
||||||
Target: backend.URL(),
|
|
||||||
Headers: map[string]string{
|
|
||||||
"X-Client-IP": "$remote_addr",
|
|
||||||
},
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/api", nil)
|
|
||||||
req.RemoteAddr = "10.0.0.1:54321"
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
var response map[string]string
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["client_ip"] != "10.0.0.1" {
|
|
||||||
t.Errorf("Expected client IP 10.0.0.1, got %v", response["client_ip"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Query String Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_QueryStringPreservation(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"query": r.URL.RawQuery,
|
|
||||||
"param": r.URL.Query().Get("key"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/search?key=value&page=2", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
var response map[string]string
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["param"] != "value" {
|
|
||||||
t.Errorf("Expected query param 'value', got %v", response["param"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Status Code Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_StatusCodePreservation(t *testing.T) {
|
|
||||||
statusCodes := []int{200, 201, 400, 404, 500}
|
|
||||||
|
|
||||||
for _, code := range statusCodes {
|
|
||||||
code := code // capture range variable
|
|
||||||
t.Run(fmt.Sprintf("Status_%d", code), func(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(code)
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/status", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != code {
|
|
||||||
t.Errorf("Expected status %d, got %d", code, rr.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Error Handling Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_BackendUnavailable(t *testing.T) {
|
|
||||||
// Use a port that's definitely not listening
|
|
||||||
proxy, _ := New(&Config{
|
|
||||||
Target: "http://127.0.0.1:59999",
|
|
||||||
Timeout: 1 * time.Second,
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/test", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusBadGateway {
|
|
||||||
t.Errorf("Expected status 502 Bad Gateway, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_Timeout(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
time.Sleep(2 * time.Second)
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{
|
|
||||||
Target: backend.URL(),
|
|
||||||
Timeout: 100 * time.Millisecond,
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/slow", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusGatewayTimeout {
|
|
||||||
t.Errorf("Expected status 504 Gateway Timeout, got %d", rr.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Path Handling Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_StripPrefix(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"path": r.URL.Path,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{
|
|
||||||
Target: backend.URL(),
|
|
||||||
StripPrefix: "/api/v1",
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/api/v1/users", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
var response map[string]string
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["path"] != "/users" {
|
|
||||||
t.Errorf("Expected stripped path /users, got %v", response["path"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_TargetWithPath(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"path": r.URL.Path,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{
|
|
||||||
Target: backend.URL() + "/backend",
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/resource", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
var response map[string]string
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["path"] != "/backend/resource" {
|
|
||||||
t.Errorf("Expected path /backend/resource, got %v", response["path"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Large Body Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_LargeRequestBody(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
json.NewEncoder(w).Encode(map[string]int{
|
|
||||||
"received_bytes": len(body),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
// 100KB body
|
|
||||||
largeBody := strings.Repeat("x", 100000)
|
|
||||||
req := httptest.NewRequest("POST", "/upload", strings.NewReader(largeBody))
|
|
||||||
req.ContentLength = int64(len(largeBody))
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
var response map[string]int
|
|
||||||
json.NewDecoder(rr.Body).Decode(&response)
|
|
||||||
|
|
||||||
if response["received_bytes"] != 100000 {
|
|
||||||
t.Errorf("Expected 100000 bytes, got %d", response["received_bytes"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxy_LargeResponseBody(t *testing.T) {
|
|
||||||
largeResponse := strings.Repeat("y", 100000)
|
|
||||||
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte(largeResponse))
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/large", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Body.Len() != 100000 {
|
|
||||||
t.Errorf("Expected 100000 bytes in response, got %d", rr.Body.Len())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Concurrent Requests Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_ConcurrentRequests(t *testing.T) {
|
|
||||||
backend := newTestBackend(nil)
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
const numRequests = 50
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
errors := make(chan error, numRequests)
|
|
||||||
|
|
||||||
for i := 0; i < numRequests; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(n int) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/concurrent", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
|
||||||
errors <- &net.OpError{Op: "test", Err: context.DeadlineExceeded}
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
close(errors)
|
|
||||||
|
|
||||||
errorCount := 0
|
|
||||||
for range errors {
|
|
||||||
errorCount++
|
|
||||||
}
|
|
||||||
|
|
||||||
if errorCount > 0 {
|
|
||||||
t.Errorf("Got %d errors in concurrent requests", errorCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
if backend.getRequestCount() != numRequests {
|
|
||||||
t.Errorf("Expected %d requests at backend, got %d", numRequests, backend.getRequestCount())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Echo Tests ==============
|
|
||||||
|
|
||||||
func TestProxy_Echo(t *testing.T) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
|
|
||||||
w.Write(body)
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
testData := "Hello, Proxy!"
|
|
||||||
req := httptest.NewRequest("POST", "/echo", strings.NewReader(testData))
|
|
||||||
req.Header.Set("Content-Type", "text/plain")
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
|
|
||||||
if rr.Body.String() != testData {
|
|
||||||
t.Errorf("Expected echo of '%s', got '%s'", testData, rr.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if rr.Header().Get("Content-Type") != "text/plain" {
|
|
||||||
t.Errorf("Expected Content-Type text/plain, got %s", rr.Header().Get("Content-Type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Helper Function Tests ==============
|
|
||||||
|
|
||||||
func TestSingleJoiningSlash(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
a, b, expected string
|
|
||||||
}{
|
|
||||||
{"/api", "/users", "/api/users"},
|
|
||||||
{"/api/", "/users", "/api/users"},
|
|
||||||
{"/api", "users", "/api/users"},
|
|
||||||
{"/api/", "users", "/api/users"},
|
|
||||||
{"", "/users", "/users"},
|
|
||||||
{"/api", "", "/api/"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
result := singleJoiningSlash(tt.a, tt.b)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("singleJoiningSlash(%q, %q) = %q, want %q", tt.a, tt.b, result, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetClientIP(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
remoteAddr string
|
|
||||||
xRealIP string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"192.168.1.1:1234", "", "192.168.1.1"},
|
|
||||||
{"192.168.1.1:1234", "10.0.0.1", "10.0.0.1"},
|
|
||||||
{"invalid", "", "invalid"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
req.RemoteAddr = tt.remoteAddr
|
|
||||||
if tt.xRealIP != "" {
|
|
||||||
req.Header.Set("X-Real-IP", tt.xRealIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := getClientIP(req)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("getClientIP() = %q, want %q", result, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetScheme(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
tls bool
|
|
||||||
header string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"HTTP", false, "", "http"},
|
|
||||||
{"HTTPS from TLS", true, "", "https"},
|
|
||||||
{"HTTPS from header", false, "https", "https"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
|
||||||
if tt.header != "" {
|
|
||||||
req.Header.Set("X-Forwarded-Proto", tt.header)
|
|
||||||
}
|
|
||||||
// Note: httptest doesn't set TLS, so we can only test non-TLS cases fully
|
|
||||||
|
|
||||||
result := getScheme(req)
|
|
||||||
if !tt.tls && result != tt.expected {
|
|
||||||
t.Errorf("getScheme() = %q, want %q", result, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsConnectionError(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
err error
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{nil, false},
|
|
||||||
{&net.OpError{Op: "dial", Err: &net.DNSError{Err: "no such host"}}, true},
|
|
||||||
{context.DeadlineExceeded, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
result := isConnectionError(tt.err)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("isConnectionError(%v) = %v, want %v", tt.err, result, tt.expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Benchmarks ==============
|
|
||||||
|
|
||||||
func BenchmarkProxy_SimpleGET(b *testing.B) {
|
|
||||||
backend := newTestBackend(nil)
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
req := httptest.NewRequest("GET", "/bench", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkProxy_POSTWithBody(b *testing.B) {
|
|
||||||
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
io.Copy(io.Discard, r.Body)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
defer backend.close()
|
|
||||||
|
|
||||||
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
|
||||||
body := strings.Repeat("x", 1024) // 1KB body
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
req := httptest.NewRequest("POST", "/bench", strings.NewReader(body))
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
proxy.ServeHTTP(rr, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,492 +0,0 @@
|
|||||||
// Package routing provides HTTP routing with regex support
|
|
||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/config"
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
"github.com/konduktor/konduktor/internal/proxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RouteMatch represents a matched route with captured parameters
|
|
||||||
type RouteMatch struct {
|
|
||||||
Config map[string]interface{}
|
|
||||||
Params map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegexRoute represents a compiled regex route
|
|
||||||
type RegexRoute struct {
|
|
||||||
Pattern *regexp.Regexp
|
|
||||||
Config map[string]interface{}
|
|
||||||
CaseSensitive bool
|
|
||||||
OriginalExpr string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Router handles HTTP routing with exact, regex, and default routes
|
|
||||||
type Router struct {
|
|
||||||
config *config.Config
|
|
||||||
logger *logging.Logger
|
|
||||||
mux *http.ServeMux
|
|
||||||
staticDir string
|
|
||||||
exactRoutes map[string]map[string]interface{}
|
|
||||||
regexRoutes []*RegexRoute
|
|
||||||
defaultRoute map[string]interface{}
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new router from config
|
|
||||||
func New(cfg *config.Config, logger *logging.Logger) *Router {
|
|
||||||
staticDir := "./static"
|
|
||||||
if cfg != nil && cfg.HTTP.StaticDir != "" {
|
|
||||||
staticDir = cfg.HTTP.StaticDir
|
|
||||||
}
|
|
||||||
|
|
||||||
r := &Router{
|
|
||||||
config: cfg,
|
|
||||||
logger: logger,
|
|
||||||
mux: http.NewServeMux(),
|
|
||||||
staticDir: staticDir,
|
|
||||||
exactRoutes: make(map[string]map[string]interface{}),
|
|
||||||
regexRoutes: make([]*RegexRoute, 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load routes from extensions
|
|
||||||
if cfg != nil {
|
|
||||||
for _, ext := range cfg.Extensions {
|
|
||||||
if ext.Type == "routing" && ext.Config != nil {
|
|
||||||
if locations, ok := ext.Config["regex_locations"].(map[string]interface{}); ok {
|
|
||||||
for pattern, routeCfg := range locations {
|
|
||||||
if rc, ok := routeCfg.(map[string]interface{}); ok {
|
|
||||||
r.AddRoute(pattern, rc)
|
|
||||||
if logger != nil {
|
|
||||||
logger.Debug("Added route", "pattern", pattern)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.setupRoutes()
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRouter creates a router without config (for testing)
|
|
||||||
func NewRouter(opts ...RouterOption) *Router {
|
|
||||||
r := &Router{
|
|
||||||
mux: http.NewServeMux(),
|
|
||||||
staticDir: "./static",
|
|
||||||
exactRoutes: make(map[string]map[string]interface{}),
|
|
||||||
regexRoutes: make([]*RegexRoute, 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// RouterOption is a functional option for Router
|
|
||||||
type RouterOption func(*Router)
|
|
||||||
|
|
||||||
// WithStaticDir sets the static directory
|
|
||||||
func WithStaticDir(dir string) RouterOption {
|
|
||||||
return func(r *Router) {
|
|
||||||
r.staticDir = dir
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// StaticDir returns the static directory path
|
|
||||||
func (r *Router) StaticDir() string {
|
|
||||||
return r.staticDir
|
|
||||||
}
|
|
||||||
|
|
||||||
// Routes returns the regex routes (for testing)
|
|
||||||
func (r *Router) Routes() []*RegexRoute {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
return r.regexRoutes
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExactRoutes returns the exact routes (for testing)
|
|
||||||
func (r *Router) ExactRoutes() map[string]map[string]interface{} {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
return r.exactRoutes
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultRoute returns the default route (for testing)
|
|
||||||
func (r *Router) DefaultRoute() map[string]interface{} {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
return r.defaultRoute
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddRoute adds a route with the given pattern and config
|
|
||||||
// Pattern formats:
|
|
||||||
// - "=/path" - exact match
|
|
||||||
// - "~regex" - case-sensitive regex
|
|
||||||
// - "~*regex" - case-insensitive regex
|
|
||||||
// - "__default__" - default/fallback route
|
|
||||||
func (r *Router) AddRoute(pattern string, routeConfig map[string]interface{}) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case pattern == "__default__":
|
|
||||||
r.defaultRoute = routeConfig
|
|
||||||
|
|
||||||
case strings.HasPrefix(pattern, "="):
|
|
||||||
// Exact match route
|
|
||||||
path := strings.TrimPrefix(pattern, "=")
|
|
||||||
r.exactRoutes[path] = routeConfig
|
|
||||||
|
|
||||||
case strings.HasPrefix(pattern, "~*"):
|
|
||||||
// Case-insensitive regex
|
|
||||||
expr := strings.TrimPrefix(pattern, "~*")
|
|
||||||
re, err := regexp.Compile("(?i)" + expr)
|
|
||||||
if err != nil {
|
|
||||||
if r.logger != nil {
|
|
||||||
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
|
|
||||||
Pattern: re,
|
|
||||||
Config: routeConfig,
|
|
||||||
CaseSensitive: false,
|
|
||||||
OriginalExpr: expr,
|
|
||||||
})
|
|
||||||
|
|
||||||
case strings.HasPrefix(pattern, "~"):
|
|
||||||
// Case-sensitive regex
|
|
||||||
expr := strings.TrimPrefix(pattern, "~")
|
|
||||||
re, err := regexp.Compile(expr)
|
|
||||||
if err != nil {
|
|
||||||
if r.logger != nil {
|
|
||||||
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
|
|
||||||
Pattern: re,
|
|
||||||
Config: routeConfig,
|
|
||||||
CaseSensitive: true,
|
|
||||||
OriginalExpr: expr,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Match finds the best matching route for a path
|
|
||||||
// Priority: exact match > regex match > default
|
|
||||||
func (r *Router) Match(path string) *RouteMatch {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
// 1. Check exact routes
|
|
||||||
if cfg, ok := r.exactRoutes[path]; ok {
|
|
||||||
return &RouteMatch{
|
|
||||||
Config: cfg,
|
|
||||||
Params: make(map[string]string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Check regex routes
|
|
||||||
for _, route := range r.regexRoutes {
|
|
||||||
match := route.Pattern.FindStringSubmatch(path)
|
|
||||||
if match != nil {
|
|
||||||
params := make(map[string]string)
|
|
||||||
|
|
||||||
// Extract named groups
|
|
||||||
names := route.Pattern.SubexpNames()
|
|
||||||
for i, name := range names {
|
|
||||||
if i > 0 && name != "" && i < len(match) {
|
|
||||||
params[name] = match[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &RouteMatch{
|
|
||||||
Config: route.Config,
|
|
||||||
Params: params,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Check default route
|
|
||||||
if r.defaultRoute != nil {
|
|
||||||
return &RouteMatch{
|
|
||||||
Config: r.defaultRoute,
|
|
||||||
Params: make(map[string]string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupRoutes configures the routes from config
|
|
||||||
func (r *Router) setupRoutes() {
|
|
||||||
// Health check endpoint
|
|
||||||
r.mux.HandleFunc("/health", r.healthHandler)
|
|
||||||
|
|
||||||
// Setup redirect instructions from config
|
|
||||||
if r.config != nil {
|
|
||||||
for from, to := range r.config.Server.RedirectInstructions {
|
|
||||||
fromPath := from
|
|
||||||
toPath := to
|
|
||||||
r.mux.HandleFunc(fromPath, func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
http.Redirect(w, req, toPath, http.StatusMovedPermanently)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default handler for all other routes
|
|
||||||
r.mux.HandleFunc("/", r.defaultHandler)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP implements http.Handler
|
|
||||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
||||||
r.mux.ServeHTTP(w, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// healthHandler handles health check requests
|
|
||||||
func (r *Router) healthHandler(w http.ResponseWriter, req *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultHandler handles requests that don't match other routes
|
|
||||||
func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) {
|
|
||||||
path := req.URL.Path
|
|
||||||
|
|
||||||
// Try to match against configured routes
|
|
||||||
match := r.Match(path)
|
|
||||||
fmt.Printf("DEBUG defaultHandler: path=%q match=%v defaultRoute=%v\n", path, match != nil, r.defaultRoute != nil)
|
|
||||||
if match != nil {
|
|
||||||
fmt.Printf("DEBUG: matched config: %v\n", match.Config)
|
|
||||||
r.handleRouteMatch(w, req, match)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to serve static file
|
|
||||||
if r.staticDir != "" {
|
|
||||||
// Get absolute path for static dir
|
|
||||||
absStaticDir, err := filepath.Abs(r.staticDir)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath := filepath.Join(absStaticDir, filepath.Clean("/"+path))
|
|
||||||
cleanPath := filepath.Clean(filePath)
|
|
||||||
|
|
||||||
// Prevent directory traversal - ensure path is within static dir
|
|
||||||
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absStaticDir+string(filepath.Separator)) {
|
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if file exists
|
|
||||||
info, err := os.Stat(filePath)
|
|
||||||
if err == nil {
|
|
||||||
if info.IsDir() {
|
|
||||||
// Try index.html
|
|
||||||
indexPath := filepath.Join(filePath, "index.html")
|
|
||||||
if _, err := os.Stat(indexPath); err == nil {
|
|
||||||
http.ServeFile(w, req, indexPath)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
http.ServeFile(w, req, filePath)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 404 Not Found
|
|
||||||
http.NotFound(w, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleRouteMatch handles a matched route
|
|
||||||
func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) {
|
|
||||||
cfg := match.Config
|
|
||||||
|
|
||||||
// Handle proxy_pass directive
|
|
||||||
if proxyTarget, ok := cfg["proxy_pass"].(string); ok {
|
|
||||||
r.handleProxyPass(w, req, proxyTarget, cfg, match.Params)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle "return" directive
|
|
||||||
if ret, ok := cfg["return"].(string); ok {
|
|
||||||
parts := strings.SplitN(ret, " ", 2)
|
|
||||||
statusCode := 200
|
|
||||||
body := "OK"
|
|
||||||
if len(parts) >= 1 {
|
|
||||||
switch parts[0] {
|
|
||||||
case "200":
|
|
||||||
statusCode = 200
|
|
||||||
case "201":
|
|
||||||
statusCode = 201
|
|
||||||
case "301":
|
|
||||||
statusCode = 301
|
|
||||||
case "302":
|
|
||||||
statusCode = 302
|
|
||||||
case "400":
|
|
||||||
statusCode = 400
|
|
||||||
case "404":
|
|
||||||
statusCode = 404
|
|
||||||
case "500":
|
|
||||||
statusCode = 500
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(parts) >= 2 {
|
|
||||||
body = parts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
if ct, ok := cfg["content_type"].(string); ok {
|
|
||||||
w.Header().Set("Content-Type", ct)
|
|
||||||
} else {
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(statusCode)
|
|
||||||
w.Write([]byte(body))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle static files with root
|
|
||||||
if root, ok := cfg["root"].(string); ok {
|
|
||||||
path := req.URL.Path
|
|
||||||
|
|
||||||
if indexFile, ok := cfg["index_file"].(string); ok {
|
|
||||||
if path == "/" || strings.HasSuffix(path, "/") {
|
|
||||||
path = "/" + indexFile
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get absolute path for root dir
|
|
||||||
absRoot, err := filepath.Abs(root)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath := filepath.Join(absRoot, filepath.Clean("/"+path))
|
|
||||||
cleanPath := filepath.Clean(filePath)
|
|
||||||
|
|
||||||
// DEBUG
|
|
||||||
fmt.Printf("DEBUG: path=%q absRoot=%q filePath=%q cleanPath=%q\n", path, absRoot, filePath, cleanPath)
|
|
||||||
fmt.Printf("DEBUG: check1=%q check2=%q\n", cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator))
|
|
||||||
|
|
||||||
// Prevent directory traversal
|
|
||||||
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) {
|
|
||||||
fmt.Printf("DEBUG: FORBIDDEN - path not within root\n")
|
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if cacheControl, ok := cfg["cache_control"].(string); ok {
|
|
||||||
w.Header().Set("Cache-Control", cacheControl)
|
|
||||||
}
|
|
||||||
|
|
||||||
if headers, ok := cfg["headers"].([]interface{}); ok {
|
|
||||||
for _, h := range headers {
|
|
||||||
if header, ok := h.(string); ok {
|
|
||||||
parts := strings.SplitN(header, ": ", 2)
|
|
||||||
if len(parts) == 2 {
|
|
||||||
w.Header().Set(parts[0], parts[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
http.ServeFile(w, req, filePath)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle SPA fallback
|
|
||||||
if spaFallback, ok := cfg["spa_fallback"].(bool); ok && spaFallback {
|
|
||||||
root := r.staticDir
|
|
||||||
if rt, ok := cfg["root"].(string); ok {
|
|
||||||
root = rt
|
|
||||||
}
|
|
||||||
|
|
||||||
indexFile := "index.html"
|
|
||||||
if idx, ok := cfg["index_file"].(string); ok {
|
|
||||||
indexFile = idx
|
|
||||||
}
|
|
||||||
|
|
||||||
filePath := filepath.Join(root, indexFile)
|
|
||||||
http.ServeFile(w, req, filePath)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
http.NotFound(w, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleProxyPass proxies the request to the target backend
|
|
||||||
func (r *Router) handleProxyPass(w http.ResponseWriter, req *http.Request, target string, cfg map[string]interface{}, params map[string]string) {
|
|
||||||
// Substitute params in target URL (e.g., {version} -> actual version)
|
|
||||||
for key, value := range params {
|
|
||||||
target = strings.ReplaceAll(target, "{"+key+"}", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create proxy
|
|
||||||
proxyConfig := &proxy.Config{
|
|
||||||
Target: target,
|
|
||||||
Headers: make(map[string]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse headers from config
|
|
||||||
if headers, ok := cfg["headers"].([]interface{}); ok {
|
|
||||||
for _, h := range headers {
|
|
||||||
if header, ok := h.(string); ok {
|
|
||||||
parts := strings.SplitN(header, ": ", 2)
|
|
||||||
if len(parts) == 2 {
|
|
||||||
// Substitute params in header values
|
|
||||||
headerValue := parts[1]
|
|
||||||
for key, value := range params {
|
|
||||||
headerValue = strings.ReplaceAll(headerValue, "{"+key+"}", value)
|
|
||||||
}
|
|
||||||
proxyConfig.Headers[parts[0]] = headerValue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := proxy.New(proxyConfig, r.logger)
|
|
||||||
if err != nil {
|
|
||||||
if r.logger != nil {
|
|
||||||
r.logger.Error("Failed to create proxy", "target", target, "error", err)
|
|
||||||
}
|
|
||||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
p.ProxyRequest(w, req, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateRouterFromConfig creates a router from extension config
|
|
||||||
func CreateRouterFromConfig(cfg map[string]interface{}) *Router {
|
|
||||||
router := NewRouter()
|
|
||||||
|
|
||||||
if locations, ok := cfg["regex_locations"].(map[string]interface{}); ok {
|
|
||||||
for pattern, routeCfg := range locations {
|
|
||||||
if rc, ok := routeCfg.(map[string]interface{}); ok {
|
|
||||||
router.AddRoute(pattern, rc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return router
|
|
||||||
}
|
|
||||||
@ -1,375 +0,0 @@
|
|||||||
package routing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ============== Router Initialization Tests ==============
|
|
||||||
|
|
||||||
func TestRouter_Initialization(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
|
|
||||||
if router.StaticDir() != "./static" {
|
|
||||||
t.Errorf("Expected static dir ./static, got %s", router.StaticDir())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(router.Routes()) != 0 {
|
|
||||||
t.Error("Expected empty routes")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(router.ExactRoutes()) != 0 {
|
|
||||||
t.Error("Expected empty exact routes")
|
|
||||||
}
|
|
||||||
|
|
||||||
if router.DefaultRoute() != nil {
|
|
||||||
t.Error("Expected nil default route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_CustomStaticDir(t *testing.T) {
|
|
||||||
router := NewRouter(WithStaticDir("/custom/path"))
|
|
||||||
|
|
||||||
if router.StaticDir() != "/custom/path" {
|
|
||||||
t.Errorf("Expected static dir /custom/path, got %s", router.StaticDir())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Route Adding Tests ==============
|
|
||||||
|
|
||||||
func TestRouter_AddExactRoute(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"return": "200 OK"}
|
|
||||||
|
|
||||||
router.AddRoute("=/health", config)
|
|
||||||
|
|
||||||
exactRoutes := router.ExactRoutes()
|
|
||||||
if _, ok := exactRoutes["/health"]; !ok {
|
|
||||||
t.Error("Expected /health in exact routes")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_AddDefaultRoute(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"spa_fallback": true, "root": "./static"}
|
|
||||||
|
|
||||||
router.AddRoute("__default__", config)
|
|
||||||
|
|
||||||
if router.DefaultRoute() == nil {
|
|
||||||
t.Error("Expected default route to be set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_AddRegexRoute(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"root": "./static"}
|
|
||||||
|
|
||||||
router.AddRoute("~^/api/", config)
|
|
||||||
|
|
||||||
if len(router.Routes()) != 1 {
|
|
||||||
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_AddCaseInsensitiveRegexRoute(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
|
|
||||||
|
|
||||||
router.AddRoute("~*\\.(css|js)$", config)
|
|
||||||
|
|
||||||
if len(router.Routes()) != 1 {
|
|
||||||
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
|
||||||
}
|
|
||||||
|
|
||||||
if router.Routes()[0].CaseSensitive {
|
|
||||||
t.Error("Expected case-insensitive route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_InvalidRegexPattern(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"root": "./static"}
|
|
||||||
|
|
||||||
// Invalid regex - unmatched bracket
|
|
||||||
router.AddRoute("~^/api/[invalid", config)
|
|
||||||
|
|
||||||
// Should not add invalid pattern
|
|
||||||
if len(router.Routes()) != 0 {
|
|
||||||
t.Error("Should not add invalid regex pattern")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Route Matching Tests ==============
|
|
||||||
|
|
||||||
func TestRouter_MatchExactRoute(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"return": "200 OK"}
|
|
||||||
router.AddRoute("=/health", config)
|
|
||||||
|
|
||||||
match := router.Match("/health")
|
|
||||||
|
|
||||||
if match == nil {
|
|
||||||
t.Fatal("Expected match for /health")
|
|
||||||
}
|
|
||||||
|
|
||||||
if match.Config["return"] != "200 OK" {
|
|
||||||
t.Error("Expected return config")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(match.Params) != 0 {
|
|
||||||
t.Error("Expected empty params for exact match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_MatchExactRouteNoMatch(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"return": "200 OK"}
|
|
||||||
router.AddRoute("=/health", config)
|
|
||||||
|
|
||||||
match := router.Match("/healthcheck")
|
|
||||||
|
|
||||||
if match != nil {
|
|
||||||
t.Error("Exact route should not match /healthcheck")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_MatchRegexRoute(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
|
|
||||||
router.AddRoute("~^/api/v\\d+/", config)
|
|
||||||
|
|
||||||
match := router.Match("/api/v1/users")
|
|
||||||
|
|
||||||
if match == nil {
|
|
||||||
t.Fatal("Expected match for /api/v1/users")
|
|
||||||
}
|
|
||||||
|
|
||||||
if match.Config["proxy_pass"] != "http://localhost:9001" {
|
|
||||||
t.Error("Expected proxy_pass config")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_MatchRegexRouteWithGroups(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
|
|
||||||
router.AddRoute("~^/api/v(?P<version>\\d+)/", config)
|
|
||||||
|
|
||||||
match := router.Match("/api/v2/data")
|
|
||||||
|
|
||||||
if match == nil {
|
|
||||||
t.Fatal("Expected match for /api/v2/data")
|
|
||||||
}
|
|
||||||
|
|
||||||
if match.Params["version"] != "2" {
|
|
||||||
t.Errorf("Expected version=2, got %s", match.Params["version"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_MatchCaseInsensitiveRegex(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
|
|
||||||
router.AddRoute("~*\\.(CSS|JS)$", config)
|
|
||||||
|
|
||||||
// Should match lowercase
|
|
||||||
match1 := router.Match("/styles/main.css")
|
|
||||||
if match1 == nil {
|
|
||||||
t.Error("Should match lowercase .css")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should match uppercase
|
|
||||||
match2 := router.Match("/scripts/app.JS")
|
|
||||||
if match2 == nil {
|
|
||||||
t.Error("Should match uppercase .JS")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_MatchCaseSensitiveRegex(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
config := map[string]interface{}{"root": "./static"}
|
|
||||||
router.AddRoute("~\\.(css)$", config)
|
|
||||||
|
|
||||||
// Should match lowercase
|
|
||||||
match1 := router.Match("/styles/main.css")
|
|
||||||
if match1 == nil {
|
|
||||||
t.Error("Should match lowercase .css")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should NOT match uppercase
|
|
||||||
match2 := router.Match("/styles/main.CSS")
|
|
||||||
if match2 != nil {
|
|
||||||
t.Error("Should not match uppercase .CSS for case-sensitive regex")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_MatchDefaultRoute(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
|
|
||||||
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
|
||||||
|
|
||||||
match := router.Match("/unknown/path")
|
|
||||||
|
|
||||||
if match == nil {
|
|
||||||
t.Fatal("Expected default route match")
|
|
||||||
}
|
|
||||||
|
|
||||||
if match.Config["spa_fallback"] != true {
|
|
||||||
t.Error("Expected spa_fallback config from default route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Priority Tests ==============
|
|
||||||
|
|
||||||
func TestRouter_PriorityExactOverRegex(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
router.AddRoute("=/api/status", map[string]interface{}{"return": "200 Exact"})
|
|
||||||
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
|
||||||
|
|
||||||
match := router.Match("/api/status")
|
|
||||||
|
|
||||||
if match == nil {
|
|
||||||
t.Fatal("Expected match")
|
|
||||||
}
|
|
||||||
|
|
||||||
if match.Config["return"] != "200 Exact" {
|
|
||||||
t.Error("Exact match should have priority over regex")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_PriorityRegexOverDefault(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
|
||||||
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
|
||||||
|
|
||||||
match := router.Match("/api/v1/users")
|
|
||||||
|
|
||||||
if match == nil {
|
|
||||||
t.Fatal("Expected match")
|
|
||||||
}
|
|
||||||
|
|
||||||
if match.Config["proxy_pass"] != "http://localhost:9001" {
|
|
||||||
t.Error("Regex match should have priority over default")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== CreateRouterFromConfig Tests ==============
|
|
||||||
|
|
||||||
func TestCreateRouterFromConfig(t *testing.T) {
|
|
||||||
config := map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"=/health": map[string]interface{}{
|
|
||||||
"return": "200 OK",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"~^/api/": map[string]interface{}{
|
|
||||||
"proxy_pass": "http://localhost:9001",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"spa_fallback": true,
|
|
||||||
"root": "./static",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
router := CreateRouterFromConfig(config)
|
|
||||||
|
|
||||||
// Check exact route
|
|
||||||
if _, ok := router.ExactRoutes()["/health"]; !ok {
|
|
||||||
t.Error("Expected /health exact route")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check regex route
|
|
||||||
if len(router.Routes()) != 1 {
|
|
||||||
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check default route
|
|
||||||
if router.DefaultRoute() == nil {
|
|
||||||
t.Error("Expected default route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Static Dir Path Tests ==============
|
|
||||||
|
|
||||||
func TestRouter_StaticDirPath(t *testing.T) {
|
|
||||||
router := NewRouter(WithStaticDir("/var/www/html"))
|
|
||||||
|
|
||||||
expected, _ := filepath.Abs("/var/www/html")
|
|
||||||
actual, _ := filepath.Abs(router.StaticDir())
|
|
||||||
|
|
||||||
if actual != expected {
|
|
||||||
t.Errorf("Expected static dir %s, got %s", expected, actual)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Concurrent Access Tests ==============
|
|
||||||
|
|
||||||
func TestRouter_ConcurrentAccess(t *testing.T) {
|
|
||||||
router := NewRouter()
|
|
||||||
|
|
||||||
// Add routes concurrently
|
|
||||||
done := make(chan bool, 10)
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
go func(n int) {
|
|
||||||
router.AddRoute("~^/api/v"+string(rune('0'+n))+"/", map[string]interface{}{
|
|
||||||
"proxy_pass": "http://localhost:900" + string(rune('0'+n)),
|
|
||||||
})
|
|
||||||
done <- true
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for all goroutines
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
<-done
|
|
||||||
}
|
|
||||||
|
|
||||||
// Match routes concurrently
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
go func(n int) {
|
|
||||||
router.Match("/api/v" + string(rune('0'+n)) + "/users")
|
|
||||||
done <- true
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
<-done
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Benchmarks ==============
|
|
||||||
|
|
||||||
func BenchmarkRouter_MatchExact(b *testing.B) {
|
|
||||||
router := NewRouter()
|
|
||||||
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
router.Match("/health")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkRouter_MatchRegex(b *testing.B) {
|
|
||||||
router := NewRouter()
|
|
||||||
router.AddRoute("~^/api/v(?P<version>\\d+)/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
router.Match("/api/v1/users/123")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkRouter_MatchWithManyRoutes(b *testing.B) {
|
|
||||||
router := NewRouter()
|
|
||||||
|
|
||||||
// Add many routes
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
router.AddRoute("~^/api/v"+string(rune('0'+i%10))+"/service"+string(rune('0'+i/10))+"/",
|
|
||||||
map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
|
||||||
}
|
|
||||||
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
router.Match("/api/v5/service3/users/123")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,158 +0,0 @@
|
|||||||
// Package server provides the HTTP server implementation
|
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/config"
|
|
||||||
"github.com/konduktor/konduktor/internal/extension"
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
"github.com/konduktor/konduktor/internal/middleware"
|
|
||||||
)
|
|
||||||
|
|
||||||
const Version = "0.2.0"
|
|
||||||
|
|
||||||
// Server represents the Konduktor HTTP server
|
|
||||||
type Server struct {
|
|
||||||
config *config.Config
|
|
||||||
httpServer *http.Server
|
|
||||||
extensionManager *extension.Manager
|
|
||||||
logger *logging.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new server instance
|
|
||||||
func New(cfg *config.Config) (*Server, error) {
|
|
||||||
if err := cfg.Validate(); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger, err := logging.NewFromConfig(cfg.Logging)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create logger: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create extension manager
|
|
||||||
extManager := extension.NewManager(logger)
|
|
||||||
|
|
||||||
// Load extensions from config
|
|
||||||
for _, extCfg := range cfg.Extensions {
|
|
||||||
// Add static_dir to routing config if not present
|
|
||||||
if extCfg.Type == "routing" {
|
|
||||||
if extCfg.Config == nil {
|
|
||||||
extCfg.Config = make(map[string]interface{})
|
|
||||||
}
|
|
||||||
if _, ok := extCfg.Config["static_dir"]; !ok {
|
|
||||||
extCfg.Config["static_dir"] = cfg.HTTP.StaticDir
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := extManager.LoadExtension(extCfg.Type, extCfg.Config); err != nil {
|
|
||||||
logger.Error("Failed to load extension", "type", extCfg.Type, "error", err)
|
|
||||||
// Continue loading other extensions
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &Server{
|
|
||||||
config: cfg,
|
|
||||||
extensionManager: extManager,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
|
|
||||||
return srv, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run starts the server and blocks until shutdown
|
|
||||||
func (s *Server) Run() error {
|
|
||||||
// Build handler chain with middleware
|
|
||||||
handler := s.buildHandler()
|
|
||||||
|
|
||||||
// Create HTTP server
|
|
||||||
addr := fmt.Sprintf("%s:%d", s.config.Server.Host, s.config.Server.Port)
|
|
||||||
s.httpServer = &http.Server{
|
|
||||||
Addr: addr,
|
|
||||||
Handler: handler,
|
|
||||||
ReadTimeout: 30 * time.Second,
|
|
||||||
WriteTimeout: 30 * time.Second,
|
|
||||||
IdleTimeout: 120 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start server in goroutine
|
|
||||||
errChan := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
s.logger.Info("Server starting", "addr", addr, "version", Version)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if s.config.SSL.Enabled {
|
|
||||||
err = s.httpServer.ListenAndServeTLS(s.config.SSL.CertFile, s.config.SSL.KeyFile)
|
|
||||||
} else {
|
|
||||||
err = s.httpServer.ListenAndServe()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && err != http.ErrServerClosed {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for shutdown signal
|
|
||||||
return s.waitForShutdown(errChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildHandler builds the HTTP handler chain
|
|
||||||
func (s *Server) buildHandler() http.Handler {
|
|
||||||
// Create base handler that returns 404
|
|
||||||
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.NotFound(w, r)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Wrap with extension manager
|
|
||||||
var handler http.Handler = s.extensionManager.Handler(baseHandler)
|
|
||||||
|
|
||||||
// Add middleware (applied in reverse order)
|
|
||||||
handler = middleware.AccessLog(handler, s.logger)
|
|
||||||
handler = middleware.ServerHeader(handler, Version)
|
|
||||||
handler = middleware.Recovery(handler, s.logger)
|
|
||||||
|
|
||||||
return handler
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitForShutdown waits for shutdown signal and gracefully stops the server
|
|
||||||
func (s *Server) waitForShutdown(errChan <-chan error) error {
|
|
||||||
// Listen for shutdown signals
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-errChan:
|
|
||||||
return err
|
|
||||||
case sig := <-sigChan:
|
|
||||||
s.logger.Info("Shutdown signal received", "signal", sig.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Graceful shutdown with timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
s.logger.Info("Shutting down server...")
|
|
||||||
|
|
||||||
// Cleanup extensions
|
|
||||||
s.extensionManager.Cleanup()
|
|
||||||
|
|
||||||
if err := s.httpServer.Shutdown(ctx); err != nil {
|
|
||||||
s.logger.Error("Error during shutdown", "error", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.logger.Info("Server stopped gracefully")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown gracefully shuts down the server
|
|
||||||
func (s *Server) Shutdown(ctx context.Context) error {
|
|
||||||
return s.httpServer.Shutdown(ctx)
|
|
||||||
}
|
|
||||||
@ -1,134 +0,0 @@
|
|||||||
# Integration Tests
|
|
||||||
|
|
||||||
Интеграционные тесты для Konduktor — полноценное тестирование сервера с реальными HTTP запросами.
|
|
||||||
|
|
||||||
## Отличие от unit-тестов
|
|
||||||
|
|
||||||
| Аспект | Unit-тесты | Интеграционные тесты |
|
|
||||||
|--------|------------|---------------------|
|
|
||||||
| Scope | Отдельный модуль в изоляции | Весь сервер целиком |
|
|
||||||
| Backend | Mock (httptest.Server) | Реальные HTTP серверы |
|
|
||||||
| Config | Программный | YAML конфигурация |
|
|
||||||
| Extensions | Не тестируются | Полная цепочка обработки |
|
|
||||||
|
|
||||||
## Структура тестов
|
|
||||||
|
|
||||||
```
|
|
||||||
tests/integration/
|
|
||||||
├── README.md # Эта документация
|
|
||||||
├── helpers_test.go # Общие хелперы и утилиты
|
|
||||||
├── reverse_proxy_test.go # Тесты reverse proxy
|
|
||||||
├── routing_test.go # Тесты маршрутизации (TODO)
|
|
||||||
├── security_test.go # Тесты security extension (TODO)
|
|
||||||
├── caching_test.go # Тесты caching extension (TODO)
|
|
||||||
└── static_files_test.go # Тесты статических файлов (TODO)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Что тестируют интеграционные тесты
|
|
||||||
|
|
||||||
### 1. Reverse Proxy (`reverse_proxy_test.go`)
|
|
||||||
|
|
||||||
- [x] Базовое проксирование GET/POST/PUT/DELETE
|
|
||||||
- [x] Exact match routes (`=/api/version`)
|
|
||||||
- [x] Regex routes с параметрами (`~^/api/resource/(?P<id>\d+)$`)
|
|
||||||
- [x] Подстановка параметров в target URL (`{id}`, `{tag}`)
|
|
||||||
- [x] Подстановка переменных в заголовки (`$remote_addr`)
|
|
||||||
- [x] Передача заголовков X-Forwarded-For, X-Real-IP
|
|
||||||
- [x] Сохранение query string
|
|
||||||
- [x] Обработка ошибок backend (502, 504)
|
|
||||||
- [x] Таймауты соединения
|
|
||||||
|
|
||||||
### 2. Routing Extension (`routing_test.go`)
|
|
||||||
|
|
||||||
- [x] Приоритет маршрутов (exact > regex > default)
|
|
||||||
- [x] Case-sensitive regex (`~`)
|
|
||||||
- [x] Case-insensitive regex (`~*`)
|
|
||||||
- [x] Default route (`__default__`)
|
|
||||||
- [x] Return directive (`return 200 "OK"`)
|
|
||||||
- [x] Regex с именованными группами
|
|
||||||
- [x] Множественные regex маршруты
|
|
||||||
- [x] Кастомные заголовки в маршрутах
|
|
||||||
- [x] Обработка отсутствия маршрута
|
|
||||||
|
|
||||||
### 3. Security Extension (`security_test.go`)
|
|
||||||
|
|
||||||
- [ ] IP whitelist
|
|
||||||
- [ ] IP blacklist
|
|
||||||
- [ ] CIDR нотация (10.0.0.0/8)
|
|
||||||
- [ ] Security headers (X-Frame-Options, X-Content-Type-Options)
|
|
||||||
- [ ] Rate limiting
|
|
||||||
- [ ] Комбинация с другими extensions
|
|
||||||
|
|
||||||
### 4. Caching Extension (`caching_test.go`)
|
|
||||||
|
|
||||||
- [x] Cache hit/miss
|
|
||||||
- [x] TTL expiration
|
|
||||||
- [x] Pattern-based caching
|
|
||||||
- [x] Cache-Control headers (X-Cache header)
|
|
||||||
- [x] Кэширование только GET запросов
|
|
||||||
- [x] Разные пути = разные ключи кэша
|
|
||||||
- [x] Query string влияет на ключ кэша
|
|
||||||
- [x] Ошибки не кэшируются
|
|
||||||
- [x] Конкурентный доступ к кэшу
|
|
||||||
- [x] Множественные паттерны кэширования
|
|
||||||
|
|
||||||
### 5. Static Files (`static_files_test.go`)
|
|
||||||
|
|
||||||
- [ ] Serving статических файлов
|
|
||||||
- [ ] Index file (index.html)
|
|
||||||
- [ ] MIME types
|
|
||||||
- [ ] Cache-Control для static
|
|
||||||
- [ ] SPA fallback
|
|
||||||
- [ ] Directory traversal protection
|
|
||||||
- [ ] 404 для несуществующих файлов
|
|
||||||
|
|
||||||
### 6. Extension Chain (`extension_chain_test.go`)
|
|
||||||
|
|
||||||
- [ ] Порядок выполнения extensions (security → caching → routing)
|
|
||||||
- [ ] Прерывание цепочки при ошибке
|
|
||||||
- [ ] Совместная работа extensions
|
|
||||||
|
|
||||||
## Запуск тестов
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Все интеграционные тесты
|
|
||||||
go test ./tests/integration/... -v
|
|
||||||
|
|
||||||
# Конкретный файл
|
|
||||||
go test ./tests/integration/... -v -run TestReverseProxy
|
|
||||||
|
|
||||||
# С таймаутом (интеграционные тесты медленнее)
|
|
||||||
go test ./tests/integration/... -v -timeout 60s
|
|
||||||
|
|
||||||
# С покрытием
|
|
||||||
go test ./tests/integration/... -v -coverprofile=coverage.out
|
|
||||||
```
|
|
||||||
|
|
||||||
## Требования
|
|
||||||
|
|
||||||
- Свободные порты: тесты используют случайные порты (`:0`)
|
|
||||||
- Сетевой доступ: для localhost соединений
|
|
||||||
- Время: интеграционные тесты занимают больше времени (~5-10 сек)
|
|
||||||
|
|
||||||
## Добавление новых тестов
|
|
||||||
|
|
||||||
1. Создайте файл `*_test.go` в `tests/integration/`
|
|
||||||
2. Используйте хелперы из `helpers_test.go`:
|
|
||||||
- `startTestServer()` — запуск Konduktor сервера
|
|
||||||
- `startBackend()` — запуск mock backend
|
|
||||||
- `makeRequest()` — отправка HTTP запроса
|
|
||||||
3. Добавьте описание в этот README
|
|
||||||
|
|
||||||
## CI/CD
|
|
||||||
|
|
||||||
Интеграционные тесты запускаются отдельно от unit-тестов:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# .github/workflows/test.yml
|
|
||||||
jobs:
|
|
||||||
unit-tests:
|
|
||||||
run: go test ./internal/...
|
|
||||||
|
|
||||||
integration-tests:
|
|
||||||
run: go test ./tests/integration/... -timeout 120s
|
|
||||||
```
|
|
||||||
@ -1,666 +0,0 @@
|
|||||||
package integration
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/extension"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ============== Basic Cache Hit/Miss Tests ==============
|
|
||||||
|
|
||||||
func TestCaching_BasicHitMiss(t *testing.T) {
|
|
||||||
var requestCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
count := atomic.AddInt64(&requestCount, 1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"request_number": count,
|
|
||||||
"timestamp": time.Now().UnixNano(),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
// Create caching extension
|
|
||||||
cachingExt, err := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "1m",
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "30s",
|
|
||||||
"methods": []interface{}{"GET"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create caching extension: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create routing extension
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// First request - should be MISS
|
|
||||||
resp1, err := client.Get("/api/data", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request 1 failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheHeader1 := resp1.Header.Get("X-Cache")
|
|
||||||
var result1 map[string]interface{}
|
|
||||||
json.NewDecoder(resp1.Body).Decode(&result1)
|
|
||||||
resp1.Body.Close()
|
|
||||||
|
|
||||||
if cacheHeader1 != "MISS" {
|
|
||||||
t.Errorf("Expected X-Cache: MISS for first request, got %q", cacheHeader1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second request - should be HIT (same response)
|
|
||||||
resp2, err := client.Get("/api/data", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request 2 failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheHeader2 := resp2.Header.Get("X-Cache")
|
|
||||||
var result2 map[string]interface{}
|
|
||||||
json.NewDecoder(resp2.Body).Decode(&result2)
|
|
||||||
resp2.Body.Close()
|
|
||||||
|
|
||||||
if cacheHeader2 != "HIT" {
|
|
||||||
t.Errorf("Expected X-Cache: HIT for second request, got %q", cacheHeader2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify same response (from cache)
|
|
||||||
if result1["request_number"] != result2["request_number"] {
|
|
||||||
t.Errorf("Expected same request_number from cache, got %v and %v",
|
|
||||||
result1["request_number"], result2["request_number"])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backend should only receive 1 request
|
|
||||||
if atomic.LoadInt64(&requestCount) != 1 {
|
|
||||||
t.Errorf("Expected 1 backend request, got %d", requestCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== TTL Expiration Tests ==============
|
|
||||||
|
|
||||||
func TestCaching_TTLExpiration(t *testing.T) {
|
|
||||||
var requestCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
count := atomic.AddInt64(&requestCount, 1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"request_number": count,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
// Create caching extension with short TTL
|
|
||||||
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "100ms", // Very short TTL for testing
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "100ms",
|
|
||||||
"methods": []interface{}{"GET"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// First request
|
|
||||||
resp1, _ := client.Get("/api/data", nil)
|
|
||||||
var result1 map[string]interface{}
|
|
||||||
json.NewDecoder(resp1.Body).Decode(&result1)
|
|
||||||
resp1.Body.Close()
|
|
||||||
|
|
||||||
// Second request (within TTL) - should be HIT
|
|
||||||
resp2, _ := client.Get("/api/data", nil)
|
|
||||||
cacheHeader2 := resp2.Header.Get("X-Cache")
|
|
||||||
resp2.Body.Close()
|
|
||||||
|
|
||||||
if cacheHeader2 != "HIT" {
|
|
||||||
t.Errorf("Expected X-Cache: HIT before TTL expires, got %q", cacheHeader2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for TTL to expire
|
|
||||||
time.Sleep(150 * time.Millisecond)
|
|
||||||
|
|
||||||
// Third request (after TTL) - should be MISS
|
|
||||||
resp3, _ := client.Get("/api/data", nil)
|
|
||||||
cacheHeader3 := resp3.Header.Get("X-Cache")
|
|
||||||
var result3 map[string]interface{}
|
|
||||||
json.NewDecoder(resp3.Body).Decode(&result3)
|
|
||||||
resp3.Body.Close()
|
|
||||||
|
|
||||||
if cacheHeader3 != "MISS" {
|
|
||||||
t.Errorf("Expected X-Cache: MISS after TTL expires, got %q", cacheHeader3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify new request was made (different request_number)
|
|
||||||
if result1["request_number"] == result3["request_number"] {
|
|
||||||
t.Error("Expected different request_number after TTL expiration")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Pattern-Based Caching Tests ==============
|
|
||||||
|
|
||||||
func TestCaching_PatternBasedCaching(t *testing.T) {
|
|
||||||
var apiCount, staticCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if r.URL.Path[:5] == "/api/" {
|
|
||||||
atomic.AddInt64(&apiCount, 1)
|
|
||||||
} else {
|
|
||||||
atomic.AddInt64(&staticCount, 1)
|
|
||||||
}
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
// Only cache /api/* paths
|
|
||||||
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "1m",
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "1m",
|
|
||||||
"methods": []interface{}{"GET"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Multiple requests to /api/ - should be cached
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
resp, _ := client.Get("/api/users", nil)
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiple requests to /static/ - should NOT be cached (not matching pattern)
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
resp, _ := client.Get("/static/file.js", nil)
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// API should have only 1 request (cached)
|
|
||||||
if atomic.LoadInt64(&apiCount) != 1 {
|
|
||||||
t.Errorf("Expected 1 API request (cached), got %d", apiCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Static should have 3 requests (not cached)
|
|
||||||
if atomic.LoadInt64(&staticCount) != 3 {
|
|
||||||
t.Errorf("Expected 3 static requests (not cached), got %d", staticCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Method-Specific Caching Tests ==============
|
|
||||||
|
|
||||||
func TestCaching_OnlyGETMethodCached(t *testing.T) {
|
|
||||||
var getCount, postCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if r.Method == "GET" {
|
|
||||||
atomic.AddInt64(&getCount, 1)
|
|
||||||
} else if r.Method == "POST" {
|
|
||||||
atomic.AddInt64(&postCount, 1)
|
|
||||||
}
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"method": r.Method,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "1m",
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "1m",
|
|
||||||
"methods": []interface{}{"GET"}, // Only GET
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Multiple GET requests - should be cached
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
resp, _ := client.Get("/api/data", nil)
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiple POST requests - should NOT be cached
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
resp, _ := client.Post("/api/data", []byte(`{}`), map[string]string{
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
})
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if atomic.LoadInt64(&getCount) != 1 {
|
|
||||||
t.Errorf("Expected 1 GET request (cached), got %d", getCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
if atomic.LoadInt64(&postCount) != 3 {
|
|
||||||
t.Errorf("Expected 3 POST requests (not cached), got %d", postCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Different Paths Different Cache Keys ==============
|
|
||||||
|
|
||||||
func TestCaching_DifferentPathsDifferentCacheKeys(t *testing.T) {
|
|
||||||
var requestCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
count := atomic.AddInt64(&requestCount, 1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"path": r.URL.Path,
|
|
||||||
"request_number": count,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "1m",
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "1m",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Request different paths
|
|
||||||
paths := []string{"/api/users", "/api/posts", "/api/comments"}
|
|
||||||
|
|
||||||
for _, path := range paths {
|
|
||||||
resp, _ := client.Get(path, nil)
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each path should result in a separate backend request
|
|
||||||
if atomic.LoadInt64(&requestCount) != 3 {
|
|
||||||
t.Errorf("Expected 3 backend requests (one per path), got %d", requestCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request same paths again - all should be cached
|
|
||||||
for _, path := range paths {
|
|
||||||
resp, _ := client.Get(path, nil)
|
|
||||||
cacheHeader := resp.Header.Get("X-Cache")
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
if cacheHeader != "HIT" {
|
|
||||||
t.Errorf("Expected X-Cache: HIT for %s, got %q", path, cacheHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No additional backend requests
|
|
||||||
if atomic.LoadInt64(&requestCount) != 3 {
|
|
||||||
t.Errorf("Expected still 3 backend requests after cache hits, got %d", requestCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Query String Affects Cache Key ==============
|
|
||||||
|
|
||||||
func TestCaching_QueryStringAffectsCacheKey(t *testing.T) {
|
|
||||||
var requestCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
count := atomic.AddInt64(&requestCount, 1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"query": r.URL.RawQuery,
|
|
||||||
"request_number": count,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "1m",
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "1m",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Different query strings = different cache keys
|
|
||||||
queries := []string{
|
|
||||||
"/api/search?q=hello",
|
|
||||||
"/api/search?q=world",
|
|
||||||
"/api/search?q=test",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, query := range queries {
|
|
||||||
resp, _ := client.Get(query, nil)
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each unique query should result in a separate backend request
|
|
||||||
if atomic.LoadInt64(&requestCount) != 3 {
|
|
||||||
t.Errorf("Expected 3 backend requests (one per query), got %d", requestCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Same query again should be cached
|
|
||||||
resp, _ := client.Get("/api/search?q=hello", nil)
|
|
||||||
cacheHeader := resp.Header.Get("X-Cache")
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
if cacheHeader != "HIT" {
|
|
||||||
t.Errorf("Expected X-Cache: HIT for repeated query, got %q", cacheHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Cache Does Not Store Error Responses ==============
|
|
||||||
|
|
||||||
func TestCaching_DoesNotCacheErrors(t *testing.T) {
|
|
||||||
var requestCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
atomic.AddInt64(&requestCount, 1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "internal error"})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "1m",
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "1m",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Multiple requests to error endpoint
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
resp, _ := client.Get("/api/error", nil)
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// All requests should reach backend (errors not cached)
|
|
||||||
if atomic.LoadInt64(&requestCount) != 3 {
|
|
||||||
t.Errorf("Expected 3 backend requests (errors not cached), got %d", requestCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Concurrent Cache Access ==============
|
|
||||||
|
|
||||||
func TestCaching_ConcurrentAccess(t *testing.T) {
|
|
||||||
var requestCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Small delay to increase chance of race conditions
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
count := atomic.AddInt64(&requestCount, 1)
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"request_number": count,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "1m",
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "1m",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
const numRequests = 20
|
|
||||||
results := make(chan error, numRequests)
|
|
||||||
|
|
||||||
// Make first request to populate cache
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
resp, _ := client.Get("/api/concurrent", nil)
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
// Now many concurrent requests should all hit cache
|
|
||||||
for i := 0; i < numRequests; i++ {
|
|
||||||
go func(n int) {
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
resp, err := client.Get("/api/concurrent", nil)
|
|
||||||
if err != nil {
|
|
||||||
results <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheHeader := resp.Header.Get("X-Cache")
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
if cacheHeader != "HIT" {
|
|
||||||
results <- fmt.Errorf("request %d: expected HIT, got %s", n, cacheHeader)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
results <- nil
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect results
|
|
||||||
var errors []error
|
|
||||||
for i := 0; i < numRequests; i++ {
|
|
||||||
if err := <-results; err != nil {
|
|
||||||
errors = append(errors, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
t.Errorf("Got %d errors in concurrent cache access: %v", len(errors), errors[:min(5, len(errors))])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only 1 request should reach backend (the initial one)
|
|
||||||
if atomic.LoadInt64(&requestCount) != 1 {
|
|
||||||
t.Errorf("Expected 1 backend request, got %d", requestCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Multiple Cache Patterns ==============
|
|
||||||
|
|
||||||
func TestCaching_MultipleCachePatterns(t *testing.T) {
|
|
||||||
var apiCount, staticCount int64
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
|
|
||||||
atomic.AddInt64(&apiCount, 1)
|
|
||||||
} else if len(r.URL.Path) >= 8 && r.URL.Path[:8] == "/static/" {
|
|
||||||
atomic.AddInt64(&staticCount, 1)
|
|
||||||
}
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
|
|
||||||
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
|
|
||||||
"default_ttl": "1m",
|
|
||||||
"cache_patterns": []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/api/.*",
|
|
||||||
"ttl": "30s",
|
|
||||||
"methods": []interface{}{"GET"},
|
|
||||||
},
|
|
||||||
map[string]interface{}{
|
|
||||||
"pattern": "^/static/.*",
|
|
||||||
"ttl": "1h", // Static files cached longer
|
|
||||||
"methods": []interface{}{"GET"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{cachingExt, routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Multiple requests to both patterns
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
resp1, _ := client.Get("/api/data", nil)
|
|
||||||
resp1.Body.Close()
|
|
||||||
|
|
||||||
resp2, _ := client.Get("/static/app.js", nil)
|
|
||||||
resp2.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Both should be cached (1 request each)
|
|
||||||
if atomic.LoadInt64(&apiCount) != 1 {
|
|
||||||
t.Errorf("Expected 1 API request, got %d", apiCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
if atomic.LoadInt64(&staticCount) != 1 {
|
|
||||||
t.Errorf("Expected 1 static request, got %d", staticCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,408 +0,0 @@
|
|||||||
package integration
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/extension"
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
"github.com/konduktor/konduktor/internal/middleware"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestServer represents a running Konduktor server for testing
|
|
||||||
type TestServer struct {
|
|
||||||
Server *http.Server
|
|
||||||
URL string
|
|
||||||
Port int
|
|
||||||
listener net.Listener
|
|
||||||
handler http.Handler
|
|
||||||
t *testing.T
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBackend represents a mock backend server
|
|
||||||
type TestBackend struct {
|
|
||||||
server *httptest.Server
|
|
||||||
requestLog []RequestLogEntry
|
|
||||||
mu sync.Mutex
|
|
||||||
requestCount int64
|
|
||||||
handler http.HandlerFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestLogEntry stores information about a received request
|
|
||||||
type RequestLogEntry struct {
|
|
||||||
Method string
|
|
||||||
Path string
|
|
||||||
Query string
|
|
||||||
Headers http.Header
|
|
||||||
Body string
|
|
||||||
Timestamp time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerConfig holds configuration for starting a test server
|
|
||||||
type ServerConfig struct {
|
|
||||||
Extensions []extension.Extension
|
|
||||||
StaticDir string
|
|
||||||
Middleware []func(http.Handler) http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Test Server ==============
|
|
||||||
|
|
||||||
// StartTestServer creates and starts a Konduktor server for testing
|
|
||||||
func StartTestServer(t *testing.T, cfg *ServerConfig) *TestServer {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
logger, err := logging.New(logging.Config{
|
|
||||||
Level: "DEBUG",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create logger: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create extension manager
|
|
||||||
extManager := extension.NewManager(logger)
|
|
||||||
|
|
||||||
// Add extensions if provided
|
|
||||||
if cfg != nil && len(cfg.Extensions) > 0 {
|
|
||||||
for _, ext := range cfg.Extensions {
|
|
||||||
if err := extManager.AddExtension(ext); err != nil {
|
|
||||||
t.Fatalf("Failed to add extension: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a fallback handler for when no extension handles the request
|
|
||||||
fallback := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.NotFound(w, r)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create handler chain
|
|
||||||
var handler http.Handler = extManager.Handler(fallback)
|
|
||||||
|
|
||||||
// Add middleware
|
|
||||||
handler = middleware.AccessLog(handler, logger)
|
|
||||||
handler = middleware.Recovery(handler, logger)
|
|
||||||
|
|
||||||
// Add custom middleware if provided
|
|
||||||
if cfg != nil {
|
|
||||||
for _, mw := range cfg.Middleware {
|
|
||||||
handler = mw(handler)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find available port
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to find available port: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
port := listener.Addr().(*net.TCPAddr).Port
|
|
||||||
|
|
||||||
server := &http.Server{
|
|
||||||
Handler: handler,
|
|
||||||
ReadTimeout: 10 * time.Second,
|
|
||||||
WriteTimeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
ts := &TestServer{
|
|
||||||
Server: server,
|
|
||||||
URL: fmt.Sprintf("http://127.0.0.1:%d", port),
|
|
||||||
Port: port,
|
|
||||||
listener: listener,
|
|
||||||
handler: handler,
|
|
||||||
t: t,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start server in goroutine
|
|
||||||
go func() {
|
|
||||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
|
||||||
// Don't fail test here as server might be intentionally closed
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for server to be ready
|
|
||||||
ts.waitReady()
|
|
||||||
|
|
||||||
return ts
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitReady waits for the server to be ready to accept connections
|
|
||||||
func (ts *TestServer) waitReady() {
|
|
||||||
deadline := time.Now().Add(5 * time.Second)
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", ts.Port), 100*time.Millisecond)
|
|
||||||
if err == nil {
|
|
||||||
conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
}
|
|
||||||
ts.t.Fatal("Server failed to start within timeout")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close shuts down the test server
|
|
||||||
func (ts *TestServer) Close() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
ts.Server.Shutdown(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Test Backend ==============
|
|
||||||
|
|
||||||
// StartBackend creates and starts a mock backend server
|
|
||||||
func StartBackend(handler http.HandlerFunc) *TestBackend {
|
|
||||||
tb := &TestBackend{
|
|
||||||
requestLog: make([]RequestLogEntry, 0),
|
|
||||||
handler: handler,
|
|
||||||
}
|
|
||||||
|
|
||||||
if handler == nil {
|
|
||||||
handler = tb.defaultHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tb.logRequest(r)
|
|
||||||
handler(w, r)
|
|
||||||
}))
|
|
||||||
|
|
||||||
return tb
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb *TestBackend) logRequest(r *http.Request) {
|
|
||||||
tb.mu.Lock()
|
|
||||||
defer tb.mu.Unlock()
|
|
||||||
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
||||||
|
|
||||||
tb.requestLog = append(tb.requestLog, RequestLogEntry{
|
|
||||||
Method: r.Method,
|
|
||||||
Path: r.URL.Path,
|
|
||||||
Query: r.URL.RawQuery,
|
|
||||||
Headers: r.Header.Clone(),
|
|
||||||
Body: string(body),
|
|
||||||
Timestamp: time.Now(),
|
|
||||||
})
|
|
||||||
atomic.AddInt64(&tb.requestCount, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tb *TestBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"backend": "default",
|
|
||||||
"path": r.URL.Path,
|
|
||||||
"method": r.Method,
|
|
||||||
"query": r.URL.RawQuery,
|
|
||||||
"received": time.Now().Unix(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// URL returns the backend server URL
|
|
||||||
func (tb *TestBackend) URL() string {
|
|
||||||
return tb.server.URL
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close shuts down the backend server
|
|
||||||
func (tb *TestBackend) Close() {
|
|
||||||
tb.server.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestCount returns the number of requests received
|
|
||||||
func (tb *TestBackend) RequestCount() int64 {
|
|
||||||
return atomic.LoadInt64(&tb.requestCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LastRequest returns the most recent request
|
|
||||||
func (tb *TestBackend) LastRequest() *RequestLogEntry {
|
|
||||||
tb.mu.Lock()
|
|
||||||
defer tb.mu.Unlock()
|
|
||||||
if len(tb.requestLog) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &tb.requestLog[len(tb.requestLog)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllRequests returns all logged requests
|
|
||||||
func (tb *TestBackend) AllRequests() []RequestLogEntry {
|
|
||||||
tb.mu.Lock()
|
|
||||||
defer tb.mu.Unlock()
|
|
||||||
result := make([]RequestLogEntry, len(tb.requestLog))
|
|
||||||
copy(result, tb.requestLog)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== HTTP Client Helpers ==============
|
|
||||||
|
|
||||||
// HTTPClient is a configured HTTP client for testing
|
|
||||||
type HTTPClient struct {
|
|
||||||
client *http.Client
|
|
||||||
baseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHTTPClient creates a new test HTTP client
|
|
||||||
func NewHTTPClient(baseURL string) *HTTPClient {
|
|
||||||
return &HTTPClient{
|
|
||||||
client: &http.Client{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse // Don't follow redirects
|
|
||||||
},
|
|
||||||
},
|
|
||||||
baseURL: baseURL,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get performs a GET request
|
|
||||||
func (c *HTTPClient) Get(path string, headers map[string]string) (*http.Response, error) {
|
|
||||||
return c.Do("GET", path, nil, headers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Post performs a POST request
|
|
||||||
func (c *HTTPClient) Post(path string, body []byte, headers map[string]string) (*http.Response, error) {
|
|
||||||
return c.Do("POST", path, body, headers)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do performs an HTTP request
|
|
||||||
func (c *HTTPClient) Do(method, path string, body []byte, headers map[string]string) (*http.Response, error) {
|
|
||||||
var bodyReader io.Reader
|
|
||||||
if body != nil {
|
|
||||||
bodyReader = bytes.NewReader(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(method, c.baseURL+path, bodyReader)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range headers {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.client.Do(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetJSON performs GET and decodes JSON response
|
|
||||||
func (c *HTTPClient) GetJSON(path string, result interface{}) (*http.Response, error) {
|
|
||||||
resp, err := c.Get(path, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== File System Helpers ==============
|
|
||||||
|
|
||||||
// CreateTempDir creates a temporary directory for static files
|
|
||||||
func CreateTempDir(t *testing.T) string {
|
|
||||||
t.Helper()
|
|
||||||
dir, err := os.MkdirTemp("", "konduktor-test-*")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create temp dir: %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { os.RemoveAll(dir) })
|
|
||||||
return dir
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTempFile creates a temporary file with given content
|
|
||||||
func CreateTempFile(t *testing.T, dir, name, content string) string {
|
|
||||||
t.Helper()
|
|
||||||
path := filepath.Join(dir, name)
|
|
||||||
|
|
||||||
// Create parent directories if needed
|
|
||||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
|
||||||
t.Fatalf("Failed to create directories: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to write file: %v", err)
|
|
||||||
}
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Assertion Helpers ==============
|
|
||||||
|
|
||||||
// AssertStatus checks if response has expected status code
|
|
||||||
func AssertStatus(t *testing.T, resp *http.Response, expected int) {
|
|
||||||
t.Helper()
|
|
||||||
if resp.StatusCode != expected {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
t.Errorf("Expected status %d, got %d. Body: %s", expected, resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AssertHeader checks if response has expected header value
|
|
||||||
func AssertHeader(t *testing.T, resp *http.Response, header, expected string) {
|
|
||||||
t.Helper()
|
|
||||||
actual := resp.Header.Get(header)
|
|
||||||
if actual != expected {
|
|
||||||
t.Errorf("Expected header %s=%q, got %q", header, expected, actual)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AssertHeaderContains checks if header contains substring
|
|
||||||
func AssertHeaderContains(t *testing.T, resp *http.Response, header, substring string) {
|
|
||||||
t.Helper()
|
|
||||||
actual := resp.Header.Get(header)
|
|
||||||
if actual == "" || !contains(actual, substring) {
|
|
||||||
t.Errorf("Expected header %s to contain %q, got %q", header, substring, actual)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AssertJSONField checks if JSON response has expected field value
|
|
||||||
func AssertJSONField(t *testing.T, body []byte, field string, expected interface{}) {
|
|
||||||
t.Helper()
|
|
||||||
var data map[string]interface{}
|
|
||||||
if err := json.Unmarshal(body, &data); err != nil {
|
|
||||||
t.Fatalf("Failed to parse JSON: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
actual, ok := data[field]
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("Field %q not found in JSON", field)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if actual != expected {
|
|
||||||
t.Errorf("Expected %s=%v, got %v", field, expected, actual)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func contains(s, substr string) bool {
|
|
||||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0))
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsAt(s, substr string, start int) bool {
|
|
||||||
for i := start; i <= len(s)-len(substr); i++ {
|
|
||||||
if s[i:i+len(substr)] == substr {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadBody reads and returns response body
|
|
||||||
func ReadBody(t *testing.T, resp *http.Response) []byte {
|
|
||||||
t.Helper()
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to read body: %v", err)
|
|
||||||
}
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
@ -1,562 +0,0 @@
|
|||||||
package integration
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/extension"
|
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// createTestLogger creates a logger for tests
|
|
||||||
func createTestLogger(t *testing.T) *logging.Logger {
|
|
||||||
t.Helper()
|
|
||||||
logger, err := logging.New(logging.Config{Level: "DEBUG"})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create logger: %v", err)
|
|
||||||
}
|
|
||||||
return logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Basic Reverse Proxy Tests ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_BasicGET(t *testing.T) {
|
|
||||||
// Start backend server
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"message": "Hello from backend",
|
|
||||||
"path": r.URL.Path,
|
|
||||||
"method": r.Method,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
// Create routing extension with proxy to backend
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, err := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create routing extension: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start Konduktor server
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Make request through Konduktor
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
resp, err := client.Get("/api/test", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Verify response
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
|
|
||||||
var result map[string]interface{}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
t.Fatalf("Failed to decode response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result["message"] != "Hello from backend" {
|
|
||||||
t.Errorf("Unexpected message: %v", result["message"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if result["path"] != "/api/test" {
|
|
||||||
t.Errorf("Expected path /api/test, got %v", result["path"])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify backend received request
|
|
||||||
if backend.RequestCount() != 1 {
|
|
||||||
t.Errorf("Expected 1 backend request, got %d", backend.RequestCount())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReverseProxy_POST(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var body map[string]interface{}
|
|
||||||
json.NewDecoder(r.Body).Decode(&body)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
||||||
"received": body,
|
|
||||||
"method": r.Method,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
body := []byte(`{"name":"test","value":123}`)
|
|
||||||
resp, err := client.Post("/api/data", body, map[string]string{
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
|
|
||||||
var result map[string]interface{}
|
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
|
||||||
|
|
||||||
if result["method"] != "POST" {
|
|
||||||
t.Errorf("Expected method POST, got %v", result["method"])
|
|
||||||
}
|
|
||||||
|
|
||||||
received := result["received"].(map[string]interface{})
|
|
||||||
if received["name"] != "test" {
|
|
||||||
t.Errorf("Expected name 'test', got %v", received["name"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Exact Match Routes ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_ExactMatchRoute(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"endpoint": "version",
|
|
||||||
"path": r.URL.Path,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
// Exact match - should use backend URL as-is
|
|
||||||
"=/api/version": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL() + "/releases/latest",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Test exact match route
|
|
||||||
resp, err := client.Get("/api/version", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
|
|
||||||
lastReq := backend.LastRequest()
|
|
||||||
if lastReq == nil {
|
|
||||||
t.Fatal("No request received by backend")
|
|
||||||
}
|
|
||||||
|
|
||||||
// For exact match, the target path should be used as-is (IgnoreRequestPath=true)
|
|
||||||
if lastReq.Path != "/releases/latest" {
|
|
||||||
t.Errorf("Expected backend path /releases/latest, got %s", lastReq.Path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Regex Routes with Parameters ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_RegexRouteWithParams(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"path": r.URL.Path,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
// Regex with named group
|
|
||||||
"~^/api/users/(?P<id>\\d+)$": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL() + "/v2/users/{id}",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Test regex route with parameter
|
|
||||||
resp, err := client.Get("/api/users/42", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
|
|
||||||
lastReq := backend.LastRequest()
|
|
||||||
if lastReq == nil {
|
|
||||||
t.Fatal("No request received by backend")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parameter {id} should be substituted
|
|
||||||
if lastReq.Path != "/v2/users/42" {
|
|
||||||
t.Errorf("Expected backend path /v2/users/42, got %s", lastReq.Path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Header Forwarding ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_HeaderForwarding(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"x-forwarded-for": r.Header.Get("X-Forwarded-For"),
|
|
||||||
"x-real-ip": r.Header.Get("X-Real-IP"),
|
|
||||||
"x-custom": r.Header.Get("X-Custom"),
|
|
||||||
"x-forwarded-host": r.Header.Get("X-Forwarded-Host"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
"headers": []interface{}{
|
|
||||||
"X-Forwarded-For: $remote_addr",
|
|
||||||
"X-Real-IP: $remote_addr",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
resp, err := client.Get("/test", map[string]string{
|
|
||||||
"X-Custom": "custom-value",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var result map[string]string
|
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
|
||||||
|
|
||||||
// X-Custom should be forwarded
|
|
||||||
if result["x-custom"] != "custom-value" {
|
|
||||||
t.Errorf("Expected X-Custom header to be forwarded, got %v", result["x-custom"])
|
|
||||||
}
|
|
||||||
|
|
||||||
// X-Forwarded-For should be set (will contain 127.0.0.1)
|
|
||||||
if result["x-forwarded-for"] == "" {
|
|
||||||
t.Error("Expected X-Forwarded-For header to be set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Query String ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_QueryStringPreservation(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"query": r.URL.RawQuery,
|
|
||||||
"foo": r.URL.Query().Get("foo"),
|
|
||||||
"bar": r.URL.Query().Get("bar"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
resp, err := client.Get("/search?foo=hello&bar=world", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var result map[string]string
|
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
|
||||||
|
|
||||||
if result["foo"] != "hello" {
|
|
||||||
t.Errorf("Expected foo=hello, got %v", result["foo"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if result["bar"] != "world" {
|
|
||||||
t.Errorf("Expected bar=world, got %v", result["bar"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Error Handling ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_BackendUnavailable(t *testing.T) {
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
// Non-existent backend
|
|
||||||
"proxy_pass": "http://127.0.0.1:59999",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
resp, err := client.Get("/test", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Should return 502 Bad Gateway
|
|
||||||
AssertStatus(t, resp, http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReverseProxy_BackendTimeout(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Simulate slow backend
|
|
||||||
time.Sleep(3 * time.Second)
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
"timeout": 0.5, // 500ms timeout
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
resp, err := client.Get("/slow", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Should return 504 Gateway Timeout
|
|
||||||
AssertStatus(t, resp, http.StatusGatewayTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== HTTP Methods ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_AllMethods(t *testing.T) {
|
|
||||||
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
|
|
||||||
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"method": r.Method,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
for _, method := range methods {
|
|
||||||
t.Run(method, func(t *testing.T) {
|
|
||||||
resp, err := client.Do(method, "/resource", nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
|
|
||||||
if method != "HEAD" {
|
|
||||||
var result map[string]string
|
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
|
||||||
|
|
||||||
if result["method"] != method {
|
|
||||||
t.Errorf("Expected method %s, got %v", method, result["method"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Large Bodies ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_LargeRequestBody(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
json.NewEncoder(w).Encode(map[string]int{
|
|
||||||
"received": len(body),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// 1MB body
|
|
||||||
largeBody := []byte(strings.Repeat("x", 1024*1024))
|
|
||||||
resp, err := client.Post("/upload", largeBody, map[string]string{
|
|
||||||
"Content-Type": "application/octet-stream",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Concurrent Requests ==============
|
|
||||||
|
|
||||||
func TestReverseProxy_ConcurrentRequests(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Small delay to simulate work
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
const numRequests = 50
|
|
||||||
results := make(chan error, numRequests)
|
|
||||||
|
|
||||||
for i := 0; i < numRequests; i++ {
|
|
||||||
go func(n int) {
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
resp, err := client.Get(fmt.Sprintf("/concurrent/%d", n), nil)
|
|
||||||
if err != nil {
|
|
||||||
results <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
results <- fmt.Errorf("unexpected status: %d", resp.StatusCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
results <- nil
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect results
|
|
||||||
var errors []error
|
|
||||||
for i := 0; i < numRequests; i++ {
|
|
||||||
if err := <-results; err != nil {
|
|
||||||
errors = append(errors, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
t.Errorf("Got %d errors in concurrent requests: %v", len(errors), errors[:min(5, len(errors))])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify all requests reached backend
|
|
||||||
if backend.RequestCount() != numRequests {
|
|
||||||
t.Errorf("Expected %d backend requests, got %d", numRequests, backend.RequestCount())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func min(a, b int) int {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
@ -1,494 +0,0 @@
|
|||||||
package integration
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/extension"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ============== Route Priority Tests ==============
|
|
||||||
|
|
||||||
func TestRouting_ExactMatchPriority(t *testing.T) {
|
|
||||||
// Exact match should have highest priority
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"path": r.URL.Path,
|
|
||||||
"source": "default",
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, err := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
// Exact match - highest priority
|
|
||||||
"=/api/status": map[string]interface{}{
|
|
||||||
"return": "200 exact-match",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
// Regex that also matches /api/status
|
|
||||||
"~^/api/.*": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create routing extension: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Test exact match route - should return static response
|
|
||||||
resp, err := client.Get("/api/status", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
|
|
||||||
body := ReadBody(t, resp)
|
|
||||||
if string(body) != "exact-match" {
|
|
||||||
t.Errorf("Expected 'exact-match', got %q", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regex route should be used for other /api/* paths
|
|
||||||
resp2, err := client.Get("/api/other", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp2.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp2, http.StatusOK)
|
|
||||||
|
|
||||||
// Verify it went to backend
|
|
||||||
if backend.RequestCount() != 1 {
|
|
||||||
t.Errorf("Expected 1 backend request, got %d", backend.RequestCount())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Case Sensitivity Tests ==============
|
|
||||||
|
|
||||||
func TestRouting_CaseSensitiveRegex(t *testing.T) {
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
// Case-sensitive regex (~)
|
|
||||||
"~^/API/test$": map[string]interface{}{
|
|
||||||
"return": "200 case-sensitive",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"return": "200 default",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Exact case match should work
|
|
||||||
resp, err := client.Get("/API/test", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body := ReadBody(t, resp)
|
|
||||||
if string(body) != "case-sensitive" {
|
|
||||||
t.Errorf("Expected 'case-sensitive' for /API/test, got %q", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different case should NOT match
|
|
||||||
resp2, err := client.Get("/api/test", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp2.Body.Close()
|
|
||||||
|
|
||||||
body2 := ReadBody(t, resp2)
|
|
||||||
if string(body2) != "default" {
|
|
||||||
t.Errorf("Expected 'default' for /api/test (case mismatch), got %q", string(body2))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouting_CaseInsensitiveRegex(t *testing.T) {
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
// Case-insensitive regex (~*)
|
|
||||||
"~*^/api/test$": map[string]interface{}{
|
|
||||||
"return": "200 case-insensitive",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"return": "200 default",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
path string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"/api/test", "case-insensitive"},
|
|
||||||
{"/API/test", "case-insensitive"},
|
|
||||||
{"/Api/Test", "case-insensitive"},
|
|
||||||
{"/API/TEST", "case-insensitive"},
|
|
||||||
{"/api/other", "default"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.path, func(t *testing.T) {
|
|
||||||
resp, err := client.Get(tc.path, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body := ReadBody(t, resp)
|
|
||||||
if string(body) != tc.expected {
|
|
||||||
t.Errorf("Expected %q for %s, got %q", tc.expected, tc.path, string(body))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Default Route Tests ==============
|
|
||||||
|
|
||||||
func TestRouting_DefaultRoute(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"handler": "default",
|
|
||||||
"path": r.URL.Path,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"=/specific": map[string]interface{}{
|
|
||||||
"return": "200 specific",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Non-matching paths should go to default
|
|
||||||
paths := []string{"/", "/random", "/path/to/resource", "/api/v1/users"}
|
|
||||||
|
|
||||||
for _, path := range paths {
|
|
||||||
t.Run(path, func(t *testing.T) {
|
|
||||||
resp, err := client.Get(path, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
|
|
||||||
var result map[string]string
|
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
|
||||||
|
|
||||||
if result["handler"] != "default" {
|
|
||||||
t.Errorf("Expected default handler, got %v", result["handler"])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Return Directive Tests ==============
|
|
||||||
|
|
||||||
func TestRouting_ReturnDirective(t *testing.T) {
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"=/health": map[string]interface{}{
|
|
||||||
"return": "200 OK",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"=/status": map[string]interface{}{
|
|
||||||
"return": "200 {\"status\": \"healthy\"}",
|
|
||||||
"content_type": "application/json",
|
|
||||||
},
|
|
||||||
"=/forbidden": map[string]interface{}{
|
|
||||||
"return": "404 Not Found",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"return": "200 default",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
path string
|
|
||||||
expectedStatus int
|
|
||||||
expectedBody string
|
|
||||||
contentType string
|
|
||||||
}{
|
|
||||||
{"/health", 200, "OK", "text/plain"},
|
|
||||||
{"/status", 200, `{"status": "healthy"}`, "application/json"},
|
|
||||||
{"/forbidden", 404, "Not Found", "text/plain"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.path, func(t *testing.T) {
|
|
||||||
resp, err := client.Get(tc.path, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, tc.expectedStatus)
|
|
||||||
AssertHeaderContains(t, resp, "Content-Type", tc.contentType)
|
|
||||||
|
|
||||||
body := ReadBody(t, resp)
|
|
||||||
if string(body) != tc.expectedBody {
|
|
||||||
t.Errorf("Expected body %q, got %q", tc.expectedBody, string(body))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Multiple Regex Routes Tests ==============
|
|
||||||
|
|
||||||
func TestRouting_MultipleRegexRoutes(t *testing.T) {
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"~^/api/v1/.*": map[string]interface{}{
|
|
||||||
"return": "200 v1",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"~^/api/v2/.*": map[string]interface{}{
|
|
||||||
"return": "200 v2",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"~^/api/.*": map[string]interface{}{
|
|
||||||
"return": "200 api-generic",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"return": "200 default",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
path string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"/api/v1/users", "v1"},
|
|
||||||
{"/api/v2/users", "v2"},
|
|
||||||
{"/api/v3/users", "api-generic"},
|
|
||||||
{"/other", "default"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.path, func(t *testing.T) {
|
|
||||||
resp, err := client.Get(tc.path, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body := ReadBody(t, resp)
|
|
||||||
if string(body) != tc.expected {
|
|
||||||
t.Errorf("Expected %q for %s, got %q", tc.expected, tc.path, string(body))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Regex with Named Groups ==============
|
|
||||||
|
|
||||||
func TestRouting_RegexNamedGroups(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"path": r.URL.Path,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"~^/users/(?P<userId>\\d+)/posts/(?P<postId>\\d+)$": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL() + "/api/v2/users/{userId}/posts/{postId}",
|
|
||||||
},
|
|
||||||
"~^/items/(?P<category>[a-z]+)/(?P<id>\\d+)$": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL() + "/catalog/{category}/item/{id}",
|
|
||||||
},
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
requestPath string
|
|
||||||
expectedPath string
|
|
||||||
}{
|
|
||||||
{"/users/123/posts/456", "/api/v2/users/123/posts/456"},
|
|
||||||
{"/items/electronics/789", "/catalog/electronics/item/789"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.requestPath, func(t *testing.T) {
|
|
||||||
resp, err := client.Get(tc.requestPath, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusOK)
|
|
||||||
|
|
||||||
lastReq := backend.LastRequest()
|
|
||||||
if lastReq == nil {
|
|
||||||
t.Fatal("No request received by backend")
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastReq.Path != tc.expectedPath {
|
|
||||||
t.Errorf("Expected backend path %s, got %s", tc.expectedPath, lastReq.Path)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== No Matching Route Tests ==============
|
|
||||||
|
|
||||||
func TestRouting_NoMatchingRoute(t *testing.T) {
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"=/specific": map[string]interface{}{
|
|
||||||
"return": "200 specific",
|
|
||||||
"content_type": "text/plain",
|
|
||||||
},
|
|
||||||
// No default route
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
// Request to non-matching path should return 404
|
|
||||||
resp, err := client.Get("/other", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
AssertStatus(t, resp, http.StatusNotFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============== Headers in Return Tests ==============
|
|
||||||
|
|
||||||
func TestRouting_CustomHeaders(t *testing.T) {
|
|
||||||
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"x-custom-header": r.Header.Get("X-Custom-Header"),
|
|
||||||
"x-api-version": r.Header.Get("X-API-Version"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
logger := createTestLogger(t)
|
|
||||||
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
|
||||||
"regex_locations": map[string]interface{}{
|
|
||||||
"__default__": map[string]interface{}{
|
|
||||||
"proxy_pass": backend.URL(),
|
|
||||||
"headers": []interface{}{
|
|
||||||
"X-Custom-Header: custom-value",
|
|
||||||
"X-API-Version: v1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, logger)
|
|
||||||
|
|
||||||
server := StartTestServer(t, &ServerConfig{
|
|
||||||
Extensions: []extension.Extension{routingExt},
|
|
||||||
})
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewHTTPClient(server.URL)
|
|
||||||
|
|
||||||
resp, err := client.Get("/test", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Request failed: %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var result map[string]string
|
|
||||||
json.NewDecoder(resp.Body).Decode(&result)
|
|
||||||
|
|
||||||
if result["x-custom-header"] != "custom-value" {
|
|
||||||
t.Errorf("Expected X-Custom-Header=custom-value, got %v", result["x-custom-header"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if result["x-api-version"] != "v1" {
|
|
||||||
t.Errorf("Expected X-API-Version=v1, got %v", result["x-api-version"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
6
poetry.lock
generated
6
poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "a2wsgi"
|
name = "a2wsgi"
|
||||||
@ -1720,5 +1720,5 @@ wsgi = ["a2wsgi"]
|
|||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.12"
|
python-versions = ">=3.12, <=3.13.7"
|
||||||
content-hash = "653d7b992e2bb133abde2e8b1c44265e948ed90487ab3f2670429510a8aa0683"
|
content-hash = "411b746f1a577ed635af9fd3e01daf1fa03950d27ef23888fc7cdd0b99762404"
|
||||||
|
|||||||
@ -3,11 +3,11 @@ name = "pyserve"
|
|||||||
version = "0.9.10"
|
version = "0.9.10"
|
||||||
description = "Python Application Orchestrator & HTTP Server - unified gateway for multiple Python web apps"
|
description = "Python Application Orchestrator & HTTP Server - unified gateway for multiple Python web apps"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Илья Глазунов",email = "i.glazunov@sapiens.solutions"}
|
{name = "Илья Глазунов",email = "lead@pyserve.org"}
|
||||||
]
|
]
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12, <=3.13.7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"starlette (>=0.47.3,<0.48.0)",
|
"starlette (>=0.47.3,<0.48.0)",
|
||||||
"uvicorn[standard] (>=0.35.0,<0.36.0)",
|
"uvicorn[standard] (>=0.35.0,<0.36.0)",
|
||||||
|
|||||||
@ -24,24 +24,24 @@ cdef class FastMountedPath:
|
|||||||
def __cinit__(self):
|
def __cinit__(self):
|
||||||
self._path = ""
|
self._path = ""
|
||||||
self._path_with_slash = "/"
|
self._path_with_slash = "/"
|
||||||
self._path_len = 0
|
self._path_len = <Py_ssize_t>0
|
||||||
self._is_root = 1
|
self._is_root = <bint>True
|
||||||
self.name = ""
|
self.name = ""
|
||||||
self.strip_path = 1
|
self.strip_path = <bint>True
|
||||||
|
|
||||||
def __init__(self, str path, str name="", bint strip_path=True):
|
def __init__(self, str path, str name="", bint strip_path=<bint>True):
|
||||||
cdef Py_ssize_t path_len
|
cdef Py_ssize_t path_len
|
||||||
|
|
||||||
path_len = len(path)
|
path_len = <Py_ssize_t>len(path)
|
||||||
if path_len > 1 and path[path_len - 1] == '/':
|
if path_len > 1 and path[path_len - 1] == '/':
|
||||||
path = path[:path_len - 1]
|
path = path[:path_len - 1]
|
||||||
|
|
||||||
self._path = path
|
self._path = path
|
||||||
self._path_len = len(path)
|
self._path_len = <Py_ssize_t>len(path)
|
||||||
self._is_root = 1 if (path == "" or path == "/") else 0
|
self._is_root = <bint>(path == "" or path == "/")
|
||||||
self._path_with_slash = path + "/" if self._is_root == 0 else "/"
|
self._path_with_slash = path + "/" if not self._is_root else "/"
|
||||||
self.name = name if name else path
|
self.name = name if name else path
|
||||||
self.strip_path = 1 if strip_path else 0
|
self.strip_path = <bint>strip_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self) -> str:
|
def path(self) -> str:
|
||||||
@ -51,20 +51,20 @@ cdef class FastMountedPath:
|
|||||||
cdef Py_ssize_t req_len
|
cdef Py_ssize_t req_len
|
||||||
|
|
||||||
if self._is_root:
|
if self._is_root:
|
||||||
return 1
|
return <bint>True
|
||||||
|
|
||||||
req_len = len(request_path)
|
req_len = <Py_ssize_t>len(request_path)
|
||||||
|
|
||||||
if req_len < self._path_len:
|
if req_len < self._path_len:
|
||||||
return 0
|
return <bint>False
|
||||||
|
|
||||||
if req_len == self._path_len:
|
if req_len == self._path_len:
|
||||||
return 1 if request_path == self._path else 0
|
return <bint>(request_path == self._path)
|
||||||
|
|
||||||
if request_path[self._path_len] == '/':
|
if request_path[self._path_len] == '/':
|
||||||
return 1 if request_path[:self._path_len] == self._path else 0
|
return <bint>(request_path[:self._path_len] == self._path)
|
||||||
|
|
||||||
return 0
|
return <bint>False
|
||||||
|
|
||||||
cpdef str get_modified_path(self, str original_path):
|
cpdef str get_modified_path(self, str original_path):
|
||||||
cdef str new_path
|
cdef str new_path
|
||||||
@ -108,7 +108,7 @@ cdef class FastMountManager:
|
|||||||
self._mounts = sorted(self._mounts, key=_get_path_len_neg, reverse=False)
|
self._mounts = sorted(self._mounts, key=_get_path_len_neg, reverse=False)
|
||||||
self._mount_count = len(self._mounts)
|
self._mount_count = len(self._mounts)
|
||||||
|
|
||||||
cpdef FastMountedPath get_mount(self, str request_path):
|
cpdef object get_mount(self, str request_path):
|
||||||
cdef:
|
cdef:
|
||||||
int i
|
int i
|
||||||
FastMountedPath mount
|
FastMountedPath mount
|
||||||
@ -126,7 +126,7 @@ cdef class FastMountManager:
|
|||||||
Py_ssize_t path_len
|
Py_ssize_t path_len
|
||||||
FastMountedPath mount
|
FastMountedPath mount
|
||||||
|
|
||||||
path_len = len(path)
|
path_len = <Py_ssize_t>len(path)
|
||||||
if path_len > 1 and path[path_len - 1] == '/':
|
if path_len > 1 and path[path_len - 1] == '/':
|
||||||
path = path[:path_len - 1]
|
path = path[:path_len - 1]
|
||||||
|
|
||||||
@ -135,9 +135,9 @@ cdef class FastMountManager:
|
|||||||
if mount._path == path:
|
if mount._path == path:
|
||||||
del self._mounts[i]
|
del self._mounts[i]
|
||||||
self._mount_count -= 1
|
self._mount_count -= 1
|
||||||
return 1
|
return <bint>True
|
||||||
|
|
||||||
return 0
|
return <bint>False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mounts(self) -> list:
|
def mounts(self) -> list:
|
||||||
@ -164,27 +164,27 @@ cdef class FastMountManager:
|
|||||||
|
|
||||||
cpdef bint path_matches_prefix(str request_path, str mount_path):
|
cpdef bint path_matches_prefix(str request_path, str mount_path):
|
||||||
cdef:
|
cdef:
|
||||||
Py_ssize_t mount_len = len(mount_path)
|
Py_ssize_t mount_len = <Py_ssize_t>len(mount_path)
|
||||||
Py_ssize_t req_len = len(request_path)
|
Py_ssize_t req_len = <Py_ssize_t>len(request_path)
|
||||||
|
|
||||||
if mount_len == 0 or mount_path == "/":
|
if mount_len == 0 or mount_path == "/":
|
||||||
return 1
|
return <bint>True
|
||||||
|
|
||||||
if req_len < mount_len:
|
if req_len < mount_len:
|
||||||
return 0
|
return <bint>False
|
||||||
|
|
||||||
if req_len == mount_len:
|
if req_len == mount_len:
|
||||||
return 1 if request_path == mount_path else 0
|
return <bint>(request_path == mount_path)
|
||||||
|
|
||||||
if request_path[mount_len] == '/':
|
if request_path[mount_len] == '/':
|
||||||
return 1 if request_path[:mount_len] == mount_path else 0
|
return <bint>(request_path[:mount_len] == mount_path)
|
||||||
|
|
||||||
return 0
|
return <bint>False
|
||||||
|
|
||||||
|
|
||||||
cpdef str strip_path_prefix(str original_path, str mount_path):
|
cpdef str strip_path_prefix(str original_path, str mount_path):
|
||||||
cdef:
|
cdef:
|
||||||
Py_ssize_t mount_len = len(mount_path)
|
Py_ssize_t mount_len = <Py_ssize_t>len(mount_path)
|
||||||
str result
|
str result
|
||||||
|
|
||||||
if mount_len == 0 or mount_path == "/":
|
if mount_len == 0 or mount_path == "/":
|
||||||
@ -198,11 +198,11 @@ cpdef str strip_path_prefix(str original_path, str mount_path):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
cpdef tuple match_and_modify_path(str request_path, str mount_path, bint strip_path=True):
|
cpdef tuple match_and_modify_path(str request_path, str mount_path, bint strip_path=<bint>True):
|
||||||
cdef:
|
cdef:
|
||||||
Py_ssize_t mount_len = len(mount_path)
|
Py_ssize_t mount_len = <Py_ssize_t>len(mount_path)
|
||||||
Py_ssize_t req_len = len(request_path)
|
Py_ssize_t req_len = <Py_ssize_t>len(request_path)
|
||||||
bint is_root = 1 if (mount_len == 0 or mount_path == "/") else 0
|
bint is_root = <bint>(mount_len == 0 or mount_path == "/")
|
||||||
str modified
|
str modified
|
||||||
|
|
||||||
if is_root:
|
if is_root:
|
||||||
|
|||||||
486
pyserve/_routing.pyx
Normal file
486
pyserve/_routing.pyx
Normal file
@ -0,0 +1,486 @@
|
|||||||
|
# cython: language_level=3
|
||||||
|
# cython: boundscheck=False
|
||||||
|
# cython: wraparound=False
|
||||||
|
# cython: cdivision=True
|
||||||
|
from libc.stdint cimport uint32_t
|
||||||
|
from libc.stddef cimport size_t
|
||||||
|
from cpython.bytes cimport PyBytes_AsString, PyBytes_GET_SIZE
|
||||||
|
|
||||||
|
cimport cython
|
||||||
|
from ._routing_pcre2 cimport *
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Type aliases for cleaner NULL casts
|
||||||
|
ctypedef pcre2_compile_context* compile_ctx_ptr
|
||||||
|
ctypedef pcre2_match_context* match_ctx_ptr
|
||||||
|
ctypedef pcre2_general_context* general_ctx_ptr
|
||||||
|
|
||||||
|
|
||||||
|
# Buffer size for error messages
|
||||||
|
DEF ERROR_BUFFER_SIZE = 256
|
||||||
|
|
||||||
|
# Maximum capture groups we support
|
||||||
|
DEF MAX_CAPTURE_GROUPS = 32
|
||||||
|
|
||||||
|
|
||||||
|
cdef class PCRE2Pattern:
|
||||||
|
cdef:
|
||||||
|
pcre2_code* _code
|
||||||
|
pcre2_match_data* _match_data
|
||||||
|
bint _jit_available
|
||||||
|
str _pattern_str
|
||||||
|
uint32_t _capture_count
|
||||||
|
dict _name_to_index # Named capture groups
|
||||||
|
list _index_to_name # Index to name mapping
|
||||||
|
|
||||||
|
def __cinit__(self):
|
||||||
|
self._code = NULL
|
||||||
|
self._match_data = NULL
|
||||||
|
self._jit_available = <bint>False
|
||||||
|
self._capture_count = 0
|
||||||
|
self._name_to_index = {}
|
||||||
|
self._index_to_name = []
|
||||||
|
|
||||||
|
def __dealloc__(self):
|
||||||
|
if self._match_data is not NULL:
|
||||||
|
pcre2_match_data_free(self._match_data)
|
||||||
|
self._match_data = NULL
|
||||||
|
if self._code is not NULL:
|
||||||
|
pcre2_code_free(self._code)
|
||||||
|
self._code = NULL
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef PCRE2Pattern _create(str pattern, bint case_insensitive=<bint>False, bint use_jit=<bint>True):
|
||||||
|
cdef:
|
||||||
|
PCRE2Pattern self = PCRE2Pattern.__new__(PCRE2Pattern)
|
||||||
|
bytes pattern_bytes
|
||||||
|
const char* pattern_ptr
|
||||||
|
Py_ssize_t pattern_len
|
||||||
|
uint32_t options = 0
|
||||||
|
int errorcode = 0
|
||||||
|
PCRE2_SIZE erroroffset = 0
|
||||||
|
int jit_result
|
||||||
|
uint32_t capture_count = 0
|
||||||
|
|
||||||
|
self._pattern_str = pattern
|
||||||
|
self._name_to_index = {}
|
||||||
|
self._index_to_name = []
|
||||||
|
|
||||||
|
pattern_bytes = pattern.encode('utf-8')
|
||||||
|
pattern_ptr = PyBytes_AsString(pattern_bytes)
|
||||||
|
pattern_len = PyBytes_GET_SIZE(pattern_bytes)
|
||||||
|
|
||||||
|
options = PCRE2_UTF | PCRE2_UCP
|
||||||
|
if case_insensitive:
|
||||||
|
options |= PCRE2_CASELESS
|
||||||
|
|
||||||
|
self._code = pcre2_compile(
|
||||||
|
<PCRE2_SPTR>pattern_ptr,
|
||||||
|
<PCRE2_SIZE>pattern_len,
|
||||||
|
options,
|
||||||
|
&errorcode,
|
||||||
|
&erroroffset,
|
||||||
|
<compile_ctx_ptr>NULL
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._code is NULL:
|
||||||
|
error_msg = PCRE2Pattern._get_error_message(errorcode)
|
||||||
|
raise ValueError(f"PCRE2 compile error at offset {erroroffset}: {error_msg}")
|
||||||
|
|
||||||
|
if use_jit:
|
||||||
|
jit_result = pcre2_jit_compile(self._code, PCRE2_JIT_COMPLETE)
|
||||||
|
self._jit_available = <bint>(jit_result == 0)
|
||||||
|
|
||||||
|
pcre2_pattern_info(self._code, PCRE2_INFO_CAPTURECOUNT, <void*>&capture_count)
|
||||||
|
self._capture_count = capture_count
|
||||||
|
|
||||||
|
self._match_data = pcre2_match_data_create_from_pattern(self._code, <general_ctx_ptr>NULL)
|
||||||
|
if self._match_data is NULL:
|
||||||
|
pcre2_code_free(self._code)
|
||||||
|
self._code = NULL
|
||||||
|
raise MemoryError("Failed to create match data")
|
||||||
|
|
||||||
|
self._extract_named_groups()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
cdef void _extract_named_groups(self):
|
||||||
|
cdef:
|
||||||
|
uint32_t namecount = 0
|
||||||
|
uint32_t nameentrysize = 0
|
||||||
|
PCRE2_SPTR nametable
|
||||||
|
uint32_t i
|
||||||
|
int group_num
|
||||||
|
bytes name_bytes
|
||||||
|
str name
|
||||||
|
|
||||||
|
pcre2_pattern_info(self._code, PCRE2_INFO_NAMECOUNT, <void*>&namecount)
|
||||||
|
|
||||||
|
if namecount == 0:
|
||||||
|
return # void return
|
||||||
|
|
||||||
|
pcre2_pattern_info(self._code, PCRE2_INFO_NAMEENTRYSIZE, <void*>&nameentrysize)
|
||||||
|
pcre2_pattern_info(self._code, PCRE2_INFO_NAMETABLE, <void*>&nametable)
|
||||||
|
|
||||||
|
self._index_to_name = [None] * (self._capture_count + 1)
|
||||||
|
|
||||||
|
for i in range(namecount):
|
||||||
|
group_num = (<int>nametable[0] << 8) | <int>nametable[1]
|
||||||
|
name_bytes = <bytes>(nametable + 2)
|
||||||
|
name = name_bytes.decode('utf-8')
|
||||||
|
|
||||||
|
self._name_to_index[name] = group_num
|
||||||
|
if <uint32_t>group_num <= self._capture_count:
|
||||||
|
self._index_to_name[<Py_ssize_t>group_num] = name
|
||||||
|
|
||||||
|
nametable += nameentrysize
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef str _get_error_message(int errorcode):
|
||||||
|
cdef:
|
||||||
|
PCRE2_UCHAR buffer[ERROR_BUFFER_SIZE]
|
||||||
|
int result
|
||||||
|
|
||||||
|
result = pcre2_get_error_message(errorcode, buffer, ERROR_BUFFER_SIZE)
|
||||||
|
if result < 0:
|
||||||
|
return f"Unknown error {errorcode}"
|
||||||
|
return (<bytes>buffer).decode('utf-8')
|
||||||
|
|
||||||
|
cpdef bint search(self, str subject):
|
||||||
|
"""
|
||||||
|
Search for pattern anywhere in subject.
|
||||||
|
Returns True if found, False otherwise.
|
||||||
|
"""
|
||||||
|
cdef:
|
||||||
|
bytes subject_bytes
|
||||||
|
const char* subject_ptr
|
||||||
|
Py_ssize_t subject_len
|
||||||
|
int result
|
||||||
|
|
||||||
|
if self._code is NULL:
|
||||||
|
return <bint>False
|
||||||
|
|
||||||
|
subject_bytes = subject.encode('utf-8')
|
||||||
|
subject_ptr = PyBytes_AsString(subject_bytes)
|
||||||
|
subject_len = PyBytes_GET_SIZE(subject_bytes)
|
||||||
|
|
||||||
|
if self._jit_available:
|
||||||
|
result = pcre2_jit_match(
|
||||||
|
self._code,
|
||||||
|
<PCRE2_SPTR>subject_ptr,
|
||||||
|
<PCRE2_SIZE>subject_len,
|
||||||
|
0, # start offset
|
||||||
|
0, # options
|
||||||
|
self._match_data,
|
||||||
|
<match_ctx_ptr>NULL
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = pcre2_match(
|
||||||
|
self._code,
|
||||||
|
<PCRE2_SPTR>subject_ptr,
|
||||||
|
<PCRE2_SIZE>subject_len,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
self._match_data,
|
||||||
|
<match_ctx_ptr>NULL
|
||||||
|
)
|
||||||
|
|
||||||
|
return <bint>(result >= 0)
|
||||||
|
|
||||||
|
cpdef dict groupdict(self, str subject):
|
||||||
|
"""
|
||||||
|
Match pattern and return dict of named groups.
|
||||||
|
Returns empty dict if no match or no named groups.
|
||||||
|
"""
|
||||||
|
cdef:
|
||||||
|
bytes subject_bytes
|
||||||
|
const char* subject_ptr
|
||||||
|
Py_ssize_t subject_len
|
||||||
|
int result
|
||||||
|
PCRE2_SIZE* ovector
|
||||||
|
dict groups = {}
|
||||||
|
str name
|
||||||
|
int index
|
||||||
|
PCRE2_SIZE start, end
|
||||||
|
|
||||||
|
if self._code is NULL or not self._name_to_index:
|
||||||
|
return groups
|
||||||
|
|
||||||
|
subject_bytes = subject.encode('utf-8')
|
||||||
|
subject_ptr = PyBytes_AsString(subject_bytes)
|
||||||
|
subject_len = PyBytes_GET_SIZE(subject_bytes)
|
||||||
|
|
||||||
|
if self._jit_available:
|
||||||
|
result = pcre2_jit_match(
|
||||||
|
self._code,
|
||||||
|
<PCRE2_SPTR>subject_ptr,
|
||||||
|
<PCRE2_SIZE>subject_len,
|
||||||
|
0, 0,
|
||||||
|
self._match_data,
|
||||||
|
<match_ctx_ptr>NULL
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = pcre2_match(
|
||||||
|
self._code,
|
||||||
|
<PCRE2_SPTR>subject_ptr,
|
||||||
|
<PCRE2_SIZE>subject_len,
|
||||||
|
0, 0,
|
||||||
|
self._match_data,
|
||||||
|
<match_ctx_ptr>NULL
|
||||||
|
)
|
||||||
|
|
||||||
|
if result < 0:
|
||||||
|
return groups
|
||||||
|
|
||||||
|
ovector = pcre2_get_ovector_pointer(self._match_data)
|
||||||
|
|
||||||
|
for name, index in self._name_to_index.items():
|
||||||
|
start = ovector[<Py_ssize_t>(2 * index)]
|
||||||
|
end = ovector[<Py_ssize_t>(2 * index + 1)]
|
||||||
|
if start != PCRE2_UNSET and end != PCRE2_UNSET:
|
||||||
|
groups[name] = subject_bytes[start:end].decode('utf-8')
|
||||||
|
else:
|
||||||
|
groups[name] = None
|
||||||
|
|
||||||
|
return groups
|
||||||
|
|
||||||
|
cpdef tuple search_with_groups(self, str subject):
|
||||||
|
cdef:
|
||||||
|
bytes subject_bytes
|
||||||
|
const char* subject_ptr
|
||||||
|
Py_ssize_t subject_len
|
||||||
|
int result
|
||||||
|
PCRE2_SIZE* ovector
|
||||||
|
dict groups = {}
|
||||||
|
str name
|
||||||
|
int index
|
||||||
|
PCRE2_SIZE start, end
|
||||||
|
|
||||||
|
if self._code is NULL:
|
||||||
|
return (False, {})
|
||||||
|
|
||||||
|
subject_bytes = subject.encode('utf-8')
|
||||||
|
subject_ptr = PyBytes_AsString(subject_bytes)
|
||||||
|
subject_len = PyBytes_GET_SIZE(subject_bytes)
|
||||||
|
|
||||||
|
if self._jit_available:
|
||||||
|
result = pcre2_jit_match(
|
||||||
|
self._code,
|
||||||
|
<PCRE2_SPTR>subject_ptr,
|
||||||
|
<PCRE2_SIZE>subject_len,
|
||||||
|
0, 0,
|
||||||
|
self._match_data,
|
||||||
|
<match_ctx_ptr>NULL
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = pcre2_match(
|
||||||
|
self._code,
|
||||||
|
<PCRE2_SPTR>subject_ptr,
|
||||||
|
<PCRE2_SIZE>subject_len,
|
||||||
|
0, 0,
|
||||||
|
self._match_data,
|
||||||
|
<match_ctx_ptr>NULL
|
||||||
|
)
|
||||||
|
|
||||||
|
if result < 0:
|
||||||
|
return (False, {})
|
||||||
|
|
||||||
|
if self._name_to_index:
|
||||||
|
ovector = pcre2_get_ovector_pointer(self._match_data)
|
||||||
|
for name, index in self._name_to_index.items():
|
||||||
|
start = ovector[<Py_ssize_t>(2 * index)]
|
||||||
|
end = ovector[<Py_ssize_t>(2 * index + 1)]
|
||||||
|
if start != PCRE2_UNSET and end != PCRE2_UNSET:
|
||||||
|
groups[name] = subject_bytes[start:end].decode('utf-8')
|
||||||
|
else:
|
||||||
|
groups[name] = None
|
||||||
|
|
||||||
|
return (True, groups)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pattern(self) -> str:
|
||||||
|
return self._pattern_str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def jit_compiled(self) -> bool:
|
||||||
|
return <bint>self._jit_available
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capture_count(self) -> int:
|
||||||
|
return self._capture_count
|
||||||
|
|
||||||
|
|
||||||
|
cdef class FastRouteMatch:
|
||||||
|
cdef:
|
||||||
|
public dict config
|
||||||
|
public dict params
|
||||||
|
|
||||||
|
def __cinit__(self):
|
||||||
|
self.config = {}
|
||||||
|
self.params = {}
|
||||||
|
|
||||||
|
def __init__(self, dict config, params=None):
|
||||||
|
self.config = config
|
||||||
|
self.params = params if params is not None else {}
|
||||||
|
|
||||||
|
|
||||||
|
cdef class FastRouter:
|
||||||
|
"""
|
||||||
|
High-performance router with PCRE2 JIT-compiled patterns.
|
||||||
|
|
||||||
|
Matching order (nginx-like):
|
||||||
|
1. Exact routes (prefix "=") - O(1) dict lookup
|
||||||
|
2. Regex routes (prefix "~" or "~*") - PCRE2 JIT matching
|
||||||
|
3. Default route (fallback)
|
||||||
|
"""
|
||||||
|
cdef:
|
||||||
|
dict _exact_routes
|
||||||
|
list _regex_routes
|
||||||
|
dict _default_route
|
||||||
|
bint _has_default
|
||||||
|
int _regex_count
|
||||||
|
|
||||||
|
def __cinit__(self):
|
||||||
|
self._exact_routes = {}
|
||||||
|
self._regex_routes = []
|
||||||
|
self._default_route = {}
|
||||||
|
self._has_default = <bint>False
|
||||||
|
self._regex_count = 0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._exact_routes = {}
|
||||||
|
self._regex_routes = []
|
||||||
|
self._default_route = {}
|
||||||
|
self._has_default = <bint>False
|
||||||
|
self._regex_count = 0
|
||||||
|
|
||||||
|
def add_route(self, str pattern, dict config):
|
||||||
|
cdef:
|
||||||
|
str exact_path
|
||||||
|
str regex_pattern
|
||||||
|
bint case_insensitive
|
||||||
|
PCRE2Pattern compiled_pattern
|
||||||
|
|
||||||
|
if pattern.startswith("="):
|
||||||
|
exact_path = pattern[1:]
|
||||||
|
self._exact_routes[exact_path] = config
|
||||||
|
|
||||||
|
elif pattern == "__default__":
|
||||||
|
self._default_route = config
|
||||||
|
self._has_default = <bint>True
|
||||||
|
|
||||||
|
elif pattern.startswith("~"):
|
||||||
|
case_insensitive = <bint>pattern.startswith("~*")
|
||||||
|
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
|
||||||
|
|
||||||
|
try:
|
||||||
|
compiled_pattern = PCRE2Pattern._create(regex_pattern, case_insensitive)
|
||||||
|
self._regex_routes.append((compiled_pattern, config))
|
||||||
|
self._regex_count = len(self._regex_routes)
|
||||||
|
except (ValueError, MemoryError):
|
||||||
|
pass # Skip invalid patterns
|
||||||
|
|
||||||
|
cpdef object match(self, str path):
|
||||||
|
cdef:
|
||||||
|
dict config
|
||||||
|
dict params
|
||||||
|
int i
|
||||||
|
PCRE2Pattern pattern
|
||||||
|
tuple route_entry
|
||||||
|
bint matched
|
||||||
|
|
||||||
|
if path in self._exact_routes:
|
||||||
|
config = self._exact_routes[path]
|
||||||
|
return FastRouteMatch(config, {})
|
||||||
|
|
||||||
|
for i in range(self._regex_count):
|
||||||
|
route_entry = <tuple>self._regex_routes[i]
|
||||||
|
pattern = <PCRE2Pattern>route_entry[0]
|
||||||
|
config = <dict>route_entry[1]
|
||||||
|
|
||||||
|
matched, params = pattern.search_with_groups(path)
|
||||||
|
if matched:
|
||||||
|
return FastRouteMatch(config, params)
|
||||||
|
|
||||||
|
if self._has_default:
|
||||||
|
return FastRouteMatch(self._default_route, {})
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exact_routes(self) -> dict:
|
||||||
|
return self._exact_routes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def routes(self) -> dict:
|
||||||
|
"""Return regex routes as dict (pattern_str -> config)."""
|
||||||
|
cdef:
|
||||||
|
dict result = {}
|
||||||
|
PCRE2Pattern pattern
|
||||||
|
for pattern, config in self._regex_routes:
|
||||||
|
result[pattern.pattern] = config
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_route(self) -> Optional[dict]:
|
||||||
|
return self._default_route if self._has_default else None
|
||||||
|
|
||||||
|
cpdef list list_routes(self):
|
||||||
|
cdef:
|
||||||
|
list result = []
|
||||||
|
str path_str
|
||||||
|
dict config
|
||||||
|
PCRE2Pattern pattern
|
||||||
|
|
||||||
|
for path_str, config in self._exact_routes.items():
|
||||||
|
result.append({
|
||||||
|
"type": "exact",
|
||||||
|
"pattern": f"={path_str}",
|
||||||
|
"config": config,
|
||||||
|
})
|
||||||
|
|
||||||
|
for pattern, config in self._regex_routes:
|
||||||
|
result.append({
|
||||||
|
"type": "regex",
|
||||||
|
"pattern": pattern.pattern,
|
||||||
|
"jit_compiled": pattern.jit_compiled,
|
||||||
|
"config": config,
|
||||||
|
})
|
||||||
|
|
||||||
|
if self._has_default:
|
||||||
|
result.append({
|
||||||
|
"type": "default",
|
||||||
|
"pattern": "__default__",
|
||||||
|
"config": self._default_route,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def compile_pattern(str pattern, bint case_insensitive=<bint>False) -> PCRE2Pattern:
|
||||||
|
"""
|
||||||
|
Compile a PCRE2 pattern with JIT support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Regular expression pattern
|
||||||
|
case_insensitive: Whether to match case-insensitively
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compiled PCRE2Pattern object
|
||||||
|
"""
|
||||||
|
return PCRE2Pattern._create(pattern, case_insensitive)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_match(router: FastRouter, str path):
|
||||||
|
"""
|
||||||
|
Convenience function for matching a path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
router: FastRouter instance
|
||||||
|
path: URL path to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FastRouteMatch or None
|
||||||
|
"""
|
||||||
|
return router.match(path)
|
||||||
208
pyserve/_routing_pcre2.pxd
Normal file
208
pyserve/_routing_pcre2.pxd
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
# cython: language_level=3
|
||||||
|
|
||||||
|
from libc.stdint cimport uint8_t, uint32_t, int32_t
|
||||||
|
from libc.stddef cimport size_t
|
||||||
|
|
||||||
|
cdef extern from "pcre2.h":
|
||||||
|
pass
|
||||||
|
|
||||||
|
cdef extern from *:
|
||||||
|
ctypedef struct pcre2_code_8:
|
||||||
|
pass
|
||||||
|
ctypedef pcre2_code_8 pcre2_code
|
||||||
|
|
||||||
|
ctypedef struct pcre2_match_data_8:
|
||||||
|
pass
|
||||||
|
ctypedef pcre2_match_data_8 pcre2_match_data
|
||||||
|
|
||||||
|
ctypedef struct pcre2_compile_context_8:
|
||||||
|
pass
|
||||||
|
ctypedef pcre2_compile_context_8 pcre2_compile_context
|
||||||
|
|
||||||
|
ctypedef struct pcre2_match_context_8:
|
||||||
|
pass
|
||||||
|
ctypedef pcre2_match_context_8 pcre2_match_context
|
||||||
|
|
||||||
|
ctypedef struct pcre2_general_context_8:
|
||||||
|
pass
|
||||||
|
ctypedef pcre2_general_context_8 pcre2_general_context
|
||||||
|
|
||||||
|
ctypedef uint8_t PCRE2_UCHAR
|
||||||
|
ctypedef const uint8_t* PCRE2_SPTR
|
||||||
|
ctypedef size_t PCRE2_SIZE
|
||||||
|
|
||||||
|
uint32_t PCRE2_CASELESS
|
||||||
|
uint32_t PCRE2_MULTILINE
|
||||||
|
uint32_t PCRE2_DOTALL
|
||||||
|
uint32_t PCRE2_UTF
|
||||||
|
uint32_t PCRE2_UCP
|
||||||
|
uint32_t PCRE2_NO_UTF_CHECK
|
||||||
|
uint32_t PCRE2_ANCHORED
|
||||||
|
uint32_t PCRE2_ENDANCHORED
|
||||||
|
|
||||||
|
uint32_t PCRE2_JIT_COMPLETE
|
||||||
|
uint32_t PCRE2_JIT_PARTIAL_SOFT
|
||||||
|
uint32_t PCRE2_JIT_PARTIAL_HARD
|
||||||
|
|
||||||
|
int PCRE2_ERROR_NOMATCH
|
||||||
|
int PCRE2_ERROR_PARTIAL
|
||||||
|
int PCRE2_ERROR_JIT_STACKLIMIT
|
||||||
|
|
||||||
|
PCRE2_SIZE PCRE2_UNSET
|
||||||
|
PCRE2_SIZE PCRE2_ZERO_TERMINATED
|
||||||
|
|
||||||
|
pcre2_code* pcre2_compile_8(
|
||||||
|
PCRE2_SPTR pattern,
|
||||||
|
PCRE2_SIZE length,
|
||||||
|
uint32_t options,
|
||||||
|
int* errorcode,
|
||||||
|
PCRE2_SIZE* erroroffset,
|
||||||
|
pcre2_compile_context* ccontext
|
||||||
|
)
|
||||||
|
|
||||||
|
void pcre2_code_free_8(pcre2_code* code)
|
||||||
|
|
||||||
|
int pcre2_jit_compile_8(pcre2_code* code, uint32_t options)
|
||||||
|
|
||||||
|
pcre2_match_data* pcre2_match_data_create_from_pattern_8(
|
||||||
|
const pcre2_code* code,
|
||||||
|
pcre2_general_context* gcontext
|
||||||
|
)
|
||||||
|
|
||||||
|
pcre2_match_data* pcre2_match_data_create_8(
|
||||||
|
uint32_t ovecsize,
|
||||||
|
pcre2_general_context* gcontext
|
||||||
|
)
|
||||||
|
|
||||||
|
void pcre2_match_data_free_8(pcre2_match_data* match_data)
|
||||||
|
|
||||||
|
int pcre2_match_8(
|
||||||
|
const pcre2_code* code,
|
||||||
|
PCRE2_SPTR subject,
|
||||||
|
PCRE2_SIZE length,
|
||||||
|
PCRE2_SIZE startoffset,
|
||||||
|
uint32_t options,
|
||||||
|
pcre2_match_data* match_data,
|
||||||
|
pcre2_match_context* mcontext
|
||||||
|
)
|
||||||
|
|
||||||
|
int pcre2_jit_match_8(
|
||||||
|
const pcre2_code* code,
|
||||||
|
PCRE2_SPTR subject,
|
||||||
|
PCRE2_SIZE length,
|
||||||
|
PCRE2_SIZE startoffset,
|
||||||
|
uint32_t options,
|
||||||
|
pcre2_match_data* match_data,
|
||||||
|
pcre2_match_context* mcontext
|
||||||
|
)
|
||||||
|
|
||||||
|
PCRE2_SIZE* pcre2_get_ovector_pointer_8(pcre2_match_data* match_data)
|
||||||
|
uint32_t pcre2_get_ovector_count_8(pcre2_match_data* match_data)
|
||||||
|
|
||||||
|
int pcre2_pattern_info_8(
|
||||||
|
const pcre2_code* code,
|
||||||
|
uint32_t what,
|
||||||
|
void* where
|
||||||
|
)
|
||||||
|
|
||||||
|
uint32_t PCRE2_INFO_CAPTURECOUNT
|
||||||
|
uint32_t PCRE2_INFO_NAMECOUNT
|
||||||
|
uint32_t PCRE2_INFO_NAMETABLE
|
||||||
|
uint32_t PCRE2_INFO_NAMEENTRYSIZE
|
||||||
|
uint32_t PCRE2_INFO_JITSIZE
|
||||||
|
|
||||||
|
int pcre2_get_error_message_8(
|
||||||
|
int errorcode,
|
||||||
|
PCRE2_UCHAR* buffer,
|
||||||
|
PCRE2_SIZE bufflen
|
||||||
|
)
|
||||||
|
|
||||||
|
int pcre2_substring_copy_byname_8(
|
||||||
|
pcre2_match_data* match_data,
|
||||||
|
PCRE2_SPTR name,
|
||||||
|
PCRE2_UCHAR* buffer,
|
||||||
|
PCRE2_SIZE* bufflen
|
||||||
|
)
|
||||||
|
|
||||||
|
int pcre2_substring_copy_bynumber_8(
|
||||||
|
pcre2_match_data* match_data,
|
||||||
|
uint32_t number,
|
||||||
|
PCRE2_UCHAR* buffer,
|
||||||
|
PCRE2_SIZE* bufflen
|
||||||
|
)
|
||||||
|
|
||||||
|
int pcre2_substring_get_byname_8(
|
||||||
|
pcre2_match_data* match_data,
|
||||||
|
PCRE2_SPTR name,
|
||||||
|
PCRE2_UCHAR** bufferptr,
|
||||||
|
PCRE2_SIZE* bufflen
|
||||||
|
)
|
||||||
|
|
||||||
|
int pcre2_substring_get_bynumber_8(
|
||||||
|
pcre2_match_data* match_data,
|
||||||
|
uint32_t number,
|
||||||
|
PCRE2_UCHAR** bufferptr,
|
||||||
|
PCRE2_SIZE* bufflen
|
||||||
|
)
|
||||||
|
|
||||||
|
void pcre2_substring_free_8(PCRE2_UCHAR* buffer)
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline pcre2_code* pcre2_compile(
|
||||||
|
PCRE2_SPTR pattern,
|
||||||
|
PCRE2_SIZE length,
|
||||||
|
uint32_t options,
|
||||||
|
int* errorcode,
|
||||||
|
PCRE2_SIZE* erroroffset,
|
||||||
|
pcre2_compile_context* ccontext
|
||||||
|
) noexcept:
|
||||||
|
return pcre2_compile_8(pattern, length, options, errorcode, erroroffset, ccontext)
|
||||||
|
|
||||||
|
cdef inline void pcre2_code_free(pcre2_code* code) noexcept:
|
||||||
|
pcre2_code_free_8(code)
|
||||||
|
|
||||||
|
cdef inline int pcre2_jit_compile(pcre2_code* code, uint32_t options) noexcept:
|
||||||
|
return pcre2_jit_compile_8(code, options)
|
||||||
|
|
||||||
|
cdef inline pcre2_match_data* pcre2_match_data_create_from_pattern(
|
||||||
|
const pcre2_code* code,
|
||||||
|
pcre2_general_context* gcontext
|
||||||
|
) noexcept:
|
||||||
|
return pcre2_match_data_create_from_pattern_8(code, gcontext)
|
||||||
|
|
||||||
|
cdef inline void pcre2_match_data_free(pcre2_match_data* match_data) noexcept:
|
||||||
|
pcre2_match_data_free_8(match_data)
|
||||||
|
|
||||||
|
cdef inline int pcre2_match(
|
||||||
|
const pcre2_code* code,
|
||||||
|
PCRE2_SPTR subject,
|
||||||
|
PCRE2_SIZE length,
|
||||||
|
PCRE2_SIZE startoffset,
|
||||||
|
uint32_t options,
|
||||||
|
pcre2_match_data* match_data,
|
||||||
|
pcre2_match_context* mcontext
|
||||||
|
) noexcept:
|
||||||
|
return pcre2_match_8(code, subject, length, startoffset, options, match_data, mcontext)
|
||||||
|
|
||||||
|
cdef inline int pcre2_jit_match(
|
||||||
|
const pcre2_code* code,
|
||||||
|
PCRE2_SPTR subject,
|
||||||
|
PCRE2_SIZE length,
|
||||||
|
PCRE2_SIZE startoffset,
|
||||||
|
uint32_t options,
|
||||||
|
pcre2_match_data* match_data,
|
||||||
|
pcre2_match_context* mcontext
|
||||||
|
) noexcept:
|
||||||
|
return pcre2_jit_match_8(code, subject, length, startoffset, options, match_data, mcontext)
|
||||||
|
|
||||||
|
cdef inline PCRE2_SIZE* pcre2_get_ovector_pointer(pcre2_match_data* match_data) noexcept:
|
||||||
|
return pcre2_get_ovector_pointer_8(match_data)
|
||||||
|
|
||||||
|
cdef inline uint32_t pcre2_get_ovector_count(pcre2_match_data* match_data) noexcept:
|
||||||
|
return pcre2_get_ovector_count_8(match_data)
|
||||||
|
|
||||||
|
cdef inline int pcre2_pattern_info(const pcre2_code* code, uint32_t what, void* where) noexcept:
|
||||||
|
return pcre2_pattern_info_8(code, what, where)
|
||||||
|
|
||||||
|
cdef inline int pcre2_get_error_message(int errorcode, PCRE2_UCHAR* buffer, PCRE2_SIZE bufflen) noexcept:
|
||||||
|
return pcre2_get_error_message_8(errorcode, buffer, bufflen)
|
||||||
129
pyserve/_routing_py.py
Normal file
129
pyserve/_routing_py.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
Pure Python fallback for _routing when PCRE2/Cython is not available.
|
||||||
|
|
||||||
|
This module provides the same interface using the standard library `re` module.
|
||||||
|
It's slower than the Cython+PCRE2 implementation but works everywhere.
|
||||||
|
In future we may add pcre2.py library support for better performance in this module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional, Pattern, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class FastRouteMatch:
|
||||||
|
__slots__ = ("config", "params")
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
|
||||||
|
self.config = config
|
||||||
|
self.params = params if params is not None else {}
|
||||||
|
|
||||||
|
|
||||||
|
class FastRouter:
|
||||||
|
"""
|
||||||
|
Router with regex pattern matching.
|
||||||
|
|
||||||
|
Matching order (nginx-like):
|
||||||
|
1. Exact routes (prefix "=") - O(1) dict lookup
|
||||||
|
2. Regex routes (prefix "~" or "~*") - linear scan
|
||||||
|
3. Default route (fallback)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("_exact_routes", "_regex_routes", "_default_route", "_has_default", "_regex_count")
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._exact_routes: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self._regex_routes: List[Tuple[Pattern[str], Dict[str, Any]]] = []
|
||||||
|
self._default_route: Dict[str, Any] = {}
|
||||||
|
self._has_default: bool = False
|
||||||
|
self._regex_count: int = 0
|
||||||
|
|
||||||
|
def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
|
||||||
|
if pattern.startswith("="):
|
||||||
|
exact_path = pattern[1:]
|
||||||
|
self._exact_routes[exact_path] = config
|
||||||
|
return
|
||||||
|
|
||||||
|
if pattern == "__default__":
|
||||||
|
self._default_route = config
|
||||||
|
self._has_default = True
|
||||||
|
return
|
||||||
|
|
||||||
|
if pattern.startswith("~"):
|
||||||
|
case_insensitive = pattern.startswith("~*")
|
||||||
|
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
|
||||||
|
|
||||||
|
flags = re.IGNORECASE if case_insensitive else 0
|
||||||
|
try:
|
||||||
|
compiled_pattern = re.compile(regex_pattern, flags)
|
||||||
|
self._regex_routes.append((compiled_pattern, config))
|
||||||
|
self._regex_count = len(self._regex_routes)
|
||||||
|
except re.error:
|
||||||
|
pass # Ignore invalid patterns
|
||||||
|
|
||||||
|
def match(self, path: str) -> Optional[FastRouteMatch]:
|
||||||
|
if path in self._exact_routes:
|
||||||
|
config = self._exact_routes[path]
|
||||||
|
return FastRouteMatch(config, {})
|
||||||
|
|
||||||
|
for pattern, config in self._regex_routes:
|
||||||
|
match_obj = pattern.search(path)
|
||||||
|
if match_obj is not None:
|
||||||
|
params = match_obj.groupdict()
|
||||||
|
return FastRouteMatch(config, params)
|
||||||
|
|
||||||
|
if self._has_default:
|
||||||
|
return FastRouteMatch(self._default_route, {})
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exact_routes(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
return self._exact_routes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def routes(self) -> Dict[Pattern[str], Dict[str, Any]]:
|
||||||
|
return {p: c for p, c in self._regex_routes}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_route(self) -> Optional[Dict[str, Any]]:
|
||||||
|
return self._default_route if self._has_default else None
|
||||||
|
|
||||||
|
def list_routes(self) -> List[Dict[str, Any]]:
|
||||||
|
result: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for path, config in self._exact_routes.items():
|
||||||
|
result.append({
|
||||||
|
"type": "exact",
|
||||||
|
"pattern": f"={path}",
|
||||||
|
"config": config,
|
||||||
|
})
|
||||||
|
|
||||||
|
for pattern, config in self._regex_routes:
|
||||||
|
result.append({
|
||||||
|
"type": "regex",
|
||||||
|
"pattern": pattern.pattern,
|
||||||
|
"config": config,
|
||||||
|
})
|
||||||
|
|
||||||
|
if self._has_default:
|
||||||
|
result.append({
|
||||||
|
"type": "default",
|
||||||
|
"pattern": "__default__",
|
||||||
|
"config": self._default_route,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def fast_match(router: FastRouter, path: str) -> Optional[FastRouteMatch]:
|
||||||
|
"""
|
||||||
|
Convenience function for matching a path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
router: FastRouter instance
|
||||||
|
path: URL path to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FastRouteMatch or None
|
||||||
|
"""
|
||||||
|
return router.match(path)
|
||||||
@ -1,7 +1,6 @@
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Pattern
|
from typing import Any, Dict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -10,60 +9,19 @@ from starlette.responses import FileResponse, PlainTextResponse, Response
|
|||||||
|
|
||||||
from .logging_utils import get_logger
|
from .logging_utils import get_logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyserve._routing import FastRouteMatch, FastRouter, fast_match # type: ignore
|
||||||
|
CYTHON_ROUTING_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
from pyserve._routing_py import FastRouteMatch, FastRouter, fast_match
|
||||||
|
CYTHON_ROUTING_AVAILABLE = False
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RouteMatch:
|
# Aliases for backward compatibility
|
||||||
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
|
RouteMatch = FastRouteMatch
|
||||||
self.config = config
|
Router = FastRouter
|
||||||
self.params = params or {}
|
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
|
||||||
def __init__(self, static_dir: str = "./static"):
|
|
||||||
self.static_dir = Path(static_dir)
|
|
||||||
self.routes: Dict[Pattern, Dict[str, Any]] = {}
|
|
||||||
self.exact_routes: Dict[str, Dict[str, Any]] = {}
|
|
||||||
self.default_route: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
|
|
||||||
if pattern.startswith("="):
|
|
||||||
exact_path = pattern[1:]
|
|
||||||
self.exact_routes[exact_path] = config
|
|
||||||
logger.debug(f"Added exact route: {exact_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if pattern == "__default__":
|
|
||||||
self.default_route = config
|
|
||||||
logger.debug("Added default route")
|
|
||||||
return
|
|
||||||
|
|
||||||
if pattern.startswith("~"):
|
|
||||||
case_insensitive = pattern.startswith("~*")
|
|
||||||
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
|
|
||||||
|
|
||||||
flags = re.IGNORECASE if case_insensitive else 0
|
|
||||||
try:
|
|
||||||
compiled_pattern = re.compile(regex_pattern, flags)
|
|
||||||
self.routes[compiled_pattern] = config
|
|
||||||
logger.debug(f"Added regex route: {pattern}")
|
|
||||||
except re.error as e:
|
|
||||||
logger.error(f"Regex compilation error {pattern}: {e}")
|
|
||||||
|
|
||||||
def match(self, path: str) -> Optional[RouteMatch]:
|
|
||||||
if path in self.exact_routes:
|
|
||||||
return RouteMatch(self.exact_routes[path])
|
|
||||||
|
|
||||||
for pattern, config in self.routes.items():
|
|
||||||
match = pattern.search(path)
|
|
||||||
if match:
|
|
||||||
params = match.groupdict()
|
|
||||||
return RouteMatch(config, params)
|
|
||||||
|
|
||||||
if self.default_route:
|
|
||||||
return RouteMatch(self.default_route)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
|
|||||||
@ -9,9 +9,86 @@ Or via make:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_pcre2_config():
|
||||||
|
include_dirs = []
|
||||||
|
library_dirs = []
|
||||||
|
libraries = ["pcre2-8"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
cflags = subprocess.check_output(
|
||||||
|
["pkg-config", "--cflags", "libpcre2-8"],
|
||||||
|
stderr=subprocess.DEVNULL
|
||||||
|
).decode().strip()
|
||||||
|
libs = subprocess.check_output(
|
||||||
|
["pkg-config", "--libs", "libpcre2-8"],
|
||||||
|
stderr=subprocess.DEVNULL
|
||||||
|
).decode().strip()
|
||||||
|
|
||||||
|
for flag in cflags.split():
|
||||||
|
if flag.startswith("-I"):
|
||||||
|
include_dirs.append(flag[2:])
|
||||||
|
|
||||||
|
for flag in libs.split():
|
||||||
|
if flag.startswith("-L"):
|
||||||
|
library_dirs.append(flag[2:])
|
||||||
|
elif flag.startswith("-l"):
|
||||||
|
lib = flag[2:]
|
||||||
|
if lib not in libraries:
|
||||||
|
libraries.append(lib)
|
||||||
|
|
||||||
|
return include_dirs, library_dirs, libraries
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
cflags = subprocess.check_output(
|
||||||
|
["pcre2-config", "--cflags"],
|
||||||
|
stderr=subprocess.DEVNULL
|
||||||
|
).decode().strip()
|
||||||
|
libs = subprocess.check_output(
|
||||||
|
["pcre2-config", "--libs8"],
|
||||||
|
stderr=subprocess.DEVNULL
|
||||||
|
).decode().strip()
|
||||||
|
|
||||||
|
for flag in cflags.split():
|
||||||
|
if flag.startswith("-I"):
|
||||||
|
include_dirs.append(flag[2:])
|
||||||
|
|
||||||
|
for flag in libs.split():
|
||||||
|
if flag.startswith("-L"):
|
||||||
|
library_dirs.append(flag[2:])
|
||||||
|
elif flag.startswith("-l"):
|
||||||
|
lib = flag[2:]
|
||||||
|
if lib not in libraries:
|
||||||
|
libraries.append(lib)
|
||||||
|
|
||||||
|
return include_dirs, library_dirs, libraries
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: try common paths
|
||||||
|
common_paths = [
|
||||||
|
"/opt/homebrew", # macOS ARM
|
||||||
|
"/usr/local", # macOS Intel / Linux
|
||||||
|
"/usr", # Linux
|
||||||
|
]
|
||||||
|
|
||||||
|
for base in common_paths:
|
||||||
|
include_path = Path(base) / "include"
|
||||||
|
lib_path = Path(base) / "lib"
|
||||||
|
if (include_path / "pcre2.h").exists():
|
||||||
|
include_dirs.append(str(include_path))
|
||||||
|
library_dirs.append(str(lib_path))
|
||||||
|
break
|
||||||
|
|
||||||
|
return include_dirs, library_dirs, libraries
|
||||||
|
|
||||||
|
|
||||||
def build_extensions():
|
def build_extensions():
|
||||||
try:
|
try:
|
||||||
from Cython.Build import cythonize
|
from Cython.Build import cythonize
|
||||||
@ -29,6 +106,14 @@ def build_extensions():
|
|||||||
print("Install with: pip install setuptools")
|
print("Install with: pip install setuptools")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
pcre2_include, pcre2_libdir, pcre2_libs = get_pcre2_config()
|
||||||
|
|
||||||
|
if not pcre2_include:
|
||||||
|
print("WARNING: PCRE2 not found. Routing module may not compile.")
|
||||||
|
print("Install PCRE2: brew install pcre2 (macOS) or apt install libpcre2-dev (Linux)")
|
||||||
|
else:
|
||||||
|
print(f"Found PCRE2: includes={pcre2_include}, libs={pcre2_libdir}")
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
Extension(
|
Extension(
|
||||||
"pyserve._path_matcher",
|
"pyserve._path_matcher",
|
||||||
@ -36,6 +121,18 @@ def build_extensions():
|
|||||||
extra_compile_args=["-O3", "-ffast-math"],
|
extra_compile_args=["-O3", "-ffast-math"],
|
||||||
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
|
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
|
||||||
),
|
),
|
||||||
|
Extension(
|
||||||
|
"pyserve._routing",
|
||||||
|
sources=["pyserve/_routing.pyx"],
|
||||||
|
include_dirs=pcre2_include,
|
||||||
|
library_dirs=pcre2_libdir,
|
||||||
|
libraries=pcre2_libs,
|
||||||
|
extra_compile_args=["-O3", "-ffast-math"],
|
||||||
|
define_macros=[
|
||||||
|
("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION"),
|
||||||
|
("PCRE2_CODE_UNIT_WIDTH", "8"),
|
||||||
|
],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
ext_modules = cythonize(
|
ext_modules = cythonize(
|
||||||
@ -59,7 +156,9 @@ def build_extensions():
|
|||||||
cmd.run()
|
cmd.run()
|
||||||
|
|
||||||
print("\nCython extensions built successfully!")
|
print("\nCython extensions built successfully!")
|
||||||
print(" - pyserve/_path_matcher" + (".pyd" if sys.platform == "win32" else ".so"))
|
ext_suffix = ".pyd" if sys.platform == "win32" else ".so"
|
||||||
|
print(f" - pyserve/_path_matcher{ext_suffix}")
|
||||||
|
print(f" - pyserve/_routing{ext_suffix}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -50,16 +50,10 @@ class TestRouter:
|
|||||||
def test_router_initialization(self):
|
def test_router_initialization(self):
|
||||||
"""Test router initializes with correct defaults."""
|
"""Test router initializes with correct defaults."""
|
||||||
router = Router()
|
router = Router()
|
||||||
assert router.static_dir == Path("./static")
|
|
||||||
assert router.routes == {}
|
assert router.routes == {}
|
||||||
assert router.exact_routes == {}
|
assert router.exact_routes == {}
|
||||||
assert router.default_route is None
|
assert router.default_route is None
|
||||||
|
|
||||||
def test_router_custom_static_dir(self):
|
|
||||||
"""Test router with custom static directory."""
|
|
||||||
router = Router(static_dir="/custom/path")
|
|
||||||
assert router.static_dir == Path("/custom/path")
|
|
||||||
|
|
||||||
def test_add_exact_route(self):
|
def test_add_exact_route(self):
|
||||||
"""Test adding exact match route."""
|
"""Test adding exact match route."""
|
||||||
router = Router()
|
router = Router()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user