Skip to content

Commit 688e8ea

Browse files
committed
Set accepts Assigner for Generics API
1 parent 1901911 commit 688e8ea

File tree

5 files changed

+86
-63
lines changed

5 files changed

+86
-63
lines changed

clause/set.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ type Assignment struct {
99
Value interface{}
1010
}
1111

12+
// Assigner assignments provider interface
13+
type Assigner interface {
14+
Assignments() []Assignment
15+
}
16+
1217
func (set Set) Name() string {
1318
return "SET"
1419
}
@@ -37,6 +42,9 @@ func (set Set) MergeClause(clause *Clause) {
3742
clause.Expression = Set(copiedAssignments)
3843
}
3944

45+
// Assignments implements Assigner for Set.
46+
func (set Set) Assignments() []Assignment { return []Assignment(set) }
47+
4048
func Assignments(values map[string]interface{}) Set {
4149
keys := make([]string, 0, len(values))
4250
for key := range values {
@@ -58,3 +66,6 @@ func AssignmentColumns(values []string) Set {
5866
}
5967
return assignments
6068
}
69+
70+
// Assignments implements Assigner for a single Assignment.
71+
func (a Assignment) Assignments() []Assignment { return []Assignment{a} }

clause/set_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ import (
99
"gorm.io/gorm/clause"
1010
)
1111

12+
// Compile-time assertions that types implement clause.Assigner
13+
var (
14+
_ clause.Assigner = clause.Assignment{}
15+
_ clause.Assigner = clause.Set{}
16+
)
17+
1218
func TestSet(t *testing.T) {
1319
results := []struct {
1420
Clauses []clause.Interface

generics.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ type CreateInterface[T any] interface {
6161
Table(name string, args ...interface{}) CreateInterface[T]
6262
Create(ctx context.Context, r *T) error
6363
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
64-
Set(assignments ...clause.Assignment) SetCreateOrUpdateInterface[T]
64+
Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T]
6565
}
6666

6767
type ChainInterface[T any] interface {
@@ -81,7 +81,7 @@ type ChainInterface[T any] interface {
8181
Group(name string) ChainInterface[T]
8282
Having(query interface{}, args ...interface{}) ChainInterface[T]
8383
Order(value interface{}) ChainInterface[T]
84-
Set(assignments ...clause.Assignment) SetUpdateOnlyInterface[T]
84+
Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T]
8585

8686
Build(builder clause.Builder)
8787

@@ -199,10 +199,8 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
199199
})}
200200
}
201201

202-
func (c createG[T]) Set(assignments ...clause.Assignment) SetCreateOrUpdateInterface[T] {
203-
assigns := make([]clause.Assignment, len(assignments))
204-
copy(assigns, assignments)
205-
return setCreateOrUpdateG[T]{c: c.chainG, assigns: assigns}
202+
func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] {
203+
return setCreateOrUpdateG[T]{c: c.chainG, assigns: toAssignments(assignments...)}
206204
}
207205

208206
func (c createG[T]) Create(ctx context.Context, r *T) error {
@@ -432,10 +430,8 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
432430
})
433431
}
434432

435-
func (c chainG[T]) Set(assignments ...clause.Assignment) SetUpdateOnlyInterface[T] {
436-
assigns := make([]clause.Assignment, len(assignments))
437-
copy(assigns, assignments)
438-
return setCreateOrUpdateG[T]{c: c, assigns: assigns}
433+
func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] {
434+
return setCreateOrUpdateG[T]{c: c, assigns: toAssignments(assignments...)}
439435
}
440436

441437
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
@@ -610,6 +606,16 @@ type setCreateOrUpdateG[T any] struct {
610606
assigns []clause.Assignment
611607
}
612608

609+
// toAssignments converts various supported types into []clause.Assignment.
610+
// Supported inputs implement clause.Assigner.
611+
func toAssignments(items ...clause.Assigner) []clause.Assignment {
612+
out := make([]clause.Assignment, 0, len(items))
613+
for _, it := range items {
614+
out = append(out, it.Assignments()...)
615+
}
616+
return out
617+
}
618+
613619
func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) {
614620
var r T
615621
res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{})

tests/generics_test.go

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -668,61 +668,61 @@ func TestGenericsDistinct(t *testing.T) {
668668
}
669669

670670
func TestGenericsSetCreate(t *testing.T) {
671-
ctx := context.Background()
672-
673-
name := "GenericsSetCreate"
674-
age := uint(21)
675-
676-
err := gorm.G[User](DB).Set(
677-
clause.Assignment{Column: clause.Column{Name: "name"}, Value: name},
678-
clause.Assignment{Column: clause.Column{Name: "age"}, Value: age},
679-
).Create(ctx)
680-
if err != nil {
681-
t.Fatalf("Set Create failed: %v", err)
682-
}
683-
684-
u, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
685-
if err != nil {
686-
t.Fatalf("failed to find created user: %v", err)
687-
}
688-
if u.ID == 0 || u.Name != name || u.Age != age {
689-
t.Fatalf("created user mismatch, got %+v", u)
690-
}
671+
ctx := context.Background()
672+
673+
name := "GenericsSetCreate"
674+
age := uint(21)
675+
676+
err := gorm.G[User](DB).Set(
677+
clause.Assignment{Column: clause.Column{Name: "name"}, Value: name},
678+
clause.Assignment{Column: clause.Column{Name: "age"}, Value: age},
679+
).Create(ctx)
680+
if err != nil {
681+
t.Fatalf("Set Create failed: %v", err)
682+
}
683+
684+
u, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
685+
if err != nil {
686+
t.Fatalf("failed to find created user: %v", err)
687+
}
688+
if u.ID == 0 || u.Name != name || u.Age != age {
689+
t.Fatalf("created user mismatch, got %+v", u)
690+
}
691691
}
692692

693693
func TestGenericsSetUpdate(t *testing.T) {
694-
ctx := context.Background()
695-
696-
// prepare
697-
u := User{Name: "GenericsSetUpdate_Before", Age: 30}
698-
if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
699-
t.Fatalf("prepare user failed: %v", err)
700-
}
701-
702-
// update with Set after chain
703-
newName := "GenericsSetUpdate_After"
704-
newAge := uint(31)
705-
rows, err := gorm.G[User](DB).
706-
Where("id = ?", u.ID).
707-
Set(
708-
clause.Assignment{Column: clause.Column{Name: "name"}, Value: newName},
709-
clause.Assignment{Column: clause.Column{Name: "age"}, Value: newAge},
710-
).
711-
Update(ctx)
712-
if err != nil {
713-
t.Fatalf("Set Update failed: %v", err)
714-
}
715-
if rows != 1 {
716-
t.Fatalf("expected 1 row affected, got %d", rows)
717-
}
718-
719-
nu, err := gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
720-
if err != nil {
721-
t.Fatalf("failed to query updated user: %v", err)
722-
}
723-
if nu.Name != newName || nu.Age != newAge {
724-
t.Fatalf("updated user mismatch, got %+v", nu)
725-
}
694+
ctx := context.Background()
695+
696+
// prepare
697+
u := User{Name: "GenericsSetUpdate_Before", Age: 30}
698+
if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
699+
t.Fatalf("prepare user failed: %v", err)
700+
}
701+
702+
// update with Set after chain
703+
newName := "GenericsSetUpdate_After"
704+
newAge := uint(31)
705+
rows, err := gorm.G[User](DB).
706+
Where("id = ?", u.ID).
707+
Set(
708+
clause.Assignment{Column: clause.Column{Name: "name"}, Value: newName},
709+
clause.Assignment{Column: clause.Column{Name: "age"}, Value: newAge},
710+
).
711+
Update(ctx)
712+
if err != nil {
713+
t.Fatalf("Set Update failed: %v", err)
714+
}
715+
if rows != 1 {
716+
t.Fatalf("expected 1 row affected, got %d", rows)
717+
}
718+
719+
nu, err := gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
720+
if err != nil {
721+
t.Fatalf("failed to query updated user: %v", err)
722+
}
723+
if nu.Name != newName || nu.Age != newAge {
724+
t.Fatalf("updated user mismatch, got %+v", nu)
725+
}
726726
}
727727

728728
func TestGenericsGroupHaving(t *testing.T) {

tests/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ require (
1212
gorm.io/driver/postgres v1.6.0
1313
gorm.io/driver/sqlite v1.6.0
1414
gorm.io/driver/sqlserver v1.6.1
15-
gorm.io/gorm v1.30.3
15+
gorm.io/gorm v1.30.4
1616
)
1717

1818
require (

0 commit comments

Comments
 (0)