|
5 | 5 | "fmt"
|
6 | 6 | "os"
|
7 | 7 | "path/filepath"
|
| 8 | + "strconv" |
8 | 9 | "time"
|
9 | 10 |
|
10 | 11 | "github.com/go-chi/httprate"
|
@@ -125,32 +126,27 @@ func (c *redisCounter) Get(key string, currentWindow, previousWindow time.Time)
|
125 | 126 | ctx := context.Background()
|
126 | 127 | conn := c.client
|
127 | 128 |
|
128 |
| - cmd := conn.Do(ctx, "GET", c.limitCounterKey(key, currentWindow)) |
129 |
| - if cmd == nil { |
130 |
| - return 0, 0, fmt.Errorf("httprateredis: redis get curr failed") |
131 |
| - } |
132 |
| - if err := cmd.Err(); err != nil && err != redis.Nil { |
133 |
| - return 0, 0, fmt.Errorf("httprateredis: redis get curr failed: %w", err) |
134 |
| - } |
| 129 | + currKey := c.limitCounterKey(key, currentWindow) |
| 130 | + prevKey := c.limitCounterKey(key, previousWindow) |
135 | 131 |
|
136 |
| - curr, err := cmd.Int() |
137 |
| - if err != nil && err != redis.Nil { |
138 |
| - return 0, 0, fmt.Errorf("httprateredis: redis int curr value: %w", err) |
| 132 | + values, err := conn.MGet(ctx, currKey, prevKey).Result() |
| 133 | + if err != nil { |
| 134 | + return 0, 0, fmt.Errorf("httprateredis: redis mget failed: %w", err) |
| 135 | + } else if len(values) != 2 { |
| 136 | + return 0, 0, fmt.Errorf("httprateredis: redis mget returned wrong number of keys: %v, expected 2", len(values)) |
139 | 137 | }
|
140 | 138 |
|
141 |
| - cmd = conn.Do(ctx, "GET", c.limitCounterKey(key, previousWindow)) |
142 |
| - if cmd == nil { |
143 |
| - return 0, 0, fmt.Errorf("httprateredis: redis get prev failed") |
144 |
| - } |
| 139 | + var curr, prev int |
145 | 140 |
|
146 |
| - if err := cmd.Err(); err != nil && err != redis.Nil { |
147 |
| - return 0, 0, fmt.Errorf("httprateredis: redis get prev failed: %w", err) |
| 141 | + // MGET always returns slice with nil or "string" values, even if the values |
| 142 | + // were created with the INCR command. Ignore error if we can't parse the number. |
| 143 | + if values[0] != nil { |
| 144 | + v, _ := values[0].(string) |
| 145 | + curr, _ = strconv.Atoi(v) |
148 | 146 | }
|
149 |
| - |
150 |
| - var prev int |
151 |
| - prev, err = cmd.Int() |
152 |
| - if err != nil && err != redis.Nil { |
153 |
| - return 0, 0, fmt.Errorf("httprateredis: redis int prev value: %w", err) |
| 147 | + if values[1] != nil { |
| 148 | + v, _ := values[1].(string) |
| 149 | + prev, _ = strconv.Atoi(v) |
154 | 150 | }
|
155 | 151 |
|
156 | 152 | return curr, prev, nil
|
|
0 commit comments