Skip to content

Commit c590192

Browse files
committed
Fallback to local in-memory counter if Redis is unavailable
1 parent 71c932d commit c590192

File tree

1 file changed

+101
-42
lines changed

1 file changed

+101
-42
lines changed

httprateredis.go

+101-42
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ package httprateredis
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"log"
8+
"net"
69
"os"
710
"path/filepath"
811
"strconv"
12+
"sync/atomic"
913
"time"
1014

1115
"github.com/go-chi/httprate"
@@ -16,14 +20,11 @@ func WithRedisLimitCounter(cfg *Config) httprate.Option {
1620
if cfg.Disabled {
1721
return httprate.WithNoop()
1822
}
19-
rc, err := NewRedisLimitCounter(cfg)
20-
if err != nil {
21-
panic(err)
22-
}
23+
rc, _ := NewRedisLimitCounter(cfg)
2324
return httprate.WithLimitCounter(rc)
2425
}
2526

26-
func NewRedisLimitCounter(cfg *Config) (httprate.LimitCounter, error) {
27+
func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
2728
if cfg == nil {
2829
cfg = &Config{}
2930
}
@@ -39,20 +40,15 @@ func NewRedisLimitCounter(cfg *Config) (httprate.LimitCounter, error) {
3940
if cfg.PrefixKey == "" {
4041
cfg.PrefixKey = "httprate"
4142
}
42-
43-
c, err := newClient(cfg)
44-
if err != nil {
45-
return nil, err
43+
if cfg.CommandTimeout == 0 {
44+
cfg.CommandTimeout = 50 * time.Millisecond
4645
}
47-
return &redisCounter{
48-
client: c,
49-
prefixKey: cfg.PrefixKey,
50-
}, nil
51-
}
5246

53-
func newClient(cfg *Config) (*redis.Client, error) {
54-
if cfg.Client != nil {
55-
return cfg.Client, nil
47+
rc := &redisCounter{
48+
prefixKey: cfg.PrefixKey,
49+
}
50+
if !cfg.FallbackDisabled {
51+
rc.fallbackCounter = httprate.NewLocalLimitCounter(cfg.WindowLength)
5652
}
5753

5854
var maxIdle, maxActive = cfg.MaxIdle, cfg.MaxActive
@@ -64,80 +60,118 @@ func newClient(cfg *Config) (*redis.Client, error) {
6460
}
6561

6662
address := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
67-
c := redis.NewClient(&redis.Options{
63+
rc.client = redis.NewClient(&redis.Options{
6864
Addr: address,
6965
Password: cfg.Password,
7066
DB: cfg.DBIndex,
7167
PoolSize: maxActive,
7268
MaxIdleConns: maxIdle,
7369
ClientName: cfg.ClientName,
74-
})
7570

76-
status := c.Ping(context.Background())
77-
if status == nil || status.Err() != nil {
78-
return nil, fmt.Errorf("httprateredis: unable to dial redis host %v", address)
79-
}
71+
DialTimeout: cfg.CommandTimeout,
72+
ReadTimeout: cfg.CommandTimeout,
73+
WriteTimeout: cfg.CommandTimeout,
74+
OnConnect: func(ctx context.Context, cn *redis.Conn) error {
75+
log.Println("redis UP")
76+
rc.isRedisDown.Store(false)
77+
return nil
78+
},
79+
MinIdleConns: 1,
80+
MaxRetries: 5,
81+
//MaxRetryBackoff: time.Duration,
82+
})
8083

81-
return c, nil
84+
return rc, nil
8285
}
8386

8487
type redisCounter struct {
85-
client *redis.Client
86-
windowLength time.Duration
87-
prefixKey string
88+
client *redis.Client
89+
windowLength time.Duration
90+
prefixKey string
91+
isRedisDown atomic.Bool
92+
fallbackCounter httprate.LimitCounter
8893
}
8994

90-
var _ httprate.LimitCounter = &redisCounter{}
95+
var _ httprate.LimitCounter = (*redisCounter)(nil)
9196

9297
func (c *redisCounter) Config(requestLimit int, windowLength time.Duration) {
9398
c.windowLength = windowLength
99+
if c.fallbackCounter != nil {
100+
c.fallbackCounter.Config(requestLimit, windowLength)
101+
}
94102
}
95103

96104
func (c *redisCounter) Increment(key string, currentWindow time.Time) error {
97105
return c.IncrementBy(key, currentWindow, 1)
98106
}
99107

100-
func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount int) error {
101-
ctx := context.Background()
102-
conn := c.client
103-
108+
func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount int) (err error) {
109+
if c.fallbackCounter != nil {
110+
if c.isRedisDown.Load() {
111+
return c.fallbackCounter.IncrementBy(key, currentWindow, amount)
112+
}
113+
defer func() {
114+
if err != nil {
115+
// On redis network error, fallback to local in-memory counter.
116+
var netErr net.Error
117+
if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) {
118+
go c.fallback()
119+
err = c.fallbackCounter.IncrementBy(key, currentWindow, amount)
120+
}
121+
}
122+
}()
123+
}
124+
125+
ctx := context.Background() // Note: We use timeouts set up on the Redis client directly.
104126
hkey := c.limitCounterKey(key, currentWindow)
105127

106-
pipe := conn.TxPipeline()
128+
pipe := c.client.TxPipeline()
107129
incrCmd := pipe.IncrBy(ctx, hkey, int64(amount))
108130
expireCmd := pipe.Expire(ctx, hkey, c.windowLength*3)
109-
_, err := pipe.Exec(ctx)
131+
132+
_, err = pipe.Exec(ctx)
110133
if err != nil {
111134
return fmt.Errorf("httprateredis: redis transaction failed: %w", err)
112135
}
113-
114136
if err := incrCmd.Err(); err != nil {
115137
return fmt.Errorf("httprateredis: redis incr failed: %w", err)
116138
}
117-
118139
if err := expireCmd.Err(); err != nil {
119140
return fmt.Errorf("httprateredis: redis expire failed: %w", err)
120141
}
121142

122143
return nil
123144
}
124145

125-
func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) {
126-
ctx := context.Background()
127-
conn := c.client
146+
func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time) (curr int, prev int, err error) {
147+
if c.fallbackCounter != nil {
148+
if c.isRedisDown.Load() {
149+
return c.fallbackCounter.Get(key, currentWindow, previousWindow)
150+
}
151+
defer func() {
152+
if err != nil {
153+
// On redis network error, fallback to local in-memory counter.
154+
var netErr net.Error
155+
if errors.As(err, &netErr) || errors.Is(err, redis.ErrClosed) {
156+
go c.fallback()
157+
curr, prev, err = c.fallbackCounter.Get(key, currentWindow, previousWindow)
158+
}
159+
}
160+
}()
161+
}
162+
163+
ctx := context.Background() // Note: We use timeouts set up on the Redis client directly.
128164

129165
currKey := c.limitCounterKey(key, currentWindow)
130166
prevKey := c.limitCounterKey(key, previousWindow)
131167

132-
values, err := conn.MGet(ctx, currKey, prevKey).Result()
168+
values, err := c.client.MGet(ctx, currKey, prevKey).Result()
133169
if err != nil {
134170
return 0, 0, fmt.Errorf("httprateredis: redis mget failed: %w", err)
135171
} else if len(values) != 2 {
136172
return 0, 0, fmt.Errorf("httprateredis: redis mget returned wrong number of keys: %v, expected 2", len(values))
137173
}
138174

139-
var curr, prev int
140-
141175
// MGET always returns slice with nil or "string" values, even if the values
142176
// were created with the INCR command. Ignore error if we can't parse the number.
143177
if values[0] != nil {
@@ -152,6 +186,31 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time)
152186
return curr, prev, nil
153187
}
154188

189+
func (c *redisCounter) IsRedisDown() bool {
190+
return c.isRedisDown.Load()
191+
}
192+
193+
func (c *redisCounter) fallback() {
194+
log.Println("redis DOWN")
195+
// Fallback to in-memory counter.
196+
wasAlreadyDown := c.isRedisDown.Swap(true)
197+
if wasAlreadyDown {
198+
return
199+
}
200+
201+
// Try to re-connect to redis.
202+
for {
203+
log.Println("redis PING...")
204+
err := c.client.Ping(context.Background()).Err()
205+
if err == nil {
206+
c.isRedisDown.Store(false)
207+
return
208+
}
209+
//time.Sleep(10 * time.Millisecond)
210+
time.Sleep(time.Second)
211+
}
212+
}
213+
155214
func (c *redisCounter) limitCounterKey(key string, window time.Time) string {
156215
return fmt.Sprintf("%s:%d", c.prefixKey, httprate.LimitCounterKey(key, window))
157216
}

0 commit comments

Comments
 (0)