diff --git a/intercept/forward_headers.go b/intercept/forward_headers.go new file mode 100644 index 0000000..8bdf4c8 --- /dev/null +++ b/intercept/forward_headers.go @@ -0,0 +1,91 @@ +package intercept + +import ( + "net" + "net/http" +) + +// hopByHopHeaders are connection-level headers that must not be forwarded to an upstream. +var hopByHopHeaders = []string{ + "Connection", + "Transfer-Encoding", + "Upgrade", + "Keep-Alive", + "Proxy-Authorization", + "Proxy-Connection", + "Te", + "Trailer", +} + +// managedHeaders are headers managed by the SDK or the provider configuration. +// They must not be forwarded from the client so that our own values always win. +var managedHeaders = []string{ + "Authorization", + "X-Api-Key", + "Host", + "Content-Length", + "Content-Type", +} + +// PrepareForwardHeaders clones clientHeaders and returns a sanitized copy suitable +// for forwarding to an upstream provider. +// +// It removes: +// - hop-by-hop headers (Connection, Transfer-Encoding, …) +// - provider-managed headers (Authorization, X-Api-Key, Host, Content-Length, Content-Type) +// - actor headers (anything whose canonical name starts with X-AI-Bridge-Actor) +// +// It then sets standard proxy headers: +// - X-Forwarded-For (appends the client IP extracted from remoteAddr) +// - X-Forwarded-Host (set to host) +// - X-Forwarded-Proto (https or http based on isTLS) +// - User-Agent (set to "aibridge" if not already present) +func PrepareForwardHeaders(clientHeaders http.Header, remoteAddr, host string, isTLS bool) http.Header { + if clientHeaders == nil { + clientHeaders = http.Header{} + } + headers := clientHeaders.Clone() + + // Remove hop-by-hop headers. + for _, h := range hopByHopHeaders { + headers.Del(h) + } + + // Remove provider-managed headers. + for _, h := range managedHeaders { + headers.Del(h) + } + + // Remove actor headers. Deletion during range over a map is safe in Go. + for k := range headers { + if IsActorHeader(k) { + headers.Del(k) + } + } + + // X-Forwarded-For: append client IP to any existing value. + clientIP, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + clientIP = remoteAddr + } + if prior := headers.Get("X-Forwarded-For"); prior != "" { + headers.Set("X-Forwarded-For", prior+", "+clientIP) + } else { + headers.Set("X-Forwarded-For", clientIP) + } + + headers.Set("X-Forwarded-Host", host) + + if isTLS { + headers.Set("X-Forwarded-Proto", "https") + } else { + headers.Set("X-Forwarded-Proto", "http") + } + + // Use a consistent User-Agent if the client did not set one. + if _, ok := headers["User-Agent"]; !ok { + headers.Set("User-Agent", "aibridge") + } + + return headers +} diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 6f4f01f..9d90afb 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -18,6 +18,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" aibconfig "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -40,6 +41,13 @@ type interceptionBase struct { cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock + // clientHeaders holds the sanitized headers from the original client request. + // They are forwarded to the upstream provider at lowest priority. + clientHeaders http.Header + remoteAddr string + host string + isTLS bool + tracer trace.Tracer logger slog.Logger @@ -163,6 +171,17 @@ func (i *interceptionBase) isSmallFastModel() bool { } func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { + // Client headers are forwarded at lowest priority: prepend them so that + // actor headers and provider auth (appended below) win on any conflict. + fwd := intercept.PrepareForwardHeaders(i.clientHeaders, i.remoteAddr, i.host, i.isTLS) + var clientOpts []option.RequestOption + for k, vals := range fwd { + for _, v := range vals { + clientOpts = append(clientOpts, option.WithHeaderAdd(k, v)) + } + } + opts = append(clientOpts, opts...) + opts = append(opts, option.WithAPIKey(i.cfg.Key)) opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 7ab2bed..845167a 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -28,14 +28,18 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, remoteAddr, host string, isTLS bool, tracer trace.Tracer) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - payload: payload, - cfg: cfg, - bedrockCfg: bedrockCfg, - tracer: tracer, + id: id, + req: req, + payload: payload, + cfg: cfg, + bedrockCfg: bedrockCfg, + clientHeaders: clientHeaders, + remoteAddr: remoteAddr, + host: host, + isTLS: isTLS, + tracer: tracer, }} } diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 4fc19fd..a91255e 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -34,14 +34,18 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, remoteAddr, host string, isTLS bool, tracer trace.Tracer) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - payload: payload, - cfg: cfg, - bedrockCfg: bedrockCfg, - tracer: tracer, + id: id, + req: req, + payload: payload, + cfg: cfg, + bedrockCfg: bedrockCfg, + clientHeaders: clientHeaders, + remoteAddr: remoteAddr, + host: host, + isTLS: isTLS, + tracer: tracer, }} } diff --git a/passthrough.go b/passthrough.go index 0dcef9c..dc0ca78 100644 --- a/passthrough.go +++ b/passthrough.go @@ -1,13 +1,13 @@ package aibridge import ( - "net" "net/http" "net/http/httputil" "net/url" "time" "cdr.dev/slog/v3" + "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" @@ -68,29 +68,8 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met req.Host = upURL.Host span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, req.URL.String())) - // Copy headers from client. - req.Header = r.Header.Clone() - - // Standard proxy headers. - host, _, herr := net.SplitHostPort(r.RemoteAddr) - if herr != nil { - host = r.RemoteAddr - } - if prior := req.Header.Get("X-Forwarded-For"); prior != "" { - req.Header.Set("X-Forwarded-For", prior+", "+host) - } else { - req.Header.Set("X-Forwarded-For", host) - } - req.Header.Set("X-Forwarded-Host", r.Host) - if r.TLS != nil { - req.Header.Set("X-Forwarded-Proto", "https") - } else { - req.Header.Set("X-Forwarded-Proto", "http") - } - // Avoid default Go user-agent if none provided. - if _, ok := req.Header["User-Agent"]; !ok { - req.Header.Set("User-Agent", "aibridge") // TODO: use build tag. - } + // Copy and sanitize headers from client, then set standard proxy headers. + req.Header = intercept.PrepareForwardHeaders(r.Header, r.RemoteAddr, r.Host, r.TLS != nil) // Inject provider auth. provider.InjectAuthHeader(&req.Header) diff --git a/provider/anthropic.go b/provider/anthropic.go index be12583..02896c5 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -102,9 +102,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr var interceptor intercept.Interceptor if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, tracer) + interceptor = messages.NewStreamingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, r.Header.Clone(), r.RemoteAddr, r.Host, r.TLS != nil, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, tracer) + interceptor = messages.NewBlockingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, r.Header.Clone(), r.RemoteAddr, r.Host, r.TLS != nil, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil