Skip to content

Commit 0f80a44

Browse files
committed
Resolves conflicts(gorilla#515)
2 parents dde8a3e + d07530f commit 0f80a44

13 files changed

+370
-174
lines changed

.circleci/config.yml

+43-60
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,70 @@
1-
version: 2.0
1+
version: 2.1
22

33
jobs:
4-
# Base test configuration for Go library tests Each distinct version should
5-
# inherit this base, and override (at least) the container image used.
6-
"test": &test
4+
"test":
5+
parameters:
6+
version:
7+
type: string
8+
default: "latest"
9+
golint:
10+
type: boolean
11+
default: true
12+
modules:
13+
type: boolean
14+
default: true
15+
goproxy:
16+
type: string
17+
default: ""
718
docker:
8-
- image: circleci/golang:latest
19+
- image: "circleci/golang:<< parameters.version >>"
920
working_directory: /go/src/github.com/gorilla/mux
10-
steps: &steps
11-
# Our build steps: we checkout the repo, fetch our deps, lint, and finally
12-
# run "go test" on the package.
21+
environment:
22+
GO111MODULE: "on"
23+
GOPROXY: "<< parameters.goproxy >>"
24+
steps:
1325
- checkout
14-
# Logs the version in our build logs, for posterity
15-
- run: go version
26+
- run:
27+
name: "Print the Go version"
28+
command: >
29+
go version
1630
- run:
1731
name: "Fetch dependencies"
1832
command: >
19-
go get -t -v ./...
33+
if [[ << parameters.modules >> = true ]]; then
34+
go mod download
35+
export GO111MODULE=on
36+
else
37+
go get -v ./...
38+
fi
2039
# Only run gofmt, vet & lint against the latest Go version
2140
- run:
2241
name: "Run golint"
2342
command: >
24-
if [ "${LATEST}" = true ] && [ -z "${SKIP_GOLINT}" ]; then
43+
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
2544
go get -u golang.org/x/lint/golint
2645
golint ./...
2746
fi
2847
- run:
2948
name: "Run gofmt"
3049
command: >
31-
if [[ "${LATEST}" = true ]]; then
50+
if [[ << parameters.version >> = "latest" ]]; then
3251
diff -u <(echo -n) <(gofmt -d -e .)
3352
fi
3453
- run:
3554
name: "Run go vet"
36-
command: >
37-
if [[ "${LATEST}" = true ]]; then
55+
command: >
56+
if [[ << parameters.version >> = "latest" ]]; then
3857
go vet -v ./...
3958
fi
40-
- run: go test -v -race ./...
41-
42-
"latest":
43-
<<: *test
44-
environment:
45-
LATEST: true
46-
47-
"1.12":
48-
<<: *test
49-
docker:
50-
- image: circleci/golang:1.12
51-
52-
"1.11":
53-
<<: *test
54-
docker:
55-
- image: circleci/golang:1.11
56-
57-
"1.10":
58-
<<: *test
59-
docker:
60-
- image: circleci/golang:1.10
61-
62-
"1.9":
63-
<<: *test
64-
docker:
65-
- image: circleci/golang:1.9
66-
67-
"1.8":
68-
<<: *test
69-
docker:
70-
- image: circleci/golang:1.8
71-
72-
"1.7":
73-
<<: *test
74-
docker:
75-
- image: circleci/golang:1.7
59+
- run:
60+
name: "Run go test (+ race detector)"
61+
command: >
62+
go test -v -race ./...
7663
7764
workflows:
78-
version: 2
79-
build:
65+
tests:
8066
jobs:
81-
- "latest"
82-
- "1.12"
83-
- "1.11"
84-
- "1.10"
85-
- "1.9"
86-
- "1.8"
87-
- "1.7"
67+
- test:
68+
matrix:
69+
parameters:
70+
version: ["latest", "1.15", "1.14", "1.13", "1.12", "1.11"]

context.go

-18
This file was deleted.

context_test.go

-30
This file was deleted.

middleware.go

+10-15
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,17 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc {
5858
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
5959
var allMethods []string
6060

61-
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
62-
for _, m := range route.matchers {
63-
if _, ok := m.(*routeRegexp); ok {
64-
if m.Match(req, &RouteMatch{}) {
65-
methods, err := route.GetMethods()
66-
if err != nil {
67-
return err
68-
}
69-
70-
allMethods = append(allMethods, methods...)
71-
}
72-
break
61+
for _, route := range r.routes {
62+
var match RouteMatch
63+
if route.Match(req, &match) || match.MatchErr == ErrMethodMismatch {
64+
methods, err := route.GetMethods()
65+
if err != nil {
66+
return nil, err
7367
}
68+
69+
allMethods = append(allMethods, methods...)
7470
}
75-
return nil
76-
})
71+
}
7772

78-
return allMethods, err
73+
return allMethods, nil
7974
}

middleware_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,26 @@ func TestCORSMethodMiddleware(t *testing.T) {
478478
}
479479
}
480480

481+
func TestCORSMethodMiddlewareSubrouter(t *testing.T) {
482+
router := NewRouter().StrictSlash(true)
483+
484+
subrouter := router.PathPrefix("/test").Subrouter()
485+
subrouter.HandleFunc("/hello", stringHandler("a")).Methods(http.MethodGet, http.MethodOptions, http.MethodPost)
486+
subrouter.HandleFunc("/hello/{name}", stringHandler("b")).Methods(http.MethodGet, http.MethodOptions)
487+
488+
subrouter.Use(CORSMethodMiddleware(subrouter))
489+
490+
rw := NewRecorder()
491+
req := newRequest("GET", "/test/hello/asdf")
492+
router.ServeHTTP(rw, req)
493+
494+
actualMethods := rw.Header().Get("Access-Control-Allow-Methods")
495+
expectedMethods := "GET,OPTIONS"
496+
if actualMethods != expectedMethods {
497+
t.Fatalf("expected methods %q but got: %q", expectedMethods, actualMethods)
498+
}
499+
}
500+
481501
func TestMiddlewareOnMultiSubrouter(t *testing.T) {
482502
first := "first"
483503
second := "second"

mux.go

+13-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package mux
66

77
import (
8+
"context"
89
"errors"
910
"fmt"
1011
"net/http"
@@ -58,8 +59,7 @@ type Router struct {
5859

5960
// If true, do not clear the request context after handling the request.
6061
//
61-
// Deprecated: No effect when go1.7+ is used, since the context is stored
62-
// on the request itself.
62+
// Deprecated: No effect, since the context is stored on the request itself.
6363
KeepContext bool
6464

6565
// Slice of middlewares to be called after a match is found
@@ -195,8 +195,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
195195
var handler http.Handler
196196
if r.Match(req, &match) {
197197
handler = match.Handler
198-
req = setVars(req, match.Vars)
199-
req = setCurrentRoute(req, match.Route)
198+
req = requestWithVars(req, match.Vars)
199+
req = requestWithRoute(req, match.Route)
200200
}
201201

202202
if handler == nil && match.MatchErr == ErrMethodMismatch {
@@ -426,7 +426,7 @@ const (
426426

427427
// Vars returns the route variables for the current request, if any.
428428
func Vars(r *http.Request) map[string]string {
429-
if rv := contextGet(r, varsKey); rv != nil {
429+
if rv := r.Context().Value(varsKey); rv != nil {
430430
return rv.(map[string]string)
431431
}
432432
return nil
@@ -435,21 +435,22 @@ func Vars(r *http.Request) map[string]string {
435435
// CurrentRoute returns the matched route for the current request, if any.
436436
// This only works when called inside the handler of the matched route
437437
// because the matched route is stored in the request context which is cleared
438-
// after the handler returns, unless the KeepContext option is set on the
439-
// Router.
438+
// after the handler returns.
440439
func CurrentRoute(r *http.Request) *Route {
441-
if rv := contextGet(r, routeKey); rv != nil {
440+
if rv := r.Context().Value(routeKey); rv != nil {
442441
return rv.(*Route)
443442
}
444443
return nil
445444
}
446445

447-
func setVars(r *http.Request, val interface{}) *http.Request {
448-
return contextSet(r, varsKey, val)
446+
func requestWithVars(r *http.Request, vars map[string]string) *http.Request {
447+
ctx := context.WithValue(r.Context(), varsKey, vars)
448+
return r.WithContext(ctx)
449449
}
450450

451-
func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
452-
return contextSet(r, routeKey, val)
451+
func requestWithRoute(r *http.Request, route *Route) *http.Request {
452+
ctx := context.WithValue(r.Context(), routeKey, route)
453+
return r.WithContext(ctx)
453454
}
454455

455456
// ----------------------------------------------------------------------------

mux_httpserver_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// +build go1.9
2+
3+
package mux
4+
5+
import (
6+
"bytes"
7+
"io/ioutil"
8+
"net/http"
9+
"net/http/httptest"
10+
"testing"
11+
)
12+
13+
func TestSchemeMatchers(t *testing.T) {
14+
router := NewRouter()
15+
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
16+
rw.Write([]byte("hello http world"))
17+
}).Schemes("http")
18+
router.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) {
19+
rw.Write([]byte("hello https world"))
20+
}).Schemes("https")
21+
22+
assertResponseBody := func(t *testing.T, s *httptest.Server, expectedBody string) {
23+
resp, err := s.Client().Get(s.URL)
24+
if err != nil {
25+
t.Fatalf("unexpected error getting from server: %v", err)
26+
}
27+
if resp.StatusCode != 200 {
28+
t.Fatalf("expected a status code of 200, got %v", resp.StatusCode)
29+
}
30+
body, err := ioutil.ReadAll(resp.Body)
31+
if err != nil {
32+
t.Fatalf("unexpected error reading body: %v", err)
33+
}
34+
if !bytes.Equal(body, []byte(expectedBody)) {
35+
t.Fatalf("response should be hello world, was: %q", string(body))
36+
}
37+
}
38+
39+
t.Run("httpServer", func(t *testing.T) {
40+
s := httptest.NewServer(router)
41+
defer s.Close()
42+
assertResponseBody(t, s, "hello http world")
43+
})
44+
t.Run("httpsServer", func(t *testing.T) {
45+
s := httptest.NewTLSServer(router)
46+
defer s.Close()
47+
assertResponseBody(t, s, "hello https world")
48+
})
49+
}

0 commit comments

Comments
 (0)