Skip to content

Commit fb0dd93

Browse files
committed
feat(api,cli): users store
1 parent 5098d2f commit fb0dd93

18 files changed

+377
-132
lines changed

api/services/auth.go

+13-12
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"time"
1717

1818
"github.com/cnf/structhash"
19+
"github.com/shellhub-io/shellhub/api/store"
1920
"github.com/shellhub-io/shellhub/pkg/api/authorizer"
2021
"github.com/shellhub-io/shellhub/pkg/api/jwttoken"
2122
"github.com/shellhub-io/shellhub/pkg/api/requests"
@@ -187,15 +188,14 @@ func (s *service) AuthLocalUser(ctx context.Context, req *requests.AuthLocalUser
187188
return nil, 0, "", NewErrAuthMethodNotAllowed(models.UserAuthMethodLocal.String())
188189
}
189190

190-
var err error
191-
var user *models.User
192-
191+
ident := store.UserIdent("")
193192
if req.Identifier.IsEmail() {
194-
user, err = s.store.UserGetByEmail(ctx, strings.ToLower(string(req.Identifier)))
193+
ident = store.UserIdentEmail
195194
} else {
196-
user, err = s.store.UserGetByUsername(ctx, strings.ToLower(string(req.Identifier)))
195+
ident = store.UserIdentUsername
197196
}
198197

198+
user, err := s.store.UserGet(ctx, ident, strings.ToLower(string(req.Identifier)))
199199
if err != nil {
200200
return nil, 0, "", NewErrAuthUnathorized(nil)
201201
}
@@ -284,16 +284,16 @@ func (s *service) AuthLocalUser(ctx context.Context, req *requests.AuthLocalUser
284284
return nil, 0, "", NewErrTokenSigned(err)
285285
}
286286

287-
// Updates last_login and the hash algorithm to bcrypt if still using SHA256
288-
changes := &models.UserChanges{LastLogin: clock.Now(), PreferredNamespace: &tenantID}
289-
if !strings.HasPrefix(user.PasswordDigest, "$") {
287+
user.LastLogin = clock.Now()
288+
user.Preferences.PreferredNamespace = tenantID
289+
if !strings.HasPrefix(user.PasswordDigest, "$") { // Updates the hash algorithm to bcrypt only if still using SHA256
290290
if passwordDigest, _ := hash.Do(req.Password); passwordDigest != "" {
291-
changes.Password = passwordDigest
291+
user.PasswordDigest = passwordDigest
292292
}
293293
}
294294

295295
// TODO: evaluate make this update in a go routine.
296-
if err := s.store.UserUpdate(ctx, user.ID, changes); err != nil {
296+
if err := s.store.Save(ctx, user.ID); err != nil {
297297
return nil, 0, "", NewErrUserUpdate(user, err)
298298
}
299299

@@ -322,7 +322,7 @@ func (s *service) AuthLocalUser(ctx context.Context, req *requests.AuthLocalUser
322322
}
323323

324324
func (s *service) CreateUserToken(ctx context.Context, req *requests.CreateUserToken) (*models.UserAuthResponse, error) {
325-
user, _, err := s.store.UserGetByID(ctx, req.UserID, false)
325+
user, err := s.store.UserGet(ctx, store.UserIdentID, req.UserID)
326326
if err != nil {
327327
return nil, NewErrUserNotFound(req.UserID, err)
328328
}
@@ -366,7 +366,8 @@ func (s *service) CreateUserToken(ctx context.Context, req *requests.CreateUserT
366366
role = member.Role.String()
367367

368368
if user.Preferences.PreferredNamespace != namespace.TenantID {
369-
_ = s.store.UserUpdate(ctx, user.ID, &models.UserChanges{PreferredNamespace: &tenantID})
369+
user.Preferences.PreferredNamespace = tenantID
370+
_ = s.store.Save(ctx, user)
370371
}
371372
}
372373

api/services/member.go

+12-11
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (s *service) AddNamespaceMember(ctx context.Context, req *requests.Namespac
5353
return nil, NewErrNamespaceNotFound(req.TenantID, err)
5454
}
5555

56-
user, _, err := s.store.UserGetByID(ctx, req.UserID, false)
56+
user, err := s.store.UserGet(ctx, store.UserIdentID, req.UserID)
5757
if err != nil || user == nil {
5858
return nil, NewErrUserNotFound(req.UserID, err)
5959
}
@@ -71,7 +71,7 @@ func (s *service) AddNamespaceMember(ctx context.Context, req *requests.Namespac
7171
// In cloud instances, if the target user does not exist, we need to create a new user
7272
// with the specified email. We use the inserted ID to identify the user once they complete
7373
// the registration and accepts the invitation.
74-
passiveUser, err := s.store.UserGetByEmail(ctx, strings.ToLower(req.MemberEmail))
74+
passiveUser, err := s.store.UserGet(ctx, store.UserIdentEmail, strings.ToLower(req.MemberEmail))
7575
if err != nil {
7676
if !envs.IsCloud() || !errors.Is(err, store.ErrNoDocuments) {
7777
return nil, NewErrUserNotFound(req.MemberEmail, err)
@@ -161,7 +161,7 @@ func (s *service) UpdateNamespaceMember(ctx context.Context, req *requests.Names
161161
return NewErrNamespaceNotFound(req.TenantID, err)
162162
}
163163

164-
user, _, err := s.store.UserGetByID(ctx, req.UserID, false)
164+
user, err := s.store.UserGet(ctx, store.UserIdentID, req.UserID)
165165
if err != nil {
166166
return NewErrUserNotFound(req.UserID, err)
167167
}
@@ -198,7 +198,7 @@ func (s *service) RemoveNamespaceMember(ctx context.Context, req *requests.Names
198198
return nil, NewErrNamespaceNotFound(req.TenantID, err)
199199
}
200200

201-
user, _, err := s.store.UserGetByID(ctx, req.UserID, false)
201+
user, err := s.store.UserGet(ctx, store.UserIdentID, req.UserID)
202202
if err != nil {
203203
return nil, NewErrUserNotFound(req.UserID, err)
204204
}
@@ -251,13 +251,14 @@ func (s *service) LeaveNamespace(ctx context.Context, req *requests.LeaveNamespa
251251
return nil, nil
252252
}
253253

254-
emptyString := "" // just to be used as a pointer
255-
if err := s.store.UserUpdate(ctx, req.UserID, &models.UserChanges{PreferredNamespace: &emptyString}); err != nil {
256-
log.WithError(err).
257-
WithField("tenant_id", req.TenantID).
258-
WithField("user_id", req.UserID).
259-
Error("failed to reset user's preferred namespace")
260-
}
254+
// TODO: search for the user so we can use s.store.Save()
255+
// emptyString := "" // just to be used as a pointer
256+
// if err := s.store.UserUpdate(ctx, req.UserID, &models.UserChanges{PreferredNamespace: &emptyString}); err != nil {
257+
// log.WithError(err).
258+
// WithField("tenant_id", req.TenantID).
259+
// WithField("user_id", req.UserID).
260+
// Error("failed to reset user's preferred namespace")
261+
// }
261262

262263
if err := s.AuthUncacheToken(ctx, req.TenantID, req.UserID); err != nil {
263264
log.WithError(err).

api/services/namespace.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type NamespaceService interface {
2525

2626
// CreateNamespace creates a new namespace.
2727
func (s *service) CreateNamespace(ctx context.Context, req *requests.NamespaceCreate) (*models.Namespace, error) {
28-
user, _, err := s.store.UserGetByID(ctx, req.UserID, false)
28+
user, err := s.store.UserGet(ctx, store.UserIdentID, req.UserID)
2929
if err != nil || user == nil {
3030
return nil, NewErrUserNotFound(req.UserID, err)
3131
}

api/services/setup.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ func (s *service) Setup(ctx context.Context, req requests.Setup) error {
7979
},
8080
}
8181

82+
// TODO: use a transaction here
8283
if _, err = s.store.NamespaceCreate(ctx, namespace); err != nil {
83-
if err := s.store.UserDelete(ctx, insertedID); err != nil {
84+
if err := s.store.Delete(ctx, user); err != nil {
8485
return NewErrUserDelete(err)
8586
}
8687

api/services/user.go

+23-12
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ import (
44
"context"
55
"strings"
66

7+
"github.com/shellhub-io/shellhub/api/store"
78
"github.com/shellhub-io/shellhub/pkg/api/requests"
89
"github.com/shellhub-io/shellhub/pkg/hash"
910
"github.com/shellhub-io/shellhub/pkg/models"
11+
"golang.org/x/text/cases"
12+
"golang.org/x/text/language"
1013
)
1114

1215
type UserService interface {
@@ -23,7 +26,7 @@ type UserService interface {
2326
}
2427

2528
func (s *service) UpdateUser(ctx context.Context, req *requests.UpdateUser) ([]string, error) {
26-
user, _, err := s.store.UserGetByID(ctx, req.UserID, false)
29+
user, err := s.store.UserGet(ctx, store.UserIdentID, req.UserID)
2730
if err != nil {
2831
return []string{}, NewErrUserNotFound(req.UserID, nil)
2932
}
@@ -38,11 +41,20 @@ func (s *service) UpdateUser(ctx context.Context, req *requests.UpdateUser) ([]s
3841
return conflicts, NewErrUserDuplicated(conflicts, nil)
3942
}
4043

41-
changes := &models.UserChanges{
42-
Name: req.Name,
43-
Username: strings.ToLower(req.Username),
44-
Email: strings.ToLower(req.Email),
45-
RecoveryEmail: strings.ToLower(req.RecoveryEmail),
44+
if req.Name != "" {
45+
user.Name = cases.Title(language.AmericanEnglish).String(strings.ToLower(req.Name))
46+
}
47+
48+
if req.Username != "" {
49+
user.Username = strings.ToLower(req.Username)
50+
}
51+
52+
if req.Email != "" {
53+
user.Email = strings.ToLower(req.Email)
54+
}
55+
56+
if req.RecoveryEmail != "" {
57+
user.Preferences.SecurityEmail = strings.ToLower(req.RecoveryEmail)
4658
}
4759

4860
if req.Password != "" {
@@ -52,10 +64,10 @@ func (s *service) UpdateUser(ctx context.Context, req *requests.UpdateUser) ([]s
5264
}
5365

5466
passwordDigest, _ := hash.Do(req.Password)
55-
changes.Password = passwordDigest
67+
user.PasswordDigest = passwordDigest
5668
}
5769

58-
if err := s.store.UserUpdate(ctx, req.UserID, changes); err != nil {
70+
if err := s.store.Save(ctx, user); err != nil {
5971
return []string{}, NewErrUserUpdate(user, err)
6072
}
6173

@@ -66,7 +78,7 @@ func (s *service) UpdateUser(ctx context.Context, req *requests.UpdateUser) ([]s
6678
//
6779
// Deprecated, use [Service.UpdateUser] instead.
6880
func (s *service) UpdatePasswordUser(ctx context.Context, id, currentPassword, newPassword string) error {
69-
user, _, err := s.store.UserGetByID(ctx, id, false)
81+
user, err := s.store.UserGet(ctx, store.UserIdentID, id)
7082
if user == nil {
7183
return NewErrUserNotFound(id, err)
7284
}
@@ -75,12 +87,11 @@ func (s *service) UpdatePasswordUser(ctx context.Context, id, currentPassword, n
7587
return NewErrUserPasswordNotMatch(nil)
7688
}
7789

78-
passwordDigest, err := hash.Do(newPassword)
79-
if err != nil {
90+
if user.PasswordDigest, err = hash.Do(newPassword); err != nil {
8091
return NewErrUserPasswordInvalid(err)
8192
}
8293

83-
if err := s.store.UserUpdate(ctx, id, &models.UserChanges{Password: passwordDigest}); err != nil {
94+
if err := s.store.Save(ctx, user); err != nil {
8495
return NewErrUserUpdate(user, err)
8596
}
8697

api/store/pg/internal/dbtest/fixtures/.keep

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
- model: User
2+
rows:
3+
- id: 0195cefa-aa01-7efb-8098-c9c173056250
4+
created_at: 2025-01-15T10:30:00+00:00
5+
updated_at: 2025-01-15T10:30:00+00:00
6+
last_login: null
7+
status: confirmed
8+
origin: local
9+
external_id: ""
10+
name: Jonh Doe
11+
username: john_doe
12+
13+
security_email: [email protected]
14+
password_digest: "$2y$12$VVm2ETx7AvaGlfMYqNYK9uzU2M45YZ70YnT..O.s1o2zdE1pekhq6"
15+
auth_methods: [ local ]
16+
namespace_ownership_limit: -1
17+
email_marketing: true
18+
preferred_namespace_id: null

api/store/pg/internal/entity/user.go

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package entity
2+
3+
import (
4+
"github.com/shellhub-io/shellhub/pkg/models"
5+
"github.com/uptrace/bun"
6+
)
7+
8+
type User struct {
9+
bun.BaseModel `bun:"table:users"`
10+
models.User `bun:"embed:"`
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
BEGIN;
2+
3+
DROP TYPE IF EXISTS user_origin;
4+
DROP TYPE IF EXISTS user_status;
5+
DROP TYPE IF EXISTS user_auth_method;
6+
DROP TABLE users;
7+
8+
COMMIT;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
BEGIN;
2+
3+
DROP TYPE IF EXISTS user_origin;
4+
CREATE TYPE user_origin AS ENUM ('local', 'saml');
5+
6+
DROP TYPE IF EXISTS user_status;
7+
CREATE TYPE user_status AS ENUM ('invited', 'pending', 'confirmed');
8+
9+
DROP TYPE IF EXISTS user_auth_method;
10+
CREATE TYPE user_auth_method AS ENUM ('local', 'saml');
11+
12+
13+
CREATE TABLE IF NOT EXISTS users(
14+
id UUID PRIMARY KEY,
15+
16+
created_at TIMESTAMPTZ NOT NULL,
17+
updated_at TIMESTAMPTZ NOT NULL,
18+
last_login TIMESTAMPTZ,
19+
20+
origin user_origin NOT NULL,
21+
external_id VARCHAR,
22+
status user_status NOT NULL,
23+
name VARCHAR(64) NOT NULL,
24+
username VARCHAR(32) UNIQUE NOT NULL,
25+
email VARCHAR(320) UNIQUE NOT NULL,
26+
security_email VARCHAR(320),
27+
password_digest CHAR(72) NOT NULL,
28+
auth_methods user_auth_method[] NOT NULL,
29+
30+
namespace_ownership_limit INTEGER NOT NULL,
31+
email_marketing BOOLEAN NOT NULL,
32+
preferred_namespace_id UUID
33+
);
34+
35+
COMMIT;

api/store/pg/user.go

+33-19
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,25 @@ package pg
33
import (
44
"context"
55

6+
"github.com/shellhub-io/shellhub/api/store"
7+
"github.com/shellhub-io/shellhub/api/store/pg/internal/entity"
68
"github.com/shellhub-io/shellhub/pkg/api/query"
9+
"github.com/shellhub-io/shellhub/pkg/clock"
710
"github.com/shellhub-io/shellhub/pkg/models"
11+
"github.com/shellhub-io/shellhub/pkg/uuid"
12+
"github.com/uptrace/bun"
813
)
914

1015
func (pg *pg) UserCreate(ctx context.Context, user *models.User) (string, error) {
11-
return "", nil
16+
user.ID = uuid.Generate()
17+
user.CreatedAt = clock.Now()
18+
user.UpdatedAt = clock.Now()
19+
20+
if _, err := pg.driver.NewInsert().Model(&entity.User{User: *user}).Exec(ctx); err != nil {
21+
return "", err
22+
}
23+
24+
return user.ID, nil
1225
}
1326

1427
func (pg *pg) UserCreateInvited(ctx context.Context, email string) (string, error) {
@@ -17,34 +30,35 @@ func (pg *pg) UserCreateInvited(ctx context.Context, email string) (string, erro
1730
}
1831

1932
func (pg *pg) UserConflicts(ctx context.Context, target *models.UserConflicts) ([]string, bool, error) {
20-
return nil, false, nil
33+
users := make([]map[string]any, 0)
34+
if err := pg.driver.NewSelect().Model((*entity.User)(nil)).Column("email").Where("email = ?", target.Email).Scan(ctx, &users); err != nil {
35+
return nil, false, err
36+
}
37+
38+
conflicts := make([]string, 0)
39+
for _, user := range users {
40+
if user["email"] == target.Email {
41+
conflicts = append(conflicts, "email")
42+
}
43+
}
44+
45+
return conflicts, len(conflicts) > 0, nil
2146
}
2247

2348
func (pg *pg) UserList(ctx context.Context, paginator query.Paginator, filters query.Filters) ([]models.User, int, error) {
2449
return nil, 0, nil
2550
}
2651

27-
func (pg *pg) UserGetByID(ctx context.Context, id string, ns bool) (*models.User, int, error) {
28-
return nil, 0, nil
29-
}
52+
func (pg *pg) UserGet(ctx context.Context, ident store.UserIdent, val string) (*models.User, error) {
53+
u := new(entity.User)
54+
if err := pg.driver.NewSelect().Model(u).Where("? = ?", bun.Ident(ident), val).Scan(ctx); err != nil {
55+
return nil, fromSqlError(err)
56+
}
3057

31-
func (pg *pg) UserGetByUsername(ctx context.Context, username string) (*models.User, error) {
32-
return nil, nil
33-
}
34-
35-
func (pg *pg) UserGetByEmail(ctx context.Context, email string) (*models.User, error) {
36-
return nil, nil
58+
return &u.User, nil
3759
}
3860

3961
func (pg *pg) UserGetInfo(ctx context.Context, id string) (userInfo *models.UserInfo, err error) {
4062
// TODO: unify get methods
4163
return nil, nil
4264
}
43-
44-
func (pg *pg) UserUpdate(ctx context.Context, id string, changes *models.UserChanges) error {
45-
return nil
46-
}
47-
48-
func (pg *pg) UserDelete(ctx context.Context, id string) error {
49-
return nil
50-
}

0 commit comments

Comments
 (0)