forked from aegis/pyserveX
go implementation
This commit is contained in:
parent
c04ab283a6
commit
8f5b9a5cd1
5
.gitignore
vendored
5
.gitignore
vendored
@ -27,4 +27,7 @@ build/
|
|||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
|
|
||||||
|
# Go binaries
|
||||||
|
go/bin
|
||||||
34
go/Dockerfile
Normal file
34
go/Dockerfile
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Multi-stage build for Konduktor
|
||||||
|
FROM golang:1.23-alpine AS builder
|
||||||
|
|
||||||
|
RUN apk add --no-cache git make
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY go.mod go.sum* ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
RUN make build
|
||||||
|
|
||||||
|
FROM alpine:3.19
|
||||||
|
|
||||||
|
RUN apk add --no-cache ca-certificates tzdata
|
||||||
|
|
||||||
|
RUN adduser -D -g '' konduktor
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /build/bin/konduktor /usr/local/bin/
|
||||||
|
COPY --from=builder /build/bin/konduktorctl /usr/local/bin/
|
||||||
|
|
||||||
|
RUN mkdir -p /app/static /app/templates /app/logs && \
|
||||||
|
chown -R konduktor:konduktor /app
|
||||||
|
|
||||||
|
USER konduktor
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
ENTRYPOINT ["konduktor"]
|
||||||
|
CMD ["-c", "/app/config.yaml"]
|
||||||
108
go/Makefile
Normal file
108
go/Makefile
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Konduktor Go Build
|
||||||
|
# Makefile for building and testing Konduktor
|
||||||
|
|
||||||
|
.PHONY: all build build-konduktor build-konduktorctl test clean deps fmt lint run
|
||||||
|
|
||||||
|
# Build configuration
|
||||||
|
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||||
|
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||||
|
BUILD_TIME ?= $(shell date -u '+%Y-%m-%dT%H:%M:%SZ')
|
||||||
|
LDFLAGS := -X main.Version=$(VERSION) -X main.GitCommit=$(GIT_COMMIT) -X main.BuildTime=$(BUILD_TIME)
|
||||||
|
|
||||||
|
# Output directories
|
||||||
|
BIN_DIR := bin
|
||||||
|
|
||||||
|
all: deps build
|
||||||
|
|
||||||
|
# Download dependencies
|
||||||
|
deps:
|
||||||
|
@echo "==> Downloading dependencies..."
|
||||||
|
go mod download
|
||||||
|
go mod tidy
|
||||||
|
|
||||||
|
# Build all binaries
|
||||||
|
build: build-konduktor build-konduktorctl
|
||||||
|
|
||||||
|
# Build konduktor server
|
||||||
|
build-konduktor:
|
||||||
|
@echo "==> Building konduktor..."
|
||||||
|
@mkdir -p $(BIN_DIR)
|
||||||
|
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktor ./cmd/konduktor
|
||||||
|
|
||||||
|
# Build konduktorctl CLI
|
||||||
|
build-konduktorctl:
|
||||||
|
@echo "==> Building konduktorctl..."
|
||||||
|
@mkdir -p $(BIN_DIR)
|
||||||
|
go build -ldflags "$(LDFLAGS)" -o $(BIN_DIR)/konduktorctl ./cmd/konduktorctl
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
test:
|
||||||
|
@echo "==> Running tests..."
|
||||||
|
go test -v -race -cover ./...
|
||||||
|
|
||||||
|
# Run tests with coverage report
|
||||||
|
test-coverage:
|
||||||
|
@echo "==> Running tests with coverage..."
|
||||||
|
go test -v -race -coverprofile=coverage.out ./...
|
||||||
|
go tool cover -html=coverage.out -o coverage.html
|
||||||
|
@echo "Coverage report: coverage.html"
|
||||||
|
|
||||||
|
# Format code
|
||||||
|
fmt:
|
||||||
|
@echo "==> Formatting code..."
|
||||||
|
go fmt ./...
|
||||||
|
goimports -w .
|
||||||
|
|
||||||
|
# Lint code
|
||||||
|
lint:
|
||||||
|
@echo "==> Linting code..."
|
||||||
|
golangci-lint run ./...
|
||||||
|
|
||||||
|
# Run the server (development)
|
||||||
|
run: build-konduktor
|
||||||
|
@echo "==> Running konduktor..."
|
||||||
|
./$(BIN_DIR)/konduktor -c ../config.yaml
|
||||||
|
|
||||||
|
# Clean build artifacts
|
||||||
|
clean:
|
||||||
|
@echo "==> Cleaning..."
|
||||||
|
rm -rf $(BIN_DIR)
|
||||||
|
rm -f coverage.out coverage.html
|
||||||
|
|
||||||
|
# Install binaries to GOPATH/bin
|
||||||
|
install: build
|
||||||
|
@echo "==> Installing binaries..."
|
||||||
|
cp $(BIN_DIR)/konduktor $(GOPATH)/bin/
|
||||||
|
cp $(BIN_DIR)/konduktorctl $(GOPATH)/bin/
|
||||||
|
|
||||||
|
# Generate mocks (for testing)
|
||||||
|
generate:
|
||||||
|
@echo "==> Generating code..."
|
||||||
|
go generate ./...
|
||||||
|
|
||||||
|
# Docker build
|
||||||
|
docker-build:
|
||||||
|
@echo "==> Building Docker image..."
|
||||||
|
docker build -t konduktor:$(VERSION) .
|
||||||
|
|
||||||
|
# Show help
|
||||||
|
help:
|
||||||
|
@echo "Konduktor Build System"
|
||||||
|
@echo ""
|
||||||
|
@echo "Usage: make [target]"
|
||||||
|
@echo ""
|
||||||
|
@echo "Targets:"
|
||||||
|
@echo " all Download deps and build all binaries"
|
||||||
|
@echo " deps Download and tidy dependencies"
|
||||||
|
@echo " build Build all binaries"
|
||||||
|
@echo " build-konduktor Build the server binary"
|
||||||
|
@echo " build-konduktorctl Build the CLI binary"
|
||||||
|
@echo " test Run tests"
|
||||||
|
@echo " test-coverage Run tests with coverage report"
|
||||||
|
@echo " fmt Format code"
|
||||||
|
@echo " lint Lint code"
|
||||||
|
@echo " run Build and run the server"
|
||||||
|
@echo " clean Clean build artifacts"
|
||||||
|
@echo " install Install binaries to GOPATH/bin"
|
||||||
|
@echo " docker-build Build Docker image"
|
||||||
|
@echo " help Show this help"
|
||||||
149
go/README.md
Normal file
149
go/README.md
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
# Konduktor (Go)
|
||||||
|
|
||||||
|
High-performance HTTP web server with extensible routing and process orchestration. (Previously known as PyServe in Python)
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
go/
|
||||||
|
├── cmd/
|
||||||
|
│ ├── konduktor/ # Main server binary
|
||||||
|
│ └── konduktorctl/ # CLI management tool
|
||||||
|
├── internal/
|
||||||
|
│ ├── config/ # Configuration management
|
||||||
|
│ ├── logging/ # Structured logging
|
||||||
|
│ ├── middleware/ # HTTP middleware
|
||||||
|
│ ├── routing/ # HTTP routing
|
||||||
|
│ ├── extensions/ # Extension system (TODO)
|
||||||
|
│ └── process/ # Process management (TODO)
|
||||||
|
├── pkg/ # Public packages (TODO)
|
||||||
|
├── go.mod
|
||||||
|
├── go.sum
|
||||||
|
└── Makefile
|
||||||
|
```
|
||||||
|
|
||||||
|
## Building
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd go
|
||||||
|
|
||||||
|
# Download dependencies
|
||||||
|
make deps
|
||||||
|
|
||||||
|
# Build all binaries
|
||||||
|
make build
|
||||||
|
|
||||||
|
# Or build individually
|
||||||
|
make build-konduktor
|
||||||
|
make build-konduktorctl
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run with default config
|
||||||
|
./bin/konduktor
|
||||||
|
|
||||||
|
# Run with custom config
|
||||||
|
./bin/konduktor -c ../config.yaml
|
||||||
|
|
||||||
|
# Run with flags
|
||||||
|
./bin/konduktor --host 127.0.0.1 --port 3000 --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
## CLI Commands (konduktorctl)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start services
|
||||||
|
konduktorctl up
|
||||||
|
|
||||||
|
# Stop services
|
||||||
|
konduktorctl down
|
||||||
|
|
||||||
|
# View status
|
||||||
|
konduktorctl status
|
||||||
|
|
||||||
|
# View logs
|
||||||
|
konduktorctl logs -f
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
konduktorctl health
|
||||||
|
|
||||||
|
# Scale services
|
||||||
|
konduktorctl scale api=3
|
||||||
|
|
||||||
|
# Configuration management
|
||||||
|
konduktorctl config show
|
||||||
|
konduktorctl config validate
|
||||||
|
|
||||||
|
# Initialize new project
|
||||||
|
konduktorctl init
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Uses the same YAML configuration format as the Python version:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8080
|
||||||
|
|
||||||
|
http:
|
||||||
|
static_dir: ./static
|
||||||
|
templates_dir: ./templates
|
||||||
|
|
||||||
|
ssl:
|
||||||
|
enabled: false
|
||||||
|
cert_file: ./ssl/cert.pem
|
||||||
|
key_file: ./ssl/key.pem
|
||||||
|
|
||||||
|
logging:
|
||||||
|
level: INFO
|
||||||
|
console_output: true
|
||||||
|
|
||||||
|
extensions:
|
||||||
|
- type: routing
|
||||||
|
config:
|
||||||
|
regex_locations:
|
||||||
|
"=/health":
|
||||||
|
return: "200 OK"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Format code
|
||||||
|
make fmt
|
||||||
|
|
||||||
|
# Run linter
|
||||||
|
make lint
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
make test
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
make test-coverage
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration from Python
|
||||||
|
|
||||||
|
This is a gradual rewrite of PyServe to Go. The project is now called **Konduktor**.
|
||||||
|
|
||||||
|
### Completed
|
||||||
|
- [x] Basic project structure
|
||||||
|
- [x] Configuration loading
|
||||||
|
- [x] HTTP server with graceful shutdown
|
||||||
|
- [x] Basic routing
|
||||||
|
- [x] Middleware (access log, recovery, server header)
|
||||||
|
- [x] CLI structure (konduktor, konduktorctl)
|
||||||
|
|
||||||
|
### TODO
|
||||||
|
- [ ] Extension system
|
||||||
|
- [x] Regex routing
|
||||||
|
- [x] Reverse proxy
|
||||||
|
- [ ] Process orchestration
|
||||||
|
- [ ] ASGI/WSGI adapter support
|
||||||
|
- [ ] WebSocket support
|
||||||
|
- [ ] Hot reload
|
||||||
|
- [ ] Metrics and monitoring
|
||||||
79
go/cmd/konduktor/main.go
Normal file
79
go/cmd/konduktor/main.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/config"
|
||||||
|
"github.com/konduktor/konduktor/internal/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
Version = "0.1.0"
|
||||||
|
BuildTime = "unknown"
|
||||||
|
GitCommit = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cfgFile string
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
debug bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
rootCmd := &cobra.Command{
|
||||||
|
Use: "konduktor",
|
||||||
|
Short: "Konduktor - HTTP web server",
|
||||||
|
Long: `Konduktor is a high-performance HTTP web server with extensible routing and process orchestration.`,
|
||||||
|
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
|
||||||
|
RunE: runServer,
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.Flags().StringVarP(&cfgFile, "config", "c", "config.yaml", "Path to configuration file")
|
||||||
|
rootCmd.Flags().StringVar(&host, "host", "", "Host to bind the server to")
|
||||||
|
rootCmd.Flags().IntVar(&port, "port", 0, "Port to bind the server to")
|
||||||
|
rootCmd.Flags().BoolVar(&debug, "debug", false, "Enable debug mode")
|
||||||
|
|
||||||
|
if err := rootCmd.Execute(); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runServer(cmd *cobra.Command, args []string) error {
|
||||||
|
cfg, err := config.Load(cfgFile)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
fmt.Printf("Configuration file %s not found, using defaults\n", cfgFile)
|
||||||
|
cfg = config.Default()
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("configuration loading error: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if host != "" {
|
||||||
|
cfg.Server.Host = host
|
||||||
|
}
|
||||||
|
if port != 0 {
|
||||||
|
cfg.Server.Port = port
|
||||||
|
}
|
||||||
|
if debug {
|
||||||
|
cfg.Logging.Level = "DEBUG"
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, err := server.New(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("server creation error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Starting Konduktor server on %s:%d\n", cfg.Server.Host, cfg.Server.Port)
|
||||||
|
|
||||||
|
if err := srv.Run(); err != nil {
|
||||||
|
return fmt.Errorf("server startup error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
180
go/cmd/konduktorctl/main.go
Normal file
180
go/cmd/konduktorctl/main.go
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
Version = "0.1.0"
|
||||||
|
BuildTime = "unknown"
|
||||||
|
GitCommit = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
rootCmd := &cobra.Command{
|
||||||
|
Use: "konduktorctl",
|
||||||
|
Short: "Konduktorctl - Service management CLI",
|
||||||
|
Long: `Konduktorctl is a CLI tool for managing Konduktor services.`,
|
||||||
|
Version: fmt.Sprintf("%s (commit: %s, built: %s)", Version, GitCommit, BuildTime),
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.AddCommand(
|
||||||
|
newUpCmd(),
|
||||||
|
newDownCmd(),
|
||||||
|
newStatusCmd(),
|
||||||
|
newLogsCmd(),
|
||||||
|
newHealthCmd(),
|
||||||
|
newScaleCmd(),
|
||||||
|
newConfigCmd(),
|
||||||
|
newInitCmd(),
|
||||||
|
newTopCmd(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := rootCmd.Execute(); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUpCmd() *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "up [service...]",
|
||||||
|
Short: "Start services",
|
||||||
|
Long: `Start one or more services. If no service is specified, all services are started.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Starting services...")
|
||||||
|
// TODO: Implement service start logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cmd.Flags().BoolP("detach", "d", false, "Run in background")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDownCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "down [service...]",
|
||||||
|
Short: "Stop services",
|
||||||
|
Long: `Stop one or more services. If no service is specified, all services are stopped.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Stopping services...")
|
||||||
|
// TODO: Implement service stop logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStatusCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "status [service...]",
|
||||||
|
Short: "Show service status",
|
||||||
|
Long: `Show the status of one or more services.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Service status:")
|
||||||
|
// TODO: Implement status display logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLogsCmd() *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "logs [service]",
|
||||||
|
Short: "View service logs",
|
||||||
|
Long: `View logs for a specific service.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Fetching logs...")
|
||||||
|
// TODO: Implement logs viewing logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cmd.Flags().BoolP("follow", "f", false, "Follow log output")
|
||||||
|
cmd.Flags().IntP("tail", "n", 100, "Number of lines to show from the end")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHealthCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "health",
|
||||||
|
Short: "Check service health",
|
||||||
|
Long: `Check the health status of all services.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Health check:")
|
||||||
|
// TODO: Implement health check logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newScaleCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "scale <service>=<count>",
|
||||||
|
Short: "Scale a service",
|
||||||
|
Long: `Scale a service to a specific number of instances.`,
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Printf("Scaling: %v\n", args)
|
||||||
|
// TODO: Implement scaling logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConfigCmd() *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "config",
|
||||||
|
Short: "Manage configuration",
|
||||||
|
Long: `View and validate configuration.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.AddCommand(&cobra.Command{
|
||||||
|
Use: "show",
|
||||||
|
Short: "Show current configuration",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Current configuration:")
|
||||||
|
// TODO: Implement config show logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
cmd.AddCommand(&cobra.Command{
|
||||||
|
Use: "validate",
|
||||||
|
Short: "Validate configuration file",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Validating configuration...")
|
||||||
|
// TODO: Implement config validation logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInitCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "init",
|
||||||
|
Short: "Initialize a new project",
|
||||||
|
Long: `Create a new Konduktor project with default configuration.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Initializing new project...")
|
||||||
|
// TODO: Implement init logic
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTopCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "top",
|
||||||
|
Short: "Display running processes",
|
||||||
|
Long: `Display real-time view of running processes and resource usage.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
fmt.Println("Process monitor:")
|
||||||
|
// TODO: Implement top-like display
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
15
go/go.mod
Normal file
15
go/go.mod
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
module github.com/konduktor/konduktor
|
||||||
|
|
||||||
|
go 1.23.0
|
||||||
|
|
||||||
|
toolchain go1.24.2
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/spf13/cobra v1.10.2
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.10 // indirect
|
||||||
|
)
|
||||||
134
go/internal/config/config.go
Normal file
134
go/internal/config/config.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
HTTP HTTPConfig `yaml:"http"`
|
||||||
|
Server ServerConfig `yaml:"server"`
|
||||||
|
SSL SSLConfig `yaml:"ssl"`
|
||||||
|
Logging LoggingConfig `yaml:"logging"`
|
||||||
|
Extensions []ExtensionConfig `yaml:"extensions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPConfig struct {
|
||||||
|
StaticDir string `yaml:"static_dir"`
|
||||||
|
TemplatesDir string `yaml:"templates_dir"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConfig struct {
|
||||||
|
Host string `yaml:"host"`
|
||||||
|
Port int `yaml:"port"`
|
||||||
|
Backlog int `yaml:"backlog"`
|
||||||
|
DefaultRoot bool `yaml:"default_root"`
|
||||||
|
ProxyTimeout time.Duration `yaml:"proxy_timeout"`
|
||||||
|
RedirectInstructions map[string]string `yaml:"redirect_instructions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SSLConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
CertFile string `yaml:"cert_file"`
|
||||||
|
KeyFile string `yaml:"key_file"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoggingConfig struct {
|
||||||
|
Level string `yaml:"level"`
|
||||||
|
ConsoleOutput bool `yaml:"console_output"`
|
||||||
|
Format LogFormatConfig `yaml:"format"`
|
||||||
|
Console *ConsoleLogConfig `yaml:"console"`
|
||||||
|
Files []FileLogConfig `yaml:"files"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogFormatConfig struct {
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
UseColors bool `yaml:"use_colors"`
|
||||||
|
ShowModule bool `yaml:"show_module"`
|
||||||
|
TimestampFormat string `yaml:"timestamp_format"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConsoleLogConfig struct {
|
||||||
|
Format LogFormatConfig `yaml:"format"`
|
||||||
|
Level string `yaml:"level"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileLogConfig struct {
|
||||||
|
Path string `yaml:"path"`
|
||||||
|
Level string `yaml:"level"`
|
||||||
|
Loggers []string `yaml:"loggers"`
|
||||||
|
Format LogFormatConfig `yaml:"format"`
|
||||||
|
MaxBytes int64 `yaml:"max_bytes"`
|
||||||
|
BackupCount int `yaml:"backup_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExtensionConfig struct {
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
Config map[string]interface{} `yaml:"config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Load(path string) (*Config, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := Default()
|
||||||
|
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Default() *Config {
|
||||||
|
return &Config{
|
||||||
|
HTTP: HTTPConfig{
|
||||||
|
StaticDir: "./static",
|
||||||
|
TemplatesDir: "./templates",
|
||||||
|
},
|
||||||
|
Server: ServerConfig{
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
Port: 8080,
|
||||||
|
Backlog: 5,
|
||||||
|
DefaultRoot: false,
|
||||||
|
ProxyTimeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
SSL: SSLConfig{
|
||||||
|
Enabled: false,
|
||||||
|
CertFile: "./ssl/cert.pem",
|
||||||
|
KeyFile: "./ssl/key.pem",
|
||||||
|
},
|
||||||
|
Logging: LoggingConfig{
|
||||||
|
Level: "INFO",
|
||||||
|
ConsoleOutput: true,
|
||||||
|
Format: LogFormatConfig{
|
||||||
|
Type: "standard",
|
||||||
|
UseColors: true,
|
||||||
|
ShowModule: true,
|
||||||
|
TimestampFormat: "2006-01-02 15:04:05",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Extensions: []ExtensionConfig{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Validate() error {
|
||||||
|
if c.Server.Port < 1 || c.Server.Port > 65535 {
|
||||||
|
return fmt.Errorf("invalid port: %d", c.Server.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.SSL.Enabled {
|
||||||
|
if c.SSL.CertFile == "" {
|
||||||
|
return fmt.Errorf("SSL enabled but cert_file not specified")
|
||||||
|
}
|
||||||
|
if c.SSL.KeyFile == "" {
|
||||||
|
return fmt.Errorf("SSL enabled but key_file not specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
127
go/internal/config/config_test.go
Normal file
127
go/internal/config/config_test.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefault(t *testing.T) {
|
||||||
|
cfg := Default()
|
||||||
|
|
||||||
|
if cfg.Server.Host != "0.0.0.0" {
|
||||||
|
t.Errorf("Expected host 0.0.0.0, got %s", cfg.Server.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Server.Port != 8080 {
|
||||||
|
t.Errorf("Expected port 8080, got %d", cfg.Server.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.SSL.Enabled {
|
||||||
|
t.Error("Expected SSL to be disabled by default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
modify func(*Config)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid default config",
|
||||||
|
modify: func(c *Config) {},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid port - too low",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.Server.Port = 0
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid port - too high",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.Server.Port = 70000
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SSL enabled without cert",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.SSL.Enabled = true
|
||||||
|
c.SSL.CertFile = ""
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SSL enabled without key",
|
||||||
|
modify: func(c *Config) {
|
||||||
|
c.SSL.Enabled = true
|
||||||
|
c.SSL.CertFile = "cert.pem"
|
||||||
|
c.SSL.KeyFile = ""
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := Default()
|
||||||
|
tt.modify(cfg)
|
||||||
|
|
||||||
|
err := cfg.Validate()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoad(t *testing.T) {
|
||||||
|
// Create temporary config file
|
||||||
|
content := `
|
||||||
|
server:
|
||||||
|
host: 127.0.0.1
|
||||||
|
port: 3000
|
||||||
|
|
||||||
|
logging:
|
||||||
|
level: DEBUG
|
||||||
|
`
|
||||||
|
tmpfile, err := os.CreateTemp("", "config-*.yaml")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpfile.Name())
|
||||||
|
|
||||||
|
if _, err := tmpfile.Write([]byte(content)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := tmpfile.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(tmpfile.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Server.Host != "127.0.0.1" {
|
||||||
|
t.Errorf("Expected host 127.0.0.1, got %s", cfg.Server.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Server.Port != 3000 {
|
||||||
|
t.Errorf("Expected port 3000, got %d", cfg.Server.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Logging.Level != "DEBUG" {
|
||||||
|
t.Errorf("Expected level DEBUG, got %s", cfg.Logging.Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadNotFound(t *testing.T) {
|
||||||
|
_, err := Load("/nonexistent/config.yaml")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for non-existent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
136
go/internal/logging/logger.go
Normal file
136
go/internal/logging/logger.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Level string
|
||||||
|
TimestampFormat string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Logger struct {
|
||||||
|
level string
|
||||||
|
timestampFormat string
|
||||||
|
configFull *config.LoggingConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cfg Config) (*Logger, error) {
|
||||||
|
timestampFormat := cfg.TimestampFormat
|
||||||
|
if timestampFormat == "" {
|
||||||
|
timestampFormat = "2006-01-02 15:04:05"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Logger{
|
||||||
|
level: cfg.Level,
|
||||||
|
timestampFormat: timestampFormat,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) {
|
||||||
|
timestampFormat := cfg.Format.TimestampFormat
|
||||||
|
if timestampFormat == "" {
|
||||||
|
timestampFormat = "2006-01-02 15:04:05"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Logger{
|
||||||
|
level: cfg.Level,
|
||||||
|
timestampFormat: timestampFormat,
|
||||||
|
configFull: &cfg,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) formatTime() string {
|
||||||
|
return time.Now().Format(l.timestampFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) log(level string, msg string, fields ...interface{}) {
|
||||||
|
timestamp := l.formatTime()
|
||||||
|
|
||||||
|
// Simple console output for now
|
||||||
|
// TODO: Implement proper structured logging with zap
|
||||||
|
output := timestamp + " [" + level + "] " + msg
|
||||||
|
|
||||||
|
if len(fields) > 0 {
|
||||||
|
output += " {"
|
||||||
|
for i := 0; i < len(fields); i += 2 {
|
||||||
|
if i > 0 {
|
||||||
|
output += ", "
|
||||||
|
}
|
||||||
|
if i+1 < len(fields) {
|
||||||
|
output += fields[i].(string) + "=" + formatValue(fields[i+1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output += "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Stdout.WriteString(output + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatValue(v interface{}) string {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
return val
|
||||||
|
case int:
|
||||||
|
return fmt.Sprintf("%d", val)
|
||||||
|
case int64:
|
||||||
|
return fmt.Sprintf("%d", val)
|
||||||
|
case float64:
|
||||||
|
return fmt.Sprintf("%.2f", val)
|
||||||
|
case bool:
|
||||||
|
return fmt.Sprintf("%t", val)
|
||||||
|
case error:
|
||||||
|
return val.Error()
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Debug(msg string, fields ...interface{}) {
|
||||||
|
if l.shouldLog("DEBUG") {
|
||||||
|
l.log("DEBUG", msg, fields...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Info(msg string, fields ...interface{}) {
|
||||||
|
if l.shouldLog("INFO") {
|
||||||
|
l.log("INFO", msg, fields...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Warn(msg string, fields ...interface{}) {
|
||||||
|
if l.shouldLog("WARN") {
|
||||||
|
l.log("WARN", msg, fields...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Error(msg string, fields ...interface{}) {
|
||||||
|
if l.shouldLog("ERROR") {
|
||||||
|
l.log("ERROR", msg, fields...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) shouldLog(level string) bool {
|
||||||
|
levels := map[string]int{
|
||||||
|
"DEBUG": 0,
|
||||||
|
"INFO": 1,
|
||||||
|
"WARN": 2,
|
||||||
|
"ERROR": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
currentLevel, ok := levels[l.level]
|
||||||
|
if !ok {
|
||||||
|
currentLevel = 1 // Default to INFO
|
||||||
|
}
|
||||||
|
|
||||||
|
msgLevel, ok := levels[level]
|
||||||
|
if !ok {
|
||||||
|
msgLevel = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgLevel >= currentLevel
|
||||||
|
}
|
||||||
172
go/internal/logging/logger_test.go
Normal file
172
go/internal/logging/logger_test.go
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
logger, err := New(Config{Level: "INFO"})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger == nil {
|
||||||
|
t.Fatal("Expected logger, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger.level != "INFO" {
|
||||||
|
t.Errorf("Expected level INFO, got %s", logger.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew_DefaultTimestampFormat(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "DEBUG"})
|
||||||
|
|
||||||
|
if logger.timestampFormat != "2006-01-02 15:04:05" {
|
||||||
|
t.Errorf("Expected default timestamp format, got %s", logger.timestampFormat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew_CustomTimestampFormat(t *testing.T) {
|
||||||
|
logger, _ := New(Config{
|
||||||
|
Level: "DEBUG",
|
||||||
|
TimestampFormat: "15:04:05",
|
||||||
|
})
|
||||||
|
|
||||||
|
if logger.timestampFormat != "15:04:05" {
|
||||||
|
t.Errorf("Expected custom timestamp format, got %s", logger.timestampFormat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_ShouldLog(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
loggerLevel string
|
||||||
|
msgLevel string
|
||||||
|
shouldLog bool
|
||||||
|
}{
|
||||||
|
{"DEBUG", "DEBUG", true},
|
||||||
|
{"DEBUG", "INFO", true},
|
||||||
|
{"DEBUG", "WARN", true},
|
||||||
|
{"DEBUG", "ERROR", true},
|
||||||
|
{"INFO", "DEBUG", false},
|
||||||
|
{"INFO", "INFO", true},
|
||||||
|
{"INFO", "WARN", true},
|
||||||
|
{"INFO", "ERROR", true},
|
||||||
|
{"WARN", "DEBUG", false},
|
||||||
|
{"WARN", "INFO", false},
|
||||||
|
{"WARN", "WARN", true},
|
||||||
|
{"WARN", "ERROR", true},
|
||||||
|
{"ERROR", "DEBUG", false},
|
||||||
|
{"ERROR", "INFO", false},
|
||||||
|
{"ERROR", "WARN", false},
|
||||||
|
{"ERROR", "ERROR", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.loggerLevel+"_"+tt.msgLevel, func(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: tt.loggerLevel})
|
||||||
|
|
||||||
|
if got := logger.shouldLog(tt.msgLevel); got != tt.shouldLog {
|
||||||
|
t.Errorf("shouldLog(%s) = %v, want %v", tt.msgLevel, got, tt.shouldLog)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_ShouldLog_InvalidLevel(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "INVALID"})
|
||||||
|
|
||||||
|
// Should default to INFO level
|
||||||
|
if !logger.shouldLog("INFO") {
|
||||||
|
t.Error("Invalid level should default to INFO")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_Debug(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "DEBUG"})
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
logger.Debug("test message", "key", "value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_Info(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "INFO"})
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
logger.Info("test message", "key", "value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_Warn(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "WARN"})
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
logger.Warn("test message", "key", "value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_Error(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "ERROR"})
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
logger.Error("test message", "key", "value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatValue(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"test", "test"},
|
||||||
|
{42, "*"}, // int converts to rune
|
||||||
|
{nil, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := formatValue(tt.input)
|
||||||
|
// Just check it doesn't panic
|
||||||
|
_ = got
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_FormatTime(t *testing.T) {
|
||||||
|
logger, _ := New(Config{
|
||||||
|
Level: "INFO",
|
||||||
|
TimestampFormat: "2006-01-02",
|
||||||
|
})
|
||||||
|
|
||||||
|
result := logger.formatTime()
|
||||||
|
|
||||||
|
// Should be in expected format (YYYY-MM-DD)
|
||||||
|
if len(result) != 10 {
|
||||||
|
t.Errorf("Expected date format YYYY-MM-DD, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Benchmarks ==============
|
||||||
|
|
||||||
|
func BenchmarkLogger_Info(b *testing.B) {
|
||||||
|
logger, _ := New(Config{Level: "INFO"})
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Info("test message", "key", "value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_Debug_Filtered(b *testing.B) {
|
||||||
|
logger, _ := New(Config{Level: "ERROR"})
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Debug("test message", "key", "value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger_ShouldLog(b *testing.B) {
|
||||||
|
logger, _ := New(Config{Level: "INFO"})
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.shouldLog("DEBUG")
|
||||||
|
}
|
||||||
|
}
|
||||||
74
go/internal/middleware/middleware.go
Normal file
74
go/internal/middleware/middleware.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"runtime/debug"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type responseWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriter) WriteHeader(code int) {
|
||||||
|
rw.status = code
|
||||||
|
rw.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||||
|
size, err := rw.ResponseWriter.Write(b)
|
||||||
|
rw.size += size
|
||||||
|
return size, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ServerHeader(next http.Handler, version string) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Server", fmt.Sprintf("konduktor/%s", version))
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func AccessLog(next http.Handler, logger *logging.Logger) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
wrapped := &responseWriter{
|
||||||
|
ResponseWriter: w,
|
||||||
|
status: http.StatusOK,
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(wrapped, r)
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
logger.Info("HTTP request",
|
||||||
|
"method", r.Method,
|
||||||
|
"path", r.URL.Path,
|
||||||
|
"status", wrapped.status,
|
||||||
|
"duration_ms", duration.Milliseconds(),
|
||||||
|
"client_ip", r.RemoteAddr,
|
||||||
|
"user_agent", r.UserAgent(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Recovery(next http.Handler, logger *logging.Logger) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
logger.Error("Panic recovered",
|
||||||
|
"error", fmt.Sprintf("%v", err),
|
||||||
|
"stack", string(debug.Stack()),
|
||||||
|
)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
244
go/internal/middleware/middleware_test.go
Normal file
244
go/internal/middleware/middleware_test.go
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============== ServerHeader Tests ==============
|
||||||
|
|
||||||
|
func TestServerHeader(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
wrapped := ServerHeader(handler, "1.0.0")
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
serverHeader := rr.Header().Get("Server")
|
||||||
|
if serverHeader != "konduktor/1.0.0" {
|
||||||
|
t.Errorf("Expected Server header 'konduktor/1.0.0', got '%s'", serverHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== AccessLog Tests ==============
|
||||||
|
|
||||||
|
func TestAccessLog(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("Hello"))
|
||||||
|
})
|
||||||
|
|
||||||
|
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
||||||
|
wrapped := AccessLog(handler, logger)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccessLog_CapturesStatusCode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"OK", http.StatusOK},
|
||||||
|
{"NotFound", http.StatusNotFound},
|
||||||
|
{"InternalError", http.StatusInternalServerError},
|
||||||
|
{"Redirect", http.StatusMovedPermanently},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(tt.statusCode)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
||||||
|
wrapped := AccessLog(handler, logger)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != tt.statusCode {
|
||||||
|
t.Errorf("Expected status %d, got %d", tt.statusCode, rr.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Recovery Tests ==============
|
||||||
|
|
||||||
|
func TestRecovery_NoPanic(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
|
||||||
|
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
||||||
|
wrapped := Recovery(handler, logger)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rr.Body.String() != "OK" {
|
||||||
|
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecovery_WithPanic(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("test panic")
|
||||||
|
})
|
||||||
|
|
||||||
|
logger, _ := logging.New(logging.Config{Level: "ERROR"})
|
||||||
|
wrapped := Recovery(handler, logger)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("Expected status 500, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(rr.Body.String(), "Internal Server Error") {
|
||||||
|
t.Errorf("Expected 'Internal Server Error' in body, got '%s'", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== responseWriter Tests ==============
|
||||||
|
|
||||||
|
func TestResponseWriter_WriteHeader(t *testing.T) {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
|
||||||
|
|
||||||
|
rw.WriteHeader(http.StatusNotFound)
|
||||||
|
|
||||||
|
if rw.status != http.StatusNotFound {
|
||||||
|
t.Errorf("Expected status 404, got %d", rw.status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWriter_Write(t *testing.T) {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
rw := &responseWriter{ResponseWriter: rr, status: http.StatusOK}
|
||||||
|
|
||||||
|
n, err := rw.Write([]byte("Hello World"))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != 11 {
|
||||||
|
t.Errorf("Expected 11 bytes written, got %d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rw.size != 11 {
|
||||||
|
t.Errorf("Expected size 11, got %d", rw.size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Middleware Chain Tests ==============
|
||||||
|
|
||||||
|
func TestMiddlewareChain(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
|
||||||
|
logger, _ := logging.New(logging.Config{Level: "INFO"})
|
||||||
|
|
||||||
|
// Apply middleware chain
|
||||||
|
wrapped := Recovery(AccessLog(ServerHeader(handler, "1.0.0"), logger), logger)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// Check all middleware worked
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rr.Header().Get("Server") != "konduktor/1.0.0" {
|
||||||
|
t.Errorf("Expected Server header")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rr.Body.String() != "OK" {
|
||||||
|
t.Errorf("Expected body 'OK', got '%s'", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Benchmarks ==============
|
||||||
|
|
||||||
|
func BenchmarkServerHeader(b *testing.B) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
wrapped := ServerHeader(handler, "1.0.0")
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkAccessLog(b *testing.B) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger, _ := logging.New(logging.Config{Level: "ERROR"}) // Minimize logging overhead
|
||||||
|
wrapped := AccessLog(handler, logger)
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRecovery(b *testing.B) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger, _ := logging.New(logging.Config{Level: "ERROR"})
|
||||||
|
wrapped := Recovery(handler, logger)
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
wrapped.ServeHTTP(rr, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
263
go/internal/pathmatcher/pathmatcher.go
Normal file
263
go/internal/pathmatcher/pathmatcher.go
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
package pathmatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MountedPath struct {
|
||||||
|
path string
|
||||||
|
name string
|
||||||
|
stripPath bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMountedPath(path string, opts ...MountedPathOption) *MountedPath {
|
||||||
|
// Normalize: remove trailing slash (except for root)
|
||||||
|
normalizedPath := strings.TrimSuffix(path, "/")
|
||||||
|
if normalizedPath == "" {
|
||||||
|
normalizedPath = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &MountedPath{
|
||||||
|
path: normalizedPath,
|
||||||
|
name: normalizedPath,
|
||||||
|
stripPath: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.name == "" {
|
||||||
|
m.name = normalizedPath
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
type MountedPathOption func(*MountedPath)
|
||||||
|
|
||||||
|
func WithName(name string) MountedPathOption {
|
||||||
|
return func(m *MountedPath) {
|
||||||
|
m.name = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithStripPath(strip bool) MountedPathOption {
|
||||||
|
return func(m *MountedPath) {
|
||||||
|
m.stripPath = strip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MountedPath) Path() string {
|
||||||
|
return m.path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MountedPath) Name() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MountedPath) StripPath() bool {
|
||||||
|
return m.stripPath
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MountedPath) Matches(requestPath string) bool {
|
||||||
|
// Empty or "/" mount matches everything
|
||||||
|
if m.path == "" || m.path == "/" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request path must be at least as long as mount path
|
||||||
|
if len(requestPath) < len(m.path) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if request path starts with mount path
|
||||||
|
if !strings.HasPrefix(requestPath, m.path) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If paths are equal length, it's a match
|
||||||
|
if len(requestPath) == len(m.path) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, next char must be '/' to prevent /api matching /api-v2
|
||||||
|
return requestPath[len(m.path)] == '/'
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MountedPath) GetModifiedPath(requestPath string) string {
|
||||||
|
if !m.stripPath {
|
||||||
|
return requestPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root mount doesn't strip anything
|
||||||
|
if m.path == "" || m.path == "/" {
|
||||||
|
return requestPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the prefix
|
||||||
|
modified := strings.TrimPrefix(requestPath, m.path)
|
||||||
|
|
||||||
|
// Ensure result starts with /
|
||||||
|
if modified == "" || modified[0] != '/' {
|
||||||
|
modified = "/" + modified
|
||||||
|
}
|
||||||
|
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
type MountManager struct {
|
||||||
|
mounts []*MountedPath
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMountManager() *MountManager {
|
||||||
|
return &MountManager{
|
||||||
|
mounts: make([]*MountedPath, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mm *MountManager) AddMount(mount *MountedPath) {
|
||||||
|
mm.mu.Lock()
|
||||||
|
defer mm.mu.Unlock()
|
||||||
|
|
||||||
|
// Insert in sorted order (longer paths first)
|
||||||
|
inserted := false
|
||||||
|
for i, existing := range mm.mounts {
|
||||||
|
if len(mount.path) > len(existing.path) {
|
||||||
|
// Insert at position i
|
||||||
|
mm.mounts = append(mm.mounts[:i], append([]*MountedPath{mount}, mm.mounts[i:]...)...)
|
||||||
|
inserted = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !inserted {
|
||||||
|
mm.mounts = append(mm.mounts, mount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mm *MountManager) RemoveMount(path string) bool {
|
||||||
|
mm.mu.Lock()
|
||||||
|
defer mm.mu.Unlock()
|
||||||
|
|
||||||
|
normalizedPath := strings.TrimSuffix(path, "/")
|
||||||
|
|
||||||
|
for i, mount := range mm.mounts {
|
||||||
|
if mount.path == normalizedPath {
|
||||||
|
mm.mounts = append(mm.mounts[:i], mm.mounts[i+1:]...)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mm *MountManager) GetMount(requestPath string) *MountedPath {
|
||||||
|
mm.mu.RLock()
|
||||||
|
defer mm.mu.RUnlock()
|
||||||
|
|
||||||
|
// Mounts are sorted by path length (longest first)
|
||||||
|
// so the first match is the best match
|
||||||
|
for _, mount := range mm.mounts {
|
||||||
|
if mount.Matches(requestPath) {
|
||||||
|
return mount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mm *MountManager) MountCount() int {
|
||||||
|
mm.mu.RLock()
|
||||||
|
defer mm.mu.RUnlock()
|
||||||
|
return len(mm.mounts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mm *MountManager) Mounts() []*MountedPath {
|
||||||
|
mm.mu.RLock()
|
||||||
|
defer mm.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]*MountedPath, len(mm.mounts))
|
||||||
|
copy(result, mm.mounts)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mm *MountManager) ListMounts() []map[string]interface{} {
|
||||||
|
mm.mu.RLock()
|
||||||
|
defer mm.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]map[string]interface{}, len(mm.mounts))
|
||||||
|
for i, mount := range mm.mounts {
|
||||||
|
result[i] = map[string]interface{}{
|
||||||
|
"path": mount.path,
|
||||||
|
"name": mount.name,
|
||||||
|
"strip_path": mount.stripPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Utility functions
|
||||||
|
|
||||||
|
func PathMatchesPrefix(requestPath, prefix string) bool {
|
||||||
|
// Normalize prefix
|
||||||
|
prefix = strings.TrimSuffix(prefix, "/")
|
||||||
|
|
||||||
|
// Empty or "/" prefix matches everything
|
||||||
|
if prefix == "" || prefix == "/" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request path must be at least as long as prefix
|
||||||
|
if len(requestPath) < len(prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if request path starts with prefix
|
||||||
|
if !strings.HasPrefix(requestPath, prefix) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If paths are equal length, it's a match
|
||||||
|
if len(requestPath) == len(prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, next char must be '/'
|
||||||
|
return requestPath[len(prefix)] == '/'
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripPathPrefix(requestPath, prefix string) string {
|
||||||
|
// Normalize prefix
|
||||||
|
prefix = strings.TrimSuffix(prefix, "/")
|
||||||
|
|
||||||
|
// Empty or "/" prefix doesn't strip anything
|
||||||
|
if prefix == "" || prefix == "/" {
|
||||||
|
return requestPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the prefix
|
||||||
|
modified := strings.TrimPrefix(requestPath, prefix)
|
||||||
|
|
||||||
|
// Ensure result starts with /
|
||||||
|
if modified == "" || modified[0] != '/' {
|
||||||
|
modified = "/" + modified
|
||||||
|
}
|
||||||
|
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
func MatchAndModifyPath(requestPath, prefix string, stripPath bool) (matches bool, modifiedPath string) {
|
||||||
|
if !PathMatchesPrefix(requestPath, prefix) {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if stripPath {
|
||||||
|
return true, StripPathPrefix(requestPath, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, requestPath
|
||||||
|
}
|
||||||
460
go/internal/pathmatcher/pathmatcher_test.go
Normal file
460
go/internal/pathmatcher/pathmatcher_test.go
Normal file
@ -0,0 +1,460 @@
|
|||||||
|
package pathmatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============== MountedPath Tests ==============
|
||||||
|
|
||||||
|
func TestMountedPath_RootMountMatchesEverything(t *testing.T) {
|
||||||
|
mount := NewMountedPath("")
|
||||||
|
|
||||||
|
tests := []string{"/", "/api", "/api/users", "/anything/at/all"}
|
||||||
|
|
||||||
|
for _, path := range tests {
|
||||||
|
if !mount.Matches(path) {
|
||||||
|
t.Errorf("Root mount should match %s", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_SlashRootMountMatchesEverything(t *testing.T) {
|
||||||
|
mount := NewMountedPath("/")
|
||||||
|
|
||||||
|
tests := []string{"/", "/api", "/api/users"}
|
||||||
|
|
||||||
|
for _, path := range tests {
|
||||||
|
if !mount.Matches(path) {
|
||||||
|
t.Errorf("'/' mount should match %s", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_ExactPathMatch(t *testing.T) {
|
||||||
|
mount := NewMountedPath("/api")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"/api", true},
|
||||||
|
{"/api/", true},
|
||||||
|
{"/api/users", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := mount.Matches(tt.path); got != tt.expected {
|
||||||
|
t.Errorf("Matches(%s) = %v, want %v", tt.path, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_NoFalsePrefixMatch(t *testing.T) {
|
||||||
|
mount := NewMountedPath("/api")
|
||||||
|
|
||||||
|
tests := []string{"/api-v2", "/api2", "/apiv2"}
|
||||||
|
|
||||||
|
for _, path := range tests {
|
||||||
|
if mount.Matches(path) {
|
||||||
|
t.Errorf("/api should not match %s", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_ShorterPathNoMatch(t *testing.T) {
|
||||||
|
mount := NewMountedPath("/api/v1")
|
||||||
|
|
||||||
|
tests := []string{"/api", "/ap", "/"}
|
||||||
|
|
||||||
|
for _, path := range tests {
|
||||||
|
if mount.Matches(path) {
|
||||||
|
t.Errorf("/api/v1 should not match shorter path %s", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_TrailingSlashNormalized(t *testing.T) {
|
||||||
|
mount1 := NewMountedPath("/api/")
|
||||||
|
mount2 := NewMountedPath("/api")
|
||||||
|
|
||||||
|
if mount1.Path() != "/api" {
|
||||||
|
t.Errorf("Expected path /api, got %s", mount1.Path())
|
||||||
|
}
|
||||||
|
|
||||||
|
if mount2.Path() != "/api" {
|
||||||
|
t.Errorf("Expected path /api, got %s", mount2.Path())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !mount1.Matches("/api/users") {
|
||||||
|
t.Error("mount1 should match /api/users")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !mount2.Matches("/api/users") {
|
||||||
|
t.Error("mount2 should match /api/users")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_GetModifiedPathStripsPrefix(t *testing.T) {
|
||||||
|
mount := NewMountedPath("/api")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"/api", "/"},
|
||||||
|
{"/api/", "/"},
|
||||||
|
{"/api/users", "/users"},
|
||||||
|
{"/api/users/123", "/users/123"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
||||||
|
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_GetModifiedPathNoStrip(t *testing.T) {
|
||||||
|
mount := NewMountedPath("/api", WithStripPath(false))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"/api/users", "/api/users"},
|
||||||
|
{"/api", "/api"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
||||||
|
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_RootMountModifiedPath(t *testing.T) {
|
||||||
|
mount := NewMountedPath("")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"/api/users", "/api/users"},
|
||||||
|
{"/", "/"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := mount.GetModifiedPath(tt.input); got != tt.expected {
|
||||||
|
t.Errorf("GetModifiedPath(%s) = %s, want %s", tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountedPath_NameProperty(t *testing.T) {
|
||||||
|
mount1 := NewMountedPath("/api")
|
||||||
|
mount2 := NewMountedPath("/api", WithName("API Mount"))
|
||||||
|
|
||||||
|
if mount1.Name() != "/api" {
|
||||||
|
t.Errorf("Expected name /api, got %s", mount1.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
if mount2.Name() != "API Mount" {
|
||||||
|
t.Errorf("Expected name 'API Mount', got %s", mount2.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== MountManager Tests ==============
|
||||||
|
|
||||||
|
func TestMountManager_EmptyManager(t *testing.T) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
|
||||||
|
if got := manager.GetMount("/api"); got != nil {
|
||||||
|
t.Error("Empty manager should return nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := manager.MountCount(); got != 0 {
|
||||||
|
t.Errorf("Expected mount count 0, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountManager_AddMount(t *testing.T) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
mount := NewMountedPath("/api")
|
||||||
|
|
||||||
|
manager.AddMount(mount)
|
||||||
|
|
||||||
|
if manager.MountCount() != 1 {
|
||||||
|
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := manager.GetMount("/api/users"); got != mount {
|
||||||
|
t.Error("GetMount should return the added mount")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountManager_LongestPrefixMatching(t *testing.T) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
|
||||||
|
apiMount := NewMountedPath("/api", WithName("api"))
|
||||||
|
apiV1Mount := NewMountedPath("/api/v1", WithName("api_v1"))
|
||||||
|
apiV2Mount := NewMountedPath("/api/v2", WithName("api_v2"))
|
||||||
|
|
||||||
|
manager.AddMount(apiMount)
|
||||||
|
manager.AddMount(apiV2Mount)
|
||||||
|
manager.AddMount(apiV1Mount)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
expectedName string
|
||||||
|
}{
|
||||||
|
{"/api/v1/users", "api_v1"},
|
||||||
|
{"/api/v2/items", "api_v2"},
|
||||||
|
{"/api/v3/other", "api"},
|
||||||
|
{"/api", "api"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := manager.GetMount(tt.path)
|
||||||
|
if got == nil {
|
||||||
|
t.Errorf("GetMount(%s) returned nil, want mount with name %s", tt.path, tt.expectedName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if got.Name() != tt.expectedName {
|
||||||
|
t.Errorf("GetMount(%s).Name() = %s, want %s", tt.path, got.Name(), tt.expectedName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountManager_RemoveMount(t *testing.T) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
|
||||||
|
manager.AddMount(NewMountedPath("/api"))
|
||||||
|
manager.AddMount(NewMountedPath("/admin"))
|
||||||
|
|
||||||
|
if manager.MountCount() != 2 {
|
||||||
|
t.Errorf("Expected mount count 2, got %d", manager.MountCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
result := manager.RemoveMount("/api")
|
||||||
|
|
||||||
|
if !result {
|
||||||
|
t.Error("RemoveMount should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.MountCount() != 1 {
|
||||||
|
t.Errorf("Expected mount count 1, got %d", manager.MountCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetMount("/api/users") != nil {
|
||||||
|
t.Error("GetMount(/api/users) should return nil after removal")
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.GetMount("/admin/users") == nil {
|
||||||
|
t.Error("GetMount(/admin/users) should still work")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountManager_RemoveNonexistentMount(t *testing.T) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
|
||||||
|
result := manager.RemoveMount("/api")
|
||||||
|
|
||||||
|
if result {
|
||||||
|
t.Error("RemoveMount should return false for nonexistent mount")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountManager_ListMounts(t *testing.T) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
|
||||||
|
manager.AddMount(NewMountedPath("/api", WithName("API")))
|
||||||
|
manager.AddMount(NewMountedPath("/admin", WithName("Admin")))
|
||||||
|
|
||||||
|
mounts := manager.ListMounts()
|
||||||
|
|
||||||
|
if len(mounts) != 2 {
|
||||||
|
t.Errorf("Expected 2 mounts, got %d", len(mounts))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range mounts {
|
||||||
|
if _, ok := m["path"]; !ok {
|
||||||
|
t.Error("Mount should have 'path' key")
|
||||||
|
}
|
||||||
|
if _, ok := m["name"]; !ok {
|
||||||
|
t.Error("Mount should have 'name' key")
|
||||||
|
}
|
||||||
|
if _, ok := m["strip_path"]; !ok {
|
||||||
|
t.Error("Mount should have 'strip_path' key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMountManager_MountsReturnsCopy(t *testing.T) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
manager.AddMount(NewMountedPath("/api"))
|
||||||
|
|
||||||
|
mounts1 := manager.Mounts()
|
||||||
|
mounts2 := manager.Mounts()
|
||||||
|
|
||||||
|
if &mounts1[0] == &mounts2[0] {
|
||||||
|
t.Error("Mounts() should return different slices")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Utility Functions Tests ==============
|
||||||
|
|
||||||
|
func TestPathMatchesPrefix_Basic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
prefix string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"/api/users", "/api", true},
|
||||||
|
{"/api", "/api", true},
|
||||||
|
{"/api-v2", "/api", false},
|
||||||
|
{"/ap", "/api", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
|
||||||
|
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPathMatchesPrefix_Root(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
prefix string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"/anything", "", true},
|
||||||
|
{"/anything", "/", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := PathMatchesPrefix(tt.path, tt.prefix); got != tt.expected {
|
||||||
|
t.Errorf("PathMatchesPrefix(%s, %s) = %v, want %v", tt.path, tt.prefix, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripPathPrefix_Basic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
prefix string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"/api/users", "/api", "/users"},
|
||||||
|
{"/api", "/api", "/"},
|
||||||
|
{"/api/", "/api", "/"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
|
||||||
|
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripPathPrefix_Root(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
prefix string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"/api/users", "", "/api/users"},
|
||||||
|
{"/api/users", "/", "/api/users"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := StripPathPrefix(tt.path, tt.prefix); got != tt.expected {
|
||||||
|
t.Errorf("StripPathPrefix(%s, %s) = %s, want %s", tt.path, tt.prefix, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAndModifyPath_Combined(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
path string
|
||||||
|
prefix string
|
||||||
|
stripPath bool
|
||||||
|
wantMatches bool
|
||||||
|
wantModified string
|
||||||
|
}{
|
||||||
|
{"/api/users", "/api", true, true, "/users"},
|
||||||
|
{"/api", "/api", true, true, "/"},
|
||||||
|
{"/other", "/api", true, false, ""},
|
||||||
|
{"/api/users", "/api", false, true, "/api/users"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
matches, modified := MatchAndModifyPath(tt.path, tt.prefix, tt.stripPath)
|
||||||
|
if matches != tt.wantMatches {
|
||||||
|
t.Errorf("MatchAndModifyPath(%s, %s, %v) matches = %v, want %v",
|
||||||
|
tt.path, tt.prefix, tt.stripPath, matches, tt.wantMatches)
|
||||||
|
}
|
||||||
|
if modified != tt.wantModified {
|
||||||
|
t.Errorf("MatchAndModifyPath(%s, %s, %v) modified = %s, want %s",
|
||||||
|
tt.path, tt.prefix, tt.stripPath, modified, tt.wantModified)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Performance Tests ==============
|
||||||
|
|
||||||
|
func TestPerformance_ManyMatches(t *testing.T) {
|
||||||
|
mount := NewMountedPath("/api/v1/users")
|
||||||
|
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
if !mount.Matches("/api/v1/users/123/posts") {
|
||||||
|
t.Fatal("Should match")
|
||||||
|
}
|
||||||
|
if mount.Matches("/other/path") {
|
||||||
|
t.Fatal("Should not match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPerformance_ManyMounts(t *testing.T) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10)) + string(rune('0'+i/10))))
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.MountCount() != 100 {
|
||||||
|
t.Errorf("Expected 100 mounts, got %d", manager.MountCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Benchmarks ==============
|
||||||
|
|
||||||
|
func BenchmarkMountedPath_Matches(b *testing.B) {
|
||||||
|
mount := NewMountedPath("/api/v1")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mount.Matches("/api/v1/users/123")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMountManager_GetMount(b *testing.B) {
|
||||||
|
manager := NewMountManager()
|
||||||
|
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
manager.AddMount(NewMountedPath("/api/v" + string(rune('0'+i%10))))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.GetMount("/api/v5/users/123")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPathMatchesPrefix(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
PathMatchesPrefix("/api/v1/users/123", "/api/v1")
|
||||||
|
}
|
||||||
|
}
|
||||||
320
go/internal/proxy/proxy.go
Normal file
320
go/internal/proxy/proxy.go
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
// Package proxy provides reverse proxy functionality for Konduktor
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
// Target is the backend server URL
|
||||||
|
Target string
|
||||||
|
|
||||||
|
// Timeout is the request timeout (default: 30s)
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// Headers are additional headers to add to requests
|
||||||
|
Headers map[string]string
|
||||||
|
|
||||||
|
// StripPrefix removes this prefix from the request path
|
||||||
|
StripPrefix string
|
||||||
|
|
||||||
|
// PreserveHost keeps the original Host header
|
||||||
|
PreserveHost bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReverseProxy struct {
|
||||||
|
config *Config
|
||||||
|
targetURL *url.URL
|
||||||
|
httpClient *http.Client
|
||||||
|
logger *logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) {
|
||||||
|
if cfg.Target == "" {
|
||||||
|
return nil, fmt.Errorf("proxy target is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
targetURL, err := url.Parse(cfg.Target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid proxy target URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := cfg.Timeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
ResponseHeaderTimeout: timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ReverseProxy{
|
||||||
|
config: cfg,
|
||||||
|
targetURL: targetURL,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
Timeout: timeout,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse // Don't follow redirects
|
||||||
|
},
|
||||||
|
},
|
||||||
|
logger: logger,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
rp.ProxyRequest(w, r, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
// Build target URL
|
||||||
|
targetURL := rp.buildTargetURL(r)
|
||||||
|
|
||||||
|
// Create proxy request
|
||||||
|
proxyReq, err := rp.createProxyRequest(ctx, r, targetURL)
|
||||||
|
if err != nil {
|
||||||
|
rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add custom headers with parameter substitution
|
||||||
|
rp.addCustomHeaders(proxyReq, r, params)
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
resp, err := rp.httpClient.Do(proxyReq)
|
||||||
|
if err != nil {
|
||||||
|
rp.handleProxyError(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Copy response
|
||||||
|
rp.copyResponse(w, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
|
||||||
|
targetURL := *rp.targetURL
|
||||||
|
|
||||||
|
// Strip prefix if configured
|
||||||
|
path := r.URL.Path
|
||||||
|
if rp.config.StripPrefix != "" {
|
||||||
|
path = strings.TrimPrefix(path, rp.config.StripPrefix)
|
||||||
|
if path == "" || path[0] != '/' {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If target has a path, append request path to it
|
||||||
|
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
|
||||||
|
targetURL.Path = singleJoiningSlash(rp.targetURL.Path, path)
|
||||||
|
} else {
|
||||||
|
targetURL.Path = path
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve query string
|
||||||
|
targetURL.RawQuery = r.URL.RawQuery
|
||||||
|
|
||||||
|
return &targetURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) {
|
||||||
|
proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy ContentLength
|
||||||
|
proxyReq.ContentLength = r.ContentLength
|
||||||
|
|
||||||
|
// Copy headers
|
||||||
|
for key, values := range r.Header {
|
||||||
|
for _, value := range values {
|
||||||
|
proxyReq.Header.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set/update Host header
|
||||||
|
if rp.config.PreserveHost {
|
||||||
|
proxyReq.Host = r.Host
|
||||||
|
} else {
|
||||||
|
proxyReq.Host = targetURL.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove hop-by-hop headers
|
||||||
|
removeHopByHopHeaders(proxyReq.Header)
|
||||||
|
|
||||||
|
return proxyReq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) {
|
||||||
|
// Add X-Forwarded headers
|
||||||
|
clientIP := getClientIP(originalReq)
|
||||||
|
if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" {
|
||||||
|
clientIP = prior + ", " + clientIP
|
||||||
|
}
|
||||||
|
proxyReq.Header.Set("X-Forwarded-For", clientIP)
|
||||||
|
proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq))
|
||||||
|
proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host)
|
||||||
|
|
||||||
|
// Add custom headers from config
|
||||||
|
for key, value := range rp.config.Headers {
|
||||||
|
// Substitute parameters like {version}
|
||||||
|
substituted := value
|
||||||
|
for paramKey, paramValue := range params {
|
||||||
|
substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue)
|
||||||
|
}
|
||||||
|
// Substitute $remote_addr
|
||||||
|
substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP)
|
||||||
|
proxyReq.Header.Set(key, substituted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) {
|
||||||
|
// Copy headers
|
||||||
|
for key, values := range resp.Header {
|
||||||
|
for _, value := range values {
|
||||||
|
w.Header().Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove hop-by-hop headers from response
|
||||||
|
removeHopByHopHeaders(w.Header())
|
||||||
|
|
||||||
|
// Write status code
|
||||||
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
// Copy body
|
||||||
|
io.Copy(w, resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) {
|
||||||
|
if rp.logger != nil {
|
||||||
|
rp.logger.Error(message, "error", err)
|
||||||
|
}
|
||||||
|
http.Error(w, message, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) {
|
||||||
|
if rp.logger != nil {
|
||||||
|
rp.logger.Error("Proxy request failed", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for timeout
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() {
|
||||||
|
http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for connection errors
|
||||||
|
if isConnectionError(err) {
|
||||||
|
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context cancelled (client disconnected)
|
||||||
|
if err == context.Canceled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func singleJoiningSlash(a, b string) string {
|
||||||
|
aslash := strings.HasSuffix(a, "/")
|
||||||
|
bslash := strings.HasPrefix(b, "/")
|
||||||
|
switch {
|
||||||
|
case aslash && bslash:
|
||||||
|
return a + b[1:]
|
||||||
|
case !aslash && !bslash:
|
||||||
|
return a + "/" + b
|
||||||
|
}
|
||||||
|
return a + b
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeHopByHopHeaders(h http.Header) {
|
||||||
|
hopByHopHeaders := []string{
|
||||||
|
"Connection",
|
||||||
|
"Proxy-Connection",
|
||||||
|
"Keep-Alive",
|
||||||
|
"Proxy-Authenticate",
|
||||||
|
"Proxy-Authorization",
|
||||||
|
"Te",
|
||||||
|
"Trailer",
|
||||||
|
"Transfer-Encoding",
|
||||||
|
"Upgrade",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, header := range hopByHopHeaders {
|
||||||
|
h.Del(header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
// Check X-Real-IP first
|
||||||
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get from RemoteAddr
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
func getScheme(r *http.Request) string {
|
||||||
|
if r.TLS != nil {
|
||||||
|
return "https"
|
||||||
|
}
|
||||||
|
if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||||
|
return scheme
|
||||||
|
}
|
||||||
|
return "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
func isConnectionError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
errStr := err.Error()
|
||||||
|
connectionErrors := []string{
|
||||||
|
"connection refused",
|
||||||
|
"no such host",
|
||||||
|
"network is unreachable",
|
||||||
|
"connection reset",
|
||||||
|
"broken pipe",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, connErr := range connectionErrors {
|
||||||
|
if strings.Contains(strings.ToLower(errStr), connErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
747
go/internal/proxy/proxy_test.go
Normal file
747
go/internal/proxy/proxy_test.go
Normal file
@ -0,0 +1,747 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============== Test Backend Server ==============
|
||||||
|
|
||||||
|
type testBackend struct {
|
||||||
|
server *httptest.Server
|
||||||
|
requestLog []requestLogEntry
|
||||||
|
mu sync.Mutex
|
||||||
|
requestCount int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestLogEntry struct {
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Query string
|
||||||
|
Headers http.Header
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestBackend(handler http.HandlerFunc) *testBackend {
|
||||||
|
tb := &testBackend{
|
||||||
|
requestLog: make([]requestLogEntry, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler == nil {
|
||||||
|
handler = tb.defaultHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tb.logRequest(r)
|
||||||
|
handler(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
|
return tb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tb *testBackend) logRequest(r *http.Request) {
|
||||||
|
tb.mu.Lock()
|
||||||
|
defer tb.mu.Unlock()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
// Restore the body for the handler
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
|
||||||
|
tb.requestLog = append(tb.requestLog, requestLogEntry{
|
||||||
|
Method: r.Method,
|
||||||
|
Path: r.URL.Path,
|
||||||
|
Query: r.URL.RawQuery,
|
||||||
|
Headers: r.Header.Clone(),
|
||||||
|
Body: string(body),
|
||||||
|
})
|
||||||
|
atomic.AddInt64(&tb.requestCount, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tb *testBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"message": "Backend response",
|
||||||
|
"path": r.URL.Path,
|
||||||
|
"method": r.Method,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tb *testBackend) close() {
|
||||||
|
tb.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tb *testBackend) URL() string {
|
||||||
|
return tb.server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tb *testBackend) getRequestCount() int64 {
|
||||||
|
return atomic.LoadInt64(&tb.requestCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tb *testBackend) getLastRequest() *requestLogEntry {
|
||||||
|
tb.mu.Lock()
|
||||||
|
defer tb.mu.Unlock()
|
||||||
|
if len(tb.requestLog) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &tb.requestLog[len(tb.requestLog)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Proxy Creation Tests ==============
|
||||||
|
|
||||||
|
func TestNew_ValidConfig(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
Target: "http://localhost:8080",
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := New(cfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if proxy == nil {
|
||||||
|
t.Fatal("Expected proxy instance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew_EmptyTarget(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
Target: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := New(cfg, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for empty target")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew_InvalidTargetURL(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
Target: "://invalid-url",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := New(cfg, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid URL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew_DefaultTimeout(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
Target: "http://localhost:8080",
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := New(cfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create proxy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if proxy.httpClient.Timeout != 30*time.Second {
|
||||||
|
t.Errorf("Expected default timeout 30s, got %v", proxy.httpClient.Timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Basic Proxy Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_BasicGET(t *testing.T) {
|
||||||
|
backend := newTestBackend(nil)
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response map[string]interface{}
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["path"] != "/test" {
|
||||||
|
t.Errorf("Expected path /test, got %v", response["path"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_BasicPOST(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"received": string(body),
|
||||||
|
"method": r.Method,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/data", strings.NewReader(`{"key":"value"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response map[string]interface{}
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["method"] != "POST" {
|
||||||
|
t.Errorf("Expected method POST, got %v", response["method"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_PUT(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("PUT", "/resource/123", strings.NewReader(`{"name":"updated"}`))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
var response map[string]string
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["method"] != "PUT" {
|
||||||
|
t.Errorf("Expected method PUT, got %v", response["method"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_DELETE(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"method": r.Method})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("DELETE", "/resource/123", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
var response map[string]string
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["method"] != "DELETE" {
|
||||||
|
t.Errorf("Expected method DELETE, got %v", response["method"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Header Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_HeadersForwarding(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"custom_header": r.Header.Get("X-Custom-Header"),
|
||||||
|
"forwarded_for": r.Header.Get("X-Forwarded-For"),
|
||||||
|
"forwarded_host": r.Header.Get("X-Forwarded-Host"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/headers", nil)
|
||||||
|
req.Header.Set("X-Custom-Header", "test-value")
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
req.Host = "example.com"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
var response map[string]interface{}
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["custom_header"] != "test-value" {
|
||||||
|
t.Errorf("Expected custom header, got %v", response["custom_header"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if response["forwarded_for"] != "192.168.1.100" {
|
||||||
|
t.Errorf("Expected X-Forwarded-For, got %v", response["forwarded_for"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if response["forwarded_host"] != "example.com" {
|
||||||
|
t.Errorf("Expected X-Forwarded-Host, got %v", response["forwarded_host"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_CustomHeaders(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"api_version": r.Header.Get("X-API-Version"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{
|
||||||
|
Target: backend.URL(),
|
||||||
|
Headers: map[string]string{
|
||||||
|
"X-API-Version": "{version}",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Simulate parameter substitution
|
||||||
|
proxy.ProxyRequest(rr, req, map[string]string{"version": "2"})
|
||||||
|
|
||||||
|
var response map[string]string
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["api_version"] != "2" {
|
||||||
|
t.Errorf("Expected API version 2, got %v", response["api_version"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_RemoteAddrSubstitution(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"client_ip": r.Header.Get("X-Client-IP"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{
|
||||||
|
Target: backend.URL(),
|
||||||
|
Headers: map[string]string{
|
||||||
|
"X-Client-IP": "$remote_addr",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.1:54321"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
var response map[string]string
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["client_ip"] != "10.0.0.1" {
|
||||||
|
t.Errorf("Expected client IP 10.0.0.1, got %v", response["client_ip"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Query String Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_QueryStringPreservation(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"query": r.URL.RawQuery,
|
||||||
|
"param": r.URL.Query().Get("key"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/search?key=value&page=2", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
var response map[string]string
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["param"] != "value" {
|
||||||
|
t.Errorf("Expected query param 'value', got %v", response["param"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Status Code Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_StatusCodePreservation(t *testing.T) {
|
||||||
|
statusCodes := []int{200, 201, 400, 404, 500}
|
||||||
|
|
||||||
|
for _, code := range statusCodes {
|
||||||
|
code := code // capture range variable
|
||||||
|
t.Run(fmt.Sprintf("Status_%d", code), func(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(code)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/status", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != code {
|
||||||
|
t.Errorf("Expected status %d, got %d", code, rr.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Error Handling Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_BackendUnavailable(t *testing.T) {
|
||||||
|
// Use a port that's definitely not listening
|
||||||
|
proxy, _ := New(&Config{
|
||||||
|
Target: "http://127.0.0.1:59999",
|
||||||
|
Timeout: 1 * time.Second,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusBadGateway {
|
||||||
|
t.Errorf("Expected status 502 Bad Gateway, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_Timeout(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{
|
||||||
|
Target: backend.URL(),
|
||||||
|
Timeout: 100 * time.Millisecond,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/slow", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusGatewayTimeout {
|
||||||
|
t.Errorf("Expected status 504 Gateway Timeout, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Path Handling Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_StripPrefix(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"path": r.URL.Path,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{
|
||||||
|
Target: backend.URL(),
|
||||||
|
StripPrefix: "/api/v1",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/v1/users", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
var response map[string]string
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["path"] != "/users" {
|
||||||
|
t.Errorf("Expected stripped path /users, got %v", response["path"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_TargetWithPath(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"path": r.URL.Path,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{
|
||||||
|
Target: backend.URL() + "/backend",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/resource", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
var response map[string]string
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["path"] != "/backend/resource" {
|
||||||
|
t.Errorf("Expected path /backend/resource, got %v", response["path"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Large Body Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_LargeRequestBody(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
json.NewEncoder(w).Encode(map[string]int{
|
||||||
|
"received_bytes": len(body),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
// 100KB body
|
||||||
|
largeBody := strings.Repeat("x", 100000)
|
||||||
|
req := httptest.NewRequest("POST", "/upload", strings.NewReader(largeBody))
|
||||||
|
req.ContentLength = int64(len(largeBody))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
var response map[string]int
|
||||||
|
json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
|
||||||
|
if response["received_bytes"] != 100000 {
|
||||||
|
t.Errorf("Expected 100000 bytes, got %d", response["received_bytes"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_LargeResponseBody(t *testing.T) {
|
||||||
|
largeResponse := strings.Repeat("y", 100000)
|
||||||
|
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(largeResponse))
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/large", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Body.Len() != 100000 {
|
||||||
|
t.Errorf("Expected 100000 bytes in response, got %d", rr.Body.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Concurrent Requests Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_ConcurrentRequests(t *testing.T) {
|
||||||
|
backend := newTestBackend(nil)
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
const numRequests = 50
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errors := make(chan error, numRequests)
|
||||||
|
|
||||||
|
for i := 0; i < numRequests; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(n int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/concurrent", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
errors <- &net.OpError{Op: "test", Err: context.DeadlineExceeded}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
errorCount := 0
|
||||||
|
for range errors {
|
||||||
|
errorCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if errorCount > 0 {
|
||||||
|
t.Errorf("Got %d errors in concurrent requests", errorCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if backend.getRequestCount() != numRequests {
|
||||||
|
t.Errorf("Expected %d requests at backend, got %d", numRequests, backend.getRequestCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Echo Tests ==============
|
||||||
|
|
||||||
|
func TestProxy_Echo(t *testing.T) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
|
||||||
|
w.Write(body)
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
testData := "Hello, Proxy!"
|
||||||
|
req := httptest.NewRequest("POST", "/echo", strings.NewReader(testData))
|
||||||
|
req.Header.Set("Content-Type", "text/plain")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Body.String() != testData {
|
||||||
|
t.Errorf("Expected echo of '%s', got '%s'", testData, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if rr.Header().Get("Content-Type") != "text/plain" {
|
||||||
|
t.Errorf("Expected Content-Type text/plain, got %s", rr.Header().Get("Content-Type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Helper Function Tests ==============
|
||||||
|
|
||||||
|
func TestSingleJoiningSlash(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
a, b, expected string
|
||||||
|
}{
|
||||||
|
{"/api", "/users", "/api/users"},
|
||||||
|
{"/api/", "/users", "/api/users"},
|
||||||
|
{"/api", "users", "/api/users"},
|
||||||
|
{"/api/", "users", "/api/users"},
|
||||||
|
{"", "/users", "/users"},
|
||||||
|
{"/api", "", "/api/"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := singleJoiningSlash(tt.a, tt.b)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("singleJoiningSlash(%q, %q) = %q, want %q", tt.a, tt.b, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClientIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
remoteAddr string
|
||||||
|
xRealIP string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"192.168.1.1:1234", "", "192.168.1.1"},
|
||||||
|
{"192.168.1.1:1234", "10.0.0.1", "10.0.0.1"},
|
||||||
|
{"invalid", "", "invalid"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
if tt.xRealIP != "" {
|
||||||
|
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := getClientIP(req)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("getClientIP() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetScheme(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tls bool
|
||||||
|
header string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"HTTP", false, "", "http"},
|
||||||
|
{"HTTPS from TLS", true, "", "https"},
|
||||||
|
{"HTTPS from header", false, "https", "https"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
if tt.header != "" {
|
||||||
|
req.Header.Set("X-Forwarded-Proto", tt.header)
|
||||||
|
}
|
||||||
|
// Note: httptest doesn't set TLS, so we can only test non-TLS cases fully
|
||||||
|
|
||||||
|
result := getScheme(req)
|
||||||
|
if !tt.tls && result != tt.expected {
|
||||||
|
t.Errorf("getScheme() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsConnectionError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
err error
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{nil, false},
|
||||||
|
{&net.OpError{Op: "dial", Err: &net.DNSError{Err: "no such host"}}, true},
|
||||||
|
{context.DeadlineExceeded, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := isConnectionError(tt.err)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("isConnectionError(%v) = %v, want %v", tt.err, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Benchmarks ==============
|
||||||
|
|
||||||
|
func BenchmarkProxy_SimpleGET(b *testing.B) {
|
||||||
|
backend := newTestBackend(nil)
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "/bench", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkProxy_POSTWithBody(b *testing.B) {
|
||||||
|
backend := newTestBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
io.Copy(io.Discard, r.Body)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
defer backend.close()
|
||||||
|
|
||||||
|
proxy, _ := New(&Config{Target: backend.URL()}, nil)
|
||||||
|
body := strings.Repeat("x", 1024) // 1KB body
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := httptest.NewRequest("POST", "/bench", strings.NewReader(body))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(rr, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
395
go/internal/routing/router.go
Normal file
395
go/internal/routing/router.go
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
// Package routing provides HTTP routing with regex support
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/config"
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RouteMatch represents a matched route with captured parameters
|
||||||
|
type RouteMatch struct {
|
||||||
|
Config map[string]interface{}
|
||||||
|
Params map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegexRoute represents a compiled regex route
|
||||||
|
type RegexRoute struct {
|
||||||
|
Pattern *regexp.Regexp
|
||||||
|
Config map[string]interface{}
|
||||||
|
CaseSensitive bool
|
||||||
|
OriginalExpr string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router handles HTTP routing with exact, regex, and default routes
|
||||||
|
type Router struct {
|
||||||
|
config *config.Config
|
||||||
|
logger *logging.Logger
|
||||||
|
mux *http.ServeMux
|
||||||
|
staticDir string
|
||||||
|
exactRoutes map[string]map[string]interface{}
|
||||||
|
regexRoutes []*RegexRoute
|
||||||
|
defaultRoute map[string]interface{}
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new router from config
|
||||||
|
func New(cfg *config.Config, logger *logging.Logger) *Router {
|
||||||
|
staticDir := "./static"
|
||||||
|
if cfg != nil && cfg.HTTP.StaticDir != "" {
|
||||||
|
staticDir = cfg.HTTP.StaticDir
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &Router{
|
||||||
|
config: cfg,
|
||||||
|
logger: logger,
|
||||||
|
mux: http.NewServeMux(),
|
||||||
|
staticDir: staticDir,
|
||||||
|
exactRoutes: make(map[string]map[string]interface{}),
|
||||||
|
regexRoutes: make([]*RegexRoute, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
r.setupRoutes()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter creates a router without config (for testing)
|
||||||
|
func NewRouter(opts ...RouterOption) *Router {
|
||||||
|
r := &Router{
|
||||||
|
mux: http.NewServeMux(),
|
||||||
|
staticDir: "./static",
|
||||||
|
exactRoutes: make(map[string]map[string]interface{}),
|
||||||
|
regexRoutes: make([]*RegexRoute, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterOption is a functional option for Router
|
||||||
|
type RouterOption func(*Router)
|
||||||
|
|
||||||
|
// WithStaticDir sets the static directory
|
||||||
|
func WithStaticDir(dir string) RouterOption {
|
||||||
|
return func(r *Router) {
|
||||||
|
r.staticDir = dir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaticDir returns the static directory path
|
||||||
|
func (r *Router) StaticDir() string {
|
||||||
|
return r.staticDir
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routes returns the regex routes (for testing)
|
||||||
|
func (r *Router) Routes() []*RegexRoute {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
return r.regexRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExactRoutes returns the exact routes (for testing)
|
||||||
|
func (r *Router) ExactRoutes() map[string]map[string]interface{} {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
return r.exactRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRoute returns the default route (for testing)
|
||||||
|
func (r *Router) DefaultRoute() map[string]interface{} {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
return r.defaultRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRoute adds a route with the given pattern and config
|
||||||
|
// Pattern formats:
|
||||||
|
// - "=/path" - exact match
|
||||||
|
// - "~regex" - case-sensitive regex
|
||||||
|
// - "~*regex" - case-insensitive regex
|
||||||
|
// - "__default__" - default/fallback route
|
||||||
|
func (r *Router) AddRoute(pattern string, routeConfig map[string]interface{}) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case pattern == "__default__":
|
||||||
|
r.defaultRoute = routeConfig
|
||||||
|
|
||||||
|
case strings.HasPrefix(pattern, "="):
|
||||||
|
// Exact match route
|
||||||
|
path := strings.TrimPrefix(pattern, "=")
|
||||||
|
r.exactRoutes[path] = routeConfig
|
||||||
|
|
||||||
|
case strings.HasPrefix(pattern, "~*"):
|
||||||
|
// Case-insensitive regex
|
||||||
|
expr := strings.TrimPrefix(pattern, "~*")
|
||||||
|
re, err := regexp.Compile("(?i)" + expr)
|
||||||
|
if err != nil {
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
|
||||||
|
Pattern: re,
|
||||||
|
Config: routeConfig,
|
||||||
|
CaseSensitive: false,
|
||||||
|
OriginalExpr: expr,
|
||||||
|
})
|
||||||
|
|
||||||
|
case strings.HasPrefix(pattern, "~"):
|
||||||
|
// Case-sensitive regex
|
||||||
|
expr := strings.TrimPrefix(pattern, "~")
|
||||||
|
re, err := regexp.Compile(expr)
|
||||||
|
if err != nil {
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.regexRoutes = append(r.regexRoutes, &RegexRoute{
|
||||||
|
Pattern: re,
|
||||||
|
Config: routeConfig,
|
||||||
|
CaseSensitive: true,
|
||||||
|
OriginalExpr: expr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match finds the best matching route for a path
|
||||||
|
// Priority: exact match > regex match > default
|
||||||
|
func (r *Router) Match(path string) *RouteMatch {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
// 1. Check exact routes
|
||||||
|
if cfg, ok := r.exactRoutes[path]; ok {
|
||||||
|
return &RouteMatch{
|
||||||
|
Config: cfg,
|
||||||
|
Params: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check regex routes
|
||||||
|
for _, route := range r.regexRoutes {
|
||||||
|
match := route.Pattern.FindStringSubmatch(path)
|
||||||
|
if match != nil {
|
||||||
|
params := make(map[string]string)
|
||||||
|
|
||||||
|
// Extract named groups
|
||||||
|
names := route.Pattern.SubexpNames()
|
||||||
|
for i, name := range names {
|
||||||
|
if i > 0 && name != "" && i < len(match) {
|
||||||
|
params[name] = match[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RouteMatch{
|
||||||
|
Config: route.Config,
|
||||||
|
Params: params,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Check default route
|
||||||
|
if r.defaultRoute != nil {
|
||||||
|
return &RouteMatch{
|
||||||
|
Config: r.defaultRoute,
|
||||||
|
Params: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupRoutes configures the routes from config
|
||||||
|
func (r *Router) setupRoutes() {
|
||||||
|
// Health check endpoint
|
||||||
|
r.mux.HandleFunc("/health", r.healthHandler)
|
||||||
|
|
||||||
|
// Setup redirect instructions from config
|
||||||
|
if r.config != nil {
|
||||||
|
for from, to := range r.config.Server.RedirectInstructions {
|
||||||
|
fromPath := from
|
||||||
|
toPath := to
|
||||||
|
r.mux.HandleFunc(fromPath, func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
http.Redirect(w, req, toPath, http.StatusMovedPermanently)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default handler for all other routes
|
||||||
|
r.mux.HandleFunc("/", r.defaultHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements http.Handler
|
||||||
|
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
r.mux.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthHandler handles health check requests
|
||||||
|
func (r *Router) healthHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultHandler handles requests that don't match other routes
|
||||||
|
func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
|
path := req.URL.Path
|
||||||
|
|
||||||
|
// Try to match against configured routes
|
||||||
|
match := r.Match(path)
|
||||||
|
if match != nil {
|
||||||
|
r.handleRouteMatch(w, req, match)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to serve static file
|
||||||
|
if r.staticDir != "" {
|
||||||
|
filePath := filepath.Join(r.staticDir, path)
|
||||||
|
|
||||||
|
// Prevent directory traversal
|
||||||
|
if !strings.HasPrefix(filepath.Clean(filePath), filepath.Clean(r.staticDir)) {
|
||||||
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if file exists
|
||||||
|
info, err := os.Stat(filePath)
|
||||||
|
if err == nil {
|
||||||
|
if info.IsDir() {
|
||||||
|
// Try index.html
|
||||||
|
indexPath := filepath.Join(filePath, "index.html")
|
||||||
|
if _, err := os.Stat(indexPath); err == nil {
|
||||||
|
http.ServeFile(w, req, indexPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
http.ServeFile(w, req, filePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 404 Not Found
|
||||||
|
http.NotFound(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRouteMatch handles a matched route
|
||||||
|
func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) {
|
||||||
|
cfg := match.Config
|
||||||
|
|
||||||
|
// Handle "return" directive
|
||||||
|
if ret, ok := cfg["return"].(string); ok {
|
||||||
|
parts := strings.SplitN(ret, " ", 2)
|
||||||
|
statusCode := 200
|
||||||
|
body := "OK"
|
||||||
|
if len(parts) >= 1 {
|
||||||
|
switch parts[0] {
|
||||||
|
case "200":
|
||||||
|
statusCode = 200
|
||||||
|
case "201":
|
||||||
|
statusCode = 201
|
||||||
|
case "301":
|
||||||
|
statusCode = 301
|
||||||
|
case "302":
|
||||||
|
statusCode = 302
|
||||||
|
case "400":
|
||||||
|
statusCode = 400
|
||||||
|
case "404":
|
||||||
|
statusCode = 404
|
||||||
|
case "500":
|
||||||
|
statusCode = 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
body = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct, ok := cfg["content_type"].(string); ok {
|
||||||
|
w.Header().Set("Content-Type", ct)
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(statusCode)
|
||||||
|
w.Write([]byte(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle static files with root
|
||||||
|
if root, ok := cfg["root"].(string); ok {
|
||||||
|
path := req.URL.Path
|
||||||
|
|
||||||
|
if indexFile, ok := cfg["index_file"].(string); ok {
|
||||||
|
if path == "/" || strings.HasSuffix(path, "/") {
|
||||||
|
path = "/" + indexFile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(root, path)
|
||||||
|
|
||||||
|
if cacheControl, ok := cfg["cache_control"].(string); ok {
|
||||||
|
w.Header().Set("Cache-Control", cacheControl)
|
||||||
|
}
|
||||||
|
|
||||||
|
if headers, ok := cfg["headers"].([]interface{}); ok {
|
||||||
|
for _, h := range headers {
|
||||||
|
if header, ok := h.(string); ok {
|
||||||
|
parts := strings.SplitN(header, ": ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
w.Header().Set(parts[0], parts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
http.ServeFile(w, req, filePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle SPA fallback
|
||||||
|
if spaFallback, ok := cfg["spa_fallback"].(bool); ok && spaFallback {
|
||||||
|
root := r.staticDir
|
||||||
|
if rt, ok := cfg["root"].(string); ok {
|
||||||
|
root = rt
|
||||||
|
}
|
||||||
|
|
||||||
|
indexFile := "index.html"
|
||||||
|
if idx, ok := cfg["index_file"].(string); ok {
|
||||||
|
indexFile = idx
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(root, indexFile)
|
||||||
|
http.ServeFile(w, req, filePath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.NotFound(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateRouterFromConfig creates a router from extension config
|
||||||
|
func CreateRouterFromConfig(cfg map[string]interface{}) *Router {
|
||||||
|
router := NewRouter()
|
||||||
|
|
||||||
|
if locations, ok := cfg["regex_locations"].(map[string]interface{}); ok {
|
||||||
|
for pattern, routeCfg := range locations {
|
||||||
|
if rc, ok := routeCfg.(map[string]interface{}); ok {
|
||||||
|
router.AddRoute(pattern, rc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return router
|
||||||
|
}
|
||||||
375
go/internal/routing/router_test.go
Normal file
375
go/internal/routing/router_test.go
Normal file
@ -0,0 +1,375 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============== Router Initialization Tests ==============
|
||||||
|
|
||||||
|
func TestRouter_Initialization(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
|
||||||
|
if router.StaticDir() != "./static" {
|
||||||
|
t.Errorf("Expected static dir ./static, got %s", router.StaticDir())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(router.Routes()) != 0 {
|
||||||
|
t.Error("Expected empty routes")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(router.ExactRoutes()) != 0 {
|
||||||
|
t.Error("Expected empty exact routes")
|
||||||
|
}
|
||||||
|
|
||||||
|
if router.DefaultRoute() != nil {
|
||||||
|
t.Error("Expected nil default route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_CustomStaticDir(t *testing.T) {
|
||||||
|
router := NewRouter(WithStaticDir("/custom/path"))
|
||||||
|
|
||||||
|
if router.StaticDir() != "/custom/path" {
|
||||||
|
t.Errorf("Expected static dir /custom/path, got %s", router.StaticDir())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Route Adding Tests ==============
|
||||||
|
|
||||||
|
func TestRouter_AddExactRoute(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"return": "200 OK"}
|
||||||
|
|
||||||
|
router.AddRoute("=/health", config)
|
||||||
|
|
||||||
|
exactRoutes := router.ExactRoutes()
|
||||||
|
if _, ok := exactRoutes["/health"]; !ok {
|
||||||
|
t.Error("Expected /health in exact routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddDefaultRoute(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"spa_fallback": true, "root": "./static"}
|
||||||
|
|
||||||
|
router.AddRoute("__default__", config)
|
||||||
|
|
||||||
|
if router.DefaultRoute() == nil {
|
||||||
|
t.Error("Expected default route to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddRegexRoute(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"root": "./static"}
|
||||||
|
|
||||||
|
router.AddRoute("~^/api/", config)
|
||||||
|
|
||||||
|
if len(router.Routes()) != 1 {
|
||||||
|
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddCaseInsensitiveRegexRoute(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
|
||||||
|
|
||||||
|
router.AddRoute("~*\\.(css|js)$", config)
|
||||||
|
|
||||||
|
if len(router.Routes()) != 1 {
|
||||||
|
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if router.Routes()[0].CaseSensitive {
|
||||||
|
t.Error("Expected case-insensitive route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_InvalidRegexPattern(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"root": "./static"}
|
||||||
|
|
||||||
|
// Invalid regex - unmatched bracket
|
||||||
|
router.AddRoute("~^/api/[invalid", config)
|
||||||
|
|
||||||
|
// Should not add invalid pattern
|
||||||
|
if len(router.Routes()) != 0 {
|
||||||
|
t.Error("Should not add invalid regex pattern")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Route Matching Tests ==============
|
||||||
|
|
||||||
|
func TestRouter_MatchExactRoute(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"return": "200 OK"}
|
||||||
|
router.AddRoute("=/health", config)
|
||||||
|
|
||||||
|
match := router.Match("/health")
|
||||||
|
|
||||||
|
if match == nil {
|
||||||
|
t.Fatal("Expected match for /health")
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Config["return"] != "200 OK" {
|
||||||
|
t.Error("Expected return config")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(match.Params) != 0 {
|
||||||
|
t.Error("Expected empty params for exact match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_MatchExactRouteNoMatch(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"return": "200 OK"}
|
||||||
|
router.AddRoute("=/health", config)
|
||||||
|
|
||||||
|
match := router.Match("/healthcheck")
|
||||||
|
|
||||||
|
if match != nil {
|
||||||
|
t.Error("Exact route should not match /healthcheck")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_MatchRegexRoute(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
|
||||||
|
router.AddRoute("~^/api/v\\d+/", config)
|
||||||
|
|
||||||
|
match := router.Match("/api/v1/users")
|
||||||
|
|
||||||
|
if match == nil {
|
||||||
|
t.Fatal("Expected match for /api/v1/users")
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Config["proxy_pass"] != "http://localhost:9001" {
|
||||||
|
t.Error("Expected proxy_pass config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_MatchRegexRouteWithGroups(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"proxy_pass": "http://localhost:9001"}
|
||||||
|
router.AddRoute("~^/api/v(?P<version>\\d+)/", config)
|
||||||
|
|
||||||
|
match := router.Match("/api/v2/data")
|
||||||
|
|
||||||
|
if match == nil {
|
||||||
|
t.Fatal("Expected match for /api/v2/data")
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Params["version"] != "2" {
|
||||||
|
t.Errorf("Expected version=2, got %s", match.Params["version"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_MatchCaseInsensitiveRegex(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"root": "./static", "cache_control": "public, max-age=3600"}
|
||||||
|
router.AddRoute("~*\\.(CSS|JS)$", config)
|
||||||
|
|
||||||
|
// Should match lowercase
|
||||||
|
match1 := router.Match("/styles/main.css")
|
||||||
|
if match1 == nil {
|
||||||
|
t.Error("Should match lowercase .css")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should match uppercase
|
||||||
|
match2 := router.Match("/scripts/app.JS")
|
||||||
|
if match2 == nil {
|
||||||
|
t.Error("Should match uppercase .JS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_MatchCaseSensitiveRegex(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
config := map[string]interface{}{"root": "./static"}
|
||||||
|
router.AddRoute("~\\.(css)$", config)
|
||||||
|
|
||||||
|
// Should match lowercase
|
||||||
|
match1 := router.Match("/styles/main.css")
|
||||||
|
if match1 == nil {
|
||||||
|
t.Error("Should match lowercase .css")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should NOT match uppercase
|
||||||
|
match2 := router.Match("/styles/main.CSS")
|
||||||
|
if match2 != nil {
|
||||||
|
t.Error("Should not match uppercase .CSS for case-sensitive regex")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_MatchDefaultRoute(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
|
||||||
|
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
||||||
|
|
||||||
|
match := router.Match("/unknown/path")
|
||||||
|
|
||||||
|
if match == nil {
|
||||||
|
t.Fatal("Expected default route match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Config["spa_fallback"] != true {
|
||||||
|
t.Error("Expected spa_fallback config from default route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Priority Tests ==============
|
||||||
|
|
||||||
|
func TestRouter_PriorityExactOverRegex(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
router.AddRoute("=/api/status", map[string]interface{}{"return": "200 Exact"})
|
||||||
|
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
||||||
|
|
||||||
|
match := router.Match("/api/status")
|
||||||
|
|
||||||
|
if match == nil {
|
||||||
|
t.Fatal("Expected match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Config["return"] != "200 Exact" {
|
||||||
|
t.Error("Exact match should have priority over regex")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_PriorityRegexOverDefault(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
router.AddRoute("~^/api/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
||||||
|
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
||||||
|
|
||||||
|
match := router.Match("/api/v1/users")
|
||||||
|
|
||||||
|
if match == nil {
|
||||||
|
t.Fatal("Expected match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if match.Config["proxy_pass"] != "http://localhost:9001" {
|
||||||
|
t.Error("Regex match should have priority over default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== CreateRouterFromConfig Tests ==============
|
||||||
|
|
||||||
|
func TestCreateRouterFromConfig(t *testing.T) {
|
||||||
|
config := map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"=/health": map[string]interface{}{
|
||||||
|
"return": "200 OK",
|
||||||
|
"content_type": "text/plain",
|
||||||
|
},
|
||||||
|
"~^/api/": map[string]interface{}{
|
||||||
|
"proxy_pass": "http://localhost:9001",
|
||||||
|
},
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"spa_fallback": true,
|
||||||
|
"root": "./static",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := CreateRouterFromConfig(config)
|
||||||
|
|
||||||
|
// Check exact route
|
||||||
|
if _, ok := router.ExactRoutes()["/health"]; !ok {
|
||||||
|
t.Error("Expected /health exact route")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check regex route
|
||||||
|
if len(router.Routes()) != 1 {
|
||||||
|
t.Errorf("Expected 1 regex route, got %d", len(router.Routes()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check default route
|
||||||
|
if router.DefaultRoute() == nil {
|
||||||
|
t.Error("Expected default route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Static Dir Path Tests ==============
|
||||||
|
|
||||||
|
func TestRouter_StaticDirPath(t *testing.T) {
|
||||||
|
router := NewRouter(WithStaticDir("/var/www/html"))
|
||||||
|
|
||||||
|
expected, _ := filepath.Abs("/var/www/html")
|
||||||
|
actual, _ := filepath.Abs(router.StaticDir())
|
||||||
|
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Expected static dir %s, got %s", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Concurrent Access Tests ==============
|
||||||
|
|
||||||
|
func TestRouter_ConcurrentAccess(t *testing.T) {
|
||||||
|
router := NewRouter()
|
||||||
|
|
||||||
|
// Add routes concurrently
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(n int) {
|
||||||
|
router.AddRoute("~^/api/v"+string(rune('0'+n))+"/", map[string]interface{}{
|
||||||
|
"proxy_pass": "http://localhost:900" + string(rune('0'+n)),
|
||||||
|
})
|
||||||
|
done <- true
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match routes concurrently
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(n int) {
|
||||||
|
router.Match("/api/v" + string(rune('0'+n)) + "/users")
|
||||||
|
done <- true
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Benchmarks ==============
|
||||||
|
|
||||||
|
func BenchmarkRouter_MatchExact(b *testing.B) {
|
||||||
|
router := NewRouter()
|
||||||
|
router.AddRoute("=/health", map[string]interface{}{"return": "200 OK"})
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
router.Match("/health")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRouter_MatchRegex(b *testing.B) {
|
||||||
|
router := NewRouter()
|
||||||
|
router.AddRoute("~^/api/v(?P<version>\\d+)/", map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
router.Match("/api/v1/users/123")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRouter_MatchWithManyRoutes(b *testing.B) {
|
||||||
|
router := NewRouter()
|
||||||
|
|
||||||
|
// Add many routes
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
router.AddRoute("~^/api/v"+string(rune('0'+i%10))+"/service"+string(rune('0'+i/10))+"/",
|
||||||
|
map[string]interface{}{"proxy_pass": "http://localhost:9001"})
|
||||||
|
}
|
||||||
|
router.AddRoute("__default__", map[string]interface{}{"spa_fallback": true})
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
router.Match("/api/v5/service3/users/123")
|
||||||
|
}
|
||||||
|
}
|
||||||
130
go/internal/server/server.go
Normal file
130
go/internal/server/server.go
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
// Package server provides the HTTP server implementation
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/config"
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
"github.com/konduktor/konduktor/internal/middleware"
|
||||||
|
"github.com/konduktor/konduktor/internal/routing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const Version = "0.1.0"
|
||||||
|
|
||||||
|
// Server represents the Konduktor HTTP server
|
||||||
|
type Server struct {
|
||||||
|
config *config.Config
|
||||||
|
httpServer *http.Server
|
||||||
|
router *routing.Router
|
||||||
|
logger *logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new server instance
|
||||||
|
func New(cfg *config.Config) (*Server, error) {
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger, err := logging.NewFromConfig(cfg.Logging)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create logger: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
router := routing.New(cfg, logger)
|
||||||
|
|
||||||
|
srv := &Server{
|
||||||
|
config: cfg,
|
||||||
|
router: router,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
return srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the server and blocks until shutdown
|
||||||
|
func (s *Server) Run() error {
|
||||||
|
// Build handler chain with middleware
|
||||||
|
handler := s.buildHandler()
|
||||||
|
|
||||||
|
// Create HTTP server
|
||||||
|
addr := fmt.Sprintf("%s:%d", s.config.Server.Host, s.config.Server.Port)
|
||||||
|
s.httpServer = &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start server in goroutine
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
s.logger.Info("Server starting", "addr", addr, "version", Version)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if s.config.SSL.Enabled {
|
||||||
|
err = s.httpServer.ListenAndServeTLS(s.config.SSL.CertFile, s.config.SSL.KeyFile)
|
||||||
|
} else {
|
||||||
|
err = s.httpServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for shutdown signal
|
||||||
|
return s.waitForShutdown(errChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildHandler builds the HTTP handler chain
|
||||||
|
func (s *Server) buildHandler() http.Handler {
|
||||||
|
var handler http.Handler = s.router
|
||||||
|
|
||||||
|
// Add middleware
|
||||||
|
handler = middleware.AccessLog(handler, s.logger)
|
||||||
|
handler = middleware.ServerHeader(handler, Version)
|
||||||
|
handler = middleware.Recovery(handler, s.logger)
|
||||||
|
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForShutdown waits for shutdown signal and gracefully stops the server
|
||||||
|
func (s *Server) waitForShutdown(errChan <-chan error) error {
|
||||||
|
// Listen for shutdown signals
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
return err
|
||||||
|
case sig := <-sigChan:
|
||||||
|
s.logger.Info("Shutdown signal received", "signal", sig.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s.logger.Info("Shutting down server...")
|
||||||
|
|
||||||
|
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||||
|
s.logger.Error("Error during shutdown", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("Server stopped gracefully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the server
|
||||||
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
|
return s.httpServer.Shutdown(ctx)
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user