@@ -2,10 +2,14 @@ package httprateredis
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"fmt"
7
+ "log"
8
+ "net"
6
9
"os"
7
10
"path/filepath"
8
11
"strconv"
12
+ "sync/atomic"
9
13
"time"
10
14
11
15
"github.com/go-chi/httprate"
@@ -16,14 +20,11 @@ func WithRedisLimitCounter(cfg *Config) httprate.Option {
16
20
if cfg .Disabled {
17
21
return httprate .WithNoop ()
18
22
}
19
- rc , err := NewRedisLimitCounter (cfg )
20
- if err != nil {
21
- panic (err )
22
- }
23
+ rc , _ := NewRedisLimitCounter (cfg )
23
24
return httprate .WithLimitCounter (rc )
24
25
}
25
26
26
- func NewRedisLimitCounter (cfg * Config ) (httprate. LimitCounter , error ) {
27
+ func NewRedisLimitCounter (cfg * Config ) (* redisCounter , error ) {
27
28
if cfg == nil {
28
29
cfg = & Config {}
29
30
}
@@ -39,20 +40,15 @@ func NewRedisLimitCounter(cfg *Config) (httprate.LimitCounter, error) {
39
40
if cfg .PrefixKey == "" {
40
41
cfg .PrefixKey = "httprate"
41
42
}
42
-
43
- c , err := newClient (cfg )
44
- if err != nil {
45
- return nil , err
43
+ if cfg .CommandTimeout == 0 {
44
+ cfg .CommandTimeout = 50 * time .Millisecond
46
45
}
47
- return & redisCounter {
48
- client : c ,
49
- prefixKey : cfg .PrefixKey ,
50
- }, nil
51
- }
52
46
53
- func newClient (cfg * Config ) (* redis.Client , error ) {
54
- if cfg .Client != nil {
55
- return cfg .Client , nil
47
+ rc := & redisCounter {
48
+ prefixKey : cfg .PrefixKey ,
49
+ }
50
+ if ! cfg .FallbackDisabled {
51
+ rc .fallbackCounter = httprate .NewLocalLimitCounter (cfg .WindowLength )
56
52
}
57
53
58
54
var maxIdle , maxActive = cfg .MaxIdle , cfg .MaxActive
@@ -64,80 +60,118 @@ func newClient(cfg *Config) (*redis.Client, error) {
64
60
}
65
61
66
62
address := fmt .Sprintf ("%s:%d" , cfg .Host , cfg .Port )
67
- c : = redis .NewClient (& redis.Options {
63
+ rc . client = redis .NewClient (& redis.Options {
68
64
Addr : address ,
69
65
Password : cfg .Password ,
70
66
DB : cfg .DBIndex ,
71
67
PoolSize : maxActive ,
72
68
MaxIdleConns : maxIdle ,
73
69
ClientName : cfg .ClientName ,
74
- })
75
70
76
- status := c .Ping (context .Background ())
77
- if status == nil || status .Err () != nil {
78
- return nil , fmt .Errorf ("httprateredis: unable to dial redis host %v" , address )
79
- }
71
+ DialTimeout : cfg .CommandTimeout ,
72
+ ReadTimeout : cfg .CommandTimeout ,
73
+ WriteTimeout : cfg .CommandTimeout ,
74
+ OnConnect : func (ctx context.Context , cn * redis.Conn ) error {
75
+ log .Println ("redis UP" )
76
+ rc .isRedisDown .Store (false )
77
+ return nil
78
+ },
79
+ MinIdleConns : 1 ,
80
+ MaxRetries : 5 ,
81
+ //MaxRetryBackoff: time.Duration,
82
+ })
80
83
81
- return c , nil
84
+ return rc , nil
82
85
}
83
86
84
87
type redisCounter struct {
85
- client * redis.Client
86
- windowLength time.Duration
87
- prefixKey string
88
+ client * redis.Client
89
+ windowLength time.Duration
90
+ prefixKey string
91
+ isRedisDown atomic.Bool
92
+ fallbackCounter httprate.LimitCounter
88
93
}
89
94
90
- var _ httprate.LimitCounter = & redisCounter {}
95
+ var _ httprate.LimitCounter = ( * redisCounter )( nil )
91
96
92
97
func (c * redisCounter ) Config (requestLimit int , windowLength time.Duration ) {
93
98
c .windowLength = windowLength
99
+ if c .fallbackCounter != nil {
100
+ c .fallbackCounter .Config (requestLimit , windowLength )
101
+ }
94
102
}
95
103
96
104
func (c * redisCounter ) Increment (key string , currentWindow time.Time ) error {
97
105
return c .IncrementBy (key , currentWindow , 1 )
98
106
}
99
107
100
- func (c * redisCounter ) IncrementBy (key string , currentWindow time.Time , amount int ) error {
101
- ctx := context .Background ()
102
- conn := c .client
103
-
108
+ func (c * redisCounter ) IncrementBy (key string , currentWindow time.Time , amount int ) (err error ) {
109
+ if c .fallbackCounter != nil {
110
+ if c .isRedisDown .Load () {
111
+ return c .fallbackCounter .IncrementBy (key , currentWindow , amount )
112
+ }
113
+ defer func () {
114
+ if err != nil {
115
+ // On redis network error, fallback to local in-memory counter.
116
+ var netErr net.Error
117
+ if errors .As (err , & netErr ) || errors .Is (err , redis .ErrClosed ) {
118
+ go c .fallback ()
119
+ err = c .fallbackCounter .IncrementBy (key , currentWindow , amount )
120
+ }
121
+ }
122
+ }()
123
+ }
124
+
125
+ ctx := context .Background () // Note: We use timeouts set up on the Redis client directly.
104
126
hkey := c .limitCounterKey (key , currentWindow )
105
127
106
- pipe := conn .TxPipeline ()
128
+ pipe := c . client .TxPipeline ()
107
129
incrCmd := pipe .IncrBy (ctx , hkey , int64 (amount ))
108
130
expireCmd := pipe .Expire (ctx , hkey , c .windowLength * 3 )
109
- _ , err := pipe .Exec (ctx )
131
+
132
+ _ , err = pipe .Exec (ctx )
110
133
if err != nil {
111
134
return fmt .Errorf ("httprateredis: redis transaction failed: %w" , err )
112
135
}
113
-
114
136
if err := incrCmd .Err (); err != nil {
115
137
return fmt .Errorf ("httprateredis: redis incr failed: %w" , err )
116
138
}
117
-
118
139
if err := expireCmd .Err (); err != nil {
119
140
return fmt .Errorf ("httprateredis: redis expire failed: %w" , err )
120
141
}
121
142
122
143
return nil
123
144
}
124
145
125
- func (c * redisCounter ) Get (key string , currentWindow , previousWindow time.Time ) (int , int , error ) {
126
- ctx := context .Background ()
127
- conn := c .client
146
+ func (c * redisCounter ) Get (key string , currentWindow , previousWindow time.Time ) (curr int , prev int , err error ) {
147
+ if c .fallbackCounter != nil {
148
+ if c .isRedisDown .Load () {
149
+ return c .fallbackCounter .Get (key , currentWindow , previousWindow )
150
+ }
151
+ defer func () {
152
+ if err != nil {
153
+ // On redis network error, fallback to local in-memory counter.
154
+ var netErr net.Error
155
+ if errors .As (err , & netErr ) || errors .Is (err , redis .ErrClosed ) {
156
+ go c .fallback ()
157
+ curr , prev , err = c .fallbackCounter .Get (key , currentWindow , previousWindow )
158
+ }
159
+ }
160
+ }()
161
+ }
162
+
163
+ ctx := context .Background () // Note: We use timeouts set up on the Redis client directly.
128
164
129
165
currKey := c .limitCounterKey (key , currentWindow )
130
166
prevKey := c .limitCounterKey (key , previousWindow )
131
167
132
- values , err := conn .MGet (ctx , currKey , prevKey ).Result ()
168
+ values , err := c . client .MGet (ctx , currKey , prevKey ).Result ()
133
169
if err != nil {
134
170
return 0 , 0 , fmt .Errorf ("httprateredis: redis mget failed: %w" , err )
135
171
} else if len (values ) != 2 {
136
172
return 0 , 0 , fmt .Errorf ("httprateredis: redis mget returned wrong number of keys: %v, expected 2" , len (values ))
137
173
}
138
174
139
- var curr , prev int
140
-
141
175
// MGET always returns slice with nil or "string" values, even if the values
142
176
// were created with the INCR command. Ignore error if we can't parse the number.
143
177
if values [0 ] != nil {
@@ -152,6 +186,31 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time)
152
186
return curr , prev , nil
153
187
}
154
188
189
+ func (c * redisCounter ) IsRedisDown () bool {
190
+ return c .isRedisDown .Load ()
191
+ }
192
+
193
+ func (c * redisCounter ) fallback () {
194
+ log .Println ("redis DOWN" )
195
+ // Fallback to in-memory counter.
196
+ wasAlreadyDown := c .isRedisDown .Swap (true )
197
+ if wasAlreadyDown {
198
+ return
199
+ }
200
+
201
+ // Try to re-connect to redis.
202
+ for {
203
+ log .Println ("redis PING..." )
204
+ err := c .client .Ping (context .Background ()).Err ()
205
+ if err == nil {
206
+ c .isRedisDown .Store (false )
207
+ return
208
+ }
209
+ //time.Sleep(10 * time.Millisecond)
210
+ time .Sleep (time .Second )
211
+ }
212
+ }
213
+
155
214
func (c * redisCounter ) limitCounterKey (key string , window time.Time ) string {
156
215
return fmt .Sprintf ("%s:%d" , c .prefixKey , httprate .LimitCounterKey (key , window ))
157
216
}
0 commit comments