Skip to content

Commit 74d6623

Browse files
authored
Merge pull request #228 from SkynetLabs/ivo/stripe_promote_on_checkout
Stripe promote on checkout
2 parents 59e344d + 6a97405 commit 74d6623

File tree

7 files changed

+502
-40
lines changed

7 files changed

+502
-40
lines changed

api/routes.go

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ func (api *API) buildHTTPRoutes() {
8181
// `POST /stripe/billing` is deprecated. Please use `GET /stripe/billing`.
8282
api.staticRouter.POST("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingHANDLER, false)))
8383
api.staticRouter.POST("/stripe/checkout", api.WithDBSession(api.withAuth(api.stripeCheckoutPOST, false)))
84+
api.staticRouter.GET("/stripe/checkout/:checkout_id", api.WithDBSession(api.withAuth(api.stripeCheckoutIDGET, false)))
8485
api.staticRouter.GET("/stripe/prices", api.noAuth(api.stripePricesGET))
8586
api.staticRouter.POST("/stripe/webhook", api.WithDBSession(api.noAuth(api.stripeWebhookPOST)))
8687

api/stripe.go

+182-10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/stripe/stripe-go/v72/sub"
2222
"github.com/stripe/stripe-go/v72/webhook"
2323
"gitlab.com/NebulousLabs/errors"
24+
"gitlab.com/SkynetLabs/skyd/build"
2425
)
2526

2627
const (
@@ -34,8 +35,22 @@ var (
3435
// `https://account.` prepended to it).
3536
DashboardURL = "https://account.siasky.net"
3637

37-
// True is a helper for when we need to pass a *bool to Stripe.
38-
True = true
38+
// ErrCheckoutWithoutCustomer is the error returned when a checkout session
39+
// doesn't have an associated customer
40+
ErrCheckoutWithoutCustomer = errors.New("this checkout session does not have an associated customer")
41+
// ErrCheckoutWithoutSub is the error returned when a checkout session doesn't
42+
// have an associated subscription
43+
ErrCheckoutWithoutSub = errors.New("this checkout session does not have an associated subscription")
44+
// ErrCheckoutDoesNotBelongToUser is returned when the given checkout
45+
// session does not belong to the current user. This might be a mistake or
46+
// might be an attempt for fraud.
47+
ErrCheckoutDoesNotBelongToUser = errors.New("checkout session does not belong to current user")
48+
// ErrSubNotActive is returned when the given subscription is not active, so
49+
// we cannot do anything based on it.
50+
ErrSubNotActive = errors.New("subscription not active")
51+
// ErrSubWithoutPrice is returned when the subscription doesn't have a
52+
// price, so we cannot determine the user's tier based on it.
53+
ErrSubWithoutPrice = errors.New("subscription does not have a price")
3954

4055
// stripePageSize defines the number of records we are going to request from
4156
// endpoints that support pagination.
@@ -62,7 +77,7 @@ var (
6277
)
6378

6479
type (
65-
// StripePrice ...
80+
// StripePrice describes a Stripe price item.
6681
StripePrice struct {
6782
ID string `json:"id"`
6883
Name string `json:"name"`
@@ -74,6 +89,42 @@ type (
7489
ProductID string `json:"productId"`
7590
LiveMode bool `json:"livemode"`
7691
}
92+
// SubscriptionGET describes a Stripe subscription for our front end needs.
93+
SubscriptionGET struct {
94+
Created int64 `json:"created"`
95+
CurrentPeriodStart int64 `json:"currentPeriodStart"`
96+
Discount *SubscriptionDiscountGET `json:"discount"`
97+
ID string `json:"id"`
98+
Plan *SubscriptionPlanGET `json:"plan"`
99+
StartDate int64 `json:"startDate"`
100+
Status string `json:"status"`
101+
}
102+
// SubscriptionDiscountGET describes a Stripe subscription discount for our
103+
// front end needs.
104+
SubscriptionDiscountGET struct {
105+
AmountOff int64 `json:"amountOff"`
106+
Currency string `json:"currency"`
107+
Duration string `json:"duration"`
108+
DurationInMonths int64 `json:"durationInMonths"`
109+
Name string `json:"name"`
110+
PercentOff float64 `json:"percentOff"`
111+
}
112+
// SubscriptionPlanGET describes a Stripe subscription plan for our front
113+
// end needs.
114+
SubscriptionPlanGET struct {
115+
Amount int64 `json:"amount"`
116+
Currency string `json:"currency"`
117+
Interval string `json:"interval"`
118+
IntervalCount int64 `json:"intervalCount"`
119+
Price string `json:"price"`
120+
Product *SubscriptionProductGET `json:"product"`
121+
}
122+
// SubscriptionProductGET describes a Stripe subscription product for our
123+
// front end needs.
124+
SubscriptionProductGET struct {
125+
Description string `json:"description"`
126+
Name string `json:"name"`
127+
}
77128
)
78129

79130
// processStripeSub reads the information about the user's subscription and
@@ -120,8 +171,8 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
120171
}
121172
// Cancel all subs aside from the latest one.
122173
p := stripe.SubscriptionCancelParams{
123-
InvoiceNow: &True,
124-
Prorate: &True,
174+
InvoiceNow: stripe.Bool(true),
175+
Prorate: stripe.Bool(true),
125176
}
126177
for _, subsc := range subs {
127178
if subsc == nil || (mostRecentSub != nil && subsc.ID == mostRecentSub.ID) {
@@ -199,7 +250,7 @@ func (api *API) stripeCheckoutPOST(u *database.User, w http.ResponseWriter, req
199250
cancelURL := DashboardURL + "/payments"
200251
successURL := DashboardURL + "/payments?session_id={CHECKOUT_SESSION_ID}"
201252
params := stripe.CheckoutSessionParams{
202-
AllowPromotionCodes: &True,
253+
AllowPromotionCodes: stripe.Bool(true),
203254
CancelURL: &cancelURL,
204255
ClientReferenceID: &u.Sub,
205256
Customer: &u.StripeID,
@@ -226,26 +277,147 @@ func (api *API) stripeCheckoutPOST(u *database.User, w http.ResponseWriter, req
226277
api.WriteJSON(w, response)
227278
}
228279

280+
// stripeCheckoutIDGET checks the status of a checkout session. If the checkout
281+
// is successful and results in a higher tier sub than the current one, we
282+
// upgrade the user to the new tier.
283+
func (api *API) stripeCheckoutIDGET(u *database.User, w http.ResponseWriter, req *http.Request, ps httprouter.Params) {
284+
checkoutSessionID := ps.ByName("checkout_id")
285+
subStr := "subscription"
286+
subDiscountStr := "subscription.discount"
287+
subPlanProductStr := "subscription.plan.product"
288+
params := &stripe.CheckoutSessionParams{
289+
Params: stripe.Params{
290+
Expand: []*string{&subStr, &subDiscountStr, &subPlanProductStr},
291+
},
292+
}
293+
cos, err := cosession.Get(checkoutSessionID, params)
294+
if err != nil {
295+
api.WriteError(w, err, http.StatusInternalServerError)
296+
return
297+
}
298+
if cos.Customer == nil {
299+
api.WriteError(w, ErrCheckoutWithoutCustomer, http.StatusBadRequest)
300+
return
301+
}
302+
if cos.Customer.ID != u.StripeID {
303+
api.WriteError(w, ErrCheckoutDoesNotBelongToUser, http.StatusForbidden)
304+
return
305+
}
306+
coSub := cos.Subscription
307+
if coSub == nil {
308+
api.WriteError(w, ErrCheckoutWithoutSub, http.StatusBadRequest)
309+
return
310+
}
311+
if coSub.Status != stripe.SubscriptionStatusActive {
312+
api.WriteError(w, ErrSubNotActive, http.StatusBadRequest)
313+
return
314+
}
315+
// Get the subscription price.
316+
if coSub.Items == nil || len(coSub.Items.Data) == 0 || coSub.Items.Data[0].Price == nil {
317+
api.WriteError(w, ErrSubWithoutPrice, http.StatusBadRequest)
318+
return
319+
}
320+
coSubPrice := coSub.Items.Data[0].Price
321+
tier, exists := StripePrices()[coSubPrice.ID]
322+
if !exists {
323+
err = fmt.Errorf("invalid price id '%s'", coSubPrice.ID)
324+
api.WriteError(w, err, http.StatusInternalServerError)
325+
build.Critical(errors.AddContext(err, "We somehow received an invalid price ID from Stripe. This might be caused by mismatched test/prod tokens or a breakdown in our Stripe setup."))
326+
return
327+
}
328+
// Promote the user, if needed.
329+
if tier > u.Tier {
330+
err = api.staticDB.UserSetTier(req.Context(), u, tier)
331+
if err != nil {
332+
api.WriteError(w, errors.AddContext(err, "failed to promote user"), http.StatusInternalServerError)
333+
return
334+
}
335+
}
336+
// Build the response DTO.
337+
var discountInfo *SubscriptionDiscountGET
338+
if coSub.Discount != nil {
339+
var coupon *stripe.Coupon
340+
// We can potentially fetch the discount coupon from two places - the
341+
// discount itself or its promotional code. We'll check them in order.
342+
if coSub.Discount.Coupon != nil {
343+
coupon = coSub.Discount.Coupon
344+
} else if coSub.Discount.PromotionCode != nil && coSub.Discount.PromotionCode.Coupon != nil {
345+
coupon = coSub.Discount.PromotionCode.Coupon
346+
}
347+
if coupon != nil {
348+
discountInfo = &SubscriptionDiscountGET{
349+
AmountOff: coupon.AmountOff,
350+
Currency: string(coupon.Currency),
351+
Duration: string(coupon.Duration),
352+
DurationInMonths: coupon.DurationInMonths,
353+
Name: coupon.Name,
354+
PercentOff: coupon.PercentOff,
355+
}
356+
}
357+
}
358+
var planInfo *SubscriptionPlanGET
359+
if coSub.Plan != nil {
360+
var productInfo *SubscriptionProductGET
361+
if coSub.Plan.Product != nil {
362+
productInfo = &SubscriptionProductGET{
363+
Description: coSub.Plan.Product.Description,
364+
Name: coSub.Plan.Product.Name,
365+
}
366+
}
367+
planInfo = &SubscriptionPlanGET{
368+
Amount: coSub.Plan.Amount,
369+
Currency: string(coSub.Plan.Currency),
370+
Interval: string(coSub.Plan.Interval),
371+
IntervalCount: coSub.Plan.IntervalCount,
372+
Price: coSub.Plan.ID, // plan ID and price ID are the same
373+
Product: productInfo,
374+
}
375+
}
376+
377+
subInfo := SubscriptionGET{
378+
Created: coSub.Created,
379+
CurrentPeriodStart: coSub.CurrentPeriodStart,
380+
Discount: discountInfo,
381+
ID: coSub.ID,
382+
Plan: planInfo,
383+
StartDate: coSub.StartDate,
384+
Status: string(coSub.Status),
385+
}
386+
api.WriteJSON(w, subInfo)
387+
}
388+
229389
// stripeCreateCustomer creates a Stripe customer record for this user and
230390
// updates the user in the database.
231391
func (api *API) stripeCreateCustomer(ctx context.Context, u *database.User) (string, error) {
232392
cus, err := customer.New(&stripe.CustomerParams{})
233393
if err != nil {
234394
return "", errors.AddContext(err, "failed to create Stripe customer")
235395
}
236-
stripeID := cus.ID
237-
err = api.staticDB.UserSetStripeID(ctx, u, stripeID)
396+
// We'll try to update the customer with the user's email and sub. We only
397+
// do this as an optional step, so we can match Stripe customers to local
398+
// users more easily. We do not care if this step fails - it's entirely
399+
// optional. It requires an additional round-trip to Stripe and we don't
400+
// need to wait for it to finish, so we'll do it in a separate goroutine.
401+
go func() {
402+
email := u.Email.String()
403+
updateParams := stripe.CustomerParams{
404+
Description: &u.Sub,
405+
Email: &email,
406+
}
407+
_, _ = customer.Update(cus.ID, &updateParams)
408+
}()
409+
err = api.staticDB.UserSetStripeID(ctx, u, cus.ID)
238410
if err != nil {
239411
return "", errors.AddContext(err, "failed to save user's StripeID")
240412
}
241-
return stripeID, nil
413+
return cus.ID, nil
242414
}
243415

244416
// stripePricesGET returns a list of plans and prices.
245417
func (api *API) stripePricesGET(_ *database.User, w http.ResponseWriter, _ *http.Request, _ httprouter.Params) {
246418
var sPrices []StripePrice
247419
params := &stripe.PriceListParams{
248-
Active: &True,
420+
Active: stripe.Bool(true),
249421
ListParams: stripe.ListParams{
250422
Limit: &stripePageSize,
251423
},

go.mod

+10-8
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ require (
88
github.com/julienschmidt/httprouter v1.3.0
99
github.com/lestrrat-go/jwx v1.2.25
1010
github.com/sirupsen/logrus v1.8.1
11-
github.com/stripe/stripe-go/v72 v72.115.0
11+
github.com/stripe/stripe-go/v72 v72.117.0
1212
gitlab.com/NebulousLabs/errors v0.0.0-20200929122200-06c536cf6975
1313
gitlab.com/NebulousLabs/fastrand v0.0.0-20181126182046-603482d69e40
14-
gitlab.com/SkynetLabs/skyd v1.5.10
14+
gitlab.com/SkynetLabs/skyd v1.6.0
1515
go.mongodb.org/mongo-driver v1.9.1
16-
go.sia.tech/siad v1.5.8
16+
go.sia.tech/siad v1.5.9-rc1
1717
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d
18+
gopkg.in/h2non/gock.v1 v1.1.2
1819
gopkg.in/mail.v2 v2.3.1
1920
)
2021

@@ -24,10 +25,11 @@ require (
2425
github.com/dchest/threefish v0.0.0-20120919164726-3ecf4c494abf // indirect
2526
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect
2627
github.com/go-stack/stack v1.8.1 // indirect
27-
github.com/goccy/go-json v0.9.7 // indirect
28+
github.com/goccy/go-json v0.9.8 // indirect
2829
github.com/golang/snappy v0.0.4 // indirect
2930
github.com/gorilla/websocket v1.5.0 // indirect
30-
github.com/klauspost/compress v1.15.6 // indirect
31+
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
32+
github.com/klauspost/compress v1.15.7 // indirect
3133
github.com/klauspost/cpuid/v2 v2.0.14 // indirect
3234
github.com/klauspost/reedsolomon v1.10.0 // indirect
3335
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
@@ -50,11 +52,11 @@ require (
5052
gitlab.com/NebulousLabs/merkletree v0.0.0-20200118113624-07fbf710afc4 // indirect
5153
gitlab.com/NebulousLabs/persist v0.0.0-20200605115618-007e5e23d877 // indirect
5254
gitlab.com/NebulousLabs/ratelimit v0.0.0-20200811080431-99b8f0768b2e // indirect
53-
gitlab.com/NebulousLabs/siamux v0.0.0-20220616144115-9831ef867730 // indirect
55+
gitlab.com/NebulousLabs/siamux v0.0.2-0.20220630142132-142a1443a259 // indirect
5456
gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213 // indirect
55-
golang.org/x/net v0.0.0-20220622184535-263ec571b305 // indirect
57+
golang.org/x/net v0.0.0-20220706163947-c90051bbdb60 // indirect
5658
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
57-
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 // indirect
59+
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect
5860
golang.org/x/text v0.3.7 // indirect
5961
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
6062
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect

0 commit comments

Comments
 (0)