forked from aegis/pyserveX
feat: Add reverse proxy functionality with enhanced routing capabilities
- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing. - Implemented proxy_pass directive in routing extension to handle backend requests. - Enhanced error handling for backend unavailability and timeouts. - Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation. - Created helper functions for setting up test servers and backends, along with assertion utilities for response validation. - Updated server initialization to support extension management and middleware chaining. - Improved logging for debugging purposes during request handling.
This commit is contained in:
parent
8f5b9a5cd1
commit
881028c1e6
@ -12,4 +12,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.10 // indirect
|
github.com/spf13/pflag v1.0.10 // indirect
|
||||||
|
go.uber.org/multierr v1.10.0 // indirect
|
||||||
|
go.uber.org/zap v1.27.1 // indirect
|
||||||
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
20
go/go.sum
Normal file
20
go/go.sum
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
|
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||||
|
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||||
|
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
|
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||||
|
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
|
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||||
|
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
|
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||||
|
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||||
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
427
go/internal/extension/caching.go
Normal file
427
go/internal/extension/caching.go
Normal file
@ -0,0 +1,427 @@
|
|||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CachingExtension struct {
|
||||||
|
BaseExtension
|
||||||
|
cache map[string]*cacheEntry
|
||||||
|
cachePatterns []*cachePattern
|
||||||
|
defaultTTL time.Duration
|
||||||
|
maxSize int
|
||||||
|
currentSize int
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
hits int64
|
||||||
|
misses int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheEntry struct {
|
||||||
|
key string
|
||||||
|
body []byte
|
||||||
|
headers http.Header
|
||||||
|
statusCode int
|
||||||
|
contentType string
|
||||||
|
createdAt time.Time
|
||||||
|
expiresAt time.Time
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
type cachePattern struct {
|
||||||
|
pattern *regexp.Regexp
|
||||||
|
ttl time.Duration
|
||||||
|
methods []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CachingConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
DefaultTTL string `yaml:"default_ttl"` // e.g., "5m", "1h"
|
||||||
|
MaxSizeMB int `yaml:"max_size_mb"` // Max cache size in MB
|
||||||
|
CachePatterns []PatternConfig `yaml:"cache_patterns"` // Patterns to cache
|
||||||
|
}
|
||||||
|
|
||||||
|
type PatternConfig struct {
|
||||||
|
Pattern string `yaml:"pattern"` // Regex pattern
|
||||||
|
TTL string `yaml:"ttl"` // TTL for this pattern
|
||||||
|
Methods []string `yaml:"methods"` // HTTP methods to cache (default: GET)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCachingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
||||||
|
ext := &CachingExtension{
|
||||||
|
BaseExtension: NewBaseExtension("caching", 20, logger),
|
||||||
|
cache: make(map[string]*cacheEntry),
|
||||||
|
cachePatterns: make([]*cachePattern, 0),
|
||||||
|
defaultTTL: 5 * time.Minute,
|
||||||
|
maxSize: 100 * 1024 * 1024, // 100MB default
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttl, ok := config["default_ttl"].(string); ok {
|
||||||
|
if duration, err := time.ParseDuration(ttl); err == nil {
|
||||||
|
ext.defaultTTL = duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxSize, ok := config["max_size_mb"].(int); ok {
|
||||||
|
ext.maxSize = maxSize * 1024 * 1024
|
||||||
|
} else if maxSizeFloat, ok := config["max_size_mb"].(float64); ok {
|
||||||
|
ext.maxSize = int(maxSizeFloat) * 1024 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
if patterns, ok := config["cache_patterns"].([]interface{}); ok {
|
||||||
|
for _, p := range patterns {
|
||||||
|
if patternCfg, ok := p.(map[string]interface{}); ok {
|
||||||
|
pattern := &cachePattern{
|
||||||
|
ttl: ext.defaultTTL,
|
||||||
|
methods: []string{"GET"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if patternStr, ok := patternCfg["pattern"].(string); ok {
|
||||||
|
re, err := regexp.Compile(patternStr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Invalid cache pattern", "pattern", patternStr, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pattern.pattern = re
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttl, ok := patternCfg["ttl"].(string); ok {
|
||||||
|
if duration, err := time.ParseDuration(ttl); err == nil {
|
||||||
|
pattern.ttl = duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if methods, ok := patternCfg["methods"].([]interface{}); ok {
|
||||||
|
pattern.methods = make([]string, 0)
|
||||||
|
for _, m := range methods {
|
||||||
|
if method, ok := m.(string); ok {
|
||||||
|
pattern.methods = append(pattern.methods, strings.ToUpper(method))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ext.cachePatterns = append(ext.cachePatterns, pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go ext.cleanupLoop()
|
||||||
|
|
||||||
|
return ext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
||||||
|
if !e.shouldCache(r) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key := e.cacheKey(r)
|
||||||
|
|
||||||
|
e.mu.RLock()
|
||||||
|
entry, exists := e.cache[key]
|
||||||
|
e.mu.RUnlock()
|
||||||
|
|
||||||
|
if exists && time.Now().Before(entry.expiresAt) {
|
||||||
|
e.mu.Lock()
|
||||||
|
e.hits++
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
for k, values := range entry.headers {
|
||||||
|
for _, v := range values {
|
||||||
|
w.Header().Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Header().Set("X-Cache", "HIT")
|
||||||
|
w.Header().Set("Content-Type", entry.contentType)
|
||||||
|
w.WriteHeader(entry.statusCode)
|
||||||
|
w.Write(entry.body)
|
||||||
|
|
||||||
|
e.logger.Debug("Cache hit", "key", key, "path", r.URL.Path)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
e.mu.Lock()
|
||||||
|
e.misses++
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessResponse caches the response if applicable
|
||||||
|
func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Response caching is handled by the CachingResponseWriter
|
||||||
|
// This is called after the response is written
|
||||||
|
w.Header().Set("X-Cache", "MISS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapResponseWriter wraps the response writer to capture the response for caching
|
||||||
|
func (e *CachingExtension) WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter {
|
||||||
|
if !e.shouldCache(r) {
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cachingResponseWriter{
|
||||||
|
ResponseWriter: w,
|
||||||
|
ext: e,
|
||||||
|
request: r,
|
||||||
|
buffer: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type cachingResponseWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
ext *CachingExtension
|
||||||
|
request *http.Request
|
||||||
|
buffer *bytes.Buffer
|
||||||
|
statusCode int
|
||||||
|
wroteHeader bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *cachingResponseWriter) WriteHeader(code int) {
|
||||||
|
if !cw.wroteHeader {
|
||||||
|
cw.statusCode = code
|
||||||
|
cw.wroteHeader = true
|
||||||
|
cw.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *cachingResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
if !cw.wroteHeader {
|
||||||
|
cw.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.buffer.Write(b)
|
||||||
|
|
||||||
|
return cw.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *cachingResponseWriter) Finalize() {
|
||||||
|
if cw.statusCode < 200 || cw.statusCode >= 400 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body := cw.buffer.Bytes()
|
||||||
|
if len(body) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
key := cw.ext.cacheKey(cw.request)
|
||||||
|
ttl := cw.ext.getTTL(cw.request)
|
||||||
|
|
||||||
|
entry := &cacheEntry{
|
||||||
|
key: key,
|
||||||
|
body: body,
|
||||||
|
headers: cw.Header().Clone(),
|
||||||
|
statusCode: cw.statusCode,
|
||||||
|
contentType: cw.Header().Get("Content-Type"),
|
||||||
|
createdAt: time.Now(),
|
||||||
|
expiresAt: time.Now().Add(ttl),
|
||||||
|
size: len(body),
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.ext.store(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) shouldCache(r *http.Request) bool {
|
||||||
|
path := r.URL.Path
|
||||||
|
method := r.Method
|
||||||
|
|
||||||
|
for _, pattern := range e.cachePatterns {
|
||||||
|
if pattern.pattern.MatchString(path) {
|
||||||
|
for _, m := range pattern.methods {
|
||||||
|
if m == method {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default: only cache GET requests
|
||||||
|
return method == "GET" && len(e.cachePatterns) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) cacheKey(r *http.Request) string {
|
||||||
|
// Create cache key from method + URL + relevant headers
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(r.Method))
|
||||||
|
h.Write([]byte(r.URL.String()))
|
||||||
|
|
||||||
|
// Include Accept-Encoding for vary
|
||||||
|
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
|
||||||
|
h.Write([]byte(ae))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) getTTL(r *http.Request) time.Duration {
|
||||||
|
path := r.URL.Path
|
||||||
|
|
||||||
|
for _, pattern := range e.cachePatterns {
|
||||||
|
if pattern.pattern.MatchString(path) {
|
||||||
|
return pattern.ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.defaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) store(entry *cacheEntry) {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
// Evict old entries if needed
|
||||||
|
for e.currentSize+entry.size > e.maxSize && len(e.cache) > 0 {
|
||||||
|
e.evictOldest()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store new entry
|
||||||
|
if existing, ok := e.cache[entry.key]; ok {
|
||||||
|
e.currentSize -= existing.size
|
||||||
|
}
|
||||||
|
|
||||||
|
e.cache[entry.key] = entry
|
||||||
|
e.currentSize += entry.size
|
||||||
|
|
||||||
|
e.logger.Debug("Cached response",
|
||||||
|
"key", entry.key[:16],
|
||||||
|
"size", entry.size,
|
||||||
|
"ttl", entry.expiresAt.Sub(entry.createdAt).String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) evictOldest() {
|
||||||
|
var oldestKey string
|
||||||
|
var oldestTime time.Time
|
||||||
|
|
||||||
|
for key, entry := range e.cache {
|
||||||
|
if oldestKey == "" || entry.createdAt.Before(oldestTime) {
|
||||||
|
oldestKey = key
|
||||||
|
oldestTime = entry.createdAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldestKey != "" {
|
||||||
|
entry := e.cache[oldestKey]
|
||||||
|
e.currentSize -= entry.size
|
||||||
|
delete(e.cache, oldestKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) cleanupLoop() {
|
||||||
|
ticker := time.NewTicker(time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
e.cleanupExpired()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) cleanupExpired() {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for key, entry := range e.cache {
|
||||||
|
if now.After(entry.expiresAt) {
|
||||||
|
e.currentSize -= entry.size
|
||||||
|
delete(e.cache, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CachingExtension) Invalidate(key string) {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
if entry, ok := e.cache[key]; ok {
|
||||||
|
e.currentSize -= entry.size
|
||||||
|
delete(e.cache, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidatePattern removes all entries matching a pattern, unlike Invalidate
|
||||||
|
func (e *CachingExtension) InvalidatePattern(pattern string) error {
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
for key, entry := range e.cache {
|
||||||
|
if re.MatchString(key) {
|
||||||
|
e.currentSize -= entry.size
|
||||||
|
delete(e.cache, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all entries from cache
|
||||||
|
func (e *CachingExtension) Clear() {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
e.cache = make(map[string]*cacheEntry)
|
||||||
|
e.currentSize = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns caching metrics
|
||||||
|
func (e *CachingExtension) GetMetrics() map[string]interface{} {
|
||||||
|
e.mu.RLock()
|
||||||
|
defer e.mu.RUnlock()
|
||||||
|
|
||||||
|
hitRate := float64(0)
|
||||||
|
total := e.hits + e.misses
|
||||||
|
if total > 0 {
|
||||||
|
hitRate = float64(e.hits) / float64(total) * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"entries": len(e.cache),
|
||||||
|
"size_bytes": e.currentSize,
|
||||||
|
"max_size": e.maxSize,
|
||||||
|
"hits": e.hits,
|
||||||
|
"misses": e.misses,
|
||||||
|
"hit_rate": hitRate,
|
||||||
|
"patterns": len(e.cachePatterns),
|
||||||
|
"default_ttl": e.defaultTTL.String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup stops the cleanup goroutine
|
||||||
|
func (e *CachingExtension) Cleanup() error {
|
||||||
|
e.Clear()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheReader wraps an io.ReadCloser to cache the body
|
||||||
|
type CacheReader struct {
|
||||||
|
io.ReadCloser
|
||||||
|
buffer *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *CacheReader) Read(p []byte) (int, error) {
|
||||||
|
n, err := cr.ReadCloser.Read(p)
|
||||||
|
if n > 0 {
|
||||||
|
cr.buffer.Write(p[:n])
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *CacheReader) GetBody() []byte {
|
||||||
|
return cr.buffer.Bytes()
|
||||||
|
}
|
||||||
111
go/internal/extension/extension.go
Normal file
111
go/internal/extension/extension.go
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Extension is the interface that all extensions must implement
|
||||||
|
type Extension interface {
|
||||||
|
// Name returns the unique name of the extension
|
||||||
|
Name() string
|
||||||
|
|
||||||
|
// Initialize is called when the extension is loaded
|
||||||
|
Initialize() error
|
||||||
|
|
||||||
|
// ProcessRequest processes an incoming request before routing.
|
||||||
|
// Returns:
|
||||||
|
// - response: if non-nil, the request is handled and no further processing occurs
|
||||||
|
// - handled: if true, the request was handled by this extension
|
||||||
|
// - err: any error that occurred
|
||||||
|
ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (handled bool, err error)
|
||||||
|
|
||||||
|
// ProcessResponse is called after the response is generated but before it's sent.
|
||||||
|
// Extensions can modify the response here.
|
||||||
|
ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request)
|
||||||
|
|
||||||
|
// Cleanup is called when the extension is being unloaded
|
||||||
|
Cleanup() error
|
||||||
|
|
||||||
|
// Enabled returns whether the extension is currently enabled
|
||||||
|
Enabled() bool
|
||||||
|
|
||||||
|
// SetEnabled enables or disables the extension
|
||||||
|
SetEnabled(enabled bool)
|
||||||
|
|
||||||
|
// Priority returns the extension's priority (lower = earlier execution)
|
||||||
|
Priority() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaseExtension provides a default implementation for common Extension methods
|
||||||
|
type BaseExtension struct {
|
||||||
|
name string
|
||||||
|
enabled bool
|
||||||
|
priority int
|
||||||
|
logger *logging.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBaseExtension creates a new BaseExtension
|
||||||
|
func NewBaseExtension(name string, priority int, logger *logging.Logger) BaseExtension {
|
||||||
|
return BaseExtension{
|
||||||
|
name: name,
|
||||||
|
enabled: true,
|
||||||
|
priority: priority,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the extension name
|
||||||
|
func (b *BaseExtension) Name() string {
|
||||||
|
return b.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enabled returns whether the extension is enabled
|
||||||
|
func (b *BaseExtension) Enabled() bool {
|
||||||
|
return b.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnabled sets the enabled state
|
||||||
|
func (b *BaseExtension) SetEnabled(enabled bool) {
|
||||||
|
b.enabled = enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority returns the extension priority
|
||||||
|
func (b *BaseExtension) Priority() int {
|
||||||
|
return b.priority
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize default implementation (no-op)
|
||||||
|
func (b *BaseExtension) Initialize() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup default implementation (no-op)
|
||||||
|
func (b *BaseExtension) Cleanup() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessRequest default implementation (pass-through)
|
||||||
|
func (b *BaseExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessResponse default implementation (no-op)
|
||||||
|
func (b *BaseExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger returns the extension's logger
|
||||||
|
func (b *BaseExtension) Logger() *logging.Logger {
|
||||||
|
return b.logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtensionConfig holds configuration for creating extensions
|
||||||
|
type ExtensionConfig struct {
|
||||||
|
Type string
|
||||||
|
Config map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtensionFactory is a function that creates an extension from config
|
||||||
|
type ExtensionFactory func(config map[string]interface{}, logger *logging.Logger) (Extension, error)
|
||||||
234
go/internal/extension/manager.go
Normal file
234
go/internal/extension/manager.go
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager manages all loaded extensions
|
||||||
|
type Manager struct {
|
||||||
|
extensions []Extension
|
||||||
|
registry map[string]ExtensionFactory
|
||||||
|
logger *logging.Logger
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new extension manager
|
||||||
|
func NewManager(logger *logging.Logger) *Manager {
|
||||||
|
m := &Manager{
|
||||||
|
extensions: make([]Extension, 0),
|
||||||
|
registry: make(map[string]ExtensionFactory),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register built-in extensions
|
||||||
|
m.RegisterFactory("routing", NewRoutingExtension)
|
||||||
|
m.RegisterFactory("security", NewSecurityExtension)
|
||||||
|
m.RegisterFactory("caching", NewCachingExtension)
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterFactory registers an extension factory
|
||||||
|
func (m *Manager) RegisterFactory(name string, factory ExtensionFactory) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.registry[name] = factory
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadExtension loads an extension by type and config
|
||||||
|
func (m *Manager) LoadExtension(extType string, config map[string]interface{}) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
factory, ok := m.registry[extType]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unknown extension type: %s", extType)
|
||||||
|
}
|
||||||
|
|
||||||
|
ext, err := factory(config, m.logger)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create extension %s: %w", extType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ext.Initialize(); err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize extension %s: %w", extType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.extensions = append(m.extensions, ext)
|
||||||
|
|
||||||
|
// Sort by priority (lower first)
|
||||||
|
sort.Slice(m.extensions, func(i, j int) bool {
|
||||||
|
return m.extensions[i].Priority() < m.extensions[j].Priority()
|
||||||
|
})
|
||||||
|
|
||||||
|
m.logger.Info("Loaded extension", "type", extType, "name", ext.Name(), "priority", ext.Priority())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddExtension adds a pre-created extension
|
||||||
|
func (m *Manager) AddExtension(ext Extension) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if err := ext.Initialize(); err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize extension %s: %w", ext.Name(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.extensions = append(m.extensions, ext)
|
||||||
|
|
||||||
|
// Sort by priority
|
||||||
|
sort.Slice(m.extensions, func(i, j int) bool {
|
||||||
|
return m.extensions[i].Priority() < m.extensions[j].Priority()
|
||||||
|
})
|
||||||
|
|
||||||
|
m.logger.Info("Added extension", "name", ext.Name(), "priority", ext.Priority())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessRequest runs all extensions' ProcessRequest in order
|
||||||
|
// Returns true if any extension handled the request
|
||||||
|
func (m *Manager) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
extensions := m.extensions
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, ext := range extensions {
|
||||||
|
if !ext.Enabled() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
handled, err := ext.ProcessRequest(ctx, w, r)
|
||||||
|
if err != nil {
|
||||||
|
m.logger.Error("Extension error", "extension", ext.Name(), "error", err)
|
||||||
|
// Continue to next extension on error
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if handled {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessResponse runs all extensions' ProcessResponse in reverse order
|
||||||
|
func (m *Manager) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
m.mu.RLock()
|
||||||
|
extensions := m.extensions
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
// Process in reverse order for response
|
||||||
|
for i := len(extensions) - 1; i >= 0; i-- {
|
||||||
|
ext := extensions[i]
|
||||||
|
if !ext.Enabled() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ext.ProcessResponse(ctx, w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup cleans up all extensions
|
||||||
|
func (m *Manager) Cleanup() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
for _, ext := range m.extensions {
|
||||||
|
if err := ext.Cleanup(); err != nil {
|
||||||
|
m.logger.Error("Extension cleanup error", "extension", ext.Name(), "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.extensions = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExtension returns an extension by name
|
||||||
|
func (m *Manager) GetExtension(name string) Extension {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, ext := range m.extensions {
|
||||||
|
if ext.Name() == name {
|
||||||
|
return ext
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extensions returns all loaded extensions
|
||||||
|
func (m *Manager) Extensions() []Extension {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]Extension, len(m.extensions))
|
||||||
|
copy(result, m.extensions)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler returns an http.Handler that processes requests through all extensions
|
||||||
|
func (m *Manager) Handler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
// Create response wrapper to capture response for ProcessResponse
|
||||||
|
wrapper := newResponseWrapper(w)
|
||||||
|
|
||||||
|
// Process request through extensions
|
||||||
|
handled, err := m.ProcessRequest(ctx, wrapper, r)
|
||||||
|
if err != nil {
|
||||||
|
m.logger.Error("Error processing request", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if handled {
|
||||||
|
// Extension handled the request, process response
|
||||||
|
m.ProcessResponse(ctx, wrapper, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No extension handled, pass to next handler
|
||||||
|
next.ServeHTTP(wrapper, r)
|
||||||
|
|
||||||
|
// Process response
|
||||||
|
m.ProcessResponse(ctx, wrapper, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseWrapper wraps http.ResponseWriter to allow response modification
|
||||||
|
type responseWrapper struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
statusCode int
|
||||||
|
written bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResponseWrapper(w http.ResponseWriter) *responseWrapper {
|
||||||
|
return &responseWrapper{
|
||||||
|
ResponseWriter: w,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWrapper) WriteHeader(code int) {
|
||||||
|
if !rw.written {
|
||||||
|
rw.statusCode = code
|
||||||
|
rw.ResponseWriter.WriteHeader(code)
|
||||||
|
rw.written = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWrapper) Write(b []byte) (int, error) {
|
||||||
|
if !rw.written {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
return rw.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWrapper) StatusCode() int {
|
||||||
|
return rw.statusCode
|
||||||
|
}
|
||||||
176
go/internal/extension/manager_test.go
Normal file
176
go/internal/extension/manager_test.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestLogger() *logging.Logger {
|
||||||
|
logger, _ := logging.New(logging.Config{Level: "DEBUG"})
|
||||||
|
return logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewManager(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
manager := NewManager(logger)
|
||||||
|
|
||||||
|
if manager == nil {
|
||||||
|
t.Fatal("Expected manager, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check built-in factories are registered
|
||||||
|
if _, ok := manager.registry["routing"]; !ok {
|
||||||
|
t.Error("Expected routing factory to be registered")
|
||||||
|
}
|
||||||
|
if _, ok := manager.registry["security"]; !ok {
|
||||||
|
t.Error("Expected security factory to be registered")
|
||||||
|
}
|
||||||
|
if _, ok := manager.registry["caching"]; !ok {
|
||||||
|
t.Error("Expected caching factory to be registered")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_LoadExtension(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
manager := NewManager(logger)
|
||||||
|
|
||||||
|
err := manager.LoadExtension("security", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to load security extension: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exts := manager.Extensions()
|
||||||
|
if len(exts) != 1 {
|
||||||
|
t.Errorf("Expected 1 extension, got %d", len(exts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_LoadExtension_Unknown(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
manager := NewManager(logger)
|
||||||
|
|
||||||
|
err := manager.LoadExtension("unknown", map[string]interface{}{})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for unknown extension type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_GetExtension(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
manager := NewManager(logger)
|
||||||
|
|
||||||
|
manager.LoadExtension("security", map[string]interface{}{})
|
||||||
|
|
||||||
|
ext := manager.GetExtension("security")
|
||||||
|
if ext == nil {
|
||||||
|
t.Error("Expected to find security extension")
|
||||||
|
}
|
||||||
|
|
||||||
|
ext = manager.GetExtension("nonexistent")
|
||||||
|
if ext != nil {
|
||||||
|
t.Error("Expected nil for nonexistent extension")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_ProcessRequest(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
manager := NewManager(logger)
|
||||||
|
|
||||||
|
// Load security extension with blocked IP
|
||||||
|
manager.LoadExtension("security", map[string]interface{}{
|
||||||
|
"blocked_ips": []interface{}{"192.168.1.1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handled, err := manager.ProcessRequest(context.Background(), rr, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !handled {
|
||||||
|
t.Error("Expected request to be handled (blocked)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rr.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Expected status 403, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Handler(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
manager := NewManager(logger)
|
||||||
|
|
||||||
|
// Load routing extension with a simple route
|
||||||
|
manager.LoadExtension("routing", map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"=/health": map[string]interface{}{
|
||||||
|
"return": "200 OK",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := manager.Handler(baseHandler)
|
||||||
|
|
||||||
|
// Test health route
|
||||||
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Priority(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
manager := NewManager(logger)
|
||||||
|
|
||||||
|
// Load extensions in any order
|
||||||
|
manager.LoadExtension("routing", map[string]interface{}{}) // Priority 50
|
||||||
|
manager.LoadExtension("security", map[string]interface{}{}) // Priority 10
|
||||||
|
manager.LoadExtension("caching", map[string]interface{}{}) // Priority 20
|
||||||
|
|
||||||
|
exts := manager.Extensions()
|
||||||
|
if len(exts) != 3 {
|
||||||
|
t.Fatalf("Expected 3 extensions, got %d", len(exts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check order by priority
|
||||||
|
if exts[0].Name() != "security" {
|
||||||
|
t.Errorf("Expected security first, got %s", exts[0].Name())
|
||||||
|
}
|
||||||
|
if exts[1].Name() != "caching" {
|
||||||
|
t.Errorf("Expected caching second, got %s", exts[1].Name())
|
||||||
|
}
|
||||||
|
if exts[2].Name() != "routing" {
|
||||||
|
t.Errorf("Expected routing third, got %s", exts[2].Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Cleanup(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
manager := NewManager(logger)
|
||||||
|
|
||||||
|
manager.LoadExtension("security", map[string]interface{}{})
|
||||||
|
manager.LoadExtension("routing", map[string]interface{}{})
|
||||||
|
|
||||||
|
manager.Cleanup()
|
||||||
|
|
||||||
|
exts := manager.Extensions()
|
||||||
|
if len(exts) != 0 {
|
||||||
|
t.Errorf("Expected 0 extensions after cleanup, got %d", len(exts))
|
||||||
|
}
|
||||||
|
}
|
||||||
428
go/internal/extension/routing.go
Normal file
428
go/internal/extension/routing.go
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
"github.com/konduktor/konduktor/internal/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoutingExtension handles request routing based on patterns
|
||||||
|
type RoutingExtension struct {
|
||||||
|
BaseExtension
|
||||||
|
exactRoutes map[string]RouteConfig
|
||||||
|
regexRoutes []*regexRoute
|
||||||
|
defaultRoute *RouteConfig
|
||||||
|
staticDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteConfig holds configuration for a route
|
||||||
|
type RouteConfig struct {
|
||||||
|
ProxyPass string
|
||||||
|
Root string
|
||||||
|
IndexFile string
|
||||||
|
Return string
|
||||||
|
ContentType string
|
||||||
|
Headers []string
|
||||||
|
CacheControl string
|
||||||
|
SPAFallback bool
|
||||||
|
ExcludePatterns []string
|
||||||
|
Timeout float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type regexRoute struct {
|
||||||
|
pattern *regexp.Regexp
|
||||||
|
config RouteConfig
|
||||||
|
caseSensitive bool
|
||||||
|
originalExpr string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoutingExtension creates a new routing extension
|
||||||
|
func NewRoutingExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
||||||
|
ext := &RoutingExtension{
|
||||||
|
BaseExtension: NewBaseExtension("routing", 50, logger), // Middle priority
|
||||||
|
exactRoutes: make(map[string]RouteConfig),
|
||||||
|
regexRoutes: make([]*regexRoute, 0),
|
||||||
|
staticDir: "./static",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Routing extension config", "config", config)
|
||||||
|
|
||||||
|
// Parse regex_locations from config
|
||||||
|
if locations, ok := config["regex_locations"].(map[string]interface{}); ok {
|
||||||
|
logger.Debug("Found regex_locations", "count", len(locations))
|
||||||
|
for pattern, routeCfg := range locations {
|
||||||
|
logger.Debug("Adding route", "pattern", pattern)
|
||||||
|
if rc, ok := routeCfg.(map[string]interface{}); ok {
|
||||||
|
ext.addRoute(pattern, parseRouteConfig(rc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Warn("No regex_locations found in config", "config_keys", getKeys(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse static_dir if provided
|
||||||
|
if staticDir, ok := config["static_dir"].(string); ok {
|
||||||
|
ext.staticDir = staticDir
|
||||||
|
}
|
||||||
|
|
||||||
|
return ext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRouteConfig(cfg map[string]interface{}) RouteConfig {
|
||||||
|
rc := RouteConfig{
|
||||||
|
IndexFile: "index.html",
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := cfg["proxy_pass"].(string); ok {
|
||||||
|
rc.ProxyPass = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["root"].(string); ok {
|
||||||
|
rc.Root = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["index_file"].(string); ok {
|
||||||
|
rc.IndexFile = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["return"].(string); ok {
|
||||||
|
rc.Return = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["content_type"].(string); ok {
|
||||||
|
rc.ContentType = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["cache_control"].(string); ok {
|
||||||
|
rc.CacheControl = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["spa_fallback"].(bool); ok {
|
||||||
|
rc.SPAFallback = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["timeout"].(float64); ok {
|
||||||
|
rc.Timeout = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse headers
|
||||||
|
if headers, ok := cfg["headers"].([]interface{}); ok {
|
||||||
|
for _, h := range headers {
|
||||||
|
if header, ok := h.(string); ok {
|
||||||
|
rc.Headers = append(rc.Headers, header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse exclude_patterns
|
||||||
|
if patterns, ok := cfg["exclude_patterns"].([]interface{}); ok {
|
||||||
|
for _, p := range patterns {
|
||||||
|
if pattern, ok := p.(string); ok {
|
||||||
|
rc.ExcludePatterns = append(rc.ExcludePatterns, pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RoutingExtension) addRoute(pattern string, config RouteConfig) {
|
||||||
|
switch {
|
||||||
|
case pattern == "__default__":
|
||||||
|
e.defaultRoute = &config
|
||||||
|
|
||||||
|
case strings.HasPrefix(pattern, "="):
|
||||||
|
// Exact match
|
||||||
|
path := strings.TrimPrefix(pattern, "=")
|
||||||
|
e.exactRoutes[path] = config
|
||||||
|
|
||||||
|
case strings.HasPrefix(pattern, "~*"):
|
||||||
|
// Case-insensitive regex
|
||||||
|
expr := strings.TrimPrefix(pattern, "~*")
|
||||||
|
re, err := regexp.Compile("(?i)" + expr)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.regexRoutes = append(e.regexRoutes, ®exRoute{
|
||||||
|
pattern: re,
|
||||||
|
config: config,
|
||||||
|
caseSensitive: false,
|
||||||
|
originalExpr: expr,
|
||||||
|
})
|
||||||
|
|
||||||
|
case strings.HasPrefix(pattern, "~"):
|
||||||
|
// Case-sensitive regex
|
||||||
|
expr := strings.TrimPrefix(pattern, "~")
|
||||||
|
re, err := regexp.Compile(expr)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("Invalid regex pattern", "pattern", pattern, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.regexRoutes = append(e.regexRoutes, ®exRoute{
|
||||||
|
pattern: re,
|
||||||
|
config: config,
|
||||||
|
caseSensitive: true,
|
||||||
|
originalExpr: expr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessRequest handles the request routing
|
||||||
|
func (e *RoutingExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
||||||
|
path := r.URL.Path
|
||||||
|
|
||||||
|
// 1. Check exact routes (ignore request path for proxy)
|
||||||
|
if config, ok := e.exactRoutes[path]; ok {
|
||||||
|
return e.handleRoute(w, r, config, nil, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check regex routes
|
||||||
|
for _, route := range e.regexRoutes {
|
||||||
|
match := route.pattern.FindStringSubmatch(path)
|
||||||
|
if match != nil {
|
||||||
|
params := make(map[string]string)
|
||||||
|
names := route.pattern.SubexpNames()
|
||||||
|
for i, name := range names {
|
||||||
|
if i > 0 && name != "" && i < len(match) {
|
||||||
|
params[name] = match[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return e.handleRoute(w, r, route.config, params, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Check default route
|
||||||
|
if e.defaultRoute != nil {
|
||||||
|
return e.handleRoute(w, r, *e.defaultRoute, nil, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RoutingExtension) handleRoute(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
|
||||||
|
// Handle "return" directive
|
||||||
|
if config.Return != "" {
|
||||||
|
return e.handleReturn(w, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle proxy_pass
|
||||||
|
if config.ProxyPass != "" {
|
||||||
|
return e.handleProxy(w, r, config, params, exactMatch)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle static files with root
|
||||||
|
if config.Root != "" {
|
||||||
|
return e.handleStatic(w, r, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle SPA fallback
|
||||||
|
if config.SPAFallback {
|
||||||
|
return e.handleSPAFallback(w, r, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RoutingExtension) handleReturn(w http.ResponseWriter, config RouteConfig) (bool, error) {
|
||||||
|
parts := strings.SplitN(config.Return, " ", 2)
|
||||||
|
statusCode := 200
|
||||||
|
body := "OK"
|
||||||
|
|
||||||
|
if len(parts) >= 1 {
|
||||||
|
switch parts[0] {
|
||||||
|
case "200":
|
||||||
|
statusCode = 200
|
||||||
|
case "201":
|
||||||
|
statusCode = 201
|
||||||
|
case "301":
|
||||||
|
statusCode = 301
|
||||||
|
case "302":
|
||||||
|
statusCode = 302
|
||||||
|
case "400":
|
||||||
|
statusCode = 400
|
||||||
|
case "404":
|
||||||
|
statusCode = 404
|
||||||
|
case "500":
|
||||||
|
statusCode = 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
body = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := "text/plain"
|
||||||
|
if config.ContentType != "" {
|
||||||
|
contentType = config.ContentType
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", contentType)
|
||||||
|
w.WriteHeader(statusCode)
|
||||||
|
w.Write([]byte(body))
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RoutingExtension) handleProxy(w http.ResponseWriter, r *http.Request, config RouteConfig, params map[string]string, exactMatch bool) (bool, error) {
|
||||||
|
target := config.ProxyPass
|
||||||
|
|
||||||
|
// Check if target URL contains parameter placeholders
|
||||||
|
hasParams := strings.Contains(target, "{") && strings.Contains(target, "}")
|
||||||
|
|
||||||
|
// Substitute params in target URL
|
||||||
|
for key, value := range params {
|
||||||
|
target = strings.ReplaceAll(target, "{"+key+"}", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create proxy config
|
||||||
|
// IgnoreRequestPath=true when:
|
||||||
|
// - exact match route (=/path)
|
||||||
|
// - target URL had parameter substitutions (the target path is fully specified)
|
||||||
|
proxyConfig := &proxy.Config{
|
||||||
|
Target: target,
|
||||||
|
Headers: make(map[string]string),
|
||||||
|
IgnoreRequestPath: exactMatch || hasParams,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set timeout if specified
|
||||||
|
if config.Timeout > 0 {
|
||||||
|
proxyConfig.Timeout = time.Duration(config.Timeout * float64(time.Second))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse headers
|
||||||
|
clientIP := getClientIP(r)
|
||||||
|
for _, header := range config.Headers {
|
||||||
|
parts := strings.SplitN(header, ": ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
value := parts[1]
|
||||||
|
// Substitute params
|
||||||
|
for key, pValue := range params {
|
||||||
|
value = strings.ReplaceAll(value, "{"+key+"}", pValue)
|
||||||
|
}
|
||||||
|
// Substitute special variables
|
||||||
|
value = strings.ReplaceAll(value, "$remote_addr", clientIP)
|
||||||
|
proxyConfig.Headers[parts[0]] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := proxy.New(proxyConfig, e.logger)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("Failed to create proxy", "target", target, "error", err)
|
||||||
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.ProxyRequest(w, r, params)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RoutingExtension) handleStatic(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
|
||||||
|
path := r.URL.Path
|
||||||
|
|
||||||
|
// Handle index file for root or directory paths
|
||||||
|
if path == "/" || strings.HasSuffix(path, "/") {
|
||||||
|
path = "/" + config.IndexFile
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get absolute path for root dir
|
||||||
|
absRoot, err := filepath.Abs(config.Root)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(absRoot, filepath.Clean("/"+path))
|
||||||
|
cleanPath := filepath.Clean(filePath)
|
||||||
|
|
||||||
|
// Prevent directory traversal
|
||||||
|
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) {
|
||||||
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if file exists
|
||||||
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||||
|
return false, nil // Let other handlers try
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set cache control header
|
||||||
|
if config.CacheControl != "" {
|
||||||
|
w.Header().Set("Cache-Control", config.CacheControl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set custom headers
|
||||||
|
for _, header := range config.Headers {
|
||||||
|
parts := strings.SplitN(header, ": ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
w.Header().Set(parts[0], parts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
http.ServeFile(w, r, filePath)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RoutingExtension) handleSPAFallback(w http.ResponseWriter, r *http.Request, config RouteConfig) (bool, error) {
|
||||||
|
path := r.URL.Path
|
||||||
|
|
||||||
|
// Check exclude patterns
|
||||||
|
for _, pattern := range config.ExcludePatterns {
|
||||||
|
if strings.HasPrefix(path, pattern) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
root := config.Root
|
||||||
|
if root == "" {
|
||||||
|
root = e.staticDir
|
||||||
|
}
|
||||||
|
|
||||||
|
indexFile := config.IndexFile
|
||||||
|
if indexFile == "" {
|
||||||
|
indexFile = "index.html"
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(root, indexFile)
|
||||||
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
http.ServeFile(w, r, filePath)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
// Check X-Forwarded-For header first
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
parts := strings.Split(xff, ",")
|
||||||
|
return strings.TrimSpace(parts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check X-Real-IP header
|
||||||
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||||
|
return xri
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to RemoteAddr
|
||||||
|
ip := r.RemoteAddr
|
||||||
|
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||||
|
ip = ip[:idx]
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns routing metrics
|
||||||
|
func (e *RoutingExtension) GetMetrics() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"exact_routes": len(e.exactRoutes),
|
||||||
|
"regex_routes": len(e.regexRoutes),
|
||||||
|
"has_default": e.defaultRoute != nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKeys(m map[string]interface{}) []string {
|
||||||
|
keys := make([]string, 0, len(m))
|
||||||
|
for k := range m {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
312
go/internal/extension/security.go
Normal file
312
go/internal/extension/security.go
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SecurityExtension provides security features like IP filtering and security headers
|
||||||
|
type SecurityExtension struct {
|
||||||
|
BaseExtension
|
||||||
|
allowedIPs map[string]bool
|
||||||
|
blockedIPs map[string]bool
|
||||||
|
allowedCIDRs []*net.IPNet
|
||||||
|
blockedCIDRs []*net.IPNet
|
||||||
|
securityHeaders map[string]string
|
||||||
|
|
||||||
|
// Rate limiting
|
||||||
|
rateLimitEnabled bool
|
||||||
|
rateLimitRequests int
|
||||||
|
rateLimitWindow time.Duration
|
||||||
|
rateLimitByIP map[string]*rateLimitEntry
|
||||||
|
rateLimitMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type rateLimitEntry struct {
|
||||||
|
count int
|
||||||
|
resetTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecurityConfig holds security extension configuration
|
||||||
|
type SecurityConfig struct {
|
||||||
|
AllowedIPs []string `yaml:"allowed_ips"`
|
||||||
|
BlockedIPs []string `yaml:"blocked_ips"`
|
||||||
|
SecurityHeaders map[string]string `yaml:"security_headers"`
|
||||||
|
RateLimit *RateLimitConfig `yaml:"rate_limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimitConfig holds rate limiting configuration
|
||||||
|
type RateLimitConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
Requests int `yaml:"requests"`
|
||||||
|
Window string `yaml:"window"` // e.g., "1m", "1h"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default security headers
|
||||||
|
var defaultSecurityHeaders = map[string]string{
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"X-Frame-Options": "DENY",
|
||||||
|
"X-XSS-Protection": "1; mode=block",
|
||||||
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSecurityExtension creates a new security extension
|
||||||
|
func NewSecurityExtension(config map[string]interface{}, logger *logging.Logger) (Extension, error) {
|
||||||
|
ext := &SecurityExtension{
|
||||||
|
BaseExtension: NewBaseExtension("security", 10, logger), // High priority (early execution)
|
||||||
|
allowedIPs: make(map[string]bool),
|
||||||
|
blockedIPs: make(map[string]bool),
|
||||||
|
allowedCIDRs: make([]*net.IPNet, 0),
|
||||||
|
blockedCIDRs: make([]*net.IPNet, 0),
|
||||||
|
securityHeaders: make(map[string]string),
|
||||||
|
rateLimitByIP: make(map[string]*rateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy default security headers
|
||||||
|
for k, v := range defaultSecurityHeaders {
|
||||||
|
ext.securityHeaders[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse allowed_ips
|
||||||
|
if allowedIPs, ok := config["allowed_ips"].([]interface{}); ok {
|
||||||
|
for _, ip := range allowedIPs {
|
||||||
|
if ipStr, ok := ip.(string); ok {
|
||||||
|
ext.addAllowedIP(ipStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse blocked_ips
|
||||||
|
if blockedIPs, ok := config["blocked_ips"].([]interface{}); ok {
|
||||||
|
for _, ip := range blockedIPs {
|
||||||
|
if ipStr, ok := ip.(string); ok {
|
||||||
|
ext.addBlockedIP(ipStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse security_headers
|
||||||
|
if headers, ok := config["security_headers"].(map[string]interface{}); ok {
|
||||||
|
for k, v := range headers {
|
||||||
|
if vStr, ok := v.(string); ok {
|
||||||
|
ext.securityHeaders[k] = vStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse rate_limit
|
||||||
|
if rateLimit, ok := config["rate_limit"].(map[string]interface{}); ok {
|
||||||
|
if enabled, ok := rateLimit["enabled"].(bool); ok && enabled {
|
||||||
|
ext.rateLimitEnabled = true
|
||||||
|
|
||||||
|
if requests, ok := rateLimit["requests"].(int); ok {
|
||||||
|
ext.rateLimitRequests = requests
|
||||||
|
} else if requestsFloat, ok := rateLimit["requests"].(float64); ok {
|
||||||
|
ext.rateLimitRequests = int(requestsFloat)
|
||||||
|
} else {
|
||||||
|
ext.rateLimitRequests = 100 // default
|
||||||
|
}
|
||||||
|
|
||||||
|
if window, ok := rateLimit["window"].(string); ok {
|
||||||
|
if duration, err := time.ParseDuration(window); err == nil {
|
||||||
|
ext.rateLimitWindow = duration
|
||||||
|
} else {
|
||||||
|
ext.rateLimitWindow = time.Minute // default
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ext.rateLimitWindow = time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Rate limiting enabled",
|
||||||
|
"requests", ext.rateLimitRequests,
|
||||||
|
"window", ext.rateLimitWindow.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SecurityExtension) addAllowedIP(ip string) {
|
||||||
|
if strings.Contains(ip, "/") {
|
||||||
|
// CIDR notation
|
||||||
|
_, cidr, err := net.ParseCIDR(ip)
|
||||||
|
if err == nil {
|
||||||
|
e.allowedCIDRs = append(e.allowedCIDRs, cidr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
e.allowedIPs[ip] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SecurityExtension) addBlockedIP(ip string) {
|
||||||
|
if strings.Contains(ip, "/") {
|
||||||
|
// CIDR notation
|
||||||
|
_, cidr, err := net.ParseCIDR(ip)
|
||||||
|
if err == nil {
|
||||||
|
e.blockedCIDRs = append(e.blockedCIDRs, cidr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
e.blockedIPs[ip] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessRequest checks security rules
|
||||||
|
func (e *SecurityExtension) ProcessRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) (bool, error) {
|
||||||
|
clientIP := getClientIP(r)
|
||||||
|
parsedIP := net.ParseIP(clientIP)
|
||||||
|
|
||||||
|
// Check blocked IPs first
|
||||||
|
if e.isBlocked(clientIP, parsedIP) {
|
||||||
|
e.logger.Warn("Blocked request from IP", "ip", clientIP)
|
||||||
|
http.Error(w, "403 Forbidden", http.StatusForbidden)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check allowed IPs (if configured, only these IPs are allowed)
|
||||||
|
if len(e.allowedIPs) > 0 || len(e.allowedCIDRs) > 0 {
|
||||||
|
if !e.isAllowed(clientIP, parsedIP) {
|
||||||
|
e.logger.Warn("Access denied for IP", "ip", clientIP)
|
||||||
|
http.Error(w, "403 Forbidden", http.StatusForbidden)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rate limit
|
||||||
|
if e.rateLimitEnabled {
|
||||||
|
if !e.checkRateLimit(clientIP) {
|
||||||
|
e.logger.Warn("Rate limit exceeded", "ip", clientIP)
|
||||||
|
w.Header().Set("Retry-After", "60")
|
||||||
|
http.Error(w, "429 Too Many Requests", http.StatusTooManyRequests)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessResponse adds security headers to the response
|
||||||
|
func (e *SecurityExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
for header, value := range e.securityHeaders {
|
||||||
|
w.Header().Set(header, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SecurityExtension) isBlocked(ip string, parsedIP net.IP) bool {
|
||||||
|
// Check exact match
|
||||||
|
if e.blockedIPs[ip] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check CIDR ranges
|
||||||
|
if parsedIP != nil {
|
||||||
|
for _, cidr := range e.blockedCIDRs {
|
||||||
|
if cidr.Contains(parsedIP) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SecurityExtension) isAllowed(ip string, parsedIP net.IP) bool {
|
||||||
|
// Check exact match
|
||||||
|
if e.allowedIPs[ip] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check CIDR ranges
|
||||||
|
if parsedIP != nil {
|
||||||
|
for _, cidr := range e.allowedCIDRs {
|
||||||
|
if cidr.Contains(parsedIP) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SecurityExtension) checkRateLimit(ip string) bool {
|
||||||
|
e.rateLimitMu.Lock()
|
||||||
|
defer e.rateLimitMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
entry, exists := e.rateLimitByIP[ip]
|
||||||
|
|
||||||
|
if !exists || now.After(entry.resetTime) {
|
||||||
|
// Create new entry or reset expired one
|
||||||
|
e.rateLimitByIP[ip] = &rateLimitEntry{
|
||||||
|
count: 1,
|
||||||
|
resetTime: now.Add(e.rateLimitWindow),
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment counter
|
||||||
|
entry.count++
|
||||||
|
return entry.count <= e.rateLimitRequests
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBlockedIP adds an IP to the blocked list at runtime
|
||||||
|
func (e *SecurityExtension) AddBlockedIP(ip string) {
|
||||||
|
e.addBlockedIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveBlockedIP removes an IP from the blocked list
|
||||||
|
func (e *SecurityExtension) RemoveBlockedIP(ip string) {
|
||||||
|
delete(e.blockedIPs, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAllowedIP adds an IP to the allowed list at runtime
|
||||||
|
func (e *SecurityExtension) AddAllowedIP(ip string) {
|
||||||
|
e.addAllowedIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllowedIP removes an IP from the allowed list
|
||||||
|
func (e *SecurityExtension) RemoveAllowedIP(ip string) {
|
||||||
|
delete(e.allowedIPs, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSecurityHeader sets or updates a security header
|
||||||
|
func (e *SecurityExtension) SetSecurityHeader(name, value string) {
|
||||||
|
e.securityHeaders[name] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics returns security metrics
|
||||||
|
func (e *SecurityExtension) GetMetrics() map[string]interface{} {
|
||||||
|
e.rateLimitMu.RLock()
|
||||||
|
activeRateLimits := len(e.rateLimitByIP)
|
||||||
|
e.rateLimitMu.RUnlock()
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"allowed_ips": len(e.allowedIPs),
|
||||||
|
"allowed_cidrs": len(e.allowedCIDRs),
|
||||||
|
"blocked_ips": len(e.blockedIPs),
|
||||||
|
"blocked_cidrs": len(e.blockedCIDRs),
|
||||||
|
"security_headers": len(e.securityHeaders),
|
||||||
|
"rate_limit_enabled": e.rateLimitEnabled,
|
||||||
|
"active_rate_limits": activeRateLimits,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup cleans up rate limit entries periodically
|
||||||
|
func (e *SecurityExtension) Cleanup() error {
|
||||||
|
e.rateLimitMu.Lock()
|
||||||
|
defer e.rateLimitMu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for ip, entry := range e.rateLimitByIP {
|
||||||
|
if now.After(entry.resetTime) {
|
||||||
|
delete(e.rateLimitByIP, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
213
go/internal/extension/security_test.go
Normal file
213
go/internal/extension/security_test.go
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSecurityExtension(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
|
||||||
|
ext, err := NewSecurityExtension(map[string]interface{}{}, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create security extension: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ext.Name() != "security" {
|
||||||
|
t.Errorf("Expected name 'security', got %s", ext.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
if ext.Priority() != 10 {
|
||||||
|
t.Errorf("Expected priority 10, got %d", ext.Priority())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityExtension_BlockedIP(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
|
||||||
|
ext, _ := NewSecurityExtension(map[string]interface{}{
|
||||||
|
"blocked_ips": []interface{}{"192.168.1.100"},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handled, err := ext.ProcessRequest(context.Background(), rr, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !handled {
|
||||||
|
t.Error("Expected blocked request to be handled")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rr.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Expected status 403, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityExtension_AllowedIP(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
|
||||||
|
ext, _ := NewSecurityExtension(map[string]interface{}{
|
||||||
|
"allowed_ips": []interface{}{"192.168.1.50"},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
// Allowed IP
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.50:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
|
||||||
|
if handled {
|
||||||
|
t.Error("Expected allowed IP to pass through")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not allowed IP
|
||||||
|
req = httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.51:12345"
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
|
||||||
|
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
|
||||||
|
if !handled {
|
||||||
|
t.Error("Expected non-allowed IP to be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rr.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("Expected status 403, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityExtension_CIDR(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
|
||||||
|
ext, _ := NewSecurityExtension(map[string]interface{}{
|
||||||
|
"blocked_ips": []interface{}{"10.0.0.0/8"},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
// IP in blocked CIDR
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "10.1.2.3:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
|
||||||
|
if !handled {
|
||||||
|
t.Error("Expected IP in blocked CIDR to be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP not in blocked CIDR
|
||||||
|
req = httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
|
||||||
|
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
|
||||||
|
if handled {
|
||||||
|
t.Error("Expected IP not in blocked CIDR to pass through")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityExtension_SecurityHeaders(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
|
||||||
|
ext, _ := NewSecurityExtension(map[string]interface{}{
|
||||||
|
"security_headers": map[string]interface{}{
|
||||||
|
"X-Custom-Header": "custom-value",
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ext.ProcessResponse(context.Background(), rr, req)
|
||||||
|
|
||||||
|
// Check default headers
|
||||||
|
if rr.Header().Get("X-Content-Type-Options") != "nosniff" {
|
||||||
|
t.Error("Expected X-Content-Type-Options header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check custom header
|
||||||
|
if rr.Header().Get("X-Custom-Header") != "custom-value" {
|
||||||
|
t.Error("Expected custom header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityExtension_RateLimit(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
|
||||||
|
ext, _ := NewSecurityExtension(map[string]interface{}{
|
||||||
|
"rate_limit": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"requests": 2,
|
||||||
|
"window": "1m",
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
securityExt := ext.(*SecurityExtension)
|
||||||
|
clientIP := "192.168.1.1"
|
||||||
|
|
||||||
|
// First request - should pass
|
||||||
|
if !securityExt.checkRateLimit(clientIP) {
|
||||||
|
t.Error("First request should pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second request - should pass
|
||||||
|
if !securityExt.checkRateLimit(clientIP) {
|
||||||
|
t.Error("Second request should pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third request - should be rate limited
|
||||||
|
if securityExt.checkRateLimit(clientIP) {
|
||||||
|
t.Error("Third request should be rate limited")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityExtension_GetMetrics(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
|
||||||
|
ext, _ := NewSecurityExtension(map[string]interface{}{
|
||||||
|
"blocked_ips": []interface{}{"192.168.1.1"},
|
||||||
|
"allowed_ips": []interface{}{"192.168.1.2"},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
securityExt := ext.(*SecurityExtension)
|
||||||
|
metrics := securityExt.GetMetrics()
|
||||||
|
|
||||||
|
if metrics["blocked_ips"].(int) != 1 {
|
||||||
|
t.Errorf("Expected 1 blocked IP, got %v", metrics["blocked_ips"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if metrics["allowed_ips"].(int) != 1 {
|
||||||
|
t.Errorf("Expected 1 allowed IP, got %v", metrics["allowed_ips"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecurityExtension_AddRemoveIPs(t *testing.T) {
|
||||||
|
logger := newTestLogger()
|
||||||
|
|
||||||
|
ext, _ := NewSecurityExtension(map[string]interface{}{}, logger)
|
||||||
|
securityExt := ext.(*SecurityExtension)
|
||||||
|
|
||||||
|
// Add blocked IP
|
||||||
|
securityExt.AddBlockedIP("192.168.1.100")
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handled, _ := ext.ProcessRequest(context.Background(), rr, req)
|
||||||
|
if !handled {
|
||||||
|
t.Error("Expected dynamically blocked IP to be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove blocked IP
|
||||||
|
securityExt.RemoveBlockedIP("192.168.1.100")
|
||||||
|
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
handled, _ = ext.ProcessRequest(context.Background(), rr, req)
|
||||||
|
if handled {
|
||||||
|
t.Error("Expected removed blocked IP to pass through")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,136 +1,338 @@
|
|||||||
|
// Package logging provides structured logging with zap
|
||||||
package logging
|
package logging
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/config"
|
"github.com/konduktor/konduktor/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Config is a simple configuration for basic logger setup
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Level string
|
Level string
|
||||||
TimestampFormat string
|
TimestampFormat string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logger wraps zap.SugaredLogger with additional functionality
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
level string
|
*zap.SugaredLogger
|
||||||
timestampFormat string
|
zap *zap.Logger
|
||||||
configFull *config.LoggingConfig
|
config *config.LoggingConfig
|
||||||
|
name string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New creates a new Logger with basic configuration
|
||||||
func New(cfg Config) (*Logger, error) {
|
func New(cfg Config) (*Logger, error) {
|
||||||
|
level := parseLevel(cfg.Level)
|
||||||
|
|
||||||
timestampFormat := cfg.TimestampFormat
|
timestampFormat := cfg.TimestampFormat
|
||||||
if timestampFormat == "" {
|
if timestampFormat == "" {
|
||||||
timestampFormat = "2006-01-02 15:04:05"
|
timestampFormat = "2006-01-02 15:04:05"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encoderConfig := zap.NewProductionEncoderConfig()
|
||||||
|
encoderConfig.TimeKey = "timestamp"
|
||||||
|
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(timestampFormat)
|
||||||
|
encoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||||
|
|
||||||
|
core := zapcore.NewCore(
|
||||||
|
zapcore.NewConsoleEncoder(encoderConfig),
|
||||||
|
zapcore.AddSync(os.Stdout),
|
||||||
|
level,
|
||||||
|
)
|
||||||
|
|
||||||
|
zapLogger := zap.New(core)
|
||||||
return &Logger{
|
return &Logger{
|
||||||
level: cfg.Level,
|
SugaredLogger: zapLogger.Sugar(),
|
||||||
timestampFormat: timestampFormat,
|
zap: zapLogger,
|
||||||
|
name: "konduktor",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewFromConfig creates a Logger from full LoggingConfig
|
||||||
func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) {
|
func NewFromConfig(cfg config.LoggingConfig) (*Logger, error) {
|
||||||
timestampFormat := cfg.Format.TimestampFormat
|
var cores []zapcore.Core
|
||||||
|
|
||||||
|
// Parse main level
|
||||||
|
mainLevel := parseLevel(cfg.Level)
|
||||||
|
|
||||||
|
// Add console core if enabled
|
||||||
|
if cfg.ConsoleOutput {
|
||||||
|
consoleLevel := mainLevel
|
||||||
|
if cfg.Console != nil && cfg.Console.Level != "" {
|
||||||
|
consoleLevel = parseLevel(cfg.Console.Level)
|
||||||
|
}
|
||||||
|
|
||||||
|
var consoleEncoder zapcore.Encoder
|
||||||
|
formatConfig := cfg.Format
|
||||||
|
if cfg.Console != nil {
|
||||||
|
formatConfig = mergeFormatConfig(cfg.Format, cfg.Console.Format)
|
||||||
|
}
|
||||||
|
|
||||||
|
encoderCfg := createEncoderConfig(formatConfig)
|
||||||
|
if formatConfig.Type == "json" {
|
||||||
|
consoleEncoder = zapcore.NewJSONEncoder(encoderCfg)
|
||||||
|
} else {
|
||||||
|
if formatConfig.UseColors {
|
||||||
|
encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||||
|
}
|
||||||
|
consoleEncoder = zapcore.NewConsoleEncoder(encoderCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
consoleSyncer := zapcore.AddSync(os.Stdout)
|
||||||
|
cores = append(cores, zapcore.NewCore(consoleEncoder, consoleSyncer, consoleLevel))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add file cores
|
||||||
|
for _, fileConfig := range cfg.Files {
|
||||||
|
fileCore, err := createFileCore(fileConfig, cfg.Format, mainLevel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create file logger for %s: %w", fileConfig.Path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If specific loggers are configured, wrap with filter
|
||||||
|
if len(fileConfig.Loggers) > 0 {
|
||||||
|
fileCore = &filteredCore{
|
||||||
|
Core: fileCore,
|
||||||
|
loggers: fileConfig.Loggers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cores = append(cores, fileCore)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no cores configured, add default console
|
||||||
|
if len(cores) == 0 {
|
||||||
|
encoderCfg := zap.NewProductionEncoderConfig()
|
||||||
|
encoderCfg.EncodeTime = zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05")
|
||||||
|
encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||||
|
cores = append(cores, zapcore.NewCore(
|
||||||
|
zapcore.NewConsoleEncoder(encoderCfg),
|
||||||
|
zapcore.AddSync(os.Stdout),
|
||||||
|
mainLevel,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine all cores
|
||||||
|
core := zapcore.NewTee(cores...)
|
||||||
|
zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1))
|
||||||
|
|
||||||
|
return &Logger{
|
||||||
|
SugaredLogger: zapLogger.Sugar(),
|
||||||
|
zap: zapLogger,
|
||||||
|
config: &cfg,
|
||||||
|
name: "konduktor",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Named returns a logger with a specific name (for filtering)
|
||||||
|
func (l *Logger) Named(name string) *Logger {
|
||||||
|
return &Logger{
|
||||||
|
SugaredLogger: l.SugaredLogger.Named(name),
|
||||||
|
zap: l.zap.Named(name),
|
||||||
|
config: l.config,
|
||||||
|
name: name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// With returns a logger with additional fields
|
||||||
|
func (l *Logger) With(args ...interface{}) *Logger {
|
||||||
|
return &Logger{
|
||||||
|
SugaredLogger: l.SugaredLogger.With(args...),
|
||||||
|
zap: l.zap.Sugar().With(args...).Desugar(),
|
||||||
|
config: l.config,
|
||||||
|
name: l.name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync flushes any buffered log entries
|
||||||
|
func (l *Logger) Sync() error {
|
||||||
|
return l.zap.Sync()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetZap returns the underlying zap.Logger
|
||||||
|
func (l *Logger) GetZap() *zap.Logger {
|
||||||
|
return l.zap
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug logs a debug message
|
||||||
|
func (l *Logger) Debug(msg string, keysAndValues ...interface{}) {
|
||||||
|
l.SugaredLogger.Debugw(msg, keysAndValues...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info logs an info message
|
||||||
|
func (l *Logger) Info(msg string, keysAndValues ...interface{}) {
|
||||||
|
l.SugaredLogger.Infow(msg, keysAndValues...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn logs a warning message
|
||||||
|
func (l *Logger) Warn(msg string, keysAndValues ...interface{}) {
|
||||||
|
l.SugaredLogger.Warnw(msg, keysAndValues...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error logs an error message
|
||||||
|
func (l *Logger) Error(msg string, keysAndValues ...interface{}) {
|
||||||
|
l.SugaredLogger.Errorw(msg, keysAndValues...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fatal logs a fatal message and exits
|
||||||
|
func (l *Logger) Fatal(msg string, keysAndValues ...interface{}) {
|
||||||
|
l.SugaredLogger.Fatalw(msg, keysAndValues...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper functions ---
|
||||||
|
|
||||||
|
func parseLevel(level string) zapcore.Level {
|
||||||
|
switch strings.ToUpper(level) {
|
||||||
|
case "DEBUG":
|
||||||
|
return zapcore.DebugLevel
|
||||||
|
case "INFO":
|
||||||
|
return zapcore.InfoLevel
|
||||||
|
case "WARN", "WARNING":
|
||||||
|
return zapcore.WarnLevel
|
||||||
|
case "ERROR":
|
||||||
|
return zapcore.ErrorLevel
|
||||||
|
case "CRITICAL", "FATAL":
|
||||||
|
return zapcore.FatalLevel
|
||||||
|
default:
|
||||||
|
return zapcore.InfoLevel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createEncoderConfig(format config.LogFormatConfig) zapcore.EncoderConfig {
|
||||||
|
timestampFormat := format.TimestampFormat
|
||||||
if timestampFormat == "" {
|
if timestampFormat == "" {
|
||||||
timestampFormat = "2006-01-02 15:04:05"
|
timestampFormat = "2006-01-02 15:04:05"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Logger{
|
cfg := zapcore.EncoderConfig{
|
||||||
level: cfg.Level,
|
TimeKey: "timestamp",
|
||||||
timestampFormat: timestampFormat,
|
LevelKey: "level",
|
||||||
configFull: &cfg,
|
NameKey: "logger",
|
||||||
}, nil
|
CallerKey: "caller",
|
||||||
|
FunctionKey: zapcore.OmitKey,
|
||||||
|
MessageKey: "msg",
|
||||||
|
StacktraceKey: "stacktrace",
|
||||||
|
LineEnding: zapcore.DefaultLineEnding,
|
||||||
|
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||||
|
EncodeTime: zapcore.TimeEncoderOfLayout(timestampFormat),
|
||||||
|
EncodeDuration: zapcore.SecondsDurationEncoder,
|
||||||
|
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) formatTime() string {
|
if !format.ShowModule {
|
||||||
return time.Now().Format(l.timestampFormat)
|
cfg.NameKey = zapcore.OmitKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) log(level string, msg string, fields ...interface{}) {
|
return cfg
|
||||||
timestamp := l.formatTime()
|
|
||||||
|
|
||||||
// Simple console output for now
|
|
||||||
// TODO: Implement proper structured logging with zap
|
|
||||||
output := timestamp + " [" + level + "] " + msg
|
|
||||||
|
|
||||||
if len(fields) > 0 {
|
|
||||||
output += " {"
|
|
||||||
for i := 0; i < len(fields); i += 2 {
|
|
||||||
if i > 0 {
|
|
||||||
output += ", "
|
|
||||||
}
|
|
||||||
if i+1 < len(fields) {
|
|
||||||
output += fields[i].(string) + "=" + formatValue(fields[i+1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
output += "}"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Stdout.WriteString(output + "\n")
|
func mergeFormatConfig(base, override config.LogFormatConfig) config.LogFormatConfig {
|
||||||
|
result := base
|
||||||
|
if override.Type != "" {
|
||||||
|
result.Type = override.Type
|
||||||
|
}
|
||||||
|
if override.TimestampFormat != "" {
|
||||||
|
result.TimestampFormat = override.TimestampFormat
|
||||||
|
}
|
||||||
|
// UseColors and ShowModule are bool - check if override has non-default
|
||||||
|
result.UseColors = override.UseColors
|
||||||
|
result.ShowModule = override.ShowModule
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatValue(v interface{}) string {
|
func createFileCore(fileConfig config.FileLogConfig, defaultFormat config.LogFormatConfig, defaultLevel zapcore.Level) (zapcore.Core, error) {
|
||||||
switch val := v.(type) {
|
// Ensure directory exists
|
||||||
case string:
|
dir := filepath.Dir(fileConfig.Path)
|
||||||
return val
|
if dir != "" && dir != "." {
|
||||||
case int:
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
return fmt.Sprintf("%d", val)
|
return nil, fmt.Errorf("failed to create log directory %s: %w", dir, err)
|
||||||
case int64:
|
|
||||||
return fmt.Sprintf("%d", val)
|
|
||||||
case float64:
|
|
||||||
return fmt.Sprintf("%.2f", val)
|
|
||||||
case bool:
|
|
||||||
return fmt.Sprintf("%t", val)
|
|
||||||
case error:
|
|
||||||
return val.Error()
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%v", val)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Debug(msg string, fields ...interface{}) {
|
// Configure log rotation with lumberjack
|
||||||
if l.shouldLog("DEBUG") {
|
maxSize := 10 // MB
|
||||||
l.log("DEBUG", msg, fields...)
|
if fileConfig.MaxBytes > 0 {
|
||||||
|
maxSize = int(fileConfig.MaxBytes / (1024 * 1024))
|
||||||
|
if maxSize < 1 {
|
||||||
|
maxSize = 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Info(msg string, fields ...interface{}) {
|
backupCount := 5
|
||||||
if l.shouldLog("INFO") {
|
if fileConfig.BackupCount > 0 {
|
||||||
l.log("INFO", msg, fields...)
|
backupCount = fileConfig.BackupCount
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Warn(msg string, fields ...interface{}) {
|
rotator := &lumberjack.Logger{
|
||||||
if l.shouldLog("WARN") {
|
Filename: fileConfig.Path,
|
||||||
l.log("WARN", msg, fields...)
|
MaxSize: maxSize,
|
||||||
}
|
MaxBackups: backupCount,
|
||||||
|
MaxAge: 30, // days
|
||||||
|
Compress: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Error(msg string, fields ...interface{}) {
|
// Determine level
|
||||||
if l.shouldLog("ERROR") {
|
level := defaultLevel
|
||||||
l.log("ERROR", msg, fields...)
|
if fileConfig.Level != "" {
|
||||||
}
|
level = parseLevel(fileConfig.Level)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) shouldLog(level string) bool {
|
// Create encoder
|
||||||
levels := map[string]int{
|
format := defaultFormat
|
||||||
"DEBUG": 0,
|
if fileConfig.Format.Type != "" {
|
||||||
"INFO": 1,
|
format = mergeFormatConfig(defaultFormat, fileConfig.Format)
|
||||||
"WARN": 2,
|
}
|
||||||
"ERROR": 3,
|
// Files should not use colors
|
||||||
|
format.UseColors = false
|
||||||
|
|
||||||
|
encoderConfig := createEncoderConfig(format)
|
||||||
|
var encoder zapcore.Encoder
|
||||||
|
if format.Type == "json" {
|
||||||
|
encoder = zapcore.NewJSONEncoder(encoderConfig)
|
||||||
|
} else {
|
||||||
|
encoder = zapcore.NewConsoleEncoder(encoderConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
currentLevel, ok := levels[l.level]
|
return zapcore.NewCore(encoder, zapcore.AddSync(rotator), level), nil
|
||||||
if !ok {
|
|
||||||
currentLevel = 1 // Default to INFO
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msgLevel, ok := levels[level]
|
// filteredCore wraps a Core to filter by logger name
|
||||||
if !ok {
|
type filteredCore struct {
|
||||||
msgLevel = 1
|
zapcore.Core
|
||||||
|
loggers []string
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgLevel >= currentLevel
|
func (c *filteredCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||||
|
if !c.shouldLog(entry.LoggerName) {
|
||||||
|
return ce
|
||||||
|
}
|
||||||
|
return c.Core.Check(entry, ce)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *filteredCore) shouldLog(loggerName string) bool {
|
||||||
|
if len(c.loggers) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowed := range c.loggers {
|
||||||
|
if loggerName == allowed || strings.HasPrefix(loggerName, allowed+".") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *filteredCore) With(fields []zapcore.Field) zapcore.Core {
|
||||||
|
return &filteredCore{
|
||||||
|
Core: c.Core.With(fields),
|
||||||
|
loggers: c.loggers,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,6 +2,8 @@ package logging
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
@ -15,71 +17,84 @@ func TestNew(t *testing.T) {
|
|||||||
t.Fatal("Expected logger, got nil")
|
t.Fatal("Expected logger, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if logger.level != "INFO" {
|
if logger.name != "konduktor" {
|
||||||
t.Errorf("Expected level INFO, got %s", logger.level)
|
t.Errorf("Expected name konduktor, got %s", logger.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNew_DefaultTimestampFormat(t *testing.T) {
|
func TestNew_DefaultTimestampFormat(t *testing.T) {
|
||||||
logger, _ := New(Config{Level: "DEBUG"})
|
logger, err := New(Config{Level: "DEBUG"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if logger.timestampFormat != "2006-01-02 15:04:05" {
|
// Logger should be created successfully
|
||||||
t.Errorf("Expected default timestamp format, got %s", logger.timestampFormat)
|
if logger == nil {
|
||||||
|
t.Fatal("Expected logger, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNew_CustomTimestampFormat(t *testing.T) {
|
func TestNew_CustomTimestampFormat(t *testing.T) {
|
||||||
logger, _ := New(Config{
|
logger, err := New(Config{
|
||||||
Level: "DEBUG",
|
Level: "DEBUG",
|
||||||
TimestampFormat: "15:04:05",
|
TimestampFormat: "15:04:05",
|
||||||
})
|
})
|
||||||
|
|
||||||
if logger.timestampFormat != "15:04:05" {
|
if err != nil {
|
||||||
t.Errorf("Expected custom timestamp format, got %s", logger.timestampFormat)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger == nil {
|
||||||
|
t.Fatal("Expected logger, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLogger_ShouldLog(t *testing.T) {
|
func TestNewFromConfig(t *testing.T) {
|
||||||
tests := []struct {
|
cfg := config.LoggingConfig{
|
||||||
loggerLevel string
|
Level: "DEBUG",
|
||||||
msgLevel string
|
ConsoleOutput: true,
|
||||||
shouldLog bool
|
Format: config.LogFormatConfig{
|
||||||
}{
|
Type: "standard",
|
||||||
{"DEBUG", "DEBUG", true},
|
UseColors: true,
|
||||||
{"DEBUG", "INFO", true},
|
ShowModule: true,
|
||||||
{"DEBUG", "WARN", true},
|
TimestampFormat: "2006-01-02 15:04:05",
|
||||||
{"DEBUG", "ERROR", true},
|
},
|
||||||
{"INFO", "DEBUG", false},
|
|
||||||
{"INFO", "INFO", true},
|
|
||||||
{"INFO", "WARN", true},
|
|
||||||
{"INFO", "ERROR", true},
|
|
||||||
{"WARN", "DEBUG", false},
|
|
||||||
{"WARN", "INFO", false},
|
|
||||||
{"WARN", "WARN", true},
|
|
||||||
{"WARN", "ERROR", true},
|
|
||||||
{"ERROR", "DEBUG", false},
|
|
||||||
{"ERROR", "INFO", false},
|
|
||||||
{"ERROR", "WARN", false},
|
|
||||||
{"ERROR", "ERROR", true},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
logger, err := NewFromConfig(cfg)
|
||||||
t.Run(tt.loggerLevel+"_"+tt.msgLevel, func(t *testing.T) {
|
if err != nil {
|
||||||
logger, _ := New(Config{Level: tt.loggerLevel})
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
|
||||||
if got := logger.shouldLog(tt.msgLevel); got != tt.shouldLog {
|
|
||||||
t.Errorf("shouldLog(%s) = %v, want %v", tt.msgLevel, got, tt.shouldLog)
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
if logger == nil {
|
||||||
|
t.Fatal("Expected logger, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLogger_ShouldLog_InvalidLevel(t *testing.T) {
|
func TestNewFromConfig_WithConsole(t *testing.T) {
|
||||||
logger, _ := New(Config{Level: "INVALID"})
|
cfg := config.LoggingConfig{
|
||||||
|
Level: "INFO",
|
||||||
|
ConsoleOutput: true,
|
||||||
|
Format: config.LogFormatConfig{
|
||||||
|
Type: "standard",
|
||||||
|
UseColors: true,
|
||||||
|
},
|
||||||
|
Console: &config.ConsoleLogConfig{
|
||||||
|
Level: "DEBUG",
|
||||||
|
Format: config.LogFormatConfig{
|
||||||
|
Type: "standard",
|
||||||
|
UseColors: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Should default to INFO level
|
logger, err := NewFromConfig(cfg)
|
||||||
if !logger.shouldLog("INFO") {
|
if err != nil {
|
||||||
t.Error("Invalid level should default to INFO")
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger == nil {
|
||||||
|
t.Fatal("Expected logger, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,34 +126,65 @@ func TestLogger_Error(t *testing.T) {
|
|||||||
logger.Error("test message", "key", "value")
|
logger.Error("test message", "key", "value")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatValue(t *testing.T) {
|
func TestLogger_Named(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "INFO"})
|
||||||
|
named := logger.Named("test.module")
|
||||||
|
|
||||||
|
if named == nil {
|
||||||
|
t.Fatal("Expected named logger, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if named.name != "test.module" {
|
||||||
|
t.Errorf("Expected name 'test.module', got %s", named.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
named.Info("test from named logger")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_With(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "INFO"})
|
||||||
|
withFields := logger.With("service", "test")
|
||||||
|
|
||||||
|
if withFields == nil {
|
||||||
|
t.Fatal("Expected logger with fields, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
withFields.Info("test with fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_Sync(t *testing.T) {
|
||||||
|
logger, _ := New(Config{Level: "INFO"})
|
||||||
|
|
||||||
|
// Should not panic
|
||||||
|
err := logger.Sync()
|
||||||
|
// Sync may return an error for stdout on some systems, ignore it
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseLevel(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input interface{}
|
input string
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{"test", "test"},
|
{"DEBUG", "debug"},
|
||||||
{42, "*"}, // int converts to rune
|
{"INFO", "info"},
|
||||||
{nil, ""},
|
{"WARN", "warn"},
|
||||||
|
{"WARNING", "warn"},
|
||||||
|
{"ERROR", "error"},
|
||||||
|
{"CRITICAL", "fatal"},
|
||||||
|
{"FATAL", "fatal"},
|
||||||
|
{"invalid", "info"}, // defaults to INFO
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
got := formatValue(tt.input)
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
// Just check it doesn't panic
|
level := parseLevel(tt.input)
|
||||||
_ = got
|
if level.String() != tt.expected {
|
||||||
|
t.Errorf("parseLevel(%s) = %s, want %s", tt.input, level.String(), tt.expected)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestLogger_FormatTime(t *testing.T) {
|
|
||||||
logger, _ := New(Config{
|
|
||||||
Level: "INFO",
|
|
||||||
TimestampFormat: "2006-01-02",
|
|
||||||
})
|
})
|
||||||
|
|
||||||
result := logger.formatTime()
|
|
||||||
|
|
||||||
// Should be in expected format (YYYY-MM-DD)
|
|
||||||
if len(result) != 10 {
|
|
||||||
t.Errorf("Expected date format YYYY-MM-DD, got %s", result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,12 +207,3 @@ func BenchmarkLogger_Debug_Filtered(b *testing.B) {
|
|||||||
logger.Debug("test message", "key", "value")
|
logger.Debug("test message", "key", "value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLogger_ShouldLog(b *testing.B) {
|
|
||||||
logger, _ := New(Config{Level: "INFO"})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
logger.shouldLog("DEBUG")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -29,6 +29,10 @@ type Config struct {
|
|||||||
|
|
||||||
// PreserveHost keeps the original Host header
|
// PreserveHost keeps the original Host header
|
||||||
PreserveHost bool
|
PreserveHost bool
|
||||||
|
|
||||||
|
// IgnoreRequestPath ignores the request path and uses only the target path
|
||||||
|
// This is useful for exact match routes where target URL should be used as-is
|
||||||
|
IgnoreRequestPath bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReverseProxy struct {
|
type ReverseProxy struct {
|
||||||
@ -116,6 +120,13 @@ func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, par
|
|||||||
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
|
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
|
||||||
targetURL := *rp.targetURL
|
targetURL := *rp.targetURL
|
||||||
|
|
||||||
|
// If ignoring request path, use target URL path as-is
|
||||||
|
if rp.config.IgnoreRequestPath {
|
||||||
|
// Preserve query string only
|
||||||
|
targetURL.RawQuery = r.URL.RawQuery
|
||||||
|
return &targetURL
|
||||||
|
}
|
||||||
|
|
||||||
// Strip prefix if configured
|
// Strip prefix if configured
|
||||||
path := r.URL.Path
|
path := r.URL.Path
|
||||||
if rp.config.StripPrefix != "" {
|
if rp.config.StripPrefix != "" {
|
||||||
@ -125,10 +136,12 @@ func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If target has a path, append request path to it
|
// If target URL has a non-empty path, combine it with the request path
|
||||||
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
|
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
|
||||||
targetURL.Path = singleJoiningSlash(rp.targetURL.Path, path)
|
// Combine target path with request path
|
||||||
|
targetURL.Path = strings.TrimSuffix(rp.targetURL.Path, "/") + path
|
||||||
} else {
|
} else {
|
||||||
|
// No path in target, use request path as-is
|
||||||
targetURL.Path = path
|
targetURL.Path = path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -11,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/config"
|
"github.com/konduktor/konduktor/internal/config"
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
"github.com/konduktor/konduktor/internal/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RouteMatch represents a matched route with captured parameters
|
// RouteMatch represents a matched route with captured parameters
|
||||||
@ -55,6 +57,24 @@ func New(cfg *config.Config, logger *logging.Logger) *Router {
|
|||||||
regexRoutes: make([]*RegexRoute, 0),
|
regexRoutes: make([]*RegexRoute, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load routes from extensions
|
||||||
|
if cfg != nil {
|
||||||
|
for _, ext := range cfg.Extensions {
|
||||||
|
if ext.Type == "routing" && ext.Config != nil {
|
||||||
|
if locations, ok := ext.Config["regex_locations"].(map[string]interface{}); ok {
|
||||||
|
for pattern, routeCfg := range locations {
|
||||||
|
if rc, ok := routeCfg.(map[string]interface{}); ok {
|
||||||
|
r.AddRoute(pattern, rc)
|
||||||
|
if logger != nil {
|
||||||
|
logger.Debug("Added route", "pattern", pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
r.setupRoutes()
|
r.setupRoutes()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@ -250,17 +270,27 @@ func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
// Try to match against configured routes
|
// Try to match against configured routes
|
||||||
match := r.Match(path)
|
match := r.Match(path)
|
||||||
|
fmt.Printf("DEBUG defaultHandler: path=%q match=%v defaultRoute=%v\n", path, match != nil, r.defaultRoute != nil)
|
||||||
if match != nil {
|
if match != nil {
|
||||||
|
fmt.Printf("DEBUG: matched config: %v\n", match.Config)
|
||||||
r.handleRouteMatch(w, req, match)
|
r.handleRouteMatch(w, req, match)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to serve static file
|
// Try to serve static file
|
||||||
if r.staticDir != "" {
|
if r.staticDir != "" {
|
||||||
filePath := filepath.Join(r.staticDir, path)
|
// Get absolute path for static dir
|
||||||
|
absStaticDir, err := filepath.Abs(r.staticDir)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Prevent directory traversal
|
filePath := filepath.Join(absStaticDir, filepath.Clean("/"+path))
|
||||||
if !strings.HasPrefix(filepath.Clean(filePath), filepath.Clean(r.staticDir)) {
|
cleanPath := filepath.Clean(filePath)
|
||||||
|
|
||||||
|
// Prevent directory traversal - ensure path is within static dir
|
||||||
|
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absStaticDir+string(filepath.Separator)) {
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -290,6 +320,12 @@ func (r *Router) defaultHandler(w http.ResponseWriter, req *http.Request) {
|
|||||||
func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) {
|
func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, match *RouteMatch) {
|
||||||
cfg := match.Config
|
cfg := match.Config
|
||||||
|
|
||||||
|
// Handle proxy_pass directive
|
||||||
|
if proxyTarget, ok := cfg["proxy_pass"].(string); ok {
|
||||||
|
r.handleProxyPass(w, req, proxyTarget, cfg, match.Params)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Handle "return" directive
|
// Handle "return" directive
|
||||||
if ret, ok := cfg["return"].(string); ok {
|
if ret, ok := cfg["return"].(string); ok {
|
||||||
parts := strings.SplitN(ret, " ", 2)
|
parts := strings.SplitN(ret, " ", 2)
|
||||||
@ -338,7 +374,26 @@ func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, matc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
filePath := filepath.Join(root, path)
|
// Get absolute path for root dir
|
||||||
|
absRoot, err := filepath.Abs(root)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(absRoot, filepath.Clean("/"+path))
|
||||||
|
cleanPath := filepath.Clean(filePath)
|
||||||
|
|
||||||
|
// DEBUG
|
||||||
|
fmt.Printf("DEBUG: path=%q absRoot=%q filePath=%q cleanPath=%q\n", path, absRoot, filePath, cleanPath)
|
||||||
|
fmt.Printf("DEBUG: check1=%q check2=%q\n", cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator))
|
||||||
|
|
||||||
|
// Prevent directory traversal
|
||||||
|
if !strings.HasPrefix(cleanPath+string(filepath.Separator), absRoot+string(filepath.Separator)) {
|
||||||
|
fmt.Printf("DEBUG: FORBIDDEN - path not within root\n")
|
||||||
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if cacheControl, ok := cfg["cache_control"].(string); ok {
|
if cacheControl, ok := cfg["cache_control"].(string); ok {
|
||||||
w.Header().Set("Cache-Control", cacheControl)
|
w.Header().Set("Cache-Control", cacheControl)
|
||||||
@ -379,6 +434,48 @@ func (r *Router) handleRouteMatch(w http.ResponseWriter, req *http.Request, matc
|
|||||||
http.NotFound(w, req)
|
http.NotFound(w, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleProxyPass proxies the request to the target backend
|
||||||
|
func (r *Router) handleProxyPass(w http.ResponseWriter, req *http.Request, target string, cfg map[string]interface{}, params map[string]string) {
|
||||||
|
// Substitute params in target URL (e.g., {version} -> actual version)
|
||||||
|
for key, value := range params {
|
||||||
|
target = strings.ReplaceAll(target, "{"+key+"}", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create proxy
|
||||||
|
proxyConfig := &proxy.Config{
|
||||||
|
Target: target,
|
||||||
|
Headers: make(map[string]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse headers from config
|
||||||
|
if headers, ok := cfg["headers"].([]interface{}); ok {
|
||||||
|
for _, h := range headers {
|
||||||
|
if header, ok := h.(string); ok {
|
||||||
|
parts := strings.SplitN(header, ": ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
// Substitute params in header values
|
||||||
|
headerValue := parts[1]
|
||||||
|
for key, value := range params {
|
||||||
|
headerValue = strings.ReplaceAll(headerValue, "{"+key+"}", value)
|
||||||
|
}
|
||||||
|
proxyConfig.Headers[parts[0]] = headerValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := proxy.New(proxyConfig, r.logger)
|
||||||
|
if err != nil {
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Error("Failed to create proxy", "target", target, "error", err)
|
||||||
|
}
|
||||||
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.ProxyRequest(w, req, params)
|
||||||
|
}
|
||||||
|
|
||||||
// CreateRouterFromConfig creates a router from extension config
|
// CreateRouterFromConfig creates a router from extension config
|
||||||
func CreateRouterFromConfig(cfg map[string]interface{}) *Router {
|
func CreateRouterFromConfig(cfg map[string]interface{}) *Router {
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
|
|||||||
@ -11,18 +11,18 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/konduktor/konduktor/internal/config"
|
"github.com/konduktor/konduktor/internal/config"
|
||||||
|
"github.com/konduktor/konduktor/internal/extension"
|
||||||
"github.com/konduktor/konduktor/internal/logging"
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
"github.com/konduktor/konduktor/internal/middleware"
|
"github.com/konduktor/konduktor/internal/middleware"
|
||||||
"github.com/konduktor/konduktor/internal/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const Version = "0.1.0"
|
const Version = "0.2.0"
|
||||||
|
|
||||||
// Server represents the Konduktor HTTP server
|
// Server represents the Konduktor HTTP server
|
||||||
type Server struct {
|
type Server struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
router *routing.Router
|
extensionManager *extension.Manager
|
||||||
logger *logging.Logger
|
logger *logging.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,11 +37,30 @@ func New(cfg *config.Config) (*Server, error) {
|
|||||||
return nil, fmt.Errorf("failed to create logger: %w", err)
|
return nil, fmt.Errorf("failed to create logger: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
router := routing.New(cfg, logger)
|
// Create extension manager
|
||||||
|
extManager := extension.NewManager(logger)
|
||||||
|
|
||||||
|
// Load extensions from config
|
||||||
|
for _, extCfg := range cfg.Extensions {
|
||||||
|
// Add static_dir to routing config if not present
|
||||||
|
if extCfg.Type == "routing" {
|
||||||
|
if extCfg.Config == nil {
|
||||||
|
extCfg.Config = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
if _, ok := extCfg.Config["static_dir"]; !ok {
|
||||||
|
extCfg.Config["static_dir"] = cfg.HTTP.StaticDir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := extManager.LoadExtension(extCfg.Type, extCfg.Config); err != nil {
|
||||||
|
logger.Error("Failed to load extension", "type", extCfg.Type, "error", err)
|
||||||
|
// Continue loading other extensions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
config: cfg,
|
config: cfg,
|
||||||
router: router,
|
extensionManager: extManager,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,9 +105,15 @@ func (s *Server) Run() error {
|
|||||||
|
|
||||||
// buildHandler builds the HTTP handler chain
|
// buildHandler builds the HTTP handler chain
|
||||||
func (s *Server) buildHandler() http.Handler {
|
func (s *Server) buildHandler() http.Handler {
|
||||||
var handler http.Handler = s.router
|
// Create base handler that returns 404
|
||||||
|
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
// Add middleware
|
// Wrap with extension manager
|
||||||
|
var handler http.Handler = s.extensionManager.Handler(baseHandler)
|
||||||
|
|
||||||
|
// Add middleware (applied in reverse order)
|
||||||
handler = middleware.AccessLog(handler, s.logger)
|
handler = middleware.AccessLog(handler, s.logger)
|
||||||
handler = middleware.ServerHeader(handler, Version)
|
handler = middleware.ServerHeader(handler, Version)
|
||||||
handler = middleware.Recovery(handler, s.logger)
|
handler = middleware.Recovery(handler, s.logger)
|
||||||
@ -115,6 +140,9 @@ func (s *Server) waitForShutdown(errChan <-chan error) error {
|
|||||||
|
|
||||||
s.logger.Info("Shutting down server...")
|
s.logger.Info("Shutting down server...")
|
||||||
|
|
||||||
|
// Cleanup extensions
|
||||||
|
s.extensionManager.Cleanup()
|
||||||
|
|
||||||
if err := s.httpServer.Shutdown(ctx); err != nil {
|
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||||
s.logger.Error("Error during shutdown", "error", err)
|
s.logger.Error("Error during shutdown", "error", err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
127
go/tests/integration/README.md
Normal file
127
go/tests/integration/README.md
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# Integration Tests
|
||||||
|
|
||||||
|
Интеграционные тесты для Konduktor — полноценное тестирование сервера с реальными HTTP запросами.
|
||||||
|
|
||||||
|
## Отличие от unit-тестов
|
||||||
|
|
||||||
|
| Аспект | Unit-тесты | Интеграционные тесты |
|
||||||
|
|--------|------------|---------------------|
|
||||||
|
| Scope | Отдельный модуль в изоляции | Весь сервер целиком |
|
||||||
|
| Backend | Mock (httptest.Server) | Реальные HTTP серверы |
|
||||||
|
| Config | Программный | YAML конфигурация |
|
||||||
|
| Extensions | Не тестируются | Полная цепочка обработки |
|
||||||
|
|
||||||
|
## Структура тестов
|
||||||
|
|
||||||
|
```
|
||||||
|
tests/integration/
|
||||||
|
├── README.md # Эта документация
|
||||||
|
├── helpers_test.go # Общие хелперы и утилиты
|
||||||
|
├── reverse_proxy_test.go # Тесты reverse proxy
|
||||||
|
├── routing_test.go # Тесты маршрутизации (TODO)
|
||||||
|
├── security_test.go # Тесты security extension (TODO)
|
||||||
|
├── caching_test.go # Тесты caching extension (TODO)
|
||||||
|
└── static_files_test.go # Тесты статических файлов (TODO)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Что тестируют интеграционные тесты
|
||||||
|
|
||||||
|
### 1. Reverse Proxy (`reverse_proxy_test.go`)
|
||||||
|
|
||||||
|
- [ ] Базовое проксирование GET/POST/PUT/DELETE
|
||||||
|
- [ ] Exact match routes (`=/api/version`)
|
||||||
|
- [ ] Regex routes с параметрами (`~^/api/resource/(?P<id>\d+)$`)
|
||||||
|
- [ ] Подстановка параметров в target URL (`{id}`, `{tag}`)
|
||||||
|
- [ ] Подстановка переменных в заголовки (`$remote_addr`)
|
||||||
|
- [ ] Передача заголовков X-Forwarded-For, X-Real-IP
|
||||||
|
- [ ] Сохранение query string
|
||||||
|
- [ ] Обработка ошибок backend (502, 504)
|
||||||
|
- [ ] Таймауты соединения
|
||||||
|
|
||||||
|
### 2. Routing Extension (`routing_test.go`)
|
||||||
|
|
||||||
|
- [ ] Приоритет маршрутов (exact > regex > default)
|
||||||
|
- [ ] Case-sensitive regex (`~`)
|
||||||
|
- [ ] Case-insensitive regex (`~*`)
|
||||||
|
- [ ] Default route (`__default__`)
|
||||||
|
- [ ] Return directive (`return 200 "OK"`)
|
||||||
|
- [ ] Конфликт маршрутов
|
||||||
|
|
||||||
|
### 3. Security Extension (`security_test.go`)
|
||||||
|
|
||||||
|
- [ ] IP whitelist
|
||||||
|
- [ ] IP blacklist
|
||||||
|
- [ ] CIDR нотация (10.0.0.0/8)
|
||||||
|
- [ ] Security headers (X-Frame-Options, X-Content-Type-Options)
|
||||||
|
- [ ] Rate limiting
|
||||||
|
- [ ] Комбинация с другими extensions
|
||||||
|
|
||||||
|
### 4. Caching Extension (`caching_test.go`)
|
||||||
|
|
||||||
|
- [ ] Cache hit/miss
|
||||||
|
- [ ] TTL expiration
|
||||||
|
- [ ] Pattern-based caching
|
||||||
|
- [ ] Cache-Control headers
|
||||||
|
- [ ] Cache invalidation
|
||||||
|
- [ ] Max cache size и eviction
|
||||||
|
|
||||||
|
### 5. Static Files (`static_files_test.go`)
|
||||||
|
|
||||||
|
- [ ] Serving статических файлов
|
||||||
|
- [ ] Index file (index.html)
|
||||||
|
- [ ] MIME types
|
||||||
|
- [ ] Cache-Control для static
|
||||||
|
- [ ] SPA fallback
|
||||||
|
- [ ] Directory traversal protection
|
||||||
|
- [ ] 404 для несуществующих файлов
|
||||||
|
|
||||||
|
### 6. Extension Chain (`extension_chain_test.go`)
|
||||||
|
|
||||||
|
- [ ] Порядок выполнения extensions (security → caching → routing)
|
||||||
|
- [ ] Прерывание цепочки при ошибке
|
||||||
|
- [ ] Совместная работа extensions
|
||||||
|
|
||||||
|
## Запуск тестов
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Все интеграционные тесты
|
||||||
|
go test ./tests/integration/... -v
|
||||||
|
|
||||||
|
# Конкретный файл
|
||||||
|
go test ./tests/integration/... -v -run TestReverseProxy
|
||||||
|
|
||||||
|
# С таймаутом (интеграционные тесты медленнее)
|
||||||
|
go test ./tests/integration/... -v -timeout 60s
|
||||||
|
|
||||||
|
# С покрытием
|
||||||
|
go test ./tests/integration/... -v -coverprofile=coverage.out
|
||||||
|
```
|
||||||
|
|
||||||
|
## Требования
|
||||||
|
|
||||||
|
- Свободные порты: тесты используют случайные порты (`:0`)
|
||||||
|
- Сетевой доступ: для localhost соединений
|
||||||
|
- Время: интеграционные тесты занимают больше времени (~5-10 сек)
|
||||||
|
|
||||||
|
## Добавление новых тестов
|
||||||
|
|
||||||
|
1. Создайте файл `*_test.go` в `tests/integration/`
|
||||||
|
2. Используйте хелперы из `helpers_test.go`:
|
||||||
|
- `startTestServer()` — запуск Konduktor сервера
|
||||||
|
- `startBackend()` — запуск mock backend
|
||||||
|
- `makeRequest()` — отправка HTTP запроса
|
||||||
|
3. Добавьте описание в этот README
|
||||||
|
|
||||||
|
## CI/CD
|
||||||
|
|
||||||
|
Интеграционные тесты запускаются отдельно от unit-тестов:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# .github/workflows/test.yml
|
||||||
|
jobs:
|
||||||
|
unit-tests:
|
||||||
|
run: go test ./internal/...
|
||||||
|
|
||||||
|
integration-tests:
|
||||||
|
run: go test ./tests/integration/... -timeout 120s
|
||||||
|
```
|
||||||
408
go/tests/integration/helpers_test.go
Normal file
408
go/tests/integration/helpers_test.go
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/extension"
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
"github.com/konduktor/konduktor/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestServer represents a running Konduktor server for testing
|
||||||
|
type TestServer struct {
|
||||||
|
Server *http.Server
|
||||||
|
URL string
|
||||||
|
Port int
|
||||||
|
listener net.Listener
|
||||||
|
handler http.Handler
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBackend represents a mock backend server
|
||||||
|
type TestBackend struct {
|
||||||
|
server *httptest.Server
|
||||||
|
requestLog []RequestLogEntry
|
||||||
|
mu sync.Mutex
|
||||||
|
requestCount int64
|
||||||
|
handler http.HandlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestLogEntry stores information about a received request
|
||||||
|
type RequestLogEntry struct {
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Query string
|
||||||
|
Headers http.Header
|
||||||
|
Body string
|
||||||
|
Timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig holds configuration for starting a test server
|
||||||
|
type ServerConfig struct {
|
||||||
|
Extensions []extension.Extension
|
||||||
|
StaticDir string
|
||||||
|
Middleware []func(http.Handler) http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Test Server ==============
|
||||||
|
|
||||||
|
// StartTestServer creates and starts a Konduktor server for testing
|
||||||
|
func StartTestServer(t *testing.T, cfg *ServerConfig) *TestServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
logger, err := logging.New(logging.Config{
|
||||||
|
Level: "DEBUG",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create logger: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create extension manager
|
||||||
|
extManager := extension.NewManager(logger)
|
||||||
|
|
||||||
|
// Add extensions if provided
|
||||||
|
if cfg != nil && len(cfg.Extensions) > 0 {
|
||||||
|
for _, ext := range cfg.Extensions {
|
||||||
|
if err := extManager.AddExtension(ext); err != nil {
|
||||||
|
t.Fatalf("Failed to add extension: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a fallback handler for when no extension handles the request
|
||||||
|
fallback := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create handler chain
|
||||||
|
var handler http.Handler = extManager.Handler(fallback)
|
||||||
|
|
||||||
|
// Add middleware
|
||||||
|
handler = middleware.AccessLog(handler, logger)
|
||||||
|
handler = middleware.Recovery(handler, logger)
|
||||||
|
|
||||||
|
// Add custom middleware if provided
|
||||||
|
if cfg != nil {
|
||||||
|
for _, mw := range cfg.Middleware {
|
||||||
|
handler = mw(handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find available port
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to find available port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := &TestServer{
|
||||||
|
Server: server,
|
||||||
|
URL: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
Port: port,
|
||||||
|
listener: listener,
|
||||||
|
handler: handler,
|
||||||
|
t: t,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start server in goroutine
|
||||||
|
go func() {
|
||||||
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
// Don't fail test here as server might be intentionally closed
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for server to be ready
|
||||||
|
ts.waitReady()
|
||||||
|
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitReady waits for the server to be ready to accept connections
|
||||||
|
func (ts *TestServer) waitReady() {
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", ts.Port), 100*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
ts.t.Fatal("Server failed to start within timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the test server
|
||||||
|
func (ts *TestServer) Close() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
ts.Server.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Test Backend ==============
|
||||||
|
|
||||||
|
// StartBackend creates and starts a mock backend server
|
||||||
|
func StartBackend(handler http.HandlerFunc) *TestBackend {
|
||||||
|
tb := &TestBackend{
|
||||||
|
requestLog: make([]RequestLogEntry, 0),
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler == nil {
|
||||||
|
handler = tb.defaultHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
tb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tb.logRequest(r)
|
||||||
|
handler(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
|
return tb
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tb *TestBackend) logRequest(r *http.Request) {
|
||||||
|
tb.mu.Lock()
|
||||||
|
defer tb.mu.Unlock()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
|
||||||
|
tb.requestLog = append(tb.requestLog, RequestLogEntry{
|
||||||
|
Method: r.Method,
|
||||||
|
Path: r.URL.Path,
|
||||||
|
Query: r.URL.RawQuery,
|
||||||
|
Headers: r.Header.Clone(),
|
||||||
|
Body: string(body),
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
atomic.AddInt64(&tb.requestCount, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tb *TestBackend) defaultHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"backend": "default",
|
||||||
|
"path": r.URL.Path,
|
||||||
|
"method": r.Method,
|
||||||
|
"query": r.URL.RawQuery,
|
||||||
|
"received": time.Now().Unix(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// URL returns the backend server URL
|
||||||
|
func (tb *TestBackend) URL() string {
|
||||||
|
return tb.server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the backend server
|
||||||
|
func (tb *TestBackend) Close() {
|
||||||
|
tb.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestCount returns the number of requests received
|
||||||
|
func (tb *TestBackend) RequestCount() int64 {
|
||||||
|
return atomic.LoadInt64(&tb.requestCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastRequest returns the most recent request
|
||||||
|
func (tb *TestBackend) LastRequest() *RequestLogEntry {
|
||||||
|
tb.mu.Lock()
|
||||||
|
defer tb.mu.Unlock()
|
||||||
|
if len(tb.requestLog) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &tb.requestLog[len(tb.requestLog)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllRequests returns all logged requests
|
||||||
|
func (tb *TestBackend) AllRequests() []RequestLogEntry {
|
||||||
|
tb.mu.Lock()
|
||||||
|
defer tb.mu.Unlock()
|
||||||
|
result := make([]RequestLogEntry, len(tb.requestLog))
|
||||||
|
copy(result, tb.requestLog)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== HTTP Client Helpers ==============
|
||||||
|
|
||||||
|
// HTTPClient is a configured HTTP client for testing
|
||||||
|
type HTTPClient struct {
|
||||||
|
client *http.Client
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPClient creates a new test HTTP client
|
||||||
|
func NewHTTPClient(baseURL string) *HTTPClient {
|
||||||
|
return &HTTPClient{
|
||||||
|
client: &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse // Don't follow redirects
|
||||||
|
},
|
||||||
|
},
|
||||||
|
baseURL: baseURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get performs a GET request
|
||||||
|
func (c *HTTPClient) Get(path string, headers map[string]string) (*http.Response, error) {
|
||||||
|
return c.Do("GET", path, nil, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post performs a POST request
|
||||||
|
func (c *HTTPClient) Post(path string, body []byte, headers map[string]string) (*http.Response, error) {
|
||||||
|
return c.Do("POST", path, body, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do performs an HTTP request
|
||||||
|
func (c *HTTPClient) Do(method, path string, body []byte, headers map[string]string) (*http.Response, error) {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if body != nil {
|
||||||
|
bodyReader = bytes.NewReader(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(method, c.baseURL+path, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range headers {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.client.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJSON performs GET and decodes JSON response
|
||||||
|
func (c *HTTPClient) GetJSON(path string, result interface{}) (*http.Response, error) {
|
||||||
|
resp, err := c.Get(path, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== File System Helpers ==============
|
||||||
|
|
||||||
|
// CreateTempDir creates a temporary directory for static files
|
||||||
|
func CreateTempDir(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
dir, err := os.MkdirTemp("", "konduktor-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { os.RemoveAll(dir) })
|
||||||
|
return dir
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTempFile creates a temporary file with given content
|
||||||
|
func CreateTempFile(t *testing.T, dir, name, content string) string {
|
||||||
|
t.Helper()
|
||||||
|
path := filepath.Join(dir, name)
|
||||||
|
|
||||||
|
// Create parent directories if needed
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||||
|
t.Fatalf("Failed to create directories: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write file: %v", err)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Assertion Helpers ==============
|
||||||
|
|
||||||
|
// AssertStatus checks if response has expected status code
|
||||||
|
func AssertStatus(t *testing.T, resp *http.Response, expected int) {
|
||||||
|
t.Helper()
|
||||||
|
if resp.StatusCode != expected {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Errorf("Expected status %d, got %d. Body: %s", expected, resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertHeader checks if response has expected header value
|
||||||
|
func AssertHeader(t *testing.T, resp *http.Response, header, expected string) {
|
||||||
|
t.Helper()
|
||||||
|
actual := resp.Header.Get(header)
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Expected header %s=%q, got %q", header, expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertHeaderContains checks if header contains substring
|
||||||
|
func AssertHeaderContains(t *testing.T, resp *http.Response, header, substring string) {
|
||||||
|
t.Helper()
|
||||||
|
actual := resp.Header.Get(header)
|
||||||
|
if actual == "" || !contains(actual, substring) {
|
||||||
|
t.Errorf("Expected header %s to contain %q, got %q", header, substring, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertJSONField checks if JSON response has expected field value
|
||||||
|
func AssertJSONField(t *testing.T, body []byte, field string, expected interface{}) {
|
||||||
|
t.Helper()
|
||||||
|
var data map[string]interface{}
|
||||||
|
if err := json.Unmarshal(body, &data); err != nil {
|
||||||
|
t.Fatalf("Failed to parse JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual, ok := data[field]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Field %q not found in JSON", field)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Expected %s=%v, got %v", field, expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsAt(s, substr string, start int) bool {
|
||||||
|
for i := start; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadBody reads and returns response body
|
||||||
|
func ReadBody(t *testing.T, resp *http.Response) []byte {
|
||||||
|
t.Helper()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read body: %v", err)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
562
go/tests/integration/reverse_proxy_test.go
Normal file
562
go/tests/integration/reverse_proxy_test.go
Normal file
@ -0,0 +1,562 @@
|
|||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/konduktor/konduktor/internal/extension"
|
||||||
|
"github.com/konduktor/konduktor/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createTestLogger creates a logger for tests
|
||||||
|
func createTestLogger(t *testing.T) *logging.Logger {
|
||||||
|
t.Helper()
|
||||||
|
logger, err := logging.New(logging.Config{Level: "DEBUG"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create logger: %v", err)
|
||||||
|
}
|
||||||
|
return logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Basic Reverse Proxy Tests ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_BasicGET(t *testing.T) {
|
||||||
|
// Start backend server
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"message": "Hello from backend",
|
||||||
|
"path": r.URL.Path,
|
||||||
|
"method": r.Method,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
// Create routing extension with proxy to backend
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, err := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create routing extension: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start Konduktor server
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Make request through Konduktor
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
resp, err := client.Get("/api/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
AssertStatus(t, resp, http.StatusOK)
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
t.Fatalf("Failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["message"] != "Hello from backend" {
|
||||||
|
t.Errorf("Unexpected message: %v", result["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["path"] != "/api/test" {
|
||||||
|
t.Errorf("Expected path /api/test, got %v", result["path"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify backend received request
|
||||||
|
if backend.RequestCount() != 1 {
|
||||||
|
t.Errorf("Expected 1 backend request, got %d", backend.RequestCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_POST(t *testing.T) {
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var body map[string]interface{}
|
||||||
|
json.NewDecoder(r.Body).Decode(&body)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"received": body,
|
||||||
|
"method": r.Method,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
body := []byte(`{"name":"test","value":123}`)
|
||||||
|
resp, err := client.Post("/api/data", body, map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
AssertStatus(t, resp, http.StatusOK)
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
if result["method"] != "POST" {
|
||||||
|
t.Errorf("Expected method POST, got %v", result["method"])
|
||||||
|
}
|
||||||
|
|
||||||
|
received := result["received"].(map[string]interface{})
|
||||||
|
if received["name"] != "test" {
|
||||||
|
t.Errorf("Expected name 'test', got %v", received["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Exact Match Routes ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_ExactMatchRoute(t *testing.T) {
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"endpoint": "version",
|
||||||
|
"path": r.URL.Path,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
// Exact match - should use backend URL as-is
|
||||||
|
"=/api/version": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL() + "/releases/latest",
|
||||||
|
},
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
|
||||||
|
// Test exact match route
|
||||||
|
resp, err := client.Get("/api/version", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
AssertStatus(t, resp, http.StatusOK)
|
||||||
|
|
||||||
|
lastReq := backend.LastRequest()
|
||||||
|
if lastReq == nil {
|
||||||
|
t.Fatal("No request received by backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
// For exact match, the target path should be used as-is (IgnoreRequestPath=true)
|
||||||
|
if lastReq.Path != "/releases/latest" {
|
||||||
|
t.Errorf("Expected backend path /releases/latest, got %s", lastReq.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Regex Routes with Parameters ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_RegexRouteWithParams(t *testing.T) {
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"path": r.URL.Path,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
// Regex with named group
|
||||||
|
"~^/api/users/(?P<id>\\d+)$": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL() + "/v2/users/{id}",
|
||||||
|
},
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
|
||||||
|
// Test regex route with parameter
|
||||||
|
resp, err := client.Get("/api/users/42", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
AssertStatus(t, resp, http.StatusOK)
|
||||||
|
|
||||||
|
lastReq := backend.LastRequest()
|
||||||
|
if lastReq == nil {
|
||||||
|
t.Fatal("No request received by backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameter {id} should be substituted
|
||||||
|
if lastReq.Path != "/v2/users/42" {
|
||||||
|
t.Errorf("Expected backend path /v2/users/42, got %s", lastReq.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Header Forwarding ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_HeaderForwarding(t *testing.T) {
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"x-forwarded-for": r.Header.Get("X-Forwarded-For"),
|
||||||
|
"x-real-ip": r.Header.Get("X-Real-IP"),
|
||||||
|
"x-custom": r.Header.Get("X-Custom"),
|
||||||
|
"x-forwarded-host": r.Header.Get("X-Forwarded-Host"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
"headers": []interface{}{
|
||||||
|
"X-Forwarded-For: $remote_addr",
|
||||||
|
"X-Real-IP: $remote_addr",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
resp, err := client.Get("/test", map[string]string{
|
||||||
|
"X-Custom": "custom-value",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
// X-Custom should be forwarded
|
||||||
|
if result["x-custom"] != "custom-value" {
|
||||||
|
t.Errorf("Expected X-Custom header to be forwarded, got %v", result["x-custom"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// X-Forwarded-For should be set (will contain 127.0.0.1)
|
||||||
|
if result["x-forwarded-for"] == "" {
|
||||||
|
t.Error("Expected X-Forwarded-For header to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Query String ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_QueryStringPreservation(t *testing.T) {
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"query": r.URL.RawQuery,
|
||||||
|
"foo": r.URL.Query().Get("foo"),
|
||||||
|
"bar": r.URL.Query().Get("bar"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
resp, err := client.Get("/search?foo=hello&bar=world", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
if result["foo"] != "hello" {
|
||||||
|
t.Errorf("Expected foo=hello, got %v", result["foo"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["bar"] != "world" {
|
||||||
|
t.Errorf("Expected bar=world, got %v", result["bar"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Error Handling ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_BackendUnavailable(t *testing.T) {
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
// Non-existent backend
|
||||||
|
"proxy_pass": "http://127.0.0.1:59999",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
resp, err := client.Get("/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Should return 502 Bad Gateway
|
||||||
|
AssertStatus(t, resp, http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_BackendTimeout(t *testing.T) {
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simulate slow backend
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
"timeout": 0.5, // 500ms timeout
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
resp, err := client.Get("/slow", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Should return 504 Gateway Timeout
|
||||||
|
AssertStatus(t, resp, http.StatusGatewayTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== HTTP Methods ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_AllMethods(t *testing.T) {
|
||||||
|
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"}
|
||||||
|
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"method": r.Method,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
|
||||||
|
for _, method := range methods {
|
||||||
|
t.Run(method, func(t *testing.T) {
|
||||||
|
resp, err := client.Do(method, "/resource", nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
AssertStatus(t, resp, http.StatusOK)
|
||||||
|
|
||||||
|
if method != "HEAD" {
|
||||||
|
var result map[string]string
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
if result["method"] != method {
|
||||||
|
t.Errorf("Expected method %s, got %v", method, result["method"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Large Bodies ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_LargeRequestBody(t *testing.T) {
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
json.NewEncoder(w).Encode(map[string]int{
|
||||||
|
"received": len(body),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
|
||||||
|
// 1MB body
|
||||||
|
largeBody := []byte(strings.Repeat("x", 1024*1024))
|
||||||
|
resp, err := client.Post("/upload", largeBody, map[string]string{
|
||||||
|
"Content-Type": "application/octet-stream",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
AssertStatus(t, resp, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============== Concurrent Requests ==============
|
||||||
|
|
||||||
|
func TestReverseProxy_ConcurrentRequests(t *testing.T) {
|
||||||
|
backend := StartBackend(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Small delay to simulate work
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||||
|
})
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
logger := createTestLogger(t)
|
||||||
|
routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{
|
||||||
|
"regex_locations": map[string]interface{}{
|
||||||
|
"__default__": map[string]interface{}{
|
||||||
|
"proxy_pass": backend.URL(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
server := StartTestServer(t, &ServerConfig{
|
||||||
|
Extensions: []extension.Extension{routingExt},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
const numRequests = 50
|
||||||
|
results := make(chan error, numRequests)
|
||||||
|
|
||||||
|
for i := 0; i < numRequests; i++ {
|
||||||
|
go func(n int) {
|
||||||
|
client := NewHTTPClient(server.URL)
|
||||||
|
resp, err := client.Get(fmt.Sprintf("/concurrent/%d", n), nil)
|
||||||
|
if err != nil {
|
||||||
|
results <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
results <- fmt.Errorf("unexpected status: %d", resp.StatusCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
results <- nil
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect results
|
||||||
|
var errors []error
|
||||||
|
for i := 0; i < numRequests; i++ {
|
||||||
|
if err := <-results; err != nil {
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
t.Errorf("Got %d errors in concurrent requests: %v", len(errors), errors[:min(5, len(errors))])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all requests reached backend
|
||||||
|
if backend.RequestCount() != numRequests {
|
||||||
|
t.Errorf("Expected %d backend requests, got %d", numRequests, backend.RequestCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user