Skip to content

Commit e1abd5e

Browse files
author
Hein
committed
Enhanced the SanitizeWhereClause function
1 parent ca4e539 commit e1abd5e

File tree

4 files changed

+306
-45
lines changed

4 files changed

+306
-45
lines changed

pkg/common/sql_helpers.go

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,22 @@ func IsTrivialCondition(cond string) bool {
7878
return false
7979
}
8080

81-
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
81+
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
8282
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
8383
//
8484
// Parameters:
8585
// - where: The WHERE clause string to sanitize
86-
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
86+
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
87+
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
8788
//
8889
// Returns:
89-
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
90+
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
9091
// - An empty string if all conditions were trivial or the input was empty
91-
func SanitizeWhereClause(where string, tableName string) string {
92+
//
93+
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
94+
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
95+
// prefix matches a preloaded relation name, in which case it's left unchanged.
96+
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
9297
if where == "" {
9398
return ""
9499
}
@@ -104,6 +109,22 @@ func SanitizeWhereClause(where string, tableName string) string {
104109
validColumns = getValidColumnsForTable(tableName)
105110
}
106111

112+
// Build a set of allowed table prefixes (main table + preloaded relations)
113+
allowedPrefixes := make(map[string]bool)
114+
if tableName != "" {
115+
allowedPrefixes[tableName] = true
116+
}
117+
118+
// Add preload relation names as allowed prefixes
119+
if len(options) > 0 && options[0] != nil {
120+
for pi := range options[0].Preload {
121+
if options[0].Preload[pi].Relation != "" {
122+
allowedPrefixes[options[0].Preload[pi].Relation] = true
123+
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
124+
}
125+
}
126+
}
127+
107128
// Split by AND to handle multiple conditions
108129
conditions := splitByAND(where)
109130

@@ -124,22 +145,23 @@ func SanitizeWhereClause(where string, tableName string) string {
124145
continue
125146
}
126147

127-
// If tableName is provided and the condition doesn't already have a table prefix,
128-
// attempt to add it
129-
if tableName != "" && !hasTablePrefix(condToCheck) {
130-
// Check if this is a SQL expression/literal that shouldn't be prefixed
131-
if !IsSQLExpression(strings.ToLower(condToCheck)) {
132-
// Extract the column name and prefix it
133-
columnName := ExtractColumnName(condToCheck)
134-
if columnName != "" {
135-
// Only prefix if this is a valid column in the model
136-
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
148+
// If tableName is provided and the condition HAS a table prefix, check if it's correct
149+
if tableName != "" && hasTablePrefix(condToCheck) {
150+
// Extract the current prefix and column name
151+
currentPrefix, columnName := extractTableAndColumn(condToCheck)
152+
153+
if currentPrefix != "" && columnName != "" {
154+
// Check if the prefix is allowed (main table or preload relation)
155+
if !allowedPrefixes[currentPrefix] {
156+
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
137157
if validColumns == nil || isValidColumn(columnName, validColumns) {
138-
// Replace in the original condition (without stripped parens)
139-
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
140-
logger.Debug("Prefixed column in condition: '%s'", cond)
158+
// Replace the incorrect prefix with the correct main table name
159+
oldRef := currentPrefix + "." + columnName
160+
newRef := tableName + "." + columnName
161+
cond = strings.Replace(cond, oldRef, newRef, 1)
162+
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
141163
} else {
142-
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
164+
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
143165
}
144166
}
145167
}
@@ -288,6 +310,53 @@ func getValidColumnsForTable(tableName string) map[string]bool {
288310
return columnMap
289311
}
290312

313+
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
314+
// For example: "users.status = 'active'" returns ("users", "status")
315+
// Returns empty strings if no table prefix is found
316+
func extractTableAndColumn(cond string) (table string, column string) {
317+
// Common SQL operators to find the column reference
318+
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
319+
320+
var columnRef string
321+
322+
// Find the column reference (left side of the operator)
323+
for _, op := range operators {
324+
if idx := strings.Index(cond, op); idx > 0 {
325+
columnRef = strings.TrimSpace(cond[:idx])
326+
break
327+
}
328+
}
329+
330+
// If no operator found, the whole condition might be the column reference
331+
if columnRef == "" {
332+
parts := strings.Fields(cond)
333+
if len(parts) > 0 {
334+
columnRef = parts[0]
335+
}
336+
}
337+
338+
if columnRef == "" {
339+
return "", ""
340+
}
341+
342+
// Remove any quotes
343+
columnRef = strings.Trim(columnRef, "`\"'")
344+
345+
// Check if it contains a dot (qualified reference)
346+
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
347+
table = columnRef[:dotIdx]
348+
column = columnRef[dotIdx+1:]
349+
350+
// Remove quotes from table and column if present
351+
table = strings.Trim(table, "`\"'")
352+
column = strings.Trim(column, "`\"'")
353+
354+
return table, column
355+
}
356+
357+
return "", ""
358+
}
359+
291360
// isValidColumn checks if a column name exists in the valid columns map
292361
// Handles case-insensitive comparison
293362
func isValidColumn(columnName string, validColumns map[string]bool) bool {

0 commit comments

Comments
 (0)