diff --git a/mux.go b/mux.go index eaff897a..b40ab0a5 100644 --- a/mux.go +++ b/mux.go @@ -97,6 +97,10 @@ type routeConf struct { // if true, the the http.Request context will not contain the router omitRouterFromContext bool + // If true, only ampersands (not semicolons) act as separators for + // query-parameter pairs. + strictQueryParamSep bool + // Manager for the variables from host and path. regexp routeRegexpGroup @@ -295,6 +299,24 @@ func (r *Router) OmitRouterFromContext(value bool) *Router { return r } +// StrictQueryParamSep defines which characters act as separators for +// query-parameter pairs. The initial value is false, but beware: a future +// version of this library will adopt true for the initial value. +// +// When true, only ampersands act as separators for query parameters. +// This behavior complies with [the URL Living Standard]. +// +// When false, both ampersands and semicolons act as separators for +// query-parameter pairs. This contravenes the URL Living Standard and causes +// interoperability problems that can lead to [security vulnerabilities]. +// +// [security vulnerabilities]: https://github.com/gorilla/mux/issues/781 +// [the URL Living Standard]: https://url.spec.whatwg.org/#urlencoded-parsing +func (r *Router) StrictQueryParamSep(value bool) *Router { + r.strictQueryParamSep = value + return r +} + // UseEncodedPath tells the router to match the encoded original path // to the routes. // For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to". diff --git a/mux_test.go b/mux_test.go index bac758bc..601070ab 100644 --- a/mux_test.go +++ b/mux_test.go @@ -12,10 +12,12 @@ import ( "fmt" "io" "log" + "maps" "net/http" "net/http/httptest" "net/url" "reflect" + "strconv" "strings" "testing" "time" @@ -2067,6 +2069,54 @@ func TestSkipClean(t *testing.T) { } } +func TestStrictQueryParamSep(t *testing.T) { + cases := []struct { + b bool + pairs []string + want map[string]string + }{ + { + b: false, + pairs: []string{"foo", "{foo}", "bar", "{bar}", "baz", "{baz}"}, + want: map[string]string{ + "foo": "foo", + "bar": "bar", + "baz": "baz", + }, + }, { + b: true, + pairs: []string{"foo", "{foo}", "baz", "{baz}"}, + want: map[string]string{ + "foo": "foo;bar=bar", + "baz": "baz", + }, + }, + } + for _, tc := range cases { + f := func(t *testing.T) { + var got map[string]string + handle := func(_ http.ResponseWriter, r *http.Request) { + got = Vars(r) + } + r := NewRouter() + if r.strictQueryParamSep { + t.Error("strickQueryParamSep should be false by default") + } + r.StrictQueryParamSep(tc.b) + r.HandleFunc("/", handle).Queries(tc.pairs...) + + req := httptest.NewRequest("GET", "http://localhost/?foo=foo;bar=bar&baz=baz", nil) + res := NewRecorder() + r.ServeHTTP(res, req) + + if !maps.Equal(got, tc.want) { + t.Errorf("unexpected query params: got %v; want %v", got, tc.want) + } + } + t.Run(strconv.FormatBool(tc.b), f) + } +} + // https://plus.google.com/101022900381697718949/posts/eWy6DjFJ6uW func TestSubrouterHeader(t *testing.T) { expected := "func1 response" diff --git a/regexp.go b/regexp.go index e0bcff6a..33f775f9 100644 --- a/regexp.go +++ b/regexp.go @@ -5,7 +5,6 @@ package mux import ( - "bytes" "fmt" "net/http" "net/url" @@ -15,8 +14,9 @@ import ( ) type routeRegexpOptions struct { - strictSlash bool - useEncodedPath bool + strictSlash bool + useEncodedPath bool + strictQueryParamSep bool } type regexpType int @@ -245,7 +245,8 @@ func (r *routeRegexp) getURLQuery(req *http.Request) string { return "" } templateKey := strings.SplitN(r.template, "=", 2)[0] - val, ok := findFirstQueryKey(req.URL.RawQuery, templateKey) + strict := r.options.strictQueryParamSep + val, ok := findFirstQueryKey(req.URL.RawQuery, templateKey, strict) if ok { return templateKey + "=" + val } @@ -254,34 +255,32 @@ func (r *routeRegexp) getURLQuery(req *http.Request) string { // findFirstQueryKey returns the same result as (*url.URL).Query()[key][0]. // If key was not found, empty string and false is returned. -func findFirstQueryKey(rawQuery, key string) (value string, ok bool) { - query := []byte(rawQuery) - for len(query) > 0 { - foundKey := query - if i := bytes.IndexAny(foundKey, "&;"); i >= 0 { - foundKey, query = foundKey[:i], foundKey[i+1:] +func findFirstQueryKey(rawQuery, key string, strict bool) (value string, ok bool) { + for len(rawQuery) > 0 { + foundKey := rawQuery + if strict { + foundKey, rawQuery, _ = strings.Cut(foundKey, "&") + } else if i := strings.IndexAny(foundKey, "&;"); i >= 0 { + foundKey, rawQuery = foundKey[:i], foundKey[i+1:] } else { - query = query[:0] + rawQuery = rawQuery[:0] } if len(foundKey) == 0 { continue } - var value []byte - if i := bytes.IndexByte(foundKey, '='); i >= 0 { - foundKey, value = foundKey[:i], foundKey[i+1:] - } + foundKey, value, _ := strings.Cut(foundKey, "=") if len(foundKey) < len(key) { // Cannot possibly be key. continue } - keyString, err := url.QueryUnescape(string(foundKey)) + keyString, err := url.QueryUnescape(foundKey) if err != nil { continue } if keyString != key { continue } - valueString, err := url.QueryUnescape(string(value)) + valueString, err := url.QueryUnescape(value) if err != nil { continue } diff --git a/regexp_test.go b/regexp_test.go index f7fce81c..9ac96385 100644 --- a/regexp_test.go +++ b/regexp_test.go @@ -52,7 +52,7 @@ func Test_findFirstQueryKey(t *testing.T) { all, _ := url.ParseQuery(query) for key, want := range all { t.Run(key, func(t *testing.T) { - got, ok := findFirstQueryKey(query, key) + got, ok := findFirstQueryKey(query, key, false) if !ok { t.Error("Did not get expected key", key) } @@ -81,7 +81,7 @@ func Benchmark_findQueryKey(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { for key := range all { - _, _ = findFirstQueryKey(query, key) + _, _ = findFirstQueryKey(query, key, false) } } }) diff --git a/route.go b/route.go index d10401e9..c13e3db6 100644 --- a/route.go +++ b/route.go @@ -257,8 +257,9 @@ func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error { } } rr, err := newRouteRegexp(tpl, typ, routeRegexpOptions{ - strictSlash: r.strictSlash, - useEncodedPath: r.useEncodedPath, + strictSlash: r.strictSlash, + useEncodedPath: r.useEncodedPath, + strictQueryParamSep: r.strictQueryParamSep, }) if err != nil { return err