@@ -90,7 +90,11 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
90
90
} ;
91
91
92
92
for bb in body. basic_blocks . indices ( ) {
93
- finder. start_from_switch ( bb) ;
93
+ let old_len = finder. opportunities . len ( ) ;
94
+ // If we have any const-eval errors discard any opportunities found
95
+ if finder. start_from_switch ( bb) . is_none ( ) {
96
+ finder. opportunities . truncate ( old_len) ;
97
+ }
94
98
}
95
99
96
100
let opportunities = finder. opportunities ;
@@ -150,14 +154,6 @@ impl Condition {
150
154
fn matches ( & self , value : ScalarInt ) -> bool {
151
155
( self . value == value) == ( self . polarity == Polarity :: Eq )
152
156
}
153
-
154
- fn inv ( mut self ) -> Self {
155
- self . polarity = match self . polarity {
156
- Polarity :: Eq => Polarity :: Ne ,
157
- Polarity :: Ne => Polarity :: Eq ,
158
- } ;
159
- self
160
- }
161
157
}
162
158
163
159
#[ derive( Copy , Clone , Debug ) ]
@@ -180,8 +176,21 @@ impl<'a> ConditionSet<'a> {
180
176
self . iter ( ) . filter ( move |c| c. matches ( value) )
181
177
}
182
178
183
- fn map ( self , arena : & ' a DroplessArena , f : impl Fn ( Condition ) -> Condition ) -> ConditionSet < ' a > {
184
- ConditionSet ( arena. alloc_from_iter ( self . iter ( ) . map ( f) ) )
179
+ fn map (
180
+ self ,
181
+ arena : & ' a DroplessArena ,
182
+ f : impl Fn ( Condition ) -> Option < Condition > ,
183
+ ) -> Option < ConditionSet < ' a > > {
184
+ let mut all_ok = true ;
185
+ let set = arena. alloc_from_iter ( self . iter ( ) . map_while ( |c| {
186
+ if let Some ( c) = f ( c) {
187
+ Some ( c)
188
+ } else {
189
+ all_ok = false ;
190
+ None
191
+ }
192
+ } ) ) ;
193
+ all_ok. then_some ( ConditionSet ( set) )
185
194
}
186
195
}
187
196
@@ -192,28 +201,28 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
192
201
193
202
/// Recursion entry point to find threading opportunities.
194
203
#[ instrument( level = "trace" , skip( self ) ) ]
195
- fn start_from_switch ( & mut self , bb : BasicBlock ) {
204
+ fn start_from_switch ( & mut self , bb : BasicBlock ) -> Option < ( ) > {
196
205
let bbdata = & self . body [ bb] ;
197
206
if bbdata. is_cleanup || self . loop_headers . contains ( bb) {
198
- return ;
207
+ return Some ( ( ) ) ;
199
208
}
200
- let Some ( ( discr, targets) ) = bbdata. terminator ( ) . kind . as_switch ( ) else { return } ;
201
- let Some ( discr) = discr. place ( ) else { return } ;
209
+ let Some ( ( discr, targets) ) = bbdata. terminator ( ) . kind . as_switch ( ) else { return Some ( ( ) ) } ;
210
+ let Some ( discr) = discr. place ( ) else { return Some ( ( ) ) } ;
202
211
debug ! ( ?discr, ?bb) ;
203
212
204
213
let discr_ty = discr. ty ( self . body , self . tcx ) . ty ;
205
214
let Ok ( discr_layout) = self . ecx . layout_of ( discr_ty) else {
206
- return ;
215
+ return Some ( ( ) ) ;
207
216
} ;
208
217
209
- let Some ( discr) = self . map . find ( discr. as_ref ( ) ) else { return } ;
218
+ let Some ( discr) = self . map . find ( discr. as_ref ( ) ) else { return Some ( ( ) ) } ;
210
219
debug ! ( ?discr) ;
211
220
212
221
let cost = CostChecker :: new ( self . tcx , self . typing_env , None , self . body ) ;
213
222
let mut state = State :: new_reachable ( ) ;
214
223
215
224
let conds = if let Some ( ( value, then, else_) ) = targets. as_static_if ( ) {
216
- let Some ( value) = ScalarInt :: try_from_uint ( value, discr_layout. size ) else { return } ;
225
+ let value = ScalarInt :: try_from_uint ( value, discr_layout. size ) ? ;
217
226
self . arena . alloc_from_iter ( [
218
227
Condition { value, polarity : Polarity :: Eq , target : then } ,
219
228
Condition { value, polarity : Polarity :: Ne , target : else_ } ,
@@ -227,7 +236,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
227
236
let conds = ConditionSet ( conds) ;
228
237
state. insert_value_idx ( discr, conds, & self . map ) ;
229
238
230
- self . find_opportunity ( bb, state, cost, 0 ) ;
239
+ self . find_opportunity ( bb, state, cost, 0 )
231
240
}
232
241
233
242
/// Recursively walk statements backwards from this bb's terminator to find threading
@@ -239,27 +248,27 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
239
248
mut state : State < ConditionSet < ' a > > ,
240
249
mut cost : CostChecker < ' _ , ' tcx > ,
241
250
depth : usize ,
242
- ) {
251
+ ) -> Option < ( ) > {
243
252
// Do not thread through loop headers.
244
253
if self . loop_headers . contains ( bb) {
245
- return ;
254
+ return Some ( ( ) ) ;
246
255
}
247
256
248
257
debug ! ( cost = ?cost. cost( ) ) ;
249
258
for ( statement_index, stmt) in
250
259
self . body . basic_blocks [ bb] . statements . iter ( ) . enumerate ( ) . rev ( )
251
260
{
252
261
if self . is_empty ( & state) {
253
- return ;
262
+ return Some ( ( ) ) ;
254
263
}
255
264
256
265
cost. visit_statement ( stmt, Location { block : bb, statement_index } ) ;
257
266
if cost. cost ( ) > MAX_COST {
258
- return ;
267
+ return Some ( ( ) ) ;
259
268
}
260
269
261
270
// Attempt to turn the `current_condition` on `lhs` into a condition on another place.
262
- self . process_statement ( bb, stmt, & mut state) ;
271
+ self . process_statement ( bb, stmt, & mut state) ? ;
263
272
264
273
// When a statement mutates a place, assignments to that place that happen
265
274
// above the mutation cannot fulfill a condition.
@@ -271,7 +280,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
271
280
}
272
281
273
282
if self . is_empty ( & state) || depth >= MAX_BACKTRACK {
274
- return ;
283
+ return Some ( ( ) ) ;
275
284
}
276
285
277
286
let last_non_rec = self . opportunities . len ( ) ;
@@ -284,9 +293,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
284
293
match term. kind {
285
294
TerminatorKind :: SwitchInt { ref discr, ref targets } => {
286
295
self . process_switch_int ( discr, targets, bb, & mut state) ;
287
- self . find_opportunity ( pred, state, cost, depth + 1 ) ;
296
+ self . find_opportunity ( pred, state, cost, depth + 1 ) ? ;
288
297
}
289
- _ => self . recurse_through_terminator ( pred, || state, & cost, depth) ,
298
+ _ => self . recurse_through_terminator ( pred, || state, & cost, depth) ? ,
290
299
}
291
300
} else if let & [ ref predecessors @ .., last_pred] = & predecessors[ ..] {
292
301
for & pred in predecessors {
@@ -311,12 +320,13 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
311
320
let first = & mut new_tos[ 0 ] ;
312
321
* first = ThreadingOpportunity { chain : vec ! [ bb] , target : first. target } ;
313
322
self . opportunities . truncate ( last_non_rec + 1 ) ;
314
- return ;
323
+ return Some ( ( ) ) ;
315
324
}
316
325
317
326
for op in self . opportunities [ last_non_rec..] . iter_mut ( ) {
318
327
op. chain . push ( bb) ;
319
328
}
329
+ Some ( ( ) )
320
330
}
321
331
322
332
/// Extract the mutated place from a statement.
@@ -430,23 +440,23 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
430
440
lhs : PlaceIndex ,
431
441
rhs : & Operand < ' tcx > ,
432
442
state : & mut State < ConditionSet < ' a > > ,
433
- ) {
443
+ ) -> Option < ( ) > {
434
444
match rhs {
435
445
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
436
446
Operand :: Constant ( constant) => {
437
- let Some ( constant) =
438
- self . ecx . eval_mir_constant ( & constant. const_ , constant. span , None ) . discard_err ( )
439
- else {
440
- return ;
441
- } ;
447
+ let constant = self
448
+ . ecx
449
+ . eval_mir_constant ( & constant. const_ , constant. span , None )
450
+ . discard_err ( ) ?;
442
451
self . process_constant ( bb, lhs, constant, state) ;
443
452
}
444
453
// Transfer the conditions on the copied rhs.
445
454
Operand :: Move ( rhs) | Operand :: Copy ( rhs) => {
446
- let Some ( rhs) = self . map . find ( rhs. as_ref ( ) ) else { return } ;
455
+ let Some ( rhs) = self . map . find ( rhs. as_ref ( ) ) else { return Some ( ( ) ) } ;
447
456
state. insert_place_idx ( rhs, lhs, & self . map ) ;
448
457
}
449
458
}
459
+ Some ( ( ) )
450
460
}
451
461
452
462
#[ instrument( level = "trace" , skip( self ) ) ]
@@ -456,22 +466,26 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
456
466
lhs_place : & Place < ' tcx > ,
457
467
rhs : & Rvalue < ' tcx > ,
458
468
state : & mut State < ConditionSet < ' a > > ,
459
- ) {
460
- let Some ( lhs) = self . map . find ( lhs_place. as_ref ( ) ) else { return } ;
469
+ ) -> Option < ( ) > {
470
+ let Some ( lhs) = self . map . find ( lhs_place. as_ref ( ) ) else {
471
+ return Some ( ( ) ) ;
472
+ } ;
461
473
match rhs {
462
- Rvalue :: Use ( operand) => self . process_operand ( bb, lhs, operand, state) ,
474
+ Rvalue :: Use ( operand) => self . process_operand ( bb, lhs, operand, state) ? ,
463
475
// Transfer the conditions on the copy rhs.
464
- Rvalue :: CopyForDeref ( rhs) => self . process_operand ( bb, lhs, & Operand :: Copy ( * rhs) , state) ,
476
+ Rvalue :: CopyForDeref ( rhs) => {
477
+ self . process_operand ( bb, lhs, & Operand :: Copy ( * rhs) , state) ?
478
+ }
465
479
Rvalue :: Discriminant ( rhs) => {
466
- let Some ( rhs) = self . map . find_discr ( rhs. as_ref ( ) ) else { return } ;
480
+ let Some ( rhs) = self . map . find_discr ( rhs. as_ref ( ) ) else { return Some ( ( ) ) } ;
467
481
state. insert_place_idx ( rhs, lhs, & self . map ) ;
468
482
}
469
483
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
470
484
Rvalue :: Aggregate ( box kind, operands) => {
471
485
let agg_ty = lhs_place. ty ( self . body , self . tcx ) . ty ;
472
486
let lhs = match kind {
473
487
// Do not support unions.
474
- AggregateKind :: Adt ( .., Some ( _) ) => return ,
488
+ AggregateKind :: Adt ( .., Some ( _) ) => return Some ( ( ) ) ,
475
489
AggregateKind :: Adt ( _, variant_index, ..) if agg_ty. is_enum ( ) => {
476
490
if let Some ( discr_target) = self . map . apply ( lhs, TrackElem :: Discriminant )
477
491
&& let Some ( discr_value) = self
@@ -484,30 +498,31 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
484
498
if let Some ( idx) = self . map . apply ( lhs, TrackElem :: Variant ( * variant_index) ) {
485
499
idx
486
500
} else {
487
- return ;
501
+ return Some ( ( ) ) ;
488
502
}
489
503
}
490
504
_ => lhs,
491
505
} ;
492
506
for ( field_index, operand) in operands. iter_enumerated ( ) {
493
507
if let Some ( field) = self . map . apply ( lhs, TrackElem :: Field ( field_index) ) {
494
- self . process_operand ( bb, field, operand, state) ;
508
+ self . process_operand ( bb, field, operand, state) ? ;
495
509
}
496
510
}
497
511
}
498
- // Transfer the conditions on the copy rhs, after inversing polarity .
512
+ // Transfer the conditions on the copy rhs, after inverting the value of the condition .
499
513
Rvalue :: UnaryOp ( UnOp :: Not , Operand :: Move ( place) | Operand :: Copy ( place) ) => {
500
- if !place. ty ( self . body , self . tcx ) . ty . is_bool ( ) {
501
- // Constructing the conditions by inverting the polarity
502
- // of equality is only correct for bools. That is to say,
503
- // `!a == b` is not `a != b` for integers greater than 1 bit.
504
- return ;
505
- }
506
- let Some ( conditions) = state. try_get_idx ( lhs, & self . map ) else { return } ;
507
- let Some ( place) = self . map . find ( place. as_ref ( ) ) else { return } ;
508
- // FIXME: I think This could be generalized to not bool if we
509
- // actually perform a logical not on the condition's value.
510
- let conds = conditions. map ( self . arena , Condition :: inv) ;
514
+ let layout = self . ecx . layout_of ( place. ty ( self . body , self . tcx ) . ty ) . unwrap ( ) ;
515
+ let Some ( conditions) = state. try_get_idx ( lhs, & self . map ) else { return Some ( ( ) ) } ;
516
+ let Some ( place) = self . map . find ( place. as_ref ( ) ) else { return Some ( ( ) ) } ;
517
+ let conds = conditions. map ( self . arena , |mut cond| {
518
+ cond. value = self
519
+ . ecx
520
+ . unary_op ( UnOp :: Not , & ImmTy :: from_scalar_int ( cond. value , layout) )
521
+ . discard_err ( ) ?
522
+ . to_scalar_int ( )
523
+ . discard_err ( ) ?;
524
+ Some ( cond)
525
+ } ) ?;
511
526
state. insert_value_idx ( place, conds, & self . map ) ;
512
527
}
513
528
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
@@ -517,34 +532,34 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
517
532
box ( Operand :: Move ( place) | Operand :: Copy ( place) , Operand :: Constant ( value) )
518
533
| box ( Operand :: Constant ( value) , Operand :: Move ( place) | Operand :: Copy ( place) ) ,
519
534
) => {
520
- let Some ( conditions) = state. try_get_idx ( lhs, & self . map ) else { return } ;
521
- let Some ( place) = self . map . find ( place. as_ref ( ) ) else { return } ;
535
+ let Some ( conditions) = state. try_get_idx ( lhs, & self . map ) else { return Some ( ( ) ) } ;
536
+ let Some ( place) = self . map . find ( place. as_ref ( ) ) else { return Some ( ( ) ) } ;
522
537
let equals = match op {
523
538
BinOp :: Eq => ScalarInt :: TRUE ,
524
539
BinOp :: Ne => ScalarInt :: FALSE ,
525
- _ => return ,
540
+ _ => return Some ( ( ) ) ,
526
541
} ;
527
542
if value. const_ . ty ( ) . is_floating_point ( ) {
528
543
// Floating point equality does not follow bit-patterns.
529
544
// -0.0 and NaN both have special rules for equality,
530
545
// and therefore we cannot use integer comparisons for them.
531
546
// Avoid handling them, though this could be extended in the future.
532
- return ;
547
+ return Some ( ( ) ) ;
533
548
}
534
- let Some ( value) = value. const_ . try_eval_scalar_int ( self . tcx , self . typing_env )
535
- else {
536
- return ;
537
- } ;
538
- let conds = conditions. map ( self . arena , |c| Condition {
539
- value,
540
- polarity : if c. matches ( equals) { Polarity :: Eq } else { Polarity :: Ne } ,
541
- ..c
542
- } ) ;
549
+ let value = value. const_ . try_eval_scalar_int ( self . tcx , self . typing_env ) ?;
550
+ let conds = conditions. map ( self . arena , |c| {
551
+ Some ( Condition {
552
+ value,
553
+ polarity : if c. matches ( equals) { Polarity :: Eq } else { Polarity :: Ne } ,
554
+ ..c
555
+ } )
556
+ } ) ?;
543
557
state. insert_value_idx ( place, conds, & self . map ) ;
544
558
}
545
559
546
560
_ => { }
547
561
}
562
+ Some ( ( ) )
548
563
}
549
564
550
565
#[ instrument( level = "trace" , skip( self ) ) ]
@@ -553,7 +568,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
553
568
bb : BasicBlock ,
554
569
stmt : & Statement < ' tcx > ,
555
570
state : & mut State < ConditionSet < ' a > > ,
556
- ) {
571
+ ) -> Option < ( ) > {
557
572
let register_opportunity = |c : Condition | {
558
573
debug ! ( ?bb, ?c. target, "register" ) ;
559
574
self . opportunities . push ( ThreadingOpportunity { chain : vec ! [ bb] , target : c. target } )
@@ -566,30 +581,32 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
566
581
// If we expect `discriminant(place) ?= A`,
567
582
// we have an opportunity if `variant_index ?= A`.
568
583
StatementKind :: SetDiscriminant { box place, variant_index } => {
569
- let Some ( discr_target) = self . map . find_discr ( place. as_ref ( ) ) else { return } ;
584
+ let Some ( discr_target) = self . map . find_discr ( place. as_ref ( ) ) else {
585
+ return Some ( ( ) ) ;
586
+ } ;
570
587
let enum_ty = place. ty ( self . body , self . tcx ) . ty ;
571
588
// `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
572
589
// Even if the discriminant write does nothing due to niches, it is UB to set the
573
590
// discriminant when the data does not encode the desired discriminant.
574
- let Some ( discr) =
575
- self . ecx . discriminant_for_variant ( enum_ty, * variant_index) . discard_err ( )
576
- else {
577
- return ;
578
- } ;
591
+ let discr =
592
+ self . ecx . discriminant_for_variant ( enum_ty, * variant_index) . discard_err ( ) ?;
579
593
self . process_immediate ( bb, discr_target, discr, state) ;
580
594
}
581
595
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
582
596
StatementKind :: Intrinsic ( box NonDivergingIntrinsic :: Assume (
583
597
Operand :: Copy ( place) | Operand :: Move ( place) ,
584
598
) ) => {
585
- let Some ( conditions) = state. try_get ( place. as_ref ( ) , & self . map ) else { return } ;
599
+ let Some ( conditions) = state. try_get ( place. as_ref ( ) , & self . map ) else {
600
+ return Some ( ( ) ) ;
601
+ } ;
586
602
conditions. iter_matches ( ScalarInt :: TRUE ) . for_each ( register_opportunity) ;
587
603
}
588
604
StatementKind :: Assign ( box ( lhs_place, rhs) ) => {
589
- self . process_assign ( bb, lhs_place, rhs, state) ;
605
+ self . process_assign ( bb, lhs_place, rhs, state) ? ;
590
606
}
591
607
_ => { }
592
608
}
609
+ Some ( ( ) )
593
610
}
594
611
595
612
#[ instrument( level = "trace" , skip( self , state, cost) ) ]
@@ -600,7 +617,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
600
617
state : impl FnOnce ( ) -> State < ConditionSet < ' a > > ,
601
618
cost : & CostChecker < ' _ , ' tcx > ,
602
619
depth : usize ,
603
- ) {
620
+ ) -> Option < ( ) > {
604
621
let term = self . body . basic_blocks [ bb] . terminator ( ) ;
605
622
let place_to_flood = match term. kind {
606
623
// We come from a target, so those are not possible.
@@ -615,9 +632,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
615
632
| TerminatorKind :: FalseUnwind { .. }
616
633
| TerminatorKind :: Yield { .. } => bug ! ( "{term:?} invalid" ) ,
617
634
// Cannot reason about inline asm.
618
- TerminatorKind :: InlineAsm { .. } => return ,
635
+ TerminatorKind :: InlineAsm { .. } => return Some ( ( ) ) ,
619
636
// `SwitchInt` is handled specially.
620
- TerminatorKind :: SwitchInt { .. } => return ,
637
+ TerminatorKind :: SwitchInt { .. } => return Some ( ( ) ) ,
621
638
// We can recurse, no thing particular to do.
622
639
TerminatorKind :: Goto { .. } => None ,
623
640
// Flood the overwritten place, and progress through.
@@ -632,7 +649,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
632
649
if let Some ( place_to_flood) = place_to_flood {
633
650
state. flood_with ( place_to_flood. as_ref ( ) , & self . map , ConditionSet :: BOTTOM ) ;
634
651
}
635
- self . find_opportunity ( bb, state, cost. clone ( ) , depth + 1 ) ;
652
+ self . find_opportunity ( bb, state, cost. clone ( ) , depth + 1 )
636
653
}
637
654
638
655
#[ instrument( level = "trace" , skip( self ) ) ]
0 commit comments