Skip to content

Commit 61daf03

Browse files
authored
Implement OnError and OnFallbackChange callbacks (#18)
1 parent 8812af7 commit 61daf03

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

config.go

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ type Config struct {
1313
ClientName string `toml:"client_name"` // default: os.Args[0]
1414
PrefixKey string `toml:"prefix_key"` // default: "httprate"
1515

16+
// OnError lets you subscribe to all runtime Redis errors. Useful for logging/debugging.
17+
OnError func(err error)
18+
1619
// Disable the use of the local in-memory fallback mechanism. When enabled,
1720
// the system will return HTTP 428 for all requests when Redis is down.
1821
FallbackDisabled bool `toml:"fallback_disabled"` // default: false
@@ -22,6 +25,9 @@ type Config struct {
2225
// the system will use the local counter unless it is explicitly disabled.
2326
FallbackTimeout time.Duration `toml:"fallback_timeout"` // default: 100ms
2427

28+
// OnFallbackChange lets subscribe to local in-memory fallback changes.
29+
OnFallbackChange func(activated bool)
30+
2531
// Client if supplied will be used and the below fields will be ignored.
2632
//
2733
// NOTE: It's recommended to set short dial/read/write timeouts and disable

httprateredis.go

+16-1
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,18 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
4848
}
4949

5050
rc := &redisCounter{
51-
prefixKey: cfg.PrefixKey,
51+
prefixKey: cfg.PrefixKey,
52+
onError: func(err error) {},
53+
onFallback: func(activated bool) {},
54+
}
55+
if cfg.OnError != nil {
56+
rc.onError = cfg.OnError
5257
}
5358
if !cfg.FallbackDisabled {
5459
rc.fallbackCounter = httprate.NewLocalLimitCounter(cfg.WindowLength)
60+
if cfg.OnFallbackChange != nil {
61+
rc.onFallback = cfg.OnFallbackChange
62+
}
5563
}
5664

5765
if cfg.Client == nil {
@@ -89,6 +97,8 @@ type redisCounter struct {
8997
prefixKey string
9098
fallbackActivated atomic.Bool
9199
fallbackCounter httprate.LimitCounter
100+
onError func(err error)
101+
onFallback func(activated bool)
92102
}
93103

94104
var _ httprate.LimitCounter = (*redisCounter)(nil)
@@ -190,10 +200,12 @@ func (c *redisCounter) shouldFallback(err error) bool {
190200
if err == nil {
191201
return false
192202
}
203+
c.onError(err)
193204

194205
// Activate the local in-memory counter fallback, unless activated by some other goroutine.
195206
alreadyActivated := c.fallbackActivated.Swap(true)
196207
if !alreadyActivated {
208+
c.onFallback(true)
197209
go c.reconnect()
198210
}
199211

@@ -208,6 +220,9 @@ func (c *redisCounter) reconnect() {
208220
err := c.client.Ping(context.Background()).Err()
209221
if err == nil {
210222
c.fallbackActivated.Store(false)
223+
if c.onFallback != nil {
224+
c.onFallback(false)
225+
}
211226
return
212227
}
213228
}

local_fallback_test.go

+31-7
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@ func TestLocalFallback(t *testing.T) {
1616
redis, err := miniredis.Run()
1717
redisPort, _ := strconv.Atoi(redis.Port())
1818

19+
var onErrorCalled bool
20+
var onFallbackCalled bool
21+
1922
limitCounter, err := httprateredis.NewRedisLimitCounter(&httprateredis.Config{
20-
Host: redis.Host(),
21-
Port: uint16(redisPort),
22-
MaxIdle: 0,
23-
MaxActive: 1,
24-
ClientName: "httprateredis_test",
25-
PrefixKey: fmt.Sprintf("httprate:test:%v", rand.Int31n(100000)), // Unique Redis key for each test
26-
FallbackTimeout: 200 * time.Millisecond,
23+
Host: redis.Host(),
24+
Port: uint16(redisPort),
25+
MaxIdle: 0,
26+
MaxActive: 1,
27+
ClientName: "httprateredis_test",
28+
PrefixKey: fmt.Sprintf("httprate:test:%v", rand.Int31n(100000)), // Unique Redis key for each test
29+
FallbackTimeout: 200 * time.Millisecond,
30+
OnError: func(err error) { onErrorCalled = true },
31+
OnFallbackChange: func(fallbackActivated bool) { onFallbackCalled = true },
2732
})
2833
if err != nil {
2934
t.Fatalf("redis not available: %v", err)
@@ -37,6 +42,12 @@ func TestLocalFallback(t *testing.T) {
3742
if limitCounter.IsFallbackActivated() {
3843
t.Error("fallback should not be activated at the beginning")
3944
}
45+
if onErrorCalled {
46+
t.Error("onError() should not be called at the beginning")
47+
}
48+
if onFallbackCalled {
49+
t.Error("onFallback() should not be called before we simulate redis failure")
50+
}
4051

4152
err = limitCounter.IncrementBy("key:fallback", currentWindow, 1)
4253
if err != nil {
@@ -51,6 +62,12 @@ func TestLocalFallback(t *testing.T) {
5162
if limitCounter.IsFallbackActivated() {
5263
t.Error("fallback should not be activated before we simulate redis failure")
5364
}
65+
if onErrorCalled {
66+
t.Error("onError() should not be called before we simulate redis failure")
67+
}
68+
if onFallbackCalled {
69+
t.Error("onFallback() should not be called before we simulate redis failure")
70+
}
5471

5572
redis.Close()
5673

@@ -67,4 +84,11 @@ func TestLocalFallback(t *testing.T) {
6784
if !limitCounter.IsFallbackActivated() {
6885
t.Error("fallback should be activated after we simulate redis failure")
6986
}
87+
if !onErrorCalled {
88+
t.Error("onError() should be called after we simulate redis failure")
89+
}
90+
if !onFallbackCalled {
91+
t.Error("onFallback() should be called after we simulate redis failure")
92+
}
93+
7094
}

0 commit comments

Comments
 (0)