Skip to content

Commit c2354ea

Browse files
committed
ClientIP middleware proposal, intended to replace RealIP
1 parent 0a20a0e commit c2354ea

File tree

2 files changed

+326
-0
lines changed

2 files changed

+326
-0
lines changed

middleware/client_ip.go

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/http"
7+
"net/netip"
8+
"strings"
9+
)
10+
11+
var (
12+
// clientIPCtxKey is the context key used to store the client IP address.
13+
clientIPCtxKey = &contextKey{"clientIP"}
14+
)
15+
16+
// ClientIPFromHeader parses the client IP address from a specified HTTP header
17+
// (e.g., X-Real-IP, CF-Connecting-IP) and injects it into the request context
18+
// if it is not already set. The parsed IP address can be retrieved using GetClientIP().
19+
//
20+
// The middleware validates the IP address to ignore loopback, private, and unspecified addresses.
21+
//
22+
// ### Important Notice:
23+
// - Use this middleware only when your infrastructure sets a trusted header containing the client IP.
24+
// - If the specified header is not securely set by your infrastructure, malicious clients could spoof it.
25+
//
26+
// Example trusted headers:
27+
// - "X-Real-IP" - Nginx (ngx_http_realip_module)
28+
// - "X-Client-IP" - Apache (mod_remoteip)
29+
// - "CF-Connecting-IP" - Cloudflare
30+
// - "True-Client-IP" - Akamai, Cloudflare Enterprise
31+
// - "X-Azure-ClientIP" - Azure Front Door
32+
// - "Fastly-Client-IP" - Fastly
33+
func ClientIPFromHeader(trustedHeader string) func(http.Handler) http.Handler {
34+
return func(h http.Handler) http.Handler {
35+
fn := func(w http.ResponseWriter, r *http.Request) {
36+
ctx := r.Context()
37+
38+
// Check if the client IP is already set in the context.
39+
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
40+
h.ServeHTTP(w, r)
41+
return
42+
}
43+
44+
// Parse the IP address from the trusted header.
45+
ip, err := netip.ParseAddr(r.Header.Get(trustedHeader))
46+
if err != nil || ip.IsLoopback() || ip.IsUnspecified() || ip.IsPrivate() {
47+
// Ignore invalid or private IPs.
48+
h.ServeHTTP(w, r)
49+
return
50+
}
51+
52+
// Store the valid client IP in the context.
53+
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
54+
h.ServeHTTP(w, r.WithContext(ctx))
55+
}
56+
return http.HandlerFunc(fn)
57+
}
58+
}
59+
60+
// ClientIPFromXFFHeader parses the client IP address from the X-Forwarded-For
61+
// header and injects it into the request context if it is not already set. The
62+
// parsed IP address can be retrieved using GetClientIP().
63+
//
64+
// The middleware traverses the X-Forwarded-For chain (rightmost untrusted IP)
65+
// and excludes loopback, private, unspecified, and trusted IP ranges.
66+
//
67+
// ### Important Notice:
68+
// - Use this middleware only when your infrastructure sets and validates the X-Forwarded-For header.
69+
// - Malicious clients can spoof the header unless a trusted reverse proxy or load balancer sanitizes it.
70+
//
71+
// Parameters:
72+
// - `trustedIPPrefixes`: A list of CIDR prefixes that define trusted proxy IP ranges.
73+
//
74+
// Example trusted IP ranges:
75+
// - "203.0.113.0/24" - Example corporate proxy
76+
// - "198.51.100.0/24" - Example data center or hosting provider
77+
// - "2400:cb00::/32" - Cloudflare IPv6 range
78+
// - "2606:4700::/32" - Cloudflare IPv6 range
79+
// - "192.0.2.0/24" - Example VPN gateway
80+
//
81+
// Note: Private IP ranges (e.g., "10.0.0.0/8", "192.168.0.0/16", "172.16.0.0/12")
82+
// are automatically excluded by netip.Addr.IsPrivate() and do not need to be added here.
83+
func ClientIPFromXFFHeader(trustedIPPrefixes ...string) func(http.Handler) http.Handler {
84+
// Pre-parse trusted prefixes.
85+
trustedPrefixes := make([]netip.Prefix, len(trustedIPPrefixes))
86+
for i, ipRange := range trustedIPPrefixes {
87+
trustedPrefixes[i] = netip.MustParsePrefix(ipRange)
88+
}
89+
90+
return func(h http.Handler) http.Handler {
91+
fn := func(w http.ResponseWriter, r *http.Request) {
92+
ctx := r.Context()
93+
94+
// Check if the client IP is already set in the context.
95+
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
96+
h.ServeHTTP(w, r)
97+
return
98+
}
99+
100+
// Parse and split the X-Forwarded-For header(s).
101+
xff := strings.Split(strings.Join(r.Header.Values("X-Forwarded-For"), ","), ",")
102+
nextValue:
103+
for i := len(xff) - 1; i >= 0; i-- {
104+
ip, err := netip.ParseAddr(strings.TrimSpace(xff[i]))
105+
if err != nil {
106+
continue
107+
}
108+
109+
// Ignore loopback, private, or unspecified addresses.
110+
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() {
111+
continue
112+
}
113+
114+
// Ignore trusted IPs within the given ranges.
115+
for _, prefix := range trustedPrefixes {
116+
if prefix.Contains(ip) {
117+
continue nextValue
118+
}
119+
}
120+
121+
// Store the valid client IP in the context.
122+
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
123+
h.ServeHTTP(w, r.WithContext(ctx))
124+
return
125+
}
126+
127+
h.ServeHTTP(w, r)
128+
}
129+
return http.HandlerFunc(fn)
130+
}
131+
}
132+
133+
// ClientIPFromRemoteAddr extracts the client IP address from the RemoteAddr
134+
// field of the HTTP request and injects it into the request context if it is
135+
// not already set. The parsed IP address can be retrieved using GetClientIP().
136+
//
137+
// The middleware ignores invalid or private IPs.
138+
//
139+
// ### Use Case:
140+
// This middleware is useful when the client IP cannot be determined from headers
141+
// such as X-Forwarded-For or X-Real-IP, and you need to fall back to RemoteAddr.
142+
func ClientIPFromRemoteAddr(h http.Handler) http.Handler {
143+
fn := func(w http.ResponseWriter, r *http.Request) {
144+
ctx := r.Context()
145+
146+
// Check if the client IP is already set in the context.
147+
if _, ok := ctx.Value(clientIPCtxKey).(netip.Addr); ok {
148+
h.ServeHTTP(w, r)
149+
return
150+
}
151+
152+
// Extract the IP from request RemoteAddr.
153+
host, _, err := net.SplitHostPort(r.RemoteAddr)
154+
if err != nil {
155+
h.ServeHTTP(w, r)
156+
return
157+
}
158+
159+
ip, err := netip.ParseAddr(host)
160+
if err != nil {
161+
h.ServeHTTP(w, r)
162+
return
163+
}
164+
165+
// Store the valid client IP in the context.
166+
ctx = context.WithValue(ctx, clientIPCtxKey, ip)
167+
h.ServeHTTP(w, r.WithContext(ctx))
168+
}
169+
return http.HandlerFunc(fn)
170+
}
171+
172+
// GetClientIP retrieves the client IP address from the given context.
173+
// The IP address is set by one of the following middlewares:
174+
// - ClientIPFromHeader
175+
// - ClientIPFromXFFHeader
176+
// - ClientIPFromRemoteAddr
177+
//
178+
// Returns an empty string if no valid IP is found.
179+
func GetClientIP(ctx context.Context) string {
180+
ip, ok := ctx.Value(clientIPCtxKey).(netip.Addr)
181+
if !ok || !ip.IsValid() {
182+
return ""
183+
}
184+
return ip.String()
185+
}

middleware/client_ip_test.go

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/go-chi/chi/v5"
9+
)
10+
11+
func TestClientIPFromHeader(t *testing.T) {
12+
tt := []struct {
13+
name string
14+
in string
15+
out string
16+
}{
17+
// Empty header.
18+
{name: "empty", in: "", out: ""},
19+
20+
// Valid X-Real-IP header values.
21+
{name: "valid_ipv4", in: "100.100.100.100", out: "100.100.100.100"},
22+
{name: "valid_ipv4", in: "178.25.203.2", out: "178.25.203.2"},
23+
{name: "valid_ipv6_lower", in: "2345:0425:2ca1:0000:0000:0567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"},
24+
{name: "valid_ipv6_upper", in: "2345:0425:2CA1:0000:0000:0567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"},
25+
{name: "valid_ipv6_lower_short", in: "2345:425:2ca1::567:5673:23b5", out: "2345:425:2ca1::567:5673:23b5"},
26+
{name: "valid_ipv6_upper_short", in: "2345:425:2CA1::567:5673:23B5", out: "2345:425:2ca1::567:5673:23b5"},
27+
28+
// Invalid X-Real-IP header values.
29+
{name: "invalid_ip", in: "invalid", out: ""},
30+
{name: "invalid_ip_with_port", in: "100.100.100.100:80", out: ""},
31+
{name: "invalid_multiple_ips", in: "100.100.100.100;100.100.100.101;100.100.100.102", out: ""},
32+
{name: "invalid_loopback", in: "127.0.0.1", out: ""},
33+
{name: "invalid_zeroes", in: "0.0.0.0", out: ""},
34+
{name: "invalid_loopback", in: "127.0.0.1", out: ""},
35+
{name: "invalid_private_ipv4_1", in: "192.168.0.1", out: ""},
36+
{name: "invalid_private_ipv4_2", in: "192.168.10.12", out: ""},
37+
{name: "invalid_private_ipv4_3", in: "172.16.0.0", out: ""},
38+
{name: "invalid_private_ipv4_4", in: "172.25.203.2", out: ""},
39+
{name: "invalid_private_ipv4_5", in: "10.0.0.0", out: ""},
40+
{name: "invalid_private_ipv4_6", in: "10.0.1.10", out: ""},
41+
{name: "invalid_private_ipv6_1", in: "fc00::1", out: ""},
42+
{name: "invalid_private_ipv6_2", in: "fc00:0425:2ca1:0000:0000:0567:5673:23b5", out: ""},
43+
}
44+
45+
for _, tc := range tt {
46+
t.Run(tc.name, func(t *testing.T) {
47+
req, _ := http.NewRequest("GET", "/", nil)
48+
req.Header.Add("X-Real-IP", tc.in)
49+
w := httptest.NewRecorder()
50+
51+
r := chi.NewRouter()
52+
r.Use(ClientIPFromHeader("X-Real-IP"))
53+
54+
var clientIP string
55+
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
56+
clientIP = GetClientIP(r.Context())
57+
w.Write([]byte("Hello World"))
58+
})
59+
r.ServeHTTP(w, req)
60+
61+
if w.Code != 200 {
62+
t.Errorf("Response Code should be 200")
63+
}
64+
65+
if clientIP != tc.out {
66+
t.Errorf("expected %v, got %v", tc.out, clientIP)
67+
}
68+
})
69+
}
70+
}
71+
72+
func TestClientIPFromXFFHeader(t *testing.T) {
73+
tt := []struct {
74+
name string
75+
xff []string
76+
out string
77+
}{
78+
{name: "empty", xff: []string{""}, out: ""},
79+
80+
{name: "", xff: []string{"100.100.100.100"}, out: "100.100.100.100"},
81+
{name: "", xff: []string{"100.100.100.100, 200.200.200.200"}, out: "200.200.200.200"},
82+
{name: "", xff: []string{"100.100.100.100,200.200.200.200"}, out: "200.200.200.200"},
83+
{name: "", xff: []string{"100.100.100.100", "200.200.200.200"}, out: "200.200.200.200"},
84+
{name: "", xff: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"},
85+
{name: "", xff: []string{"203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348"}, out: "2001:db8:85a3:8d3:1319:8a2e:370:7348"},
86+
{name: "", xff: []string{"5.5.5.5, 203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348", "7.7.7.7, 4.4.4.4"}, out: "4.4.4.4"},
87+
}
88+
89+
r := chi.NewRouter()
90+
r.Use(ClientIPFromXFFHeader())
91+
92+
for _, tc := range tt {
93+
req, _ := http.NewRequest("GET", "/", nil)
94+
for _, v := range tc.xff {
95+
req.Header.Add("X-Forwarded-For", v)
96+
}
97+
98+
w := httptest.NewRecorder()
99+
100+
clientIP := ""
101+
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
102+
clientIP = GetClientIP(r.Context())
103+
w.Write([]byte("Hello World"))
104+
})
105+
r.ServeHTTP(w, req)
106+
107+
if w.Code != 200 {
108+
t.Errorf("Response Code should be 200")
109+
}
110+
111+
if clientIP != tc.out {
112+
t.Errorf("expected %v, got %v", tc.out, clientIP)
113+
}
114+
}
115+
}
116+
117+
func TestClientIPFromRemoteAddr(t *testing.T) {
118+
req, _ := http.NewRequest("GET", "/", nil)
119+
req.RemoteAddr = "192.0.2.1:1234" // Simulate the remote address set by http.Server.
120+
121+
w := httptest.NewRecorder()
122+
123+
r := chi.NewRouter()
124+
r.Use(ClientIPFromRemoteAddr)
125+
126+
var clientIP string
127+
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
128+
clientIP = GetClientIP(r.Context())
129+
w.Write([]byte("Hello World"))
130+
})
131+
r.ServeHTTP(w, req)
132+
133+
if w.Code != 200 {
134+
t.Errorf("Response Code should be 200")
135+
}
136+
137+
expected := "192.0.2.1"
138+
if clientIP != expected {
139+
t.Errorf("expected %v, got %v", expected, clientIP)
140+
}
141+
}

0 commit comments

Comments
 (0)