Skip to content

Commit 75ad232

Browse files
authored
Implement transport.If() for conditional transports (e.g. for debugging) (#8)
* Implement transport.If() for conditional transports (e.g. for debugging) * PR feedback
1 parent a2be5e0 commit 75ad232

File tree

5 files changed

+118
-9
lines changed

5 files changed

+118
-9
lines changed

README.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ There are multiple use-cases where this pattern comes handy such as request logg
77

88
## Examples
99

10-
Set up HTTP client, which sets `User-Agent`, `Authorization` and `TraceID` headers automatically :
10+
Set up HTTP client, which sets `User-Agent`, `Authorization` and `TraceID` headers automatically:
1111
```go
1212
authClient := http.Client{
1313
Transport: transport.Chain(
@@ -22,12 +22,12 @@ authClient := http.Client{
2222

2323
Or debug all outgoing requests globally within your application:
2424
```go
25-
if debugMode {
26-
http.DefaultTransport = transport.Chain(
27-
http.DefaultTransport,
28-
transport.LogRequests,
29-
)
30-
}
25+
debugMode := os.Getenv("DEBUG") == "true"
26+
27+
http.DefaultTransport = transport.Chain(
28+
http.DefaultTransport,
29+
transport.If(debugMode, transport.LogRequests),
30+
)
3131
```
3232

3333
# Authors

if.go

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package transport
2+
3+
import (
4+
"net/http"
5+
)
6+
7+
// If sets given transport if given condition is true. Otherwise it sets nil transport, which will be ignored.
8+
//
9+
// Example:
10+
//
11+
// http.DefaultTransport = transport.Chain(
12+
// http.DefaultTransport,
13+
// transport.If(debugMode, transport.LogRequests),
14+
// )
15+
func If(condition bool, transport func(http.RoundTripper) http.RoundTripper) func(http.RoundTripper) http.RoundTripper {
16+
if condition {
17+
return transport
18+
}
19+
20+
return nil
21+
}

if_test.go

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package transport_test
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
"time"
9+
10+
"github.com/go-chi/transport"
11+
)
12+
13+
func TestIfTrue(t *testing.T) {
14+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15+
if r.Header.Get("debug") != "true" {
16+
t.Error("expected debug=true")
17+
w.WriteHeader(500)
18+
return
19+
}
20+
21+
fmt.Fprintf(w, "ok")
22+
}))
23+
defer server.Close()
24+
25+
client := &http.Client{
26+
Timeout: 15 * time.Second,
27+
Transport: transport.Chain(
28+
http.DefaultTransport,
29+
transport.If(true, transport.SetHeader("debug", "true")), // Set header.
30+
),
31+
}
32+
33+
request, err := http.NewRequest("GET", server.URL, nil)
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
resp, err := client.Do(request)
38+
if err != nil {
39+
t.Fatal(err)
40+
}
41+
42+
if resp.StatusCode != 200 {
43+
t.Fatal("unexpected response")
44+
}
45+
}
46+
47+
func TestIfFalse(t *testing.T) {
48+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49+
if r.Header.Get("debug") != "" {
50+
t.Error("expected no debug header")
51+
w.WriteHeader(500)
52+
return
53+
}
54+
55+
fmt.Fprintf(w, "ok")
56+
}))
57+
defer server.Close()
58+
59+
client := &http.Client{
60+
Timeout: 15 * time.Second,
61+
Transport: transport.Chain(
62+
http.DefaultTransport,
63+
transport.If(false, transport.SetHeader("debug", "true")), // Do not set header.
64+
),
65+
}
66+
67+
request, err := http.NewRequest("GET", server.URL, nil)
68+
if err != nil {
69+
t.Fatal(err)
70+
}
71+
resp, err := client.Do(request)
72+
if err != nil {
73+
t.Fatal(err)
74+
}
75+
76+
if resp.StatusCode != 200 {
77+
t.Fatal("unexpected response")
78+
}
79+
}

transport.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,22 @@ func Chain(base http.RoundTripper, mw ...func(http.RoundTripper) http.RoundTripp
4545
base = http.DefaultTransport
4646
}
4747

48+
// Filter out nil transports.
49+
mws := []func(http.RoundTripper) http.RoundTripper{}
50+
for _, fn := range mw {
51+
if fn != nil {
52+
mws = append(mws, fn)
53+
}
54+
}
55+
4856
if c, ok := base.(*chain); ok {
49-
c.middlewares = append(c.middlewares, mw...)
57+
c.middlewares = append(c.middlewares, mws...)
5058
return c
5159
}
5260

5361
return &chain{
5462
baseTransport: base,
55-
middlewares: mw,
63+
middlewares: mws,
5664
}
5765
}
5866

transport_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ func TestChain(t *testing.T) {
1313
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1414
if r.Header.Get("User-Agent") != "transport-chain/v1.0.0" {
1515
w.WriteHeader(500)
16+
return
1617
}
1718

1819
fmt.Fprintf(w, expected)

0 commit comments

Comments
 (0)