Skip to content

Commit 6dff5dc

Browse files
Christopher Schinnerlkwypchlo
Christopher Schinnerl
authored andcommitted
Merge pull request #146 from SkynetLabs/ivo/fix_limits
Fix /user/limits quotaExceeded
1 parent e385fc9 commit 6dff5dc

10 files changed

+353
-194
lines changed

api/cache.go

+13-21
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ type (
2222
// userTierCacheEntry allows us to cache some basic information about the
2323
// user, so we don't need to hit the DB to fetch data that rarely changes.
2424
userTierCacheEntry struct {
25-
Tier int
26-
ExpiresAt time.Time
25+
Tier int
26+
QuotaExceeded bool
27+
ExpiresAt time.Time
2728
}
2829
)
2930

@@ -34,34 +35,25 @@ func newUserTierCache() *userTierCache {
3435
}
3536
}
3637

37-
// Get returns the user's tier and an OK indicator which is true when the cache
38-
// entry exists and hasn't expired, yet.
39-
func (utc *userTierCache) Get(sub string) (int, bool) {
38+
// Get returns the user's tier, a quota exceeded flag, and an OK indicator
39+
// which is true when the cache entry exists and hasn't expired, yet.
40+
func (utc *userTierCache) Get(sub string) (int, bool, bool) {
4041
utc.mu.Lock()
4142
ce, exists := utc.cache[sub]
4243
utc.mu.Unlock()
4344
if !exists || ce.ExpiresAt.Before(time.Now().UTC()) {
44-
return database.TierAnonymous, false
45+
return database.TierAnonymous, false, false
4546
}
46-
return ce.Tier, true
47+
return ce.Tier, ce.QuotaExceeded, true
4748
}
4849

4950
// Set stores the user's tier in the cache.
5051
func (utc *userTierCache) Set(u *database.User) {
51-
var ce userTierCacheEntry
52-
now := time.Now().UTC()
53-
if u.QuotaExceeded {
54-
ce = userTierCacheEntry{
55-
Tier: database.TierAnonymous,
56-
ExpiresAt: now.Add(userTierCacheTTL),
57-
}
58-
} else {
59-
ce = userTierCacheEntry{
60-
Tier: u.Tier,
61-
ExpiresAt: now.Add(userTierCacheTTL),
62-
}
63-
}
6452
utc.mu.Lock()
65-
utc.cache[u.Sub] = ce
53+
utc.cache[u.Sub] = userTierCacheEntry{
54+
Tier: u.Tier,
55+
QuotaExceeded: u.QuotaExceeded,
56+
ExpiresAt: time.Now().UTC().Add(userTierCacheTTL),
57+
}
6658
utc.mu.Unlock()
6759
}

api/cache_test.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,29 @@ func TestUserTierCache(t *testing.T) {
1717
QuotaExceeded: false,
1818
}
1919
// Get the user from the empty cache.
20-
tier, ok := cache.Get(u.Sub)
20+
tier, _, ok := cache.Get(u.Sub)
2121
if ok || tier != database.TierAnonymous {
2222
t.Fatalf("Expected to get tier %d and %t, got %d and %t.", database.TierAnonymous, false, tier, ok)
2323
}
24-
// Set the use in the cache.
24+
// Set the user in the cache.
2525
cache.Set(u)
2626
// Check again.
27-
tier, ok = cache.Get(u.Sub)
27+
tier, qe, ok := cache.Get(u.Sub)
2828
if !ok || tier != u.Tier {
2929
t.Fatalf("Expected to get tier %d and %t, got %d and %t.", u.Tier, true, tier, ok)
3030
}
31+
if qe != u.QuotaExceeded {
32+
t.Fatal("Quota exceeded flag doesn't match.")
33+
}
34+
u.QuotaExceeded = true
35+
cache.Set(u)
36+
tier, qe, ok = cache.Get(u.Sub)
37+
if !ok || tier != u.Tier {
38+
t.Fatalf("Expected to get tier %d and %t, got %d and %t.", u.Tier, true, tier, ok)
39+
}
40+
if qe != u.QuotaExceeded {
41+
t.Fatal("Quota exceeded flag doesn't match.")
42+
}
3143
ce, exists := cache.cache[u.Sub]
3244
if !exists {
3345
t.Fatal("Expected the entry to exist.")

api/handlers.go

+50-17
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,14 @@ type (
8080
}
8181
// UserLimitsGET is response of GET /user/limits
8282
UserLimitsGET struct {
83-
TierID int `json:"tierID"`
84-
database.TierLimits
83+
TierID int `json:"tierID"`
84+
TierName string `json:"tierName"`
85+
UploadBandwidth int `json:"upload"` // bytes per second
86+
DownloadBandwidth int `json:"download"` // bytes per second
87+
MaxUploadSize int64 `json:"maxUploadSize"` // the max size of a single upload in bytes
88+
MaxNumberUploads int `json:"-"`
89+
RegistryDelay int `json:"registry"` // ms delay
90+
Storage int64 `json:"-"`
8591
}
8692

8793
// accountRecoveryPOST defines the payload we expect when a user is trying
@@ -398,20 +404,23 @@ func (api *API) userGET(u *database.User, w http.ResponseWriter, _ *http.Request
398404
func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) {
399405
// First check for an API key.
400406
ak, err := apiKeyFromRequest(req)
401-
respAnon := UserLimitsGET{
402-
TierID: database.TierAnonymous,
403-
TierLimits: database.UserLimits[database.TierAnonymous],
404-
}
407+
respAnon := userLimitsGetFromTier(database.TierAnonymous)
405408
if err == nil {
406409
u, err := api.staticDB.UserByAPIKey(req.Context(), ak)
407410
if err != nil {
408411
api.staticLogger.Traceln("Error while fetching user by API key:", err)
409412
api.WriteJSON(w, respAnon)
410413
return
411414
}
412-
resp := UserLimitsGET{
413-
TierID: u.Tier,
414-
TierLimits: database.UserLimits[u.Tier],
415+
resp := userLimitsGetFromTier(u.Tier)
416+
// If the quota is exceeded we should keep the user's tier but report
417+
// anonymous-level speeds.
418+
if u.QuotaExceeded {
419+
// Report the speeds for tier anonymous.
420+
resp = userLimitsGetFromTier(database.TierAnonymous)
421+
// But keep reporting the user's actual tier and it's name.
422+
resp.TierID = u.Tier
423+
resp.TierName = database.UserLimits[u.Tier].TierName
415424
}
416425
api.WriteJSON(w, resp)
417426
return
@@ -430,7 +439,7 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http
430439
sub := s.(string)
431440
// If the user is not cached, or they were cached too long ago we'll fetch
432441
// their data from the DB.
433-
tier, ok := api.staticUserTierCache.Get(sub)
442+
tier, qe, ok := api.staticUserTierCache.Get(sub)
434443
if !ok {
435444
u, err := api.staticDB.UserBySub(req.Context(), sub)
436445
if err != nil {
@@ -439,14 +448,22 @@ func (api *API) userLimitsGET(_ *database.User, w http.ResponseWriter, req *http
439448
return
440449
}
441450
api.staticUserTierCache.Set(u)
451+
// Populate the tier and qe values, while simultaneously making sure
452+
// that we can read the record from the cache.
453+
tier, qe, ok = api.staticUserTierCache.Get(sub)
454+
if !ok {
455+
build.Critical("Failed to fetch user from UserTierCache right after setting it.")
456+
}
442457
}
443-
tier, ok = api.staticUserTierCache.Get(sub)
444-
if !ok {
445-
build.Critical("Failed to fetch user from UserTierCache right after setting it.")
446-
}
447-
resp := UserLimitsGET{
448-
TierID: tier,
449-
TierLimits: database.UserLimits[tier],
458+
resp := userLimitsGetFromTier(tier)
459+
// If the quota is exceeded we should keep the user's tier but report
460+
// anonymous-level speeds.
461+
if qe {
462+
// Report anonymous speeds.
463+
resp = userLimitsGetFromTier(database.TierAnonymous)
464+
// Keep reporting the user's actual tier and tier name.
465+
resp.TierID = tier
466+
resp.TierName = database.UserLimits[tier].TierName
450467
}
451468
api.WriteJSON(w, resp)
452469
}
@@ -1177,3 +1194,19 @@ func fetchPageSize(form url.Values) (int, error) {
11771194
func parseRequestBodyJSON(body io.ReadCloser, maxBodySize int64, objRef interface{}) error {
11781195
return json.NewDecoder(io.LimitReader(body, maxBodySize)).Decode(&objRef)
11791196
}
1197+
1198+
// userLimitsGetFromTier is a helper that lets us succinctly translate
1199+
// from the database DTO to the API DTO.
1200+
func userLimitsGetFromTier(tier int) *UserLimitsGET {
1201+
t := database.UserLimits[tier]
1202+
return &UserLimitsGET{
1203+
TierID: tier,
1204+
TierName: t.TierName,
1205+
UploadBandwidth: t.UploadBandwidth,
1206+
DownloadBandwidth: t.DownloadBandwidth,
1207+
MaxUploadSize: t.MaxUploadSize,
1208+
MaxNumberUploads: t.MaxNumberUploads,
1209+
RegistryDelay: t.RegistryDelay,
1210+
Storage: t.Storage,
1211+
}
1212+
}

database/challenge.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ type (
9494
// NewChallenge creates a new challenge with the given type and pubKey.
9595
func (db *DB) NewChallenge(ctx context.Context, pubKey PubKey, cType string) (*Challenge, error) {
9696
if cType != ChallengeTypeLogin && cType != ChallengeTypeRegister && cType != ChallengeTypeUpdate {
97-
return nil, errors.New(fmt.Sprintf("invalid challenge type '%s'", cType))
97+
return nil, fmt.Errorf("invalid challenge type '%s'", cType)
9898
}
9999
ch := &Challenge{
100100
Challenge: hex.EncodeToString(fastrand.Bytes(ChallengeSize)),

test/api/api_test.go

+46-42
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/hex"
7-
"encoding/json"
7+
"fmt"
88
"net/http"
99
"net/url"
1010
"testing"
@@ -14,6 +14,7 @@ import (
1414
"github.com/SkynetLabs/skynet-accounts/database"
1515
"github.com/SkynetLabs/skynet-accounts/test"
1616
"gitlab.com/NebulousLabs/fastrand"
17+
"go.sia.tech/siad/build"
1718

1819
"github.com/julienschmidt/httprouter"
1920
"github.com/sirupsen/logrus"
@@ -198,15 +199,9 @@ func TestUserTierCache(t *testing.T) {
198199
if err != nil {
199200
t.Fatal(err)
200201
}
201-
at.Cookie = test.ExtractCookie(r)
202-
// Get the user's limit. Since they are on a Pro account but their
203-
// SubscribedUntil is set in the past, we expect to get TierFree.
204-
_, b, err := at.Get("/user/limits", nil)
205-
if err != nil {
206-
t.Fatal(err)
207-
}
208-
var ul api.UserLimitsGET
209-
err = json.Unmarshal(b, &ul)
202+
at.SetCookie(test.ExtractCookie(r))
203+
// Get the user's limit.
204+
ul, _, err := at.UserLimits()
210205
if err != nil {
211206
t.Fatal(err)
212207
}
@@ -216,12 +211,11 @@ func TestUserTierCache(t *testing.T) {
216211
if ul.TierID != database.TierPremium20 {
217212
t.Fatalf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID)
218213
}
219-
// Now set their SubscribedUntil in the future, so their subscription tier
220-
// is active.
221-
u.SubscribedUntil = time.Now().UTC().Add(365 * 24 * time.Hour)
222-
err = at.DB.UserSave(at.Ctx, u.User)
223-
if err != nil {
224-
t.Fatal(err)
214+
if ul.TierName != database.UserLimits[database.TierPremium20].TierName {
215+
t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName)
216+
}
217+
if ul.UploadBandwidth != database.UserLimits[database.TierPremium20].UploadBandwidth {
218+
t.Fatalf("Expected upload bandwidth '%d', got '%d'", database.UserLimits[database.TierPremium20].UploadBandwidth, ul.UploadBandwidth)
225219
}
226220
// Register a test upload that exceeds the user's allowed storage, so their
227221
// QuotaExceeded flag will get raised.
@@ -232,45 +226,55 @@ func TestUserTierCache(t *testing.T) {
232226
// Make a specific call to trackUploadPOST in order to trigger the
233227
// checkUserQuotas method. This wil register the upload a second time but
234228
// that doesn't affect the test.
235-
_, _, err = at.Post("/track/upload/"+sl.Skylink, nil, nil)
236-
if err != nil {
237-
t.Fatal(err)
238-
}
239-
// Sleep for a short time in order to make sure that the background
240-
// goroutine that updates user's quotas has had time to run.
241-
time.Sleep(2 * time.Second)
242-
// We expect to get TierAnonymous.
243-
_, b, err = at.Get("/user/limits", nil)
229+
_, err = at.TrackUpload(sl.Skylink)
244230
if err != nil {
245231
t.Fatal(err)
246232
}
247-
err = json.Unmarshal(b, &ul)
233+
// We need to try this several times because we'll only get the right result
234+
// after the background goroutine that updates user's quotas has had time to
235+
// run.
236+
err = build.Retry(10, 200*time.Millisecond, func() error {
237+
// We expect to get tier with name and id matching TierPremium20 but with
238+
// speeds matching TierAnonymous.
239+
ul, _, err = at.UserLimits()
240+
if err != nil {
241+
t.Fatal(err)
242+
}
243+
if ul.TierID != database.TierPremium20 {
244+
return fmt.Errorf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID)
245+
}
246+
if ul.TierName != database.UserLimits[database.TierPremium20].TierName {
247+
return fmt.Errorf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName)
248+
}
249+
if ul.UploadBandwidth != database.UserLimits[database.TierAnonymous].UploadBandwidth {
250+
return fmt.Errorf("Expected upload bandwidth '%d', got '%d'", database.UserLimits[database.TierAnonymous].UploadBandwidth, ul.UploadBandwidth)
251+
}
252+
return nil
253+
})
248254
if err != nil {
249255
t.Fatal(err)
250256
}
251-
if ul.TierID != database.TierAnonymous {
252-
t.Fatalf("Expected tier id '%d', got '%d'", database.TierAnonymous, ul.TierID)
253-
}
254-
if ul.TierName != database.UserLimits[database.TierAnonymous].TierName {
255-
t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierAnonymous].TierName, ul.TierName)
256-
}
257257
// Delete the uploaded file, so the user's quota recovers.
258258
// This call should invalidate the tier cache.
259259
_, _, err = at.Delete("/user/uploads/"+sl.Skylink, nil)
260-
time.Sleep(2 * time.Second)
261-
// We expect to get TierPremium20.
262-
_, b, err = at.Get("/user/limits", nil)
263260
if err != nil {
264261
t.Fatal(err)
265262
}
266-
err = json.Unmarshal(b, &ul)
263+
err = build.Retry(10, 200*time.Millisecond, func() error {
264+
// We expect to get TierPremium20.
265+
ul, _, err = at.UserLimits()
266+
if err != nil {
267+
return errors.AddContext(err, "failed to call /user/limits")
268+
}
269+
if ul.TierID != database.TierPremium20 {
270+
return fmt.Errorf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID)
271+
}
272+
if ul.TierName != database.UserLimits[database.TierPremium20].TierName {
273+
return fmt.Errorf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName)
274+
}
275+
return nil
276+
})
267277
if err != nil {
268278
t.Fatal(err)
269279
}
270-
if ul.TierID != database.TierPremium20 {
271-
t.Fatalf("Expected tier id '%d', got '%d'", database.TierPremium20, ul.TierID)
272-
}
273-
if ul.TierName != database.UserLimits[database.TierPremium20].TierName {
274-
t.Fatalf("Expected tier name '%s', got '%s'", database.UserLimits[database.TierPremium20].TierName, ul.TierName)
275-
}
276280
}

test/api/apikeys_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func testAPIKeysFlow(t *testing.T, at *test.AccountsTester) {
2020
if err != nil {
2121
t.Fatal(err, string(body))
2222
}
23-
at.Cookie = test.ExtractCookie(r)
23+
at.SetCookie(test.ExtractCookie(r))
2424

2525
aks := make([]database.APIKeyRecord, 0)
2626

@@ -115,7 +115,7 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
115115
if err != nil {
116116
t.Fatal(err)
117117
}
118-
at.Cookie = test.ExtractCookie(r)
118+
at.SetCookie(test.ExtractCookie(r))
119119
// Get the user and create a test upload, so the stats won't be all zeros.
120120
u, err := at.DB.UserByEmail(at.Ctx, email)
121121
if err != nil {
@@ -132,7 +132,7 @@ func testAPIKeysUsage(t *testing.T, at *test.AccountsTester) {
132132
t.Fatal(err)
133133
}
134134
// Stop using the cookie, so we can test the API key.
135-
at.Cookie = nil
135+
at.ClearCredentials()
136136
// We use a custom struct and not the APIKeyRecord one because that one does
137137
// not render the key in JSON form and therefore it won't unmarshal it,
138138
// either.

0 commit comments

Comments
 (0)