Compare commits

..

2 Commits
golang ... main

Author SHA1 Message Date
Илья Глазунов
eeeccd57da Cython routing added 2026-01-31 02:44:50 +03:00
Илья Глазунов
fe541778f1 changes in pyprojects
fixed python version

fixed linter errors in Cython path_matcher
2026-01-30 23:11:42 +03:00
45 changed files with 1140 additions and 8895 deletions

View File

@ -22,6 +22,11 @@ jobs:
with:
fetch-depth: 0
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y libpcre2-dev
- name: Setup Python
uses: actions/setup-python@v4
with:
@ -45,6 +50,9 @@ jobs:
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --with dev
- name: Build Cython extensions
run: poetry run python scripts/build_cython.py build_ext --inplace
- name: Build package
run: |
poetry build

View File

@ -17,6 +17,11 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y libpcre2-dev
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
@ -40,6 +45,9 @@ jobs:
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --with dev
- name: Build Cython extensions
run: poetry run python scripts/build_cython.py build_ext --inplace
- name: Run tests
run: poetry run pytest tests/ -v

5
.gitignore vendored
View File

@ -27,7 +27,4 @@ build/
.idea/
.vscode/
*.swp
*.swo
# Go binaries
go/bin
*.swo

153
benchmarks/bench_routing.py Normal file
View File

@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Benchmark script for routing performance comparison.
Compares:
- Pure Python implementation with standard re (_routing_py)
- Cython implementation with PCRE2 JIT (_routing)
Usage:
python benchmarks/bench_routing.py
"""
import re
import time
import statistics
from typing import Callable, Tuple
from pyserve._routing_py import (
FastRouter as PyFastRouter,
FastRouteMatch as PyFastRouteMatch,
)
try:
from pyserve._routing import (
FastRouter as CyFastRouter,
FastRouteMatch as CyFastRouteMatch,
)
CYTHON_AVAILABLE = True
except ImportError:
CYTHON_AVAILABLE = False
print("Cython module not compiled. Run: poetry run python scripts/build_cython.py\n")
def benchmark(func: Callable, iterations: int = 100000) -> Tuple[float, float]:
"""Benchmark a function and return mean/stdev in nanoseconds."""
times = []
# Warmup
for _ in range(1000):
func()
# Actual benchmark
for _ in range(iterations):
start = time.perf_counter_ns()
func()
end = time.perf_counter_ns()
times.append(end - start)
return statistics.mean(times), statistics.stdev(times)
def format_time(ns: float) -> str:
"""Format time in nanoseconds to human readable format."""
if ns < 1000:
return f"{ns:.1f} ns"
elif ns < 1_000_000:
return f"{ns/1000:.2f} µs"
else:
return f"{ns/1_000_000:.2f} ms"
def setup_router(router_class):
"""Setup a router with typical routes."""
router = router_class()
# Exact routes
router.add_route("=/health", {"return": "200 OK"})
router.add_route("=/api/status", {"return": "200 OK"})
router.add_route("=/favicon.ico", {"return": "204"})
# Regex routes
router.add_route("~^/api/v1/users/(?P<user_id>\\d+)$", {"proxy_pass": "http://users-service"})
router.add_route("~^/api/v1/posts/(?P<post_id>\\d+)$", {"proxy_pass": "http://posts-service"})
router.add_route("~\\.(css|js|png|jpg|gif|svg|woff2?)$", {"root": "./static"})
router.add_route("~^/api/", {"proxy_pass": "http://api-gateway"})
# Default route
router.add_route("__default__", {"spa_fallback": True, "root": "./dist"})
return router
def run_benchmarks():
print("=" * 70)
print("ROUTING BENCHMARK")
print("=" * 70)
print()
# Test paths with different matching scenarios
test_cases = [
("/health", "Exact match (first)"),
("/api/status", "Exact match (middle)"),
("/api/v1/users/12345", "Regex match with groups"),
("/static/app.js", "Regex match (file extension)"),
("/api/v2/other", "Regex match (simple prefix)"),
("/some/random/path", "Default route (fallback)"),
("/nonexistent", "Default route (fallback)"),
]
iterations = 100000
print(f"Iterations: {iterations:,}")
print()
# Setup routers
py_router = setup_router(PyFastRouter)
cy_router = setup_router(CyFastRouter) if CYTHON_AVAILABLE else None
results = {}
for path, description in test_cases:
print(f"Path: {path}")
print(f" {description}")
# Python implementation (standard re)
py_mean, py_std = benchmark(lambda p=path: py_router.match(p), iterations)
results[(path, "Python (re)")] = py_mean
print(f" Python (re): {format_time(py_mean):>12} ± {format_time(py_std)}")
# Cython implementation (PCRE2 JIT)
if CYTHON_AVAILABLE and cy_router:
cy_mean, cy_std = benchmark(lambda p=path: cy_router.match(p), iterations)
results[(path, "Cython (PCRE2)")] = cy_mean
speedup = py_mean / cy_mean if cy_mean > 0 else 0
print(f" Cython (PCRE2): {format_time(cy_mean):>12} ± {format_time(cy_std)} ({speedup:.2f}x faster)")
print()
# Summary
if CYTHON_AVAILABLE:
print("=" * 70)
print("SUMMARY")
print("=" * 70)
py_total = sum(v for k, v in results.items() if k[1] == "Python (re)")
cy_total = sum(v for k, v in results.items() if k[1] == "Cython (PCRE2)")
print(f" Python (re) total: {format_time(py_total)}")
print(f" Cython (PCRE2) total: {format_time(cy_total)}")
print(f" Overall speedup: {py_total / cy_total:.2f}x")
# Show JIT compilation status
print()
print("PCRE2 JIT Status:")
for route in cy_router.list_routes(): # type: ignore False linter error
if route["type"] == "regex":
jit = route.get("jit_compiled", False)
status = "✓ JIT" if jit else "✗ No JIT"
print(f" {status}: {route['pattern']}")
if __name__ == "__main__":
run_benchmarks()

View File

@ -1,34 +0,0 @@
# Multi-stage build for Konduktor
FROM golang:1.23-alpine AS builder
RUN apk add --no-cache git make
WORKDIR /build
COPY go.mod go.sum* ./
RUN go mod download
COPY . .
RUN make build
FROM alpine:3.19
RUN apk add --no-cache ca-certificates tzdata
RUN adduser -D -g '' konduktor
WORKDIR /app
COPY --from=builder /build/bin/konduktor /usr/local/bin/
COPY --from=builder /build/bin/konduktorctl /usr/local/bin/
RUN mkdir -p /app/static /app/templates /app/logs && \
chown -R konduktor:konduktor /app
USER konduktor
EXPOSE 8080
ENTRYPOINT ["konduktor"]
CMD ["-c", "/app/config.yaml"]

View File

@ -1,108 +0,0 @@
# Konduktor Go Build
# Makefile for building and testing Konduktor
.PHONY: all build build-konduktor build-konduktorctl test clean deps fmt lint run
# Build configuration
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
BUILD_TIME ?= $(shell date -u '+%Y-%m-%dT%H:%M:%SZ')
LDFLAGS := -X main.Version=$(VERSION) -X main.GitCommit=$(GIT_COMMIT) -X main.BuildTime=$(BUILD_TIME)
# Output directories
BIN_DIR := bin
all: deps build
# Download dependencies
deps:
@echo "==> Downloading dependencies..."
go mod download
go mod tidy
# Build all binaries
build: build-konduktor build-konduktorctl
# Build konduktor server
build-konduktor:
@echo "==> Building konduktor..."
@mkdir -p $(BIN_DIR)
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktor ./cmd/konduktor
# Build konduktorctl CLI
build-konduktorctl:
@echo "==> Building konduktorctl..."
@mkdir -p $(BIN_DIR)
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktorctl ./cmd/konduktorctl
# Run tests
test:
@echo "==> Running tests..."
go test -v -race -cover ./...
# Run tests with coverage report
test-coverage:
@echo "==> Running tests with coverage..."
go test -v -race -coverprofile=coverage.out ./...
go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report: coverage.html"
# Format code
fmt:
@echo "==> Formatting code..."
go fmt ./...
goimports -w .
# Lint code
lint:
@echo "==> Linting code..."
golangci-lint run ./...
# Run the server (development)
run: build-konduktor
@echo "==> Running konduktor..."
./$(BIN_DIR)/konduktor -c ../config.yaml
# Clean build artifacts
clean:
@echo "==> Cleaning..."
rm -rf $(BIN_DIR)
rm -f coverage.out coverage.html
# Install binaries to GOPATH/bin
install: build
@echo "==> Installing binaries..."
cp $(BIN_DIR)/konduktor $(GOPATH)/bin/
cp $(BIN_DIR)/konduktorctl $(GOPATH)/bin/
# Generate mocks (for testing)
generate:
@echo "==> Generating code..."
go generate ./...
# Docker build
docker-build:
@echo "==> Building Docker image..."
docker build -t konduktor:$(VERSION) .
# Show help
help:
@echo "Konduktor Build System"
@echo ""
@echo "Usage: make [target]"
@echo ""
@echo "Targets:"
@echo " all Download deps and build all binaries"
@echo " deps Download and tidy dependencies"
@echo " build Build all binaries"
@echo " build-konduktor Build the server binary"
@echo " build-konduktorctl Build the CLI binary"
@echo " test Run tests"
@echo " test-coverage Run tests with coverage report"
@echo " fmt Format code"
@echo " lint Lint code"
@echo " run Build and run the server"
@echo " clean Clean build artifacts"
@echo " install Install binaries to GOPATH/bin"
@echo " docker-build Build Docker image"
@echo " help Show this help"

View File

@ -1,149 +0,0 @@
# Konduktor (Go)
High-performance HTTP web server with extensible routing and process orchestration. (Previously known as PyServe in Python)
## Project Structure
```
go/
├── cmd/
│ ├── konduktor/ # Main server binary
│ └── konduktorctl/ # CLI management tool
├── internal/
│ ├── config/ # Configuration management
│ ├── logging/ # Structured logging
│ ├── middleware/ # HTTP middleware
│ ├── routing/ # HTTP routing
│ ├── extensions/ # Extension system (TODO)
│ └── process/ # Process management (TODO)
├── pkg/ # Public packages (TODO)
├── go.mod
├── go.sum
└── Makefile
```
## Building
```bash
cd go
# Download dependencies
make deps
# Build all binaries
make build
# Or build individually
make build-konduktor
make build-konduktorctl
```
## Running
```bash
# Run with default config
./bin/konduktor
# Run with custom config
./bin/konduktor -c ../config.yaml
# Run with flags
./bin/konduktor --host 127.0.0.1 --port 3000 --debug
```
## CLI Commands (konduktorctl)
```bash
# Start services
konduktorctl up
# Stop services
konduktorctl down
# View status
konduktorctl status
# View logs
konduktorctl logs -f
# Health check
konduktorctl health
# Scale services
konduktorctl scale api=3
# Configuration management
konduktorctl config show
konduktorctl config validate
# Initialize new project
konduktorctl init
```
## Configuration
Uses the same YAML configuration format as the Python version:
```yaml
server:
host: 0.0.0.0
port: 8080
http:
static_dir: ./static
templates_dir: ./templates
ssl:
enabled: false
cert_file: ./ssl/cert.pem
key_file: ./ssl/key.pem
logging:
level: INFO
console_output: true
extensions:
- type: routing
config:
regex_locations:
"=/health":
return: "200 OK"
```
## Development
```bash
# Format code
make fmt
# Run linter
make lint
# Run tests
make test
# Run with coverage
make test-coverage
```
## Migration from Python
This is a gradual rewrite of PyServe to Go. The project is now called **Konduktor**.
### Completed
- [x] Basic project structure
- [x] Configuration loading
- [x] HTTP server with graceful shutdown
- [x] Basic routing
- [x] Middleware (access log, recovery, server header)
- [x] CLI structure (konduktor, konduktorctl)
### TODO
- [ ] Extension system
- [x] Regex routing
- [x] Reverse proxy
- [ ] Process orchestration
- [ ] ASGI/WSGI adapter support
- [ ] WebSocket support
- [ ] Hot reload
- [ ] Metrics and monitoring

View File

@ -1,79 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/konduktor/konduktor/internal/config"
"github.com/konduktor/konduktor/internal/server"
)
var (
Version = "0.1.0"
BuildTime = "unknown"
GitCommit = "unknown"
)
var (
cfgFile string
host string
port int
debug bool
)
func main() {
rootCmd := &cobra.Command{
Use: "konduktor",
Short: "Konduktor - HTTP web server",
Long: `Konduktor is a high-performance HTTP web server with extensible routing and process orchestration.`,
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
RunE: runServer,
}
rootCmd.Flags().StringVarP(&cfgFile, "config", "c", "config.yaml", "Path to configuration file")
rootCmd.Flags().StringVar(&host, "host", "", "Host to bind the server to")
rootCmd.Flags().IntVar(&port, "port", 0, "Port to bind the server to")
rootCmd.Flags().BoolVar(&debug, "debug", false, "Enable debug mode")
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func runServer(cmd *cobra.Command, args []string) error {
cfg, err := config.Load(cfgFile)
if err != nil {
if os.IsNotExist(err) {
fmt.Printf("Configuration file %s not found, using defaults\n", cfgFile)
cfg = config.Default()
} else {
return fmt.Errorf("configuration loading error: %w", err)
}
}
if host != "" {
cfg.Server.Host = host
}
if port != 0 {
cfg.Server.Port = port
}
if debug {
cfg.Logging.Level = "DEBUG"
}
srv, err := server.New(cfg)
if err != nil {
return fmt.Errorf("server creation error: %w", err)
}
fmt.Printf("Starting Konduktor server on %s:%d\n", cfg.Server.Host, cfg.Server.Port)
if err := srv.Run(); err != nil {
return fmt.Errorf("server startup error: %w", err)
}
return nil
}

View File

@ -1,180 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
)
var (
Version = "0.1.0"
BuildTime = "unknown"
GitCommit = "unknown"
)
func main() {
rootCmd := &cobra.Command{
Use: "konduktorctl",
Short: "Konduktorctl - Service management CLI",
Long: `Konduktorctl is a CLI tool for managing Konduktor services.`,
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
}
rootCmd.AddCommand(
newUpCmd(),
newDownCmd(),
newStatusCmd(),
newLogsCmd(),
newHealthCmd(),
newScaleCmd(),
newConfigCmd(),
newInitCmd(),
newTopCmd(),
)
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func newUpCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "up [service...]",
Short: "Start services",
Long: `Start one or more services. If no service is specified, all services are started.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Starting services...")
// TODO: Implement service start logic
return nil
},
}
cmd.Flags().BoolP("detach", "d", false, "Run in background")
return cmd
}
func newDownCmd() *cobra.Command {
return &cobra.Command{
Use: "down [service...]",
Short: "Stop services",
Long: `Stop one or more services. If no service is specified, all services are stopped.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Stopping services...")
// TODO: Implement service stop logic
return nil
},
}
}
func newStatusCmd() *cobra.Command {
return &cobra.Command{
Use: "status [service...]",
Short: "Show service status",
Long: `Show the status of one or more services.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Service status:")
// TODO: Implement status display logic
return nil
},
}
}
func newLogsCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "logs [service]",
Short: "View service logs",
Long: `View logs for a specific service.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Fetching logs...")
// TODO: Implement logs viewing logic
return nil
},
}
cmd.Flags().BoolP("follow", "f", false, "Follow log output")
cmd.Flags().IntP("tail", "n", 100, "Number of lines to show from the end")
return cmd
}
func newHealthCmd() *cobra.Command {
return &cobra.Command{
Use: "health",
Short: "Check service health",
Long: `Check the health status of all services.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Health check:")
// TODO: Implement health check logic
return nil
},
}
}
func newScaleCmd() *cobra.Command {
return &cobra.Command{
Use: "scale <service>=<count>",
Short: "Scale a service",
Long: `Scale a service to a specific number of instances.`,
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Printf("Scaling: %v\n", args)
// TODO: Implement scaling logic
return nil
},
}
}
func newConfigCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "config",
Short: "Manage configuration",
Long: `View and validate configuration.`,
}
cmd.AddCommand(&cobra.Command{
Use: "show",
Short: "Show current configuration",
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Current configuration:")
// TODO: Implement config show logic
return nil
},
})
cmd.AddCommand(&cobra.Command{
Use: "validate",
Short: "Validate configuration file",
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Validating configuration...")
// TODO: Implement config validation logic
return nil
},
})
return cmd
}
func newInitCmd() *cobra.Command {
return &cobra.Command{
Use: "init",
Short: "Initialize a new project",
Long: `Create a new Konduktor project with default configuration.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Initializing new project...")
// TODO: Implement init logic
return nil
},
}
}
func newTopCmd() *cobra.Command {
return &cobra.Command{
Use: "top",
Short: "Display running processes",
Long: `Display real-time view of running processes and resource usage.`,
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Println("Process monitor:")
// TODO: Implement top-like display
return nil
},
}
}

View File

@ -1,18 +0,0 @@
module github.com/konduktor/konduktor
go 1.23.0
toolchain go1.24.2
require (
github.com/spf13/cobra v1.10.2
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.uber.org/zap v1.27.1 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)

View File

@ -1,20 +0,0 @@
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,134 +0,0 @@
package config
import (
"fmt"
"os"
"time"
"gopkg.in/yaml.v3"
)
type Config struct {
HTTP HTTPConfig `yaml:"http"`
Server ServerConfig `yaml:"server"`
SSL SSLConfig `yaml:"ssl"`
Logging LoggingConfig `yaml:"logging"`
Extensions []ExtensionConfig `yaml:"extensions"`
}
type HTTPConfig struct {
StaticDir string `yaml:"static_dir"`
TemplatesDir string `yaml:"templates_dir"`
}
type ServerConfig struct {
Host string `yaml:"host"`
Port int `yaml:"port"`
Backlog int `yaml:"backlog"`
DefaultRoot bool `yaml:"default_root"`
ProxyTimeout time.Duration `yaml:"proxy_timeout"`
RedirectInstructions map[string]string `yaml:"redirect_instructions"`
}
type SSLConfig struct {
Enabled bool `yaml:"enabled"`
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
}
type LoggingConfig struct {
Level string `yaml:"level"`
ConsoleOutput bool `yaml:"console_output"`
Format LogFormatConfig `yaml:"format"`
Console *ConsoleLogConfig `yaml:"console"`
Files []FileLogConfig `yaml:"files"`
}
type LogFormatConfig struct {
Type string `yaml:"type"`
UseColors bool `yaml:"use_colors"`
ShowModule bool `yaml:"show_module"`
TimestampFormat string `yaml:"timestamp_format"`
}
type ConsoleLogConfig struct {
Format LogFormatConfig `yaml:"format"`
Level string `yaml:"level"`
}
type FileLogConfig struct {
Path string `yaml:"path"`
Level string `yaml:"level"`
Loggers []string `yaml:"loggers"`
Format LogFormatConfig `yaml:"format"`
MaxBytes int64 `yaml:"max_bytes"`
BackupCount int `yaml:"backup_count"`
}
type ExtensionConfig struct {
Type string `yaml:"type"`
Config map[string]interface{} `yaml:"config"`
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
cfg := Default()
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
return cfg, nil
}
func Default() *Config {
return &Config{
HTTP: HTTPConfig{
StaticDir: "./static",
TemplatesDir: "./templates",
},
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
Backlog: 5,
DefaultRoot: false,
ProxyTimeout: 30 * time.Second,
},
SSL: SSLConfig{
Enabled: false,
CertFile: "./ssl/cert.pem",
KeyFile: "./ssl/key.pem",
},
Logging: LoggingConfig{
Level: "INFO",
ConsoleOutput: true,
Format: LogFormatConfig{
Type: "standard",
UseColors: true,
ShowModule: true,
TimestampFormat: "2006-01-02 15:04:05",
},
},
Extensions: []ExtensionConfig{},
}
}
func (c *Config) Validate() error {
if c.Server.Port < 1 || c.Server.Port > 65535 {
return fmt.Errorf("invalid port: %d", c.Server.Port)
}
if c.SSL.Enabled {
if c.SSL.CertFile == "" {
return fmt.Errorf("SSL enabled but cert_file not specified")
}
if c.SSL.KeyFile == "" {
return fmt.Errorf("SSL enabled but key_file not specified")
}
}
return nil
}

View File

@ -1,127 +0,0 @@
package config
import (
"os"
"testing"
)
func TestDefault(t *testing.T) {
cfg := Default()
if cfg.Server.Host != "0.0.0.0" {
t.Errorf("Expected host 0.0.0.0, got %s", cfg.Server.Host)
}
if cfg.Server.Port != 8080 {
t.Errorf("Expected port 8080, got %d", cfg.Server.Port)
}
if cfg.SSL.Enabled {
t.Error("Expected SSL to be disabled by default")
}
}
func TestValidate(t *testing.T) {
tests := []struct {
name string
modify func(*Config)
wantErr bool
}{
{
name: "valid default config",
modify: func(c *Config) {},
wantErr: false,
},
{
name: "invalid port - too low",
modify: func(c *Config) {
c.Server.Port = 0
},
wantErr: true,
},
{
name: "invalid port - too high",
modify: func(c *Config) {
c.Server.Port = 70000
},
wantErr: true,
},
{
name: "SSL enabled without cert",
modify: func(c *Config) {
c.SSL.Enabled = true
c.SSL.CertFile = ""
},
wantErr: true,
},
{
name: "SSL enabled without key",
modify: func(c *Config) {
c.SSL.Enabled = true
c.SSL.CertFile = "cert.pem"
c.SSL.KeyFile = ""
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := Default()
tt.modify(cfg)
err := cfg.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestLoad(t *testing.T) {
// Create temporary config file
content := `
server:
host: 127.0.0.1
port: 3000
logging:
level: DEBUG
`
tmpfile, err := os.CreateTemp("", "config-*.yaml")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write([]byte(content)); err != nil {
t.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}
cfg, err := Load(tmpfile.Name())
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if cfg.Server.Host != "127.0.0.1" {
t.Errorf("Expected host 127.0.0.1, got %s", cfg.Server.Host)
}
if cfg.Server.Port != 3000 {
t.Errorf("Expected port 3000, got %d", cfg.Server.Port)
}
if cfg.Logging.Level != "DEBUG" {
t.Errorf("Expected level DEBUG, got %s", cfg.Logging.Level)
}
}
func TestLoadNotFound(t *testing.T) {
_, err := Load("/nonexistent/config.yaml")
if err == nil {
t.Error("Expected error for non-existent file")
}
}

View File

@ -1,466 +0,0 @@
package extension
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"io"
"net/http"
"regexp"
"strings"
"sync"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
type CachingExtension struct {
BaseExtension
cache map[string]*cacheEntry
cachePatterns []*cachePattern
defaultTTL time.Duration
maxSize int
currentSize int
mu sync.RWMutex
hits int64
misses int64
}
type cacheEntry struct {
key string
body []byte
headers http.Header
statusCode int
contentType string
createdAt time.Time
expiresAt time.Time
size int
}
type cachePattern struct {
pattern *regexp.Regexp
ttl time.Duration
methods []string
}
type CachingConfig struct {
Enabled bool `yaml:"enabled"`
DefaultTTL string `yaml:"default_ttl"` // e.g., "5m", "1h"
MaxSizeMB int `yaml:"max_size_mb"` // Max cache size in MB
CachePatterns []PatternConfig `yaml:"cache_patterns"` // Patterns to cache
}
type PatternConfig struct {
Pattern string `yaml:"pattern"` // Regex pattern
TTL string `yaml:"ttl"` // TTL for this pattern
Methods []string `yaml:"methods"` // HTTP methods to cache (default: GET)
}
func NewCachingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
ext := &CachingExtension{
BaseExtension: NewBaseExtension("caching", 20, logger),
cache: make(map[string]*cacheEntry),
cachePatterns: make([]*cachePattern, 0),
defaultTTL: 5 * time.Minute,
maxSize: 100 * 1024 * 1024, // 100MB default
}
if ttl, ok := config["default_ttl"].(string); ok {
if duration, err := time.ParseDuration(ttl); err == nil {
ext.defaultTTL = duration
}
}
if maxSize, ok := config["max_size_mb"].(int); ok {
ext.maxSize = maxSize * 1024 * 1024
} else if maxSizeFloat, ok := config["max_size_mb"].(float64); ok {
ext.maxSize = int(maxSizeFloat) * 1024 * 1024
}
if patterns, ok := config["cache_patterns"].([]interface{}); ok {
for _, p := range patterns {
if patternCfg, ok := p.(map[string]interface{}); ok {
pattern := &cachePattern{
ttl: ext.defaultTTL,
methods: []string{"GET"},
}
if patternStr, ok := patternCfg["pattern"].(string); ok {
re, err := regexp.Compile(patternStr)
if err != nil {
logger.Error("Invalid cache pattern", "pattern", patternStr, "error", err)
continue
}
pattern.pattern = re
}
if ttl, ok := patternCfg["ttl"].(string); ok {
if duration, err := time.ParseDuration(ttl); err == nil {
pattern.ttl = duration
}
}
if methods, ok := patternCfg["methods"].([]interface{}); ok {
pattern.methods = make([]string, 0)
for _, m := range methods {
if method, ok := m.(string); ok {
pattern.methods = append(pattern.methods, strings.ToUpper(method))
}
}
}
ext.cachePatterns = append(ext.cachePatterns, pattern)
}
}
}
go ext.cleanupLoop()
return ext, nil
}
func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
if !e.shouldCache(r) {
return false, nil
}
key := e.cacheKey(r)
e.mu.RLock()
entry, exists := e.cache[key]
e.mu.RUnlock()
if exists && time.Now().Before(entry.expiresAt) {
e.mu.Lock()
e.hits++
e.mu.Unlock()
// Mark as cache hit to prevent setting X-Cache: MISS
// Try to find cachingResponseWriter in the wrapper chain
setCacheHitFlag(w)
for k, values := range entry.headers {
for _, v := range values {
w.Header().Add(k, v)
}
}
w.Header().Set("X-Cache", "HIT")
w.Header().Set("Content-Type", entry.contentType)
w.WriteHeader(entry.statusCode)
w.Write(entry.body)
e.logger.Debug("Cache hit", "key", key, "path", r.URL.Path)
return true, nil
}
e.mu.Lock()
e.misses++
e.mu.Unlock()
return false, nil
}
// setCacheHitFlag tries to find cachingResponseWriter and set cache hit flag
func setCacheHitFlag(w http.ResponseWriter) {
// Direct match
if cw, ok := w.(*cachingResponseWriter); ok {
cw.SetCacheHit()
return
}
// Try unwrapping
type unwrapper interface {
Unwrap() http.ResponseWriter
}
for {
if u, ok := w.(unwrapper); ok {
w = u.Unwrap()
if cw, ok := w.(*cachingResponseWriter); ok {
cw.SetCacheHit()
return
}
} else {
return
}
}
}
// ProcessResponse caches the response if applicable
func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
// Response caching is handled by the CachingResponseWriter
// X-Cache header is set in the cachingResponseWriter.WriteHeader
}
// WrapResponseWriter wraps the response writer to capture the response for caching
func (e *CachingExtension) WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter {
if !e.shouldCache(r) {
return w
}
return &cachingResponseWriter{
ResponseWriter: w,
ext: e,
request: r,
buffer: &bytes.Buffer{},
}
}
type cachingResponseWriter struct {
http.ResponseWriter
ext *CachingExtension
request *http.Request
buffer *bytes.Buffer
statusCode int
wroteHeader bool
cacheHit bool // Flag to indicate if this was a cache hit
}
func (cw *cachingResponseWriter) WriteHeader(code int) {
if !cw.wroteHeader {
cw.statusCode = code
cw.wroteHeader = true
// Set X-Cache: MISS header before writing headers (only if not a cache hit)
if !cw.cacheHit {
cw.ResponseWriter.Header().Set("X-Cache", "MISS")
}
cw.ResponseWriter.WriteHeader(code)
}
}
// SetCacheHit marks this response as a cache hit (to avoid setting X-Cache: MISS)
func (cw *cachingResponseWriter) SetCacheHit() {
cw.cacheHit = true
}
func (cw *cachingResponseWriter) Write(b []byte) (int, error) {
if !cw.wroteHeader {
cw.WriteHeader(http.StatusOK)
}
cw.buffer.Write(b)
return cw.ResponseWriter.Write(b)
}
func (cw *cachingResponseWriter) Finalize() {
if cw.statusCode < 200 || cw.statusCode >= 400 {
return
}
body := cw.buffer.Bytes()
if len(body) == 0 {
return
}
key := cw.ext.cacheKey(cw.request)
ttl := cw.ext.getTTL(cw.request)
entry := &cacheEntry{
key: key,
body: body,
headers: cw.Header().Clone(),
statusCode: cw.statusCode,
contentType: cw.Header().Get("Content-Type"),
createdAt: time.Now(),
expiresAt: time.Now().Add(ttl),
size: len(body),
}
cw.ext.store(entry)
}
func (e *CachingExtension) shouldCache(r *http.Request) bool {
path := r.URL.Path
method := r.Method
for _, pattern := range e.cachePatterns {
if pattern.pattern.MatchString(path) {
for _, m := range pattern.methods {
if m == method {
return true
}
}
}
}
// Default: only cache GET requests
return method == "GET" && len(e.cachePatterns) == 0
}
func (e *CachingExtension) cacheKey(r *http.Request) string {
// Create cache key from method + URL + relevant headers
h := sha256.New()
h.Write([]byte(r.Method))
h.Write([]byte(r.URL.String()))
// Include Accept-Encoding for vary
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
h.Write([]byte(ae))
}
return hex.EncodeToString(h.Sum(nil))
}
func (e *CachingExtension) getTTL(r *http.Request) time.Duration {
path := r.URL.Path
for _, pattern := range e.cachePatterns {
if pattern.pattern.MatchString(path) {
return pattern.ttl
}
}
return e.defaultTTL
}
func (e *CachingExtension) store(entry *cacheEntry) {
e.mu.Lock()
defer e.mu.Unlock()
// Evict old entries if needed
for e.currentSize+entry.size > e.maxSize && len(e.cache) > 0 {
e.evictOldest()
}
// Store new entry
if existing, ok := e.cache[entry.key]; ok {
e.currentSize -= existing.size
}
e.cache[entry.key] = entry
e.currentSize += entry.size
e.logger.Debug("Cached response",
"key", entry.key[:16],
"size", entry.size,
"ttl", entry.expiresAt.Sub(entry.createdAt).String())
}
func (e *CachingExtension) evictOldest() {
var oldestKey string
var oldestTime time.Time
for key, entry := range e.cache {
if oldestKey == "" || entry.createdAt.Before(oldestTime) {
oldestKey = key
oldestTime = entry.createdAt
}
}
if oldestKey != "" {
entry := e.cache[oldestKey]
e.currentSize -= entry.size
delete(e.cache, oldestKey)
}
}
func (e *CachingExtension) cleanupLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
e.cleanupExpired()
}
}
func (e *CachingExtension) cleanupExpired() {
e.mu.Lock()
defer e.mu.Unlock()
now := time.Now()
for key, entry := range e.cache {
if now.After(entry.expiresAt) {
e.currentSize -= entry.size
delete(e.cache, key)
}
}
}
func (e *CachingExtension) Invalidate(key string) {
e.mu.Lock()
defer e.mu.Unlock()
if entry, ok := e.cache[key]; ok {
e.currentSize -= entry.size
delete(e.cache, key)
}
}
// InvalidatePattern removes all entries matching a pattern, unlike Invalidate
func (e *CachingExtension) InvalidatePattern(pattern string) error {
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
e.mu.Lock()
defer e.mu.Unlock()
for key, entry := range e.cache {
if re.MatchString(key) {
e.currentSize -= entry.size
delete(e.cache, key)
}
}
return nil
}
// Clear removes all entries from cache
func (e *CachingExtension) Clear() {
e.mu.Lock()
defer e.mu.Unlock()
e.cache = make(map[string]*cacheEntry)
e.currentSize = 0
}
// GetMetrics returns caching metrics
func (e *CachingExtension) GetMetrics() map[string]interface{} {
e.mu.RLock()
defer e.mu.RUnlock()
hitRate := float64(0)
total := e.hits + e.misses
if total > 0 {
hitRate = float64(e.hits) / float64(total) * 100
}
return map[string]interface{}{
"entries": len(e.cache),
"size_bytes": e.currentSize,
"max_size": e.maxSize,
"hits": e.hits,
"misses": e.misses,
"hit_rate": hitRate,
"patterns": len(e.cachePatterns),
"default_ttl": e.defaultTTL.String(),
}
}
// Cleanup stops the cleanup goroutine
func (e *CachingExtension) Cleanup() error {
e.Clear()
return nil
}
// CacheReader wraps an io.ReadCloser to cache the body
type CacheReader struct {
io.ReadCloser
buffer *bytes.Buffer
}
func (cr *CacheReader) Read(p []byte) (int, error) {
n, err := cr.ReadCloser.Read(p)
if n > 0 {
cr.buffer.Write(p[:n])
}
return n, err
}
func (cr *CacheReader) GetBody() []byte {
return cr.buffer.Bytes()
}

View File

@ -1,123 +0,0 @@
package extension
import (
"context"
"net/http"
"github.com/konduktor/konduktor/internal/logging"
)
// Extension is the interface that all extensions must implement
type Extension interface {
// Name returns the unique name of the extension
Name() string
// Initialize is called when the extension is loaded
Initialize() error
// ProcessRequest processes an incoming request before routing.
// Returns:
// - response: if non-nil, the request is handled and no further processing occurs
// - handled: if true, the request was handled by this extension
// - err: any error that occurred
ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (handled bool, err error)
// ProcessResponse is called after the response is generated but before it's sent.
// Extensions can modify the response here.
ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request)
// Cleanup is called when the extension is being unloaded
Cleanup() error
// Enabled returns whether the extension is currently enabled
Enabled() bool
// SetEnabled enables or disables the extension
SetEnabled(enabled bool)
// Priority returns the extension's priority (lower = earlier execution)
Priority() int
}
// BaseExtension provides a default implementation for common Extension methods
type BaseExtension struct {
name string
enabled bool
priority int
logger *logging.Logger
}
// NewBaseExtension creates a new BaseExtension
func NewBaseExtension(name string, priority int, logger *logging.Logger) BaseExtension {
return BaseExtension{
name: name,
enabled: true,
priority: priority,
logger: logger,
}
}
// Name returns the extension name
func (b *BaseExtension) Name() string {
return b.name
}
// Enabled returns whether the extension is enabled
func (b *BaseExtension) Enabled() bool {
return b.enabled
}
// SetEnabled sets the enabled state
func (b *BaseExtension) SetEnabled(enabled bool) {
b.enabled = enabled
}
// Priority returns the extension priority
func (b *BaseExtension) Priority() int {
return b.priority
}
// Initialize default implementation (no-op)
func (b *BaseExtension) Initialize() error {
return nil
}
// Cleanup default implementation (no-op)
func (b *BaseExtension) Cleanup() error {
return nil
}
// ProcessRequest default implementation (pass-through)
func (b *BaseExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
return false, nil
}
// ProcessResponse default implementation (no-op)
func (b *BaseExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
}
// Logger returns the extension's logger
func (b *BaseExtension) Logger() *logging.Logger {
return b.logger
}
// ExtensionConfig holds configuration for creating extensions
type ExtensionConfig struct {
Type string
Config map[string]interface{}
}
// ExtensionFactory is a function that creates an extension from config
type ExtensionFactory func(config map[string]interface{}, logger *logging.Logger) (Extension, error)
// ResponseWriterWrapper is an optional interface that extensions can implement
// to wrap the response writer for capturing/modifying responses
type ResponseWriterWrapper interface {
WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter
}
// ResponseFinalizer is an optional interface for response writers that need
// to perform finalization after the response is written (e.g., caching)
type ResponseFinalizer interface {
Finalize()
}

View File

@ -1,271 +0,0 @@
package extension
import (
"context"
"fmt"
"net/http"
"sort"
"sync"
"github.com/konduktor/konduktor/internal/logging"
)
// Manager manages all loaded extensions
type Manager struct {
extensions []Extension
registry map[string]ExtensionFactory
logger *logging.Logger
mu sync.RWMutex
}
// NewManager creates a new extension manager
func NewManager(logger *logging.Logger) *Manager {
m := &Manager{
extensions: make([]Extension, 0),
registry: make(map[string]ExtensionFactory),
logger: logger,
}
// Register built-in extensions
m.RegisterFactory("routing", NewRoutingExtension)
m.RegisterFactory("security", NewSecurityExtension)
m.RegisterFactory("caching", NewCachingExtension)
return m
}
// RegisterFactory registers an extension factory
func (m *Manager) RegisterFactory(name string, factory ExtensionFactory) {
m.mu.Lock()
defer m.mu.Unlock()
m.registry[name] = factory
}
// LoadExtension loads an extension by type and config
func (m *Manager) LoadExtension(extType string, config map[string]interface{}) error {
m.mu.Lock()
defer m.mu.Unlock()
factory, ok := m.registry[extType]
if !ok {
return fmt.Errorf("unknown extension type: %s", extType)
}
ext, err := factory(config, m.logger)
if err != nil {
return fmt.Errorf("failed to create extension %s: %w", extType, err)
}
if err := ext.Initialize(); err != nil {
return fmt.Errorf("failed to initialize extension %s: %w", extType, err)
}
m.extensions = append(m.extensions, ext)
// Sort by priority (lower first)
sort.Slice(m.extensions, func(i, j int) bool {
return m.extensions[i].Priority() < m.extensions[j].Priority()
})
m.logger.Info("Loaded extension", "type", extType, "name", ext.Name(), "priority", ext.Priority())
return nil
}
// AddExtension adds a pre-created extension
func (m *Manager) AddExtension(ext Extension) error {
m.mu.Lock()
defer m.mu.Unlock()
if err := ext.Initialize(); err != nil {
return fmt.Errorf("failed to initialize extension %s: %w", ext.Name(), err)
}
m.extensions = append(m.extensions, ext)
// Sort by priority
sort.Slice(m.extensions, func(i, j int) bool {
return m.extensions[i].Priority() < m.extensions[j].Priority()
})
m.logger.Info("Added extension", "name", ext.Name(), "priority", ext.Priority())
return nil
}
// ProcessRequest runs all extensions' ProcessRequest in order
// Returns true if any extension handled the request
func (m *Manager) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
m.mu.RLock()
extensions := m.extensions
m.mu.RUnlock()
for _, ext := range extensions {
if !ext.Enabled() {
continue
}
handled, err := ext.ProcessRequest(ctx, w, r)
if err != nil {
m.logger.Error("Extension error", "extension", ext.Name(), "error", err)
// Continue to next extension on error
continue
}
if handled {
return true, nil
}
}
return false, nil
}
// ProcessResponse runs all extensions' ProcessResponse in reverse order
func (m *Manager) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
m.mu.RLock()
extensions := m.extensions
m.mu.RUnlock()
// Process in reverse order for response
for i := len(extensions) - 1; i >= 0; i-- {
ext := extensions[i]
if !ext.Enabled() {
continue
}
ext.ProcessResponse(ctx, w, r)
}
}
// Cleanup cleans up all extensions
func (m *Manager) Cleanup() {
m.mu.Lock()
defer m.mu.Unlock()
for _, ext := range m.extensions {
if err := ext.Cleanup(); err != nil {
m.logger.Error("Extension cleanup error", "extension", ext.Name(), "error", err)
}
}
m.extensions = nil
}
// GetExtension returns an extension by name
func (m *Manager) GetExtension(name string) Extension {
m.mu.RLock()
defer m.mu.RUnlock()
for _, ext := range m.extensions {
if ext.Name() == name {
return ext
}
}
return nil
}
// Extensions returns all loaded extensions
func (m *Manager) Extensions() []Extension {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]Extension, len(m.extensions))
copy(result, m.extensions)
return result
}
// Handler returns an http.Handler that processes requests through all extensions
func (m *Manager) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Wrap response writer through all extensions that support it
// Process in reverse priority order so highest priority wrapper is outermost
wrappedWriter := w
var finalizers []ResponseFinalizer
m.mu.RLock()
extensions := m.extensions
m.mu.RUnlock()
// Wrap response writer (lowest priority first, so they wrap in correct order)
for _, ext := range extensions {
if !ext.Enabled() {
continue
}
if wrapper, ok := ext.(ResponseWriterWrapper); ok {
wrappedWriter = wrapper.WrapResponseWriter(wrappedWriter, r)
// Check if the wrapped writer implements Finalizer
if finalizer, ok := wrappedWriter.(ResponseFinalizer); ok {
finalizers = append(finalizers, finalizer)
}
}
}
// Create response wrapper to capture status code
responseWrapper := newResponseWrapper(wrappedWriter)
// Process request through extensions
handled, err := m.ProcessRequest(ctx, responseWrapper, r)
if err != nil {
m.logger.Error("Error processing request", "error", err)
}
if handled {
// Extension handled the request, process response
m.ProcessResponse(ctx, responseWrapper, r)
// Finalize all response writers
for i := len(finalizers) - 1; i >= 0; i-- {
finalizers[i].Finalize()
}
return
}
// No extension handled, pass to next handler
next.ServeHTTP(responseWrapper, r)
// Process response
m.ProcessResponse(ctx, responseWrapper, r)
// Finalize all response writers
for i := len(finalizers) - 1; i >= 0; i-- {
finalizers[i].Finalize()
}
})
}
// responseWrapper wraps http.ResponseWriter to allow response modification
type responseWrapper struct {
http.ResponseWriter
statusCode int
written bool
}
func newResponseWrapper(w http.ResponseWriter) *responseWrapper {
return &responseWrapper{
ResponseWriter: w,
statusCode: http.StatusOK,
}
}
func (rw *responseWrapper) WriteHeader(code int) {
if !rw.written {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
rw.written = true
}
}
func (rw *responseWrapper) Write(b []byte) (int, error) {
if !rw.written {
rw.WriteHeader(http.StatusOK)
}
return rw.ResponseWriter.Write(b)
}
func (rw *responseWrapper) StatusCode() int {
return rw.statusCode
}
// Unwrap returns the underlying ResponseWriter (for type assertions)
func (rw *responseWrapper) Unwrap() http.ResponseWriter {
return rw.ResponseWriter
}

View File

@ -1,176 +0,0 @@
package extension
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/konduktor/konduktor/internal/logging"
)
func newTestLogger() *logging.Logger {
logger, _ := logging.New(logging.Config{Level: "DEBUG"})
return logger
}
func TestNewManager(t *testing.T) {
logger := newTestLogger()
manager := NewManager(logger)
if manager == nil {
t.Fatal("Expected manager, got nil")
}
// Check built-in factories are registered
if _, ok := manager.registry["routing"]; !ok {
t.Error("Expected routing factory to be registered")
}
if _, ok := manager.registry["security"]; !ok {
t.Error("Expected security factory to be registered")
}
if _, ok := manager.registry["caching"]; !ok {
t.Error("Expected caching factory to be registered")
}
}
func TestManager_LoadExtension(t *testing.T) {
logger := newTestLogger()
manager := NewManager(logger)
err := manager.LoadExtension("security", map[string]interface{}{})
if err != nil {
t.Errorf("Failed to load security extension: %v", err)
}
exts := manager.Extensions()
if len(exts) != 1 {
t.Errorf("Expected 1 extension, got %d", len(exts))
}
}
func TestManager_LoadExtension_Unknown(t *testing.T) {
logger := newTestLogger()
manager := NewManager(logger)
err := manager.LoadExtension("unknown", map[string]interface{}{})
if err == nil {
t.Error("Expected error for unknown extension type")
}
}
func TestManager_GetExtension(t *testing.T) {
logger := newTestLogger()
manager := NewManager(logger)
manager.LoadExtension("security", map[string]interface{}{})
ext := manager.GetExtension("security")
if ext == nil {
t.Error("Expected to find security extension")
}
ext = manager.GetExtension("nonexistent")
if ext != nil {
t.Error("Expected nil for nonexistent extension")
}
}
func TestManager_ProcessRequest(t *testing.T) {
logger := newTestLogger()
manager := NewManager(logger)
// Load security extension with blocked IP
manager.LoadExtension("security", map[string]interface{}{
"blocked_ips": []interface{}{"192.168.1.1"},
})
// Create test request
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handled, err := manager.ProcessRequest(context.Background(), rr, req)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !handled {
t.Error("Expected request to be handled (blocked)")
}
if rr.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d", rr.Code)
}
}
func TestManager_Handler(t *testing.T) {
logger := newTestLogger()
manager := NewManager(logger)
// Load routing extension with a simple route
manager.LoadExtension("routing", map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/health": map[string]interface{}{
"return": "200 OK",
},
},
})
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
handler := manager.Handler(baseHandler)
// Test health route
req := httptest.NewRequest("GET", "/health", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
}
func TestManager_Priority(t *testing.T) {
logger := newTestLogger()
manager := NewManager(logger)
// Load extensions in any order
manager.LoadExtension("routing", map[string]interface{}{}) // Priority 50
manager.LoadExtension("security", map[string]interface{}{}) // Priority 10
manager.LoadExtension("caching", map[string]interface{}{}) // Priority 20
exts := manager.Extensions()
if len(exts) != 3 {
t.Fatalf("Expected 3 extensions, got %d", len(exts))
}
// Check order by priority
if exts[0].Name() != "security" {
t.Errorf("Expected security first, got %s", exts[0].Name())
}
if exts[1].Name() != "caching" {
t.Errorf("Expected caching second, got %s", exts[1].Name())
}
if exts[2].Name() != "routing" {
t.Errorf("Expected routing third, got %s", exts[2].Name())
}
}
func TestManager_Cleanup(t *testing.T) {
logger := newTestLogger()
manager := NewManager(logger)
manager.LoadExtension("security", map[string]interface{}{})
manager.LoadExtension("routing", map[string]interface{}{})
manager.Cleanup()
exts := manager.Extensions()
if len(exts) != 0 {
t.Errorf("Expected 0 extensions after cleanup, got %d", len(exts))
}
}

View File

@ -1,428 +0,0 @@
package extension
import (
"context"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/konduktor/konduktor/internal/logging"
"github.com/konduktor/konduktor/internal/proxy"
)
// RoutingExtension handles request routing based on patterns
type RoutingExtension struct {
BaseExtension
exactRoutes map[string]RouteConfig
regexRoutes []*regexRoute
defaultRoute *RouteConfig
staticDir string
}
// RouteConfig holds configuration for a route
type RouteConfig struct {
ProxyPass string
Root string
IndexFile string
Return string
ContentType string
Headers []string
CacheControl string
SPAFallback bool
ExcludePatterns []string
Timeout float64
}
type regexRoute struct {
pattern *regexp.Regexp
config RouteConfig
caseSensitive bool
originalExpr string
}
// NewRoutingExtension creates a new routing extension
func NewRoutingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
ext := &RoutingExtension{
BaseExtension: NewBaseExtension("routing", 50, logger), // Middle priority
exactRoutes: make(map[string]RouteConfig),
regexRoutes: make([]*regexRoute, 0),
staticDir: "./static",
}
logger.Debug("Routing extension config", "config", config)
// Parse regex_locations from config
if locations, ok := config["regex_locations"].(map[string]interface{}); ok {
logger.Debug("Found regex_locations", "count", len(locations))
for pattern, routeCfg := range locations {
logger.Debug("Adding route", "pattern", pattern)
if rc, ok := routeCfg.(map[string]interface{}); ok {
ext.addRoute(pattern, parseRouteConfig(rc))
}
}
} else {
logger.Warn("No regex_locations found in config", "config_keys", getKeys(config))
}
// Parse static_dir if provided
if staticDir, ok := config["static_dir"].(string); ok {
ext.staticDir = staticDir
}
return ext, nil
}
func parseRouteConfig(cfg map[string]interface{}) RouteConfig {
rc := RouteConfig{
IndexFile: "index.html",
}
if v, ok := cfg["proxy_pass"].(string); ok {
rc.ProxyPass = v
}
if v, ok := cfg["root"].(string); ok {
rc.Root = v
}
if v, ok := cfg["index_file"].(string); ok {
rc.IndexFile = v
}
if v, ok := cfg["return"].(string); ok {
rc.Return = v
}
if v, ok := cfg["content_type"].(string); ok {
rc.ContentType = v
}
if v, ok := cfg["cache_control"].(string); ok {
rc.CacheControl = v
}
if v, ok := cfg["spa_fallback"].(bool); ok {
rc.SPAFallback = v
}
if v, ok := cfg["timeout"].(float64); ok {
rc.Timeout = v
}
// Parse headers
if headers, ok := cfg["headers"].([]interface{}); ok {
for _, h := range headers {
if header, ok := h.(string); ok {
rc.Headers = append(rc.Headers, header)
}
}
}
// Parse exclude_patterns
if patterns, ok := cfg["exclude_patterns"].([]interface{}); ok {
for _, p := range patterns {
if pattern, ok := p.(string); ok {
rc.ExcludePatterns = append(rc.ExcludePatterns, pattern)
}
}
}
return rc
}
func (e *RoutingExtension) addRoute(pattern string, config RouteConfig) {
switch {
case pattern == "__default__":
e.defaultRoute = &config
case strings.HasPrefix(pattern, "="):
// Exact match
path := strings.TrimPrefix(pattern, "=")
e.exactRoutes[path] = config
case strings.HasPrefix(pattern, "~*"):
// Case-insensitive regex
expr := strings.TrimPrefix(pattern, "~*")
re, err := regexp.Compile("(?i)" + expr)
if err != nil {
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
return
}
e.regexRoutes = append(e.regexRoutes, &regexRoute{
pattern: re,
config: config,
caseSensitive: false,
originalExpr: expr,
})
case strings.HasPrefix(pattern, "~"):
// Case-sensitive regex
expr := strings.TrimPrefix(pattern, "~")
re, err := regexp.Compile(expr)
if err != nil {
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
return
}
e.regexRoutes = append(e.regexRoutes, &regexRoute{
pattern: re,
config: config,
caseSensitive: true,
originalExpr: expr,
})
}
}
// ProcessRequest handles the request routing
func (e *RoutingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
path := r.URL.Path
// 1. Check exact routes (ignore request path for proxy)
if config, ok := e.exactRoutes[path]; ok {
return e.handleRoute(w, r, config, nil, true)
}
// 2. Check regex routes
for _, route := range e.regexRoutes {
match := route.pattern.FindStringSubmatch(path)
if match != nil {
params := make(map[string]string)
names := route.pattern.SubexpNames()
for i, name := range names {
if i > 0 && name != "" && i < len(match) {
params[name] = match[i]
}
}
return e.handleRoute(w, r, route.config, params, false)
}
}
// 3. Check default route
if e.defaultRoute != nil {
return e.handleRoute(w, r, *e.defaultRoute, nil, false)
}
return false, nil
}
func (e *RoutingExtension) handleRoute(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
// Handle "return" directive
if config.Return != "" {
return e.handleReturn(w, config)
}
// Handle proxy_pass
if config.ProxyPass != "" {
return e.handleProxy(w, r, config, params, exactMatch)
}
// Handle static files with root
if config.Root != "" {
return e.handleStatic(w, r, config)
}
// Handle SPA fallback
if config.SPAFallback {
return e.handleSPAFallback(w, r, config)
}
return false, nil
}
func (e *RoutingExtension) handleReturn(w http.ResponseWriter, config RouteConfig) (bool, error) {
parts := strings.SplitN(config.Return, " ", 2)
statusCode := 200
body := "OK"
if len(parts) >= 1 {
switch parts[0] {
case "200":
statusCode = 200
case "201":
statusCode = 201
case "301":
statusCode = 301
case "302":
statusCode = 302
case "400":
statusCode = 400
case "404":
statusCode = 404
case "500":
statusCode = 500
}
}
if len(parts) >= 2 {
body = parts[1]
}
contentType := "text/plain"
if config.ContentType != "" {
contentType = config.ContentType
}
w.Header().Set("Content-Type", contentType)
w.WriteHeader(statusCode)
w.Write([]byte(body))
return true, nil
}
func (e *RoutingExtension) handleProxy(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
target := config.ProxyPass
// Check if target URL contains parameter placeholders
hasParams := strings.Contains(target, "{") && strings.Contains(target, "}")
// Substitute params in target URL
for key, value := range params {
target = strings.ReplaceAll(target, "{"+key+"}", value)
}
// Create proxy config
// IgnoreRequestPath=true when:
// - exact match route (=/path)
// - target URL had parameter substitutions (the target path is fully specified)
proxyConfig := &proxy.Config{
Target: target,
Headers: make(map[string]string),
IgnoreRequestPath: exactMatch || hasParams,
}
// Set timeout if specified
if config.Timeout > 0 {
proxyConfig.Timeout = time.Duration(config.Timeout * float64(time.Second))
}
// Parse headers
clientIP := getClientIP(r)
for _, header := range config.Headers {
parts := strings.SplitN(header, ": ", 2)
if len(parts) == 2 {
value := parts[1]
// Substitute params
for key, pValue := range params {
value = strings.ReplaceAll(value, "{"+key+"}", pValue)
}
// Substitute special variables
value = strings.ReplaceAll(value, "$remote_addr", clientIP)
proxyConfig.Headers[parts[0]] = value
}
}
p, err := proxy.New(proxyConfig, e.logger)
if err != nil {
e.logger.Error("Failed to create proxy", "target", target, "error", err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return true, nil
}
p.ProxyRequest(w, r, params)
return true, nil
}
func (e *RoutingExtension) handleStatic(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
path := r.URL.Path
// Handle index file for root or directory paths
if path == "/" || strings.HasSuffix(path, "/") {
path = "/" + config.IndexFile
}
// Get absolute path for root dir
absRoot, err := filepath.Abs(config.Root)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return true, nil
}
filePath := filepath.Join(absRoot, filepath.Clean("/"+path))
cleanPath := filepath.Clean(filePath)
// Prevent directory traversal
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) {
http.Error(w, "Forbidden", http.StatusForbidden)
return true, nil
}
// Check if file exists
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return false, nil // Let other handlers try
}
// Set cache control header
if config.CacheControl != "" {
w.Header().Set("Cache-Control", config.CacheControl)
}
// Set custom headers
for _, header := range config.Headers {
parts := strings.SplitN(header, ": ", 2)
if len(parts) == 2 {
w.Header().Set(parts[0], parts[1])
}
}
http.ServeFile(w, r, filePath)
return true, nil
}
func (e *RoutingExtension) handleSPAFallback(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
path := r.URL.Path
// Check exclude patterns
for _, pattern := range config.ExcludePatterns {
if strings.HasPrefix(path, pattern) {
return false, nil
}
}
root := config.Root
if root == "" {
root = e.staticDir
}
indexFile := config.IndexFile
if indexFile == "" {
indexFile = "index.html"
}
filePath := filepath.Join(root, indexFile)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return false, nil
}
http.ServeFile(w, r, filePath)
return true, nil
}
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to RemoteAddr
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}
// GetMetrics returns routing metrics
func (e *RoutingExtension) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"exact_routes": len(e.exactRoutes),
"regex_routes": len(e.regexRoutes),
"has_default": e.defaultRoute != nil,
}
}
func getKeys(m map[string]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}

View File

@ -1,312 +0,0 @@
package extension
import (
"context"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
// SecurityExtension provides security features like IP filtering and security headers
type SecurityExtension struct {
BaseExtension
allowedIPs map[string]bool
blockedIPs map[string]bool
allowedCIDRs []*net.IPNet
blockedCIDRs []*net.IPNet
securityHeaders map[string]string
// Rate limiting
rateLimitEnabled bool
rateLimitRequests int
rateLimitWindow time.Duration
rateLimitByIP map[string]*rateLimitEntry
rateLimitMu sync.RWMutex
}
type rateLimitEntry struct {
count int
resetTime time.Time
}
// SecurityConfig holds security extension configuration
type SecurityConfig struct {
AllowedIPs []string `yaml:"allowed_ips"`
BlockedIPs []string `yaml:"blocked_ips"`
SecurityHeaders map[string]string `yaml:"security_headers"`
RateLimit *RateLimitConfig `yaml:"rate_limit"`
}
// RateLimitConfig holds rate limiting configuration
type RateLimitConfig struct {
Enabled bool `yaml:"enabled"`
Requests int `yaml:"requests"`
Window string `yaml:"window"` // e.g., "1m", "1h"
}
// Default security headers
var defaultSecurityHeaders = map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
}
// NewSecurityExtension creates a new security extension
func NewSecurityExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
ext := &SecurityExtension{
BaseExtension: NewBaseExtension("security", 10, logger), // High priority (early execution)
allowedIPs: make(map[string]bool),
blockedIPs: make(map[string]bool),
allowedCIDRs: make([]*net.IPNet, 0),
blockedCIDRs: make([]*net.IPNet, 0),
securityHeaders: make(map[string]string),
rateLimitByIP: make(map[string]*rateLimitEntry),
}
// Copy default security headers
for k, v := range defaultSecurityHeaders {
ext.securityHeaders[k] = v
}
// Parse allowed_ips
if allowedIPs, ok := config["allowed_ips"].([]interface{}); ok {
for _, ip := range allowedIPs {
if ipStr, ok := ip.(string); ok {
ext.addAllowedIP(ipStr)
}
}
}
// Parse blocked_ips
if blockedIPs, ok := config["blocked_ips"].([]interface{}); ok {
for _, ip := range blockedIPs {
if ipStr, ok := ip.(string); ok {
ext.addBlockedIP(ipStr)
}
}
}
// Parse security_headers
if headers, ok := config["security_headers"].(map[string]interface{}); ok {
for k, v := range headers {
if vStr, ok := v.(string); ok {
ext.securityHeaders[k] = vStr
}
}
}
// Parse rate_limit
if rateLimit, ok := config["rate_limit"].(map[string]interface{}); ok {
if enabled, ok := rateLimit["enabled"].(bool); ok && enabled {
ext.rateLimitEnabled = true
if requests, ok := rateLimit["requests"].(int); ok {
ext.rateLimitRequests = requests
} else if requestsFloat, ok := rateLimit["requests"].(float64); ok {
ext.rateLimitRequests = int(requestsFloat)
} else {
ext.rateLimitRequests = 100 // default
}
if window, ok := rateLimit["window"].(string); ok {
if duration, err := time.ParseDuration(window); err == nil {
ext.rateLimitWindow = duration
} else {
ext.rateLimitWindow = time.Minute // default
}
} else {
ext.rateLimitWindow = time.Minute
}
logger.Info("Rate limiting enabled",
"requests", ext.rateLimitRequests,
"window", ext.rateLimitWindow.String())
}
}
return ext, nil
}
func (e *SecurityExtension) addAllowedIP(ip string) {
if strings.Contains(ip, "/") {
// CIDR notation
_, cidr, err := net.ParseCIDR(ip)
if err == nil {
e.allowedCIDRs = append(e.allowedCIDRs, cidr)
}
} else {
e.allowedIPs[ip] = true
}
}
func (e *SecurityExtension) addBlockedIP(ip string) {
if strings.Contains(ip, "/") {
// CIDR notation
_, cidr, err := net.ParseCIDR(ip)
if err == nil {
e.blockedCIDRs = append(e.blockedCIDRs, cidr)
}
} else {
e.blockedIPs[ip] = true
}
}
// ProcessRequest checks security rules
func (e *SecurityExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
clientIP := getClientIP(r)
parsedIP := net.ParseIP(clientIP)
// Check blocked IPs first
if e.isBlocked(clientIP, parsedIP) {
e.logger.Warn("Blocked request from IP", "ip", clientIP)
http.Error(w, "403 Forbidden", http.StatusForbidden)
return true, nil
}
// Check allowed IPs (if configured, only these IPs are allowed)
if len(e.allowedIPs) > 0 || len(e.allowedCIDRs) > 0 {
if !e.isAllowed(clientIP, parsedIP) {
e.logger.Warn("Access denied for IP", "ip", clientIP)
http.Error(w, "403 Forbidden", http.StatusForbidden)
return true, nil
}
}
// Check rate limit
if e.rateLimitEnabled {
if !e.checkRateLimit(clientIP) {
e.logger.Warn("Rate limit exceeded", "ip", clientIP)
w.Header().Set("Retry-After", "60")
http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests)
return true, nil
}
}
return false, nil
}
// ProcessResponse adds security headers to the response
func (e *SecurityExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
for header, value := range e.securityHeaders {
w.Header().Set(header, value)
}
}
func (e *SecurityExtension) isBlocked(ip string, parsedIP net.IP) bool {
// Check exact match
if e.blockedIPs[ip] {
return true
}
// Check CIDR ranges
if parsedIP != nil {
for _, cidr := range e.blockedCIDRs {
if cidr.Contains(parsedIP) {
return true
}
}
}
return false
}
func (e *SecurityExtension) isAllowed(ip string, parsedIP net.IP) bool {
// Check exact match
if e.allowedIPs[ip] {
return true
}
// Check CIDR ranges
if parsedIP != nil {
for _, cidr := range e.allowedCIDRs {
if cidr.Contains(parsedIP) {
return true
}
}
}
return false
}
func (e *SecurityExtension) checkRateLimit(ip string) bool {
e.rateLimitMu.Lock()
defer e.rateLimitMu.Unlock()
now := time.Now()
entry, exists := e.rateLimitByIP[ip]
if !exists || now.After(entry.resetTime) {
// Create new entry or reset expired one
e.rateLimitByIP[ip] = &rateLimitEntry{
count: 1,
resetTime: now.Add(e.rateLimitWindow),
}
return true
}
// Increment counter
entry.count++
return entry.count <= e.rateLimitRequests
}
// AddBlockedIP adds an IP to the blocked list at runtime
func (e *SecurityExtension) AddBlockedIP(ip string) {
e.addBlockedIP(ip)
}
// RemoveBlockedIP removes an IP from the blocked list
func (e *SecurityExtension) RemoveBlockedIP(ip string) {
delete(e.blockedIPs, ip)
}
// AddAllowedIP adds an IP to the allowed list at runtime
func (e *SecurityExtension) AddAllowedIP(ip string) {
e.addAllowedIP(ip)
}
// RemoveAllowedIP removes an IP from the allowed list
func (e *SecurityExtension) RemoveAllowedIP(ip string) {
delete(e.allowedIPs, ip)
}
// SetSecurityHeader sets or updates a security header
func (e *SecurityExtension) SetSecurityHeader(name, value string) {
e.securityHeaders[name] = value
}
// GetMetrics returns security metrics
func (e *SecurityExtension) GetMetrics() map[string]interface{} {
e.rateLimitMu.RLock()
activeRateLimits := len(e.rateLimitByIP)
e.rateLimitMu.RUnlock()
return map[string]interface{}{
"allowed_ips": len(e.allowedIPs),
"allowed_cidrs": len(e.allowedCIDRs),
"blocked_ips": len(e.blockedIPs),
"blocked_cidrs": len(e.blockedCIDRs),
"security_headers": len(e.securityHeaders),
"rate_limit_enabled": e.rateLimitEnabled,
"active_rate_limits": activeRateLimits,
}
}
// Cleanup cleans up rate limit entries periodically
func (e *SecurityExtension) Cleanup() error {
e.rateLimitMu.Lock()
defer e.rateLimitMu.Unlock()
now := time.Now()
for ip, entry := range e.rateLimitByIP {
if now.After(entry.resetTime) {
delete(e.rateLimitByIP, ip)
}
}
return nil
}

View File

@ -1,213 +0,0 @@
package extension
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func TestNewSecurityExtension(t *testing.T) {
logger := newTestLogger()
ext, err := NewSecurityExtension(map[string]interface{}{}, logger)
if err != nil {
t.Fatalf("Failed to create security extension: %v", err)
}
if ext.Name() != "security" {
t.Errorf("Expected name 'security', got %s", ext.Name())
}
if ext.Priority() != 10 {
t.Errorf("Expected priority 10, got %d", ext.Priority())
}
}
func TestSecurityExtension_BlockedIP(t *testing.T) {
logger := newTestLogger()
ext, _ := NewSecurityExtension(map[string]interface{}{
"blocked_ips": []interface{}{"192.168.1.100"},
}, logger)
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
rr := httptest.NewRecorder()
handled, err := ext.ProcessRequest(context.Background(), rr, req)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !handled {
t.Error("Expected blocked request to be handled")
}
if rr.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d", rr.Code)
}
}
func TestSecurityExtension_AllowedIP(t *testing.T) {
logger := newTestLogger()
ext, _ := NewSecurityExtension(map[string]interface{}{
"allowed_ips": []interface{}{"192.168.1.50"},
}, logger)
// Allowed IP
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.50:12345"
rr := httptest.NewRecorder()
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
if handled {
t.Error("Expected allowed IP to pass through")
}
// Not allowed IP
req = httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.51:12345"
rr = httptest.NewRecorder()
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
if !handled {
t.Error("Expected non-allowed IP to be blocked")
}
if rr.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d", rr.Code)
}
}
func TestSecurityExtension_CIDR(t *testing.T) {
logger := newTestLogger()
ext, _ := NewSecurityExtension(map[string]interface{}{
"blocked_ips": []interface{}{"10.0.0.0/8"},
}, logger)
// IP in blocked CIDR
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "10.1.2.3:12345"
rr := httptest.NewRecorder()
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
if !handled {
t.Error("Expected IP in blocked CIDR to be blocked")
}
// IP not in blocked CIDR
req = httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr = httptest.NewRecorder()
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
if handled {
t.Error("Expected IP not in blocked CIDR to pass through")
}
}
func TestSecurityExtension_SecurityHeaders(t *testing.T) {
logger := newTestLogger()
ext, _ := NewSecurityExtension(map[string]interface{}{
"security_headers": map[string]interface{}{
"X-Custom-Header": "custom-value",
},
}, logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
ext.ProcessResponse(context.Background(), rr, req)
// Check default headers
if rr.Header().Get("X-Content-Type-Options") != "nosniff" {
t.Error("Expected X-Content-Type-Options header")
}
// Check custom header
if rr.Header().Get("X-Custom-Header") != "custom-value" {
t.Error("Expected custom header")
}
}
func TestSecurityExtension_RateLimit(t *testing.T) {
logger := newTestLogger()
ext, _ := NewSecurityExtension(map[string]interface{}{
"rate_limit": map[string]interface{}{
"enabled": true,
"requests": 2,
"window": "1m",
},
}, logger)
securityExt := ext.(*SecurityExtension)
clientIP := "192.168.1.1"
// First request - should pass
if !securityExt.checkRateLimit(clientIP) {
t.Error("First request should pass")
}
// Second request - should pass
if !securityExt.checkRateLimit(clientIP) {
t.Error("Second request should pass")
}
// Third request - should be rate limited
if securityExt.checkRateLimit(clientIP) {
t.Error("Third request should be rate limited")
}
}
func TestSecurityExtension_GetMetrics(t *testing.T) {
logger := newTestLogger()
ext, _ := NewSecurityExtension(map[string]interface{}{
"blocked_ips": []interface{}{"192.168.1.1"},
"allowed_ips": []interface{}{"192.168.1.2"},
}, logger)
securityExt := ext.(*SecurityExtension)
metrics := securityExt.GetMetrics()
if metrics["blocked_ips"].(int) != 1 {
t.Errorf("Expected 1 blocked IP, got %v", metrics["blocked_ips"])
}
if metrics["allowed_ips"].(int) != 1 {
t.Errorf("Expected 1 allowed IP, got %v", metrics["allowed_ips"])
}
}
func TestSecurityExtension_AddRemoveIPs(t *testing.T) {
logger := newTestLogger()
ext, _ := NewSecurityExtension(map[string]interface{}{}, logger)
securityExt := ext.(*SecurityExtension)
// Add blocked IP
securityExt.AddBlockedIP("192.168.1.100")
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
rr := httptest.NewRecorder()
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
if !handled {
t.Error("Expected dynamically blocked IP to be blocked")
}
// Remove blocked IP
securityExt.RemoveBlockedIP("192.168.1.100")
rr = httptest.NewRecorder()
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
if handled {
t.Error("Expected removed blocked IP to pass through")
}
}

View File

@ -1,338 +0,0 @@
// Package logging provides structured logging with zap
package logging
import (
"fmt"
"os"
"path/filepath"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
"github.com/konduktor/konduktor/internal/config"
)
// Config is a simple configuration for basic logger setup
type Config struct {
Level string
TimestampFormat string
}
// Logger wraps zap.SugaredLogger with additional functionality
type Logger struct {
*zap.SugaredLogger
zap *zap.Logger
config *config.LoggingConfig
name string
}
// New creates a new Logger with basic configuration
func New(cfg Config) (*Logger, error) {
level := parseLevel(cfg.Level)
timestampFormat := cfg.TimestampFormat
if timestampFormat == "" {
timestampFormat = "2006-01-02 15:04:05"
}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = "timestamp"
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(timestampFormat)
encoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
core := zapcore.NewCore(
zapcore.NewConsoleEncoder(encoderConfig),
zapcore.AddSync(os.Stdout),
level,
)
zapLogger := zap.New(core)
return &Logger{
SugaredLogger: zapLogger.Sugar(),
zap: zapLogger,
name: "konduktor",
}, nil
}
// NewFromConfig creates a Logger from full LoggingConfig
func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) {
var cores []zapcore.Core
// Parse main level
mainLevel := parseLevel(cfg.Level)
// Add console core if enabled
if cfg.ConsoleOutput {
consoleLevel := mainLevel
if cfg.Console != nil && cfg.Console.Level != "" {
consoleLevel = parseLevel(cfg.Console.Level)
}
var consoleEncoder zapcore.Encoder
formatConfig := cfg.Format
if cfg.Console != nil {
formatConfig = mergeFormatConfig(cfg.Format, cfg.Console.Format)
}
encoderCfg := createEncoderConfig(formatConfig)
if formatConfig.Type == "json" {
consoleEncoder = zapcore.NewJSONEncoder(encoderCfg)
} else {
if formatConfig.UseColors {
encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder
}
consoleEncoder = zapcore.NewConsoleEncoder(encoderCfg)
}
consoleSyncer := zapcore.AddSync(os.Stdout)
cores = append(cores, zapcore.NewCore(consoleEncoder, consoleSyncer, consoleLevel))
}
// Add file cores
for _, fileConfig := range cfg.Files {
fileCore, err := createFileCore(fileConfig, cfg.Format, mainLevel)
if err != nil {
return nil, fmt.Errorf("failed to create file logger for %s: %w", fileConfig.Path, err)
}
// If specific loggers are configured, wrap with filter
if len(fileConfig.Loggers) > 0 {
fileCore = &filteredCore{
Core: fileCore,
loggers: fileConfig.Loggers,
}
}
cores = append(cores, fileCore)
}
// If no cores configured, add default console
if len(cores) == 0 {
encoderCfg := zap.NewProductionEncoderConfig()
encoderCfg.EncodeTime = zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05")
encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder
cores = append(cores, zapcore.NewCore(
zapcore.NewConsoleEncoder(encoderCfg),
zapcore.AddSync(os.Stdout),
mainLevel,
))
}
// Combine all cores
core := zapcore.NewTee(cores...)
zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1))
return &Logger{
SugaredLogger: zapLogger.Sugar(),
zap: zapLogger,
config: &cfg,
name: "konduktor",
}, nil
}
// Named returns a logger with a specific name (for filtering)
func (l *Logger) Named(name string) *Logger {
return &Logger{
SugaredLogger: l.SugaredLogger.Named(name),
zap: l.zap.Named(name),
config: l.config,
name: name,
}
}
// With returns a logger with additional fields
func (l *Logger) With(args ...interface{}) *Logger {
return &Logger{
SugaredLogger: l.SugaredLogger.With(args...),
zap: l.zap.Sugar().With(args...).Desugar(),
config: l.config,
name: l.name,
}
}
// Sync flushes any buffered log entries
func (l *Logger) Sync() error {
return l.zap.Sync()
}
// GetZap returns the underlying zap.Logger
func (l *Logger) GetZap() *zap.Logger {
return l.zap
}
// Debug logs a debug message
func (l *Logger) Debug(msg string, keysAndValues ...interface{}) {
l.SugaredLogger.Debugw(msg, keysAndValues...)
}
// Info logs an info message
func (l *Logger) Info(msg string, keysAndValues ...interface{}) {
l.SugaredLogger.Infow(msg, keysAndValues...)
}
// Warn logs a warning message
func (l *Logger) Warn(msg string, keysAndValues ...interface{}) {
l.SugaredLogger.Warnw(msg, keysAndValues...)
}
// Error logs an error message
func (l *Logger) Error(msg string, keysAndValues ...interface{}) {
l.SugaredLogger.Errorw(msg, keysAndValues...)
}
// Fatal logs a fatal message and exits
func (l *Logger) Fatal(msg string, keysAndValues ...interface{}) {
l.SugaredLogger.Fatalw(msg, keysAndValues...)
}
// --- Helper functions ---
func parseLevel(level string) zapcore.Level {
switch strings.ToUpper(level) {
case "DEBUG":
return zapcore.DebugLevel
case "INFO":
return zapcore.InfoLevel
case "WARN", "WARNING":
return zapcore.WarnLevel
case "ERROR":
return zapcore.ErrorLevel
case "CRITICAL", "FATAL":
return zapcore.FatalLevel
default:
return zapcore.InfoLevel
}
}
func createEncoderConfig(format config.LogFormatConfig) zapcore.EncoderConfig {
timestampFormat := format.TimestampFormat
if timestampFormat == "" {
timestampFormat = "2006-01-02 15:04:05"
}
cfg := zapcore.EncoderConfig{
TimeKey: "timestamp",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
FunctionKey: zapcore.OmitKey,
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.CapitalLevelEncoder,
EncodeTime: zapcore.TimeEncoderOfLayout(timestampFormat),
EncodeDuration: zapcore.SecondsDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
if !format.ShowModule {
cfg.NameKey = zapcore.OmitKey
}
return cfg
}
func mergeFormatConfig(base, override config.LogFormatConfig) config.LogFormatConfig {
result := base
if override.Type != "" {
result.Type = override.Type
}
if override.TimestampFormat != "" {
result.TimestampFormat = override.TimestampFormat
}
// UseColors and ShowModule are bool - check if override has non-default
result.UseColors = override.UseColors
result.ShowModule = override.ShowModule
return result
}
func createFileCore(fileConfig config.FileLogConfig, defaultFormat config.LogFormatConfig, defaultLevel zapcore.Level) (zapcore.Core, error) {
// Ensure directory exists
dir := filepath.Dir(fileConfig.Path)
if dir != "" && dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory %s: %w", dir, err)
}
}
// Configure log rotation with lumberjack
maxSize := 10 // MB
if fileConfig.MaxBytes > 0 {
maxSize = int(fileConfig.MaxBytes / (1024 * 1024))
if maxSize < 1 {
maxSize = 1
}
}
backupCount := 5
if fileConfig.BackupCount > 0 {
backupCount = fileConfig.BackupCount
}
rotator := &lumberjack.Logger{
Filename: fileConfig.Path,
MaxSize: maxSize,
MaxBackups: backupCount,
MaxAge: 30, // days
Compress: true,
}
// Determine level
level := defaultLevel
if fileConfig.Level != "" {
level = parseLevel(fileConfig.Level)
}
// Create encoder
format := defaultFormat
if fileConfig.Format.Type != "" {
format = mergeFormatConfig(defaultFormat, fileConfig.Format)
}
// Files should not use colors
format.UseColors = false
encoderConfig := createEncoderConfig(format)
var encoder zapcore.Encoder
if format.Type == "json" {
encoder = zapcore.NewJSONEncoder(encoderConfig)
} else {
encoder = zapcore.NewConsoleEncoder(encoderConfig)
}
return zapcore.NewCore(encoder, zapcore.AddSync(rotator), level), nil
}
// filteredCore wraps a Core to filter by logger name
type filteredCore struct {
zapcore.Core
loggers []string
}
func (c *filteredCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if !c.shouldLog(entry.LoggerName) {
return ce
}
return c.Core.Check(entry, ce)
}
func (c *filteredCore) shouldLog(loggerName string) bool {
if len(c.loggers) == 0 {
return true
}
for _, allowed := range c.loggers {
if loggerName == allowed || strings.HasPrefix(loggerName, allowed+".") {
return true
}
}
return false
}
func (c *filteredCore) With(fields []zapcore.Field) zapcore.Core {
return &filteredCore{
Core: c.Core.With(fields),
loggers: c.loggers,
}
}

View File

@ -1,209 +0,0 @@
package logging
import (
"testing"
"github.com/konduktor/konduktor/internal/config"
)
func TestNew(t *testing.T) {
logger, err := New(Config{Level: "INFO"})
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if logger == nil {
t.Fatal("Expected logger, got nil")
}
if logger.name != "konduktor" {
t.Errorf("Expected name konduktor, got %s", logger.name)
}
}
func TestNew_DefaultTimestampFormat(t *testing.T) {
logger, err := New(Config{Level: "DEBUG"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Logger should be created successfully
if logger == nil {
t.Fatal("Expected logger, got nil")
}
}
func TestNew_CustomTimestampFormat(t *testing.T) {
logger, err := New(Config{
Level: "DEBUG",
TimestampFormat: "15:04:05",
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if logger == nil {
t.Fatal("Expected logger, got nil")
}
}
func TestNewFromConfig(t *testing.T) {
cfg := config.LoggingConfig{
Level: "DEBUG",
ConsoleOutput: true,
Format: config.LogFormatConfig{
Type: "standard",
UseColors: true,
ShowModule: true,
TimestampFormat: "2006-01-02 15:04:05",
},
}
logger, err := NewFromConfig(cfg)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if logger == nil {
t.Fatal("Expected logger, got nil")
}
}
func TestNewFromConfig_WithConsole(t *testing.T) {
cfg := config.LoggingConfig{
Level: "INFO",
ConsoleOutput: true,
Format: config.LogFormatConfig{
Type: "standard",
UseColors: true,
},
Console: &config.ConsoleLogConfig{
Level: "DEBUG",
Format: config.LogFormatConfig{
Type: "standard",
UseColors: false,
},
},
}
logger, err := NewFromConfig(cfg)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if logger == nil {
t.Fatal("Expected logger, got nil")
}
}
func TestLogger_Debug(t *testing.T) {
logger, _ := New(Config{Level: "DEBUG"})
// Should not panic
logger.Debug("test message", "key", "value")
}
func TestLogger_Info(t *testing.T) {
logger, _ := New(Config{Level: "INFO"})
// Should not panic
logger.Info("test message", "key", "value")
}
func TestLogger_Warn(t *testing.T) {
logger, _ := New(Config{Level: "WARN"})
// Should not panic
logger.Warn("test message", "key", "value")
}
func TestLogger_Error(t *testing.T) {
logger, _ := New(Config{Level: "ERROR"})
// Should not panic
logger.Error("test message", "key", "value")
}
func TestLogger_Named(t *testing.T) {
logger, _ := New(Config{Level: "INFO"})
named := logger.Named("test.module")
if named == nil {
t.Fatal("Expected named logger, got nil")
}
if named.name != "test.module" {
t.Errorf("Expected name 'test.module', got %s", named.name)
}
// Should not panic
named.Info("test from named logger")
}
func TestLogger_With(t *testing.T) {
logger, _ := New(Config{Level: "INFO"})
withFields := logger.With("service", "test")
if withFields == nil {
t.Fatal("Expected logger with fields, got nil")
}
// Should not panic
withFields.Info("test with fields")
}
func TestLogger_Sync(t *testing.T) {
logger, _ := New(Config{Level: "INFO"})
// Should not panic
err := logger.Sync()
// Sync may return an error for stdout on some systems, ignore it
_ = err
}
func TestParseLevel(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"DEBUG", "debug"},
{"INFO", "info"},
{"WARN", "warn"},
{"WARNING", "warn"},
{"ERROR", "error"},
{"CRITICAL", "fatal"},
{"FATAL", "fatal"},
{"invalid", "info"}, // defaults to INFO
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
level := parseLevel(tt.input)
if level.String() != tt.expected {
t.Errorf("parseLevel(%s) = %s, want %s", tt.input, level.String(), tt.expected)
}
})
}
}
// ============== Benchmarks ==============
func BenchmarkLogger_Info(b *testing.B) {
logger, _ := New(Config{Level: "INFO"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Info("test message", "key", "value")
}
}
func BenchmarkLogger_Debug_Filtered(b *testing.B) {
logger, _ := New(Config{Level: "ERROR"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Debug("test message", "key", "value")
}
}

View File

@ -1,74 +0,0 @@
package middleware
import (
"fmt"
"net/http"
"runtime/debug"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
type responseWriter struct {
http.ResponseWriter
status int
size int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
size, err := rw.ResponseWriter.Write(b)
rw.size += size
return size, err
}
func ServerHeader(next http.Handler, version string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Server", fmt.Sprintf("konduktor/%s", version))
next.ServeHTTP(w, r)
})
}
func AccessLog(next http.Handler, logger *logging.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrapped := &responseWriter{
ResponseWriter: w,
status: http.StatusOK,
}
next.ServeHTTP(wrapped, r)
duration := time.Since(start)
logger.Info("HTTP request",
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.status,
"duration_ms", duration.Milliseconds(),
"client_ip", r.RemoteAddr,
"user_agent", r.UserAgent(),
)
})
}
func Recovery(next http.Handler, logger *logging.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
logger.Error("Panic recovered",
"error", fmt.Sprintf("%v", err),
"stack", string(debug.Stack()),
)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}

View File

@ -1,244 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/konduktor/konduktor/internal/logging"
)
// ============== ServerHeader Tests ==============
func TestServerHeader(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := ServerHeader(handler, "1.0.0")
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
serverHeader := rr.Header().Get("Server")
if serverHeader != "konduktor/1.0.0" {
t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader)
}
}
// ============== AccessLog Tests ==============
func TestAccessLog(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
}
func TestAccessLog_CapturesStatusCode(t *testing.T) {
tests := []struct {
name string
statusCode int
}{
{"OK", http.StatusOK},
{"NotFound", http.StatusNotFound},
{"InternalError", http.StatusInternalServerError},
{"Redirect", http.StatusMovedPermanently},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != tt.statusCode {
t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code)
}
})
}
}
// ============== Recovery Tests ==============
func TestRecovery_NoPanic(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
if rr.Body.String() != "OK" {
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
}
}
func TestRecovery_WithPanic(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("test panic")
})
logger, _ := logging.New(logging.Config{Level: "ERROR"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
// Should not panic
wrapped.ServeHTTP(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Errorf("Expected status 500, got %d", rr.Code)
}
if !strings.Contains(rr.Body.String(), "Internal Server Error") {
t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String())
}
}
// ============== responseWriter Tests ==============
func TestResponseWriter_WriteHeader(t *testing.T) {
rr := httptest.NewRecorder()
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
rw.WriteHeader(http.StatusNotFound)
if rw.status != http.StatusNotFound {
t.Errorf("Expected status 404, got %d", rw.status)
}
}
func TestResponseWriter_Write(t *testing.T) {
rr := httptest.NewRecorder()
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
n, err := rw.Write([]byte("Hello World"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if n != 11 {
t.Errorf("Expected 11 bytes written, got %d", n)
}
if rw.size != 11 {
t.Errorf("Expected size 11, got %d", rw.size)
}
}
// ============== Middleware Chain Tests ==============
func TestMiddlewareChain(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
logger, _ := logging.New(logging.Config{Level: "INFO"})
// Apply middleware chain
wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
// Check all middleware worked
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
if rr.Header().Get("Server") != "konduktor/1.0.0" {
t.Errorf("Expected Server header")
}
if rr.Body.String() != "OK" {
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
}
}
// ============== Benchmarks ==============
func BenchmarkServerHeader(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := ServerHeader(handler, "1.0.0")
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}
func BenchmarkAccessLog(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead
wrapped := AccessLog(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}
func BenchmarkRecovery(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
logger, _ := logging.New(logging.Config{Level: "ERROR"})
wrapped := Recovery(handler, logger)
req := httptest.NewRequest("GET", "/", nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rr := httptest.NewRecorder()
wrapped.ServeHTTP(rr, req)
}
}

View File

@ -1,263 +0,0 @@
package pathmatcher
import (
"strings"
"sync"
)
type MountedPath struct {
path string
name string
stripPath bool
}
func NewMountedPath(path string, opts ...MountedPathOption) *MountedPath {
// Normalize: remove trailing slash (except for root)
normalizedPath := strings.TrimSuffix(path, "/")
if normalizedPath == "" {
normalizedPath = ""
}
m := &MountedPath{
path: normalizedPath,
name: normalizedPath,
stripPath: true,
}
for _, opt := range opts {
opt(m)
}
if m.name == "" {
m.name = normalizedPath
}
return m
}
type MountedPathOption func(*MountedPath)
func WithName(name string) MountedPathOption {
return func(m *MountedPath) {
m.name = name
}
}
func WithStripPath(strip bool) MountedPathOption {
return func(m *MountedPath) {
m.stripPath = strip
}
}
func (m *MountedPath) Path() string {
return m.path
}
func (m *MountedPath) Name() string {
return m.name
}
func (m *MountedPath) StripPath() bool {
return m.stripPath
}
func (m *MountedPath) Matches(requestPath string) bool {
// Empty or "/" mount matches everything
if m.path == "" || m.path == "/" {
return true
}
// Request path must be at least as long as mount path
if len(requestPath) < len(m.path) {
return false
}
// Check if request path starts with mount path
if !strings.HasPrefix(requestPath, m.path) {
return false
}
// If paths are equal length, it's a match
if len(requestPath) == len(m.path) {
return true
}
// Otherwise, next char must be '/' to prevent /api matching /api-v2
return requestPath[len(m.path)] == '/'
}
func (m *MountedPath) GetModifiedPath(requestPath string) string {
if !m.stripPath {
return requestPath
}
// Root mount doesn't strip anything
if m.path == "" || m.path == "/" {
return requestPath
}
// Strip the prefix
modified := strings.TrimPrefix(requestPath, m.path)
// Ensure result starts with /
if modified == "" || modified[0] != '/' {
modified = "/" + modified
}
return modified
}
type MountManager struct {
mounts []*MountedPath
mu sync.RWMutex
}
func NewMountManager() *MountManager {
return &MountManager{
mounts: make([]*MountedPath, 0),
}
}
func (mm *MountManager) AddMount(mount *MountedPath) {
mm.mu.Lock()
defer mm.mu.Unlock()
// Insert in sorted order (longer paths first)
inserted := false
for i, existing := range mm.mounts {
if len(mount.path) > len(existing.path) {
// Insert at position i
mm.mounts = append(mm.mounts[:i], append([]*MountedPath{mount}, mm.mounts[i:]...)...)
inserted = true
break
}
}
if !inserted {
mm.mounts = append(mm.mounts, mount)
}
}
func (mm *MountManager) RemoveMount(path string) bool {
mm.mu.Lock()
defer mm.mu.Unlock()
normalizedPath := strings.TrimSuffix(path, "/")
for i, mount := range mm.mounts {
if mount.path == normalizedPath {
mm.mounts = append(mm.mounts[:i], mm.mounts[i+1:]...)
return true
}
}
return false
}
func (mm *MountManager) GetMount(requestPath string) *MountedPath {
mm.mu.RLock()
defer mm.mu.RUnlock()
// Mounts are sorted by path length (longest first)
// so the first match is the best match
for _, mount := range mm.mounts {
if mount.Matches(requestPath) {
return mount
}
}
return nil
}
func (mm *MountManager) MountCount() int {
mm.mu.RLock()
defer mm.mu.RUnlock()
return len(mm.mounts)
}
func (mm *MountManager) Mounts() []*MountedPath {
mm.mu.RLock()
defer mm.mu.RUnlock()
result := make([]*MountedPath, len(mm.mounts))
copy(result, mm.mounts)
return result
}
func (mm *MountManager) ListMounts() []map[string]interface{} {
mm.mu.RLock()
defer mm.mu.RUnlock()
result := make([]map[string]interface{}, len(mm.mounts))
for i, mount := range mm.mounts {
result[i] = map[string]interface{}{
"path": mount.path,
"name": mount.name,
"strip_path": mount.stripPath,
}
}
return result
}
// Utility functions
func PathMatchesPrefix(requestPath, prefix string) bool {
// Normalize prefix
prefix = strings.TrimSuffix(prefix, "/")
// Empty or "/" prefix matches everything
if prefix == "" || prefix == "/" {
return true
}
// Request path must be at least as long as prefix
if len(requestPath) < len(prefix) {
return false
}
// Check if request path starts with prefix
if !strings.HasPrefix(requestPath, prefix) {
return false
}
// If paths are equal length, it's a match
if len(requestPath) == len(prefix) {
return true
}
// Otherwise, next char must be '/'
return requestPath[len(prefix)] == '/'
}
func StripPathPrefix(requestPath, prefix string) string {
// Normalize prefix
prefix = strings.TrimSuffix(prefix, "/")
// Empty or "/" prefix doesn't strip anything
if prefix == "" || prefix == "/" {
return requestPath
}
// Strip the prefix
modified := strings.TrimPrefix(requestPath, prefix)
// Ensure result starts with /
if modified == "" || modified[0] != '/' {
modified = "/" + modified
}
return modified
}
func MatchAndModifyPath(requestPath, prefix string, stripPath bool) (matches bool, modifiedPath string) {
if !PathMatchesPrefix(requestPath, prefix) {
return false, ""
}
if stripPath {
return true, StripPathPrefix(requestPath, prefix)
}
return true, requestPath
}

View File

@ -1,460 +0,0 @@
package pathmatcher
import (
"testing"
)
// ============== MountedPath Tests ==============
func TestMountedPath_RootMountMatchesEverything(t *testing.T) {
mount := NewMountedPath("")
tests := []string{"/", "/api", "/api/users", "/anything/at/all"}
for _, path := range tests {
if !mount.Matches(path) {
t.Errorf("Root mount should match %s", path)
}
}
}
func TestMountedPath_SlashRootMountMatchesEverything(t *testing.T) {
mount := NewMountedPath("/")
tests := []string{"/", "/api", "/api/users"}
for _, path := range tests {
if !mount.Matches(path) {
t.Errorf("'/' mount should match %s", path)
}
}
}
func TestMountedPath_ExactPathMatch(t *testing.T) {
mount := NewMountedPath("/api")
tests := []struct {
path string
expected bool
}{
{"/api", true},
{"/api/", true},
{"/api/users", true},
}
for _, tt := range tests {
if got := mount.Matches(tt.path); got != tt.expected {
t.Errorf("Matches(%s) = %v, want %v", tt.path, got, tt.expected)
}
}
}
func TestMountedPath_NoFalsePrefixMatch(t *testing.T) {
mount := NewMountedPath("/api")
tests := []string{"/api-v2", "/api2", "/apiv2"}
for _, path := range tests {
if mount.Matches(path) {
t.Errorf("/api should not match %s", path)
}
}
}
func TestMountedPath_ShorterPathNoMatch(t *testing.T) {
mount := NewMountedPath("/api/v1")
tests := []string{"/api", "/ap", "/"}
for _, path := range tests {
if mount.Matches(path) {
t.Errorf("/api/v1 should not match shorter path %s", path)
}
}
}
func TestMountedPath_TrailingSlashNormalized(t *testing.T) {
mount1 := NewMountedPath("/api/")
mount2 := NewMountedPath("/api")
if mount1.Path() != "/api" {
t.Errorf("Expected path /api, got %s", mount1.Path())
}
if mount2.Path() != "/api" {
t.Errorf("Expected path /api, got %s", mount2.Path())
}
if !mount1.Matches("/api/users") {
t.Error("mount1 should match /api/users")
}
if !mount2.Matches("/api/users") {
t.Error("mount2 should match /api/users")
}
}
func TestMountedPath_GetModifiedPathStripsPrefix(t *testing.T) {
mount := NewMountedPath("/api")
tests := []struct {
input string
expected string
}{
{"/api", "/"},
{"/api/", "/"},
{"/api/users", "/users"},
{"/api/users/123", "/users/123"},
}
for _, tt := range tests {
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
}
}
}
func TestMountedPath_GetModifiedPathNoStrip(t *testing.T) {
mount := NewMountedPath("/api", WithStripPath(false))
tests := []struct {
input string
expected string
}{
{"/api/users", "/api/users"},
{"/api", "/api"},
}
for _, tt := range tests {
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
}
}
}
func TestMountedPath_RootMountModifiedPath(t *testing.T) {
mount := NewMountedPath("")
tests := []struct {
input string
expected string
}{
{"/api/users", "/api/users"},
{"/", "/"},
}
for _, tt := range tests {
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
}
}
}
func TestMountedPath_NameProperty(t *testing.T) {
mount1 := NewMountedPath("/api")
mount2 := NewMountedPath("/api", WithName("API Mount"))
if mount1.Name() != "/api" {
t.Errorf("Expected name /api, got %s", mount1.Name())
}
if mount2.Name() != "API Mount" {
t.Errorf("Expected name 'API Mount', got %s", mount2.Name())
}
}
// ============== MountManager Tests ==============
func TestMountManager_EmptyManager(t *testing.T) {
manager := NewMountManager()
if got := manager.GetMount("/api"); got != nil {
t.Error("Empty manager should return nil")
}
if got := manager.MountCount(); got != 0 {
t.Errorf("Expected mount count 0, got %d", got)
}
}
func TestMountManager_AddMount(t *testing.T) {
manager := NewMountManager()
mount := NewMountedPath("/api")
manager.AddMount(mount)
if manager.MountCount() != 1 {
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
}
if got := manager.GetMount("/api/users"); got != mount {
t.Error("GetMount should return the added mount")
}
}
func TestMountManager_LongestPrefixMatching(t *testing.T) {
manager := NewMountManager()
apiMount := NewMountedPath("/api", WithName("api"))
apiV1Mount := NewMountedPath("/api/v1", WithName("api_v1"))
apiV2Mount := NewMountedPath("/api/v2", WithName("api_v2"))
manager.AddMount(apiMount)
manager.AddMount(apiV2Mount)
manager.AddMount(apiV1Mount)
tests := []struct {
path string
expectedName string
}{
{"/api/v1/users", "api_v1"},
{"/api/v2/items", "api_v2"},
{"/api/v3/other", "api"},
{"/api", "api"},
}
for _, tt := range tests {
got := manager.GetMount(tt.path)
if got == nil {
t.Errorf("GetMount(%s) returned nil, want mount with name %s", tt.path, tt.expectedName)
continue
}
if got.Name() != tt.expectedName {
t.Errorf("GetMount(%s).Name() = %s, want %s", tt.path, got.Name(), tt.expectedName)
}
}
}
func TestMountManager_RemoveMount(t *testing.T) {
manager := NewMountManager()
manager.AddMount(NewMountedPath("/api"))
manager.AddMount(NewMountedPath("/admin"))
if manager.MountCount() != 2 {
t.Errorf("Expected mount count 2, got %d", manager.MountCount())
}
result := manager.RemoveMount("/api")
if !result {
t.Error("RemoveMount should return true")
}
if manager.MountCount() != 1 {
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
}
if manager.GetMount("/api/users") != nil {
t.Error("GetMount(/api/users) should return nil after removal")
}
if manager.GetMount("/admin/users") == nil {
t.Error("GetMount(/admin/users) should still work")
}
}
func TestMountManager_RemoveNonexistentMount(t *testing.T) {
manager := NewMountManager()
result := manager.RemoveMount("/api")
if result {
t.Error("RemoveMount should return false for nonexistent mount")
}
}
func TestMountManager_ListMounts(t *testing.T) {
manager := NewMountManager()
manager.AddMount(NewMountedPath("/api", WithName("API")))
manager.AddMount(NewMountedPath("/admin", WithName("Admin")))
mounts := manager.ListMounts()
if len(mounts) != 2 {
t.Errorf("Expected 2 mounts, got %d", len(mounts))
}
for _, m := range mounts {
if _, ok := m["path"]; !ok {
t.Error("Mount should have 'path' key")
}
if _, ok := m["name"]; !ok {
t.Error("Mount should have 'name' key")
}
if _, ok := m["strip_path"]; !ok {
t.Error("Mount should have 'strip_path' key")
}
}
}
func TestMountManager_MountsReturnsCopy(t *testing.T) {
manager := NewMountManager()
manager.AddMount(NewMountedPath("/api"))
mounts1 := manager.Mounts()
mounts2 := manager.Mounts()
if &mounts1[0] == &mounts2[0] {
t.Error("Mounts() should return different slices")
}
}
// ============== Utility Functions Tests ==============
func TestPathMatchesPrefix_Basic(t *testing.T) {
tests := []struct {
path string
prefix string
expected bool
}{
{"/api/users", "/api", true},
{"/api", "/api", true},
{"/api-v2", "/api", false},
{"/ap", "/api", false},
}
for _, tt := range tests {
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
}
}
}
func TestPathMatchesPrefix_Root(t *testing.T) {
tests := []struct {
path string
prefix string
expected bool
}{
{"/anything", "", true},
{"/anything", "/", true},
}
for _, tt := range tests {
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
}
}
}
func TestStripPathPrefix_Basic(t *testing.T) {
tests := []struct {
path string
prefix string
expected string
}{
{"/api/users", "/api", "/users"},
{"/api", "/api", "/"},
{"/api/", "/api", "/"},
}
for _, tt := range tests {
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
}
}
}
func TestStripPathPrefix_Root(t *testing.T) {
tests := []struct {
path string
prefix string
expected string
}{
{"/api/users", "", "/api/users"},
{"/api/users", "/", "/api/users"},
}
for _, tt := range tests {
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
}
}
}
func TestMatchAndModifyPath_Combined(t *testing.T) {
tests := []struct {
path string
prefix string
stripPath bool
wantMatches bool
wantModified string
}{
{"/api/users", "/api", true, true, "/users"},
{"/api", "/api", true, true, "/"},
{"/other", "/api", true, false, ""},
{"/api/users", "/api", false, true, "/api/users"},
}
for _, tt := range tests {
matches, modified := MatchAndModifyPath(tt.path, tt.prefix, tt.stripPath)
if matches != tt.wantMatches {
t.Errorf("MatchAndModifyPath(%s, %s, %v) matches = %v, want %v",
tt.path, tt.prefix, tt.stripPath, matches, tt.wantMatches)
}
if modified != tt.wantModified {
t.Errorf("MatchAndModifyPath(%s, %s, %v) modified = %s, want %s",
tt.path, tt.prefix, tt.stripPath, modified, tt.wantModified)
}
}
}
// ============== Performance Tests ==============
func TestPerformance_ManyMatches(t *testing.T) {
mount := NewMountedPath("/api/v1/users")
for i := 0; i < 10000; i++ {
if !mount.Matches("/api/v1/users/123/posts") {
t.Fatal("Should match")
}
if mount.Matches("/other/path") {
t.Fatal("Should not match")
}
}
}
func TestPerformance_ManyMounts(t *testing.T) {
manager := NewMountManager()
for i := 0; i < 100; i++ {
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10)) + string(rune('0'+i/10))))
}
if manager.MountCount() != 100 {
t.Errorf("Expected 100 mounts, got %d", manager.MountCount())
}
}
// ============== Benchmarks ==============
func BenchmarkMountedPath_Matches(b *testing.B) {
mount := NewMountedPath("/api/v1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
mount.Matches("/api/v1/users/123")
}
}
func BenchmarkMountManager_GetMount(b *testing.B) {
manager := NewMountManager()
for i := 0; i < 20; i++ {
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10))))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.GetMount("/api/v5/users/123")
}
}
func BenchmarkPathMatchesPrefix(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
PathMatchesPrefix("/api/v1/users/123", "/api/v1")
}
}

View File

@ -1,333 +0,0 @@
// Package proxy provides reverse proxy functionality for Konduktor
package proxy
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
type Config struct {
// Target is the backend server URL
Target string
// Timeout is the request timeout (default: 30s)
Timeout time.Duration
// Headers are additional headers to add to requests
Headers map[string]string
// StripPrefix removes this prefix from the request path
StripPrefix string
// PreserveHost keeps the original Host header
PreserveHost bool
// IgnoreRequestPath ignores the request path and uses only the target path
// This is useful for exact match routes where target URL should be used as-is
IgnoreRequestPath bool
}
type ReverseProxy struct {
config *Config
targetURL *url.URL
httpClient *http.Client
logger *logging.Logger
}
func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) {
if cfg.Target == "" {
return nil, fmt.Errorf("proxy target is required")
}
targetURL, err := url.Parse(cfg.Target)
if err != nil {
return nil, fmt.Errorf("invalid proxy target URL: %w", err)
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: timeout,
}
return &ReverseProxy{
config: cfg,
targetURL: targetURL,
httpClient: &http.Client{
Transport: transport,
Timeout: timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Don't follow redirects
},
},
logger: logger,
}, nil
}
func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rp.ProxyRequest(w, r, nil)
}
func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) {
ctx := r.Context()
// Build target URL
targetURL := rp.buildTargetURL(r)
// Create proxy request
proxyReq, err := rp.createProxyRequest(ctx, r, targetURL)
if err != nil {
rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err)
return
}
// Add custom headers with parameter substitution
rp.addCustomHeaders(proxyReq, r, params)
// Execute request
resp, err := rp.httpClient.Do(proxyReq)
if err != nil {
rp.handleProxyError(w, err)
return
}
defer resp.Body.Close()
// Copy response
rp.copyResponse(w, resp)
}
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
targetURL := *rp.targetURL
// If ignoring request path, use target URL path as-is
if rp.config.IgnoreRequestPath {
// Preserve query string only
targetURL.RawQuery = r.URL.RawQuery
return &targetURL
}
// Strip prefix if configured
path := r.URL.Path
if rp.config.StripPrefix != "" {
path = strings.TrimPrefix(path, rp.config.StripPrefix)
if path == "" || path[0] != '/' {
path = "/" + path
}
}
// If target URL has a non-empty path, combine it with the request path
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
// Combine target path with request path
targetURL.Path = strings.TrimSuffix(rp.targetURL.Path, "/") + path
} else {
// No path in target, use request path as-is
targetURL.Path = path
}
// Preserve query string
targetURL.RawQuery = r.URL.RawQuery
return &targetURL
}
func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) {
proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body)
if err != nil {
return nil, err
}
// Copy ContentLength
proxyReq.ContentLength = r.ContentLength
// Copy headers
for key, values := range r.Header {
for _, value := range values {
proxyReq.Header.Add(key, value)
}
}
// Set/update Host header
if rp.config.PreserveHost {
proxyReq.Host = r.Host
} else {
proxyReq.Host = targetURL.Host
}
// Remove hop-by-hop headers
removeHopByHopHeaders(proxyReq.Header)
return proxyReq, nil
}
func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) {
// Add X-Forwarded headers
clientIP := getClientIP(originalReq)
if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" {
clientIP = prior + ", " + clientIP
}
proxyReq.Header.Set("X-Forwarded-For", clientIP)
proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq))
proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host)
// Add custom headers from config
for key, value := range rp.config.Headers {
// Substitute parameters like {version}
substituted := value
for paramKey, paramValue := range params {
substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue)
}
// Substitute $remote_addr
substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP)
proxyReq.Header.Set(key, substituted)
}
}
func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) {
// Copy headers
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
// Remove hop-by-hop headers from response
removeHopByHopHeaders(w.Header())
// Write status code
w.WriteHeader(resp.StatusCode)
// Copy body
io.Copy(w, resp.Body)
}
func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) {
if rp.logger != nil {
rp.logger.Error(message, "error", err)
}
http.Error(w, message, status)
}
func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) {
if rp.logger != nil {
rp.logger.Error("Proxy request failed", "error", err)
}
// Check for timeout
if err, ok := err.(net.Error); ok && err.Timeout() {
http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout)
return
}
// Check for connection errors
if isConnectionError(err) {
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
return
}
// Context cancelled (client disconnected)
if err == context.Canceled {
return
}
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
}
// Helper functions
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func removeHopByHopHeaders(h http.Header) {
hopByHopHeaders := []string{
"Connection",
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}
for _, header := range hopByHopHeaders {
h.Del(header)
}
}
func getClientIP(r *http.Request) string {
// Check X-Real-IP first
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
// Get from RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func getScheme(r *http.Request) string {
if r.TLS != nil {
return "https"
}
if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
return "http"
}
func isConnectionError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
connectionErrors := []string{
"connection refused",
"no such host",
"network is unreachable",
"connection reset",
"broken pipe",
}
for _, connErr := range connectionErrors {
if strings.Contains(strings.ToLower(errStr), connErr) {
return true
}
}
return false
}

View File

@ -1,747 +0,0 @@
package proxy
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// ============== Test Backend Server ==============
type testBackend struct {
server *httptest.Server
requestLog []requestLogEntry
mu sync.Mutex
requestCount int64
}
type requestLogEntry struct {
Method string
Path string
Query string
Headers http.Header
Body string
}
func newTestBackend(handler http.HandlerFunc) *testBackend {
tb := &testBackend{
requestLog: make([]requestLogEntry, 0),
}
if handler == nil {
handler = tb.defaultHandler
}
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tb.logRequest(r)
handler(w, r)
}))
return tb
}
func (tb *testBackend) logRequest(r *http.Request) {
tb.mu.Lock()
defer tb.mu.Unlock()
body, _ := io.ReadAll(r.Body)
// Restore the body for the handler
r.Body = io.NopCloser(bytes.NewReader(body))
tb.requestLog = append(tb.requestLog, requestLogEntry{
Method: r.Method,
Path: r.URL.Path,
Query: r.URL.RawQuery,
Headers: r.Header.Clone(),
Body: string(body),
})
atomic.AddInt64(&tb.requestCount, 1)
}
func (tb *testBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"message": "Backend response",
"path": r.URL.Path,
"method": r.Method,
})
}
func (tb *testBackend) close() {
tb.server.Close()
}
func (tb *testBackend) URL() string {
return tb.server.URL
}
func (tb *testBackend) getRequestCount() int64 {
return atomic.LoadInt64(&tb.requestCount)
}
func (tb *testBackend) getLastRequest() *requestLogEntry {
tb.mu.Lock()
defer tb.mu.Unlock()
if len(tb.requestLog) == 0 {
return nil
}
return &tb.requestLog[len(tb.requestLog)-1]
}
// ============== Proxy Creation Tests ==============
func TestNew_ValidConfig(t *testing.T) {
cfg := &Config{
Target: "http://localhost:8080",
Timeout: 10 * time.Second,
}
proxy, err := New(cfg, nil)
if err != nil {
t.Fatalf("Failed to create proxy: %v", err)
}
if proxy == nil {
t.Fatal("Expected proxy instance")
}
}
func TestNew_EmptyTarget(t *testing.T) {
cfg := &Config{
Target: "",
}
_, err := New(cfg, nil)
if err == nil {
t.Error("Expected error for empty target")
}
}
func TestNew_InvalidTargetURL(t *testing.T) {
cfg := &Config{
Target: "://invalid-url",
}
_, err := New(cfg, nil)
if err == nil {
t.Error("Expected error for invalid URL")
}
}
func TestNew_DefaultTimeout(t *testing.T) {
cfg := &Config{
Target: "http://localhost:8080",
}
proxy, err := New(cfg, nil)
if err != nil {
t.Fatalf("Failed to create proxy: %v", err)
}
if proxy.httpClient.Timeout != 30*time.Second {
t.Errorf("Expected default timeout 30s, got %v", proxy.httpClient.Timeout)
}
}
// ============== Basic Proxy Tests ==============
func TestProxy_BasicGET(t *testing.T) {
backend := newTestBackend(nil)
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
var response map[string]interface{}
json.NewDecoder(rr.Body).Decode(&response)
if response["path"] != "/test" {
t.Errorf("Expected path /test, got %v", response["path"])
}
}
func TestProxy_BasicPOST(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"received": string(body),
"method": r.Method,
})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("POST", "/api/data", strings.NewReader(`{"key":"value"}`))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", rr.Code)
}
var response map[string]interface{}
json.NewDecoder(rr.Body).Decode(&response)
if response["method"] != "POST" {
t.Errorf("Expected method POST, got %v", response["method"])
}
}
func TestProxy_PUT(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("PUT", "/resource/123", strings.NewReader(`{"name":"updated"}`))
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["method"] != "PUT" {
t.Errorf("Expected method PUT, got %v", response["method"])
}
}
func TestProxy_DELETE(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("DELETE", "/resource/123", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["method"] != "DELETE" {
t.Errorf("Expected method DELETE, got %v", response["method"])
}
}
// ============== Header Tests ==============
func TestProxy_HeadersForwarding(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]interface{}{
"custom_header": r.Header.Get("X-Custom-Header"),
"forwarded_for": r.Header.Get("X-Forwarded-For"),
"forwarded_host": r.Header.Get("X-Forwarded-Host"),
})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/headers", nil)
req.Header.Set("X-Custom-Header", "test-value")
req.RemoteAddr = "192.168.1.100:12345"
req.Host = "example.com"
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]interface{}
json.NewDecoder(rr.Body).Decode(&response)
if response["custom_header"] != "test-value" {
t.Errorf("Expected custom header, got %v", response["custom_header"])
}
if response["forwarded_for"] != "192.168.1.100" {
t.Errorf("Expected X-Forwarded-For, got %v", response["forwarded_for"])
}
if response["forwarded_host"] != "example.com" {
t.Errorf("Expected X-Forwarded-Host, got %v", response["forwarded_host"])
}
}
func TestProxy_CustomHeaders(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"api_version": r.Header.Get("X-API-Version"),
})
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL(),
Headers: map[string]string{
"X-API-Version": "{version}",
},
}, nil)
req := httptest.NewRequest("GET", "/api", nil)
rr := httptest.NewRecorder()
// Simulate parameter substitution
proxy.ProxyRequest(rr, req, map[string]string{"version": "2"})
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["api_version"] != "2" {
t.Errorf("Expected API version 2, got %v", response["api_version"])
}
}
func TestProxy_RemoteAddrSubstitution(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"client_ip": r.Header.Get("X-Client-IP"),
})
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL(),
Headers: map[string]string{
"X-Client-IP": "$remote_addr",
},
}, nil)
req := httptest.NewRequest("GET", "/api", nil)
req.RemoteAddr = "10.0.0.1:54321"
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["client_ip"] != "10.0.0.1" {
t.Errorf("Expected client IP 10.0.0.1, got %v", response["client_ip"])
}
}
// ============== Query String Tests ==============
func TestProxy_QueryStringPreservation(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"query": r.URL.RawQuery,
"param": r.URL.Query().Get("key"),
})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/search?key=value&page=2", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["param"] != "value" {
t.Errorf("Expected query param 'value', got %v", response["param"])
}
}
// ============== Status Code Tests ==============
func TestProxy_StatusCodePreservation(t *testing.T) {
statusCodes := []int{200, 201, 400, 404, 500}
for _, code := range statusCodes {
code := code // capture range variable
t.Run(fmt.Sprintf("Status_%d", code), func(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(code)
w.Write([]byte("OK"))
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/status", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != code {
t.Errorf("Expected status %d, got %d", code, rr.Code)
}
})
}
}
// ============== Error Handling Tests ==============
func TestProxy_BackendUnavailable(t *testing.T) {
// Use a port that's definitely not listening
proxy, _ := New(&Config{
Target: "http://127.0.0.1:59999",
Timeout: 1 * time.Second,
}, nil)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusBadGateway {
t.Errorf("Expected status 502 Bad Gateway, got %d", rr.Code)
}
}
func TestProxy_Timeout(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second)
w.Write([]byte("OK"))
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL(),
Timeout: 100 * time.Millisecond,
}, nil)
req := httptest.NewRequest("GET", "/slow", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusGatewayTimeout {
t.Errorf("Expected status 504 Gateway Timeout, got %d", rr.Code)
}
}
// ============== Path Handling Tests ==============
func TestProxy_StripPrefix(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
})
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL(),
StripPrefix: "/api/v1",
}, nil)
req := httptest.NewRequest("GET", "/api/v1/users", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["path"] != "/users" {
t.Errorf("Expected stripped path /users, got %v", response["path"])
}
}
func TestProxy_TargetWithPath(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
})
})
defer backend.close()
proxy, _ := New(&Config{
Target: backend.URL() + "/backend",
}, nil)
req := httptest.NewRequest("GET", "/resource", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]string
json.NewDecoder(rr.Body).Decode(&response)
if response["path"] != "/backend/resource" {
t.Errorf("Expected path /backend/resource, got %v", response["path"])
}
}
// ============== Large Body Tests ==============
func TestProxy_LargeRequestBody(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.NewEncoder(w).Encode(map[string]int{
"received_bytes": len(body),
})
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
// 100KB body
largeBody := strings.Repeat("x", 100000)
req := httptest.NewRequest("POST", "/upload", strings.NewReader(largeBody))
req.ContentLength = int64(len(largeBody))
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
var response map[string]int
json.NewDecoder(rr.Body).Decode(&response)
if response["received_bytes"] != 100000 {
t.Errorf("Expected 100000 bytes, got %d", response["received_bytes"])
}
}
func TestProxy_LargeResponseBody(t *testing.T) {
largeResponse := strings.Repeat("y", 100000)
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(largeResponse))
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
req := httptest.NewRequest("GET", "/large", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Body.Len() != 100000 {
t.Errorf("Expected 100000 bytes in response, got %d", rr.Body.Len())
}
}
// ============== Concurrent Requests Tests ==============
func TestProxy_ConcurrentRequests(t *testing.T) {
backend := newTestBackend(nil)
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
const numRequests = 50
var wg sync.WaitGroup
errors := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
req := httptest.NewRequest("GET", "/concurrent", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
errors <- &net.OpError{Op: "test", Err: context.DeadlineExceeded}
}
}(i)
}
wg.Wait()
close(errors)
errorCount := 0
for range errors {
errorCount++
}
if errorCount > 0 {
t.Errorf("Got %d errors in concurrent requests", errorCount)
}
if backend.getRequestCount() != numRequests {
t.Errorf("Expected %d requests at backend, got %d", numRequests, backend.getRequestCount())
}
}
// ============== Echo Tests ==============
func TestProxy_Echo(t *testing.T) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
w.Write(body)
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
testData := "Hello, Proxy!"
req := httptest.NewRequest("POST", "/echo", strings.NewReader(testData))
req.Header.Set("Content-Type", "text/plain")
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
if rr.Body.String() != testData {
t.Errorf("Expected echo of '%s', got '%s'", testData, rr.Body.String())
}
if rr.Header().Get("Content-Type") != "text/plain" {
t.Errorf("Expected Content-Type text/plain, got %s", rr.Header().Get("Content-Type"))
}
}
// ============== Helper Function Tests ==============
func TestSingleJoiningSlash(t *testing.T) {
tests := []struct {
a, b, expected string
}{
{"/api", "/users", "/api/users"},
{"/api/", "/users", "/api/users"},
{"/api", "users", "/api/users"},
{"/api/", "users", "/api/users"},
{"", "/users", "/users"},
{"/api", "", "/api/"},
}
for _, tt := range tests {
result := singleJoiningSlash(tt.a, tt.b)
if result != tt.expected {
t.Errorf("singleJoiningSlash(%q, %q) = %q, want %q", tt.a, tt.b, result, tt.expected)
}
}
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
remoteAddr string
xRealIP string
expected string
}{
{"192.168.1.1:1234", "", "192.168.1.1"},
{"192.168.1.1:1234", "10.0.0.1", "10.0.0.1"},
{"invalid", "", "invalid"},
}
for _, tt := range tests {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
result := getClientIP(req)
if result != tt.expected {
t.Errorf("getClientIP() = %q, want %q", result, tt.expected)
}
}
}
func TestGetScheme(t *testing.T) {
tests := []struct {
name string
tls bool
header string
expected string
}{
{"HTTP", false, "", "http"},
{"HTTPS from TLS", true, "", "https"},
{"HTTPS from header", false, "https", "https"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if tt.header != "" {
req.Header.Set("X-Forwarded-Proto", tt.header)
}
// Note: httptest doesn't set TLS, so we can only test non-TLS cases fully
result := getScheme(req)
if !tt.tls && result != tt.expected {
t.Errorf("getScheme() = %q, want %q", result, tt.expected)
}
})
}
}
func TestIsConnectionError(t *testing.T) {
tests := []struct {
err error
expected bool
}{
{nil, false},
{&net.OpError{Op: "dial", Err: &net.DNSError{Err: "no such host"}}, true},
{context.DeadlineExceeded, false},
}
for _, tt := range tests {
result := isConnectionError(tt.err)
if result != tt.expected {
t.Errorf("isConnectionError(%v) = %v, want %v", tt.err, result, tt.expected)
}
}
}
// ============== Benchmarks ==============
func BenchmarkProxy_SimpleGET(b *testing.B) {
backend := newTestBackend(nil)
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest("GET", "/bench", nil)
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
}
}
func BenchmarkProxy_POSTWithBody(b *testing.B) {
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
})
defer backend.close()
proxy, _ := New(&Config{Target: backend.URL()}, nil)
body := strings.Repeat("x", 1024) // 1KB body
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest("POST", "/bench", strings.NewReader(body))
rr := httptest.NewRecorder()
proxy.ServeHTTP(rr, req)
}
}

View File

@ -1,492 +0,0 @@
// Package routing provides HTTP routing with regex support
package routing
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"github.com/konduktor/konduktor/internal/config"
"github.com/konduktor/konduktor/internal/logging"
"github.com/konduktor/konduktor/internal/proxy"
)
// RouteMatch represents a matched route with captured parameters
type RouteMatch struct {
Config map[string]interface{}
Params map[string]string
}
// RegexRoute represents a compiled regex route
type RegexRoute struct {
Pattern *regexp.Regexp
Config map[string]interface{}
CaseSensitive bool
OriginalExpr string
}
// Router handles HTTP routing with exact, regex, and default routes
type Router struct {
config *config.Config
logger *logging.Logger
mux *http.ServeMux
staticDir string
exactRoutes map[string]map[string]interface{}
regexRoutes []*RegexRoute
defaultRoute map[string]interface{}
mu sync.RWMutex
}
// New creates a new router from config
func New(cfg *config.Config, logger *logging.Logger) *Router {
staticDir := "./static"
if cfg != nil && cfg.HTTP.StaticDir != "" {
staticDir = cfg.HTTP.StaticDir
}
r := &Router{
config: cfg,
logger: logger,
mux: http.NewServeMux(),
staticDir: staticDir,
exactRoutes: make(map[string]map[string]interface{}),
regexRoutes: make([]*RegexRoute, 0),
}
// Load routes from extensions
if cfg != nil {
for _, ext := range cfg.Extensions {
if ext.Type == "routing" && ext.Config != nil {
if locations, ok := ext.Config["regex_locations"].(map[string]interface{}); ok {
for pattern, routeCfg := range locations {
if rc, ok := routeCfg.(map[string]interface{}); ok {
r.AddRoute(pattern, rc)
if logger != nil {
logger.Debug("Added route", "pattern", pattern)
}
}
}
}
}
}
}
r.setupRoutes()
return r
}
// NewRouter creates a router without config (for testing)
func NewRouter(opts ...RouterOption) *Router {
r := &Router{
mux: http.NewServeMux(),
staticDir: "./static",
exactRoutes: make(map[string]map[string]interface{}),
regexRoutes: make([]*RegexRoute, 0),
}
for _, opt := range opts {
opt(r)
}
return r
}
// RouterOption is a functional option for Router
type RouterOption func(*Router)
// WithStaticDir sets the static directory
func WithStaticDir(dir string) RouterOption {
return func(r *Router) {
r.staticDir = dir
}
}
// StaticDir returns the static directory path
func (r *Router) StaticDir() string {
return r.staticDir
}
// Routes returns the regex routes (for testing)
func (r *Router) Routes() []*RegexRoute {
r.mu.RLock()
defer r.mu.RUnlock()
return r.regexRoutes
}
// ExactRoutes returns the exact routes (for testing)
func (r *Router) ExactRoutes() map[string]map[string]interface{} {
r.mu.RLock()
defer r.mu.RUnlock()
return r.exactRoutes
}
// DefaultRoute returns the default route (for testing)
func (r *Router) DefaultRoute() map[string]interface{} {
r.mu.RLock()
defer r.mu.RUnlock()
return r.defaultRoute
}
// AddRoute adds a route with the given pattern and config
// Pattern formats:
// - "=/path" - exact match
// - "~regex" - case-sensitive regex
// - "~*regex" - case-insensitive regex
// - "__default__" - default/fallback route
func (r *Router) AddRoute(pattern string, routeConfig map[string]interface{}) {
r.mu.Lock()
defer r.mu.Unlock()
switch {
case pattern == "__default__":
r.defaultRoute = routeConfig
case strings.HasPrefix(pattern, "="):
// Exact match route
path := strings.TrimPrefix(pattern, "=")
r.exactRoutes[path] = routeConfig
case strings.HasPrefix(pattern, "~*"):
// Case-insensitive regex
expr := strings.TrimPrefix(pattern, "~*")
re, err := regexp.Compile("(?i)" + expr)
if err != nil {
if r.logger != nil {
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
}
return
}
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
Pattern: re,
Config: routeConfig,
CaseSensitive: false,
OriginalExpr: expr,
})
case strings.HasPrefix(pattern, "~"):
// Case-sensitive regex
expr := strings.TrimPrefix(pattern, "~")
re, err := regexp.Compile(expr)
if err != nil {
if r.logger != nil {
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
}
return
}
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
Pattern: re,
Config: routeConfig,
CaseSensitive: true,
OriginalExpr: expr,
})
}
}
// Match finds the best matching route for a path
// Priority: exact match > regex match > default
func (r *Router) Match(path string) *RouteMatch {
r.mu.RLock()
defer r.mu.RUnlock()
// 1. Check exact routes
if cfg, ok := r.exactRoutes[path]; ok {
return &RouteMatch{
Config: cfg,
Params: make(map[string]string),
}
}
// 2. Check regex routes
for _, route := range r.regexRoutes {
match := route.Pattern.FindStringSubmatch(path)
if match != nil {
params := make(map[string]string)
// Extract named groups
names := route.Pattern.SubexpNames()
for i, name := range names {
if i > 0 && name != "" && i < len(match) {
params[name] = match[i]
}
}
return &RouteMatch{
Config: route.Config,
Params: params,
}
}
}
// 3. Check default route
if r.defaultRoute != nil {
return &RouteMatch{
Config: r.defaultRoute,
Params: make(map[string]string),
}
}
return nil
}
// setupRoutes configures the routes from config
func (r *Router) setupRoutes() {
// Health check endpoint
r.mux.HandleFunc("/health", r.healthHandler)
// Setup redirect instructions from config
if r.config != nil {
for from, to := range r.config.Server.RedirectInstructions {
fromPath := from
toPath := to
r.mux.HandleFunc(fromPath, func(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, toPath, http.StatusMovedPermanently)
})
}
}
// Default handler for all other routes
r.mux.HandleFunc("/", r.defaultHandler)
}
// ServeHTTP implements http.Handler
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.mux.ServeHTTP(w, req)
}
// healthHandler handles health check requests
func (r *Router) healthHandler(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
// defaultHandler handles requests that don't match other routes
func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path
// Try to match against configured routes
match := r.Match(path)
fmt.Printf("DEBUG defaultHandler: path=%q match=%v defaultRoute=%v\n", path, match != nil, r.defaultRoute != nil)
if match != nil {
fmt.Printf("DEBUG: matched config: %v\n", match.Config)
r.handleRouteMatch(w, req, match)
return
}
// Try to serve static file
if r.staticDir != "" {
// Get absolute path for static dir
absStaticDir, err := filepath.Abs(r.staticDir)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
filePath := filepath.Join(absStaticDir, filepath.Clean("/"+path))
cleanPath := filepath.Clean(filePath)
// Prevent directory traversal - ensure path is within static dir
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absStaticDir+string(filepath.Separator)) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Check if file exists
info, err := os.Stat(filePath)
if err == nil {
if info.IsDir() {
// Try index.html
indexPath := filepath.Join(filePath, "index.html")
if _, err := os.Stat(indexPath); err == nil {
http.ServeFile(w, req, indexPath)
return
}
} else {
http.ServeFile(w, req, filePath)
return
}
}
}
// 404 Not Found
http.NotFound(w, req)
}
// handleRouteMatch handles a matched route
func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) {
cfg := match.Config
// Handle proxy_pass directive
if proxyTarget, ok := cfg["proxy_pass"].(string); ok {
r.handleProxyPass(w, req, proxyTarget, cfg, match.Params)
return
}
// Handle "return" directive
if ret, ok := cfg["return"].(string); ok {
parts := strings.SplitN(ret, " ", 2)
statusCode := 200
body := "OK"
if len(parts) >= 1 {
switch parts[0] {
case "200":
statusCode = 200
case "201":
statusCode = 201
case "301":
statusCode = 301
case "302":
statusCode = 302
case "400":
statusCode = 400
case "404":
statusCode = 404
case "500":
statusCode = 500
}
}
if len(parts) >= 2 {
body = parts[1]
}
if ct, ok := cfg["content_type"].(string); ok {
w.Header().Set("Content-Type", ct)
} else {
w.Header().Set("Content-Type", "text/plain")
}
w.WriteHeader(statusCode)
w.Write([]byte(body))
return
}
// Handle static files with root
if root, ok := cfg["root"].(string); ok {
path := req.URL.Path
if indexFile, ok := cfg["index_file"].(string); ok {
if path == "/" || strings.HasSuffix(path, "/") {
path = "/" + indexFile
}
}
// Get absolute path for root dir
absRoot, err := filepath.Abs(root)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
filePath := filepath.Join(absRoot, filepath.Clean("/"+path))
cleanPath := filepath.Clean(filePath)
// DEBUG
fmt.Printf("DEBUG: path=%q absRoot=%q filePath=%q cleanPath=%q\n", path, absRoot, filePath, cleanPath)
fmt.Printf("DEBUG: check1=%q check2=%q\n", cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator))
// Prevent directory traversal
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) {
fmt.Printf("DEBUG: FORBIDDEN - path not within root\n")
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
if cacheControl, ok := cfg["cache_control"].(string); ok {
w.Header().Set("Cache-Control", cacheControl)
}
if headers, ok := cfg["headers"].([]interface{}); ok {
for _, h := range headers {
if header, ok := h.(string); ok {
parts := strings.SplitN(header, ": ", 2)
if len(parts) == 2 {
w.Header().Set(parts[0], parts[1])
}
}
}
}
http.ServeFile(w, req, filePath)
return
}
// Handle SPA fallback
if spaFallback, ok := cfg["spa_fallback"].(bool); ok && spaFallback {
root := r.staticDir
if rt, ok := cfg["root"].(string); ok {
root = rt
}
indexFile := "index.html"
if idx, ok := cfg["index_file"].(string); ok {
indexFile = idx
}
filePath := filepath.Join(root, indexFile)
http.ServeFile(w, req, filePath)
return
}
http.NotFound(w, req)
}
// handleProxyPass proxies the request to the target backend
func (r *Router) handleProxyPass(w http.ResponseWriter, req *http.Request, target string, cfg map[string]interface{}, params map[string]string) {
// Substitute params in target URL (e.g., {version} -> actual version)
for key, value := range params {
target = strings.ReplaceAll(target, "{"+key+"}", value)
}
// Create proxy
proxyConfig := &proxy.Config{
Target: target,
Headers: make(map[string]string),
}
// Parse headers from config
if headers, ok := cfg["headers"].([]interface{}); ok {
for _, h := range headers {
if header, ok := h.(string); ok {
parts := strings.SplitN(header, ": ", 2)
if len(parts) == 2 {
// Substitute params in header values
headerValue := parts[1]
for key, value := range params {
headerValue = strings.ReplaceAll(headerValue, "{"+key+"}", value)
}
proxyConfig.Headers[parts[0]] = headerValue
}
}
}
}
p, err := proxy.New(proxyConfig, r.logger)
if err != nil {
if r.logger != nil {
r.logger.Error("Failed to create proxy", "target", target, "error", err)
}
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
p.ProxyRequest(w, req, params)
}
// CreateRouterFromConfig creates a router from extension config
func CreateRouterFromConfig(cfg map[string]interface{}) *Router {
router := NewRouter()
if locations, ok := cfg["regex_locations"].(map[string]interface{}); ok {
for pattern, routeCfg := range locations {
if rc, ok := routeCfg.(map[string]interface{}); ok {
router.AddRoute(pattern, rc)
}
}
}
return router
}

View File

@ -1,375 +0,0 @@
package routing
import (
"path/filepath"
"testing"
)
// ============== Router Initialization Tests ==============
func TestRouter_Initialization(t *testing.T) {
router := NewRouter()
if router.StaticDir() != "./static" {
t.Errorf("Expected static dir ./static, got %s", router.StaticDir())
}
if len(router.Routes()) != 0 {
t.Error("Expected empty routes")
}
if len(router.ExactRoutes()) != 0 {
t.Error("Expected empty exact routes")
}
if router.DefaultRoute() != nil {
t.Error("Expected nil default route")
}
}
func TestRouter_CustomStaticDir(t *testing.T) {
router := NewRouter(WithStaticDir("/custom/path"))
if router.StaticDir() != "/custom/path" {
t.Errorf("Expected static dir /custom/path, got %s", router.StaticDir())
}
}
// ============== Route Adding Tests ==============
func TestRouter_AddExactRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"return": "200 OK"}
router.AddRoute("=/health", config)
exactRoutes := router.ExactRoutes()
if _, ok := exactRoutes["/health"]; !ok {
t.Error("Expected /health in exact routes")
}
}
func TestRouter_AddDefaultRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"spa_fallback": true, "root": "./static"}
router.AddRoute("__default__", config)
if router.DefaultRoute() == nil {
t.Error("Expected default route to be set")
}
}
func TestRouter_AddRegexRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static"}
router.AddRoute("~^/api/", config)
if len(router.Routes()) != 1 {
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
}
}
func TestRouter_AddCaseInsensitiveRegexRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
router.AddRoute("~*\\.(css|js)$", config)
if len(router.Routes()) != 1 {
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
}
if router.Routes()[0].CaseSensitive {
t.Error("Expected case-insensitive route")
}
}
func TestRouter_InvalidRegexPattern(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static"}
// Invalid regex - unmatched bracket
router.AddRoute("~^/api/[invalid", config)
// Should not add invalid pattern
if len(router.Routes()) != 0 {
t.Error("Should not add invalid regex pattern")
}
}
// ============== Route Matching Tests ==============
func TestRouter_MatchExactRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"return": "200 OK"}
router.AddRoute("=/health", config)
match := router.Match("/health")
if match == nil {
t.Fatal("Expected match for /health")
}
if match.Config["return"] != "200 OK" {
t.Error("Expected return config")
}
if len(match.Params) != 0 {
t.Error("Expected empty params for exact match")
}
}
func TestRouter_MatchExactRouteNoMatch(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"return": "200 OK"}
router.AddRoute("=/health", config)
match := router.Match("/healthcheck")
if match != nil {
t.Error("Exact route should not match /healthcheck")
}
}
func TestRouter_MatchRegexRoute(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
router.AddRoute("~^/api/v\\d+/", config)
match := router.Match("/api/v1/users")
if match == nil {
t.Fatal("Expected match for /api/v1/users")
}
if match.Config["proxy_pass"] != "http://localhost:9001" {
t.Error("Expected proxy_pass config")
}
}
func TestRouter_MatchRegexRouteWithGroups(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
router.AddRoute("~^/api/v(?P<version>\\d+)/", config)
match := router.Match("/api/v2/data")
if match == nil {
t.Fatal("Expected match for /api/v2/data")
}
if match.Params["version"] != "2" {
t.Errorf("Expected version=2, got %s", match.Params["version"])
}
}
func TestRouter_MatchCaseInsensitiveRegex(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
router.AddRoute("~*\\.(CSS|JS)$", config)
// Should match lowercase
match1 := router.Match("/styles/main.css")
if match1 == nil {
t.Error("Should match lowercase .css")
}
// Should match uppercase
match2 := router.Match("/scripts/app.JS")
if match2 == nil {
t.Error("Should match uppercase .JS")
}
}
func TestRouter_MatchCaseSensitiveRegex(t *testing.T) {
router := NewRouter()
config := map[string]interface{}{"root": "./static"}
router.AddRoute("~\\.(css)$", config)
// Should match lowercase
match1 := router.Match("/styles/main.css")
if match1 == nil {
t.Error("Should match lowercase .css")
}
// Should NOT match uppercase
match2 := router.Match("/styles/main.CSS")
if match2 != nil {
t.Error("Should not match uppercase .CSS for case-sensitive regex")
}
}
func TestRouter_MatchDefaultRoute(t *testing.T) {
router := NewRouter()
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
match := router.Match("/unknown/path")
if match == nil {
t.Fatal("Expected default route match")
}
if match.Config["spa_fallback"] != true {
t.Error("Expected spa_fallback config from default route")
}
}
// ============== Priority Tests ==============
func TestRouter_PriorityExactOverRegex(t *testing.T) {
router := NewRouter()
router.AddRoute("=/api/status", map[string]interface{}{"return": "200 Exact"})
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
match := router.Match("/api/status")
if match == nil {
t.Fatal("Expected match")
}
if match.Config["return"] != "200 Exact" {
t.Error("Exact match should have priority over regex")
}
}
func TestRouter_PriorityRegexOverDefault(t *testing.T) {
router := NewRouter()
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
match := router.Match("/api/v1/users")
if match == nil {
t.Fatal("Expected match")
}
if match.Config["proxy_pass"] != "http://localhost:9001" {
t.Error("Regex match should have priority over default")
}
}
// ============== CreateRouterFromConfig Tests ==============
func TestCreateRouterFromConfig(t *testing.T) {
config := map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/health": map[string]interface{}{
"return": "200 OK",
"content_type": "text/plain",
},
"~^/api/": map[string]interface{}{
"proxy_pass": "http://localhost:9001",
},
"__default__": map[string]interface{}{
"spa_fallback": true,
"root": "./static",
},
},
}
router := CreateRouterFromConfig(config)
// Check exact route
if _, ok := router.ExactRoutes()["/health"]; !ok {
t.Error("Expected /health exact route")
}
// Check regex route
if len(router.Routes()) != 1 {
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
}
// Check default route
if router.DefaultRoute() == nil {
t.Error("Expected default route")
}
}
// ============== Static Dir Path Tests ==============
func TestRouter_StaticDirPath(t *testing.T) {
router := NewRouter(WithStaticDir("/var/www/html"))
expected, _ := filepath.Abs("/var/www/html")
actual, _ := filepath.Abs(router.StaticDir())
if actual != expected {
t.Errorf("Expected static dir %s, got %s", expected, actual)
}
}
// ============== Concurrent Access Tests ==============
func TestRouter_ConcurrentAccess(t *testing.T) {
router := NewRouter()
// Add routes concurrently
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(n int) {
router.AddRoute("~^/api/v"+string(rune('0'+n))+"/", map[string]interface{}{
"proxy_pass": "http://localhost:900" + string(rune('0'+n)),
})
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// Match routes concurrently
for i := 0; i < 10; i++ {
go func(n int) {
router.Match("/api/v" + string(rune('0'+n)) + "/users")
done <- true
}(i)
}
for i := 0; i < 10; i++ {
<-done
}
}
// ============== Benchmarks ==============
func BenchmarkRouter_MatchExact(b *testing.B) {
router := NewRouter()
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
router.Match("/health")
}
}
func BenchmarkRouter_MatchRegex(b *testing.B) {
router := NewRouter()
router.AddRoute("~^/api/v(?P<version>\\d+)/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
b.ResetTimer()
for i := 0; i < b.N; i++ {
router.Match("/api/v1/users/123")
}
}
func BenchmarkRouter_MatchWithManyRoutes(b *testing.B) {
router := NewRouter()
// Add many routes
for i := 0; i < 50; i++ {
router.AddRoute("~^/api/v"+string(rune('0'+i%10))+"/service"+string(rune('0'+i/10))+"/",
map[string]interface{}{"proxy_pass": "http://localhost:9001"})
}
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
b.ResetTimer()
for i := 0; i < b.N; i++ {
router.Match("/api/v5/service3/users/123")
}
}

View File

@ -1,158 +0,0 @@
// Package server provides the HTTP server implementation
package server
import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/konduktor/konduktor/internal/config"
"github.com/konduktor/konduktor/internal/extension"
"github.com/konduktor/konduktor/internal/logging"
"github.com/konduktor/konduktor/internal/middleware"
)
const Version = "0.2.0"
// Server represents the Konduktor HTTP server
type Server struct {
config *config.Config
httpServer *http.Server
extensionManager *extension.Manager
logger *logging.Logger
}
// New creates a new server instance
func New(cfg *config.Config) (*Server, error) {
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
logger, err := logging.NewFromConfig(cfg.Logging)
if err != nil {
return nil, fmt.Errorf("failed to create logger: %w", err)
}
// Create extension manager
extManager := extension.NewManager(logger)
// Load extensions from config
for _, extCfg := range cfg.Extensions {
// Add static_dir to routing config if not present
if extCfg.Type == "routing" {
if extCfg.Config == nil {
extCfg.Config = make(map[string]interface{})
}
if _, ok := extCfg.Config["static_dir"]; !ok {
extCfg.Config["static_dir"] = cfg.HTTP.StaticDir
}
}
if err := extManager.LoadExtension(extCfg.Type, extCfg.Config); err != nil {
logger.Error("Failed to load extension", "type", extCfg.Type, "error", err)
// Continue loading other extensions
}
}
srv := &Server{
config: cfg,
extensionManager: extManager,
logger: logger,
}
return srv, nil
}
// Run starts the server and blocks until shutdown
func (s *Server) Run() error {
// Build handler chain with middleware
handler := s.buildHandler()
// Create HTTP server
addr := fmt.Sprintf("%s:%d", s.config.Server.Host, s.config.Server.Port)
s.httpServer = &http.Server{
Addr: addr,
Handler: handler,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
// Start server in goroutine
errChan := make(chan error, 1)
go func() {
s.logger.Info("Server starting", "addr", addr, "version", Version)
var err error
if s.config.SSL.Enabled {
err = s.httpServer.ListenAndServeTLS(s.config.SSL.CertFile, s.config.SSL.KeyFile)
} else {
err = s.httpServer.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
errChan <- err
}
}()
// Wait for shutdown signal
return s.waitForShutdown(errChan)
}
// buildHandler builds the HTTP handler chain
func (s *Server) buildHandler() http.Handler {
// Create base handler that returns 404
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
})
// Wrap with extension manager
var handler http.Handler = s.extensionManager.Handler(baseHandler)
// Add middleware (applied in reverse order)
handler = middleware.AccessLog(handler, s.logger)
handler = middleware.ServerHeader(handler, Version)
handler = middleware.Recovery(handler, s.logger)
return handler
}
// waitForShutdown waits for shutdown signal and gracefully stops the server
func (s *Server) waitForShutdown(errChan <-chan error) error {
// Listen for shutdown signals
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
select {
case err := <-errChan:
return err
case sig := <-sigChan:
s.logger.Info("Shutdown signal received", "signal", sig.String())
}
// Graceful shutdown with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
s.logger.Info("Shutting down server...")
// Cleanup extensions
s.extensionManager.Cleanup()
if err := s.httpServer.Shutdown(ctx); err != nil {
s.logger.Error("Error during shutdown", "error", err)
return err
}
s.logger.Info("Server stopped gracefully")
return nil
}
// Shutdown gracefully shuts down the server
func (s *Server) Shutdown(ctx context.Context) error {
return s.httpServer.Shutdown(ctx)
}

View File

@ -1,134 +0,0 @@
# Integration Tests
Интеграционные тесты для Konduktor — полноценное тестирование сервера с реальными HTTP запросами.
## Отличие от unit-тестов
| Аспект | Unit-тесты | Интеграционные тесты |
|--------|------------|---------------------|
| Scope | Отдельный модуль в изоляции | Весь сервер целиком |
| Backend | Mock (httptest.Server) | Реальные HTTP серверы |
| Config | Программный | YAML конфигурация |
| Extensions | Не тестируются | Полная цепочка обработки |
## Структура тестов
```
tests/integration/
├── README.md # Эта документация
├── helpers_test.go # Общие хелперы и утилиты
├── reverse_proxy_test.go # Тесты reverse proxy
├── routing_test.go # Тесты маршрутизации (TODO)
├── security_test.go # Тесты security extension (TODO)
├── caching_test.go # Тесты caching extension (TODO)
└── static_files_test.go # Тесты статических файлов (TODO)
```
## Что тестируют интеграционные тесты
### 1. Reverse Proxy (`reverse_proxy_test.go`)
- [x] Базовое проксирование GET/POST/PUT/DELETE
- [x] Exact match routes (`=/api/version`)
- [x] Regex routes с параметрами (`~^/api/resource/(?P<id>\d+)$`)
- [x] Подстановка параметров в target URL (`{id}`, `{tag}`)
- [x] Подстановка переменных в заголовки (`$remote_addr`)
- [x] Передача заголовков X-Forwarded-For, X-Real-IP
- [x] Сохранение query string
- [x] Обработка ошибок backend (502, 504)
- [x] Таймауты соединения
### 2. Routing Extension (`routing_test.go`)
- [x] Приоритет маршрутов (exact > regex > default)
- [x] Case-sensitive regex (`~`)
- [x] Case-insensitive regex (`~*`)
- [x] Default route (`__default__`)
- [x] Return directive (`return 200 "OK"`)
- [x] Regex с именованными группами
- [x] Множественные regex маршруты
- [x] Кастомные заголовки в маршрутах
- [x] Обработка отсутствия маршрута
### 3. Security Extension (`security_test.go`)
- [ ] IP whitelist
- [ ] IP blacklist
- [ ] CIDR нотация (10.0.0.0/8)
- [ ] Security headers (X-Frame-Options, X-Content-Type-Options)
- [ ] Rate limiting
- [ ] Комбинация с другими extensions
### 4. Caching Extension (`caching_test.go`)
- [x] Cache hit/miss
- [x] TTL expiration
- [x] Pattern-based caching
- [x] Cache-Control headers (X-Cache header)
- [x] Кэширование только GET запросов
- [x] Разные пути = разные ключи кэша
- [x] Query string влияет на ключ кэша
- [x] Ошибки не кэшируются
- [x] Конкурентный доступ к кэшу
- [x] Множественные паттерны кэширования
### 5. Static Files (`static_files_test.go`)
- [ ] Serving статических файлов
- [ ] Index file (index.html)
- [ ] MIME types
- [ ] Cache-Control для static
- [ ] SPA fallback
- [ ] Directory traversal protection
- [ ] 404 для несуществующих файлов
### 6. Extension Chain (`extension_chain_test.go`)
- [ ] Порядок выполнения extensions (security → caching → routing)
- [ ] Прерывание цепочки при ошибке
- [ ] Совместная работа extensions
## Запуск тестов
```bash
# Все интеграционные тесты
go test ./tests/integration/... -v
# Конкретный файл
go test ./tests/integration/... -v -run TestReverseProxy
# С таймаутом (интеграционные тесты медленнее)
go test ./tests/integration/... -v -timeout 60s
# С покрытием
go test ./tests/integration/... -v -coverprofile=coverage.out
```
## Требования
- Свободные порты: тесты используют случайные порты (`:0`)
- Сетевой доступ: для localhost соединений
- Время: интеграционные тесты занимают больше времени (~5-10 сек)
## Добавление новых тестов
1. Создайте файл `*_test.go` в `tests/integration/`
2. Используйте хелперы из `helpers_test.go`:
- `startTestServer()` — запуск Konduktor сервера
- `startBackend()` — запуск mock backend
- `makeRequest()` — отправка HTTP запроса
3. Добавьте описание в этот README
## CI/CD
Интеграционные тесты запускаются отдельно от unit-тестов:
```yaml
# .github/workflows/test.yml
jobs:
unit-tests:
run: go test ./internal/...
integration-tests:
run: go test ./tests/integration/... -timeout 120s
```

View File

@ -1,666 +0,0 @@
package integration
import (
"encoding/json"
"fmt"
"net/http"
"sync/atomic"
"testing"
"time"
"github.com/konduktor/konduktor/internal/extension"
)
// ============== Basic Cache Hit/Miss Tests ==============
func TestCaching_BasicHitMiss(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"request_number": count,
"timestamp": time.Now().UnixNano(),
})
})
defer backend.Close()
logger := createTestLogger(t)
// Create caching extension
cachingExt, err := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "30s",
"methods": []interface{}{"GET"},
},
},
}, logger)
if err != nil {
t.Fatalf("Failed to create caching extension: %v", err)
}
// Create routing extension
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// First request - should be MISS
resp1, err := client.Get("/api/data", nil)
if err != nil {
t.Fatalf("Request 1 failed: %v", err)
}
cacheHeader1 := resp1.Header.Get("X-Cache")
var result1 map[string]interface{}
json.NewDecoder(resp1.Body).Decode(&result1)
resp1.Body.Close()
if cacheHeader1 != "MISS" {
t.Errorf("Expected X-Cache: MISS for first request, got %q", cacheHeader1)
}
// Second request - should be HIT (same response)
resp2, err := client.Get("/api/data", nil)
if err != nil {
t.Fatalf("Request 2 failed: %v", err)
}
cacheHeader2 := resp2.Header.Get("X-Cache")
var result2 map[string]interface{}
json.NewDecoder(resp2.Body).Decode(&result2)
resp2.Body.Close()
if cacheHeader2 != "HIT" {
t.Errorf("Expected X-Cache: HIT for second request, got %q", cacheHeader2)
}
// Verify same response (from cache)
if result1["request_number"] != result2["request_number"] {
t.Errorf("Expected same request_number from cache, got %v and %v",
result1["request_number"], result2["request_number"])
}
// Backend should only receive 1 request
if atomic.LoadInt64(&requestCount) != 1 {
t.Errorf("Expected 1 backend request, got %d", requestCount)
}
}
// ============== TTL Expiration Tests ==============
func TestCaching_TTLExpiration(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"request_number": count,
})
})
defer backend.Close()
logger := createTestLogger(t)
// Create caching extension with short TTL
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "100ms", // Very short TTL for testing
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "100ms",
"methods": []interface{}{"GET"},
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// First request
resp1, _ := client.Get("/api/data", nil)
var result1 map[string]interface{}
json.NewDecoder(resp1.Body).Decode(&result1)
resp1.Body.Close()
// Second request (within TTL) - should be HIT
resp2, _ := client.Get("/api/data", nil)
cacheHeader2 := resp2.Header.Get("X-Cache")
resp2.Body.Close()
if cacheHeader2 != "HIT" {
t.Errorf("Expected X-Cache: HIT before TTL expires, got %q", cacheHeader2)
}
// Wait for TTL to expire
time.Sleep(150 * time.Millisecond)
// Third request (after TTL) - should be MISS
resp3, _ := client.Get("/api/data", nil)
cacheHeader3 := resp3.Header.Get("X-Cache")
var result3 map[string]interface{}
json.NewDecoder(resp3.Body).Decode(&result3)
resp3.Body.Close()
if cacheHeader3 != "MISS" {
t.Errorf("Expected X-Cache: MISS after TTL expires, got %q", cacheHeader3)
}
// Verify new request was made (different request_number)
if result1["request_number"] == result3["request_number"] {
t.Error("Expected different request_number after TTL expiration")
}
}
// ============== Pattern-Based Caching Tests ==============
func TestCaching_PatternBasedCaching(t *testing.T) {
var apiCount, staticCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.URL.Path[:5] == "/api/" {
atomic.AddInt64(&apiCount, 1)
} else {
atomic.AddInt64(&staticCount, 1)
}
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
})
defer backend.Close()
logger := createTestLogger(t)
// Only cache /api/* paths
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
"methods": []interface{}{"GET"},
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Multiple requests to /api/ - should be cached
for i := 0; i < 3; i++ {
resp, _ := client.Get("/api/users", nil)
resp.Body.Close()
}
// Multiple requests to /static/ - should NOT be cached (not matching pattern)
for i := 0; i < 3; i++ {
resp, _ := client.Get("/static/file.js", nil)
resp.Body.Close()
}
// API should have only 1 request (cached)
if atomic.LoadInt64(&apiCount) != 1 {
t.Errorf("Expected 1 API request (cached), got %d", apiCount)
}
// Static should have 3 requests (not cached)
if atomic.LoadInt64(&staticCount) != 3 {
t.Errorf("Expected 3 static requests (not cached), got %d", staticCount)
}
}
// ============== Method-Specific Caching Tests ==============
func TestCaching_OnlyGETMethodCached(t *testing.T) {
var getCount, postCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == "GET" {
atomic.AddInt64(&getCount, 1)
} else if r.Method == "POST" {
atomic.AddInt64(&postCount, 1)
}
json.NewEncoder(w).Encode(map[string]string{
"method": r.Method,
})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
"methods": []interface{}{"GET"}, // Only GET
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Multiple GET requests - should be cached
for i := 0; i < 3; i++ {
resp, _ := client.Get("/api/data", nil)
resp.Body.Close()
}
// Multiple POST requests - should NOT be cached
for i := 0; i < 3; i++ {
resp, _ := client.Post("/api/data", []byte(`{}`), map[string]string{
"Content-Type": "application/json",
})
resp.Body.Close()
}
if atomic.LoadInt64(&getCount) != 1 {
t.Errorf("Expected 1 GET request (cached), got %d", getCount)
}
if atomic.LoadInt64(&postCount) != 3 {
t.Errorf("Expected 3 POST requests (not cached), got %d", postCount)
}
}
// ============== Different Paths Different Cache Keys ==============
func TestCaching_DifferentPathsDifferentCacheKeys(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"path": r.URL.Path,
"request_number": count,
})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Request different paths
paths := []string{"/api/users", "/api/posts", "/api/comments"}
for _, path := range paths {
resp, _ := client.Get(path, nil)
resp.Body.Close()
}
// Each path should result in a separate backend request
if atomic.LoadInt64(&requestCount) != 3 {
t.Errorf("Expected 3 backend requests (one per path), got %d", requestCount)
}
// Request same paths again - all should be cached
for _, path := range paths {
resp, _ := client.Get(path, nil)
cacheHeader := resp.Header.Get("X-Cache")
resp.Body.Close()
if cacheHeader != "HIT" {
t.Errorf("Expected X-Cache: HIT for %s, got %q", path, cacheHeader)
}
}
// No additional backend requests
if atomic.LoadInt64(&requestCount) != 3 {
t.Errorf("Expected still 3 backend requests after cache hits, got %d", requestCount)
}
}
// ============== Query String Affects Cache Key ==============
func TestCaching_QueryStringAffectsCacheKey(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"query": r.URL.RawQuery,
"request_number": count,
})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Different query strings = different cache keys
queries := []string{
"/api/search?q=hello",
"/api/search?q=world",
"/api/search?q=test",
}
for _, query := range queries {
resp, _ := client.Get(query, nil)
resp.Body.Close()
}
// Each unique query should result in a separate backend request
if atomic.LoadInt64(&requestCount) != 3 {
t.Errorf("Expected 3 backend requests (one per query), got %d", requestCount)
}
// Same query again should be cached
resp, _ := client.Get("/api/search?q=hello", nil)
cacheHeader := resp.Header.Get("X-Cache")
resp.Body.Close()
if cacheHeader != "HIT" {
t.Errorf("Expected X-Cache: HIT for repeated query, got %q", cacheHeader)
}
}
// ============== Cache Does Not Store Error Responses ==============
func TestCaching_DoesNotCacheErrors(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "internal error"})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Multiple requests to error endpoint
for i := 0; i < 3; i++ {
resp, _ := client.Get("/api/error", nil)
resp.Body.Close()
}
// All requests should reach backend (errors not cached)
if atomic.LoadInt64(&requestCount) != 3 {
t.Errorf("Expected 3 backend requests (errors not cached), got %d", requestCount)
}
}
// ============== Concurrent Cache Access ==============
func TestCaching_ConcurrentAccess(t *testing.T) {
var requestCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
// Small delay to increase chance of race conditions
time.Sleep(10 * time.Millisecond)
count := atomic.AddInt64(&requestCount, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"request_number": count,
})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "1m",
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
const numRequests = 20
results := make(chan error, numRequests)
// Make first request to populate cache
client := NewHTTPClient(server.URL)
resp, _ := client.Get("/api/concurrent", nil)
resp.Body.Close()
// Now many concurrent requests should all hit cache
for i := 0; i < numRequests; i++ {
go func(n int) {
client := NewHTTPClient(server.URL)
resp, err := client.Get("/api/concurrent", nil)
if err != nil {
results <- err
return
}
cacheHeader := resp.Header.Get("X-Cache")
resp.Body.Close()
if cacheHeader != "HIT" {
results <- fmt.Errorf("request %d: expected HIT, got %s", n, cacheHeader)
return
}
results <- nil
}(i)
}
// Collect results
var errors []error
for i := 0; i < numRequests; i++ {
if err := <-results; err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
t.Errorf("Got %d errors in concurrent cache access: %v", len(errors), errors[:min(5, len(errors))])
}
// Only 1 request should reach backend (the initial one)
if atomic.LoadInt64(&requestCount) != 1 {
t.Errorf("Expected 1 backend request, got %d", requestCount)
}
}
// ============== Multiple Cache Patterns ==============
func TestCaching_MultipleCachePatterns(t *testing.T) {
var apiCount, staticCount int64
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
atomic.AddInt64(&apiCount, 1)
} else if len(r.URL.Path) >= 8 && r.URL.Path[:8] == "/static/" {
atomic.AddInt64(&staticCount, 1)
}
json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path})
})
defer backend.Close()
logger := createTestLogger(t)
cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{
"default_ttl": "1m",
"cache_patterns": []interface{}{
map[string]interface{}{
"pattern": "^/api/.*",
"ttl": "30s",
"methods": []interface{}{"GET"},
},
map[string]interface{}{
"pattern": "^/static/.*",
"ttl": "1h", // Static files cached longer
"methods": []interface{}{"GET"},
},
},
}, logger)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{cachingExt, routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Multiple requests to both patterns
for i := 0; i < 3; i++ {
resp1, _ := client.Get("/api/data", nil)
resp1.Body.Close()
resp2, _ := client.Get("/static/app.js", nil)
resp2.Body.Close()
}
// Both should be cached (1 request each)
if atomic.LoadInt64(&apiCount) != 1 {
t.Errorf("Expected 1 API request, got %d", apiCount)
}
if atomic.LoadInt64(&staticCount) != 1 {
t.Errorf("Expected 1 static request, got %d", staticCount)
}
}

View File

@ -1,408 +0,0 @@
package integration
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/konduktor/konduktor/internal/extension"
"github.com/konduktor/konduktor/internal/logging"
"github.com/konduktor/konduktor/internal/middleware"
)
// TestServer represents a running Konduktor server for testing
type TestServer struct {
Server *http.Server
URL string
Port int
listener net.Listener
handler http.Handler
t *testing.T
}
// TestBackend represents a mock backend server
type TestBackend struct {
server *httptest.Server
requestLog []RequestLogEntry
mu sync.Mutex
requestCount int64
handler http.HandlerFunc
}
// RequestLogEntry stores information about a received request
type RequestLogEntry struct {
Method string
Path string
Query string
Headers http.Header
Body string
Timestamp time.Time
}
// ServerConfig holds configuration for starting a test server
type ServerConfig struct {
Extensions []extension.Extension
StaticDir string
Middleware []func(http.Handler) http.Handler
}
// ============== Test Server ==============
// StartTestServer creates and starts a Konduktor server for testing
func StartTestServer(t *testing.T, cfg *ServerConfig) *TestServer {
t.Helper()
logger, err := logging.New(logging.Config{
Level: "DEBUG",
})
if err != nil {
t.Fatalf("Failed to create logger: %v", err)
}
// Create extension manager
extManager := extension.NewManager(logger)
// Add extensions if provided
if cfg != nil && len(cfg.Extensions) > 0 {
for _, ext := range cfg.Extensions {
if err := extManager.AddExtension(ext); err != nil {
t.Fatalf("Failed to add extension: %v", err)
}
}
}
// Create a fallback handler for when no extension handles the request
fallback := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
})
// Create handler chain
var handler http.Handler = extManager.Handler(fallback)
// Add middleware
handler = middleware.AccessLog(handler, logger)
handler = middleware.Recovery(handler, logger)
// Add custom middleware if provided
if cfg != nil {
for _, mw := range cfg.Middleware {
handler = mw(handler)
}
}
// Find available port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to find available port: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
server := &http.Server{
Handler: handler,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
ts := &TestServer{
Server: server,
URL: fmt.Sprintf("http://127.0.0.1:%d", port),
Port: port,
listener: listener,
handler: handler,
t: t,
}
// Start server in goroutine
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
// Don't fail test here as server might be intentionally closed
}
}()
// Wait for server to be ready
ts.waitReady()
return ts
}
// waitReady waits for the server to be ready to accept connections
func (ts *TestServer) waitReady() {
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", ts.Port), 100*time.Millisecond)
if err == nil {
conn.Close()
return
}
time.Sleep(10 * time.Millisecond)
}
ts.t.Fatal("Server failed to start within timeout")
}
// Close shuts down the test server
func (ts *TestServer) Close() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ts.Server.Shutdown(ctx)
}
// ============== Test Backend ==============
// StartBackend creates and starts a mock backend server
func StartBackend(handler http.HandlerFunc) *TestBackend {
tb := &TestBackend{
requestLog: make([]RequestLogEntry, 0),
handler: handler,
}
if handler == nil {
handler = tb.defaultHandler
}
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tb.logRequest(r)
handler(w, r)
}))
return tb
}
func (tb *TestBackend) logRequest(r *http.Request) {
tb.mu.Lock()
defer tb.mu.Unlock()
body, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewReader(body))
tb.requestLog = append(tb.requestLog, RequestLogEntry{
Method: r.Method,
Path: r.URL.Path,
Query: r.URL.RawQuery,
Headers: r.Header.Clone(),
Body: string(body),
Timestamp: time.Now(),
})
atomic.AddInt64(&tb.requestCount, 1)
}
func (tb *TestBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"backend": "default",
"path": r.URL.Path,
"method": r.Method,
"query": r.URL.RawQuery,
"received": time.Now().Unix(),
})
}
// URL returns the backend server URL
func (tb *TestBackend) URL() string {
return tb.server.URL
}
// Close shuts down the backend server
func (tb *TestBackend) Close() {
tb.server.Close()
}
// RequestCount returns the number of requests received
func (tb *TestBackend) RequestCount() int64 {
return atomic.LoadInt64(&tb.requestCount)
}
// LastRequest returns the most recent request
func (tb *TestBackend) LastRequest() *RequestLogEntry {
tb.mu.Lock()
defer tb.mu.Unlock()
if len(tb.requestLog) == 0 {
return nil
}
return &tb.requestLog[len(tb.requestLog)-1]
}
// AllRequests returns all logged requests
func (tb *TestBackend) AllRequests() []RequestLogEntry {
tb.mu.Lock()
defer tb.mu.Unlock()
result := make([]RequestLogEntry, len(tb.requestLog))
copy(result, tb.requestLog)
return result
}
// ============== HTTP Client Helpers ==============
// HTTPClient is a configured HTTP client for testing
type HTTPClient struct {
client *http.Client
baseURL string
}
// NewHTTPClient creates a new test HTTP client
func NewHTTPClient(baseURL string) *HTTPClient {
return &HTTPClient{
client: &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Don't follow redirects
},
},
baseURL: baseURL,
}
}
// Get performs a GET request
func (c *HTTPClient) Get(path string, headers map[string]string) (*http.Response, error) {
return c.Do("GET", path, nil, headers)
}
// Post performs a POST request
func (c *HTTPClient) Post(path string, body []byte, headers map[string]string) (*http.Response, error) {
return c.Do("POST", path, body, headers)
}
// Do performs an HTTP request
func (c *HTTPClient) Do(method, path string, body []byte, headers map[string]string) (*http.Response, error) {
var bodyReader io.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequest(method, c.baseURL+path, bodyReader)
if err != nil {
return nil, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
return c.client.Do(req)
}
// GetJSON performs GET and decodes JSON response
func (c *HTTPClient) GetJSON(path string, result interface{}) (*http.Response, error) {
resp, err := c.Get(path, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
return resp, err
}
return resp, nil
}
// ============== File System Helpers ==============
// CreateTempDir creates a temporary directory for static files
func CreateTempDir(t *testing.T) string {
t.Helper()
dir, err := os.MkdirTemp("", "konduktor-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
t.Cleanup(func() { os.RemoveAll(dir) })
return dir
}
// CreateTempFile creates a temporary file with given content
func CreateTempFile(t *testing.T, dir, name, content string) string {
t.Helper()
path := filepath.Join(dir, name)
// Create parent directories if needed
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
t.Fatalf("Failed to create directories: %v", err)
}
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write file: %v", err)
}
return path
}
// ============== Assertion Helpers ==============
// AssertStatus checks if response has expected status code
func AssertStatus(t *testing.T, resp *http.Response, expected int) {
t.Helper()
if resp.StatusCode != expected {
body, _ := io.ReadAll(resp.Body)
t.Errorf("Expected status %d, got %d. Body: %s", expected, resp.StatusCode, string(body))
}
}
// AssertHeader checks if response has expected header value
func AssertHeader(t *testing.T, resp *http.Response, header, expected string) {
t.Helper()
actual := resp.Header.Get(header)
if actual != expected {
t.Errorf("Expected header %s=%q, got %q", header, expected, actual)
}
}
// AssertHeaderContains checks if header contains substring
func AssertHeaderContains(t *testing.T, resp *http.Response, header, substring string) {
t.Helper()
actual := resp.Header.Get(header)
if actual == "" || !contains(actual, substring) {
t.Errorf("Expected header %s to contain %q, got %q", header, substring, actual)
}
}
// AssertJSONField checks if JSON response has expected field value
func AssertJSONField(t *testing.T, body []byte, field string, expected interface{}) {
t.Helper()
var data map[string]interface{}
if err := json.Unmarshal(body, &data); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
actual, ok := data[field]
if !ok {
t.Errorf("Field %q not found in JSON", field)
return
}
if actual != expected {
t.Errorf("Expected %s=%v, got %v", field, expected, actual)
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0))
}
func containsAt(s, substr string, start int) bool {
for i := start; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// ReadBody reads and returns response body
func ReadBody(t *testing.T, resp *http.Response) []byte {
t.Helper()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read body: %v", err)
}
return body
}

View File

@ -1,562 +0,0 @@
package integration
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/konduktor/konduktor/internal/extension"
"github.com/konduktor/konduktor/internal/logging"
)
// createTestLogger creates a logger for tests
func createTestLogger(t *testing.T) *logging.Logger {
t.Helper()
logger, err := logging.New(logging.Config{Level: "DEBUG"})
if err != nil {
t.Fatalf("Failed to create logger: %v", err)
}
return logger
}
// ============== Basic Reverse Proxy Tests ==============
func TestReverseProxy_BasicGET(t *testing.T) {
// Start backend server
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"message": "Hello from backend",
"path": r.URL.Path,
"method": r.Method,
})
})
defer backend.Close()
// Create routing extension with proxy to backend
logger := createTestLogger(t)
routingExt, err := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
if err != nil {
t.Fatalf("Failed to create routing extension: %v", err)
}
// Start Konduktor server
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
// Make request through Konduktor
client := NewHTTPClient(server.URL)
resp, err := client.Get("/api/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
// Verify response
AssertStatus(t, resp, http.StatusOK)
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if result["message"] != "Hello from backend" {
t.Errorf("Unexpected message: %v", result["message"])
}
if result["path"] != "/api/test" {
t.Errorf("Expected path /api/test, got %v", result["path"])
}
// Verify backend received request
if backend.RequestCount() != 1 {
t.Errorf("Expected 1 backend request, got %d", backend.RequestCount())
}
}
func TestReverseProxy_POST(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
var body map[string]interface{}
json.NewDecoder(r.Body).Decode(&body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"received": body,
"method": r.Method,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
body := []byte(`{"name":"test","value":123}`)
resp, err := client.Post("/api/data", body, map[string]string{
"Content-Type": "application/json",
})
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
if result["method"] != "POST" {
t.Errorf("Expected method POST, got %v", result["method"])
}
received := result["received"].(map[string]interface{})
if received["name"] != "test" {
t.Errorf("Expected name 'test', got %v", received["name"])
}
}
// ============== Exact Match Routes ==============
func TestReverseProxy_ExactMatchRoute(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"endpoint": "version",
"path": r.URL.Path,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Exact match - should use backend URL as-is
"=/api/version": map[string]interface{}{
"proxy_pass": backend.URL() + "/releases/latest",
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Test exact match route
resp, err := client.Get("/api/version", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
lastReq := backend.LastRequest()
if lastReq == nil {
t.Fatal("No request received by backend")
}
// For exact match, the target path should be used as-is (IgnoreRequestPath=true)
if lastReq.Path != "/releases/latest" {
t.Errorf("Expected backend path /releases/latest, got %s", lastReq.Path)
}
}
// ============== Regex Routes with Parameters ==============
func TestReverseProxy_RegexRouteWithParams(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Regex with named group
"~^/api/users/(?P<id>\\d+)$": map[string]interface{}{
"proxy_pass": backend.URL() + "/v2/users/{id}",
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Test regex route with parameter
resp, err := client.Get("/api/users/42", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
lastReq := backend.LastRequest()
if lastReq == nil {
t.Fatal("No request received by backend")
}
// Parameter {id} should be substituted
if lastReq.Path != "/v2/users/42" {
t.Errorf("Expected backend path /v2/users/42, got %s", lastReq.Path)
}
}
// ============== Header Forwarding ==============
func TestReverseProxy_HeaderForwarding(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"x-forwarded-for": r.Header.Get("X-Forwarded-For"),
"x-real-ip": r.Header.Get("X-Real-IP"),
"x-custom": r.Header.Get("X-Custom"),
"x-forwarded-host": r.Header.Get("X-Forwarded-Host"),
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
"headers": []interface{}{
"X-Forwarded-For: $remote_addr",
"X-Real-IP: $remote_addr",
},
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/test", map[string]string{
"X-Custom": "custom-value",
})
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
// X-Custom should be forwarded
if result["x-custom"] != "custom-value" {
t.Errorf("Expected X-Custom header to be forwarded, got %v", result["x-custom"])
}
// X-Forwarded-For should be set (will contain 127.0.0.1)
if result["x-forwarded-for"] == "" {
t.Error("Expected X-Forwarded-For header to be set")
}
}
// ============== Query String ==============
func TestReverseProxy_QueryStringPreservation(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"query": r.URL.RawQuery,
"foo": r.URL.Query().Get("foo"),
"bar": r.URL.Query().Get("bar"),
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/search?foo=hello&bar=world", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["foo"] != "hello" {
t.Errorf("Expected foo=hello, got %v", result["foo"])
}
if result["bar"] != "world" {
t.Errorf("Expected bar=world, got %v", result["bar"])
}
}
// ============== Error Handling ==============
func TestReverseProxy_BackendUnavailable(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
// Non-existent backend
"proxy_pass": "http://127.0.0.1:59999",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
// Should return 502 Bad Gateway
AssertStatus(t, resp, http.StatusBadGateway)
}
func TestReverseProxy_BackendTimeout(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
// Simulate slow backend
time.Sleep(3 * time.Second)
w.Write([]byte("OK"))
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
"timeout": 0.5, // 500ms timeout
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/slow", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
// Should return 504 Gateway Timeout
AssertStatus(t, resp, http.StatusGatewayTimeout)
}
// ============== HTTP Methods ==============
func TestReverseProxy_AllMethods(t *testing.T) {
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"method": r.Method,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
for _, method := range methods {
t.Run(method, func(t *testing.T) {
resp, err := client.Do(method, "/resource", nil, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
if method != "HEAD" {
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["method"] != method {
t.Errorf("Expected method %s, got %v", method, result["method"])
}
}
})
}
}
// ============== Large Bodies ==============
func TestReverseProxy_LargeRequestBody(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.NewEncoder(w).Encode(map[string]int{
"received": len(body),
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// 1MB body
largeBody := []byte(strings.Repeat("x", 1024*1024))
resp, err := client.Post("/upload", largeBody, map[string]string{
"Content-Type": "application/octet-stream",
})
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
}
// ============== Concurrent Requests ==============
func TestReverseProxy_ConcurrentRequests(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
// Small delay to simulate work
time.Sleep(10 * time.Millisecond)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
const numRequests = 50
results := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func(n int) {
client := NewHTTPClient(server.URL)
resp, err := client.Get(fmt.Sprintf("/concurrent/%d", n), nil)
if err != nil {
results <- err
return
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
results <- fmt.Errorf("unexpected status: %d", resp.StatusCode)
return
}
results <- nil
}(i)
}
// Collect results
var errors []error
for i := 0; i < numRequests; i++ {
if err := <-results; err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
t.Errorf("Got %d errors in concurrent requests: %v", len(errors), errors[:min(5, len(errors))])
}
// Verify all requests reached backend
if backend.RequestCount() != numRequests {
t.Errorf("Expected %d backend requests, got %d", numRequests, backend.RequestCount())
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@ -1,494 +0,0 @@
package integration
import (
"encoding/json"
"net/http"
"testing"
"github.com/konduktor/konduktor/internal/extension"
)
// ============== Route Priority Tests ==============
func TestRouting_ExactMatchPriority(t *testing.T) {
// Exact match should have highest priority
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
"source": "default",
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, err := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Exact match - highest priority
"=/api/status": map[string]interface{}{
"return": "200 exact-match",
"content_type": "text/plain",
},
// Regex that also matches /api/status
"~^/api/.*": map[string]interface{}{
"proxy_pass": backend.URL(),
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
if err != nil {
t.Fatalf("Failed to create routing extension: %v", err)
}
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Test exact match route - should return static response
resp, err := client.Get("/api/status", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
body := ReadBody(t, resp)
if string(body) != "exact-match" {
t.Errorf("Expected 'exact-match', got %q", string(body))
}
// Regex route should be used for other /api/* paths
resp2, err := client.Get("/api/other", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp2.Body.Close()
AssertStatus(t, resp2, http.StatusOK)
// Verify it went to backend
if backend.RequestCount() != 1 {
t.Errorf("Expected 1 backend request, got %d", backend.RequestCount())
}
}
// ============== Case Sensitivity Tests ==============
func TestRouting_CaseSensitiveRegex(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Case-sensitive regex (~)
"~^/API/test$": map[string]interface{}{
"return": "200 case-sensitive",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"return": "200 default",
"content_type": "text/plain",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Exact case match should work
resp, err := client.Get("/API/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
body := ReadBody(t, resp)
if string(body) != "case-sensitive" {
t.Errorf("Expected 'case-sensitive' for /API/test, got %q", string(body))
}
// Different case should NOT match
resp2, err := client.Get("/api/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp2.Body.Close()
body2 := ReadBody(t, resp2)
if string(body2) != "default" {
t.Errorf("Expected 'default' for /api/test (case mismatch), got %q", string(body2))
}
}
func TestRouting_CaseInsensitiveRegex(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
// Case-insensitive regex (~*)
"~*^/api/test$": map[string]interface{}{
"return": "200 case-insensitive",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"return": "200 default",
"content_type": "text/plain",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
testCases := []struct {
path string
expected string
}{
{"/api/test", "case-insensitive"},
{"/API/test", "case-insensitive"},
{"/Api/Test", "case-insensitive"},
{"/API/TEST", "case-insensitive"},
{"/api/other", "default"},
}
for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
resp, err := client.Get(tc.path, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
body := ReadBody(t, resp)
if string(body) != tc.expected {
t.Errorf("Expected %q for %s, got %q", tc.expected, tc.path, string(body))
}
})
}
}
// ============== Default Route Tests ==============
func TestRouting_DefaultRoute(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"handler": "default",
"path": r.URL.Path,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/specific": map[string]interface{}{
"return": "200 specific",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Non-matching paths should go to default
paths := []string{"/", "/random", "/path/to/resource", "/api/v1/users"}
for _, path := range paths {
t.Run(path, func(t *testing.T) {
resp, err := client.Get(path, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["handler"] != "default" {
t.Errorf("Expected default handler, got %v", result["handler"])
}
})
}
}
// ============== Return Directive Tests ==============
func TestRouting_ReturnDirective(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/health": map[string]interface{}{
"return": "200 OK",
"content_type": "text/plain",
},
"=/status": map[string]interface{}{
"return": "200 {\"status\": \"healthy\"}",
"content_type": "application/json",
},
"=/forbidden": map[string]interface{}{
"return": "404 Not Found",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"return": "200 default",
"content_type": "text/plain",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
testCases := []struct {
path string
expectedStatus int
expectedBody string
contentType string
}{
{"/health", 200, "OK", "text/plain"},
{"/status", 200, `{"status": "healthy"}`, "application/json"},
{"/forbidden", 404, "Not Found", "text/plain"},
}
for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
resp, err := client.Get(tc.path, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, tc.expectedStatus)
AssertHeaderContains(t, resp, "Content-Type", tc.contentType)
body := ReadBody(t, resp)
if string(body) != tc.expectedBody {
t.Errorf("Expected body %q, got %q", tc.expectedBody, string(body))
}
})
}
}
// ============== Multiple Regex Routes Tests ==============
func TestRouting_MultipleRegexRoutes(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"~^/api/v1/.*": map[string]interface{}{
"return": "200 v1",
"content_type": "text/plain",
},
"~^/api/v2/.*": map[string]interface{}{
"return": "200 v2",
"content_type": "text/plain",
},
"~^/api/.*": map[string]interface{}{
"return": "200 api-generic",
"content_type": "text/plain",
},
"__default__": map[string]interface{}{
"return": "200 default",
"content_type": "text/plain",
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
testCases := []struct {
path string
expected string
}{
{"/api/v1/users", "v1"},
{"/api/v2/users", "v2"},
{"/api/v3/users", "api-generic"},
{"/other", "default"},
}
for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
resp, err := client.Get(tc.path, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
body := ReadBody(t, resp)
if string(body) != tc.expected {
t.Errorf("Expected %q for %s, got %q", tc.expected, tc.path, string(body))
}
})
}
}
// ============== Regex with Named Groups ==============
func TestRouting_RegexNamedGroups(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"path": r.URL.Path,
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"~^/users/(?P<userId>\\d+)/posts/(?P<postId>\\d+)$": map[string]interface{}{
"proxy_pass": backend.URL() + "/api/v2/users/{userId}/posts/{postId}",
},
"~^/items/(?P<category>[a-z]+)/(?P<id>\\d+)$": map[string]interface{}{
"proxy_pass": backend.URL() + "/catalog/{category}/item/{id}",
},
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
testCases := []struct {
requestPath string
expectedPath string
}{
{"/users/123/posts/456", "/api/v2/users/123/posts/456"},
{"/items/electronics/789", "/catalog/electronics/item/789"},
}
for _, tc := range testCases {
t.Run(tc.requestPath, func(t *testing.T) {
resp, err := client.Get(tc.requestPath, nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusOK)
lastReq := backend.LastRequest()
if lastReq == nil {
t.Fatal("No request received by backend")
}
if lastReq.Path != tc.expectedPath {
t.Errorf("Expected backend path %s, got %s", tc.expectedPath, lastReq.Path)
}
})
}
}
// ============== No Matching Route Tests ==============
func TestRouting_NoMatchingRoute(t *testing.T) {
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"=/specific": map[string]interface{}{
"return": "200 specific",
"content_type": "text/plain",
},
// No default route
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
// Request to non-matching path should return 404
resp, err := client.Get("/other", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
AssertStatus(t, resp, http.StatusNotFound)
}
// ============== Headers in Return Tests ==============
func TestRouting_CustomHeaders(t *testing.T) {
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"x-custom-header": r.Header.Get("X-Custom-Header"),
"x-api-version": r.Header.Get("X-API-Version"),
})
})
defer backend.Close()
logger := createTestLogger(t)
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
"regex_locations": map[string]interface{}{
"__default__": map[string]interface{}{
"proxy_pass": backend.URL(),
"headers": []interface{}{
"X-Custom-Header: custom-value",
"X-API-Version: v1",
},
},
},
}, logger)
server := StartTestServer(t, &ServerConfig{
Extensions: []extension.Extension{routingExt},
})
defer server.Close()
client := NewHTTPClient(server.URL)
resp, err := client.Get("/test", nil)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
var result map[string]string
json.NewDecoder(resp.Body).Decode(&result)
if result["x-custom-header"] != "custom-value" {
t.Errorf("Expected X-Custom-Header=custom-value, got %v", result["x-custom-header"])
}
if result["x-api-version"] != "v1" {
t.Errorf("Expected X-API-Version=v1, got %v", result["x-api-version"])
}
}

6
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
[[package]]
name = "a2wsgi"
@ -1720,5 +1720,5 @@ wsgi = ["a2wsgi"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.12"
content-hash = "653d7b992e2bb133abde2e8b1c44265e948ed90487ab3f2670429510a8aa0683"
python-versions = ">=3.12, <=3.13.7"
content-hash = "411b746f1a577ed635af9fd3e01daf1fa03950d27ef23888fc7cdd0b99762404"

View File

@ -3,11 +3,11 @@ name = "pyserve"
version = "0.9.10"
description = "Python Application Orchestrator & HTTP Server - unified gateway for multiple Python web apps"
authors = [
{name = "Илья Глазунов",email = "i.glazunov@sapiens.solutions"}
{name = "Илья Глазунов",email = "lead@pyserve.org"}
]
license = {text = "MIT"}
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.12, <=3.13.7"
dependencies = [
"starlette (>=0.47.3,<0.48.0)",
"uvicorn[standard] (>=0.35.0,<0.36.0)",

View File

@ -24,24 +24,24 @@ cdef class FastMountedPath:
def __cinit__(self):
self._path = ""
self._path_with_slash = "/"
self._path_len = 0
self._is_root = 1
self._path_len = <Py_ssize_t>0
self._is_root = <bint>True
self.name = ""
self.strip_path = 1
self.strip_path = <bint>True
def __init__(self, str path, str name="", bint strip_path=True):
def __init__(self, str path, str name="", bint strip_path=<bint>True):
cdef Py_ssize_t path_len
path_len = len(path)
path_len = <Py_ssize_t>len(path)
if path_len > 1 and path[path_len - 1] == '/':
path = path[:path_len - 1]
self._path = path
self._path_len = len(path)
self._is_root = 1 if (path == "" or path == "/") else 0
self._path_with_slash = path + "/" if self._is_root == 0 else "/"
self._path_len = <Py_ssize_t>len(path)
self._is_root = <bint>(path == "" or path == "/")
self._path_with_slash = path + "/" if not self._is_root else "/"
self.name = name if name else path
self.strip_path = 1 if strip_path else 0
self.strip_path = <bint>strip_path
@property
def path(self) -> str:
@ -51,20 +51,20 @@ cdef class FastMountedPath:
cdef Py_ssize_t req_len
if self._is_root:
return 1
return <bint>True
req_len = len(request_path)
req_len = <Py_ssize_t>len(request_path)
if req_len < self._path_len:
return 0
return <bint>False
if req_len == self._path_len:
return 1 if request_path == self._path else 0
return <bint>(request_path == self._path)
if request_path[self._path_len] == '/':
return 1 if request_path[:self._path_len] == self._path else 0
return <bint>(request_path[:self._path_len] == self._path)
return 0
return <bint>False
cpdef str get_modified_path(self, str original_path):
cdef str new_path
@ -108,7 +108,7 @@ cdef class FastMountManager:
self._mounts = sorted(self._mounts, key=_get_path_len_neg, reverse=False)
self._mount_count = len(self._mounts)
cpdef FastMountedPath get_mount(self, str request_path):
cpdef object get_mount(self, str request_path):
cdef:
int i
FastMountedPath mount
@ -126,7 +126,7 @@ cdef class FastMountManager:
Py_ssize_t path_len
FastMountedPath mount
path_len = len(path)
path_len = <Py_ssize_t>len(path)
if path_len > 1 and path[path_len - 1] == '/':
path = path[:path_len - 1]
@ -135,9 +135,9 @@ cdef class FastMountManager:
if mount._path == path:
del self._mounts[i]
self._mount_count -= 1
return 1
return <bint>True
return 0
return <bint>False
@property
def mounts(self) -> list:
@ -164,27 +164,27 @@ cdef class FastMountManager:
cpdef bint path_matches_prefix(str request_path, str mount_path):
cdef:
Py_ssize_t mount_len = len(mount_path)
Py_ssize_t req_len = len(request_path)
Py_ssize_t mount_len = <Py_ssize_t>len(mount_path)
Py_ssize_t req_len = <Py_ssize_t>len(request_path)
if mount_len == 0 or mount_path == "/":
return 1
return <bint>True
if req_len < mount_len:
return 0
return <bint>False
if req_len == mount_len:
return 1 if request_path == mount_path else 0
return <bint>(request_path == mount_path)
if request_path[mount_len] == '/':
return 1 if request_path[:mount_len] == mount_path else 0
return <bint>(request_path[:mount_len] == mount_path)
return 0
return <bint>False
cpdef str strip_path_prefix(str original_path, str mount_path):
cdef:
Py_ssize_t mount_len = len(mount_path)
Py_ssize_t mount_len = <Py_ssize_t>len(mount_path)
str result
if mount_len == 0 or mount_path == "/":
@ -198,11 +198,11 @@ cpdef str strip_path_prefix(str original_path, str mount_path):
return result
cpdef tuple match_and_modify_path(str request_path, str mount_path, bint strip_path=True):
cpdef tuple match_and_modify_path(str request_path, str mount_path, bint strip_path=<bint>True):
cdef:
Py_ssize_t mount_len = len(mount_path)
Py_ssize_t req_len = len(request_path)
bint is_root = 1 if (mount_len == 0 or mount_path == "/") else 0
Py_ssize_t mount_len = <Py_ssize_t>len(mount_path)
Py_ssize_t req_len = <Py_ssize_t>len(request_path)
bint is_root = <bint>(mount_len == 0 or mount_path == "/")
str modified
if is_root:

486
pyserve/_routing.pyx Normal file
View File

@ -0,0 +1,486 @@
# cython: language_level=3
# cython: boundscheck=False
# cython: wraparound=False
# cython: cdivision=True
from libc.stdint cimport uint32_t
from libc.stddef cimport size_t
from cpython.bytes cimport PyBytes_AsString, PyBytes_GET_SIZE
cimport cython
from ._routing_pcre2 cimport *
from typing import Optional
# Type aliases for cleaner NULL casts
ctypedef pcre2_compile_context* compile_ctx_ptr
ctypedef pcre2_match_context* match_ctx_ptr
ctypedef pcre2_general_context* general_ctx_ptr
# Buffer size for error messages
DEF ERROR_BUFFER_SIZE = 256
# Maximum capture groups we support
DEF MAX_CAPTURE_GROUPS = 32
cdef class PCRE2Pattern:
cdef:
pcre2_code* _code
pcre2_match_data* _match_data
bint _jit_available
str _pattern_str
uint32_t _capture_count
dict _name_to_index # Named capture groups
list _index_to_name # Index to name mapping
def __cinit__(self):
self._code = NULL
self._match_data = NULL
self._jit_available = <bint>False
self._capture_count = 0
self._name_to_index = {}
self._index_to_name = []
def __dealloc__(self):
if self._match_data is not NULL:
pcre2_match_data_free(self._match_data)
self._match_data = NULL
if self._code is not NULL:
pcre2_code_free(self._code)
self._code = NULL
@staticmethod
cdef PCRE2Pattern _create(str pattern, bint case_insensitive=<bint>False, bint use_jit=<bint>True):
cdef:
PCRE2Pattern self = PCRE2Pattern.__new__(PCRE2Pattern)
bytes pattern_bytes
const char* pattern_ptr
Py_ssize_t pattern_len
uint32_t options = 0
int errorcode = 0
PCRE2_SIZE erroroffset = 0
int jit_result
uint32_t capture_count = 0
self._pattern_str = pattern
self._name_to_index = {}
self._index_to_name = []
pattern_bytes = pattern.encode('utf-8')
pattern_ptr = PyBytes_AsString(pattern_bytes)
pattern_len = PyBytes_GET_SIZE(pattern_bytes)
options = PCRE2_UTF | PCRE2_UCP
if case_insensitive:
options |= PCRE2_CASELESS
self._code = pcre2_compile(
<PCRE2_SPTR>pattern_ptr,
<PCRE2_SIZE>pattern_len,
options,
&errorcode,
&erroroffset,
<compile_ctx_ptr>NULL
)
if self._code is NULL:
error_msg = PCRE2Pattern._get_error_message(errorcode)
raise ValueError(f"PCRE2 compile error at offset {erroroffset}: {error_msg}")
if use_jit:
jit_result = pcre2_jit_compile(self._code, PCRE2_JIT_COMPLETE)
self._jit_available = <bint>(jit_result == 0)
pcre2_pattern_info(self._code, PCRE2_INFO_CAPTURECOUNT, <void*>&capture_count)
self._capture_count = capture_count
self._match_data = pcre2_match_data_create_from_pattern(self._code, <general_ctx_ptr>NULL)
if self._match_data is NULL:
pcre2_code_free(self._code)
self._code = NULL
raise MemoryError("Failed to create match data")
self._extract_named_groups()
return self
cdef void _extract_named_groups(self):
cdef:
uint32_t namecount = 0
uint32_t nameentrysize = 0
PCRE2_SPTR nametable
uint32_t i
int group_num
bytes name_bytes
str name
pcre2_pattern_info(self._code, PCRE2_INFO_NAMECOUNT, <void*>&namecount)
if namecount == 0:
return # void return
pcre2_pattern_info(self._code, PCRE2_INFO_NAMEENTRYSIZE, <void*>&nameentrysize)
pcre2_pattern_info(self._code, PCRE2_INFO_NAMETABLE, <void*>&nametable)
self._index_to_name = [None] * (self._capture_count + 1)
for i in range(namecount):
group_num = (<int>nametable[0] << 8) | <int>nametable[1]
name_bytes = <bytes>(nametable + 2)
name = name_bytes.decode('utf-8')
self._name_to_index[name] = group_num
if <uint32_t>group_num <= self._capture_count:
self._index_to_name[<Py_ssize_t>group_num] = name
nametable += nameentrysize
@staticmethod
cdef str _get_error_message(int errorcode):
cdef:
PCRE2_UCHAR buffer[ERROR_BUFFER_SIZE]
int result
result = pcre2_get_error_message(errorcode, buffer, ERROR_BUFFER_SIZE)
if result < 0:
return f"Unknown error {errorcode}"
return (<bytes>buffer).decode('utf-8')
cpdef bint search(self, str subject):
"""
Search for pattern anywhere in subject.
Returns True if found, False otherwise.
"""
cdef:
bytes subject_bytes
const char* subject_ptr
Py_ssize_t subject_len
int result
if self._code is NULL:
return <bint>False
subject_bytes = subject.encode('utf-8')
subject_ptr = PyBytes_AsString(subject_bytes)
subject_len = PyBytes_GET_SIZE(subject_bytes)
if self._jit_available:
result = pcre2_jit_match(
self._code,
<PCRE2_SPTR>subject_ptr,
<PCRE2_SIZE>subject_len,
0, # start offset
0, # options
self._match_data,
<match_ctx_ptr>NULL
)
else:
result = pcre2_match(
self._code,
<PCRE2_SPTR>subject_ptr,
<PCRE2_SIZE>subject_len,
0,
0,
self._match_data,
<match_ctx_ptr>NULL
)
return <bint>(result >= 0)
cpdef dict groupdict(self, str subject):
"""
Match pattern and return dict of named groups.
Returns empty dict if no match or no named groups.
"""
cdef:
bytes subject_bytes
const char* subject_ptr
Py_ssize_t subject_len
int result
PCRE2_SIZE* ovector
dict groups = {}
str name
int index
PCRE2_SIZE start, end
if self._code is NULL or not self._name_to_index:
return groups
subject_bytes = subject.encode('utf-8')
subject_ptr = PyBytes_AsString(subject_bytes)
subject_len = PyBytes_GET_SIZE(subject_bytes)
if self._jit_available:
result = pcre2_jit_match(
self._code,
<PCRE2_SPTR>subject_ptr,
<PCRE2_SIZE>subject_len,
0, 0,
self._match_data,
<match_ctx_ptr>NULL
)
else:
result = pcre2_match(
self._code,
<PCRE2_SPTR>subject_ptr,
<PCRE2_SIZE>subject_len,
0, 0,
self._match_data,
<match_ctx_ptr>NULL
)
if result < 0:
return groups
ovector = pcre2_get_ovector_pointer(self._match_data)
for name, index in self._name_to_index.items():
start = ovector[<Py_ssize_t>(2 * index)]
end = ovector[<Py_ssize_t>(2 * index + 1)]
if start != PCRE2_UNSET and end != PCRE2_UNSET:
groups[name] = subject_bytes[start:end].decode('utf-8')
else:
groups[name] = None
return groups
cpdef tuple search_with_groups(self, str subject):
cdef:
bytes subject_bytes
const char* subject_ptr
Py_ssize_t subject_len
int result
PCRE2_SIZE* ovector
dict groups = {}
str name
int index
PCRE2_SIZE start, end
if self._code is NULL:
return (False, {})
subject_bytes = subject.encode('utf-8')
subject_ptr = PyBytes_AsString(subject_bytes)
subject_len = PyBytes_GET_SIZE(subject_bytes)
if self._jit_available:
result = pcre2_jit_match(
self._code,
<PCRE2_SPTR>subject_ptr,
<PCRE2_SIZE>subject_len,
0, 0,
self._match_data,
<match_ctx_ptr>NULL
)
else:
result = pcre2_match(
self._code,
<PCRE2_SPTR>subject_ptr,
<PCRE2_SIZE>subject_len,
0, 0,
self._match_data,
<match_ctx_ptr>NULL
)
if result < 0:
return (False, {})
if self._name_to_index:
ovector = pcre2_get_ovector_pointer(self._match_data)
for name, index in self._name_to_index.items():
start = ovector[<Py_ssize_t>(2 * index)]
end = ovector[<Py_ssize_t>(2 * index + 1)]
if start != PCRE2_UNSET and end != PCRE2_UNSET:
groups[name] = subject_bytes[start:end].decode('utf-8')
else:
groups[name] = None
return (True, groups)
@property
def pattern(self) -> str:
return self._pattern_str
@property
def jit_compiled(self) -> bool:
return <bint>self._jit_available
@property
def capture_count(self) -> int:
return self._capture_count
cdef class FastRouteMatch:
cdef:
public dict config
public dict params
def __cinit__(self):
self.config = {}
self.params = {}
def __init__(self, dict config, params=None):
self.config = config
self.params = params if params is not None else {}
cdef class FastRouter:
"""
High-performance router with PCRE2 JIT-compiled patterns.
Matching order (nginx-like):
1. Exact routes (prefix "=") - O(1) dict lookup
2. Regex routes (prefix "~" or "~*") - PCRE2 JIT matching
3. Default route (fallback)
"""
cdef:
dict _exact_routes
list _regex_routes
dict _default_route
bint _has_default
int _regex_count
def __cinit__(self):
self._exact_routes = {}
self._regex_routes = []
self._default_route = {}
self._has_default = <bint>False
self._regex_count = 0
def __init__(self):
self._exact_routes = {}
self._regex_routes = []
self._default_route = {}
self._has_default = <bint>False
self._regex_count = 0
def add_route(self, str pattern, dict config):
cdef:
str exact_path
str regex_pattern
bint case_insensitive
PCRE2Pattern compiled_pattern
if pattern.startswith("="):
exact_path = pattern[1:]
self._exact_routes[exact_path] = config
elif pattern == "__default__":
self._default_route = config
self._has_default = <bint>True
elif pattern.startswith("~"):
case_insensitive = <bint>pattern.startswith("~*")
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
try:
compiled_pattern = PCRE2Pattern._create(regex_pattern, case_insensitive)
self._regex_routes.append((compiled_pattern, config))
self._regex_count = len(self._regex_routes)
except (ValueError, MemoryError):
pass # Skip invalid patterns
cpdef object match(self, str path):
cdef:
dict config
dict params
int i
PCRE2Pattern pattern
tuple route_entry
bint matched
if path in self._exact_routes:
config = self._exact_routes[path]
return FastRouteMatch(config, {})
for i in range(self._regex_count):
route_entry = <tuple>self._regex_routes[i]
pattern = <PCRE2Pattern>route_entry[0]
config = <dict>route_entry[1]
matched, params = pattern.search_with_groups(path)
if matched:
return FastRouteMatch(config, params)
if self._has_default:
return FastRouteMatch(self._default_route, {})
return None
@property
def exact_routes(self) -> dict:
return self._exact_routes
@property
def routes(self) -> dict:
"""Return regex routes as dict (pattern_str -> config)."""
cdef:
dict result = {}
PCRE2Pattern pattern
for pattern, config in self._regex_routes:
result[pattern.pattern] = config
return result
@property
def default_route(self) -> Optional[dict]:
return self._default_route if self._has_default else None
cpdef list list_routes(self):
cdef:
list result = []
str path_str
dict config
PCRE2Pattern pattern
for path_str, config in self._exact_routes.items():
result.append({
"type": "exact",
"pattern": f"={path_str}",
"config": config,
})
for pattern, config in self._regex_routes:
result.append({
"type": "regex",
"pattern": pattern.pattern,
"jit_compiled": pattern.jit_compiled,
"config": config,
})
if self._has_default:
result.append({
"type": "default",
"pattern": "__default__",
"config": self._default_route,
})
return result
def compile_pattern(str pattern, bint case_insensitive=<bint>False) -> PCRE2Pattern:
"""
Compile a PCRE2 pattern with JIT support.
Args:
pattern: Regular expression pattern
case_insensitive: Whether to match case-insensitively
Returns:
Compiled PCRE2Pattern object
"""
return PCRE2Pattern._create(pattern, case_insensitive)
def fast_match(router: FastRouter, str path):
"""
Convenience function for matching a path.
Args:
router: FastRouter instance
path: URL path to match
Returns:
FastRouteMatch or None
"""
return router.match(path)

208
pyserve/_routing_pcre2.pxd Normal file
View File

@ -0,0 +1,208 @@
# cython: language_level=3
from libc.stdint cimport uint8_t, uint32_t, int32_t
from libc.stddef cimport size_t
cdef extern from "pcre2.h":
pass
cdef extern from *:
ctypedef struct pcre2_code_8:
pass
ctypedef pcre2_code_8 pcre2_code
ctypedef struct pcre2_match_data_8:
pass
ctypedef pcre2_match_data_8 pcre2_match_data
ctypedef struct pcre2_compile_context_8:
pass
ctypedef pcre2_compile_context_8 pcre2_compile_context
ctypedef struct pcre2_match_context_8:
pass
ctypedef pcre2_match_context_8 pcre2_match_context
ctypedef struct pcre2_general_context_8:
pass
ctypedef pcre2_general_context_8 pcre2_general_context
ctypedef uint8_t PCRE2_UCHAR
ctypedef const uint8_t* PCRE2_SPTR
ctypedef size_t PCRE2_SIZE
uint32_t PCRE2_CASELESS
uint32_t PCRE2_MULTILINE
uint32_t PCRE2_DOTALL
uint32_t PCRE2_UTF
uint32_t PCRE2_UCP
uint32_t PCRE2_NO_UTF_CHECK
uint32_t PCRE2_ANCHORED
uint32_t PCRE2_ENDANCHORED
uint32_t PCRE2_JIT_COMPLETE
uint32_t PCRE2_JIT_PARTIAL_SOFT
uint32_t PCRE2_JIT_PARTIAL_HARD
int PCRE2_ERROR_NOMATCH
int PCRE2_ERROR_PARTIAL
int PCRE2_ERROR_JIT_STACKLIMIT
PCRE2_SIZE PCRE2_UNSET
PCRE2_SIZE PCRE2_ZERO_TERMINATED
pcre2_code* pcre2_compile_8(
PCRE2_SPTR pattern,
PCRE2_SIZE length,
uint32_t options,
int* errorcode,
PCRE2_SIZE* erroroffset,
pcre2_compile_context* ccontext
)
void pcre2_code_free_8(pcre2_code* code)
int pcre2_jit_compile_8(pcre2_code* code, uint32_t options)
pcre2_match_data* pcre2_match_data_create_from_pattern_8(
const pcre2_code* code,
pcre2_general_context* gcontext
)
pcre2_match_data* pcre2_match_data_create_8(
uint32_t ovecsize,
pcre2_general_context* gcontext
)
void pcre2_match_data_free_8(pcre2_match_data* match_data)
int pcre2_match_8(
const pcre2_code* code,
PCRE2_SPTR subject,
PCRE2_SIZE length,
PCRE2_SIZE startoffset,
uint32_t options,
pcre2_match_data* match_data,
pcre2_match_context* mcontext
)
int pcre2_jit_match_8(
const pcre2_code* code,
PCRE2_SPTR subject,
PCRE2_SIZE length,
PCRE2_SIZE startoffset,
uint32_t options,
pcre2_match_data* match_data,
pcre2_match_context* mcontext
)
PCRE2_SIZE* pcre2_get_ovector_pointer_8(pcre2_match_data* match_data)
uint32_t pcre2_get_ovector_count_8(pcre2_match_data* match_data)
int pcre2_pattern_info_8(
const pcre2_code* code,
uint32_t what,
void* where
)
uint32_t PCRE2_INFO_CAPTURECOUNT
uint32_t PCRE2_INFO_NAMECOUNT
uint32_t PCRE2_INFO_NAMETABLE
uint32_t PCRE2_INFO_NAMEENTRYSIZE
uint32_t PCRE2_INFO_JITSIZE
int pcre2_get_error_message_8(
int errorcode,
PCRE2_UCHAR* buffer,
PCRE2_SIZE bufflen
)
int pcre2_substring_copy_byname_8(
pcre2_match_data* match_data,
PCRE2_SPTR name,
PCRE2_UCHAR* buffer,
PCRE2_SIZE* bufflen
)
int pcre2_substring_copy_bynumber_8(
pcre2_match_data* match_data,
uint32_t number,
PCRE2_UCHAR* buffer,
PCRE2_SIZE* bufflen
)
int pcre2_substring_get_byname_8(
pcre2_match_data* match_data,
PCRE2_SPTR name,
PCRE2_UCHAR** bufferptr,
PCRE2_SIZE* bufflen
)
int pcre2_substring_get_bynumber_8(
pcre2_match_data* match_data,
uint32_t number,
PCRE2_UCHAR** bufferptr,
PCRE2_SIZE* bufflen
)
void pcre2_substring_free_8(PCRE2_UCHAR* buffer)
cdef inline pcre2_code* pcre2_compile(
PCRE2_SPTR pattern,
PCRE2_SIZE length,
uint32_t options,
int* errorcode,
PCRE2_SIZE* erroroffset,
pcre2_compile_context* ccontext
) noexcept:
return pcre2_compile_8(pattern, length, options, errorcode, erroroffset, ccontext)
cdef inline void pcre2_code_free(pcre2_code* code) noexcept:
pcre2_code_free_8(code)
cdef inline int pcre2_jit_compile(pcre2_code* code, uint32_t options) noexcept:
return pcre2_jit_compile_8(code, options)
cdef inline pcre2_match_data* pcre2_match_data_create_from_pattern(
const pcre2_code* code,
pcre2_general_context* gcontext
) noexcept:
return pcre2_match_data_create_from_pattern_8(code, gcontext)
cdef inline void pcre2_match_data_free(pcre2_match_data* match_data) noexcept:
pcre2_match_data_free_8(match_data)
cdef inline int pcre2_match(
const pcre2_code* code,
PCRE2_SPTR subject,
PCRE2_SIZE length,
PCRE2_SIZE startoffset,
uint32_t options,
pcre2_match_data* match_data,
pcre2_match_context* mcontext
) noexcept:
return pcre2_match_8(code, subject, length, startoffset, options, match_data, mcontext)
cdef inline int pcre2_jit_match(
const pcre2_code* code,
PCRE2_SPTR subject,
PCRE2_SIZE length,
PCRE2_SIZE startoffset,
uint32_t options,
pcre2_match_data* match_data,
pcre2_match_context* mcontext
) noexcept:
return pcre2_jit_match_8(code, subject, length, startoffset, options, match_data, mcontext)
cdef inline PCRE2_SIZE* pcre2_get_ovector_pointer(pcre2_match_data* match_data) noexcept:
return pcre2_get_ovector_pointer_8(match_data)
cdef inline uint32_t pcre2_get_ovector_count(pcre2_match_data* match_data) noexcept:
return pcre2_get_ovector_count_8(match_data)
cdef inline int pcre2_pattern_info(const pcre2_code* code, uint32_t what, void* where) noexcept:
return pcre2_pattern_info_8(code, what, where)
cdef inline int pcre2_get_error_message(int errorcode, PCRE2_UCHAR* buffer, PCRE2_SIZE bufflen) noexcept:
return pcre2_get_error_message_8(errorcode, buffer, bufflen)

129
pyserve/_routing_py.py Normal file
View File

@ -0,0 +1,129 @@
"""
Pure Python fallback for _routing when PCRE2/Cython is not available.
This module provides the same interface using the standard library `re` module.
It's slower than the Cython+PCRE2 implementation but works everywhere.
In future we may add pcre2.py library support for better performance in this module.
"""
import re
from typing import Any, Dict, List, Optional, Pattern, Tuple
class FastRouteMatch:
__slots__ = ("config", "params")
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
self.config = config
self.params = params if params is not None else {}
class FastRouter:
"""
Router with regex pattern matching.
Matching order (nginx-like):
1. Exact routes (prefix "=") - O(1) dict lookup
2. Regex routes (prefix "~" or "~*") - linear scan
3. Default route (fallback)
"""
__slots__ = ("_exact_routes", "_regex_routes", "_default_route", "_has_default", "_regex_count")
def __init__(self) -> None:
self._exact_routes: Dict[str, Dict[str, Any]] = {}
self._regex_routes: List[Tuple[Pattern[str], Dict[str, Any]]] = []
self._default_route: Dict[str, Any] = {}
self._has_default: bool = False
self._regex_count: int = 0
def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
if pattern.startswith("="):
exact_path = pattern[1:]
self._exact_routes[exact_path] = config
return
if pattern == "__default__":
self._default_route = config
self._has_default = True
return
if pattern.startswith("~"):
case_insensitive = pattern.startswith("~*")
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
flags = re.IGNORECASE if case_insensitive else 0
try:
compiled_pattern = re.compile(regex_pattern, flags)
self._regex_routes.append((compiled_pattern, config))
self._regex_count = len(self._regex_routes)
except re.error:
pass # Ignore invalid patterns
def match(self, path: str) -> Optional[FastRouteMatch]:
if path in self._exact_routes:
config = self._exact_routes[path]
return FastRouteMatch(config, {})
for pattern, config in self._regex_routes:
match_obj = pattern.search(path)
if match_obj is not None:
params = match_obj.groupdict()
return FastRouteMatch(config, params)
if self._has_default:
return FastRouteMatch(self._default_route, {})
return None
@property
def exact_routes(self) -> Dict[str, Dict[str, Any]]:
return self._exact_routes
@property
def routes(self) -> Dict[Pattern[str], Dict[str, Any]]:
return {p: c for p, c in self._regex_routes}
@property
def default_route(self) -> Optional[Dict[str, Any]]:
return self._default_route if self._has_default else None
def list_routes(self) -> List[Dict[str, Any]]:
result: List[Dict[str, Any]] = []
for path, config in self._exact_routes.items():
result.append({
"type": "exact",
"pattern": f"={path}",
"config": config,
})
for pattern, config in self._regex_routes:
result.append({
"type": "regex",
"pattern": pattern.pattern,
"config": config,
})
if self._has_default:
result.append({
"type": "default",
"pattern": "__default__",
"config": self._default_route,
})
return result
def fast_match(router: FastRouter, path: str) -> Optional[FastRouteMatch]:
"""
Convenience function for matching a path.
Args:
router: FastRouter instance
path: URL path to match
Returns:
FastRouteMatch or None
"""
return router.match(path)

View File

@ -1,7 +1,6 @@
import mimetypes
import re
from pathlib import Path
from typing import Any, Dict, Optional, Pattern
from typing import Any, Dict
from urllib.parse import urlparse
import httpx
@ -10,60 +9,19 @@ from starlette.responses import FileResponse, PlainTextResponse, Response
from .logging_utils import get_logger
try:
from pyserve._routing import FastRouteMatch, FastRouter, fast_match # type: ignore
CYTHON_ROUTING_AVAILABLE = True
except ImportError:
from pyserve._routing_py import FastRouteMatch, FastRouter, fast_match
CYTHON_ROUTING_AVAILABLE = False
logger = get_logger(__name__)
class RouteMatch:
def __init__(self, config: Dict[str, Any], params: Optional[Dict[str, str]] = None):
self.config = config
self.params = params or {}
class Router:
def __init__(self, static_dir: str = "./static"):
self.static_dir = Path(static_dir)
self.routes: Dict[Pattern, Dict[str, Any]] = {}
self.exact_routes: Dict[str, Dict[str, Any]] = {}
self.default_route: Optional[Dict[str, Any]] = None
def add_route(self, pattern: str, config: Dict[str, Any]) -> None:
if pattern.startswith("="):
exact_path = pattern[1:]
self.exact_routes[exact_path] = config
logger.debug(f"Added exact route: {exact_path}")
return
if pattern == "__default__":
self.default_route = config
logger.debug("Added default route")
return
if pattern.startswith("~"):
case_insensitive = pattern.startswith("~*")
regex_pattern = pattern[2:] if case_insensitive else pattern[1:]
flags = re.IGNORECASE if case_insensitive else 0
try:
compiled_pattern = re.compile(regex_pattern, flags)
self.routes[compiled_pattern] = config
logger.debug(f"Added regex route: {pattern}")
except re.error as e:
logger.error(f"Regex compilation error {pattern}: {e}")
def match(self, path: str) -> Optional[RouteMatch]:
if path in self.exact_routes:
return RouteMatch(self.exact_routes[path])
for pattern, config in self.routes.items():
match = pattern.search(path)
if match:
params = match.groupdict()
return RouteMatch(config, params)
if self.default_route:
return RouteMatch(self.default_route)
return None
# Aliases for backward compatibility
RouteMatch = FastRouteMatch
Router = FastRouter
class RequestHandler:

View File

@ -9,9 +9,86 @@ Or via make:
"""
import os
import subprocess
import sys
from pathlib import Path
def get_pcre2_config():
include_dirs = []
library_dirs = []
libraries = ["pcre2-8"]
try:
cflags = subprocess.check_output(
["pkg-config", "--cflags", "libpcre2-8"],
stderr=subprocess.DEVNULL
).decode().strip()
libs = subprocess.check_output(
["pkg-config", "--libs", "libpcre2-8"],
stderr=subprocess.DEVNULL
).decode().strip()
for flag in cflags.split():
if flag.startswith("-I"):
include_dirs.append(flag[2:])
for flag in libs.split():
if flag.startswith("-L"):
library_dirs.append(flag[2:])
elif flag.startswith("-l"):
lib = flag[2:]
if lib not in libraries:
libraries.append(lib)
return include_dirs, library_dirs, libraries
except (subprocess.CalledProcessError, FileNotFoundError):
pass
try:
cflags = subprocess.check_output(
["pcre2-config", "--cflags"],
stderr=subprocess.DEVNULL
).decode().strip()
libs = subprocess.check_output(
["pcre2-config", "--libs8"],
stderr=subprocess.DEVNULL
).decode().strip()
for flag in cflags.split():
if flag.startswith("-I"):
include_dirs.append(flag[2:])
for flag in libs.split():
if flag.startswith("-L"):
library_dirs.append(flag[2:])
elif flag.startswith("-l"):
lib = flag[2:]
if lib not in libraries:
libraries.append(lib)
return include_dirs, library_dirs, libraries
except (subprocess.CalledProcessError, FileNotFoundError):
pass
# Fallback: try common paths
common_paths = [
"/opt/homebrew", # macOS ARM
"/usr/local", # macOS Intel / Linux
"/usr", # Linux
]
for base in common_paths:
include_path = Path(base) / "include"
lib_path = Path(base) / "lib"
if (include_path / "pcre2.h").exists():
include_dirs.append(str(include_path))
library_dirs.append(str(lib_path))
break
return include_dirs, library_dirs, libraries
def build_extensions():
try:
from Cython.Build import cythonize
@ -29,6 +106,14 @@ def build_extensions():
print("Install with: pip install setuptools")
return False
pcre2_include, pcre2_libdir, pcre2_libs = get_pcre2_config()
if not pcre2_include:
print("WARNING: PCRE2 not found. Routing module may not compile.")
print("Install PCRE2: brew install pcre2 (macOS) or apt install libpcre2-dev (Linux)")
else:
print(f"Found PCRE2: includes={pcre2_include}, libs={pcre2_libdir}")
extensions = [
Extension(
"pyserve._path_matcher",
@ -36,6 +121,18 @@ def build_extensions():
extra_compile_args=["-O3", "-ffast-math"],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
),
Extension(
"pyserve._routing",
sources=["pyserve/_routing.pyx"],
include_dirs=pcre2_include,
library_dirs=pcre2_libdir,
libraries=pcre2_libs,
extra_compile_args=["-O3", "-ffast-math"],
define_macros=[
("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION"),
("PCRE2_CODE_UNIT_WIDTH", "8"),
],
),
]
ext_modules = cythonize(
@ -59,7 +156,9 @@ def build_extensions():
cmd.run()
print("\nCython extensions built successfully!")
print(" - pyserve/_path_matcher" + (".pyd" if sys.platform == "win32" else ".so"))
ext_suffix = ".pyd" if sys.platform == "win32" else ".so"
print(f" - pyserve/_path_matcher{ext_suffix}")
print(f" - pyserve/_routing{ext_suffix}")
return True

View File

@ -50,16 +50,10 @@ class TestRouter:
def test_router_initialization(self):
"""Test router initializes with correct defaults."""
router = Router()
assert router.static_dir == Path("./static")
assert router.routes == {}
assert router.exact_routes == {}
assert router.default_route is None
def test_router_custom_static_dir(self):
"""Test router with custom static directory."""
router = Router(static_dir="/custom/path")
assert router.static_dir == Path("/custom/path")
def test_add_exact_route(self):
"""Test adding exact match route."""
router = Router()