Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,10 @@ def _annotate_binary(self, expression: B) -> B:

if isinstance(expression, (exp.Connector, exp.Predicate)):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
if isinstance(expression, exp.Is) or (
left.meta.get("nullable") is False and right.meta.get("nullable") is False
):
expression.meta["nullable"] = False
elif (left_type, right_type) in self.binary_coercions:
self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
else:
Expand All @@ -456,6 +460,9 @@ def _annotate_unary(self, expression: E) -> E:
else:
self._set_type(expression, expression.this.type)

if expression.this.meta.get("nullable") is False:
expression.meta["nullable"] = False

return expression

def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
Expand All @@ -466,6 +473,8 @@ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
else:
self._set_type(expression, exp.DataType.Type.DOUBLE)

expression.meta["nullable"] = False

return expression

def _annotate_with_type(
Expand Down
7 changes: 4 additions & 3 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,14 +395,15 @@ def remove_complements(expression, root=True):
"""
Removing complements.

A AND NOT A -> FALSE
A OR NOT A -> TRUE
A AND NOT A -> FALSE (only for non-NULL literals)
A OR NOT A -> TRUE (only for non-NULL literals)
"""
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
ops = set(expression.flatten())
for op in ops:
if isinstance(op, exp.Not) and op.this in ops:
return exp.false() if isinstance(expression, exp.And) else exp.true()
if expression.meta.get("nullable") is False:
return exp.false() if isinstance(expression, exp.And) else exp.true()

return expression

Expand Down
35 changes: 31 additions & 4 deletions tests/fixtures/optimizer/simplify.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ y OR y;
y;

x AND NOT x;
FALSE;
NOT x AND x;

x OR NOT x;
TRUE;
NOT x OR x;

1 AND TRUE;
TRUE;
Expand Down Expand Up @@ -299,7 +299,7 @@ A XOR D XOR B XOR E XOR F XOR G XOR C;
A XOR B XOR C XOR D XOR E XOR F XOR G;

A AND NOT B AND C AND B;
FALSE;
A AND B AND C AND NOT B;

(a AND b AND c AND d) AND (d AND c AND b AND a);
a AND b AND c AND d;
Expand Down Expand Up @@ -892,7 +892,7 @@ COALESCE(x, 1) = 1;
x = 1 OR x IS NULL;

COALESCE(x, 1) IS NULL;
FALSE;
NOT x IS NULL AND x IS NULL;

COALESCE(ROW() OVER (), 1) = 1;
ROW() OVER () = 1 OR ROW() OVER () IS NULL;
Expand Down Expand Up @@ -1344,3 +1344,30 @@ WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT NOT CASE WHEN t0.a > 1 THEN t0
# dialect: sqlite
WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT (NOT(CASE WHEN t0.a > 1 THEN t0.a ELSE t0.p END)) AS res FROM t0;
WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT NOT CASE WHEN t0.a > 1 THEN t0.a ELSE t0.p END AS res FROM t0;

--------------------------------------
-- Simplify complements
--------------------------------------
TRUE OR NOT TRUE;
TRUE;

TRUE AND NOT TRUE;
FALSE;

'a' OR NOT 'a';
TRUE;

'a' AND NOT 'a';
FALSE;

100 OR NOT 100;
TRUE;

100 AND NOT 100;
FALSE;

NULL OR NOT NULL;
NULL;

NULL AND NOT NULL;
NULL;
48 changes: 48 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,3 +1699,51 @@ def test_annotate_object_construct(self):
self.assertEqual(
annotated.selects[0].type.sql("snowflake"), 'OBJECT("foo" VARCHAR, "a b" VARCHAR)'
)

def test_nullable_annotation(self):
for literal_sql in ("1", "'foo'", "2.5"):
with self.subTest(f"Test NULL annotation for literal: {literal_sql}"):
sql = f"SELECT {literal_sql}"
query = parse_one(sql)
annotated = annotate_types(query)
assert annotated.selects[0].meta.get("nullable") is False

schema = {"foo": {"id": "INT"}}

for predicate in (">", "<", ">=", "<=", "=", "!=", "<>", "LIKE", "NOT LIKE"):
for operand, nullable in (("1", False), ("foo.id", None)):
sql_predicate = f"{operand} {predicate} {operand}"
with self.subTest(f"Test NULL propagation for predicate: {predicate}"):
sql = f"SELECT {sql_predicate} FROM foo"
query = parse_one(sql)
annotated = annotate_types(query, schema=schema)
assert annotated.selects[0].meta.get("nullable") is nullable

for predicate in ("IS NULL", "IS NOT NULL"):
sql_predicate = f"foo.id {predicate}"
with self.subTest(f"Test NULL propagation for predicate: {predicate}"):
sql = f"SELECT {sql_predicate} FROM foo"
query = parse_one(sql)
annotated = annotate_types(query, schema=schema)
assert annotated.selects[0].meta.get("nullable") is False

for connector in ("AND", "OR"):
for predicate in (">", "<", ">=", "<=", "=", "!=", "<>", "LIKE", "NOT LIKE"):
for operand, nullable in (("1", False), ("foo.id", None)):
sql_predicate = f"({operand} {predicate} {operand})"
sql_connector = f"{sql_predicate} {connector} {sql_predicate}"
with self.subTest(
f"Test NULL propagation for connector: {connector} with predicates: {predicate}"
):
sql = f"SELECT {sql_connector} FROM foo"
query = parse_one(sql)
annotated = annotate_types(query, schema=schema)
assert annotated.selects[0].meta.get("nullable") is nullable

for unary in ("NOT", "-"):
for value, nullable in (("1", False), ("foo.id", None)):
with self.subTest(f"Test NULL propagation for unary: {unary} with value: {value}"):
sql = f"SELECT {unary} {value} FROM foo"
query = parse_one(sql)
annotated = annotate_types(query, schema=schema)
assert annotated.selects[0].meta.get("nullable") is nullable