5
5
"fmt"
6
6
"os"
7
7
"path/filepath"
8
+ "strconv"
8
9
"time"
9
10
10
11
"github.com/go-chi/httprate"
@@ -102,20 +103,20 @@ func (c *redisCounter) IncrementBy(key string, currentWindow time.Time, amount i
102
103
103
104
hkey := c .limitCounterKey (key , currentWindow )
104
105
105
- cmd := conn .Do ( ctx , "INCRBY" , hkey , amount )
106
- if cmd == nil {
107
- return fmt . Errorf ( "httprateredis: redis incr failed" )
108
- }
109
- if err := cmd . Err (); err != nil {
110
- return err
106
+ pipe := conn .TxPipeline ( )
107
+ incrCmd := pipe . IncrBy ( ctx , hkey , int64 ( amount ))
108
+ expireCmd := pipe . Expire ( ctx , hkey , c . windowLength * 3 )
109
+ _ , err := pipe . Exec ( ctx )
110
+ if err != nil {
111
+ return fmt . Errorf ( "httprateredis: redis transaction failed: %w" , err )
111
112
}
112
113
113
- cmd = conn .Do (ctx , "EXPIRE" , hkey , c .windowLength .Seconds ()* 3 )
114
- if cmd == nil {
115
- return fmt .Errorf ("httprateredis: redis expire failed" )
114
+ if err := incrCmd .Err (); err != nil {
115
+ return fmt .Errorf ("httprateredis: redis incr failed: %w" , err )
116
116
}
117
- if err := cmd .Err (); err != nil {
118
- return err
117
+
118
+ if err := expireCmd .Err (); err != nil {
119
+ return fmt .Errorf ("httprateredis: redis expire failed: %w" , err )
119
120
}
120
121
121
122
return nil
@@ -125,32 +126,27 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time)
125
126
ctx := context .Background ()
126
127
conn := c .client
127
128
128
- cmd := conn .Do (ctx , "GET" , c .limitCounterKey (key , currentWindow ))
129
- if cmd == nil {
130
- return 0 , 0 , fmt .Errorf ("httprateredis: redis get curr failed" )
131
- }
132
- if err := cmd .Err (); err != nil && err != redis .Nil {
133
- return 0 , 0 , fmt .Errorf ("httprateredis: redis get curr failed: %w" , err )
134
- }
129
+ currKey := c .limitCounterKey (key , currentWindow )
130
+ prevKey := c .limitCounterKey (key , previousWindow )
135
131
136
- curr , err := cmd .Int ()
137
- if err != nil && err != redis .Nil {
138
- return 0 , 0 , fmt .Errorf ("httprateredis: redis int curr value: %w" , err )
132
+ values , err := conn .MGet (ctx , currKey , prevKey ).Result ()
133
+ if err != nil {
134
+ return 0 , 0 , fmt .Errorf ("httprateredis: redis mget failed: %w" , err )
135
+ } else if len (values ) != 2 {
136
+ return 0 , 0 , fmt .Errorf ("httprateredis: redis mget returned wrong number of keys: %v, expected 2" , len (values ))
139
137
}
140
138
141
- cmd = conn .Do (ctx , "GET" , c .limitCounterKey (key , previousWindow ))
142
- if cmd == nil {
143
- return 0 , 0 , fmt .Errorf ("httprateredis: redis get prev failed" )
144
- }
139
+ var curr , prev int
145
140
146
- if err := cmd .Err (); err != nil && err != redis .Nil {
147
- return 0 , 0 , fmt .Errorf ("httprateredis: redis get prev failed: %w" , err )
141
+ // MGET always returns slice with nil or "string" values, even if the values
142
+ // were created with the INCR command. Ignore error if we can't parse the number.
143
+ if values [0 ] != nil {
144
+ v , _ := values [0 ].(string )
145
+ curr , _ = strconv .Atoi (v )
148
146
}
149
-
150
- var prev int
151
- prev , err = cmd .Int ()
152
- if err != nil && err != redis .Nil {
153
- return 0 , 0 , fmt .Errorf ("httprateredis: redis int prev value: %w" , err )
147
+ if values [1 ] != nil {
148
+ v , _ := values [1 ].(string )
149
+ prev , _ = strconv .Atoi (v )
154
150
}
155
151
156
152
return curr , prev , nil
0 commit comments