@@ -169,7 +169,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
169
169
end
170
170
171
171
function simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
172
- callback = (@nospecialize (pi ), @nospecialize (idx)) -> false )
172
+ callback = (@nospecialize (x ), @nospecialize (idx)) -> false )
173
173
while true
174
174
if isa (defssa, OldSSAValue)
175
175
if already_inserted (compact, defssa)
@@ -337,10 +337,29 @@ struct LiftedValue
337
337
end
338
338
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
339
339
340
+ # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
341
+ # which can be very large sometimes, and program counters in question are often very sparse
342
+ const SPCSet = IdSet{Int}
343
+
344
+ mutable struct NestedLoads
345
+ maybe:: Union{Nothing,SPCSet}
346
+ NestedLoads () = new (nothing )
347
+ end
348
+ function record_nested_load! (nested_loads:: NestedLoads , pc:: Int )
349
+ maybe = nested_loads. maybe
350
+ maybe === nothing && (maybe = nested_loads. maybe = SPCSet ())
351
+ push! (maybe:: SPCSet , pc)
352
+ end
353
+ function is_nested_load (nested_loads:: NestedLoads , pc:: Int )
354
+ maybe = nested_loads. maybe
355
+ maybe === nothing && return false
356
+ return pc in maybe:: SPCSet
357
+ end
358
+
340
359
# try to compute lifted values that can replace `getfield(x, field)` call
341
360
# where `x` is an immutable struct that are defined at any of `leaves`
342
- function lift_leaves (compact:: IncrementalCompact ,
343
- @nospecialize (result_t), field:: Int , leaves :: Vector{Any} )
361
+ function lift_leaves! (compact:: IncrementalCompact , leaves :: Vector{Any} ,
362
+ @nospecialize (result_t), field:: Int , nested_loads :: NestedLoads )
344
363
# For every leaf, the lifted value
345
364
lifted_leaves = LiftedLeaves ()
346
365
maybe_undef = false
@@ -390,11 +409,19 @@ function lift_leaves(compact::IncrementalCompact,
390
409
ocleaf = simple_walk (compact, ocleaf)
391
410
end
392
411
ocdef, _ = walk_to_def (compact, ocleaf)
393
- if isexpr (ocdef, :new_opaque_closure ) && isa (field, Int) && 1 ≤ field ≤ length (ocdef. args)- 5
412
+ if isexpr (ocdef, :new_opaque_closure ) && 1 ≤ field ≤ length (ocdef. args)- 5
394
413
lift_arg! (compact, leaf, cache_key, ocdef, 5 + field, lifted_leaves)
395
414
continue
396
415
end
397
416
return nothing
417
+ elseif is_known_call (def, getfield, compact)
418
+ if isa (leaf, SSAValue)
419
+ struct_typ = unwrap_unionall (widenconst (argextype (def. args[2 ], compact)))
420
+ if ismutabletype (struct_typ)
421
+ record_nested_load! (nested_loads, leaf. id)
422
+ end
423
+ end
424
+ return nothing
398
425
else
399
426
typ = argextype (leaf, compact)
400
427
if ! isa (typ, Const)
@@ -588,7 +615,7 @@ function perform_lifting!(compact::IncrementalCompact,
588
615
end
589
616
val = lifted_val. x
590
617
if isa (val, AnySSAValue)
591
- callback = (@nospecialize (pi ), @nospecialize (idx)) -> true
618
+ callback = (@nospecialize (x ), @nospecialize (idx)) -> true
592
619
val = simple_walk (compact, val, callback)
593
620
end
594
621
push! (new_node. values, val)
@@ -619,10 +646,6 @@ function perform_lifting!(compact::IncrementalCompact,
619
646
return stmt_val # N.B. should never happen
620
647
end
621
648
622
- # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
623
- # which can be very large sometimes, and program counters in question are often very sparse
624
- const SPCSet = IdSet{Int}
625
-
626
649
"""
627
650
sroa_pass!(ir::IRCode) -> newir::IRCode
628
651
@@ -641,10 +664,11 @@ its argument).
641
664
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
642
665
a result of succeeding dead code elimination.
643
666
"""
644
- function sroa_pass! (ir:: IRCode )
667
+ function sroa_pass! (ir:: IRCode , optional_opts :: Bool = true )
645
668
compact = IncrementalCompact (ir)
646
669
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
647
670
lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
671
+ nested_loads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
648
672
for ((_, idx), stmt) in compact
649
673
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
650
674
isa (stmt, Expr) || continue
@@ -672,7 +696,7 @@ function sroa_pass!(ir::IRCode)
672
696
preserved_arg = stmt. args[pidx]
673
697
isa (preserved_arg, SSAValue) || continue
674
698
let intermediaries = SPCSet ()
675
- callback = function (@nospecialize (pi ), @nospecialize (ssa))
699
+ callback = function (@nospecialize (x ), @nospecialize (ssa))
676
700
push! (intermediaries, ssa. id)
677
701
return false
678
702
end
@@ -700,7 +724,9 @@ function sroa_pass!(ir::IRCode)
700
724
if defuses === nothing
701
725
defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
702
726
end
703
- mid, defuse = get! (defuses, defidx, (SPCSet (), SSADefUse ()))
727
+ mid, defuse = get! (defuses, defidx) do
728
+ SPCSet (), SSADefUse ()
729
+ end
704
730
push! (defuse. ccall_preserve_uses, idx)
705
731
union! (mid, intermediaries)
706
732
end
@@ -710,16 +736,17 @@ function sroa_pass!(ir::IRCode)
710
736
compact[idx] = form_new_preserves (stmt, preserved, new_preserves)
711
737
end
712
738
continue
713
- # TODO : This isn't the best place to put these
714
- elseif is_known_call (stmt, typeassert, compact)
715
- canonicalize_typeassert! (compact, idx, stmt)
716
- continue
717
- elseif is_known_call (stmt, (=== ), compact)
718
- lift_comparison! (compact, idx, stmt, lifting_cache)
719
- continue
720
- # elseif is_known_call(stmt, isa, compact)
721
- # TODO do a similar optimization as `lift_comparison!` for `===`
722
739
else
740
+ if optional_opts
741
+ # TODO : This isn't the best place to put these
742
+ if is_known_call (stmt, typeassert, compact)
743
+ canonicalize_typeassert! (compact, idx, stmt)
744
+ elseif is_known_call (stmt, (=== ), compact)
745
+ lift_comparison! (compact, idx, stmt, lifting_cache)
746
+ # elseif is_known_call(stmt, isa, compact)
747
+ # TODO do a similar optimization as `lift_comparison!` for `===`
748
+ end
749
+ end
723
750
continue
724
751
end
725
752
@@ -745,7 +772,7 @@ function sroa_pass!(ir::IRCode)
745
772
if ismutabletype (struct_typ)
746
773
isa (val, SSAValue) || continue
747
774
let intermediaries = SPCSet ()
748
- callback = function (@nospecialize (pi ), @nospecialize (ssa))
775
+ callback = function (@nospecialize (x ), @nospecialize (ssa))
749
776
push! (intermediaries, ssa. id)
750
777
return false
751
778
end
@@ -755,7 +782,9 @@ function sroa_pass!(ir::IRCode)
755
782
if defuses === nothing
756
783
defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
757
784
end
758
- mid, defuse = get! (defuses, def. id, (SPCSet (), SSADefUse ()))
785
+ mid, defuse = get! (defuses, def. id) do
786
+ SPCSet (), SSADefUse ()
787
+ end
759
788
if is_setfield
760
789
push! (defuse. defs, idx)
761
790
else
@@ -777,7 +806,7 @@ function sroa_pass!(ir::IRCode)
777
806
isempty (leaves) && continue
778
807
779
808
result_t = argextype (SSAValue (idx), compact)
780
- lifted_result = lift_leaves (compact, result_t, field, leaves )
809
+ lifted_result = lift_leaves! (compact, leaves, result_t, field, nested_loads )
781
810
lifted_result === nothing && continue
782
811
lifted_leaves, any_undef = lifted_result
783
812
@@ -813,18 +842,21 @@ function sroa_pass!(ir::IRCode)
813
842
used_ssas = copy (compact. used_ssas)
814
843
simple_dce! (compact, (x:: SSAValue ) -> used_ssas[x. id] -= 1 )
815
844
ir = complete (compact)
816
- sroa_mutables! (ir, defuses, used_ssas)
817
- return ir
845
+ return sroa_mutables! (ir, defuses, used_ssas, nested_loads)
818
846
else
819
847
simple_dce! (compact)
820
848
return complete (compact)
821
849
end
822
850
end
823
851
824
- function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} )
825
- # initialization of domtree is delayed to avoid the expensive computation in many cases
826
- local domtree = nothing
827
- for (idx, (intermediaries, defuse)) in defuses
852
+ function sroa_mutables! (ir:: IRCode ,
853
+ defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} ,
854
+ nested_loads:: NestedLoads )
855
+ domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
856
+ nested_mloads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
857
+ any_eliminated = false
858
+ # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
859
+ for (idx, (intermediaries, defuse)) in sort! (collect (defuses); by= first, rev= true )
828
860
intermediaries = collect (intermediaries)
829
861
# Check if there are any uses we did not account for. If so, the variable
830
862
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
@@ -839,7 +871,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
839
871
nleaves == nuses_total || continue
840
872
# Find the type for this allocation
841
873
defexpr = ir[SSAValue (idx)]
842
- isexpr (defexpr, :new ) || continue
874
+ isa (defexpr, Expr) || continue
875
+ if ! isexpr (defexpr, :new )
876
+ if is_known_call (defexpr, getfield, ir)
877
+ val = defexpr. args[2 ]
878
+ if isa (val, SSAValue)
879
+ struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
880
+ if ismutabletype (struct_typ)
881
+ record_nested_load! (nested_mloads, idx)
882
+ end
883
+ end
884
+ end
885
+ continue
886
+ end
843
887
newidx = idx
844
888
typ = ir. stmts[newidx][:type ]
845
889
if isa (typ, UnionAll)
@@ -919,6 +963,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
919
963
# Now go through all uses and rewrite them
920
964
for stmt in du. uses
921
965
ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
966
+ if ! any_eliminated
967
+ any_eliminated |= (is_nested_load (nested_loads, stmt) ||
968
+ is_nested_load (nested_mloads, stmt))
969
+ end
922
970
end
923
971
if ! isbitstype (ftyp)
924
972
if preserve_uses != = nothing
@@ -955,6 +1003,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
955
1003
956
1004
@label skip
957
1005
end
1006
+ return any_eliminated ? sroa_pass! (compact! (ir), false ) : ir
958
1007
end
959
1008
960
1009
function form_new_preserves (origex:: Expr , intermediates:: Vector{Int} , new_preserves:: Vector{Any} )
0 commit comments