Skip to content

Commit 1cd04b7

Browse files
author
Hein
committed
Better where clause handling for preloads
1 parent 0d49090 commit 1cd04b7

File tree

3 files changed

+138
-128
lines changed

3 files changed

+138
-128
lines changed

pkg/common/sql_helpers.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package common
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/bitechdev/ResolveSpec/pkg/logger"
8+
)
9+
10+
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
11+
// the relation prefix (alias). If not present, it attempts to add it to column references.
12+
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
13+
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
14+
if where == "" {
15+
return where, nil
16+
}
17+
18+
// Check if the relation name is already present in the WHERE clause
19+
lowerWhere := strings.ToLower(where)
20+
lowerRelation := strings.ToLower(relationName)
21+
22+
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
23+
if strings.Contains(lowerWhere, lowerRelation+".") ||
24+
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
25+
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
26+
// Relation prefix is already present
27+
return where, nil
28+
}
29+
30+
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
31+
// we can't safely auto-fix it - require explicit prefix
32+
if strings.Contains(lowerWhere, " or ") ||
33+
strings.Contains(where, "(") ||
34+
strings.Contains(where, ")") {
35+
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
36+
}
37+
38+
// Try to add the relation prefix to simple column references
39+
// This handles basic cases like "column = value" or "column = value AND other_column = value"
40+
// Split by AND to handle multiple conditions (case-insensitive)
41+
originalConditions := strings.Split(where, " AND ")
42+
43+
// If uppercase split didn't work, try lowercase
44+
if len(originalConditions) == 1 {
45+
originalConditions = strings.Split(where, " and ")
46+
}
47+
48+
fixedConditions := make([]string, 0, len(originalConditions))
49+
50+
for _, cond := range originalConditions {
51+
cond = strings.TrimSpace(cond)
52+
if cond == "" {
53+
continue
54+
}
55+
56+
// Check if this condition already has a table prefix (contains a dot)
57+
if strings.Contains(cond, ".") {
58+
fixedConditions = append(fixedConditions, cond)
59+
continue
60+
}
61+
62+
// Check if this is a SQL expression/literal that shouldn't be prefixed
63+
lowerCond := strings.ToLower(strings.TrimSpace(cond))
64+
if IsSQLExpression(lowerCond) {
65+
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
66+
fixedConditions = append(fixedConditions, cond)
67+
continue
68+
}
69+
70+
// Extract the column name (first identifier before operator)
71+
columnName := ExtractColumnName(cond)
72+
if columnName == "" {
73+
// Can't identify column name, require explicit prefix
74+
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
75+
}
76+
77+
// Add relation prefix to the column name only
78+
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
79+
fixedConditions = append(fixedConditions, fixedCond)
80+
}
81+
82+
fixedWhere := strings.Join(fixedConditions, " AND ")
83+
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
84+
return fixedWhere, nil
85+
}
86+
87+
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
88+
func IsSQLExpression(cond string) bool {
89+
// Common SQL literals and expressions
90+
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
91+
for _, literal := range sqlLiterals {
92+
if cond == literal {
93+
return true
94+
}
95+
}
96+
return false
97+
}
98+
99+
// ExtractColumnName extracts the column name from a WHERE condition
100+
// For example: "status = 'active'" returns "status"
101+
func ExtractColumnName(cond string) string {
102+
// Common SQL operators
103+
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
104+
105+
for _, op := range operators {
106+
if idx := strings.Index(cond, op); idx > 0 {
107+
columnName := strings.TrimSpace(cond[:idx])
108+
// Remove quotes if present
109+
columnName = strings.Trim(columnName, "`\"'")
110+
return columnName
111+
}
112+
}
113+
114+
// If no operator found, check if it's a simple identifier (for boolean columns)
115+
parts := strings.Fields(cond)
116+
if len(parts) > 0 {
117+
columnName := strings.Trim(parts[0], "`\"'")
118+
// Check if it's a valid identifier (not a SQL keyword)
119+
if !IsSQLKeyword(strings.ToLower(columnName)) {
120+
return columnName
121+
}
122+
}
123+
124+
return ""
125+
}
126+
127+
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
128+
func IsSQLKeyword(word string) bool {
129+
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
130+
for _, kw := range keywords {
131+
if word == kw {
132+
return true
133+
}
134+
}
135+
return false
136+
}

pkg/resolvespec/handler.go

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,69 +1105,6 @@ type relationshipInfo struct {
11051105
relatedModel interface{}
11061106
}
11071107

1108-
// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains
1109-
// the relation prefix (alias). If not present, it attempts to add it to column references.
1110-
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
1111-
func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) {
1112-
if where == "" {
1113-
return where, nil
1114-
}
1115-
1116-
// Check if the relation name is already present in the WHERE clause
1117-
lowerWhere := strings.ToLower(where)
1118-
lowerRelation := strings.ToLower(relationName)
1119-
1120-
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
1121-
if strings.Contains(lowerWhere, lowerRelation+".") ||
1122-
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
1123-
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
1124-
// Relation prefix is already present
1125-
return where, nil
1126-
}
1127-
1128-
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
1129-
// we can't safely auto-fix it - require explicit prefix
1130-
if strings.Contains(lowerWhere, " or ") ||
1131-
strings.Contains(where, "(") ||
1132-
strings.Contains(where, ")") {
1133-
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
1134-
}
1135-
1136-
// Try to add the relation prefix to simple column references
1137-
// This handles basic cases like "column = value" or "column = value AND other_column = value"
1138-
// Split by AND to handle multiple conditions (case-insensitive)
1139-
originalConditions := strings.Split(where, " AND ")
1140-
1141-
// If uppercase split didn't work, try lowercase
1142-
if len(originalConditions) == 1 {
1143-
originalConditions = strings.Split(where, " and ")
1144-
}
1145-
1146-
fixedConditions := make([]string, 0, len(originalConditions))
1147-
1148-
for _, cond := range originalConditions {
1149-
cond = strings.TrimSpace(cond)
1150-
if cond == "" {
1151-
continue
1152-
}
1153-
1154-
// Check if this condition already has a table prefix (contains a dot)
1155-
if strings.Contains(cond, ".") {
1156-
fixedConditions = append(fixedConditions, cond)
1157-
continue
1158-
}
1159-
1160-
// Add relation prefix to the column name
1161-
// This prefixes the entire condition with "relationName."
1162-
fixedCond := fmt.Sprintf("%s.%s", relationName, cond)
1163-
fixedConditions = append(fixedConditions, fixedCond)
1164-
}
1165-
1166-
fixedWhere := strings.Join(fixedConditions, " AND ")
1167-
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
1168-
return fixedWhere, nil
1169-
}
1170-
11711108
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
11721109
modelType := reflect.TypeOf(model)
11731110

@@ -1197,7 +1134,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
11971134

11981135
// Validate and fix WHERE clause to ensure it contains the relation prefix
11991136
if len(preload.Where) > 0 {
1200-
fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, relationFieldName)
1137+
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
12011138
if err != nil {
12021139
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
12031140
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))

pkg/restheadspec/handler.go

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -200,69 +200,6 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
200200

201201
// parseOptionsFromHeaders is now implemented in headers.go
202202

203-
// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains
204-
// the relation prefix (alias). If not present, it attempts to add it to column references.
205-
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
206-
func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) {
207-
if where == "" {
208-
return where, nil
209-
}
210-
211-
// Check if the relation name is already present in the WHERE clause
212-
lowerWhere := strings.ToLower(where)
213-
lowerRelation := strings.ToLower(relationName)
214-
215-
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
216-
if strings.Contains(lowerWhere, lowerRelation+".") ||
217-
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
218-
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
219-
// Relation prefix is already present
220-
return where, nil
221-
}
222-
223-
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
224-
// we can't safely auto-fix it - require explicit prefix
225-
if strings.Contains(lowerWhere, " or ") ||
226-
strings.Contains(where, "(") ||
227-
strings.Contains(where, ")") {
228-
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
229-
}
230-
231-
// Try to add the relation prefix to simple column references
232-
// This handles basic cases like "column = value" or "column = value AND other_column = value"
233-
// Split by AND to handle multiple conditions (case-insensitive)
234-
originalConditions := strings.Split(where, " AND ")
235-
236-
// If uppercase split didn't work, try lowercase
237-
if len(originalConditions) == 1 {
238-
originalConditions = strings.Split(where, " and ")
239-
}
240-
241-
fixedConditions := make([]string, 0, len(originalConditions))
242-
243-
for _, cond := range originalConditions {
244-
cond = strings.TrimSpace(cond)
245-
if cond == "" {
246-
continue
247-
}
248-
249-
// Check if this condition already has a table prefix (contains a dot)
250-
if strings.Contains(cond, ".") {
251-
fixedConditions = append(fixedConditions, cond)
252-
continue
253-
}
254-
255-
// Add relation prefix to the column name
256-
// This prefixes the entire condition with "relationName."
257-
fixedCond := fmt.Sprintf("%s.%s", relationName, cond)
258-
fixedConditions = append(fixedConditions, fixedCond)
259-
}
260-
261-
fixedWhere := strings.Join(fixedConditions, " AND ")
262-
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
263-
return fixedWhere, nil
264-
}
265-
266203
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
267204
// Capture panics and return error response
268205
defer func() {
@@ -410,7 +347,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
410347

411348
// Validate and fix WHERE clause to ensure it contains the relation prefix
412349
if len(preload.Where) > 0 {
413-
fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, preload.Relation)
350+
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
414351
if err != nil {
415352
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
416353
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",

0 commit comments

Comments
 (0)