Skip to content

Commit 289cd74

Browse files
author
Hein
committed
feat(database): ✨ Enhance Preload and Join functionality
* Introduce skipAutoDetect flag to prevent circular calls in PreloadRelation. * Improve handling of long alias chains in PreloadRelation. * Ensure JoinRelation uses PreloadRelation without causing recursion. * Clear deferred preloads after execution to prevent re-execution. feat(recursive_crud): ✨ Filter valid fields in nested CUD processing * Add filterValidFields method to validate input data against model structure. * Use reflection to ensure only valid fields are processed. feat(reflection): ✨ Add utility to get valid JSON field names * Implement GetValidJSONFieldNames to retrieve valid JSON field names from model. * Enhance field validation during nested CUD operations. fix(handler): 🐛 Adjust recursive preload depth limit * Change recursive preload depth limit from 5 to 4 to prevent excessive recursion.
1 parent c75842e commit 289cd74

File tree

6 files changed

+220
-75
lines changed

6 files changed

+220
-75
lines changed

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ require (
116116
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
117117
github.com/shopspring/decimal v1.4.0 // indirect
118118
github.com/sirupsen/logrus v1.9.3 // indirect
119-
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
120119
github.com/spf13/afero v1.15.0 // indirect
121120
github.com/spf13/cast v1.10.0 // indirect
122121
github.com/spf13/pflag v1.0.10 // indirect

go.sum

Lines changed: 3 additions & 56 deletions
Large diffs are not rendered by default.

pkg/common/adapters/database/bun.go

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ type BunSelectQuery struct {
211211
deferredPreloads []deferredPreload // Preloads to execute as separate queries
212212
inJoinContext bool // Track if we're in a JOIN relation context
213213
joinTableAlias string // Alias to use for JOIN conditions
214+
skipAutoDetect bool // Skip auto-detection to prevent circular calls
214215
}
215216

216217
// deferredPreload represents a preload that will be executed as a separate query
@@ -531,22 +532,25 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
531532
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
532533
// Auto-detect relationship type and choose optimal loading strategy
533534
// Get the model from the query if available
534-
model := b.query.GetModel()
535-
if model != nil && model.Value() != nil {
536-
relType := reflection.GetRelationType(model.Value(), relation)
535+
// Skip auto-detection if flag is set (prevents circular calls from JoinRelation)
536+
if !b.skipAutoDetect {
537+
model := b.query.GetModel()
538+
if model != nil && model.Value() != nil {
539+
relType := reflection.GetRelationType(model.Value(), relation)
537540

538-
// Log the detected relationship type
539-
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
541+
// Log the detected relationship type
542+
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
540543

541-
// If this is a belongs-to or has-one relation, use JOIN for better performance
542-
if relType.ShouldUseJoin() {
543-
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
544-
return b.JoinRelation(relation, apply...)
545-
}
544+
// If this is a belongs-to or has-one relation, use JOIN for better performance
545+
if relType.ShouldUseJoin() {
546+
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
547+
return b.JoinRelation(relation, apply...)
548+
}
546549

547-
// For has-many, many-to-many, or unknown: use separate query (safer default)
548-
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
549-
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
550+
// For has-many, many-to-many, or unknown: use separate query (safer default)
551+
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
552+
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
553+
}
550554
}
551555
}
552556

@@ -559,7 +563,7 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
559563
const safeAliasLimit = 35 // Leave room for column names
560564

561565
// If the alias chain is too long, defer this preload to be executed as a separate query
562-
if len(aliasChain) > safeAliasLimit {
566+
if len(relationParts) > 1 && len(aliasChain) > safeAliasLimit {
563567
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
564568
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
565569
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
@@ -683,6 +687,10 @@ func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.Sele
683687

684688
// Use PreloadRelation with the wrapped functions
685689
// Bun's Relation() will use JOIN for belongs-to and has-one relations
690+
// CRITICAL: Set skipAutoDetect flag to prevent circular call
691+
// (PreloadRelation would detect belongs-to and call JoinRelation again)
692+
b.skipAutoDetect = true
693+
defer func() { b.skipAutoDetect = false }()
686694
return b.PreloadRelation(relation, wrappedApply...)
687695
}
688696

@@ -742,6 +750,8 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
742750
logger.Warn("Failed to execute deferred preloads: %v", err)
743751
// Don't fail the whole query, just log the warning
744752
}
753+
// Clear deferred preloads to prevent re-execution
754+
b.deferredPreloads = nil
745755
}
746756

747757
return nil
@@ -810,6 +820,8 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
810820
logger.Warn("Failed to execute deferred preloads: %v", err)
811821
// Don't fail the whole query, just log the warning
812822
}
823+
// Clear deferred preloads to prevent re-execution
824+
b.deferredPreloads = nil
813825
}
814826

815827
return nil
@@ -898,13 +910,30 @@ func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record
898910
return nil
899911
}
900912

901-
// Get the interface value to pass to Bun
902-
parentValue := parentField.Interface()
913+
// Get a pointer to the parent field so Bun can modify it
914+
// CRITICAL: We need to pass a pointer, not a value, so that when Bun
915+
// loads the child records and appends them to the slice, the changes
916+
// are reflected in the original struct field.
917+
var parentPtr interface{}
918+
if parentField.Kind() == reflect.Ptr {
919+
// Field is already a pointer (e.g., Parent *Parent), use as-is
920+
parentPtr = parentField.Interface()
921+
} else {
922+
// Field is a value (e.g., Comments []Comment), get its address
923+
if parentField.CanAddr() {
924+
parentPtr = parentField.Addr().Interface()
925+
} else {
926+
return fmt.Errorf("cannot get address of field '%s'", parentRelation)
927+
}
928+
}
903929

904930
// Load the child relation on the parent record
905931
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
932+
// CRITICAL: Use WherePK() to ensure we only load children for THIS specific parent
933+
// record, not the first parent in the database table.
906934
return b.db.NewSelect().
907-
Model(parentValue).
935+
Model(parentPtr).
936+
WherePK().
908937
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
909938
// Apply any custom query modifications
910939
if len(apply) > 0 {

pkg/common/recursive_crud.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
9898
}
9999
}
100100

101+
// Filter regularData to only include fields that exist in the model
102+
// Use MapToStruct to validate and filter fields
103+
regularData = p.filterValidFields(regularData, model)
104+
101105
// Inject parent IDs for foreign key resolution
102106
p.injectForeignKeys(regularData, modelType, parentIDs)
103107

@@ -187,6 +191,115 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
187191
return ""
188192
}
189193

194+
// filterValidFields filters input data to only include fields that exist in the model
195+
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
196+
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
197+
if len(data) == 0 {
198+
return data
199+
}
200+
201+
// Create a new instance of the model to use with MapToStruct
202+
modelType := reflect.TypeOf(model)
203+
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
204+
modelType = modelType.Elem()
205+
}
206+
207+
if modelType == nil || modelType.Kind() != reflect.Struct {
208+
return data
209+
}
210+
211+
// Create a new instance of the model
212+
tempModel := reflect.New(modelType).Interface()
213+
214+
// Use MapToStruct to map the data - this will only map valid fields
215+
err := reflection.MapToStruct(data, tempModel)
216+
if err != nil {
217+
logger.Debug("Error mapping data to model: %v", err)
218+
return data
219+
}
220+
221+
// Extract the mapped fields back into a map
222+
// This effectively filters out any fields that don't exist in the model
223+
filteredData := make(map[string]interface{})
224+
tempModelValue := reflect.ValueOf(tempModel).Elem()
225+
226+
for key, value := range data {
227+
// Check if the field was successfully mapped
228+
if fieldWasMapped(tempModelValue, modelType, key) {
229+
filteredData[key] = value
230+
} else {
231+
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
232+
}
233+
}
234+
235+
return filteredData
236+
}
237+
238+
// fieldWasMapped checks if a field with the given key was mapped to the model
239+
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
240+
// Look for the field by JSON tag or field name
241+
for i := 0; i < modelType.NumField(); i++ {
242+
field := modelType.Field(i)
243+
244+
// Skip unexported fields
245+
if !field.IsExported() {
246+
continue
247+
}
248+
249+
// Check JSON tag
250+
jsonTag := field.Tag.Get("json")
251+
if jsonTag != "" && jsonTag != "-" {
252+
parts := strings.Split(jsonTag, ",")
253+
if len(parts) > 0 && parts[0] == key {
254+
return true
255+
}
256+
}
257+
258+
// Check bun tag
259+
bunTag := field.Tag.Get("bun")
260+
if bunTag != "" && bunTag != "-" {
261+
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
262+
return true
263+
}
264+
}
265+
266+
// Check gorm tag
267+
gormTag := field.Tag.Get("gorm")
268+
if gormTag != "" && gormTag != "-" {
269+
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
270+
return true
271+
}
272+
}
273+
274+
// Check lowercase field name
275+
if strings.EqualFold(field.Name, key) {
276+
return true
277+
}
278+
279+
// Handle embedded structs recursively
280+
if field.Anonymous {
281+
fieldType := field.Type
282+
if fieldType.Kind() == reflect.Ptr {
283+
fieldType = fieldType.Elem()
284+
}
285+
if fieldType.Kind() == reflect.Struct {
286+
embeddedValue := modelValue.Field(i)
287+
if embeddedValue.Kind() == reflect.Ptr {
288+
if embeddedValue.IsNil() {
289+
continue
290+
}
291+
embeddedValue = embeddedValue.Elem()
292+
}
293+
if fieldWasMapped(embeddedValue, fieldType, key) {
294+
return true
295+
}
296+
}
297+
}
298+
}
299+
300+
return false
301+
}
302+
190303
// injectForeignKeys injects parent IDs into data for foreign key fields
191304
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
192305
if len(parentIDs) == 0 {

pkg/reflection/model_utils.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,63 @@ func convertToFloat64(value interface{}) (float64, bool) {
13701370
return 0, false
13711371
}
13721372

1373+
// GetValidJSONFieldNames returns a map of valid JSON field names for a model
1374+
// This can be used to validate input data against a model's structure
1375+
// The map keys are the JSON field names (from json tags) that exist in the model
1376+
func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
1377+
validFields := make(map[string]bool)
1378+
1379+
// Unwrap pointers to get to the base struct type
1380+
for modelType != nil && modelType.Kind() == reflect.Pointer {
1381+
modelType = modelType.Elem()
1382+
}
1383+
1384+
if modelType == nil || modelType.Kind() != reflect.Struct {
1385+
return validFields
1386+
}
1387+
1388+
collectValidFieldNames(modelType, validFields)
1389+
return validFields
1390+
}
1391+
1392+
// collectValidFieldNames recursively collects valid JSON field names from a struct type
1393+
func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
1394+
for i := 0; i < typ.NumField(); i++ {
1395+
field := typ.Field(i)
1396+
1397+
// Skip unexported fields
1398+
if !field.IsExported() {
1399+
continue
1400+
}
1401+
1402+
// Check for embedded structs
1403+
if field.Anonymous {
1404+
fieldType := field.Type
1405+
if fieldType.Kind() == reflect.Ptr {
1406+
fieldType = fieldType.Elem()
1407+
}
1408+
if fieldType.Kind() == reflect.Struct {
1409+
// Recursively add fields from embedded struct
1410+
collectValidFieldNames(fieldType, validFields)
1411+
continue
1412+
}
1413+
}
1414+
1415+
// Get the JSON tag name for this field (same logic as MapToStruct)
1416+
jsonTag := field.Tag.Get("json")
1417+
if jsonTag != "" && jsonTag != "-" {
1418+
// Extract the field name from the JSON tag (before any options like omitempty)
1419+
parts := strings.Split(jsonTag, ",")
1420+
if len(parts) > 0 && parts[0] != "" {
1421+
validFields[parts[0]] = true
1422+
}
1423+
} else {
1424+
// If no JSON tag, use the field name in lowercase as a fallback
1425+
validFields[strings.ToLower(field.Name)] = true
1426+
}
1427+
}
1428+
}
1429+
13731430
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
13741431
// This is a helper function used by GetRelationModel to handle one level at a time
13751432
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {

pkg/restheadspec/handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
883883
})
884884

885885
// Handle recursive preloading
886-
if preload.Recursive && depth < 5 {
886+
if preload.Recursive && depth < 4 {
887887
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
888888

889889
// For recursive relationships, we need to get the last part of the relation path

0 commit comments

Comments
 (0)