1
1
package httprate
2
2
3
3
import (
4
+ "context"
4
5
"fmt"
5
6
"math"
6
7
"net/http"
@@ -15,6 +16,11 @@ type LimitCounter interface {
15
16
Get (key string , previousWindow , currentWindow time.Time ) (int , int , error )
16
17
}
17
18
19
+ type ContextLimitCounter interface {
20
+ Increment (ctx context.Context , key string , currentWindow time.Time ) error
21
+ Get (ctx context.Context , key string , previousWindow , currentWindow time.Time ) (int , int , error )
22
+ }
23
+
18
24
func NewRateLimiter (requestLimit int , windowLength time.Duration , options ... Option ) * rateLimiter {
19
25
return newRateLimiter (requestLimit , windowLength , options ... )
20
26
}
@@ -58,24 +64,42 @@ func LimitCounterKey(key string, window time.Time) uint64 {
58
64
return h .Sum64 ()
59
65
}
60
66
67
+ // limitCounterWrap implements the LimitCounter interface without context.
68
+ // Calls ContextLimitCounter with context.Background(), exists to maintain compatibility.
69
+ type limitCounterWrap struct {
70
+ ContextLimitCounter
71
+ }
72
+
73
+ func (l * limitCounterWrap ) Increment (key string , currentWindow time.Time ) error {
74
+ return l .ContextLimitCounter .Increment (context .Background (), key , currentWindow )
75
+ }
76
+
77
+ func (l * limitCounterWrap ) Get (key string , previousWindow , currentWindow time.Time ) (int , int , error ) {
78
+ return l .ContextLimitCounter .Get (context .Background (), key , previousWindow , currentWindow )
79
+ }
80
+
61
81
type rateLimiter struct {
62
82
requestLimit int
63
83
windowLength time.Duration
64
84
keyFn KeyFunc
65
- limitCounter LimitCounter
85
+ limitCounter ContextLimitCounter
66
86
onRequestLimit http.HandlerFunc
67
87
}
68
88
69
89
func (r * rateLimiter ) Counter () LimitCounter {
90
+ return & limitCounterWrap {ContextLimitCounter : r .limitCounter }
91
+ }
92
+
93
+ func (r * rateLimiter ) ContextCounter () ContextLimitCounter {
70
94
return r .limitCounter
71
95
}
72
96
73
- func (r * rateLimiter ) Status ( key string ) (bool , float64 , error ) {
97
+ func (r * rateLimiter ) ContextStatus ( ctx context. Context , key string ) (bool , float64 , error ) {
74
98
t := time .Now ().UTC ()
75
99
currentWindow := t .Truncate (r .windowLength )
76
100
previousWindow := currentWindow .Add (- r .windowLength )
77
101
78
- currCount , prevCount , err := r .limitCounter .Get (key , currentWindow , previousWindow )
102
+ currCount , prevCount , err := r .limitCounter .Get (ctx , key , currentWindow , previousWindow )
79
103
if err != nil {
80
104
return false , 0 , err
81
105
}
@@ -89,8 +113,14 @@ func (r *rateLimiter) Status(key string) (bool, float64, error) {
89
113
return true , rate , nil
90
114
}
91
115
116
+ func (r * rateLimiter ) Status (key string ) (bool , float64 , error ) {
117
+ return r .ContextStatus (context .Background (), key )
118
+ }
119
+
92
120
func (l * rateLimiter ) Handler (next http.Handler ) http.Handler {
93
121
return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
122
+ ctx := r .Context ()
123
+
94
124
key , err := l .keyFn (r )
95
125
if err != nil {
96
126
http .Error (w , err .Error (), http .StatusPreconditionRequired )
@@ -120,7 +150,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
120
150
return
121
151
}
122
152
123
- err = l .limitCounter .Increment (key , currentWindow )
153
+ err = l .limitCounter .Increment (ctx , key , currentWindow )
124
154
if err != nil {
125
155
http .Error (w , err .Error (), http .StatusInternalServerError )
126
156
return
@@ -137,14 +167,17 @@ type localCounter struct {
137
167
mu sync.Mutex
138
168
}
139
169
140
- var _ LimitCounter = & localCounter {}
170
+ var (
171
+ _ LimitCounter = & limitCounterWrap {ContextLimitCounter : & localCounter {}}
172
+ _ ContextLimitCounter = & localCounter {}
173
+ )
141
174
142
175
type count struct {
143
176
value int
144
177
updatedAt time.Time
145
178
}
146
179
147
- func (c * localCounter ) Increment (key string , currentWindow time.Time ) error {
180
+ func (c * localCounter ) Increment (_ context. Context , key string , currentWindow time.Time ) error {
148
181
c .evict ()
149
182
150
183
c .mu .Lock ()
@@ -163,7 +196,7 @@ func (c *localCounter) Increment(key string, currentWindow time.Time) error {
163
196
return nil
164
197
}
165
198
166
- func (c * localCounter ) Get (key string , currentWindow , previousWindow time.Time ) (int , int , error ) {
199
+ func (c * localCounter ) Get (_ context. Context , key string , currentWindow , previousWindow time.Time ) (int , int , error ) {
167
200
c .mu .Lock ()
168
201
defer c .mu .Unlock ()
169
202
0 commit comments