Skip to content

Commit 8d4ac56

Browse files
committed
Improve route headers middleware, add possibility to match using contain or regex, add tests
1 parent d7034fd commit 8d4ac56

File tree

2 files changed

+234
-35
lines changed

2 files changed

+234
-35
lines changed

middleware/route_headers.go

+114-35
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package middleware
22

33
import (
44
"net/http"
5+
"regexp"
56
"strings"
67
)
78

@@ -20,6 +21,20 @@ import (
2021
// r.Get("/", h)
2122
// rSubdomain.Get("/", h2)
2223
//
24+
// Another example, lets say you'd like to route through some middleware based on
25+
// presence of specific cookie and in request there are multiple cookies e.g.
26+
// "firstcookie=one; secondcookie=two; thirdcookie=three", then you might use
27+
// RouteHeadersContainsMatcher to be able to route this request:
28+
//
29+
// r := chi.NewRouter()
30+
// routeMiddleware := middleware.RouteHeaders().
31+
// SetMatcherType(middleware.RouteHeadersContainsMatcher).
32+
// Route("Cookie", "secondcookie", MyCustomMiddleware).
33+
// Handler
34+
//
35+
// r.Use(routeMiddleware)
36+
// r.Get("/", h)
37+
//
2338
// Another example, imagine you want to setup multiple CORS handlers, where for
2439
// your origin servers you allow authorized requests, but for third-party public
2540
// requests, authorization is disabled.
@@ -39,70 +54,114 @@ import (
3954
// AllowCredentials: false, // <----------<<< do not allow credentials
4055
// })).
4156
// Handler)
42-
func RouteHeaders() HeaderRouter {
43-
return HeaderRouter{}
57+
func RouteHeaders() *HeaderRouter {
58+
return &HeaderRouter{
59+
routes: map[string][]HeaderRoute{},
60+
matchingType: RouteHeadersClassicMatcher,
61+
}
4462
}
4563

46-
type HeaderRouter map[string][]HeaderRoute
64+
type MatcherType int
65+
66+
const (
67+
RouteHeadersClassicMatcher MatcherType = iota
68+
RouteHeadersContainsMatcher
69+
RouteHeadersRegexMatcher
70+
)
71+
72+
type HeaderRouter struct {
73+
routes map[string][]HeaderRoute
74+
matchingType MatcherType
75+
}
76+
77+
func (hr *HeaderRouter) SetMatchingType(matchingType MatcherType) *HeaderRouter {
78+
hr.matchingType = matchingType
79+
return hr
80+
}
4781

48-
func (hr HeaderRouter) Route(header, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
82+
func (hr *HeaderRouter) Route(
83+
header,
84+
match string,
85+
middlewareHandler func(next http.Handler) http.Handler,
86+
) *HeaderRouter {
4987
header = strings.ToLower(header)
50-
k := hr[header]
88+
89+
k := hr.routes[header]
5190
if k == nil {
52-
hr[header] = []HeaderRoute{}
91+
hr.routes[header] = []HeaderRoute{}
5392
}
54-
hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler})
93+
94+
hr.routes[header] = append(
95+
hr.routes[header],
96+
HeaderRoute{
97+
MatchOne: NewPattern(strings.ToLower(match), hr.matchingType),
98+
Middleware: middlewareHandler,
99+
},
100+
)
55101
return hr
56102
}
57103

58-
func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
104+
func (hr *HeaderRouter) RouteAny(
105+
header string,
106+
match []string,
107+
middlewareHandler func(next http.Handler) http.Handler,
108+
) *HeaderRouter {
59109
header = strings.ToLower(header)
60-
k := hr[header]
110+
111+
k := hr.routes[header]
61112
if k == nil {
62-
hr[header] = []HeaderRoute{}
113+
hr.routes[header] = []HeaderRoute{}
63114
}
115+
64116
patterns := []Pattern{}
65117
for _, m := range match {
66-
patterns = append(patterns, NewPattern(m))
118+
patterns = append(patterns, NewPattern(m, hr.matchingType))
67119
}
68-
hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler})
120+
121+
hr.routes[header] = append(
122+
hr.routes[header],
123+
HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler},
124+
)
125+
69126
return hr
70127
}
71128

72-
func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter {
73-
hr["*"] = []HeaderRoute{{Middleware: handler}}
129+
func (hr *HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) *HeaderRouter {
130+
hr.routes["*"] = []HeaderRoute{{Middleware: handler}}
74131
return hr
75132
}
76133

77-
func (hr HeaderRouter) Handler(next http.Handler) http.Handler {
78-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
79-
if len(hr) == 0 {
134+
func (hr *HeaderRouter) Handler(next http.Handler) http.Handler {
135+
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
136+
if len(hr.routes) == 0 {
80137
// skip if no routes set
81-
next.ServeHTTP(w, r)
138+
next.ServeHTTP(wrt, req)
82139
}
83140

84141
// find first matching header route, and continue
85-
for header, matchers := range hr {
86-
headerValue := r.Header.Get(header)
142+
for header, matchers := range hr.routes {
143+
headerValue := req.Header.Get(header)
87144
if headerValue == "" {
88145
continue
89146
}
147+
90148
headerValue = strings.ToLower(headerValue)
91149
for _, matcher := range matchers {
92150
if matcher.IsMatch(headerValue) {
93-
matcher.Middleware(next).ServeHTTP(w, r)
151+
matcher.Middleware(next).ServeHTTP(wrt, req)
94152
return
95153
}
96154
}
97155
}
98156

99157
// if no match, check for "*" default route
100-
matcher, ok := hr["*"]
158+
matcher, ok := hr.routes["*"]
101159
if !ok || matcher[0].Middleware == nil {
102-
next.ServeHTTP(w, r)
160+
next.ServeHTTP(wrt, req)
103161
return
104162
}
105-
matcher[0].Middleware(next).ServeHTTP(w, r)
163+
164+
matcher[0].Middleware(next).ServeHTTP(wrt, req)
106165
})
107166
}
108167

@@ -126,20 +185,40 @@ func (r HeaderRoute) IsMatch(value string) bool {
126185
}
127186

128187
type Pattern struct {
129-
prefix string
130-
suffix string
131-
wildcard bool
188+
prefix string
189+
suffix string
190+
wildcard bool
191+
value string
192+
matchingType MatcherType
132193
}
133194

134-
func NewPattern(value string) Pattern {
135-
p := Pattern{}
136-
p.prefix, p.suffix, p.wildcard = strings.Cut(value, "*")
137-
return p
195+
func NewPattern(value string, matchingType MatcherType) Pattern {
196+
pat := Pattern{matchingType: matchingType}
197+
switch matchingType {
198+
case RouteHeadersClassicMatcher:
199+
pat.prefix, pat.suffix, pat.wildcard = strings.Cut(value, "*")
200+
case RouteHeadersContainsMatcher:
201+
pat.value = value
202+
case RouteHeadersRegexMatcher:
203+
pat.value = value
204+
}
205+
return pat
138206
}
139207

140-
func (p Pattern) Match(v string) bool {
141-
if !p.wildcard {
142-
return p.prefix == v
208+
func (p Pattern) Match(mVal string) bool {
209+
switch p.matchingType {
210+
case RouteHeadersClassicMatcher:
211+
if !p.wildcard {
212+
return p.prefix == mVal
213+
}
214+
return len(mVal) >= len(p.prefix+p.suffix) &&
215+
strings.HasPrefix(mVal, p.prefix) &&
216+
strings.HasSuffix(mVal, p.suffix)
217+
case RouteHeadersContainsMatcher:
218+
return strings.Contains(mVal, p.value)
219+
case RouteHeadersRegexMatcher:
220+
reg := regexp.MustCompile(p.value)
221+
return reg.MatchString(mVal)
143222
}
144-
return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix)
223+
return false
145224
}

middleware/route_headers_test.go

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package middleware
2+
3+
import (
4+
"bytes"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/go-chi/chi/v5"
10+
)
11+
12+
func RouteHeadersDenyTestMiddleware() func(http.Handler) http.Handler {
13+
return func(_ http.Handler) http.Handler {
14+
return http.HandlerFunc(func(wrt http.ResponseWriter, _ *http.Request) {
15+
wrt.WriteHeader(http.StatusForbidden)
16+
})
17+
}
18+
}
19+
20+
func TestRouteHeadersMiddleware(t *testing.T) {
21+
t.Parallel()
22+
23+
tests := []struct {
24+
name string
25+
header string
26+
match string
27+
matchType MatcherType
28+
requestHeaders map[string]string
29+
want int
30+
}{
31+
{
32+
"TestClassicMatch",
33+
"Authorization",
34+
"Bearer *",
35+
RouteHeadersClassicMatcher,
36+
map[string]string{
37+
"Authorization": "Bearer whatever",
38+
"Other": "bera",
39+
},
40+
http.StatusForbidden,
41+
},
42+
{
43+
"TestContainsMatch",
44+
"Cookie",
45+
"kc-access=",
46+
RouteHeadersContainsMatcher,
47+
map[string]string{
48+
"Cookie": "some-cookie=tadadada; kc-access=mytoken",
49+
},
50+
http.StatusForbidden,
51+
},
52+
{
53+
"TestRegexMatch",
54+
"X-Custom-Header",
55+
".*mycustom[4-9]+.*",
56+
RouteHeadersRegexMatcher,
57+
map[string]string{
58+
"X-Custom-Header": "test1mycustom564other",
59+
},
60+
http.StatusForbidden,
61+
},
62+
{
63+
"TestMatchAndValueIsLowered",
64+
"Authorization",
65+
"Bearer *",
66+
RouteHeadersClassicMatcher,
67+
map[string]string{
68+
"Authorization": "bearer whatever",
69+
},
70+
http.StatusForbidden,
71+
},
72+
{
73+
"TestNotMatch",
74+
"Authorization",
75+
"Bearer *",
76+
RouteHeadersClassicMatcher,
77+
map[string]string{
78+
"Authorization": "Basic test",
79+
},
80+
http.StatusOK,
81+
},
82+
}
83+
84+
for _, test := range tests {
85+
test := test
86+
t.Run(test.name, func(t *testing.T) {
87+
t.Parallel()
88+
89+
recorder := httptest.NewRecorder()
90+
91+
headerRouterMiddleware := RouteHeaders().
92+
SetMatchingType(test.matchType).
93+
Route(
94+
test.header,
95+
test.match,
96+
RouteHeadersDenyTestMiddleware(),
97+
).
98+
Handler
99+
100+
router := chi.NewRouter()
101+
router.Use(headerRouterMiddleware)
102+
router.Get("/", func(_ http.ResponseWriter, _ *http.Request) {})
103+
104+
var body []byte
105+
req := httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(body))
106+
for hName, hVal := range test.requestHeaders {
107+
req.Header.Set(hName, hVal)
108+
}
109+
110+
router.ServeHTTP(recorder, req)
111+
res := recorder.Result()
112+
113+
res.Body.Close()
114+
115+
if res.StatusCode != test.want {
116+
t.Errorf("response is incorrect, got %d, want %d", recorder.Code, test.want)
117+
}
118+
})
119+
}
120+
}

0 commit comments

Comments
 (0)