Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to match by substring or regex in headers middleware #989

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 114 additions & 35 deletions middleware/route_headers.go
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ package middleware

import (
"net/http"
"regexp"
"strings"
)

@@ -20,6 +21,20 @@ import (
// r.Get("/", h)
// rSubdomain.Get("/", h2)
//
// Another example, lets say you'd like to route through some middleware based on
// presence of specific cookie and in request there are multiple cookies e.g.
// "firstcookie=one; secondcookie=two; thirdcookie=three", then you might use
// RouteHeadersContainsMatcher to be able to route this request:
//
// r := chi.NewRouter()
// routeMiddleware := middleware.RouteHeaders().
// SetMatcherType(middleware.RouteHeadersContainsMatcher).
// Route("Cookie", "secondcookie", MyCustomMiddleware).
// Handler
//
// r.Use(routeMiddleware)
// r.Get("/", h)
//
// Another example, imagine you want to setup multiple CORS handlers, where for
// your origin servers you allow authorized requests, but for third-party public
// requests, authorization is disabled.
@@ -39,70 +54,114 @@ import (
// AllowCredentials: false, // <----------<<< do not allow credentials
// })).
// Handler)
func RouteHeaders() HeaderRouter {
return HeaderRouter{}
func RouteHeaders() *HeaderRouter {
return &HeaderRouter{
routes: map[string][]HeaderRoute{},
matchingType: RouteHeadersClassicMatcher,
}
}

type HeaderRouter map[string][]HeaderRoute
type MatcherType int

const (
RouteHeadersClassicMatcher MatcherType = iota
RouteHeadersContainsMatcher
RouteHeadersRegexMatcher
)

type HeaderRouter struct {
routes map[string][]HeaderRoute
matchingType MatcherType
}

func (hr *HeaderRouter) SetMatchingType(matchingType MatcherType) *HeaderRouter {
hr.matchingType = matchingType
return hr
}

func (hr HeaderRouter) Route(header, match string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
func (hr *HeaderRouter) Route(
header,
match string,
middlewareHandler func(next http.Handler) http.Handler,
) *HeaderRouter {
header = strings.ToLower(header)
k := hr[header]

k := hr.routes[header]
if k == nil {
hr[header] = []HeaderRoute{}
hr.routes[header] = []HeaderRoute{}
}
hr[header] = append(hr[header], HeaderRoute{MatchOne: NewPattern(match), Middleware: middlewareHandler})

hr.routes[header] = append(
hr.routes[header],
HeaderRoute{
MatchOne: NewPattern(strings.ToLower(match), hr.matchingType),
Middleware: middlewareHandler,
},
)
return hr
}

func (hr HeaderRouter) RouteAny(header string, match []string, middlewareHandler func(next http.Handler) http.Handler) HeaderRouter {
func (hr *HeaderRouter) RouteAny(
header string,
match []string,
middlewareHandler func(next http.Handler) http.Handler,
) *HeaderRouter {
header = strings.ToLower(header)
k := hr[header]

k := hr.routes[header]
if k == nil {
hr[header] = []HeaderRoute{}
hr.routes[header] = []HeaderRoute{}
}

patterns := []Pattern{}
for _, m := range match {
patterns = append(patterns, NewPattern(m))
patterns = append(patterns, NewPattern(m, hr.matchingType))
}
hr[header] = append(hr[header], HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler})

hr.routes[header] = append(
hr.routes[header],
HeaderRoute{MatchAny: patterns, Middleware: middlewareHandler},
)

return hr
}

func (hr HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) HeaderRouter {
hr["*"] = []HeaderRoute{{Middleware: handler}}
func (hr *HeaderRouter) RouteDefault(handler func(next http.Handler) http.Handler) *HeaderRouter {
hr.routes["*"] = []HeaderRoute{{Middleware: handler}}
return hr
}

func (hr HeaderRouter) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(hr) == 0 {
func (hr *HeaderRouter) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) {
if len(hr.routes) == 0 {
// skip if no routes set
next.ServeHTTP(w, r)
next.ServeHTTP(wrt, req)
}

// find first matching header route, and continue
for header, matchers := range hr {
headerValue := r.Header.Get(header)
for header, matchers := range hr.routes {
headerValue := req.Header.Get(header)
if headerValue == "" {
continue
}

headerValue = strings.ToLower(headerValue)
for _, matcher := range matchers {
if matcher.IsMatch(headerValue) {
matcher.Middleware(next).ServeHTTP(w, r)
matcher.Middleware(next).ServeHTTP(wrt, req)
return
}
}
}

// if no match, check for "*" default route
matcher, ok := hr["*"]
matcher, ok := hr.routes["*"]
if !ok || matcher[0].Middleware == nil {
next.ServeHTTP(w, r)
next.ServeHTTP(wrt, req)
return
}
matcher[0].Middleware(next).ServeHTTP(w, r)

matcher[0].Middleware(next).ServeHTTP(wrt, req)
})
}

@@ -126,20 +185,40 @@ func (r HeaderRoute) IsMatch(value string) bool {
}

type Pattern struct {
prefix string
suffix string
wildcard bool
prefix string
suffix string
wildcard bool
value string
matchingType MatcherType
}

func NewPattern(value string) Pattern {
p := Pattern{}
p.prefix, p.suffix, p.wildcard = strings.Cut(value, "*")
return p
func NewPattern(value string, matchingType MatcherType) Pattern {
pat := Pattern{matchingType: matchingType}
switch matchingType {
case RouteHeadersClassicMatcher:
pat.prefix, pat.suffix, pat.wildcard = strings.Cut(value, "*")
case RouteHeadersContainsMatcher:
pat.value = value
case RouteHeadersRegexMatcher:
pat.value = value
}
return pat
}

func (p Pattern) Match(v string) bool {
if !p.wildcard {
return p.prefix == v
func (p Pattern) Match(mVal string) bool {
switch p.matchingType {
case RouteHeadersClassicMatcher:
if !p.wildcard {
return p.prefix == mVal
}
return len(mVal) >= len(p.prefix+p.suffix) &&
strings.HasPrefix(mVal, p.prefix) &&
strings.HasSuffix(mVal, p.suffix)
case RouteHeadersContainsMatcher:
return strings.Contains(mVal, p.value)
case RouteHeadersRegexMatcher:
reg := regexp.MustCompile(p.value)
return reg.MatchString(mVal)
}
return len(v) >= len(p.prefix+p.suffix) && strings.HasPrefix(v, p.prefix) && strings.HasSuffix(v, p.suffix)
return false
}
120 changes: 120 additions & 0 deletions middleware/route_headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package middleware

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/chi/v5"
)

func RouteHeadersDenyTestMiddleware() func(http.Handler) http.Handler {
return func(_ http.Handler) http.Handler {
return http.HandlerFunc(func(wrt http.ResponseWriter, _ *http.Request) {
wrt.WriteHeader(http.StatusForbidden)
})
}
}

func TestRouteHeadersMiddleware(t *testing.T) {
t.Parallel()

tests := []struct {
name string
header string
match string
matchType MatcherType
requestHeaders map[string]string
want int
}{
{
"TestClassicMatch",
"Authorization",
"Bearer *",
RouteHeadersClassicMatcher,
map[string]string{
"Authorization": "Bearer whatever",
"Other": "bera",
},
http.StatusForbidden,
},
{
"TestContainsMatch",
"Cookie",
"kc-access=",
RouteHeadersContainsMatcher,
map[string]string{
"Cookie": "some-cookie=tadadada; kc-access=mytoken",
},
http.StatusForbidden,
},
{
"TestRegexMatch",
"X-Custom-Header",
".*mycustom[4-9]+.*",
RouteHeadersRegexMatcher,
map[string]string{
"X-Custom-Header": "test1mycustom564other",
},
http.StatusForbidden,
},
{
"TestMatchAndValueIsLowered",
"Authorization",
"Bearer *",
RouteHeadersClassicMatcher,
map[string]string{
"Authorization": "bearer whatever",
},
http.StatusForbidden,
},
{
"TestNotMatch",
"Authorization",
"Bearer *",
RouteHeadersClassicMatcher,
map[string]string{
"Authorization": "Basic test",
},
http.StatusOK,
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()

recorder := httptest.NewRecorder()

headerRouterMiddleware := RouteHeaders().
SetMatchingType(test.matchType).
Route(
test.header,
test.match,
RouteHeadersDenyTestMiddleware(),
).
Handler

router := chi.NewRouter()
router.Use(headerRouterMiddleware)
router.Get("/", func(_ http.ResponseWriter, _ *http.Request) {})

var body []byte
req := httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(body))
for hName, hVal := range test.requestHeaders {
req.Header.Set(hName, hVal)
}

router.ServeHTTP(recorder, req)
res := recorder.Result()

res.Body.Close()

if res.StatusCode != test.want {
t.Errorf("response is incorrect, got %d, want %d", recorder.Code, test.want)
}
})
}
}