diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index cda4b9f539..aa7a842efc 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -314,6 +314,9 @@ def annotate_scope(self, scope: Scope) -> None: elif isinstance(source.expression, exp.Unnest): self._set_type(col, source.expression.type) + if col.type and col.type.args.get("nullable") is False: + col.meta["nonnull"] = True + if isinstance(self.schema, MappingSchema): for table_column in scope.table_columns: source = scope.sources.get(table_column.name) @@ -446,6 +449,11 @@ def _annotate_binary(self, expression: B) -> B: else: self._set_type(expression, self._maybe_coerce(left_type, right_type)) + if isinstance(expression, exp.Is) or ( + left.meta.get("nonnull") is True and right.meta.get("nonnull") is True + ): + expression.meta["nonnull"] = True + return expression def _annotate_unary(self, expression: E) -> E: @@ -456,6 +464,9 @@ def _annotate_unary(self, expression: E) -> E: else: self._set_type(expression, expression.this.type) + if expression.this.meta.get("nonnull") is True: + expression.meta["nonnull"] = True + return expression def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: @@ -466,6 +477,8 @@ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: else: self._set_type(expression, exp.DataType.Type.DOUBLE) + expression.meta["nonnull"] = True + return expression def _annotate_with_type( diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6520866151..38b2963987 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -395,14 +395,15 @@ def remove_complements(expression, root=True): """ Removing complements. - A AND NOT A -> FALSE - A OR NOT A -> TRUE + A AND NOT A -> FALSE (only for non-NULL A) + A OR NOT A -> TRUE (only for non-NULL A) """ if isinstance(expression, AND_OR) and (root or not expression.same_parent): ops = set(expression.flatten()) for op in ops: if isinstance(op, exp.Not) and op.this in ops: - return exp.false() if isinstance(expression, exp.And) else exp.true() + if expression.meta.get("nonnull") is True: + return exp.false() if isinstance(expression, exp.And) else exp.true() return expression diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 8857d89490..76d2f552cc 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -8,10 +8,10 @@ y OR y; y; x AND NOT x; -FALSE; +NOT x AND x; x OR NOT x; -TRUE; +NOT x OR x; 1 AND TRUE; TRUE; @@ -299,7 +299,7 @@ A XOR D XOR B XOR E XOR F XOR G XOR C; A XOR B XOR C XOR D XOR E XOR F XOR G; A AND NOT B AND C AND B; -FALSE; +A AND B AND C AND NOT B; (a AND b AND c AND d) AND (d AND c AND b AND a); a AND b AND c AND d; @@ -892,7 +892,7 @@ COALESCE(x, 1) = 1; x = 1 OR x IS NULL; COALESCE(x, 1) IS NULL; -FALSE; +NOT x IS NULL AND x IS NULL; COALESCE(ROW() OVER (), 1) = 1; ROW() OVER () = 1 OR ROW() OVER () IS NULL; @@ -1344,3 +1344,30 @@ WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT NOT CASE WHEN t0.a > 1 THEN t0 # dialect: sqlite WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT (NOT(CASE WHEN t0.a > 1 THEN t0.a ELSE t0.p END)) AS res FROM t0; WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT NOT CASE WHEN t0.a > 1 THEN t0.a ELSE t0.p END AS res FROM t0; + +-------------------------------------- +-- Simplify complements +-------------------------------------- +TRUE OR NOT TRUE; +TRUE; + +TRUE AND NOT TRUE; +FALSE; + +'a' OR NOT 'a'; +TRUE; + +'a' AND NOT 'a'; +FALSE; + +100 OR NOT 100; +TRUE; + +100 AND NOT 100; +FALSE; + +NULL OR NOT NULL; +NULL; + +NULL AND NOT NULL; +NULL; diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index d5cf860329..c0a426cb62 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1699,3 +1699,58 @@ def test_annotate_object_construct(self): self.assertEqual( annotated.selects[0].type.sql("snowflake"), 'OBJECT("foo" VARCHAR, "a b" VARCHAR)' ) + + def test_nonnull_annotation(self): + for literal_sql in ("1", "'foo'", "2.5"): + with self.subTest(f"Test NULL annotation for literal: {literal_sql}"): + sql = f"SELECT {literal_sql}" + query = parse_one(sql) + annotated = annotate_types(query) + assert annotated.selects[0].meta.get("nonnull") is True + + schema = {"foo": {"id": "INT"}} + + operand_pairs = ( + ("1", "1", True), + ("foo.id", "foo.id", None), + ("1", "foo.id", None), + ("foo.id", "1", None), + ) + + for predicate in (">", "<", ">=", "<=", "=", "!=", "<>", "LIKE", "NOT LIKE"): + for operand1, operand2, nonnull in operand_pairs: + sql_predicate = f"{operand1} {predicate} {operand2}" + with self.subTest(f"Test NULL propagation for predicate: {predicate}"): + sql = f"SELECT {sql_predicate} FROM foo" + query = parse_one(sql) + annotated = annotate_types(query, schema=schema) + assert annotated.selects[0].meta.get("nonnull") is nonnull + + for predicate in ("IS NULL", "IS NOT NULL"): + sql_predicate = f"foo.id {predicate}" + with self.subTest(f"Test NULL propagation for predicate: {predicate}"): + sql = f"SELECT {sql_predicate} FROM foo" + query = parse_one(sql) + annotated = annotate_types(query, schema=schema) + assert annotated.selects[0].meta.get("nonnull") is True + + for connector in ("AND", "OR"): + for predicate in (">", "<", ">=", "<=", "=", "!=", "<>", "LIKE", "NOT LIKE"): + for operand1, operand2, nonnull in operand_pairs: + sql_predicate = f"({operand1} {predicate} {operand2})" + sql_connector = f"{sql_predicate} {connector} {sql_predicate}" + with self.subTest( + f"Test NULL propagation for connector: {connector} with predicates: {predicate}" + ): + sql = f"SELECT {sql_connector} FROM foo" + query = parse_one(sql) + annotated = annotate_types(query, schema=schema) + assert annotated.selects[0].meta.get("nonnull") is nonnull + + for unary in ("NOT", "-"): + for value, nonnull in (("1", True), ("foo.id", None)): + with self.subTest(f"Test NULL propagation for unary: {unary} with value: {value}"): + sql = f"SELECT {unary} {value} FROM foo" + query = parse_one(sql) + annotated = annotate_types(query, schema=schema) + assert annotated.selects[0].meta.get("nonnull") is nonnull