@@ -78,17 +78,22 @@ func IsTrivialCondition(cond string) bool {
7878 return false
7979}
8080
81- // SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
81+ // SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
8282// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
8383//
8484// Parameters:
8585// - where: The WHERE clause string to sanitize
86- // - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
86+ // - tableName: The correct table/relation name to use when fixing incorrect prefixes
87+ // - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
8788//
8889// Returns:
89- // - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
90+ // - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
9091// - An empty string if all conditions were trivial or the input was empty
91- func SanitizeWhereClause (where string , tableName string ) string {
92+ //
93+ // Note: This function will NOT add prefixes to unprefixed columns. It will only fix
94+ // incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
95+ // prefix matches a preloaded relation name, in which case it's left unchanged.
96+ func SanitizeWhereClause (where string , tableName string , options ... * RequestOptions ) string {
9297 if where == "" {
9398 return ""
9499 }
@@ -104,6 +109,22 @@ func SanitizeWhereClause(where string, tableName string) string {
104109 validColumns = getValidColumnsForTable (tableName )
105110 }
106111
112+ // Build a set of allowed table prefixes (main table + preloaded relations)
113+ allowedPrefixes := make (map [string ]bool )
114+ if tableName != "" {
115+ allowedPrefixes [tableName ] = true
116+ }
117+
118+ // Add preload relation names as allowed prefixes
119+ if len (options ) > 0 && options [0 ] != nil {
120+ for pi := range options [0 ].Preload {
121+ if options [0 ].Preload [pi ].Relation != "" {
122+ allowedPrefixes [options [0 ].Preload [pi ].Relation ] = true
123+ logger .Debug ("Added preload relation '%s' as allowed table prefix" , options [0 ].Preload [pi ].Relation )
124+ }
125+ }
126+ }
127+
107128 // Split by AND to handle multiple conditions
108129 conditions := splitByAND (where )
109130
@@ -124,22 +145,23 @@ func SanitizeWhereClause(where string, tableName string) string {
124145 continue
125146 }
126147
127- // If tableName is provided and the condition doesn't already have a table prefix,
128- // attempt to add it
129- if tableName != "" && ! hasTablePrefix (condToCheck ) {
130- // Check if this is a SQL expression/literal that shouldn't be prefixed
131- if ! IsSQLExpression (strings .ToLower (condToCheck )) {
132- // Extract the column name and prefix it
133- columnName := ExtractColumnName (condToCheck )
134- if columnName != "" {
135- // Only prefix if this is a valid column in the model
136- // If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
148+ // If tableName is provided and the condition HAS a table prefix, check if it's correct
149+ if tableName != "" && hasTablePrefix (condToCheck ) {
150+ // Extract the current prefix and column name
151+ currentPrefix , columnName := extractTableAndColumn (condToCheck )
152+
153+ if currentPrefix != "" && columnName != "" {
154+ // Check if the prefix is allowed (main table or preload relation)
155+ if ! allowedPrefixes [currentPrefix ] {
156+ // Prefix is not in the allowed list - only fix if it's a valid column in the main table
137157 if validColumns == nil || isValidColumn (columnName , validColumns ) {
138- // Replace in the original condition (without stripped parens)
139- cond = strings .Replace (cond , columnName , tableName + "." + columnName , 1 )
140- logger .Debug ("Prefixed column in condition: '%s'" , cond )
158+ // Replace the incorrect prefix with the correct main table name
159+ oldRef := currentPrefix + "." + columnName
160+ newRef := tableName + "." + columnName
161+ cond = strings .Replace (cond , oldRef , newRef , 1 )
162+ logger .Debug ("Fixed incorrect table prefix in condition: '%s' -> '%s'" , oldRef , newRef )
141163 } else {
142- logger .Debug ("Skipping prefix for '%s' - not a valid column in model" , columnName )
164+ logger .Debug ("Skipping prefix fix for '%s.%s ' - not a valid column in main table (might be preload relation)" , currentPrefix , columnName )
143165 }
144166 }
145167 }
@@ -288,6 +310,53 @@ func getValidColumnsForTable(tableName string) map[string]bool {
288310 return columnMap
289311}
290312
313+ // extractTableAndColumn extracts the table prefix and column name from a qualified reference
314+ // For example: "users.status = 'active'" returns ("users", "status")
315+ // Returns empty strings if no table prefix is found
316+ func extractTableAndColumn (cond string ) (table string , column string ) {
317+ // Common SQL operators to find the column reference
318+ operators := []string {" = " , " != " , " <> " , " > " , " >= " , " < " , " <= " , " LIKE " , " like " , " IN " , " in " , " IS " , " is " }
319+
320+ var columnRef string
321+
322+ // Find the column reference (left side of the operator)
323+ for _ , op := range operators {
324+ if idx := strings .Index (cond , op ); idx > 0 {
325+ columnRef = strings .TrimSpace (cond [:idx ])
326+ break
327+ }
328+ }
329+
330+ // If no operator found, the whole condition might be the column reference
331+ if columnRef == "" {
332+ parts := strings .Fields (cond )
333+ if len (parts ) > 0 {
334+ columnRef = parts [0 ]
335+ }
336+ }
337+
338+ if columnRef == "" {
339+ return "" , ""
340+ }
341+
342+ // Remove any quotes
343+ columnRef = strings .Trim (columnRef , "`\" '" )
344+
345+ // Check if it contains a dot (qualified reference)
346+ if dotIdx := strings .LastIndex (columnRef , "." ); dotIdx > 0 {
347+ table = columnRef [:dotIdx ]
348+ column = columnRef [dotIdx + 1 :]
349+
350+ // Remove quotes from table and column if present
351+ table = strings .Trim (table , "`\" '" )
352+ column = strings .Trim (column , "`\" '" )
353+
354+ return table , column
355+ }
356+
357+ return "" , ""
358+ }
359+
291360// isValidColumn checks if a column name exists in the valid columns map
292361// Handles case-insensitive comparison
293362func isValidColumn (columnName string , validColumns map [string ]bool ) bool {
0 commit comments