Skip to content

Commit 79fa918

Browse files
author
Peter Kieltyka
committed
Ensure a route context is available
1 parent 3031602 commit 79fa918

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.idea
22
*.sw?
3+
.vscode

Diff for: mux.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,10 @@ func (tr treeRouter) ServeHTTPC(ctx context.Context, w http.ResponseWriter, r *h
314314
// Grab the root context object
315315
rctx, _ := ctx.(*Context)
316316
if rctx == nil {
317-
rctx = ctx.Value(routeCtxKey).(*Context)
317+
rctx, _ = ctx.Value(routeCtxKey).(*Context)
318+
if rctx == nil {
319+
panic("chi: route context is required.")
320+
}
318321
}
319322

320323
// The request path

Diff for: mux_test.go

+51
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,57 @@ import (
1515
"golang.org/x/net/context"
1616
)
1717

18+
func TestMuxServeHTTP(t *testing.T) {
19+
r := NewRouter()
20+
r.Get("/hi", func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
21+
w.Write([]byte("bye"))
22+
})
23+
r.NotFound(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
24+
w.WriteHeader(404)
25+
w.Write([]byte("nothing here"))
26+
})
27+
28+
// Thanks to https://github.com/mrcpvn for the nice table test code
29+
testcases := []struct {
30+
Method string
31+
Path string
32+
ExpectedStatus int
33+
ExpectedBody string
34+
}{
35+
{
36+
Method: "GET",
37+
Path: "/hi",
38+
ExpectedStatus: 200,
39+
ExpectedBody: "bye",
40+
},
41+
{
42+
Method: "GET",
43+
Path: "/hello",
44+
ExpectedStatus: 404,
45+
ExpectedBody: "nothing here",
46+
},
47+
}
48+
49+
for _, tc := range testcases {
50+
resp := httptest.NewRecorder()
51+
req, err := http.NewRequest(tc.Method, tc.Path, nil)
52+
if err != nil {
53+
t.Fatalf("%v", err)
54+
}
55+
r.ServeHTTP(resp, req)
56+
b, err := ioutil.ReadAll(resp.Body)
57+
if err != nil {
58+
t.Fatalf("%v", err)
59+
}
60+
if resp.Code != tc.ExpectedStatus {
61+
t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code)
62+
}
63+
if string(b) != tc.ExpectedBody {
64+
t.Fatalf("%s != %s", tc.ExpectedBody, b)
65+
}
66+
}
67+
}
68+
1869
func TestMux(t *testing.T) {
1970
var count uint64
2071
countermw := func(next http.Handler) http.Handler {

0 commit comments

Comments
 (0)