Skip to content

Commit b33ab76

Browse files
Add HeaderFunc to allow modifying headers before every request (#298)
Adds a new HeaderFunc to the StartSettings that allows for dynamically editing the headers before each HTTP request made by the OpAMP library. Closes #297
1 parent 7cdd395 commit b33ab76

File tree

5 files changed

+120
-13
lines changed

5 files changed

+120
-13
lines changed

client/clientimpl_test.go

+73
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,79 @@ func TestConnectWithHeader(t *testing.T) {
340340
})
341341
}
342342

343+
func TestConnectWithHeaderFunc(t *testing.T) {
344+
testClients(t, func(t *testing.T, client OpAMPClient) {
345+
// Start a server.
346+
srv := internal.StartMockServer(t)
347+
var conn atomic.Value
348+
srv.OnConnect = func(r *http.Request) {
349+
authHdr := r.Header.Get("Authorization")
350+
assert.EqualValues(t, "Bearer 12345678", authHdr)
351+
userAgentHdr := r.Header.Get("User-Agent")
352+
assert.EqualValues(t, "custom-agent/1.0", userAgentHdr)
353+
conn.Store(true)
354+
}
355+
356+
hf := func(header http.Header) http.Header {
357+
header.Set("Authorization", "Bearer 12345678")
358+
header.Set("User-Agent", "custom-agent/1.0")
359+
return header
360+
}
361+
362+
// Start a client.
363+
settings := types.StartSettings{
364+
OpAMPServerURL: "ws://" + srv.Endpoint,
365+
HeaderFunc: hf,
366+
}
367+
startClient(t, settings, client)
368+
369+
// Wait for connection to be established.
370+
eventually(t, func() bool { return conn.Load() != nil })
371+
372+
// Shutdown the Server and the client.
373+
srv.Close()
374+
_ = client.Stop(context.Background())
375+
})
376+
}
377+
378+
func TestConnectWithHeaderAndHeaderFunc(t *testing.T) {
379+
testClients(t, func(t *testing.T, client OpAMPClient) {
380+
// Start a server.
381+
srv := internal.StartMockServer(t)
382+
var conn atomic.Value
383+
srv.OnConnect = func(r *http.Request) {
384+
authHdr := r.Header.Get("Authorization")
385+
assert.EqualValues(t, "Bearer 12345678", authHdr)
386+
userAgentHdr := r.Header.Get("User-Agent")
387+
assert.EqualValues(t, "custom-agent/1.0", userAgentHdr)
388+
conn.Store(true)
389+
}
390+
391+
baseHeader := http.Header{}
392+
baseHeader.Set("User-Agent", "custom-agent/1.0")
393+
394+
hf := func(header http.Header) http.Header {
395+
header.Set("Authorization", "Bearer 12345678")
396+
return header
397+
}
398+
399+
// Start a client.
400+
settings := types.StartSettings{
401+
OpAMPServerURL: "ws://" + srv.Endpoint,
402+
Header: baseHeader,
403+
HeaderFunc: hf,
404+
}
405+
startClient(t, settings, client)
406+
407+
// Wait for connection to be established.
408+
eventually(t, func() bool { return conn.Load() != nil })
409+
410+
// Shutdown the Server and the client.
411+
srv.Close()
412+
_ = client.Stop(context.Background())
413+
})
414+
}
415+
343416
func TestConnectWithTLS(t *testing.T) {
344417
testClients(t, func(t *testing.T, client OpAMPClient) {
345418
// Start a server.

client/httpclient.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (c *httpClient) Start(ctx context.Context, settings types.StartSettings) er
4444
c.opAMPServerURL = settings.OpAMPServerURL
4545

4646
// Prepare Server connection settings.
47-
c.sender.SetRequestHeader(settings.Header)
47+
c.sender.SetRequestHeader(settings.Header, settings.HeaderFunc)
4848

4949
// Add TLS configuration into httpClient
5050
c.sender.AddTLSConfig(settings.TLSConfig)

client/internal/httpsender.go

+24-9
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ type HTTPSender struct {
5959
compressionEnabled bool
6060

6161
// Headers to send with all requests.
62-
requestHeader http.Header
62+
getHeader func() http.Header
6363

6464
// Processor to handle received messages.
6565
receiveProcessor receivedProcessor
@@ -75,7 +75,7 @@ func NewHTTPSender(logger types.Logger) *HTTPSender {
7575
pollingIntervalMs: defaultPollingIntervalMs,
7676
}
7777
// initialize the headers with no additional headers
78-
h.SetRequestHeader(nil)
78+
h.SetRequestHeader(nil, nil)
7979
return h
8080
}
8181

@@ -121,12 +121,26 @@ func (h *HTTPSender) Run(
121121

122122
// SetRequestHeader sets additional HTTP headers to send with all future requests.
123123
// Should not be called concurrently with any other method.
124-
func (h *HTTPSender) SetRequestHeader(header http.Header) {
125-
if header == nil {
126-
header = http.Header{}
124+
func (h *HTTPSender) SetRequestHeader(baseHeaders http.Header, headerFunc func(http.Header) http.Header) {
125+
if baseHeaders == nil {
126+
baseHeaders = http.Header{}
127+
}
128+
129+
if headerFunc == nil {
130+
headerFunc = func(h http.Header) http.Header {
131+
return h
132+
}
133+
}
134+
135+
h.getHeader = func() http.Header {
136+
requestHeader := headerFunc(baseHeaders.Clone())
137+
requestHeader.Set(headerContentType, contentTypeProtobuf)
138+
if h.compressionEnabled {
139+
requestHeader.Set(headerContentEncoding, encodingTypeGZip)
140+
}
141+
142+
return requestHeader
127143
}
128-
h.requestHeader = header
129-
h.requestHeader.Set(headerContentType, contentTypeProtobuf)
130144
}
131145

132146
// makeOneRequestRoundtrip sends a request and receives a response.
@@ -255,7 +269,7 @@ func (h *HTTPSender) prepareRequest(ctx context.Context) (*requestWrapper, error
255269
return nil, err
256270
}
257271

258-
req.Header = h.requestHeader
272+
req.Header = h.getHeader()
259273
return &req, nil
260274
}
261275

@@ -295,9 +309,10 @@ func (h *HTTPSender) SetPollingInterval(duration time.Duration) {
295309
atomic.StoreInt64(&h.pollingIntervalMs, duration.Milliseconds())
296310
}
297311

312+
// EnableCompression enables compression for the sender.
313+
// Should not be called concurrently with Run.
298314
func (h *HTTPSender) EnableCompression() {
299315
h.compressionEnabled = true
300-
h.requestHeader.Set(headerContentEncoding, encodingTypeGZip)
301316
}
302317

303318
func (h *HTTPSender) AddTLSConfig(config *tls.Config) {

client/types/startsettings.go

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ type StartSettings struct {
1717
// Optional additional HTTP headers to send with all HTTP requests.
1818
Header http.Header
1919

20+
// Optional function that can be used to modify the HTTP headers
21+
// before each HTTP request.
22+
// Can modify and return the argument or return the argument without modifying.
23+
HeaderFunc func(http.Header) http.Header
24+
2025
// Optional TLS config for HTTP connection.
2126
TLSConfig *tls.Config
2227

client/wsclient.go

+17-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type wsClient struct {
3131
url *url.URL
3232

3333
// HTTP request headers to use when connecting to OpAMP Server.
34-
requestHeader http.Header
34+
getHeader func() http.Header
3535

3636
// Websocket dialer and connection.
3737
dialer websocket.Dialer
@@ -86,7 +86,21 @@ func (c *wsClient) Start(ctx context.Context, settings types.StartSettings) erro
8686
}
8787
c.dialer.TLSClientConfig = settings.TLSConfig
8888

89-
c.requestHeader = settings.Header
89+
headerFunc := settings.HeaderFunc
90+
if headerFunc == nil {
91+
headerFunc = func(h http.Header) http.Header {
92+
return h
93+
}
94+
}
95+
96+
baseHeader := settings.Header
97+
if baseHeader == nil {
98+
baseHeader = http.Header{}
99+
}
100+
101+
c.getHeader = func() http.Header {
102+
return headerFunc(baseHeader.Clone())
103+
}
90104

91105
c.common.StartConnectAndRun(c.runUntilStopped)
92106

@@ -142,7 +156,7 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS
142156
// by the Server.
143157
func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) {
144158
var resp *http.Response
145-
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.requestHeader)
159+
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader())
146160
if err != nil {
147161
if c.common.Callbacks != nil && !c.common.IsStopping() {
148162
c.common.Callbacks.OnConnectFailed(ctx, err)

0 commit comments

Comments
 (0)