diff --git a/context.go b/context.go index f5dd5a69d..515865a59 100644 --- a/context.go +++ b/context.go @@ -40,6 +40,8 @@ type Context interface { // Scheme returns the HTTP protocol scheme, `http` or `https`. Scheme() string + SchemeForwarded() *Forwarded + // RealIP returns the client's network address based on `X-Forwarded-For` // or `X-Real-IP` request header. // The behavior can be configured using `Echo#IPExtractor`. @@ -234,6 +236,14 @@ const ( ContextKeyHeaderAllow = "echo_header_allow" ) +// Forwarded represents the structured format of the Forwarded HTTP header. +type Forwarded struct { + By []string + For []string + Host []string + Proto []string +} + const ( defaultMemory = 32 << 20 // 32 MB indexPage = "index.html" @@ -293,24 +303,85 @@ func (c *context) Scheme() string { return "http" } +func (c *context) SchemeForwarded() *Forwarded { + // Parse and get "Forwarded" header. + // See : https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded + if scheme := c.request.Header.Get(HeaderForwarded); scheme != "" { + f, err := c.parseForwarded(scheme) + if err != nil { + return nil + } + return &f + } + return nil +} + +func (c *context) parseForwarded(input string) (Forwarded, error) { + forwarded := Forwarded{} + entries := strings.Split(input, ",") + + for _, entry := range entries { + entry = strings.TrimSpace(entry) + pairs := strings.Split(entry, ";") + + for _, pair := range pairs { + parts := strings.SplitN(pair, "=", 2) + if len(parts) != 2 { + return forwarded, fmt.Errorf("invalid pair: %s", pair) + } + + key := strings.TrimSpace(parts[0]) + value, err := url.QueryUnescape(strings.TrimSpace(parts[1])) + if err != nil { + return forwarded, fmt.Errorf("failed to unescape value: %w", err) + } + value = strings.Trim(value, "\"[]") + switch key { + case "by": + forwarded.By = append(forwarded.By, value) + case "for": + forwarded.For = append(forwarded.For, value) + case "host": + forwarded.Host = append(forwarded.Host, value) + case "proto": + forwarded.Proto = append(forwarded.Proto, value) + default: + return forwarded, fmt.Errorf("unknown key: %s", key) + } + } + } + + return forwarded, nil +} + func (c *context) RealIP() string { if c.echo != nil && c.echo.IPExtractor != nil { return c.echo.IPExtractor(c.request) } + // Check if the "Forwarded" header is present in the request. + if d := c.request.Header.Get(HeaderForwarded); d != "" { + // Parse the "Forwarded" header. + scheme, err := c.parseForwarded(d) + if err != nil { + return "" // Return an empty string if parsing fails. + } + if len(scheme.For) > 0 { + return scheme.For[0] // Return first for item + } + return "" + } // Fall back to legacy behavior if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { i := strings.IndexAny(ip, ",") if i > 0 { xffip := strings.TrimSpace(ip[:i]) - xffip = strings.TrimPrefix(xffip, "[") - xffip = strings.TrimSuffix(xffip, "]") + xffip = strings.Trim(xffip, "\"[]") return xffip } return ip } if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { - ip = strings.TrimPrefix(ip, "[") - ip = strings.TrimSuffix(ip, "]") + ip = strings.Trim(ip, "\"[]") return ip } ra, _, _ := net.SplitHostPort(c.request.RemoteAddr) diff --git a/context_test.go b/context_test.go index 1fd89edb4..ce93256b5 100644 --- a/context_test.go +++ b/context_test.go @@ -961,6 +961,50 @@ func TestContext_Scheme(t *testing.T) { } } +func TestContext_SchemeForwarded(t *testing.T) { + tests := []struct { + c Context + s *Forwarded + }{ + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderForwarded: []string{"for=192.0.2.60;proto=http;by=203.0.113.43"}}, + }, + }, + &Forwarded{ + For: []string{"192.0.2.60"}, + Proto: []string{"http"}, + By: []string{"203.0.113.43"}, + }, + }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderForwarded: []string{"for=192.0.2.43, for=198.51.100.17"}}, + }, + }, + &Forwarded{ + For: []string{"192.0.2.43", "198.51.100.17"}, + }, + }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderForwarded: []string{"for=192.0.2.43, for=[2001:db8:cafe::17]"}}, + }, + }, + &Forwarded{ + For: []string{"192.0.2.43", "2001:db8:cafe::17"}, + }, + }, + } + + for _, tt := range tests { + assert.Equal(t, tt.s, tt.c.SchemeForwarded()) + } +} + func TestContext_IsWebSocket(t *testing.T) { tests := []struct { c Context @@ -1062,6 +1106,22 @@ func TestContext_RealIP(t *testing.T) { }, "127.0.0.1", }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderForwarded: []string{"for=192.0.2.43, for=198.51.100.17"}}, + }, + }, + "192.0.2.43", + }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderForwarded: []string{"for=[2001:db8:85a3:8d3:1319:8a2e:370:7348], for=2001:db8::1"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, { &context{ request: &http.Request{ diff --git a/echo.go b/echo.go index 60f7061d8..a199296c4 100644 --- a/echo.go +++ b/echo.go @@ -221,6 +221,7 @@ const ( HeaderUpgrade = "Upgrade" HeaderVary = "Vary" HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderForwarded = "Forwarded" HeaderXForwardedFor = "X-Forwarded-For" HeaderXForwardedProto = "X-Forwarded-Proto" HeaderXForwardedProtocol = "X-Forwarded-Protocol"