@@ -585,12 +585,12 @@ def accept_loop(
585
585
else_body : Statement | None = None ,
586
586
* ,
587
587
exit_condition : Expression | None = None ,
588
+ on_enter_body : Callable [[], None ] | None = None ,
588
589
) -> None :
589
590
"""Repeatedly type check a loop body until the frame doesn't change."""
590
591
591
592
# The outer frame accumulates the results of all iterations:
592
593
with self .binder .frame_context (can_skip = False , conditional_frame = True ):
593
-
594
594
# Check for potential decreases in the number of partial types so as not to stop the
595
595
# iteration too early:
596
596
partials_old = sum (len (pts .map ) for pts in self .partial_types )
@@ -603,6 +603,9 @@ def accept_loop(
603
603
604
604
while True :
605
605
with self .binder .frame_context (can_skip = True , break_frame = 2 , continue_frame = 1 ):
606
+ if on_enter_body is not None :
607
+ on_enter_body ()
608
+
606
609
self .accept (body )
607
610
partials_new = sum (len (pts .map ) for pts in self .partial_types )
608
611
if (partials_new == partials_old ) and not self .binder .last_pop_changed :
@@ -615,6 +618,9 @@ def accept_loop(
615
618
self .options .enabled_error_codes .add (codes .REDUNDANT_EXPR )
616
619
if warn_unreachable or warn_redundant :
617
620
with self .binder .frame_context (can_skip = True , break_frame = 2 , continue_frame = 1 ):
621
+ if on_enter_body is not None :
622
+ on_enter_body ()
623
+
618
624
self .accept (body )
619
625
620
626
# If exit_condition is set, assume it must be False on exit from the loop:
@@ -5126,8 +5132,14 @@ def visit_for_stmt(self, s: ForStmt) -> None:
5126
5132
iterator_type , item_type = self .analyze_iterable_item_type (s .expr )
5127
5133
s .inferred_item_type = item_type
5128
5134
s .inferred_iterator_type = iterator_type
5129
- self .analyze_index_variables (s .index , item_type , s .index_type is None , s )
5130
- self .accept_loop (s .body , s .else_body )
5135
+
5136
+ self .accept_loop (
5137
+ s .body ,
5138
+ s .else_body ,
5139
+ on_enter_body = lambda : self .analyze_index_variables (
5140
+ s .index , item_type , s .index_type is None , s
5141
+ ),
5142
+ )
5131
5143
5132
5144
def analyze_async_iterable_item_type (self , expr : Expression ) -> tuple [Type , Type ]:
5133
5145
"""Analyse async iterable expression and return iterator and iterator item types."""
0 commit comments