Skip to content

Commit 15e0428

Browse files
committed
Fixups
1 parent f04b1cf commit 15e0428

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

sqlglot/optimizer/annotate_types.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def annotate_types(
5353
Args:
5454
expression: Expression to annotate.
5555
schema: Database schema.
56-
annotators: Maps expression type to corresponding annotation function.
56+
expression_metadata: Maps expression type to corresponding annotation function.
5757
coerces_to: Maps expression type to set of types that it can be coerced into.
5858
overwrite_types: Re-annotate the existing AST types.
5959
Returns:
@@ -64,7 +64,7 @@ def annotate_types(
6464

6565
return TypeAnnotator(
6666
schema=schema,
67-
annotators=annotators,
67+
expression_metadata=expression_metadata,
6868
coerces_to=coerces_to,
6969
overwrite_types=overwrite_types,
7070
).annotate(expression)
@@ -182,7 +182,6 @@ def __init__(
182182
self,
183183
schema: Schema,
184184
expression_metadata: t.Optional[ExpressionMetadataType] = None,
185-
annotators: t.Optional[AnnotatorsType] = None,
186185
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
187186
binary_coercions: t.Optional[BinaryCoercions] = None,
188187
overwrite_types: bool = True,
@@ -213,7 +212,7 @@ def __init__(
213212
# When set to False, this enables partial annotation by skipping already-annotated nodes
214213
self._overwrite_types = overwrite_types
215214

216-
def clean(self) -> None:
215+
def clear(self) -> None:
217216
self._visited.clear()
218217
self._null_expressions.clear()
219218
self._setop_column_types.clear()
@@ -236,6 +235,8 @@ def _set_type(
236235
self._null_expressions.pop(expression_id, None)
237236

238237
def annotate(self, expression: E, annotate_scope: bool = True) -> E:
238+
# This flag is used to avoid costly scope traversals when we only care about annotating
239+
# non-column expressions (partial type inference), e.g., when simplifying in the optimizer
239240
if annotate_scope:
240241
for scope in traverse_scope(expression):
241242
self.annotate_scope(scope)

sqlglot/optimizer/simplify.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
if t.TYPE_CHECKING:
1919
from sqlglot.dialects.dialect import DialectType
2020

21+
DateRange = t.Tuple[datetime.date, datetime.date]
2122
DateTruncBinaryTransform = t.Callable[
2223
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
2324
]
2425

25-
DateRange = t.Tuple[datetime.date, datetime.date]
2626

2727
logger = logging.getLogger("sqlglot")
2828

@@ -86,7 +86,7 @@ def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.E
8686
return new_expression
8787

8888
if new_expression != expression:
89-
self._annotator.clean()
89+
self._annotator.clear()
9090
new_expression = self._annotator.annotate(
9191
expression=new_expression, annotate_scope=False
9292
)
@@ -493,8 +493,6 @@ def __init__(
493493

494494
CONCATS = (exp.Concat, exp.DPipe)
495495

496-
DateRange = t.Tuple[datetime.date, datetime.date]
497-
498496
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
499497
exp.LT: lambda l, dt, u, d, t: l
500498
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),

0 commit comments

Comments
 (0)