Skip to content

Commit 0ce7302

Browse files
authored
[suggestion] Add helper interface for ProxyBalancer interface (#2316)
* [suggestion] Add helper interface for ProxyBalancer interface * Update proxy_test.go * addressed code review comments * address pr comments * clean up * return error
1 parent 8f2bf82 commit 0ce7302

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

middleware/proxy.go

+16-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ type (
7272
Next(echo.Context) *ProxyTarget
7373
}
7474

75+
// TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target.
76+
TargetProvider interface {
77+
NextTarget(echo.Context) (*ProxyTarget, error)
78+
}
79+
7580
commonBalancer struct {
7681
targets []*ProxyTarget
7782
mutex sync.RWMutex
@@ -223,6 +228,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
223228
}
224229
}
225230

231+
provider, isTargetProvider := config.Balancer.(TargetProvider)
226232
return func(next echo.HandlerFunc) echo.HandlerFunc {
227233
return func(c echo.Context) (err error) {
228234
if config.Skipper(c) {
@@ -231,7 +237,16 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
231237

232238
req := c.Request()
233239
res := c.Response()
234-
tgt := config.Balancer.Next(c)
240+
241+
var tgt *ProxyTarget
242+
if isTargetProvider {
243+
tgt, err = provider.NextTarget(c)
244+
if err != nil {
245+
return err
246+
}
247+
} else {
248+
tgt = config.Balancer.Next(c)
249+
}
235250
c.Set(config.ContextKey, tgt)
236251

237252
if err := rewriteURL(config.RegexRewrite, req); err != nil {

middleware/proxy_test.go

+50-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
"github.com/stretchr/testify/assert"
1919
)
2020

21-
//Assert expected with url.EscapedPath method to obtain the path.
21+
// Assert expected with url.EscapedPath method to obtain the path.
2222
func TestProxy(t *testing.T) {
2323
// Setup
2424
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -31,7 +31,6 @@ func TestProxy(t *testing.T) {
3131
}))
3232
defer t2.Close()
3333
url2, _ := url.Parse(t2.URL)
34-
3534
targets := []*ProxyTarget{
3635
{
3736
Name: "target 1",
@@ -122,6 +121,55 @@ func TestProxy(t *testing.T) {
122121
e.ServeHTTP(rec, req)
123122
}
124123

124+
type testProvider struct {
125+
*commonBalancer
126+
target *ProxyTarget
127+
err error
128+
}
129+
130+
func (p *testProvider) Next(c echo.Context) *ProxyTarget {
131+
return &ProxyTarget{}
132+
}
133+
134+
func (p *testProvider) NextTarget(c echo.Context) (*ProxyTarget, error) {
135+
return p.target, p.err
136+
}
137+
138+
func TestTargetProvider(t *testing.T) {
139+
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
140+
fmt.Fprint(w, "target 1")
141+
}))
142+
defer t1.Close()
143+
url1, _ := url.Parse(t1.URL)
144+
145+
e := echo.New()
146+
tp := &testProvider{commonBalancer: new(commonBalancer)}
147+
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
148+
e.Use(Proxy(tp))
149+
rec := httptest.NewRecorder()
150+
req := httptest.NewRequest(http.MethodGet, "/", nil)
151+
e.ServeHTTP(rec, req)
152+
body := rec.Body.String()
153+
assert.Equal(t, "target 1", body)
154+
}
155+
156+
func TestFailNextTarget(t *testing.T) {
157+
url1, err := url.Parse("http://dummy:8080")
158+
assert.Nil(t, err)
159+
160+
e := echo.New()
161+
tp := &testProvider{commonBalancer: new(commonBalancer)}
162+
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
163+
tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
164+
165+
e.Use(Proxy(tp))
166+
rec := httptest.NewRecorder()
167+
req := httptest.NewRequest(http.MethodGet, "/", nil)
168+
e.ServeHTTP(rec, req)
169+
body := rec.Body.String()
170+
assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
171+
}
172+
125173
func TestProxyRealIPHeader(t *testing.T) {
126174
// Setup
127175
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

0 commit comments

Comments
 (0)