Skip to content

Commit e015bd0

Browse files
authored
Add association operation support to generics Set API and enable conditional bulk association updates (#7581)
* Implement association operations in clause package with English comments (WIP) - Add clause/association.go with AssociationOperation struct and related types - Implement helper functions for creating association operations - Add support for association operations in generics Set method - Use English comments throughout the implementation - Maintain compatibility with existing GORM functionality * Add tests for clause.Association - Add comprehensive tests for Association struct - Test Assigner and AssociationAssigner interface implementations - Test different association operation types - Test AssociationAssignments method - Ensure compatibility with existing clause package functionality * Add integration tests for clause.Association - Add comprehensive integration tests for clause.Association in generics API - Test Association struct creation and interface implementations - Test different association operation types - Test AssociationAssignments method - Ensure compatibility with existing generics functionality * refactor code * Add association generics test * Add association tests * Manual Code Review@1 * Manual Code Review@2 * Manual Code Review@3 * Manual Code Review@4 * Update * generics: fix belongs-to Set(OpUpdate/OpDelete), refactor handler, add tests * refactor code * refactor code * Add more association test cases * refactor code * Refactor code
1 parent faee391 commit e015bd0

File tree

12 files changed

+1576
-43
lines changed

12 files changed

+1576
-43
lines changed

association.go

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ type Association struct {
1919
}
2020

2121
func (db *DB) Association(column string) *Association {
22-
association := &Association{DB: db}
22+
association := &Association{DB: db, Unscope: db.Statement.Unscoped}
2323
table := db.Statement.Table
2424

25-
if err := db.Statement.Parse(db.Statement.Model); err == nil {
25+
if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil {
2626
db.Statement.Table = table
2727
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
2828

@@ -34,8 +34,6 @@ func (db *DB) Association(column string) *Association {
3434
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
3535
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
3636
}
37-
} else {
38-
association.Error = err
3937
}
4038

4139
return association
@@ -58,6 +56,8 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro
5856
}
5957

6058
func (association *Association) Append(values ...interface{}) error {
59+
values = expandValues(values)
60+
6161
if association.Error == nil {
6262
switch association.Relationship.Type {
6363
case schema.HasOne, schema.BelongsTo:
@@ -73,6 +73,8 @@ func (association *Association) Append(values ...interface{}) error {
7373
}
7474

7575
func (association *Association) Replace(values ...interface{}) error {
76+
values = expandValues(values)
77+
7678
if association.Error == nil {
7779
reflectValue := association.DB.Statement.ReflectValue
7880
rel := association.Relationship
@@ -195,6 +197,8 @@ func (association *Association) Replace(values ...interface{}) error {
195197
}
196198

197199
func (association *Association) Delete(values ...interface{}) error {
200+
values = expandValues(values)
201+
198202
if association.Error == nil {
199203
var (
200204
reflectValue = association.DB.Statement.ReflectValue
@@ -431,10 +435,49 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
431435
}
432436
}
433437

438+
processMap := func(mapv reflect.Value) {
439+
child := reflect.New(association.Relationship.FieldSchema.ModelType)
440+
441+
switch association.Relationship.Type {
442+
case schema.HasMany:
443+
for _, ref := range association.Relationship.References {
444+
key := reflect.ValueOf(ref.ForeignKey.DBName)
445+
if ref.OwnPrimaryKey {
446+
v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source)
447+
mapv.SetMapIndex(key, v)
448+
} else if ref.PrimaryValue != "" {
449+
mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue))
450+
}
451+
}
452+
association.Error = association.DB.Session(&Session{
453+
NewDB: true,
454+
}).Model(child.Interface()).Create(mapv.Interface()).Error
455+
case schema.Many2Many:
456+
association.Error = association.DB.Session(&Session{
457+
NewDB: true,
458+
}).Model(child.Interface()).Create(mapv.Interface()).Error
459+
460+
for _, key := range mapv.MapKeys() {
461+
k := strings.ToLower(key.String())
462+
if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok {
463+
_ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface())
464+
}
465+
}
466+
appendToFieldValues(child)
467+
}
468+
}
469+
434470
switch rv.Kind() {
471+
case reflect.Map:
472+
processMap(rv)
435473
case reflect.Slice, reflect.Array:
436474
for i := 0; i < rv.Len(); i++ {
437-
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
475+
elem := reflect.Indirect(rv.Index(i))
476+
if elem.Kind() == reflect.Map {
477+
processMap(elem)
478+
continue
479+
}
480+
appendToFieldValues(elem.Addr())
438481
}
439482
case reflect.Struct:
440483
if !rv.CanAddr() {
@@ -591,3 +634,32 @@ func (association *Association) buildCondition() *DB {
591634

592635
return tx
593636
}
637+
638+
func expandValues(values ...any) (results []any) {
639+
appendToResult := func(rv reflect.Value) {
640+
// unwrap interface
641+
if rv.IsValid() && rv.Kind() == reflect.Interface {
642+
rv = rv.Elem()
643+
}
644+
if rv.IsValid() && rv.Kind() == reflect.Struct {
645+
p := reflect.New(rv.Type())
646+
p.Elem().Set(rv)
647+
results = append(results, p.Interface())
648+
} else if rv.IsValid() {
649+
results = append(results, rv.Interface())
650+
}
651+
}
652+
653+
// Process each argument; if an argument is a slice/array, expand its elements
654+
for _, value := range values {
655+
rv := reflect.ValueOf(value)
656+
if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array {
657+
for i := 0; i < rv.Len(); i++ {
658+
appendToResult(rv.Index(i))
659+
}
660+
} else {
661+
appendToResult(rv)
662+
}
663+
}
664+
return
665+
}

clause/association.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package clause
2+
3+
// AssociationOpType represents association operation types
4+
type AssociationOpType int
5+
6+
const (
7+
OpUnlink AssociationOpType = iota // Unlink association
8+
OpDelete // Delete association records
9+
OpUpdate // Update association records
10+
OpCreate // Create association records with assignments
11+
OpCreateValues // Create association records with model object
12+
)
13+
14+
// Association represents an association operation
15+
type Association struct {
16+
Association string // Association name
17+
Type AssociationOpType // Operation type
18+
Conditions []Expression // Filter conditions
19+
Set []Assignment // Assignment operations (for Update and Create)
20+
Model interface{} // Model object (for Create object)
21+
Values []interface{} // Values for Create operation
22+
}
23+
24+
// AssociationAssigner is an interface for association operation providers
25+
type AssociationAssigner interface {
26+
AssociationAssignments() []Association
27+
}
28+
29+
// Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter
30+
func (ao Association) Assignments() []Assignment {
31+
return []Assignment{}
32+
}
33+
34+
// AssociationAssignments implements the AssociationAssigner interface
35+
func (ao Association) AssociationAssignments() []Association {
36+
return []Association{ao}
37+
}

0 commit comments

Comments
 (0)