Илья Глазунов 881028c1e6 feat: Add reverse proxy functionality with enhanced routing capabilities
- Introduced IgnoreRequestPath option in proxy configuration to allow exact match routing.
- Implemented proxy_pass directive in routing extension to handle backend requests.
- Enhanced error handling for backend unavailability and timeouts.
- Added integration tests for reverse proxy, including basic requests, exact match routes, regex routes, header forwarding, and query string preservation.
- Created helper functions for setting up test servers and backends, along with assertion utilities for response validation.
- Updated server initialization to support extension management and middleware chaining.
- Improved logging for debugging purposes during request handling.
2025-12-12 00:38:30 +03:00

334 lines
7.6 KiB
Go

// Package proxy provides reverse proxy functionality for Konduktor
package proxy
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/konduktor/konduktor/internal/logging"
)
type Config struct {
// Target is the backend server URL
Target string
// Timeout is the request timeout (default: 30s)
Timeout time.Duration
// Headers are additional headers to add to requests
Headers map[string]string
// StripPrefix removes this prefix from the request path
StripPrefix string
// PreserveHost keeps the original Host header
PreserveHost bool
// IgnoreRequestPath ignores the request path and uses only the target path
// This is useful for exact match routes where target URL should be used as-is
IgnoreRequestPath bool
}
type ReverseProxy struct {
config *Config
targetURL *url.URL
httpClient *http.Client
logger *logging.Logger
}
func New(cfg *Config, logger *logging.Logger) (*ReverseProxy, error) {
if cfg.Target == "" {
return nil, fmt.Errorf("proxy target is required")
}
targetURL, err := url.Parse(cfg.Target)
if err != nil {
return nil, fmt.Errorf("invalid proxy target URL: %w", err)
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: timeout,
}
return &ReverseProxy{
config: cfg,
targetURL: targetURL,
httpClient: &http.Client{
Transport: transport,
Timeout: timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Don't follow redirects
},
},
logger: logger,
}, nil
}
func (rp *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rp.ProxyRequest(w, r, nil)
}
func (rp *ReverseProxy) ProxyRequest(w http.ResponseWriter, r *http.Request, params map[string]string) {
ctx := r.Context()
// Build target URL
targetURL := rp.buildTargetURL(r)
// Create proxy request
proxyReq, err := rp.createProxyRequest(ctx, r, targetURL)
if err != nil {
rp.handleError(w, http.StatusInternalServerError, "Failed to create proxy request", err)
return
}
// Add custom headers with parameter substitution
rp.addCustomHeaders(proxyReq, r, params)
// Execute request
resp, err := rp.httpClient.Do(proxyReq)
if err != nil {
rp.handleProxyError(w, err)
return
}
defer resp.Body.Close()
// Copy response
rp.copyResponse(w, resp)
}
func (rp *ReverseProxy) buildTargetURL(r *http.Request) *url.URL {
targetURL := *rp.targetURL
// If ignoring request path, use target URL path as-is
if rp.config.IgnoreRequestPath {
// Preserve query string only
targetURL.RawQuery = r.URL.RawQuery
return &targetURL
}
// Strip prefix if configured
path := r.URL.Path
if rp.config.StripPrefix != "" {
path = strings.TrimPrefix(path, rp.config.StripPrefix)
if path == "" || path[0] != '/' {
path = "/" + path
}
}
// If target URL has a non-empty path, combine it with the request path
if rp.targetURL.Path != "" && rp.targetURL.Path != "/" {
// Combine target path with request path
targetURL.Path = strings.TrimSuffix(rp.targetURL.Path, "/") + path
} else {
// No path in target, use request path as-is
targetURL.Path = path
}
// Preserve query string
targetURL.RawQuery = r.URL.RawQuery
return &targetURL
}
func (rp *ReverseProxy) createProxyRequest(ctx context.Context, r *http.Request, targetURL *url.URL) (*http.Request, error) {
proxyReq, err := http.NewRequestWithContext(ctx, r.Method, targetURL.String(), r.Body)
if err != nil {
return nil, err
}
// Copy ContentLength
proxyReq.ContentLength = r.ContentLength
// Copy headers
for key, values := range r.Header {
for _, value := range values {
proxyReq.Header.Add(key, value)
}
}
// Set/update Host header
if rp.config.PreserveHost {
proxyReq.Host = r.Host
} else {
proxyReq.Host = targetURL.Host
}
// Remove hop-by-hop headers
removeHopByHopHeaders(proxyReq.Header)
return proxyReq, nil
}
func (rp *ReverseProxy) addCustomHeaders(proxyReq *http.Request, originalReq *http.Request, params map[string]string) {
// Add X-Forwarded headers
clientIP := getClientIP(originalReq)
if prior := originalReq.Header.Get("X-Forwarded-For"); prior != "" {
clientIP = prior + ", " + clientIP
}
proxyReq.Header.Set("X-Forwarded-For", clientIP)
proxyReq.Header.Set("X-Forwarded-Proto", getScheme(originalReq))
proxyReq.Header.Set("X-Forwarded-Host", originalReq.Host)
// Add custom headers from config
for key, value := range rp.config.Headers {
// Substitute parameters like {version}
substituted := value
for paramKey, paramValue := range params {
substituted = strings.ReplaceAll(substituted, "{"+paramKey+"}", paramValue)
}
// Substitute $remote_addr
substituted = strings.ReplaceAll(substituted, "$remote_addr", clientIP)
proxyReq.Header.Set(key, substituted)
}
}
func (rp *ReverseProxy) copyResponse(w http.ResponseWriter, resp *http.Response) {
// Copy headers
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
// Remove hop-by-hop headers from response
removeHopByHopHeaders(w.Header())
// Write status code
w.WriteHeader(resp.StatusCode)
// Copy body
io.Copy(w, resp.Body)
}
func (rp *ReverseProxy) handleError(w http.ResponseWriter, status int, message string, err error) {
if rp.logger != nil {
rp.logger.Error(message, "error", err)
}
http.Error(w, message, status)
}
func (rp *ReverseProxy) handleProxyError(w http.ResponseWriter, err error) {
if rp.logger != nil {
rp.logger.Error("Proxy request failed", "error", err)
}
// Check for timeout
if err, ok := err.(net.Error); ok && err.Timeout() {
http.Error(w, "504 Gateway Timeout", http.StatusGatewayTimeout)
return
}
// Check for connection errors
if isConnectionError(err) {
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
return
}
// Context cancelled (client disconnected)
if err == context.Canceled {
return
}
http.Error(w, "502 Bad Gateway", http.StatusBadGateway)
}
// Helper functions
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func removeHopByHopHeaders(h http.Header) {
hopByHopHeaders := []string{
"Connection",
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}
for _, header := range hopByHopHeaders {
h.Del(header)
}
}
func getClientIP(r *http.Request) string {
// Check X-Real-IP first
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
// Get from RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func getScheme(r *http.Request) string {
if r.TLS != nil {
return "https"
}
if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
return "http"
}
func isConnectionError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
connectionErrors := []string{
"connection refused",
"no such host",
"network is unreachable",
"connection reset",
"broken pipe",
}
for _, connErr := range connectionErrors {
if strings.Contains(strings.ToLower(errStr), connErr) {
return true
}
}
return false
}