Skip to content

Commit 0b64165

Browse files
authored
Implement Redis backend for cache (#54)
1 parent 0880481 commit 0b64165

File tree

8 files changed

+277
-61
lines changed

8 files changed

+277
-61
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ Available flags:
5858
- `-verbose`
5959
- Allows to put the app into verbose mode and print out additional logs to stdout
6060
- Default value: none, no additional output is produced
61+
- `-redis`
62+
- Specifies Redis server's address
63+
- Type: string
64+
- Default value: none
6165

6266
You can use them like this:
6367
```bash

cache/cache.go

Lines changed: 132 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package cache
22

33
import (
4+
"context"
5+
"encoding/json"
46
"strings"
57
"time"
68

79
goCache "github.com/patrickmn/go-cache"
10+
"github.com/redis/go-redis/v9"
811
"github.com/spf13/viper"
9-
messagebus "github.com/vardius/message-bus"
1012
"golang.org/x/exp/maps"
1113

1214
"bdo-rest-api/models"
@@ -19,60 +21,65 @@ type CacheEntry[T any] struct {
1921
Status int `json:"status"`
2022
}
2123

22-
type cache[T any] struct {
23-
Bus messagebus.MessageBus
24-
internalCache *goCache.Cache
24+
type Cache[T any] interface {
25+
AddRecord(keys []string, data T, status int, taskId string) (date string, expires string)
26+
GetRecord(keys []string) (data T, status int, date string, expires string, found bool)
27+
GetItemCount() int
28+
GetKeys() []string
29+
GetValues() []CacheEntry[T]
2530
}
2631

2732
func joinKeys(keys []string) string {
2833
return strings.Join(keys, ",")
2934
}
3035

31-
func newCache[T any]() *cache[T] {
32-
cacheTTL := viper.GetDuration("cachettl")
36+
type memoryCache[T any] struct {
37+
internalCache *goCache.Cache
38+
ttl time.Duration
39+
}
40+
41+
func newMemoryCache[T any]() *memoryCache[T] {
42+
ttl := viper.GetDuration("cachettl")
3343

34-
return &cache[T]{
35-
Bus: messagebus.New(100), // Idk what buffer size is optimal
36-
internalCache: goCache.New(cacheTTL, min(time.Hour, cacheTTL)),
44+
return &memoryCache[T]{
45+
internalCache: goCache.New(ttl, min(time.Hour, ttl)),
46+
ttl: ttl,
3747
}
3848
}
3949

40-
func (c *cache[T]) AddRecord(keys []string, data T, status int, taskId string) (date string, expires string) {
41-
cacheTTL := viper.GetDuration("cachettl")
50+
func (c *memoryCache[T]) AddRecord(keys []string, data T, status int, taskId string) (date, expires string) {
4251
entry := CacheEntry[T]{
4352
Data: data,
4453
Date: time.Now(),
4554
Status: status,
4655
}
4756

48-
c.internalCache.Add(joinKeys(keys), entry, cacheTTL)
49-
c.Bus.Publish(taskId, entry)
57+
c.internalCache.Add(joinKeys(keys), entry, c.ttl)
5058

51-
return utils.FormatDateForHeaders(entry.Date), utils.FormatDateForHeaders(entry.Date.Add(cacheTTL))
59+
return utils.FormatDateForHeaders(entry.Date), utils.FormatDateForHeaders(entry.Date.Add(c.ttl))
5260
}
5361

54-
func (c *cache[T]) GetRecord(keys []string) (data T, status int, date string, expires string, found bool) {
55-
cacheTTL := viper.GetDuration("cachettl")
56-
anyEntry, found := c.internalCache.Get(joinKeys(keys))
62+
func (c *memoryCache[T]) GetRecord(keys []string) (data T, status int, date, expires string, found bool) {
63+
anyEntry, exp, found := c.internalCache.GetWithExpiration(joinKeys(keys))
5764

5865
if !found {
5966
return
6067
}
6168

6269
entry := anyEntry.(CacheEntry[T])
6370

64-
return entry.Data, entry.Status, utils.FormatDateForHeaders(entry.Date), utils.FormatDateForHeaders(entry.Date.Add(cacheTTL)), found
71+
return entry.Data, entry.Status, utils.FormatDateForHeaders(entry.Date), utils.FormatDateForHeaders(exp), true
6572
}
6673

67-
func (c *cache[T]) GetItemCount() int {
74+
func (c *memoryCache[T]) GetItemCount() int {
6875
return c.internalCache.ItemCount()
6976
}
7077

71-
func (c *cache[T]) GetKeys() []string {
78+
func (c *memoryCache[T]) GetKeys() []string {
7279
return maps.Keys(c.internalCache.Items())
7380
}
7481

75-
func (c *cache[T]) GetValues() []CacheEntry[T] {
82+
func (c *memoryCache[T]) GetValues() []CacheEntry[T] {
7683
items := c.internalCache.Items()
7784
result := make([]CacheEntry[T], 0, len(items))
7885

@@ -83,7 +90,107 @@ func (c *cache[T]) GetValues() []CacheEntry[T] {
8390
return result
8491
}
8592

86-
var GuildProfiles = newCache[models.GuildProfile]()
87-
var GuildSearch = newCache[[]models.GuildProfile]()
88-
var Profiles = newCache[models.Profile]()
89-
var ProfileSearch = newCache[[]models.Profile]()
93+
type redisCache[T any] struct {
94+
client *redis.Client
95+
ctx context.Context
96+
namespace string
97+
ttl time.Duration
98+
}
99+
100+
func newRedisCache[T any](client *redis.Client, namespace string) *redisCache[T] {
101+
return &redisCache[T]{
102+
client: client,
103+
ctx: context.Background(),
104+
namespace: namespace + ":",
105+
ttl: viper.GetDuration("cachettl"),
106+
}
107+
}
108+
109+
func (c *redisCache[T]) AddRecord(keys []string, data T, status int, taskId string) (date, expires string) {
110+
entry := CacheEntry[T]{
111+
Data: data,
112+
Date: time.Now(),
113+
Status: status,
114+
}
115+
116+
b, _ := json.Marshal(entry)
117+
c.client.Set(c.ctx, c.namespace+joinKeys(keys), b, c.ttl)
118+
119+
return utils.FormatDateForHeaders(entry.Date), utils.FormatDateForHeaders(entry.Date.Add(c.ttl))
120+
}
121+
122+
func (c *redisCache[T]) GetRecord(keys []string) (data T, status int, date string, expires string, found bool) {
123+
val, err := c.client.Get(c.ctx, c.namespace+joinKeys(keys)).Bytes()
124+
if err != nil {
125+
return
126+
}
127+
128+
var entry CacheEntry[T]
129+
if err := json.Unmarshal(val, &entry); err != nil {
130+
return
131+
}
132+
133+
ttl := c.client.TTL(c.ctx, c.namespace+joinKeys(keys)).Val()
134+
135+
return entry.Data, entry.Status, utils.FormatDateForHeaders(entry.Date), utils.FormatDateForHeaders(time.Now().Add(ttl)), true
136+
}
137+
138+
func (c *redisCache[T]) GetItemCount() int {
139+
keys, err := c.client.Keys(c.ctx, c.namespace+"*").Result()
140+
if err != nil {
141+
return 0
142+
}
143+
return len(keys)
144+
}
145+
146+
func (c *redisCache[T]) GetKeys() []string {
147+
keys, _ := c.client.Keys(c.ctx, c.namespace+"*").Result()
148+
149+
// Remove namespace from keys
150+
for i, k := range keys {
151+
keys[i] = strings.TrimPrefix(k, c.namespace)
152+
}
153+
154+
return keys
155+
}
156+
157+
func (c *redisCache[T]) GetValues() []CacheEntry[T] {
158+
keys, _ := c.client.Keys(c.ctx, c.namespace+"*").Result()
159+
result := make([]CacheEntry[T], 0, len(keys))
160+
161+
for _, k := range keys {
162+
val, err := c.client.Get(c.ctx, k).Bytes()
163+
if err != nil {
164+
continue
165+
}
166+
167+
var entry CacheEntry[T]
168+
if err := json.Unmarshal(val, &entry); err != nil {
169+
continue
170+
}
171+
result = append(result, entry)
172+
}
173+
174+
return result
175+
}
176+
177+
var (
178+
GuildProfiles Cache[models.GuildProfile]
179+
GuildSearch Cache[[]models.GuildProfile]
180+
Profiles Cache[models.Profile]
181+
ProfileSearch Cache[[]models.Profile]
182+
)
183+
184+
func InitCache() {
185+
if redisClient, err := newRedisClient(viper.GetString("redis")); err == nil {
186+
GuildProfiles = newRedisCache[models.GuildProfile](redisClient, "gpc")
187+
GuildSearch = newRedisCache[[]models.GuildProfile](redisClient, "gsc")
188+
Profiles = newRedisCache[models.Profile](redisClient, "pc")
189+
ProfileSearch = newRedisCache[[]models.Profile](redisClient, "psc")
190+
} else {
191+
GuildProfiles = newMemoryCache[models.GuildProfile]()
192+
GuildSearch = newMemoryCache[[]models.GuildProfile]()
193+
Profiles = newMemoryCache[models.Profile]()
194+
ProfileSearch = newMemoryCache[[]models.Profile]()
195+
}
196+
}

cache/cache_test.go

Lines changed: 98 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,58 +7,124 @@ import (
77
"github.com/spf13/viper"
88
)
99

10+
// Simple helper type for testing
11+
type testStruct struct {
12+
Value string
13+
}
14+
1015
func init() {
11-
viper.Set("cachettl", time.Second)
16+
// Ensure TTL is something predictable
17+
viper.Set("cachettl", time.Second*10)
18+
}
19+
20+
func TestJoinKeys(t *testing.T) {
21+
keys := []string{"a", "b", "c"}
22+
expected := "a,b,c"
23+
24+
if got := joinKeys(keys); got != expected {
25+
t.Fatalf("joinKeys() = %s; want %s", got, expected)
26+
}
1227
}
1328

14-
func TestCache(t *testing.T) {
15-
// Create a cache instance for testing
16-
testCache := newCache[string]()
29+
func TestMemoryCacheAddAndGetRecord(t *testing.T) {
30+
c := newMemoryCache[testStruct]()
1731

18-
// Test AddRecord and GetRecord
32+
data := testStruct{Value: "hello"}
1933
keys := []string{"key1", "key2"}
20-
data := "test data"
2134
status := 200
22-
taskId := "task-id"
2335

24-
date, expires := testCache.AddRecord(keys, data, status, taskId)
36+
_, _ = c.AddRecord(keys, data, status, "task123")
37+
38+
gotData, gotStatus, _, _, found := c.GetRecord(keys)
39+
if !found {
40+
t.Fatal("Expected record to be found, got not found")
41+
}
42+
43+
if gotData.Value != "hello" {
44+
t.Fatalf("Expected data 'hello', got '%s'", gotData.Value)
45+
}
2546

26-
// Validate AddRecord results
27-
if date == "" || expires == "" {
28-
t.Error("AddRecord should return non-empty date and expires values")
47+
if gotStatus != status {
48+
t.Fatalf("Expected status %d, got %d", status, gotStatus)
2949
}
50+
}
3051

31-
// Test GetRecord for an existing record
32-
returnedData, returnedStatus, returnedDate, returnedExpires, found := testCache.GetRecord(keys)
52+
func TestMemoryCacheMissingRecord(t *testing.T) {
53+
c := newMemoryCache[testStruct]()
3354

34-
if !found {
35-
t.Error("GetRecord should find the record")
55+
_, _, _, _, found := c.GetRecord([]string{"does", "not", "exist"})
56+
if found {
57+
t.Fatal("Expected record NOT to be found")
58+
}
59+
}
60+
61+
func TestMemoryCacheItemCount(t *testing.T) {
62+
c := newMemoryCache[testStruct]()
63+
64+
if c.GetItemCount() != 0 {
65+
t.Fatal("Expected empty cache")
66+
}
67+
68+
c.AddRecord([]string{"a"}, testStruct{"x"}, 200, "task1")
69+
c.AddRecord([]string{"b"}, testStruct{"y"}, 200, "task2")
70+
71+
if c.GetItemCount() != 2 {
72+
t.Fatalf("Expected 2 items, got %d", c.GetItemCount())
73+
}
74+
}
75+
76+
func TestMemoryCacheGetKeys(t *testing.T) {
77+
c := newMemoryCache[testStruct]()
78+
79+
c.AddRecord([]string{"k1"}, testStruct{"v1"}, 200, "task1")
80+
c.AddRecord([]string{"k2"}, testStruct{"v2"}, 200, "task2")
81+
82+
keys := c.GetKeys()
83+
84+
if len(keys) != 2 {
85+
t.Fatalf("Expected 2 keys, got %d", len(keys))
3686
}
3787

38-
// Validate GetRecord results
39-
if returnedData != data || returnedStatus != status || returnedDate == "" || returnedExpires == "" {
40-
t.Error("GetRecord returned unexpected values")
88+
// NOTE: key order in map is not stable
89+
found1, found2 := false, false
90+
for _, k := range keys {
91+
if k == "k1" {
92+
found1 = true
93+
}
94+
if k == "k2" {
95+
found2 = true
96+
}
4197
}
4298

43-
// Test GetItemCount
44-
itemCount := testCache.GetItemCount()
45-
if itemCount != 1 {
46-
t.Errorf("GetItemCount should return 1, but got %d", itemCount)
99+
if !found1 || !found2 {
100+
t.Fatalf("Expected keys k1 and k2, got %v", keys)
47101
}
102+
}
48103

49-
// Sleep for a while to allow the cache entry to expire
50-
time.Sleep(2 * time.Second)
104+
func TestMemoryCacheGetValues(t *testing.T) {
105+
c := newMemoryCache[testStruct]()
51106

52-
// Test GetRecord for an expired record
53-
_, _, _, _, found = testCache.GetRecord(keys)
107+
c.AddRecord([]string{"a"}, testStruct{"aaa"}, 200, "task1")
108+
c.AddRecord([]string{"b"}, testStruct{"bbb"}, 200, "task2")
54109

55-
if found {
56-
t.Error("GetRecord should not find an expired record")
110+
values := c.GetValues()
111+
112+
if len(values) != 2 {
113+
t.Fatalf("Expected 2 values, got %d", len(values))
114+
}
115+
116+
// Verify content
117+
foundA, foundB := false, false
118+
for _, v := range values {
119+
if v.Data.Value == "aaa" {
120+
foundA = true
121+
}
122+
if v.Data.Value == "bbb" {
123+
foundB = true
124+
}
57125
}
58126

59-
// Test GetItemCount after expiration
60-
itemCount = testCache.GetItemCount()
61-
if itemCount != 0 {
62-
t.Errorf("GetItemCount should return 0 after expiration, but got %d", itemCount)
127+
if !foundA || !foundB {
128+
t.Fatalf("Missing expected values, got %v", values)
63129
}
64130
}

0 commit comments

Comments
 (0)