Skip to content

Commit 82a964c

Browse files
hakankutluayerhanakp
andauthoredFeb 1, 2023
Add context timeout middleware (#2380)
Add context timeout middleware Co-authored-by: Erhan Akpınar <erhan.akpinar@yemeksepeti.com> Co-authored-by: @erhanakp
1 parent 08093a4 commit 82a964c

File tree

2 files changed

+298
-0
lines changed

2 files changed

+298
-0
lines changed
 

‎middleware/context_timeout.go

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"errors"
6+
"time"
7+
8+
"github.com/labstack/echo/v4"
9+
)
10+
11+
// ContextTimeoutConfig defines the config for ContextTimeout middleware.
12+
type ContextTimeoutConfig struct {
13+
// Skipper defines a function to skip middleware.
14+
Skipper Skipper
15+
16+
// ErrorHandler is a function when error aries in middeware execution.
17+
ErrorHandler func(err error, c echo.Context) error
18+
19+
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
20+
Timeout time.Duration
21+
}
22+
23+
// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client
24+
// when underlying method returns context.DeadlineExceeded error.
25+
func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
26+
return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout})
27+
}
28+
29+
// ContextTimeoutWithConfig returns a Timeout middleware with config.
30+
func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
31+
mw, err := config.ToMiddleware()
32+
if err != nil {
33+
panic(err)
34+
}
35+
return mw
36+
}
37+
38+
// ToMiddleware converts Config to middleware.
39+
func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
40+
if config.Timeout == 0 {
41+
return nil, errors.New("timeout must be set")
42+
}
43+
if config.Skipper == nil {
44+
config.Skipper = DefaultSkipper
45+
}
46+
if config.ErrorHandler == nil {
47+
config.ErrorHandler = func(err error, c echo.Context) error {
48+
if err != nil && errors.Is(err, context.DeadlineExceeded) {
49+
return echo.ErrServiceUnavailable.WithInternal(err)
50+
}
51+
return err
52+
}
53+
}
54+
55+
return func(next echo.HandlerFunc) echo.HandlerFunc {
56+
return func(c echo.Context) error {
57+
if config.Skipper(c) {
58+
return next(c)
59+
}
60+
61+
timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
62+
defer cancel()
63+
64+
c.SetRequest(c.Request().WithContext(timeoutContext))
65+
66+
if err := next(c); err != nil {
67+
return config.ErrorHandler(err, c)
68+
}
69+
return nil
70+
}
71+
}, nil
72+
}

‎middleware/context_timeout_test.go

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net/http"
7+
"net/http/httptest"
8+
"net/url"
9+
"strings"
10+
"testing"
11+
"time"
12+
13+
"github.com/labstack/echo/v4"
14+
"github.com/stretchr/testify/assert"
15+
)
16+
17+
func TestContextTimeoutSkipper(t *testing.T) {
18+
t.Parallel()
19+
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
20+
Skipper: func(context echo.Context) bool {
21+
return true
22+
},
23+
Timeout: 10 * time.Millisecond,
24+
})
25+
26+
req := httptest.NewRequest(http.MethodGet, "/", nil)
27+
rec := httptest.NewRecorder()
28+
29+
e := echo.New()
30+
c := e.NewContext(req, rec)
31+
32+
err := m(func(c echo.Context) error {
33+
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
34+
return err
35+
}
36+
37+
return errors.New("response from handler")
38+
})(c)
39+
40+
// if not skipped we would have not returned error due context timeout logic
41+
assert.EqualError(t, err, "response from handler")
42+
}
43+
44+
func TestContextTimeoutWithTimeout0(t *testing.T) {
45+
t.Parallel()
46+
assert.Panics(t, func() {
47+
ContextTimeout(time.Duration(0))
48+
})
49+
}
50+
51+
func TestContextTimeoutErrorOutInHandler(t *testing.T) {
52+
t.Parallel()
53+
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
54+
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
55+
Timeout: 10 * time.Millisecond,
56+
})
57+
58+
req := httptest.NewRequest(http.MethodGet, "/", nil)
59+
rec := httptest.NewRecorder()
60+
61+
e := echo.New()
62+
c := e.NewContext(req, rec)
63+
64+
rec.Code = 1 // we want to be sure that even 200 will not be sent
65+
err := m(func(c echo.Context) error {
66+
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
67+
// to handle returned error and this can be done only then handler has not yet committed (written status code)
68+
// the response.
69+
return echo.NewHTTPError(http.StatusTeapot, "err")
70+
})(c)
71+
72+
assert.Error(t, err)
73+
assert.EqualError(t, err, "code=418, message=err")
74+
assert.Equal(t, 1, rec.Code)
75+
assert.Equal(t, "", rec.Body.String())
76+
}
77+
78+
func TestContextTimeoutSuccessfulRequest(t *testing.T) {
79+
t.Parallel()
80+
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
81+
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
82+
Timeout: 10 * time.Millisecond,
83+
})
84+
85+
req := httptest.NewRequest(http.MethodGet, "/", nil)
86+
rec := httptest.NewRecorder()
87+
88+
e := echo.New()
89+
c := e.NewContext(req, rec)
90+
91+
err := m(func(c echo.Context) error {
92+
return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
93+
})(c)
94+
95+
assert.NoError(t, err)
96+
assert.Equal(t, http.StatusCreated, rec.Code)
97+
assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
98+
}
99+
100+
func TestContextTimeoutTestRequestClone(t *testing.T) {
101+
t.Parallel()
102+
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
103+
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
104+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
105+
rec := httptest.NewRecorder()
106+
107+
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
108+
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
109+
Timeout: 1 * time.Second,
110+
})
111+
112+
e := echo.New()
113+
c := e.NewContext(req, rec)
114+
115+
err := m(func(c echo.Context) error {
116+
// Cookie test
117+
cookie, err := c.Request().Cookie("cookie")
118+
if assert.NoError(t, err) {
119+
assert.EqualValues(t, "cookie", cookie.Name)
120+
assert.EqualValues(t, "value", cookie.Value)
121+
}
122+
123+
// Form values
124+
if assert.NoError(t, c.Request().ParseForm()) {
125+
assert.EqualValues(t, "value", c.Request().FormValue("form"))
126+
}
127+
128+
// Query string
129+
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
130+
return nil
131+
})(c)
132+
133+
assert.NoError(t, err)
134+
}
135+
136+
func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
137+
t.Parallel()
138+
139+
timeout := 10 * time.Millisecond
140+
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
141+
Timeout: timeout,
142+
})
143+
144+
req := httptest.NewRequest(http.MethodGet, "/", nil)
145+
rec := httptest.NewRecorder()
146+
147+
e := echo.New()
148+
c := e.NewContext(req, rec)
149+
150+
err := m(func(c echo.Context) error {
151+
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
152+
return err
153+
}
154+
return c.String(http.StatusOK, "Hello, World!")
155+
})(c)
156+
157+
assert.IsType(t, &echo.HTTPError{}, err)
158+
assert.Error(t, err)
159+
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
160+
assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
161+
}
162+
163+
func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
164+
t.Parallel()
165+
166+
timeoutErrorHandler := func(err error, c echo.Context) error {
167+
if err != nil {
168+
if errors.Is(err, context.DeadlineExceeded) {
169+
return &echo.HTTPError{
170+
Code: http.StatusServiceUnavailable,
171+
Message: "Timeout! change me",
172+
}
173+
}
174+
return err
175+
}
176+
return nil
177+
}
178+
179+
timeout := 10 * time.Millisecond
180+
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
181+
Timeout: timeout,
182+
ErrorHandler: timeoutErrorHandler,
183+
})
184+
185+
req := httptest.NewRequest(http.MethodGet, "/", nil)
186+
rec := httptest.NewRecorder()
187+
188+
e := echo.New()
189+
c := e.NewContext(req, rec)
190+
191+
err := m(func(c echo.Context) error {
192+
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
193+
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
194+
// difference over 500microseconds (0.5millisecond) response seems to be reliable
195+
196+
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
197+
return err
198+
}
199+
200+
// The Request Context should have a Deadline set by http.ContextTimeoutHandler
201+
if _, ok := c.Request().Context().Deadline(); !ok {
202+
assert.Fail(t, "No timeout set on Request Context")
203+
}
204+
return c.String(http.StatusOK, "Hello, World!")
205+
})(c)
206+
207+
assert.IsType(t, &echo.HTTPError{}, err)
208+
assert.Error(t, err)
209+
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
210+
assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message)
211+
}
212+
213+
func sleepWithContext(ctx context.Context, d time.Duration) error {
214+
timer := time.NewTimer(d)
215+
216+
defer func() {
217+
_ = timer.Stop()
218+
}()
219+
220+
select {
221+
case <-ctx.Done():
222+
return context.DeadlineExceeded
223+
case <-timer.C:
224+
return nil
225+
}
226+
}

0 commit comments

Comments
 (0)
Please sign in to comment.