Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions intercept/forward_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package intercept

import (
"net"
"net/http"
)

// hopByHopHeaders are connection-level headers that must not be forwarded to an upstream.
var hopByHopHeaders = []string{
"Connection",
"Transfer-Encoding",
"Upgrade",
"Keep-Alive",
"Proxy-Authorization",
"Proxy-Connection",
"Te",
"Trailer",
}

// managedHeaders are headers managed by the SDK or the provider configuration.
// They must not be forwarded from the client so that our own values always win.
var managedHeaders = []string{
"Authorization",
"X-Api-Key",
"Host",
"Content-Length",
"Content-Type",
}

// PrepareForwardHeaders clones clientHeaders and returns a sanitized copy suitable
// for forwarding to an upstream provider.
//
// It removes:
// - hop-by-hop headers (Connection, Transfer-Encoding, …)
// - provider-managed headers (Authorization, X-Api-Key, Host, Content-Length, Content-Type)
// - actor headers (anything whose canonical name starts with X-AI-Bridge-Actor)
//
// It then sets standard proxy headers:
// - X-Forwarded-For (appends the client IP extracted from remoteAddr)
// - X-Forwarded-Host (set to host)
// - X-Forwarded-Proto (https or http based on isTLS)
// - User-Agent (set to "aibridge" if not already present)
func PrepareForwardHeaders(clientHeaders http.Header, remoteAddr, host string, isTLS bool) http.Header {
if clientHeaders == nil {
clientHeaders = http.Header{}
}
headers := clientHeaders.Clone()

// Remove hop-by-hop headers.
for _, h := range hopByHopHeaders {
headers.Del(h)
}

// Remove provider-managed headers.
for _, h := range managedHeaders {
headers.Del(h)
}

// Remove actor headers. Deletion during range over a map is safe in Go.
for k := range headers {
if IsActorHeader(k) {
headers.Del(k)
}
}

// X-Forwarded-For: append client IP to any existing value.
clientIP, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
clientIP = remoteAddr
}
if prior := headers.Get("X-Forwarded-For"); prior != "" {
headers.Set("X-Forwarded-For", prior+", "+clientIP)
} else {
headers.Set("X-Forwarded-For", clientIP)
}

headers.Set("X-Forwarded-Host", host)

if isTLS {
headers.Set("X-Forwarded-Proto", "https")
} else {
headers.Set("X-Forwarded-Proto", "http")
}

// Use a consistent User-Agent if the client did not set one.
if _, ok := headers["User-Agent"]; !ok {
headers.Set("User-Agent", "aibridge")
}

return headers
}
19 changes: 19 additions & 0 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials"
aibconfig "github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/recorder"
Expand All @@ -40,6 +41,13 @@ type interceptionBase struct {
cfg aibconfig.Anthropic
bedrockCfg *aibconfig.AWSBedrock

// clientHeaders holds the sanitized headers from the original client request.
// They are forwarded to the upstream provider at lowest priority.
clientHeaders http.Header
remoteAddr string
host string
isTLS bool

tracer trace.Tracer
logger slog.Logger

Expand Down Expand Up @@ -163,6 +171,17 @@ func (i *interceptionBase) isSmallFastModel() bool {
}

func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) {
// Client headers are forwarded at lowest priority: prepend them so that
// actor headers and provider auth (appended below) win on any conflict.
fwd := intercept.PrepareForwardHeaders(i.clientHeaders, i.remoteAddr, i.host, i.isTLS)
var clientOpts []option.RequestOption
for k, vals := range fwd {
for _, v := range vals {
clientOpts = append(clientOpts, option.WithHeaderAdd(k, v))
}
}
opts = append(clientOpts, opts...)

opts = append(opts, option.WithAPIKey(i.cfg.Key))
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))

Expand Down
18 changes: 11 additions & 7 deletions intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@ type BlockingInterception struct {
interceptionBase
}

func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception {
func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, remoteAddr, host string, isTLS bool, tracer trace.Tracer) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
req: req,
payload: payload,
cfg: cfg,
bedrockCfg: bedrockCfg,
tracer: tracer,
id: id,
req: req,
payload: payload,
cfg: cfg,
bedrockCfg: bedrockCfg,
clientHeaders: clientHeaders,
remoteAddr: remoteAddr,
host: host,
isTLS: isTLS,
tracer: tracer,
}}
}

Expand Down
18 changes: 11 additions & 7 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ type StreamingInterception struct {
interceptionBase
}

func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception {
func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, remoteAddr, host string, isTLS bool, tracer trace.Tracer) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
req: req,
payload: payload,
cfg: cfg,
bedrockCfg: bedrockCfg,
tracer: tracer,
id: id,
req: req,
payload: payload,
cfg: cfg,
bedrockCfg: bedrockCfg,
clientHeaders: clientHeaders,
remoteAddr: remoteAddr,
host: host,
isTLS: isTLS,
tracer: tracer,
}}
}

Expand Down
27 changes: 3 additions & 24 deletions passthrough.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package aibridge

import (
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"

"cdr.dev/slog/v3"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/metrics"
"github.com/coder/aibridge/provider"
"github.com/coder/aibridge/tracing"
Expand Down Expand Up @@ -68,29 +68,8 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met
req.Host = upURL.Host
span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, req.URL.String()))

// Copy headers from client.
req.Header = r.Header.Clone()

// Standard proxy headers.
host, _, herr := net.SplitHostPort(r.RemoteAddr)
if herr != nil {
host = r.RemoteAddr
}
if prior := req.Header.Get("X-Forwarded-For"); prior != "" {
req.Header.Set("X-Forwarded-For", prior+", "+host)
} else {
req.Header.Set("X-Forwarded-For", host)
}
req.Header.Set("X-Forwarded-Host", r.Host)
if r.TLS != nil {
req.Header.Set("X-Forwarded-Proto", "https")
} else {
req.Header.Set("X-Forwarded-Proto", "http")
}
// Avoid default Go user-agent if none provided.
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
}
// Copy and sanitize headers from client, then set standard proxy headers.
req.Header = intercept.PrepareForwardHeaders(r.Header, r.RemoteAddr, r.Host, r.TLS != nil)

// Inject provider auth.
provider.InjectAuthHeader(&req.Header)
Expand Down
4 changes: 2 additions & 2 deletions provider/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr

var interceptor intercept.Interceptor
if req.Stream {
interceptor = messages.NewStreamingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, tracer)
interceptor = messages.NewStreamingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, r.Header.Clone(), r.RemoteAddr, r.Host, r.TLS != nil, tracer)
} else {
interceptor = messages.NewBlockingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, tracer)
interceptor = messages.NewBlockingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, r.Header.Clone(), r.RemoteAddr, r.Host, r.TLS != nil, tracer)
}
span.SetAttributes(interceptor.TraceAttributes(r)...)
return interceptor, nil
Expand Down