Skip to content

Commit 0dcf2b3

Browse files
committed
Introduce an Email type that handles capitalization.
1 parent d1c3514 commit 0dcf2b3

17 files changed

+252
-108
lines changed

api/auth_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/SkynetLabs/skynet-accounts/database"
1212
"github.com/SkynetLabs/skynet-accounts/jwt"
13+
"github.com/SkynetLabs/skynet-accounts/types"
1314
"github.com/sirupsen/logrus"
1415
"gitlab.com/NebulousLabs/errors"
1516
"gitlab.com/NebulousLabs/fastrand"
@@ -47,7 +48,7 @@ func TestTokenFromRequest(t *testing.T) {
4748
if err != nil {
4849
t.Fatal(err)
4950
}
50-
tk, err := jwt.TokenForUser(t.Name()+"@siasky.net", t.Name()+"_sub")
51+
tk, err := jwt.TokenForUser(types.NewEmail(t.Name()+"@siasky.net"), t.Name()+"_sub")
5152
if err != nil {
5253
t.Fatal(err)
5354
}
@@ -97,7 +98,7 @@ func TestTokenFromRequest(t *testing.T) {
9798

9899
// Token from request with a header and a cookie. Expect the header to take
99100
// precedence.
100-
tk2, err := jwt.TokenForUser(t.Name()+"[email protected]", t.Name()+"2_sub")
101+
tk2, err := jwt.TokenForUser(types.NewEmail(t.Name()+"[email protected]"), t.Name()+"2_sub")
101102
if err != nil {
102103
t.Fatal(err)
103104
}

api/handlers.go

+17-16
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/SkynetLabs/skynet-accounts/lib"
2020
"github.com/SkynetLabs/skynet-accounts/metafetcher"
2121
"github.com/SkynetLabs/skynet-accounts/skynet"
22+
"github.com/SkynetLabs/skynet-accounts/types"
2223
"github.com/julienschmidt/httprouter"
2324
jwt2 "github.com/lestrrat-go/jwx/jwt"
2425
"gitlab.com/NebulousLabs/errors"
@@ -125,16 +126,16 @@ type (
125126

126127
// credentialsPOST defines the standard credentials package we expect.
127128
credentialsPOST struct {
128-
Email string `json:"email"`
129-
Password string `json:"password"`
129+
Email types.Email `json:"email"`
130+
Password string `json:"password"`
130131
}
131132

132133
// userUpdatePUT defines the fields of the User record that can be changed
133134
// externally, e.g. by calling `PUT /user`.
134135
userUpdatePUT struct {
135-
Email string `json:"email,omitempty"`
136-
Password string `json:"password,omitempty"`
137-
StripeID string `json:"stripeCustomerId,omitempty"`
136+
Email types.Email `json:"email,omitempty"`
137+
Password string `json:"password,omitempty"`
138+
StripeID string `json:"stripeCustomerId,omitempty"`
138139
}
139140
)
140141

@@ -231,7 +232,7 @@ func (api *API) loginPOSTChallengeResponse(w http.ResponseWriter, req *http.Requ
231232
}
232233

233234
// loginPOSTCredentials is a helper that handles logins with credentials.
234-
func (api *API) loginPOSTCredentials(w http.ResponseWriter, req *http.Request, email, password string) {
235+
func (api *API) loginPOSTCredentials(w http.ResponseWriter, req *http.Request, email types.Email, password string) {
235236
// Fetch the user with that email, if they exist.
236237
u, err := api.staticDB.UserByEmail(req.Context(), email)
237238
if err != nil {
@@ -388,8 +389,8 @@ func (api *API) registerPOST(_ *database.User, w http.ResponseWriter, req *http.
388389
api.WriteError(w, errors.AddContext(err, "failed to parse request body"), http.StatusBadRequest)
389390
return
390391
}
391-
parsed, err := mail.ParseAddress(payload.Email)
392-
if err != nil || payload.Email != parsed.Address {
392+
parsed, err := mail.ParseAddress(payload.Email.String())
393+
if err != nil || payload.Email.String() != parsed.Address {
393394
api.WriteError(w, errors.New("invalid email provided"), http.StatusBadRequest)
394395
return
395396
}
@@ -616,8 +617,8 @@ func (api *API) userPOST(_ *database.User, w http.ResponseWriter, req *http.Requ
616617
api.WriteError(w, errors.New("email is required"), http.StatusBadRequest)
617618
return
618619
}
619-
parsed, err := mail.ParseAddress(payload.Email)
620-
if err != nil || payload.Email != parsed.Address {
620+
parsed, err := mail.ParseAddress(payload.Email.String())
621+
if err != nil || payload.Email.String() != parsed.Address {
621622
api.WriteError(w, errors.New("invalid email provided"), http.StatusBadRequest)
622623
return
623624
}
@@ -714,8 +715,8 @@ func (api *API) userPUT(u *database.User, w http.ResponseWriter, req *http.Reque
714715

715716
var changedEmail bool
716717
if payload.Email != "" {
717-
parsed, err := mail.ParseAddress(payload.Email)
718-
if err != nil || payload.Email != parsed.Address {
718+
parsed, err := mail.ParseAddress(payload.Email.String())
719+
if err != nil || payload.Email.String() != parsed.Address {
719720
api.WriteError(w, errors.New("invalid email provided"), http.StatusBadRequest)
720721
return
721722
}
@@ -995,10 +996,10 @@ func (api *API) userRecoverRequestPOST(_ *database.User, w http.ResponseWriter,
995996
return
996997
}
997998

998-
// Read and parse the request body.
999-
var payload struct {
1000-
Email string `json:"email"`
1001-
}
999+
// Read and parse the request body. We do not expect a password but we want
1000+
// to use the same email parsing approach in all cases where we get an email
1001+
// address from the user.
1002+
var payload credentialsPOST
10021003
err = parseRequestBodyJSON(req.Body, LimitBodySizeSmall, &payload)
10031004
if err != nil {
10041005
err = errors.AddContext(err, "failed to parse request body")

api/upload.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"time"
66

77
"github.com/SkynetLabs/skynet-accounts/database"
8+
"github.com/SkynetLabs/skynet-accounts/types"
89
"github.com/julienschmidt/httprouter"
910
"gitlab.com/NebulousLabs/errors"
1011
"go.mongodb.org/mongo-driver/bson/primitive"
@@ -14,7 +15,7 @@ type (
1415
// UploaderInfo gives information about a user who created an upload.
1516
UploaderInfo struct {
1617
UserID primitive.ObjectID
17-
Email string
18+
Email types.Email
1819
Sub string
1920
StripeID string
2021
}

database/user.go

+11-10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/SkynetLabs/skynet-accounts/hash"
1212
"github.com/SkynetLabs/skynet-accounts/lib"
1313
"github.com/SkynetLabs/skynet-accounts/skynet"
14+
"github.com/SkynetLabs/skynet-accounts/types"
1415
"gitlab.com/NebulousLabs/errors"
1516
"gitlab.com/SkynetLabs/skyd/build"
1617
"go.mongodb.org/mongo-driver/bson"
@@ -110,7 +111,7 @@ type (
110111
// ID is auto-generated by Mongo on insert. We will usually use it in
111112
// its ID.Hex() form.
112113
ID primitive.ObjectID `bson:"_id,omitempty" json:"-"`
113-
Email string `bson:"email" json:"email"`
114+
Email types.Email `bson:"email" json:"email"`
114115
EmailConfirmationToken string `bson:"email_confirmation_token,omitempty" json:"-"`
115116
EmailConfirmationTokenExpiration time.Time `bson:"email_confirmation_token_expiration,omitempty" json:"-"`
116117
PasswordHash string `bson:"password_hash" json:"-"`
@@ -155,8 +156,8 @@ type (
155156
)
156157

157158
// UserByEmail returns the user with the given username.
158-
func (db *DB) UserByEmail(ctx context.Context, email string) (*User, error) {
159-
users, err := db.managedUsersByField(ctx, "email", email)
159+
func (db *DB) UserByEmail(ctx context.Context, email types.Email) (*User, error) {
160+
users, err := db.managedUsersByField(ctx, "email", email.String())
160161
if err != nil {
161162
return nil, err
162163
}
@@ -278,14 +279,14 @@ func (db *DB) UserConfirmEmail(ctx context.Context, token string) (*User, error)
278279
//
279280
// The new user is created as "unconfirmed" and a confirmation email is sent to
280281
// the address they provided.
281-
func (db *DB) UserCreate(ctx context.Context, emailAddr, pass, sub string, tier int) (*User, error) {
282+
func (db *DB) UserCreate(ctx context.Context, emailAddr types.Email, pass, sub string, tier int) (*User, error) {
282283
// Ensure the email is valid if it's passed. We allow empty emails.
283284
if emailAddr != "" {
284-
addr, err := mail.ParseAddress(emailAddr)
285+
addr, err := mail.ParseAddress(emailAddr.String())
285286
if err != nil {
286287
return nil, errors.AddContext(err, "invalid email address")
287288
}
288-
emailAddr = addr.Address
289+
emailAddr = types.NewEmail(addr.Address)
289290
}
290291
if sub == "" {
291292
return nil, errors.New("empty sub is not allowed")
@@ -382,14 +383,14 @@ func (db *DB) UserCreateEmailConfirmation(ctx context.Context, uID primitive.Obj
382383
//
383384
// The new user is created as "unconfirmed" and a confirmation email is sent to
384385
// the address they provided.
385-
func (db *DB) UserCreatePK(ctx context.Context, emailAddr, pass, sub string, pk PubKey, tier int) (*User, error) {
386+
func (db *DB) UserCreatePK(ctx context.Context, emailAddr types.Email, pass, sub string, pk PubKey, tier int) (*User, error) {
386387
// Validate the email.
387-
parsed, err := mail.ParseAddress(emailAddr)
388-
if err != nil || parsed.Address != emailAddr {
388+
parsed, err := mail.ParseAddress(emailAddr.String())
389+
if err != nil || parsed.Address != emailAddr.String() {
389390
return nil, errors.AddContext(err, "invalid email address")
390391
}
391392
// Check for an existing user with this email.
392-
users, err := db.managedUsersByField(ctx, "email", emailAddr)
393+
users, err := db.managedUsersByField(ctx, "email", emailAddr.String())
393394
if err != nil && !errors.Contains(err, ErrUserNotFound) {
394395
return nil, errors.AddContext(err, "failed to query DB")
395396
}

email/mailer.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55

66
"github.com/SkynetLabs/skynet-accounts/database"
7+
"github.com/SkynetLabs/skynet-accounts/types"
78
)
89

910
/**
@@ -35,15 +36,15 @@ func (em Mailer) Send(ctx context.Context, m database.EmailMessage) error {
3536

3637
// SendAddressConfirmationEmail sends a new email to the given email address
3738
// with a link to confirm the ownership of the address.
38-
func (em Mailer) SendAddressConfirmationEmail(ctx context.Context, email, token string) error {
39-
m := confirmEmailEmail(email, token)
39+
func (em Mailer) SendAddressConfirmationEmail(ctx context.Context, email types.Email, token string) error {
40+
m := confirmEmailEmail(email.String(), token)
4041
return em.Send(ctx, *m)
4142
}
4243

4344
// SendRecoverAccountEmail sends a new email to the given email address
4445
// with a link to recover the account.
45-
func (em Mailer) SendRecoverAccountEmail(ctx context.Context, email, token string) error {
46-
m := recoverAccountEmail(email, token)
46+
func (em Mailer) SendRecoverAccountEmail(ctx context.Context, email types.Email, token string) error {
47+
m := recoverAccountEmail(email.String(), token)
4748
return em.Send(ctx, *m)
4849
}
4950

@@ -52,7 +53,7 @@ func (em Mailer) SendRecoverAccountEmail(ctx context.Context, email, token strin
5253
// recover a Skynet account but their email is not in our system. The main
5354
// reason to do that is because the user might have forgotten which email they
5455
// used for signing up.
55-
func (em Mailer) SendAccountAccessAttemptedEmail(ctx context.Context, email string) error {
56-
m := accountAccessAttemptedEmail(email)
56+
func (em Mailer) SendAccountAccessAttemptedEmail(ctx context.Context, email types.Email) error {
57+
m := accountAccessAttemptedEmail(email.String())
5758
return em.Send(ctx, *m)
5859
}

jwt/jwt.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io/ioutil"
77
"time"
88

9+
"github.com/SkynetLabs/skynet-accounts/types"
910
"github.com/lestrrat-go/jwx/jwa"
1011
"github.com/lestrrat-go/jwx/jwk"
1112
"github.com/lestrrat-go/jwx/jwt"
@@ -74,7 +75,7 @@ func ContextWithToken(ctx context.Context, token jwt.Token) context.Context {
7475
//
7576
// The tokens generated by this function are a slimmed down version of the ones
7677
// described in ValidateToken's docstring.
77-
func TokenForUser(email, sub string) (jwt.Token, error) {
78+
func TokenForUser(email types.Email, sub string) (jwt.Token, error) {
7879
sigAlgo, key, err := signatureAlgoAndKey()
7980
if err != nil {
8081
return nil, err
@@ -252,15 +253,15 @@ func signatureAlgoAndKey() (jwa.SignatureAlgorithm, jwk.Key, error) {
252253

253254
// tokenForUser is a helper method that puts together an unsigned token based
254255
// on the provided values.
255-
func tokenForUser(emailAddr, sub string) (jwt.Token, error) {
256+
func tokenForUser(emailAddr types.Email, sub string) (jwt.Token, error) {
256257
if emailAddr == "" || sub == "" {
257258
return nil, errors.New("email and sub cannot be empty")
258259
}
259260
session := tokenSession{
260261
Active: true,
261262
Identity: tokenIdentity{
262263
Traits: tokenTraits{
263-
Email: emailAddr,
264+
Email: emailAddr.String(),
264265
},
265266
},
266267
}

jwt/jwt_test.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"testing"
77
"time"
88

9+
"github.com/SkynetLabs/skynet-accounts/types"
910
"github.com/lestrrat-go/jwx/jwa"
1011
"github.com/lestrrat-go/jwx/jwt"
1112
"github.com/sirupsen/logrus"
@@ -19,7 +20,7 @@ func TestJWT(t *testing.T) {
1920
if err != nil {
2021
t.Fatal(err)
2122
}
22-
email := t.Name() + "@siasky.net"
23+
email := types.NewEmail(t.Name() + "@siasky.net")
2324
sub := "this is a sub"
2425
fakeSub := "fake sub"
2526
tk, err := TokenForUser(email, sub)
@@ -59,7 +60,7 @@ func TestValidateToken_Expired(t *testing.T) {
5960
if err != nil {
6061
t.Fatal(err)
6162
}
62-
email := t.Name() + "@siasky.net"
63+
email := types.NewEmail(t.Name() + "@siasky.net")
6364
sub := "this is a sub"
6465
// Fetch the tools we need in order to craft a custom token.
6566
key, found := AccountsJWKS.Get(0)
@@ -81,7 +82,7 @@ func TestValidateToken_Expired(t *testing.T) {
8182
Active: true,
8283
Identity: tokenIdentity{
8384
Traits: tokenTraits{
84-
Email: email,
85+
Email: email.String(),
8586
},
8687
},
8788
}

test/api/api_test.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/SkynetLabs/skynet-accounts/api"
1212
"github.com/SkynetLabs/skynet-accounts/database"
1313
"github.com/SkynetLabs/skynet-accounts/test"
14+
"github.com/SkynetLabs/skynet-accounts/types"
1415
"gitlab.com/NebulousLabs/fastrand"
1516
"go.sia.tech/siad/build"
1617

@@ -33,9 +34,9 @@ func TestWithDBSession(t *testing.T) {
3334
t.Fatal("Failed to instantiate API.", err)
3435
}
3536

36-
emailSuccess := t.Name() + "[email protected]"
37-
emailSuccessJSON := t.Name() + "[email protected]"
38-
emailFailure := t.Name() + "[email protected]"
37+
emailSuccess := types.NewEmail(t.Name() + "[email protected]")
38+
emailSuccessJSON := types.NewEmail(t.Name() + "[email protected]")
39+
emailFailure := types.NewEmail(t.Name() + "[email protected]")
3940

4041
// This handler successfully creates a user in the DB and exits with
4142
// a success status code. We expect the user to exist in the DB after
@@ -52,7 +53,7 @@ func TestWithDBSession(t *testing.T) {
5253
t.Fatal("Failed to fetch user from DB.", err)
5354
}
5455
if u.Email != emailSuccess {
55-
t.Fatalf("Expected email %s, got %s.", emailSuccess, u.Email)
56+
t.Fatalf("Expected email '%v', got '%v'.", emailSuccess, u.Email)
5657
}
5758
testAPI.WriteSuccess(w)
5859
}
@@ -147,7 +148,7 @@ func TestUserTierCache(t *testing.T) {
147148
}
148149
}()
149150

150-
emailAddr := test.DBNameForTest(t.Name()) + "@siasky.net"
151+
emailAddr := types.NewEmail(test.DBNameForTest(t.Name()) + "@siasky.net")
151152
password := hex.EncodeToString(fastrand.Bytes(16))
152153
u, err := test.CreateUser(at, emailAddr, password)
153154
if err != nil {
@@ -165,7 +166,7 @@ func TestUserTierCache(t *testing.T) {
165166
if err != nil {
166167
t.Fatal(err)
167168
}
168-
r, _, err := at.LoginCredentialsPOST(emailAddr, password)
169+
r, _, err := at.LoginCredentialsPOST(emailAddr.String(), password)
169170
if err != nil {
170171
t.Fatal(err)
171172
}

0 commit comments

Comments
 (0)