@@ -21,6 +21,7 @@ import (
21
21
"github.com/stripe/stripe-go/v72/sub"
22
22
"github.com/stripe/stripe-go/v72/webhook"
23
23
"gitlab.com/NebulousLabs/errors"
24
+ "gitlab.com/SkynetLabs/skyd/build"
24
25
)
25
26
26
27
const (
34
35
// `https://account.` prepended to it).
35
36
DashboardURL = "https://account.siasky.net"
36
37
37
- // True is a helper for when we need to pass a *bool to Stripe.
38
- True = true
38
+ // ErrCheckoutWithoutCustomer is the error returned when a checkout session
39
+ // doesn't have an associated customer
40
+ ErrCheckoutWithoutCustomer = errors .New ("this checkout session does not have an associated customer" )
41
+ // ErrCheckoutWithoutSub is the error returned when a checkout session doesn't
42
+ // have an associated subscription
43
+ ErrCheckoutWithoutSub = errors .New ("this checkout session does not have an associated subscription" )
44
+ // ErrCheckoutDoesNotBelongToUser is returned when the given checkout
45
+ // session does not belong to the current user. This might be a mistake or
46
+ // might be an attempt for fraud.
47
+ ErrCheckoutDoesNotBelongToUser = errors .New ("checkout session does not belong to current user" )
48
+ // ErrSubNotActive is returned when the given subscription is not active, so
49
+ // we cannot do anything based on it.
50
+ ErrSubNotActive = errors .New ("subscription not active" )
51
+ // ErrSubWithoutPrice is returned when the subscription doesn't have a
52
+ // price, so we cannot determine the user's tier based on it.
53
+ ErrSubWithoutPrice = errors .New ("subscription does not have a price" )
39
54
40
55
// stripePageSize defines the number of records we are going to request from
41
56
// endpoints that support pagination.
62
77
)
63
78
64
79
type (
65
- // StripePrice .. .
80
+ // StripePrice describes a Stripe price item .
66
81
StripePrice struct {
67
82
ID string `json:"id"`
68
83
Name string `json:"name"`
@@ -74,6 +89,42 @@ type (
74
89
ProductID string `json:"productId"`
75
90
LiveMode bool `json:"livemode"`
76
91
}
92
+ // SubscriptionGET describes a Stripe subscription for our front end needs.
93
+ SubscriptionGET struct {
94
+ Created int64 `json:"created"`
95
+ CurrentPeriodStart int64 `json:"currentPeriodStart"`
96
+ Discount * SubscriptionDiscountGET `json:"discount"`
97
+ ID string `json:"id"`
98
+ Plan * SubscriptionPlanGET `json:"plan"`
99
+ StartDate int64 `json:"startDate"`
100
+ Status string `json:"status"`
101
+ }
102
+ // SubscriptionDiscountGET describes a Stripe subscription discount for our
103
+ // front end needs.
104
+ SubscriptionDiscountGET struct {
105
+ AmountOff int64 `json:"amountOff"`
106
+ Currency string `json:"currency"`
107
+ Duration string `json:"duration"`
108
+ DurationInMonths int64 `json:"durationInMonths"`
109
+ Name string `json:"name"`
110
+ PercentOff float64 `json:"percentOff"`
111
+ }
112
+ // SubscriptionPlanGET describes a Stripe subscription plan for our front
113
+ // end needs.
114
+ SubscriptionPlanGET struct {
115
+ Amount int64 `json:"amount"`
116
+ Currency string `json:"currency"`
117
+ Interval string `json:"interval"`
118
+ IntervalCount int64 `json:"intervalCount"`
119
+ Price string `json:"price"`
120
+ Product * SubscriptionProductGET `json:"product"`
121
+ }
122
+ // SubscriptionProductGET describes a Stripe subscription product for our
123
+ // front end needs.
124
+ SubscriptionProductGET struct {
125
+ Description string `json:"description"`
126
+ Name string `json:"name"`
127
+ }
77
128
)
78
129
79
130
// processStripeSub reads the information about the user's subscription and
@@ -120,8 +171,8 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
120
171
}
121
172
// Cancel all subs aside from the latest one.
122
173
p := stripe.SubscriptionCancelParams {
123
- InvoiceNow : & True ,
124
- Prorate : & True ,
174
+ InvoiceNow : stripe . Bool ( true ) ,
175
+ Prorate : stripe . Bool ( true ) ,
125
176
}
126
177
for _ , subsc := range subs {
127
178
if subsc == nil || (mostRecentSub != nil && subsc .ID == mostRecentSub .ID ) {
@@ -199,7 +250,7 @@ func (api *API) stripeCheckoutPOST(u *database.User, w http.ResponseWriter, req
199
250
cancelURL := DashboardURL + "/payments"
200
251
successURL := DashboardURL + "/payments?session_id={CHECKOUT_SESSION_ID}"
201
252
params := stripe.CheckoutSessionParams {
202
- AllowPromotionCodes : & True ,
253
+ AllowPromotionCodes : stripe . Bool ( true ) ,
203
254
CancelURL : & cancelURL ,
204
255
ClientReferenceID : & u .Sub ,
205
256
Customer : & u .StripeID ,
@@ -226,26 +277,147 @@ func (api *API) stripeCheckoutPOST(u *database.User, w http.ResponseWriter, req
226
277
api .WriteJSON (w , response )
227
278
}
228
279
280
+ // stripeCheckoutIDGET checks the status of a checkout session. If the checkout
281
+ // is successful and results in a higher tier sub than the current one, we
282
+ // upgrade the user to the new tier.
283
+ func (api * API ) stripeCheckoutIDGET (u * database.User , w http.ResponseWriter , req * http.Request , ps httprouter.Params ) {
284
+ checkoutSessionID := ps .ByName ("checkout_id" )
285
+ subStr := "subscription"
286
+ subDiscountStr := "subscription.discount"
287
+ subPlanProductStr := "subscription.plan.product"
288
+ params := & stripe.CheckoutSessionParams {
289
+ Params : stripe.Params {
290
+ Expand : []* string {& subStr , & subDiscountStr , & subPlanProductStr },
291
+ },
292
+ }
293
+ cos , err := cosession .Get (checkoutSessionID , params )
294
+ if err != nil {
295
+ api .WriteError (w , err , http .StatusInternalServerError )
296
+ return
297
+ }
298
+ if cos .Customer == nil {
299
+ api .WriteError (w , ErrCheckoutWithoutCustomer , http .StatusBadRequest )
300
+ return
301
+ }
302
+ if cos .Customer .ID != u .StripeID {
303
+ api .WriteError (w , ErrCheckoutDoesNotBelongToUser , http .StatusForbidden )
304
+ return
305
+ }
306
+ coSub := cos .Subscription
307
+ if coSub == nil {
308
+ api .WriteError (w , ErrCheckoutWithoutSub , http .StatusBadRequest )
309
+ return
310
+ }
311
+ if coSub .Status != stripe .SubscriptionStatusActive {
312
+ api .WriteError (w , ErrSubNotActive , http .StatusBadRequest )
313
+ return
314
+ }
315
+ // Get the subscription price.
316
+ if coSub .Items == nil || len (coSub .Items .Data ) == 0 || coSub .Items .Data [0 ].Price == nil {
317
+ api .WriteError (w , ErrSubWithoutPrice , http .StatusBadRequest )
318
+ return
319
+ }
320
+ coSubPrice := coSub .Items .Data [0 ].Price
321
+ tier , exists := StripePrices ()[coSubPrice .ID ]
322
+ if ! exists {
323
+ err = fmt .Errorf ("invalid price id '%s'" , coSubPrice .ID )
324
+ api .WriteError (w , err , http .StatusInternalServerError )
325
+ build .Critical (errors .AddContext (err , "We somehow received an invalid price ID from Stripe. This might be caused by mismatched test/prod tokens or a breakdown in our Stripe setup." ))
326
+ return
327
+ }
328
+ // Promote the user, if needed.
329
+ if tier > u .Tier {
330
+ err = api .staticDB .UserSetTier (req .Context (), u , tier )
331
+ if err != nil {
332
+ api .WriteError (w , errors .AddContext (err , "failed to promote user" ), http .StatusInternalServerError )
333
+ return
334
+ }
335
+ }
336
+ // Build the response DTO.
337
+ var discountInfo * SubscriptionDiscountGET
338
+ if coSub .Discount != nil {
339
+ var coupon * stripe.Coupon
340
+ // We can potentially fetch the discount coupon from two places - the
341
+ // discount itself or its promotional code. We'll check them in order.
342
+ if coSub .Discount .Coupon != nil {
343
+ coupon = coSub .Discount .Coupon
344
+ } else if coSub .Discount .PromotionCode != nil && coSub .Discount .PromotionCode .Coupon != nil {
345
+ coupon = coSub .Discount .PromotionCode .Coupon
346
+ }
347
+ if coupon != nil {
348
+ discountInfo = & SubscriptionDiscountGET {
349
+ AmountOff : coupon .AmountOff ,
350
+ Currency : string (coupon .Currency ),
351
+ Duration : string (coupon .Duration ),
352
+ DurationInMonths : coupon .DurationInMonths ,
353
+ Name : coupon .Name ,
354
+ PercentOff : coupon .PercentOff ,
355
+ }
356
+ }
357
+ }
358
+ var planInfo * SubscriptionPlanGET
359
+ if coSub .Plan != nil {
360
+ var productInfo * SubscriptionProductGET
361
+ if coSub .Plan .Product != nil {
362
+ productInfo = & SubscriptionProductGET {
363
+ Description : coSub .Plan .Product .Description ,
364
+ Name : coSub .Plan .Product .Name ,
365
+ }
366
+ }
367
+ planInfo = & SubscriptionPlanGET {
368
+ Amount : coSub .Plan .Amount ,
369
+ Currency : string (coSub .Plan .Currency ),
370
+ Interval : string (coSub .Plan .Interval ),
371
+ IntervalCount : coSub .Plan .IntervalCount ,
372
+ Price : coSub .Plan .ID , // plan ID and price ID are the same
373
+ Product : productInfo ,
374
+ }
375
+ }
376
+
377
+ subInfo := SubscriptionGET {
378
+ Created : coSub .Created ,
379
+ CurrentPeriodStart : coSub .CurrentPeriodStart ,
380
+ Discount : discountInfo ,
381
+ ID : coSub .ID ,
382
+ Plan : planInfo ,
383
+ StartDate : coSub .StartDate ,
384
+ Status : string (coSub .Status ),
385
+ }
386
+ api .WriteJSON (w , subInfo )
387
+ }
388
+
229
389
// stripeCreateCustomer creates a Stripe customer record for this user and
230
390
// updates the user in the database.
231
391
func (api * API ) stripeCreateCustomer (ctx context.Context , u * database.User ) (string , error ) {
232
392
cus , err := customer .New (& stripe.CustomerParams {})
233
393
if err != nil {
234
394
return "" , errors .AddContext (err , "failed to create Stripe customer" )
235
395
}
236
- stripeID := cus .ID
237
- err = api .staticDB .UserSetStripeID (ctx , u , stripeID )
396
+ // We'll try to update the customer with the user's email and sub. We only
397
+ // do this as an optional step, so we can match Stripe customers to local
398
+ // users more easily. We do not care if this step fails - it's entirely
399
+ // optional. It requires an additional round-trip to Stripe and we don't
400
+ // need to wait for it to finish, so we'll do it in a separate goroutine.
401
+ go func () {
402
+ email := u .Email .String ()
403
+ updateParams := stripe.CustomerParams {
404
+ Description : & u .Sub ,
405
+ Email : & email ,
406
+ }
407
+ _ , _ = customer .Update (cus .ID , & updateParams )
408
+ }()
409
+ err = api .staticDB .UserSetStripeID (ctx , u , cus .ID )
238
410
if err != nil {
239
411
return "" , errors .AddContext (err , "failed to save user's StripeID" )
240
412
}
241
- return stripeID , nil
413
+ return cus . ID , nil
242
414
}
243
415
244
416
// stripePricesGET returns a list of plans and prices.
245
417
func (api * API ) stripePricesGET (_ * database.User , w http.ResponseWriter , _ * http.Request , _ httprouter.Params ) {
246
418
var sPrices []StripePrice
247
419
params := & stripe.PriceListParams {
248
- Active : & True ,
420
+ Active : stripe . Bool ( true ) ,
249
421
ListParams : stripe.ListParams {
250
422
Limit : & stripePageSize ,
251
423
},
0 commit comments