Skip to content

Commit d15e4dc

Browse files
authored
Merge pull request #221 from SkynetLabs/ivo/fix_stripe_subs_cancelling
Fix Stripe duplicate sub canceling.
2 parents 3b240ee + e5fad5d commit d15e4dc

File tree

5 files changed

+54
-23
lines changed

5 files changed

+54
-23
lines changed

api/stripe.go

+14-12
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ var (
3434
// `https://account.` prepended to it).
3535
DashboardURL = "https://account.siasky.net"
3636

37-
// StripeTestMode tells us whether to use Stripe's test mode or prod mode
38-
// plan and price ids. This depends on what kind of key is stored in the
39-
// STRIPE_API_KEY environment variable.
40-
StripeTestMode = false
41-
4237
// True is a helper for when we need to pass a *bool to Stripe.
4338
True = true
4439

@@ -96,8 +91,11 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
9691
Customer: s.Customer.ID,
9792
Status: string(stripe.SubscriptionStatusActive),
9893
})
99-
// Pick the latest active plan and set the user's tier based on that.
10094
subs := it.SubscriptionList().Data
95+
if len(subs) > 1 {
96+
api.staticLogger.Tracef("More than one active subscription detected: %+v", subs)
97+
}
98+
// Pick the latest active plan and set the user's tier based on that.
10199
var mostRecentSub *stripe.Subscription
102100
for _, subsc := range subs {
103101
if mostRecentSub == nil || subsc.Created > mostRecentSub.Created {
@@ -122,9 +120,6 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
122120
}
123121
// Cancel all subs aside from the latest one.
124122
p := stripe.SubscriptionCancelParams{
125-
Params: stripe.Params{
126-
StripeAccount: &s.Customer.ID,
127-
},
128123
InvoiceNow: &True,
129124
Prorate: &True,
130125
}
@@ -136,9 +131,10 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
136131
api.staticLogger.Warnf("Empty subscription ID! User ID '%s', Stripe ID '%s', subscription object '%+v'", u.ID.Hex(), u.StripeID, subs)
137132
continue
138133
}
139-
subsc, err = sub.Cancel(subsc.ID, &p)
134+
cs, err := sub.Cancel(subsc.ID, &p)
140135
if err != nil {
141136
api.staticLogger.Warnf("Failed to cancel sub with id '%s' for user '%s' with Stripe customer id '%s'. Error: '%s'", subsc.ID, u.ID.Hex(), s.Customer.ID, err.Error())
137+
api.staticLogger.Tracef("Sub information returned by Stripe: %+v", cs)
142138
} else {
143139
api.staticLogger.Tracef("Successfully cancelled sub with id '%s' for user '%s' with Stripe customer id '%s'.", subsc.ID, u.ID.Hex(), s.Customer.ID)
144140
}
@@ -329,7 +325,8 @@ func (api *API) stripeWebhookPOST(_ *database.User, w http.ResponseWriter, req *
329325
return
330326
}
331327
// Check the details about this subscription:
332-
s, err := sub.Get(hasSub.Sub, nil)
328+
var s *stripe.Subscription
329+
s, err = sub.Get(hasSub.Sub, nil)
333330
if err != nil {
334331
api.staticLogger.Debugln("Webhook: Failed to fetch sub:", err)
335332
api.WriteError(w, err, http.StatusInternalServerError)
@@ -365,8 +362,13 @@ func readStripeEvent(w http.ResponseWriter, req *http.Request) (*stripe.Event, i
365362

366363
// StripePrices returns a mapping of Stripe price ids to Skynet tiers.
367364
func StripePrices() map[string]int {
368-
if StripeTestMode {
365+
if StripeTestMode() {
369366
return stripePricesTest
370367
}
371368
return stripePricesProd
372369
}
370+
371+
// StripeTestMode tells us whether we're using a test key or a live key.
372+
func StripeTestMode() bool {
373+
return strings.HasPrefix(stripe.Key, "sk_test_")
374+
}

api/stripe_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package api
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
7+
"github.com/stripe/stripe-go/v72"
8+
)
9+
10+
// TestStripePrices ensures that we work with the correct set of prices.
11+
func TestStripePrices(t *testing.T) {
12+
// Set the Stripe key to a live key.
13+
stripe.Key = "sk_live_FAKE_LIVE_KEY"
14+
// Make sure we got the prod prices we expect.
15+
if !reflect.DeepEqual(StripePrices(), stripePricesProd) {
16+
t.Fatal("Expected prod prices, got something else.")
17+
}
18+
// Set the Stripe key to a test key.
19+
stripe.Key = "sk_test_FAKE_TEST_KEY"
20+
// Make sure we got the prod prices we expect.
21+
if !reflect.DeepEqual(StripePrices(), stripePricesTest) {
22+
t.Fatal("Expected test prices, got something else.")
23+
}
24+
}
25+
26+
// TestStripeTestMode ensures that we detect test mode accurately.
27+
func TestStripeTestMode(t *testing.T) {
28+
// Set the Stripe key to a live key.
29+
stripe.Key = "sk_live_FAKE_LIVE_KEY"
30+
// Expect test mode to be off.
31+
if StripeTestMode() {
32+
t.Fatal("Expected live mode, got test mode.")
33+
}
34+
// Set the Stripe key to a test key.
35+
stripe.Key = "sk_test_FAKE_TEST_KEY"
36+
// Expect test mode to be on.
37+
if !StripeTestMode() {
38+
t.Fatal("Expected test mode, got live mode.")
39+
}
40+
}

main.go

-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"net/url"
88
"os"
99
"strconv"
10-
"strings"
1110

1211
"github.com/SkynetLabs/skynet-accounts/api"
1312
"github.com/SkynetLabs/skynet-accounts/build"
@@ -70,7 +69,6 @@ type (
7069
PortalAddressAccounts string
7170
ServerLockID string
7271
StripeKey string
73-
StripeTestMode bool
7472
JWKSFile string
7573
JWTTTL int
7674
EmailURI string
@@ -145,7 +143,6 @@ func parseConfiguration(logger *logrus.Logger) (ServiceConfig, error) {
145143

146144
if sk := os.Getenv(envStripeAPIKey); sk != "" {
147145
config.StripeKey = sk
148-
config.StripeTestMode = !strings.HasPrefix(sk, "sk_live_")
149146
}
150147
if jwks := os.Getenv(envAccountsJWKSFile); jwks != "" {
151148
config.JWKSFile = jwks
@@ -231,7 +228,6 @@ func main() {
231228
api.DashboardURL = config.PortalAddressAccounts
232229
email.ServerLockID = config.ServerLockID
233230
stripe.Key = config.StripeKey
234-
api.StripeTestMode = config.StripeTestMode
235231
jwt.AccountsJWKSFile = config.JWKSFile
236232
jwt.TTL = config.JWTTTL
237233
email.From = config.EmailFrom

main_test.go

-6
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,6 @@ func TestParseConfiguration(t *testing.T) {
190190
if config.StripeKey != sk {
191191
t.Fatalf("Expected %s, got %s", sk, config.StripeKey)
192192
}
193-
if config.StripeTestMode {
194-
t.Fatal("Expected live mode.")
195-
}
196193
if config.ServerLockID != serverDomain {
197194
t.Fatalf("Expected %s, got %s", serverDomain, config.ServerLockID)
198195
}
@@ -227,9 +224,6 @@ func TestParseConfiguration(t *testing.T) {
227224
if config.StripeKey != sk {
228225
t.Fatalf("Expected %s, got %s", sk, config.StripeKey)
229226
}
230-
if !config.StripeTestMode {
231-
t.Fatal("Expected test mode.")
232-
}
233227
if config.MaxAPIKeys != maxKeys {
234228
t.Fatalf("Expected %d, got %d", maxKeys, config.MaxAPIKeys)
235229
}

test/api/stripe_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ func TestStripe(t *testing.T) {
3535
"Expected STRIPE_API_KEY that starts with '%s', got '%s'", t.Name(), "sk_test_", key[:8])
3636
}
3737
stripe.Key = key
38-
api.StripeTestMode = true
3938

4039
tests := map[string]func(t *testing.T, at *test.AccountsTester){
4140
"get billing": testStripeBillingGET,

0 commit comments

Comments
 (0)