Skip to content

Commit aa29c8f

Browse files
committed
Fallback to local in-memory counter if Redis is unavailable
1 parent 71c932d commit aa29c8f

File tree

5 files changed

+131
-70
lines changed

5 files changed

+131
-70
lines changed

config.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package httprateredis
2+
3+
import (
4+
"time"
5+
6+
"github.com/redis/go-redis/v9"
7+
)
8+
9+
type Config struct {
10+
Disabled bool `toml:"disabled"` // default: false
11+
12+
WindowLength time.Duration `toml:"window_length"` // default: 1m
13+
ClientName string `toml:"client_name"` // default: os.Args[0]
14+
PrefixKey string `toml:"prefix_key"` // default: "httprate"
15+
16+
// Disable the use of the local in-memory fallback mechanism. When enabled,
17+
// the system will return HTTP 428 for all requests when Redis is down.
18+
FallbackDisabled bool `toml:"fallback_disabled"` // default: false
19+
20+
// Client if supplied will be used and below fields will be ignored.
21+
//
22+
// It is recommended to disable retries and set short Dial/Read/Write
23+
// timeouts, so the local in-memory fallback can activate quickly.
24+
Client *redis.Client `toml:"-"`
25+
Host string `toml:"host"`
26+
Port uint16 `toml:"port"`
27+
Password string `toml:"password"` // optional
28+
DBIndex int `toml:"db_index"` // default: 0
29+
MaxIdle int `toml:"max_idle"` // default: 4
30+
MaxActive int `toml:"max_active"` // default: 8
31+
32+
// Timeout for each Redis command after which we fall back to a local
33+
// in-memory counter. If Redis does not respond within this duration,
34+
// the system will use the local counter unless it is explicitly disabled.
35+
CommandTimeout time.Duration `toml:"command_timeout"` // default: 50ms
36+
}

conn.go

-18
This file was deleted.

go.mod

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ module github.com/go-chi/httprate-redis
33
go 1.19
44

55
require (
6-
github.com/go-chi/httprate v0.9.0
6+
github.com/go-chi/httprate v0.12.0
77
github.com/redis/go-redis/v9 v9.6.0
8+
golang.org/x/sync v0.7.0
89
)
910

1011
require (
1112
github.com/cespare/xxhash/v2 v2.3.0 // indirect
1213
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
13-
golang.org/x/sync v0.7.0 // indirect
1414
)

go.sum

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
22
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
3-
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
4-
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
53
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
64
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
75
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
86
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
9-
github.com/go-chi/httprate v0.8.0 h1:CyKng28yhGnlGXH9EDGC/Qizj29afJQSNW15W/yj34o=
10-
github.com/go-chi/httprate v0.8.0/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A=
11-
github.com/go-chi/httprate v0.9.0 h1:21A+4WDMDA5FyWcg7mNrhj63aNT8CGh+Z1alOE/piU8=
12-
github.com/go-chi/httprate v0.9.0/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A=
13-
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
14-
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
7+
github.com/go-chi/httprate v0.12.0 h1:08D/te3pOTJe5+VAZTQrHxwdsH2NyliiUoRD1naKaMg=
8+
github.com/go-chi/httprate v0.12.0/go.mod h1:TUepLXaz/pCjmCtf/obgOQJ2Sz6rC8fSf5cAt5cnTt0=
159
github.com/redis/go-redis/v9 v9.6.0 h1:NLck+Rab3AOTHw21CGRpvQpgTrAU4sgdCswqGtlhGRA=
1610
github.com/redis/go-redis/v9 v9.6.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
1711
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=

httprateredis.go

+91-42
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package httprateredis
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"net"
68
"os"
79
"path/filepath"
810
"strconv"
11+
"sync/atomic"
912
"time"
1013

1114
"github.com/go-chi/httprate"
@@ -16,14 +19,11 @@ func WithRedisLimitCounter(cfg *Config) httprate.Option {
1619
if cfg.Disabled {
1720
return httprate.WithNoop()
1821
}
19-
rc, err := NewRedisLimitCounter(cfg)
20-
if err != nil {
21-
panic(err)
22-
}
22+
rc, _ := NewRedisLimitCounter(cfg)
2323
return httprate.WithLimitCounter(rc)
2424
}
2525

26-
func NewRedisLimitCounter(cfg *Config) (httprate.LimitCounter, error) {
26+
func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
2727
if cfg == nil {
2828
cfg = &Config{}
2929
}
@@ -39,20 +39,15 @@ func NewRedisLimitCounter(cfg *Config) (httprate.LimitCounter, error) {
3939
if cfg.PrefixKey == "" {
4040
cfg.PrefixKey = "httprate"
4141
}
42-
43-
c, err := newClient(cfg)
44-
if err != nil {
45-
return nil, err
42+
if cfg.CommandTimeout == 0 {
43+
cfg.CommandTimeout = 50 * time.Millisecond
4644
}
47-
return &redisCounter{
48-
client: c,
49-
prefixKey: cfg.PrefixKey,
50-
}, nil
51-
}
5245

53-
func newClient(cfg *Config) (*redis.Client, error) {
54-
if cfg.Client != nil {
55-
return cfg.Client, nil
46+
rc := &redisCounter{
47+
prefixKey: cfg.PrefixKey,
48+
}
49+
if !cfg.FallbackDisabled {
50+
rc.fallbackCounter = httprate.NewLocalLimitCounter(cfg.WindowLength)
5651
}
5752

5853
var maxIdle, maxActive = cfg.MaxIdle, cfg.MaxActive
@@ -64,80 +59,112 @@ func newClient(cfg *Config) (*redis.Client, error) {
6459
}
6560

6661
address := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
67-
c := redis.NewClient(&redis.Options{
62+
rc.client = redis.NewClient(&redis.Options{
6863
Addr: address,
6964
Password: cfg.Password,
7065
DB: cfg.DBIndex,
7166
PoolSize: maxActive,
7267
MaxIdleConns: maxIdle,
7368
ClientName: cfg.ClientName,
74-
})
7569

76-
status := c.Ping(context.Background())
77-
if status == nil || status.Err() != nil {
78-
return nil, fmt.Errorf("httprateredis: unable to dial redis host %v", address)
79-
}
70+
DialTimeout: cfg.CommandTimeout,
71+
ReadTimeout: cfg.CommandTimeout,
72+
WriteTimeout: cfg.CommandTimeout,
73+
MinIdleConns: 1,
74+
MaxRetries: -1,
75+
})
8076

81-
return c, nil
77+
return rc, nil
8278
}
8379

8480
type redisCounter struct {
85-
client *redis.Client
86-
windowLength time.Duration
87-
prefixKey string
81+
client *redis.Client
82+
windowLength time.Duration
83+
prefixKey string
84+
isRedisDown atomic.Bool
85+
fallbackCounter httprate.LimitCounter
8886
}
8987

90-
var _ httprate.LimitCounter = &redisCounter{}
88+
var _ httprate.LimitCounter = (*redisCounter)(nil)
9189

9290
func (c *redisCounter) Config(requestLimit int, windowLength time.Duration) {
9391
c.windowLength = windowLength
92+
if c.fallbackCounter != nil {
93+
c.fallbackCounter.Config(requestLimit, windowLength)
94+
}
9495
}
9596

9697
func (c *redisCounter) Increment(key string, currentWindow time.Time) error {
9798
return c.IncrementBy(key, currentWindow, 1)
9899
}
99100

100-
func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount int) error {
101-
ctx := context.Background()
102-
conn := c.client
103-
101+
func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount int) (err error) {
102+
if c.fallbackCounter != nil {
103+
if c.isRedisDown.Load() {
104+
return c.fallbackCounter.IncrementBy(key, currentWindow, amount)
105+
}
106+
defer func() {
107+
if err != nil {
108+
// On redis network error, fallback to local in-memory counter.
109+
var netErr net.Error
110+
if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) {
111+
go c.fallback()
112+
err = c.fallbackCounter.IncrementBy(key, currentWindow, amount)
113+
}
114+
}
115+
}()
116+
}
117+
118+
ctx := context.Background() // Note: We use timeouts set up on the Redis client directly.
104119
hkey := c.limitCounterKey(key, currentWindow)
105120

106-
pipe := conn.TxPipeline()
121+
pipe := c.client.TxPipeline()
107122
incrCmd := pipe.IncrBy(ctx, hkey, int64(amount))
108123
expireCmd := pipe.Expire(ctx, hkey, c.windowLength*3)
109-
_, err := pipe.Exec(ctx)
124+
125+
_, err = pipe.Exec(ctx)
110126
if err != nil {
111127
return fmt.Errorf("httprateredis: redis transaction failed: %w", err)
112128
}
113-
114129
if err := incrCmd.Err(); err != nil {
115130
return fmt.Errorf("httprateredis: redis incr failed: %w", err)
116131
}
117-
118132
if err := expireCmd.Err(); err != nil {
119133
return fmt.Errorf("httprateredis: redis expire failed: %w", err)
120134
}
121135

122136
return nil
123137
}
124138

125-
func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) {
126-
ctx := context.Background()
127-
conn := c.client
139+
func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time) (curr int, prev int, err error) {
140+
if c.fallbackCounter != nil {
141+
if c.isRedisDown.Load() {
142+
return c.fallbackCounter.Get(key, currentWindow, previousWindow)
143+
}
144+
defer func() {
145+
if err != nil {
146+
// On redis network error, fallback to local in-memory counter.
147+
var netErr net.Error
148+
if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) {
149+
go c.fallback()
150+
curr, prev, err = c.fallbackCounter.Get(key, currentWindow, previousWindow)
151+
}
152+
}
153+
}()
154+
}
155+
156+
ctx := context.Background() // Note: We use timeouts set up on the Redis client directly.
128157

129158
currKey := c.limitCounterKey(key, currentWindow)
130159
prevKey := c.limitCounterKey(key, previousWindow)
131160

132-
values, err := conn.MGet(ctx, currKey, prevKey).Result()
161+
values, err := c.client.MGet(ctx, currKey, prevKey).Result()
133162
if err != nil {
134163
return 0, 0, fmt.Errorf("httprateredis: redis mget failed: %w", err)
135164
} else if len(values) != 2 {
136165
return 0, 0, fmt.Errorf("httprateredis: redis mget returned wrong number of keys: %v, expected 2", len(values))
137166
}
138167

139-
var curr, prev int
140-
141168
// MGET always returns slice with nil or "string" values, even if the values
142169
// were created with the INCR command. Ignore error if we can't parse the number.
143170
if values[0] != nil {
@@ -152,6 +179,28 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time)
152179
return curr, prev, nil
153180
}
154181

182+
func (c *redisCounter) IsRedisDown() bool {
183+
return c.isRedisDown.Load()
184+
}
185+
186+
func (c *redisCounter) fallback() {
187+
// Fallback to in-memory counter.
188+
wasAlreadyDown := c.isRedisDown.Swap(true)
189+
if wasAlreadyDown {
190+
return
191+
}
192+
193+
// Try to re-connect to redis every 50ms.
194+
for {
195+
err := c.client.Ping(context.Background()).Err()
196+
if err == nil {
197+
c.isRedisDown.Store(false)
198+
return
199+
}
200+
time.Sleep(50 * time.Millisecond)
201+
}
202+
}
203+
155204
func (c *redisCounter) limitCounterKey(key string, window time.Time) string {
156205
return fmt.Sprintf("%s:%d", c.prefixKey, httprate.LimitCounterKey(key, window))
157206
}

0 commit comments

Comments
 (0)