@@ -2,6 +2,7 @@ package middleware
2
2
3
3
import (
4
4
"net/http"
5
+ "regexp"
5
6
"strings"
6
7
)
7
8
@@ -20,6 +21,20 @@ import (
20
21
// r.Get("/", h)
21
22
// rSubdomain.Get("/", h2)
22
23
//
24
+ // Another example, lets say you'd like to route through some middleware based on
25
+ // presence of specific cookie and in request there are multiple cookies e.g.
26
+ // "firstcookie=one; secondcookie=two; thirdcookie=three", then you might use
27
+ // RouteHeadersContainsMatcher to be able to route this request:
28
+ //
29
+ // r := chi.NewRouter()
30
+ // routeMiddleware := middleware.RouteHeaders().
31
+ // SetMatcherType(middleware.RouteHeadersContainsMatcher).
32
+ // Route("Cookie", "secondcookie", MyCustomMiddleware).
33
+ // Handler
34
+ //
35
+ // r.Use(routeMiddleware)
36
+ // r.Get("/", h)
37
+ //
23
38
// Another example, imagine you want to setup multiple CORS handlers, where for
24
39
// your origin servers you allow authorized requests, but for third-party public
25
40
// requests, authorization is disabled.
@@ -39,70 +54,114 @@ import (
39
54
// AllowCredentials: false, // <----------<<< do not allow credentials
40
55
// })).
41
56
// Handler)
42
- func RouteHeaders () HeaderRouter {
43
- return HeaderRouter {}
57
+ func RouteHeaders () * HeaderRouter {
58
+ return & HeaderRouter {
59
+ routes : map [string ][]HeaderRoute {},
60
+ matchingType : RouteHeadersClassicMatcher ,
61
+ }
44
62
}
45
63
46
- type HeaderRouter map [string ][]HeaderRoute
64
+ type MatcherType int
65
+
66
+ const (
67
+ RouteHeadersClassicMatcher MatcherType = iota
68
+ RouteHeadersContainsMatcher
69
+ RouteHeadersRegexMatcher
70
+ )
71
+
72
+ type HeaderRouter struct {
73
+ routes map [string ][]HeaderRoute
74
+ matchingType MatcherType
75
+ }
76
+
77
+ func (hr * HeaderRouter ) SetMatchingType (matchingType MatcherType ) * HeaderRouter {
78
+ hr .matchingType = matchingType
79
+ return hr
80
+ }
47
81
48
- func (hr HeaderRouter ) Route (header , match string , middlewareHandler func (next http.Handler ) http.Handler ) HeaderRouter {
82
+ func (hr * HeaderRouter ) Route (
83
+ header ,
84
+ match string ,
85
+ middlewareHandler func (next http.Handler ) http.Handler ,
86
+ ) * HeaderRouter {
49
87
header = strings .ToLower (header )
50
- k := hr [header ]
88
+
89
+ k := hr .routes [header ]
51
90
if k == nil {
52
- hr [header ] = []HeaderRoute {}
91
+ hr . routes [header ] = []HeaderRoute {}
53
92
}
54
- hr [header ] = append (hr [header ], HeaderRoute {MatchOne : NewPattern (match ), Middleware : middlewareHandler })
93
+
94
+ hr .routes [header ] = append (
95
+ hr .routes [header ],
96
+ HeaderRoute {
97
+ MatchOne : NewPattern (strings .ToLower (match ), hr .matchingType ),
98
+ Middleware : middlewareHandler ,
99
+ },
100
+ )
55
101
return hr
56
102
}
57
103
58
- func (hr HeaderRouter ) RouteAny (header string , match []string , middlewareHandler func (next http.Handler ) http.Handler ) HeaderRouter {
104
+ func (hr * HeaderRouter ) RouteAny (
105
+ header string ,
106
+ match []string ,
107
+ middlewareHandler func (next http.Handler ) http.Handler ,
108
+ ) * HeaderRouter {
59
109
header = strings .ToLower (header )
60
- k := hr [header ]
110
+
111
+ k := hr .routes [header ]
61
112
if k == nil {
62
- hr [header ] = []HeaderRoute {}
113
+ hr . routes [header ] = []HeaderRoute {}
63
114
}
115
+
64
116
patterns := []Pattern {}
65
117
for _ , m := range match {
66
- patterns = append (patterns , NewPattern (m ))
118
+ patterns = append (patterns , NewPattern (m , hr . matchingType ))
67
119
}
68
- hr [header ] = append (hr [header ], HeaderRoute {MatchAny : patterns , Middleware : middlewareHandler })
120
+
121
+ hr .routes [header ] = append (
122
+ hr .routes [header ],
123
+ HeaderRoute {MatchAny : patterns , Middleware : middlewareHandler },
124
+ )
125
+
69
126
return hr
70
127
}
71
128
72
- func (hr HeaderRouter ) RouteDefault (handler func (next http.Handler ) http.Handler ) HeaderRouter {
73
- hr ["*" ] = []HeaderRoute {{Middleware : handler }}
129
+ func (hr * HeaderRouter ) RouteDefault (handler func (next http.Handler ) http.Handler ) * HeaderRouter {
130
+ hr . routes ["*" ] = []HeaderRoute {{Middleware : handler }}
74
131
return hr
75
132
}
76
133
77
- func (hr HeaderRouter ) Handler (next http.Handler ) http.Handler {
78
- return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
79
- if len (hr ) == 0 {
134
+ func (hr * HeaderRouter ) Handler (next http.Handler ) http.Handler {
135
+ return http .HandlerFunc (func (wrt http.ResponseWriter , req * http.Request ) {
136
+ if len (hr . routes ) == 0 {
80
137
// skip if no routes set
81
- next .ServeHTTP (w , r )
138
+ next .ServeHTTP (wrt , req )
82
139
}
83
140
84
141
// find first matching header route, and continue
85
- for header , matchers := range hr {
86
- headerValue := r .Header .Get (header )
142
+ for header , matchers := range hr . routes {
143
+ headerValue := req .Header .Get (header )
87
144
if headerValue == "" {
88
145
continue
89
146
}
147
+
90
148
headerValue = strings .ToLower (headerValue )
91
149
for _ , matcher := range matchers {
92
150
if matcher .IsMatch (headerValue ) {
93
- matcher .Middleware (next ).ServeHTTP (w , r )
151
+ matcher .Middleware (next ).ServeHTTP (wrt , req )
94
152
return
95
153
}
96
154
}
97
155
}
98
156
99
157
// if no match, check for "*" default route
100
- matcher , ok := hr ["*" ]
158
+ matcher , ok := hr . routes ["*" ]
101
159
if ! ok || matcher [0 ].Middleware == nil {
102
- next .ServeHTTP (w , r )
160
+ next .ServeHTTP (wrt , req )
103
161
return
104
162
}
105
- matcher [0 ].Middleware (next ).ServeHTTP (w , r )
163
+
164
+ matcher [0 ].Middleware (next ).ServeHTTP (wrt , req )
106
165
})
107
166
}
108
167
@@ -126,20 +185,40 @@ func (r HeaderRoute) IsMatch(value string) bool {
126
185
}
127
186
128
187
type Pattern struct {
129
- prefix string
130
- suffix string
131
- wildcard bool
188
+ prefix string
189
+ suffix string
190
+ wildcard bool
191
+ value string
192
+ matchingType MatcherType
132
193
}
133
194
134
- func NewPattern (value string ) Pattern {
135
- p := Pattern {}
136
- p .prefix , p .suffix , p .wildcard = strings .Cut (value , "*" )
137
- return p
195
+ func NewPattern (value string , matchingType MatcherType ) Pattern {
196
+ pat := Pattern {matchingType : matchingType }
197
+ switch matchingType {
198
+ case RouteHeadersClassicMatcher :
199
+ pat .prefix , pat .suffix , pat .wildcard = strings .Cut (value , "*" )
200
+ case RouteHeadersContainsMatcher :
201
+ pat .value = value
202
+ case RouteHeadersRegexMatcher :
203
+ pat .value = value
204
+ }
205
+ return pat
138
206
}
139
207
140
- func (p Pattern ) Match (v string ) bool {
141
- if ! p .wildcard {
142
- return p .prefix == v
208
+ func (p Pattern ) Match (mVal string ) bool {
209
+ switch p .matchingType {
210
+ case RouteHeadersClassicMatcher :
211
+ if ! p .wildcard {
212
+ return p .prefix == mVal
213
+ }
214
+ return len (mVal ) >= len (p .prefix + p .suffix ) &&
215
+ strings .HasPrefix (mVal , p .prefix ) &&
216
+ strings .HasSuffix (mVal , p .suffix )
217
+ case RouteHeadersContainsMatcher :
218
+ return strings .Contains (mVal , p .value )
219
+ case RouteHeadersRegexMatcher :
220
+ reg := regexp .MustCompile (p .value )
221
+ return reg .MatchString (mVal )
143
222
}
144
- return len ( v ) >= len ( p . prefix + p . suffix ) && strings . HasPrefix ( v , p . prefix ) && strings . HasSuffix ( v , p . suffix )
223
+ return false
145
224
}
0 commit comments