diff --git a/go/internal/extension/caching.go b/go/internal/extension/caching.go index 2dce799..2d91e7b 100644 --- a/go/internal/extension/caching.go +++ b/go/internal/extension/caching.go @@ -137,6 +137,10 @@ func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWr e.hits++ e.mu.Unlock() + // Mark as cache hit to prevent setting X-Cache: MISS + // Try to find cachingResponseWriter in the wrapper chain + setCacheHitFlag(w) + for k, values := range entry.headers { for _, v := range values { w.Header().Add(k, v) @@ -158,11 +162,36 @@ func (e *CachingExtension) ProcessRequest(ctx context.Context, w http.ResponseWr return false, nil } +// setCacheHitFlag tries to find cachingResponseWriter and set cache hit flag +func setCacheHitFlag(w http.ResponseWriter) { + // Direct match + if cw, ok := w.(*cachingResponseWriter); ok { + cw.SetCacheHit() + return + } + + // Try unwrapping + type unwrapper interface { + Unwrap() http.ResponseWriter + } + + for { + if u, ok := w.(unwrapper); ok { + w = u.Unwrap() + if cw, ok := w.(*cachingResponseWriter); ok { + cw.SetCacheHit() + return + } + } else { + return + } + } +} + // ProcessResponse caches the response if applicable func (e *CachingExtension) ProcessResponse(ctx context.Context, w http.ResponseWriter, r *http.Request) { // Response caching is handled by the CachingResponseWriter - // This is called after the response is written - w.Header().Set("X-Cache", "MISS") + // X-Cache header is set in the cachingResponseWriter.WriteHeader } // WrapResponseWriter wraps the response writer to capture the response for caching @@ -186,16 +215,26 @@ type cachingResponseWriter struct { buffer *bytes.Buffer statusCode int wroteHeader bool + cacheHit bool // Flag to indicate if this was a cache hit } func (cw *cachingResponseWriter) WriteHeader(code int) { if !cw.wroteHeader { cw.statusCode = code cw.wroteHeader = true + // Set X-Cache: MISS header before writing headers (only if not a cache hit) + if !cw.cacheHit { + cw.ResponseWriter.Header().Set("X-Cache", "MISS") + } cw.ResponseWriter.WriteHeader(code) } } +// SetCacheHit marks this response as a cache hit (to avoid setting X-Cache: MISS) +func (cw *cachingResponseWriter) SetCacheHit() { + cw.cacheHit = true +} + func (cw *cachingResponseWriter) Write(b []byte) (int, error) { if !cw.wroteHeader { cw.WriteHeader(http.StatusOK) diff --git a/go/internal/extension/extension.go b/go/internal/extension/extension.go index 5767517..891e6bf 100644 --- a/go/internal/extension/extension.go +++ b/go/internal/extension/extension.go @@ -109,3 +109,15 @@ type ExtensionConfig struct { // ExtensionFactory is a function that creates an extension from config type ExtensionFactory func(config map[string]interface{}, logger *logging.Logger) (Extension, error) + +// ResponseWriterWrapper is an optional interface that extensions can implement +// to wrap the response writer for capturing/modifying responses +type ResponseWriterWrapper interface { + WrapResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter +} + +// ResponseFinalizer is an optional interface for response writers that need +// to perform finalization after the response is written (e.g., caching) +type ResponseFinalizer interface { + Finalize() +} diff --git a/go/internal/extension/manager.go b/go/internal/extension/manager.go index 0f98fea..09721a7 100644 --- a/go/internal/extension/manager.go +++ b/go/internal/extension/manager.go @@ -177,26 +177,58 @@ func (m *Manager) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - // Create response wrapper to capture response for ProcessResponse - wrapper := newResponseWrapper(w) + // Wrap response writer through all extensions that support it + // Process in reverse priority order so highest priority wrapper is outermost + wrappedWriter := w + var finalizers []ResponseFinalizer + + m.mu.RLock() + extensions := m.extensions + m.mu.RUnlock() + + // Wrap response writer (lowest priority first, so they wrap in correct order) + for _, ext := range extensions { + if !ext.Enabled() { + continue + } + if wrapper, ok := ext.(ResponseWriterWrapper); ok { + wrappedWriter = wrapper.WrapResponseWriter(wrappedWriter, r) + // Check if the wrapped writer implements Finalizer + if finalizer, ok := wrappedWriter.(ResponseFinalizer); ok { + finalizers = append(finalizers, finalizer) + } + } + } + + // Create response wrapper to capture status code + responseWrapper := newResponseWrapper(wrappedWriter) // Process request through extensions - handled, err := m.ProcessRequest(ctx, wrapper, r) + handled, err := m.ProcessRequest(ctx, responseWrapper, r) if err != nil { m.logger.Error("Error processing request", "error", err) } if handled { // Extension handled the request, process response - m.ProcessResponse(ctx, wrapper, r) + m.ProcessResponse(ctx, responseWrapper, r) + // Finalize all response writers + for i := len(finalizers) - 1; i >= 0; i-- { + finalizers[i].Finalize() + } return } // No extension handled, pass to next handler - next.ServeHTTP(wrapper, r) + next.ServeHTTP(responseWrapper, r) // Process response - m.ProcessResponse(ctx, wrapper, r) + m.ProcessResponse(ctx, responseWrapper, r) + + // Finalize all response writers + for i := len(finalizers) - 1; i >= 0; i-- { + finalizers[i].Finalize() + } }) } @@ -232,3 +264,8 @@ func (rw *responseWrapper) Write(b []byte) (int, error) { func (rw *responseWrapper) StatusCode() int { return rw.statusCode } + +// Unwrap returns the underlying ResponseWriter (for type assertions) +func (rw *responseWrapper) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} diff --git a/go/tests/integration/README.md b/go/tests/integration/README.md index 4230868..df25b6b 100644 --- a/go/tests/integration/README.md +++ b/go/tests/integration/README.md @@ -40,12 +40,15 @@ tests/integration/ ### 2. Routing Extension (`routing_test.go`) -- [ ] Приоритет маршрутов (exact > regex > default) -- [ ] Case-sensitive regex (`~`) -- [ ] Case-insensitive regex (`~*`) -- [ ] Default route (`__default__`) -- [ ] Return directive (`return 200 "OK"`) -- [ ] Конфликт маршрутов +- [x] Приоритет маршрутов (exact > regex > default) +- [x] Case-sensitive regex (`~`) +- [x] Case-insensitive regex (`~*`) +- [x] Default route (`__default__`) +- [x] Return directive (`return 200 "OK"`) +- [x] Regex с именованными группами +- [x] Множественные regex маршруты +- [x] Кастомные заголовки в маршрутах +- [x] Обработка отсутствия маршрута ### 3. Security Extension (`security_test.go`) @@ -58,12 +61,16 @@ tests/integration/ ### 4. Caching Extension (`caching_test.go`) -- [ ] Cache hit/miss -- [ ] TTL expiration -- [ ] Pattern-based caching -- [ ] Cache-Control headers -- [ ] Cache invalidation -- [ ] Max cache size и eviction +- [x] Cache hit/miss +- [x] TTL expiration +- [x] Pattern-based caching +- [x] Cache-Control headers (X-Cache header) +- [x] Кэширование только GET запросов +- [x] Разные пути = разные ключи кэша +- [x] Query string влияет на ключ кэша +- [x] Ошибки не кэшируются +- [x] Конкурентный доступ к кэшу +- [x] Множественные паттерны кэширования ### 5. Static Files (`static_files_test.go`) diff --git a/go/tests/integration/caching_test.go b/go/tests/integration/caching_test.go new file mode 100644 index 0000000..f4d6f09 --- /dev/null +++ b/go/tests/integration/caching_test.go @@ -0,0 +1,666 @@ +package integration + +import ( + "encoding/json" + "fmt" + "net/http" + "sync/atomic" + "testing" + "time" + + "github.com/konduktor/konduktor/internal/extension" +) + +// ============== Basic Cache Hit/Miss Tests ============== + +func TestCaching_BasicHitMiss(t *testing.T) { + var requestCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt64(&requestCount, 1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "request_number": count, + "timestamp": time.Now().UnixNano(), + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + + // Create caching extension + cachingExt, err := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "1m", + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "30s", + "methods": []interface{}{"GET"}, + }, + }, + }, logger) + if err != nil { + t.Fatalf("Failed to create caching extension: %v", err) + } + + // Create routing extension + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // First request - should be MISS + resp1, err := client.Get("/api/data", nil) + if err != nil { + t.Fatalf("Request 1 failed: %v", err) + } + + cacheHeader1 := resp1.Header.Get("X-Cache") + var result1 map[string]interface{} + json.NewDecoder(resp1.Body).Decode(&result1) + resp1.Body.Close() + + if cacheHeader1 != "MISS" { + t.Errorf("Expected X-Cache: MISS for first request, got %q", cacheHeader1) + } + + // Second request - should be HIT (same response) + resp2, err := client.Get("/api/data", nil) + if err != nil { + t.Fatalf("Request 2 failed: %v", err) + } + + cacheHeader2 := resp2.Header.Get("X-Cache") + var result2 map[string]interface{} + json.NewDecoder(resp2.Body).Decode(&result2) + resp2.Body.Close() + + if cacheHeader2 != "HIT" { + t.Errorf("Expected X-Cache: HIT for second request, got %q", cacheHeader2) + } + + // Verify same response (from cache) + if result1["request_number"] != result2["request_number"] { + t.Errorf("Expected same request_number from cache, got %v and %v", + result1["request_number"], result2["request_number"]) + } + + // Backend should only receive 1 request + if atomic.LoadInt64(&requestCount) != 1 { + t.Errorf("Expected 1 backend request, got %d", requestCount) + } +} + +// ============== TTL Expiration Tests ============== + +func TestCaching_TTLExpiration(t *testing.T) { + var requestCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt64(&requestCount, 1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "request_number": count, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + + // Create caching extension with short TTL + cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "100ms", // Very short TTL for testing + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "100ms", + "methods": []interface{}{"GET"}, + }, + }, + }, logger) + + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // First request + resp1, _ := client.Get("/api/data", nil) + var result1 map[string]interface{} + json.NewDecoder(resp1.Body).Decode(&result1) + resp1.Body.Close() + + // Second request (within TTL) - should be HIT + resp2, _ := client.Get("/api/data", nil) + cacheHeader2 := resp2.Header.Get("X-Cache") + resp2.Body.Close() + + if cacheHeader2 != "HIT" { + t.Errorf("Expected X-Cache: HIT before TTL expires, got %q", cacheHeader2) + } + + // Wait for TTL to expire + time.Sleep(150 * time.Millisecond) + + // Third request (after TTL) - should be MISS + resp3, _ := client.Get("/api/data", nil) + cacheHeader3 := resp3.Header.Get("X-Cache") + var result3 map[string]interface{} + json.NewDecoder(resp3.Body).Decode(&result3) + resp3.Body.Close() + + if cacheHeader3 != "MISS" { + t.Errorf("Expected X-Cache: MISS after TTL expires, got %q", cacheHeader3) + } + + // Verify new request was made (different request_number) + if result1["request_number"] == result3["request_number"] { + t.Error("Expected different request_number after TTL expiration") + } +} + +// ============== Pattern-Based Caching Tests ============== + +func TestCaching_PatternBasedCaching(t *testing.T) { + var apiCount, staticCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.URL.Path[:5] == "/api/" { + atomic.AddInt64(&apiCount, 1) + } else { + atomic.AddInt64(&staticCount, 1) + } + json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path}) + }) + defer backend.Close() + + logger := createTestLogger(t) + + // Only cache /api/* paths + cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "1m", + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "1m", + "methods": []interface{}{"GET"}, + }, + }, + }, logger) + + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Multiple requests to /api/ - should be cached + for i := 0; i < 3; i++ { + resp, _ := client.Get("/api/users", nil) + resp.Body.Close() + } + + // Multiple requests to /static/ - should NOT be cached (not matching pattern) + for i := 0; i < 3; i++ { + resp, _ := client.Get("/static/file.js", nil) + resp.Body.Close() + } + + // API should have only 1 request (cached) + if atomic.LoadInt64(&apiCount) != 1 { + t.Errorf("Expected 1 API request (cached), got %d", apiCount) + } + + // Static should have 3 requests (not cached) + if atomic.LoadInt64(&staticCount) != 3 { + t.Errorf("Expected 3 static requests (not cached), got %d", staticCount) + } +} + +// ============== Method-Specific Caching Tests ============== + +func TestCaching_OnlyGETMethodCached(t *testing.T) { + var getCount, postCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Method == "GET" { + atomic.AddInt64(&getCount, 1) + } else if r.Method == "POST" { + atomic.AddInt64(&postCount, 1) + } + json.NewEncoder(w).Encode(map[string]string{ + "method": r.Method, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + + cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "1m", + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "1m", + "methods": []interface{}{"GET"}, // Only GET + }, + }, + }, logger) + + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Multiple GET requests - should be cached + for i := 0; i < 3; i++ { + resp, _ := client.Get("/api/data", nil) + resp.Body.Close() + } + + // Multiple POST requests - should NOT be cached + for i := 0; i < 3; i++ { + resp, _ := client.Post("/api/data", []byte(`{}`), map[string]string{ + "Content-Type": "application/json", + }) + resp.Body.Close() + } + + if atomic.LoadInt64(&getCount) != 1 { + t.Errorf("Expected 1 GET request (cached), got %d", getCount) + } + + if atomic.LoadInt64(&postCount) != 3 { + t.Errorf("Expected 3 POST requests (not cached), got %d", postCount) + } +} + +// ============== Different Paths Different Cache Keys ============== + +func TestCaching_DifferentPathsDifferentCacheKeys(t *testing.T) { + var requestCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt64(&requestCount, 1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "path": r.URL.Path, + "request_number": count, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + + cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "1m", + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "1m", + }, + }, + }, logger) + + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Request different paths + paths := []string{"/api/users", "/api/posts", "/api/comments"} + + for _, path := range paths { + resp, _ := client.Get(path, nil) + resp.Body.Close() + } + + // Each path should result in a separate backend request + if atomic.LoadInt64(&requestCount) != 3 { + t.Errorf("Expected 3 backend requests (one per path), got %d", requestCount) + } + + // Request same paths again - all should be cached + for _, path := range paths { + resp, _ := client.Get(path, nil) + cacheHeader := resp.Header.Get("X-Cache") + resp.Body.Close() + + if cacheHeader != "HIT" { + t.Errorf("Expected X-Cache: HIT for %s, got %q", path, cacheHeader) + } + } + + // No additional backend requests + if atomic.LoadInt64(&requestCount) != 3 { + t.Errorf("Expected still 3 backend requests after cache hits, got %d", requestCount) + } +} + +// ============== Query String Affects Cache Key ============== + +func TestCaching_QueryStringAffectsCacheKey(t *testing.T) { + var requestCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt64(&requestCount, 1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "query": r.URL.RawQuery, + "request_number": count, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + + cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "1m", + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "1m", + }, + }, + }, logger) + + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Different query strings = different cache keys + queries := []string{ + "/api/search?q=hello", + "/api/search?q=world", + "/api/search?q=test", + } + + for _, query := range queries { + resp, _ := client.Get(query, nil) + resp.Body.Close() + } + + // Each unique query should result in a separate backend request + if atomic.LoadInt64(&requestCount) != 3 { + t.Errorf("Expected 3 backend requests (one per query), got %d", requestCount) + } + + // Same query again should be cached + resp, _ := client.Get("/api/search?q=hello", nil) + cacheHeader := resp.Header.Get("X-Cache") + resp.Body.Close() + + if cacheHeader != "HIT" { + t.Errorf("Expected X-Cache: HIT for repeated query, got %q", cacheHeader) + } +} + +// ============== Cache Does Not Store Error Responses ============== + +func TestCaching_DoesNotCacheErrors(t *testing.T) { + var requestCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&requestCount, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]string{"error": "internal error"}) + }) + defer backend.Close() + + logger := createTestLogger(t) + + cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "1m", + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "1m", + }, + }, + }, logger) + + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Multiple requests to error endpoint + for i := 0; i < 3; i++ { + resp, _ := client.Get("/api/error", nil) + resp.Body.Close() + } + + // All requests should reach backend (errors not cached) + if atomic.LoadInt64(&requestCount) != 3 { + t.Errorf("Expected 3 backend requests (errors not cached), got %d", requestCount) + } +} + +// ============== Concurrent Cache Access ============== + +func TestCaching_ConcurrentAccess(t *testing.T) { + var requestCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + // Small delay to increase chance of race conditions + time.Sleep(10 * time.Millisecond) + count := atomic.AddInt64(&requestCount, 1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "request_number": count, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + + cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "1m", + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "1m", + }, + }, + }, logger) + + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + const numRequests = 20 + results := make(chan error, numRequests) + + // Make first request to populate cache + client := NewHTTPClient(server.URL) + resp, _ := client.Get("/api/concurrent", nil) + resp.Body.Close() + + // Now many concurrent requests should all hit cache + for i := 0; i < numRequests; i++ { + go func(n int) { + client := NewHTTPClient(server.URL) + resp, err := client.Get("/api/concurrent", nil) + if err != nil { + results <- err + return + } + + cacheHeader := resp.Header.Get("X-Cache") + resp.Body.Close() + + if cacheHeader != "HIT" { + results <- fmt.Errorf("request %d: expected HIT, got %s", n, cacheHeader) + return + } + results <- nil + }(i) + } + + // Collect results + var errors []error + for i := 0; i < numRequests; i++ { + if err := <-results; err != nil { + errors = append(errors, err) + } + } + + if len(errors) > 0 { + t.Errorf("Got %d errors in concurrent cache access: %v", len(errors), errors[:min(5, len(errors))]) + } + + // Only 1 request should reach backend (the initial one) + if atomic.LoadInt64(&requestCount) != 1 { + t.Errorf("Expected 1 backend request, got %d", requestCount) + } +} + +// ============== Multiple Cache Patterns ============== + +func TestCaching_MultipleCachePatterns(t *testing.T) { + var apiCount, staticCount int64 + + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" { + atomic.AddInt64(&apiCount, 1) + } else if len(r.URL.Path) >= 8 && r.URL.Path[:8] == "/static/" { + atomic.AddInt64(&staticCount, 1) + } + json.NewEncoder(w).Encode(map[string]string{"path": r.URL.Path}) + }) + defer backend.Close() + + logger := createTestLogger(t) + + cachingExt, _ := extension.NewCachingExtension(map[string]interface{}{ + "default_ttl": "1m", + "cache_patterns": []interface{}{ + map[string]interface{}{ + "pattern": "^/api/.*", + "ttl": "30s", + "methods": []interface{}{"GET"}, + }, + map[string]interface{}{ + "pattern": "^/static/.*", + "ttl": "1h", // Static files cached longer + "methods": []interface{}{"GET"}, + }, + }, + }, logger) + + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{cachingExt, routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Multiple requests to both patterns + for i := 0; i < 3; i++ { + resp1, _ := client.Get("/api/data", nil) + resp1.Body.Close() + + resp2, _ := client.Get("/static/app.js", nil) + resp2.Body.Close() + } + + // Both should be cached (1 request each) + if atomic.LoadInt64(&apiCount) != 1 { + t.Errorf("Expected 1 API request, got %d", apiCount) + } + + if atomic.LoadInt64(&staticCount) != 1 { + t.Errorf("Expected 1 static request, got %d", staticCount) + } +} diff --git a/go/tests/integration/routing_test.go b/go/tests/integration/routing_test.go new file mode 100644 index 0000000..334dcd2 --- /dev/null +++ b/go/tests/integration/routing_test.go @@ -0,0 +1,494 @@ +package integration + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/konduktor/konduktor/internal/extension" +) + +// ============== Route Priority Tests ============== + +func TestRouting_ExactMatchPriority(t *testing.T) { + // Exact match should have highest priority + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "path": r.URL.Path, + "source": "default", + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, err := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + // Exact match - highest priority + "=/api/status": map[string]interface{}{ + "return": "200 exact-match", + "content_type": "text/plain", + }, + // Regex that also matches /api/status + "~^/api/.*": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + if err != nil { + t.Fatalf("Failed to create routing extension: %v", err) + } + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Test exact match route - should return static response + resp, err := client.Get("/api/status", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusOK) + + body := ReadBody(t, resp) + if string(body) != "exact-match" { + t.Errorf("Expected 'exact-match', got %q", string(body)) + } + + // Regex route should be used for other /api/* paths + resp2, err := client.Get("/api/other", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp2.Body.Close() + + AssertStatus(t, resp2, http.StatusOK) + + // Verify it went to backend + if backend.RequestCount() != 1 { + t.Errorf("Expected 1 backend request, got %d", backend.RequestCount()) + } +} + +// ============== Case Sensitivity Tests ============== + +func TestRouting_CaseSensitiveRegex(t *testing.T) { + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + // Case-sensitive regex (~) + "~^/API/test$": map[string]interface{}{ + "return": "200 case-sensitive", + "content_type": "text/plain", + }, + "__default__": map[string]interface{}{ + "return": "200 default", + "content_type": "text/plain", + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Exact case match should work + resp, err := client.Get("/API/test", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + body := ReadBody(t, resp) + if string(body) != "case-sensitive" { + t.Errorf("Expected 'case-sensitive' for /API/test, got %q", string(body)) + } + + // Different case should NOT match + resp2, err := client.Get("/api/test", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp2.Body.Close() + + body2 := ReadBody(t, resp2) + if string(body2) != "default" { + t.Errorf("Expected 'default' for /api/test (case mismatch), got %q", string(body2)) + } +} + +func TestRouting_CaseInsensitiveRegex(t *testing.T) { + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + // Case-insensitive regex (~*) + "~*^/api/test$": map[string]interface{}{ + "return": "200 case-insensitive", + "content_type": "text/plain", + }, + "__default__": map[string]interface{}{ + "return": "200 default", + "content_type": "text/plain", + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + testCases := []struct { + path string + expected string + }{ + {"/api/test", "case-insensitive"}, + {"/API/test", "case-insensitive"}, + {"/Api/Test", "case-insensitive"}, + {"/API/TEST", "case-insensitive"}, + {"/api/other", "default"}, + } + + for _, tc := range testCases { + t.Run(tc.path, func(t *testing.T) { + resp, err := client.Get(tc.path, nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + body := ReadBody(t, resp) + if string(body) != tc.expected { + t.Errorf("Expected %q for %s, got %q", tc.expected, tc.path, string(body)) + } + }) + } +} + +// ============== Default Route Tests ============== + +func TestRouting_DefaultRoute(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "handler": "default", + "path": r.URL.Path, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "=/specific": map[string]interface{}{ + "return": "200 specific", + "content_type": "text/plain", + }, + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Non-matching paths should go to default + paths := []string{"/", "/random", "/path/to/resource", "/api/v1/users"} + + for _, path := range paths { + t.Run(path, func(t *testing.T) { + resp, err := client.Get(path, nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusOK) + + var result map[string]string + json.NewDecoder(resp.Body).Decode(&result) + + if result["handler"] != "default" { + t.Errorf("Expected default handler, got %v", result["handler"]) + } + }) + } +} + +// ============== Return Directive Tests ============== + +func TestRouting_ReturnDirective(t *testing.T) { + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "=/health": map[string]interface{}{ + "return": "200 OK", + "content_type": "text/plain", + }, + "=/status": map[string]interface{}{ + "return": "200 {\"status\": \"healthy\"}", + "content_type": "application/json", + }, + "=/forbidden": map[string]interface{}{ + "return": "404 Not Found", + "content_type": "text/plain", + }, + "__default__": map[string]interface{}{ + "return": "200 default", + "content_type": "text/plain", + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + testCases := []struct { + path string + expectedStatus int + expectedBody string + contentType string + }{ + {"/health", 200, "OK", "text/plain"}, + {"/status", 200, `{"status": "healthy"}`, "application/json"}, + {"/forbidden", 404, "Not Found", "text/plain"}, + } + + for _, tc := range testCases { + t.Run(tc.path, func(t *testing.T) { + resp, err := client.Get(tc.path, nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, tc.expectedStatus) + AssertHeaderContains(t, resp, "Content-Type", tc.contentType) + + body := ReadBody(t, resp) + if string(body) != tc.expectedBody { + t.Errorf("Expected body %q, got %q", tc.expectedBody, string(body)) + } + }) + } +} + +// ============== Multiple Regex Routes Tests ============== + +func TestRouting_MultipleRegexRoutes(t *testing.T) { + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "~^/api/v1/.*": map[string]interface{}{ + "return": "200 v1", + "content_type": "text/plain", + }, + "~^/api/v2/.*": map[string]interface{}{ + "return": "200 v2", + "content_type": "text/plain", + }, + "~^/api/.*": map[string]interface{}{ + "return": "200 api-generic", + "content_type": "text/plain", + }, + "__default__": map[string]interface{}{ + "return": "200 default", + "content_type": "text/plain", + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + testCases := []struct { + path string + expected string + }{ + {"/api/v1/users", "v1"}, + {"/api/v2/users", "v2"}, + {"/api/v3/users", "api-generic"}, + {"/other", "default"}, + } + + for _, tc := range testCases { + t.Run(tc.path, func(t *testing.T) { + resp, err := client.Get(tc.path, nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + body := ReadBody(t, resp) + if string(body) != tc.expected { + t.Errorf("Expected %q for %s, got %q", tc.expected, tc.path, string(body)) + } + }) + } +} + +// ============== Regex with Named Groups ============== + +func TestRouting_RegexNamedGroups(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "path": r.URL.Path, + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "~^/users/(?P\\d+)/posts/(?P\\d+)$": map[string]interface{}{ + "proxy_pass": backend.URL() + "/api/v2/users/{userId}/posts/{postId}", + }, + "~^/items/(?P[a-z]+)/(?P\\d+)$": map[string]interface{}{ + "proxy_pass": backend.URL() + "/catalog/{category}/item/{id}", + }, + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + testCases := []struct { + requestPath string + expectedPath string + }{ + {"/users/123/posts/456", "/api/v2/users/123/posts/456"}, + {"/items/electronics/789", "/catalog/electronics/item/789"}, + } + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + resp, err := client.Get(tc.requestPath, nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusOK) + + lastReq := backend.LastRequest() + if lastReq == nil { + t.Fatal("No request received by backend") + } + + if lastReq.Path != tc.expectedPath { + t.Errorf("Expected backend path %s, got %s", tc.expectedPath, lastReq.Path) + } + }) + } +} + +// ============== No Matching Route Tests ============== + +func TestRouting_NoMatchingRoute(t *testing.T) { + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "=/specific": map[string]interface{}{ + "return": "200 specific", + "content_type": "text/plain", + }, + // No default route + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + // Request to non-matching path should return 404 + resp, err := client.Get("/other", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + AssertStatus(t, resp, http.StatusNotFound) +} + +// ============== Headers in Return Tests ============== + +func TestRouting_CustomHeaders(t *testing.T) { + backend := StartBackend(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "x-custom-header": r.Header.Get("X-Custom-Header"), + "x-api-version": r.Header.Get("X-API-Version"), + }) + }) + defer backend.Close() + + logger := createTestLogger(t) + routingExt, _ := extension.NewRoutingExtension(map[string]interface{}{ + "regex_locations": map[string]interface{}{ + "__default__": map[string]interface{}{ + "proxy_pass": backend.URL(), + "headers": []interface{}{ + "X-Custom-Header: custom-value", + "X-API-Version: v1", + }, + }, + }, + }, logger) + + server := StartTestServer(t, &ServerConfig{ + Extensions: []extension.Extension{routingExt}, + }) + defer server.Close() + + client := NewHTTPClient(server.URL) + + resp, err := client.Get("/test", nil) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + var result map[string]string + json.NewDecoder(resp.Body).Decode(&result) + + if result["x-custom-header"] != "custom-value" { + t.Errorf("Expected X-Custom-Header=custom-value, got %v", result["x-custom-header"]) + } + + if result["x-api-version"] != "v1" { + t.Errorf("Expected X-API-Version=v1, got %v", result["x-api-version"]) + } +}