Skip to content

Commit 09f2256

Browse files
author
Hein
committed
feat(sql): ✨ Enhance SQL clause handling with parentheses
* Add EnsureOuterParentheses function to wrap clauses in parentheses. * Implement logic to preserve outer parentheses for OR conditions. * Update SanitizeWhereClause to utilize new function for better query safety. * Introduce tests for EnsureOuterParentheses and containsTopLevelOR functions. * Refactor filter application in handler to group OR filters correctly.
1 parent c12c045 commit 09f2256

File tree

4 files changed

+420
-6
lines changed

4 files changed

+420
-6
lines changed

pkg/common/sql_helpers.go

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ func validateWhereClauseSecurity(where string) error {
130130
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
131131
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
132132
// prefix matches a preloaded relation name, in which case it's left unchanged.
133+
//
134+
// IMPORTANT: Outer parentheses are preserved if the clause contains top-level OR operators
135+
// to prevent OR logic from escaping and affecting the entire query incorrectly.
133136
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
134137
if where == "" {
135138
return ""
@@ -143,8 +146,19 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
143146
return ""
144147
}
145148

146-
// Strip outer parentheses and re-trim
147-
where = stripOuterParentheses(where)
149+
// Check if the original clause has outer parentheses and contains OR operators
150+
// If so, we need to preserve the outer parentheses to prevent OR logic from escaping
151+
hasOuterParens := false
152+
if len(where) > 0 && where[0] == '(' && where[len(where)-1] == ')' {
153+
_, hasOuterParens = stripOneMatchingOuterParen(where)
154+
}
155+
156+
// Strip outer parentheses and re-trim for processing
157+
whereWithoutParens := stripOuterParentheses(where)
158+
shouldPreserveParens := hasOuterParens && containsTopLevelOR(whereWithoutParens)
159+
160+
// Use the stripped version for processing
161+
where = whereWithoutParens
148162

149163
// Get valid columns from the model if tableName is provided
150164
var validColumns map[string]bool
@@ -229,7 +243,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
229243

230244
result := strings.Join(validConditions, " AND ")
231245

232-
if result != where {
246+
// If the original clause had outer parentheses and contains OR operators,
247+
// restore the outer parentheses to prevent OR logic from escaping
248+
if shouldPreserveParens {
249+
result = "(" + result + ")"
250+
logger.Debug("Preserved outer parentheses for OR conditions: '%s'", result)
251+
}
252+
253+
if result != where && !shouldPreserveParens {
233254
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
234255
}
235256

@@ -290,6 +311,93 @@ func stripOneMatchingOuterParen(s string) (string, bool) {
290311
return strings.TrimSpace(s[1 : len(s)-1]), true
291312
}
292313

314+
// EnsureOuterParentheses ensures that a SQL clause is wrapped in parentheses
315+
// to prevent OR logic from escaping. It checks if the clause already has
316+
// matching outer parentheses and only adds them if they don't exist.
317+
//
318+
// This is particularly important for OR conditions and complex filters where
319+
// the absence of parentheses could cause the logic to escape and affect
320+
// the entire query incorrectly.
321+
//
322+
// Parameters:
323+
// - clause: The SQL clause to check and potentially wrap
324+
//
325+
// Returns:
326+
// - The clause with guaranteed outer parentheses, or empty string if input is empty
327+
func EnsureOuterParentheses(clause string) string {
328+
if clause == "" {
329+
return ""
330+
}
331+
332+
clause = strings.TrimSpace(clause)
333+
if clause == "" {
334+
return ""
335+
}
336+
337+
// Check if the clause already has matching outer parentheses
338+
_, hasOuterParens := stripOneMatchingOuterParen(clause)
339+
340+
// If it already has matching outer parentheses, return as-is
341+
if hasOuterParens {
342+
return clause
343+
}
344+
345+
// Otherwise, wrap it in parentheses
346+
return "(" + clause + ")"
347+
}
348+
349+
// containsTopLevelOR checks if a SQL clause contains OR operators at the top level
350+
// (i.e., not inside parentheses or subqueries). This is used to determine if
351+
// outer parentheses should be preserved to prevent OR logic from escaping.
352+
func containsTopLevelOR(clause string) bool {
353+
if clause == "" {
354+
return false
355+
}
356+
357+
depth := 0
358+
inSingleQuote := false
359+
inDoubleQuote := false
360+
lowerClause := strings.ToLower(clause)
361+
362+
for i := 0; i < len(clause); i++ {
363+
ch := clause[i]
364+
365+
// Track quote state
366+
if ch == '\'' && !inDoubleQuote {
367+
inSingleQuote = !inSingleQuote
368+
continue
369+
}
370+
if ch == '"' && !inSingleQuote {
371+
inDoubleQuote = !inDoubleQuote
372+
continue
373+
}
374+
375+
// Skip if inside quotes
376+
if inSingleQuote || inDoubleQuote {
377+
continue
378+
}
379+
380+
// Track parenthesis depth
381+
switch ch {
382+
case '(':
383+
depth++
384+
case ')':
385+
depth--
386+
}
387+
388+
// Only check for OR at depth 0 (not inside parentheses)
389+
if depth == 0 && i+4 <= len(clause) {
390+
// Check for " OR " (case-insensitive)
391+
substring := lowerClause[i : i+4]
392+
if substring == " or " {
393+
return true
394+
}
395+
}
396+
}
397+
398+
return false
399+
}
400+
293401
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
294402
// This is parenthesis-aware and won't split on AND operators inside subqueries
295403
func splitByAND(where string) []string {

pkg/common/sql_helpers_test.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,179 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
659659
}
660660
}
661661

662+
func TestEnsureOuterParentheses(t *testing.T) {
663+
tests := []struct {
664+
name string
665+
input string
666+
expected string
667+
}{
668+
{
669+
name: "no parentheses",
670+
input: "status = 'active'",
671+
expected: "(status = 'active')",
672+
},
673+
{
674+
name: "already has outer parentheses",
675+
input: "(status = 'active')",
676+
expected: "(status = 'active')",
677+
},
678+
{
679+
name: "OR condition without parentheses",
680+
input: "status = 'active' OR status = 'pending'",
681+
expected: "(status = 'active' OR status = 'pending')",
682+
},
683+
{
684+
name: "OR condition with parentheses",
685+
input: "(status = 'active' OR status = 'pending')",
686+
expected: "(status = 'active' OR status = 'pending')",
687+
},
688+
{
689+
name: "complex condition with nested parentheses",
690+
input: "(status = 'active' OR status = 'pending') AND (age > 18)",
691+
expected: "((status = 'active' OR status = 'pending') AND (age > 18))",
692+
},
693+
{
694+
name: "empty string",
695+
input: "",
696+
expected: "",
697+
},
698+
{
699+
name: "whitespace only",
700+
input: " ",
701+
expected: "",
702+
},
703+
{
704+
name: "mismatched parentheses - adds outer ones",
705+
input: "(status = 'active' OR status = 'pending'",
706+
expected: "((status = 'active' OR status = 'pending')",
707+
},
708+
}
709+
710+
for _, tt := range tests {
711+
t.Run(tt.name, func(t *testing.T) {
712+
result := EnsureOuterParentheses(tt.input)
713+
if result != tt.expected {
714+
t.Errorf("EnsureOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
715+
}
716+
})
717+
}
718+
}
719+
720+
func TestContainsTopLevelOR(t *testing.T) {
721+
tests := []struct {
722+
name string
723+
input string
724+
expected bool
725+
}{
726+
{
727+
name: "no OR operator",
728+
input: "status = 'active' AND age > 18",
729+
expected: false,
730+
},
731+
{
732+
name: "top-level OR",
733+
input: "status = 'active' OR status = 'pending'",
734+
expected: true,
735+
},
736+
{
737+
name: "OR inside parentheses",
738+
input: "age > 18 AND (status = 'active' OR status = 'pending')",
739+
expected: false,
740+
},
741+
{
742+
name: "OR in subquery",
743+
input: "id IN (SELECT id FROM users WHERE status = 'active' OR status = 'pending')",
744+
expected: false,
745+
},
746+
{
747+
name: "OR inside quotes",
748+
input: "comment = 'this OR that'",
749+
expected: false,
750+
},
751+
{
752+
name: "mixed - top-level OR and nested OR",
753+
input: "name = 'test' OR (status = 'active' OR status = 'pending')",
754+
expected: true,
755+
},
756+
{
757+
name: "empty string",
758+
input: "",
759+
expected: false,
760+
},
761+
{
762+
name: "lowercase or",
763+
input: "status = 'active' or status = 'pending'",
764+
expected: true,
765+
},
766+
{
767+
name: "uppercase OR",
768+
input: "status = 'active' OR status = 'pending'",
769+
expected: true,
770+
},
771+
}
772+
773+
for _, tt := range tests {
774+
t.Run(tt.name, func(t *testing.T) {
775+
result := containsTopLevelOR(tt.input)
776+
if result != tt.expected {
777+
t.Errorf("containsTopLevelOR(%q) = %v; want %v", tt.input, result, tt.expected)
778+
}
779+
})
780+
}
781+
}
782+
783+
func TestSanitizeWhereClause_PreservesParenthesesWithOR(t *testing.T) {
784+
tests := []struct {
785+
name string
786+
where string
787+
tableName string
788+
expected string
789+
}{
790+
{
791+
name: "OR condition with outer parentheses - preserved",
792+
where: "(status = 'active' OR status = 'pending')",
793+
tableName: "users",
794+
expected: "(users.status = 'active' OR users.status = 'pending')",
795+
},
796+
{
797+
name: "AND condition with outer parentheses - stripped (no OR)",
798+
where: "(status = 'active' AND age > 18)",
799+
tableName: "users",
800+
expected: "users.status = 'active' AND users.age > 18",
801+
},
802+
{
803+
name: "complex OR with nested conditions",
804+
where: "((status = 'active' OR status = 'pending') AND age > 18)",
805+
tableName: "users",
806+
// Outer parens are stripped, but inner parens with OR are preserved
807+
expected: "(users.status = 'active' OR users.status = 'pending') AND users.age > 18",
808+
},
809+
{
810+
name: "OR without outer parentheses - no parentheses added by SanitizeWhereClause",
811+
where: "status = 'active' OR status = 'pending'",
812+
tableName: "users",
813+
expected: "users.status = 'active' OR users.status = 'pending'",
814+
},
815+
{
816+
name: "simple OR with parentheses - preserved",
817+
where: "(users.status = 'active' OR users.status = 'pending')",
818+
tableName: "users",
819+
// Already has correct prefixes, parentheses preserved
820+
expected: "(users.status = 'active' OR users.status = 'pending')",
821+
},
822+
}
823+
824+
for _, tt := range tests {
825+
t.Run(tt.name, func(t *testing.T) {
826+
prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName)
827+
result := SanitizeWhereClause(prefixedWhere, tt.tableName)
828+
if result != tt.expected {
829+
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
830+
}
831+
})
832+
}
833+
}
834+
662835
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
663836
tests := []struct {
664837
name string

pkg/resolvespec/handler.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
318318
if cursorFilter != "" {
319319
logger.Debug("Applying cursor filter: %s", cursorFilter)
320320
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
321+
// Ensure outer parentheses to prevent OR logic from escaping
322+
sanitizedCursor = common.EnsureOuterParentheses(sanitizedCursor)
321323
if sanitizedCursor != "" {
322324
query = query.Where(sanitizedCursor)
323325
}
@@ -1656,6 +1658,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
16561658
// Build RequestOptions with all preloads to allow references to sibling relations
16571659
preloadOpts := &common.RequestOptions{Preload: preloads}
16581660
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
1661+
// Ensure outer parentheses to prevent OR logic from escaping
1662+
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
16591663
if len(sanitizedWhere) > 0 {
16601664
sq = sq.Where(sanitizedWhere)
16611665
}

0 commit comments

Comments
 (0)