Skip to content

Commit e70bab9

Browse files
author
Hein
committed
feat(tests): 🎉 More test for preload fixes.
* Implement tests for SanitizeWhereClause and AddTablePrefixToColumns. * Ensure correct handling of table prefixes in WHERE clauses. * Validate that unqualified columns are prefixed correctly when necessary. * Add tests for XFiles processing to verify table name handling. * Introduce tests for recursive preloads and their related keys.
1 parent fc8f44e commit e70bab9

File tree

8 files changed

+483
-324
lines changed

8 files changed

+483
-324
lines changed

pkg/common/adapters/database/bun.go

Lines changed: 11 additions & 261 deletions
Original file line numberDiff line numberDiff line change
@@ -202,23 +202,15 @@ func (b *BunAdapter) GetUnderlyingDB() interface{} {
202202

203203
// BunSelectQuery implements SelectQuery for Bun
204204
type BunSelectQuery struct {
205-
query *bun.SelectQuery
206-
db bun.IDB // Store DB connection for count queries
207-
hasModel bool // Track if Model() was called
208-
schema string // Separated schema name
209-
tableName string // Just the table name, without schema
210-
tableAlias string
211-
deferredPreloads []deferredPreload // Preloads to execute as separate queries
212-
inJoinContext bool // Track if we're in a JOIN relation context
213-
joinTableAlias string // Alias to use for JOIN conditions
214-
skipAutoDetect bool // Skip auto-detection to prevent circular calls
215-
}
216-
217-
// deferredPreload represents a preload that will be executed as a separate query
218-
// to avoid PostgreSQL identifier length limits
219-
type deferredPreload struct {
220-
relation string
221-
apply []func(common.SelectQuery) common.SelectQuery
205+
query *bun.SelectQuery
206+
db bun.IDB // Store DB connection for count queries
207+
hasModel bool // Track if Model() was called
208+
schema string // Separated schema name
209+
tableName string // Just the table name, without schema
210+
tableAlias string
211+
inJoinContext bool // Track if we're in a JOIN relation context
212+
joinTableAlias string // Alias to use for JOIN conditions
213+
skipAutoDetect bool // Skip auto-detection to prevent circular calls
222214
}
223215

224216
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -487,51 +479,8 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
487479
return b
488480
}
489481

490-
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
491-
// // when combined with typical column names
492-
// func shortenAliasForPostgres(relationPath string) (string, bool) {
493-
// // Convert relation path to the alias format Bun uses: dots become double underscores
494-
// // Also convert to lowercase and use snake_case as Bun does
495-
// parts := strings.Split(relationPath, ".")
496-
// alias := strings.ToLower(strings.Join(parts, "__"))
497-
498-
// // PostgreSQL truncates identifiers to 63 chars
499-
// // If the alias + typical column name would exceed this, we need to shorten
500-
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
501-
// const maxAliasLength = 30
502-
503-
// if len(alias) > maxAliasLength {
504-
// // Create a shortened alias using a hash of the original
505-
// hash := md5.Sum([]byte(alias))
506-
// hashStr := hex.EncodeToString(hash[:])[:8]
507-
508-
// // Keep first few chars of original for readability + hash
509-
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
510-
// if prefixLen > len(alias) {
511-
// prefixLen = len(alias)
512-
// }
513-
514-
// shortened := alias[:prefixLen] + "_" + hashStr
515-
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
516-
// alias, len(alias), shortened, len(shortened))
517-
// return shortened, true
518-
// }
519-
520-
// return alias, false
521-
// }
522-
523-
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
524-
// // Bun creates aliases like: relationChain__columnName
525-
// func estimateColumnAliasLength(relationPath string, columnName string) int {
526-
// relationParts := strings.Split(relationPath, ".")
527-
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
528-
// // Bun adds "__" between alias and column name
529-
// return len(aliasChain) + 2 + len(columnName)
530-
// }
531-
532482
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
533483
// Auto-detect relationship type and choose optimal loading strategy
534-
// Get the model from the query if available
535484
// Skip auto-detection if flag is set (prevents circular calls from JoinRelation)
536485
if !b.skipAutoDetect {
537486
model := b.query.GetModel()
@@ -554,49 +503,7 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
554503
}
555504
}
556505

557-
// Check if this relation chain would create problematic long aliases
558-
relationParts := strings.Split(relation, ".")
559-
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
560-
561-
// PostgreSQL's identifier limit is 63 characters
562-
const postgresIdentifierLimit = 63
563-
const safeAliasLimit = 35 // Leave room for column names
564-
565-
// If the alias chain is too long, defer this preload to be executed as a separate query
566-
if len(relationParts) > 1 && len(aliasChain) > safeAliasLimit {
567-
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
568-
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
569-
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
570-
571-
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
572-
// This avoids the long concatenated alias
573-
if len(relationParts) > 1 {
574-
// Load first level normally: "Parent"
575-
firstLevel := relationParts[0]
576-
remainingPath := strings.Join(relationParts[1:], ".")
577-
578-
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
579-
firstLevel, remainingPath)
580-
581-
// Apply the first level preload normally
582-
b.query = b.query.Relation(firstLevel)
583-
584-
// Store the remaining nested preload to be executed after the main query
585-
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
586-
relation: relation,
587-
apply: apply,
588-
})
589-
590-
return b
591-
}
592-
593-
// Single level but still too long - just warn and continue
594-
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
595-
"Consider renaming the field to avoid potential issues.",
596-
relation, len(aliasChain))
597-
}
598-
599-
// Normal preload handling
506+
// Use Bun's native Relation() for preloading
600507
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
601508
defer func() {
602509
if r := recover(); r != nil {
@@ -629,12 +536,7 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
629536
// Extract table alias if model implements TableAliasProvider
630537
if provider, ok := modelValue.(common.TableAliasProvider); ok {
631538
wrapper.tableAlias = provider.TableAlias()
632-
// Apply the alias to the Bun query so conditions can reference it
633-
if wrapper.tableAlias != "" {
634-
// Note: Bun's Relation() already sets up the table, but we can add
635-
// the alias explicitly if needed
636-
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
637-
}
539+
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
638540
}
639541
}
640542

@@ -644,7 +546,6 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
644546
// Apply each function in sequence
645547
for _, fn := range apply {
646548
if fn != nil {
647-
// Pass &current (pointer to interface variable), fn modifies and returns new interface value
648549
modified := fn(current)
649550
current = modified
650551
}
@@ -734,7 +635,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
734635
return fmt.Errorf("destination cannot be nil")
735636
}
736637

737-
// Execute the main query first
738638
err = b.query.Scan(ctx, dest)
739639
if err != nil {
740640
// Log SQL string for debugging
@@ -743,17 +643,6 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
743643
return err
744644
}
745645

746-
// Execute any deferred preloads
747-
if len(b.deferredPreloads) > 0 {
748-
err = b.executeDeferredPreloads(ctx, dest)
749-
if err != nil {
750-
logger.Warn("Failed to execute deferred preloads: %v", err)
751-
// Don't fail the whole query, just log the warning
752-
}
753-
// Clear deferred preloads to prevent re-execution
754-
b.deferredPreloads = nil
755-
}
756-
757646
return nil
758647
}
759648

@@ -803,7 +692,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
803692
}
804693
}
805694

806-
// Execute the main query first
807695
err = b.query.Scan(ctx)
808696
if err != nil {
809697
// Log SQL string for debugging
@@ -812,147 +700,9 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
812700
return err
813701
}
814702

815-
// Execute any deferred preloads
816-
if len(b.deferredPreloads) > 0 {
817-
model := b.query.GetModel()
818-
err = b.executeDeferredPreloads(ctx, model.Value())
819-
if err != nil {
820-
logger.Warn("Failed to execute deferred preloads: %v", err)
821-
// Don't fail the whole query, just log the warning
822-
}
823-
// Clear deferred preloads to prevent re-execution
824-
b.deferredPreloads = nil
825-
}
826-
827-
return nil
828-
}
829-
830-
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
831-
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
832-
if len(b.deferredPreloads) == 0 {
833-
return nil
834-
}
835-
836-
for _, dp := range b.deferredPreloads {
837-
err := b.executeSingleDeferredPreload(ctx, dest, dp)
838-
if err != nil {
839-
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
840-
}
841-
}
842-
843703
return nil
844704
}
845705

846-
// executeSingleDeferredPreload executes a single deferred preload
847-
// For a relation like "Parent.Child", it:
848-
// 1. Finds all loaded Parent records in dest
849-
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
850-
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
851-
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
852-
relationParts := strings.Split(dp.relation, ".")
853-
if len(relationParts) < 2 {
854-
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
855-
}
856-
857-
// The parent relation that was already loaded
858-
parentRelation := relationParts[0]
859-
// The child relation we need to load
860-
childRelation := strings.Join(relationParts[1:], ".")
861-
862-
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
863-
864-
// Use reflection to access the parent relation field(s) in the loaded records
865-
// Then load the child relation for those parent records
866-
destValue := reflect.ValueOf(dest)
867-
if destValue.Kind() == reflect.Ptr {
868-
destValue = destValue.Elem()
869-
}
870-
871-
// Handle both slice and single record
872-
if destValue.Kind() == reflect.Slice {
873-
// Iterate through each record in the slice
874-
for i := 0; i < destValue.Len(); i++ {
875-
record := destValue.Index(i)
876-
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
877-
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
878-
// Continue with other records
879-
}
880-
}
881-
} else {
882-
// Single record
883-
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
884-
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
885-
}
886-
}
887-
888-
return nil
889-
}
890-
891-
// loadChildRelationForRecord loads a child relation for a single parent record
892-
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
893-
// Ensure we're working with the actual struct value, not a pointer
894-
if record.Kind() == reflect.Ptr {
895-
record = record.Elem()
896-
}
897-
898-
// Get the parent relation field
899-
parentField := record.FieldByName(parentRelation)
900-
if !parentField.IsValid() {
901-
// Parent relation field doesn't exist
902-
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
903-
return nil
904-
}
905-
906-
// Check if the parent field is nil (for pointer fields)
907-
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
908-
// Parent relation not loaded or nil, skip
909-
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
910-
return nil
911-
}
912-
913-
// Get a pointer to the parent field so Bun can modify it
914-
// CRITICAL: We need to pass a pointer, not a value, so that when Bun
915-
// loads the child records and appends them to the slice, the changes
916-
// are reflected in the original struct field.
917-
var parentPtr interface{}
918-
if parentField.Kind() == reflect.Ptr {
919-
// Field is already a pointer (e.g., Parent *Parent), use as-is
920-
parentPtr = parentField.Interface()
921-
} else {
922-
// Field is a value (e.g., Comments []Comment), get its address
923-
if parentField.CanAddr() {
924-
parentPtr = parentField.Addr().Interface()
925-
} else {
926-
return fmt.Errorf("cannot get address of field '%s'", parentRelation)
927-
}
928-
}
929-
930-
// Load the child relation on the parent record
931-
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
932-
// CRITICAL: Use WherePK() to ensure we only load children for THIS specific parent
933-
// record, not the first parent in the database table.
934-
return b.db.NewSelect().
935-
Model(parentPtr).
936-
WherePK().
937-
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
938-
// Apply any custom query modifications
939-
if len(apply) > 0 {
940-
wrapper := &BunSelectQuery{query: sq, db: b.db}
941-
current := common.SelectQuery(wrapper)
942-
for _, fn := range apply {
943-
if fn != nil {
944-
current = fn(current)
945-
}
946-
}
947-
if finalBun, ok := current.(*BunSelectQuery); ok {
948-
return finalBun.query
949-
}
950-
}
951-
return sq
952-
}).
953-
Scan(ctx)
954-
}
955-
956706
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
957707
defer func() {
958708
if r := recover(); r != nil {

0 commit comments

Comments
 (0)