Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OnError and OnFallbackChange callbacks #18

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type Config struct {
ClientName string `toml:"client_name"` // default: os.Args[0]
PrefixKey string `toml:"prefix_key"` // default: "httprate"

// OnError lets you subscribe to all runtime Redis errors. Useful for logging/debugging.
OnError func(err error)

// Disable the use of the local in-memory fallback mechanism. When enabled,
// the system will return HTTP 428 for all requests when Redis is down.
FallbackDisabled bool `toml:"fallback_disabled"` // default: false
Expand All @@ -22,6 +25,9 @@ type Config struct {
// the system will use the local counter unless it is explicitly disabled.
FallbackTimeout time.Duration `toml:"fallback_timeout"` // default: 100ms

// OnFallbackChange lets subscribe to local in-memory fallback changes.
OnFallbackChange func(activated bool)

// Client if supplied will be used and the below fields will be ignored.
//
// NOTE: It's recommended to set short dial/read/write timeouts and disable
Expand Down
17 changes: 16 additions & 1 deletion httprateredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,18 @@ func NewRedisLimitCounter(cfg *Config) (*redisCounter, error) {
}

rc := &redisCounter{
prefixKey: cfg.PrefixKey,
prefixKey: cfg.PrefixKey,
onError: func(err error) {},
onFallback: func(activated bool) {},
}
if cfg.OnError != nil {
rc.onError = cfg.OnError
}
if !cfg.FallbackDisabled {
rc.fallbackCounter = httprate.NewLocalLimitCounter(cfg.WindowLength)
if cfg.OnFallbackChange != nil {
rc.onFallback = cfg.OnFallbackChange
}
}

if cfg.Client == nil {
Expand Down Expand Up @@ -89,6 +97,8 @@ type redisCounter struct {
prefixKey string
fallbackActivated atomic.Bool
fallbackCounter httprate.LimitCounter
onError func(err error)
onFallback func(activated bool)
}

var _ httprate.LimitCounter = (*redisCounter)(nil)
Expand Down Expand Up @@ -190,10 +200,12 @@ func (c *redisCounter) shouldFallback(err error) bool {
if err == nil {
return false
}
c.onError(err)

// Activate the local in-memory counter fallback, unless activated by some other goroutine.
alreadyActivated := c.fallbackActivated.Swap(true)
if !alreadyActivated {
c.onFallback(true)
go c.reconnect()
}

Expand All @@ -208,6 +220,9 @@ func (c *redisCounter) reconnect() {
err := c.client.Ping(context.Background()).Err()
if err == nil {
c.fallbackActivated.Store(false)
if c.onFallback != nil {
c.onFallback(false)
}
return
}
}
Expand Down
38 changes: 31 additions & 7 deletions local_fallback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@ func TestLocalFallback(t *testing.T) {
redis, err := miniredis.Run()
redisPort, _ := strconv.Atoi(redis.Port())

var onErrorCalled bool
var onFallbackCalled bool

limitCounter, err := httprateredis.NewRedisLimitCounter(&httprateredis.Config{
Host: redis.Host(),
Port: uint16(redisPort),
MaxIdle: 0,
MaxActive: 1,
ClientName: "httprateredis_test",
PrefixKey: fmt.Sprintf("httprate:test:%v", rand.Int31n(100000)), // Unique Redis key for each test
FallbackTimeout: 200 * time.Millisecond,
Host: redis.Host(),
Port: uint16(redisPort),
MaxIdle: 0,
MaxActive: 1,
ClientName: "httprateredis_test",
PrefixKey: fmt.Sprintf("httprate:test:%v", rand.Int31n(100000)), // Unique Redis key for each test
FallbackTimeout: 200 * time.Millisecond,
OnError: func(err error) { onErrorCalled = true },
OnFallbackChange: func(fallbackActivated bool) { onFallbackCalled = true },
})
if err != nil {
t.Fatalf("redis not available: %v", err)
Expand All @@ -37,6 +42,12 @@ func TestLocalFallback(t *testing.T) {
if limitCounter.IsFallbackActivated() {
t.Error("fallback should not be activated at the beginning")
}
if onErrorCalled {
t.Error("onError() should not be called at the beginning")
}
if onFallbackCalled {
t.Error("onFallback() should not be called before we simulate redis failure")
}

err = limitCounter.IncrementBy("key:fallback", currentWindow, 1)
if err != nil {
Expand All @@ -51,6 +62,12 @@ func TestLocalFallback(t *testing.T) {
if limitCounter.IsFallbackActivated() {
t.Error("fallback should not be activated before we simulate redis failure")
}
if onErrorCalled {
t.Error("onError() should not be called before we simulate redis failure")
}
if onFallbackCalled {
t.Error("onFallback() should not be called before we simulate redis failure")
}

redis.Close()

Expand All @@ -67,4 +84,11 @@ func TestLocalFallback(t *testing.T) {
if !limitCounter.IsFallbackActivated() {
t.Error("fallback should be activated after we simulate redis failure")
}
if !onErrorCalled {
t.Error("onError() should be called after we simulate redis failure")
}
if !onFallbackCalled {
t.Error("onFallback() should be called after we simulate redis failure")
}

}
Loading