Skip to content

Commit 0d69b97

Browse files
committed
Fix passing custom redis client
1 parent fe06926 commit 0d69b97

File tree

3 files changed

+27
-24
lines changed

3 files changed

+27
-24
lines changed

_example/main.go

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ func main() {
3737
})
3838
})
3939

40+
// Rate-limit at 50 req/s per IP address.
4041
r.Use(httprate.Limit(
4142
50, time.Second,
4243
httprate.WithKeyByIP(),

config.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ type Config struct {
2222
// the system will use the local counter unless it is explicitly disabled.
2323
FallbackTimeout time.Duration `toml:"fallback_timeout"` // default: 50ms
2424

25-
// Client if supplied will be used and below fields will be ignored.
25+
// Client if supplied will be used and the below fields will be ignored.
2626
//
27-
// NOTE: It's recommended to set short Dial/Read/Write timeouts and disable
27+
// NOTE: It's recommended to set short dial/read/write timeouts and disable
2828
// retries on the client, so the local in-memory fallback can activate quickly.
2929
Client *redis.Client `toml:"-"`
3030
Host string `toml:"host"`

httprateredis.go

+24-22
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
4040
cfg.PrefixKey = "httprate"
4141
}
4242
if cfg.FallbackTimeout == 0 {
43+
// Activate local in-memory fallback fairly quickly, as this would slow down all requests.
4344
cfg.FallbackTimeout = 50 * time.Millisecond
4445
}
4546

@@ -50,29 +51,30 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
5051
rc.fallbackCounter = httprate.NewLocalLimitCounter(cfg.WindowLength)
5152
}
5253

53-
var maxIdle, maxActive = cfg.MaxIdle, cfg.MaxActive
54-
if maxIdle <= 0 {
55-
maxIdle = 20
56-
}
57-
if maxActive <= 0 {
58-
maxActive = 50
59-
}
54+
if cfg.Client == nil {
55+
maxIdle, maxActive := cfg.MaxIdle, cfg.MaxActive
56+
if maxIdle < 1 {
57+
maxIdle = 20
58+
}
59+
if maxActive < 1 {
60+
maxActive = 50
61+
}
6062

61-
address := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
62-
rc.client = redis.NewClient(&redis.Options{
63-
Addr: address,
64-
Password: cfg.Password,
65-
DB: cfg.DBIndex,
66-
PoolSize: maxActive,
67-
MaxIdleConns: maxIdle,
68-
ClientName: cfg.ClientName,
63+
rc.client = redis.NewClient(&redis.Options{
64+
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
65+
Password: cfg.Password,
66+
DB: cfg.DBIndex,
67+
ClientName: cfg.ClientName,
6968

70-
DialTimeout: cfg.FallbackTimeout,
71-
ReadTimeout: cfg.FallbackTimeout,
72-
WriteTimeout: cfg.FallbackTimeout,
73-
MinIdleConns: 1,
74-
MaxRetries: -1,
75-
})
69+
DialTimeout: cfg.FallbackTimeout,
70+
ReadTimeout: cfg.FallbackTimeout,
71+
WriteTimeout: cfg.FallbackTimeout,
72+
PoolSize: maxActive,
73+
MinIdleConns: 1,
74+
MaxIdleConns: maxIdle,
75+
MaxRetries: -1, // -1 disables retries
76+
})
77+
}
7678

7779
return rc, nil
7880
}
@@ -109,7 +111,7 @@ func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount i
109111
var netErr net.Error
110112
if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) {
111113
go c.fallback()
112-
err = c.fallbackCounter.IncrementBy(key, currentWindow, amount) // = nil
114+
err = c.fallbackCounter.IncrementBy(key, currentWindow, amount)
113115
}
114116
}
115117
}()

0 commit comments

Comments
 (0)