Skip to content

Commit 565c97d

Browse files
committed
Add support for context counter
1 parent 9719634 commit 565c97d

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

httprate.go

+21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package httprate
22

33
import (
4+
"context"
45
"net"
56
"net/http"
67
"strings"
@@ -64,7 +65,27 @@ func WithLimitHandler(h http.HandlerFunc) Option {
6465
}
6566
}
6667

68+
// noContextLimitCounter wraps the context-less LimitCounter to implement the ContextLimitCounter interface.
69+
// Exists to maintain compatiblity.
70+
type noContextLimitCounter struct {
71+
LimitCounter
72+
}
73+
74+
func (l *noContextLimitCounter) Increment(_ context.Context, key string, currentWindow time.Time) error {
75+
return l.LimitCounter.Increment(key, currentWindow)
76+
}
77+
78+
func (l *noContextLimitCounter) Get(_ context.Context, key string, previousWindow, currentWindow time.Time) (int, int, error) {
79+
return l.LimitCounter.Get(key, previousWindow, currentWindow)
80+
}
81+
6782
func WithLimitCounter(c LimitCounter) Option {
83+
return func(rl *rateLimiter) {
84+
rl.limitCounter = &noContextLimitCounter{LimitCounter: c}
85+
}
86+
}
87+
88+
func WithContextLimitCounter(c ContextLimitCounter) Option {
6889
return func(rl *rateLimiter) {
6990
rl.limitCounter = c
7091
}

limiter.go

+40-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package httprate
22

33
import (
4+
"context"
45
"fmt"
56
"math"
67
"net/http"
@@ -15,6 +16,11 @@ type LimitCounter interface {
1516
Get(key string, previousWindow, currentWindow time.Time) (int, int, error)
1617
}
1718

19+
type ContextLimitCounter interface {
20+
Increment(ctx context.Context, key string, currentWindow time.Time) error
21+
Get(ctx context.Context, key string, previousWindow, currentWindow time.Time) (int, int, error)
22+
}
23+
1824
func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter {
1925
return newRateLimiter(requestLimit, windowLength, options...)
2026
}
@@ -58,24 +64,42 @@ func LimitCounterKey(key string, window time.Time) uint64 {
5864
return h.Sum64()
5965
}
6066

67+
// limitCounterWrap implements the LimitCounter interface without context.
68+
// Calls ContextLimitCounter with context.Background(), exists to maintain compatibility.
69+
type limitCounterWrap struct {
70+
ContextLimitCounter
71+
}
72+
73+
func (l *limitCounterWrap) Increment(key string, currentWindow time.Time) error {
74+
return l.ContextLimitCounter.Increment(context.Background(), key, currentWindow)
75+
}
76+
77+
func (l *limitCounterWrap) Get(key string, previousWindow, currentWindow time.Time) (int, int, error) {
78+
return l.ContextLimitCounter.Get(context.Background(), key, previousWindow, currentWindow)
79+
}
80+
6181
type rateLimiter struct {
6282
requestLimit int
6383
windowLength time.Duration
6484
keyFn KeyFunc
65-
limitCounter LimitCounter
85+
limitCounter ContextLimitCounter
6686
onRequestLimit http.HandlerFunc
6787
}
6888

6989
func (r *rateLimiter) Counter() LimitCounter {
90+
return &limitCounterWrap{ContextLimitCounter: r.limitCounter}
91+
}
92+
93+
func (r *rateLimiter) ContextCounter() ContextLimitCounter {
7094
return r.limitCounter
7195
}
7296

73-
func (r *rateLimiter) Status(key string) (bool, float64, error) {
97+
func (r *rateLimiter) ContextStatus(ctx context.Context, key string) (bool, float64, error) {
7498
t := time.Now().UTC()
7599
currentWindow := t.Truncate(r.windowLength)
76100
previousWindow := currentWindow.Add(-r.windowLength)
77101

78-
currCount, prevCount, err := r.limitCounter.Get(key, currentWindow, previousWindow)
102+
currCount, prevCount, err := r.limitCounter.Get(ctx, key, currentWindow, previousWindow)
79103
if err != nil {
80104
return false, 0, err
81105
}
@@ -89,8 +113,14 @@ func (r *rateLimiter) Status(key string) (bool, float64, error) {
89113
return true, rate, nil
90114
}
91115

116+
func (r *rateLimiter) Status(key string) (bool, float64, error) {
117+
return r.ContextStatus(context.Background(), key)
118+
}
119+
92120
func (l *rateLimiter) Handler(next http.Handler) http.Handler {
93121
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
122+
ctx := r.Context()
123+
94124
key, err := l.keyFn(r)
95125
if err != nil {
96126
http.Error(w, err.Error(), http.StatusPreconditionRequired)
@@ -120,7 +150,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
120150
return
121151
}
122152

123-
err = l.limitCounter.Increment(key, currentWindow)
153+
err = l.limitCounter.Increment(ctx, key, currentWindow)
124154
if err != nil {
125155
http.Error(w, err.Error(), http.StatusInternalServerError)
126156
return
@@ -137,14 +167,17 @@ type localCounter struct {
137167
mu sync.Mutex
138168
}
139169

140-
var _ LimitCounter = &localCounter{}
170+
var (
171+
_ LimitCounter = &limitCounterWrap{ContextLimitCounter: &localCounter{}}
172+
_ ContextLimitCounter = &localCounter{}
173+
)
141174

142175
type count struct {
143176
value int
144177
updatedAt time.Time
145178
}
146179

147-
func (c *localCounter) Increment(key string, currentWindow time.Time) error {
180+
func (c *localCounter) Increment(_ context.Context, key string, currentWindow time.Time) error {
148181
c.evict()
149182

150183
c.mu.Lock()
@@ -163,7 +196,7 @@ func (c *localCounter) Increment(key string, currentWindow time.Time) error {
163196
return nil
164197
}
165198

166-
func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) {
199+
func (c *localCounter) Get(_ context.Context, key string, currentWindow, previousWindow time.Time) (int, int, error) {
167200
c.mu.Lock()
168201
defer c.mu.Unlock()
169202

0 commit comments

Comments
 (0)