@@ -2,10 +2,13 @@ package httprateredis
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"fmt"
7
+ "net"
6
8
"os"
7
9
"path/filepath"
8
10
"strconv"
11
+ "sync/atomic"
9
12
"time"
10
13
11
14
"github.com/go-chi/httprate"
@@ -16,14 +19,11 @@ func WithRedisLimitCounter(cfg *Config) httprate.Option {
16
19
if cfg .Disabled {
17
20
return httprate .WithNoop ()
18
21
}
19
- rc , err := NewRedisLimitCounter (cfg )
20
- if err != nil {
21
- panic (err )
22
- }
22
+ rc , _ := NewRedisLimitCounter (cfg )
23
23
return httprate .WithLimitCounter (rc )
24
24
}
25
25
26
- func NewRedisLimitCounter (cfg * Config ) (httprate. LimitCounter , error ) {
26
+ func NewRedisLimitCounter (cfg * Config ) (* redisCounter , error ) {
27
27
if cfg == nil {
28
28
cfg = & Config {}
29
29
}
@@ -39,20 +39,15 @@ func NewRedisLimitCounter(cfg *Config) (httprate.LimitCounter, error) {
39
39
if cfg .PrefixKey == "" {
40
40
cfg .PrefixKey = "httprate"
41
41
}
42
-
43
- c , err := newClient (cfg )
44
- if err != nil {
45
- return nil , err
42
+ if cfg .CommandTimeout == 0 {
43
+ cfg .CommandTimeout = 50 * time .Millisecond
46
44
}
47
- return & redisCounter {
48
- client : c ,
49
- prefixKey : cfg .PrefixKey ,
50
- }, nil
51
- }
52
45
53
- func newClient (cfg * Config ) (* redis.Client , error ) {
54
- if cfg .Client != nil {
55
- return cfg .Client , nil
46
+ rc := & redisCounter {
47
+ prefixKey : cfg .PrefixKey ,
48
+ }
49
+ if ! cfg .FallbackDisabled {
50
+ rc .fallbackCounter = httprate .NewLocalLimitCounter (cfg .WindowLength )
56
51
}
57
52
58
53
var maxIdle , maxActive = cfg .MaxIdle , cfg .MaxActive
@@ -64,80 +59,112 @@ func newClient(cfg *Config) (*redis.Client, error) {
64
59
}
65
60
66
61
address := fmt .Sprintf ("%s:%d" , cfg .Host , cfg .Port )
67
- c : = redis .NewClient (& redis.Options {
62
+ rc . client = redis .NewClient (& redis.Options {
68
63
Addr : address ,
69
64
Password : cfg .Password ,
70
65
DB : cfg .DBIndex ,
71
66
PoolSize : maxActive ,
72
67
MaxIdleConns : maxIdle ,
73
68
ClientName : cfg .ClientName ,
74
- })
75
69
76
- status := c .Ping (context .Background ())
77
- if status == nil || status .Err () != nil {
78
- return nil , fmt .Errorf ("httprateredis: unable to dial redis host %v" , address )
79
- }
70
+ DialTimeout : cfg .CommandTimeout ,
71
+ ReadTimeout : cfg .CommandTimeout ,
72
+ WriteTimeout : cfg .CommandTimeout ,
73
+ MinIdleConns : 1 ,
74
+ MaxRetries : - 1 ,
75
+ })
80
76
81
- return c , nil
77
+ return rc , nil
82
78
}
83
79
84
80
type redisCounter struct {
85
- client * redis.Client
86
- windowLength time.Duration
87
- prefixKey string
81
+ client * redis.Client
82
+ windowLength time.Duration
83
+ prefixKey string
84
+ isRedisDown atomic.Bool
85
+ fallbackCounter httprate.LimitCounter
88
86
}
89
87
90
- var _ httprate.LimitCounter = & redisCounter {}
88
+ var _ httprate.LimitCounter = ( * redisCounter )( nil )
91
89
92
90
func (c * redisCounter ) Config (requestLimit int , windowLength time.Duration ) {
93
91
c .windowLength = windowLength
92
+ if c .fallbackCounter != nil {
93
+ c .fallbackCounter .Config (requestLimit , windowLength )
94
+ }
94
95
}
95
96
96
97
func (c * redisCounter ) Increment (key string , currentWindow time.Time ) error {
97
98
return c .IncrementBy (key , currentWindow , 1 )
98
99
}
99
100
100
- func (c * redisCounter ) IncrementBy (key string , currentWindow time.Time , amount int ) error {
101
- ctx := context .Background ()
102
- conn := c .client
103
-
101
+ func (c * redisCounter ) IncrementBy (key string , currentWindow time.Time , amount int ) (err error ) {
102
+ if c .fallbackCounter != nil {
103
+ if c .isRedisDown .Load () {
104
+ return c .fallbackCounter .IncrementBy (key , currentWindow , amount )
105
+ }
106
+ defer func () {
107
+ if err != nil {
108
+ // On redis network error, fallback to local in-memory counter.
109
+ var netErr net.Error
110
+ if errors .As (err , & netErr ) || errors .Is (err , redis .ErrClosed ) {
111
+ go c .fallback ()
112
+ err = c .fallbackCounter .IncrementBy (key , currentWindow , amount )
113
+ }
114
+ }
115
+ }()
116
+ }
117
+
118
+ ctx := context .Background () // Note: We use timeouts set up on the Redis client directly.
104
119
hkey := c .limitCounterKey (key , currentWindow )
105
120
106
- pipe := conn .TxPipeline ()
121
+ pipe := c . client .TxPipeline ()
107
122
incrCmd := pipe .IncrBy (ctx , hkey , int64 (amount ))
108
123
expireCmd := pipe .Expire (ctx , hkey , c .windowLength * 3 )
109
- _ , err := pipe .Exec (ctx )
124
+
125
+ _ , err = pipe .Exec (ctx )
110
126
if err != nil {
111
127
return fmt .Errorf ("httprateredis: redis transaction failed: %w" , err )
112
128
}
113
-
114
129
if err := incrCmd .Err (); err != nil {
115
130
return fmt .Errorf ("httprateredis: redis incr failed: %w" , err )
116
131
}
117
-
118
132
if err := expireCmd .Err (); err != nil {
119
133
return fmt .Errorf ("httprateredis: redis expire failed: %w" , err )
120
134
}
121
135
122
136
return nil
123
137
}
124
138
125
- func (c * redisCounter ) Get (key string , currentWindow , previousWindow time.Time ) (int , int , error ) {
126
- ctx := context .Background ()
127
- conn := c .client
139
+ func (c * redisCounter ) Get (key string , currentWindow , previousWindow time.Time ) (curr int , prev int , err error ) {
140
+ if c .fallbackCounter != nil {
141
+ if c .isRedisDown .Load () {
142
+ return c .fallbackCounter .Get (key , currentWindow , previousWindow )
143
+ }
144
+ defer func () {
145
+ if err != nil {
146
+ // On redis network error, fallback to local in-memory counter.
147
+ var netErr net.Error
148
+ if errors .As (err , & netErr ) || errors .Is (err , redis .ErrClosed ) {
149
+ go c .fallback ()
150
+ curr , prev , err = c .fallbackCounter .Get (key , currentWindow , previousWindow )
151
+ }
152
+ }
153
+ }()
154
+ }
155
+
156
+ ctx := context .Background () // Note: We use timeouts set up on the Redis client directly.
128
157
129
158
currKey := c .limitCounterKey (key , currentWindow )
130
159
prevKey := c .limitCounterKey (key , previousWindow )
131
160
132
- values , err := conn .MGet (ctx , currKey , prevKey ).Result ()
161
+ values , err := c . client .MGet (ctx , currKey , prevKey ).Result ()
133
162
if err != nil {
134
163
return 0 , 0 , fmt .Errorf ("httprateredis: redis mget failed: %w" , err )
135
164
} else if len (values ) != 2 {
136
165
return 0 , 0 , fmt .Errorf ("httprateredis: redis mget returned wrong number of keys: %v, expected 2" , len (values ))
137
166
}
138
167
139
- var curr , prev int
140
-
141
168
// MGET always returns slice with nil or "string" values, even if the values
142
169
// were created with the INCR command. Ignore error if we can't parse the number.
143
170
if values [0 ] != nil {
@@ -152,6 +179,28 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time)
152
179
return curr , prev , nil
153
180
}
154
181
182
+ func (c * redisCounter ) IsRedisDown () bool {
183
+ return c .isRedisDown .Load ()
184
+ }
185
+
186
+ func (c * redisCounter ) fallback () {
187
+ // Fallback to in-memory counter.
188
+ wasAlreadyDown := c .isRedisDown .Swap (true )
189
+ if wasAlreadyDown {
190
+ return
191
+ }
192
+
193
+ // Try to re-connect to redis every 50ms.
194
+ for {
195
+ err := c .client .Ping (context .Background ()).Err ()
196
+ if err == nil {
197
+ c .isRedisDown .Store (false )
198
+ return
199
+ }
200
+ time .Sleep (50 * time .Millisecond )
201
+ }
202
+ }
203
+
155
204
func (c * redisCounter ) limitCounterKey (key string , window time.Time ) string {
156
205
return fmt .Sprintf ("%s:%d" , c .prefixKey , httprate .LimitCounterKey (key , window ))
157
206
}
0 commit comments