@@ -167,7 +167,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
167
167
end
168
168
169
169
function simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
170
- callback = (@nospecialize (pi ), @nospecialize (idx)) -> false )
170
+ callback = (@nospecialize (x ), @nospecialize (idx)) -> false )
171
171
while true
172
172
if isa (defssa, OldSSAValue)
173
173
if already_inserted (compact, defssa)
@@ -335,10 +335,29 @@ struct LiftedValue
335
335
end
336
336
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
337
337
338
+ # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
339
+ # which can be very large sometimes, and program counters in question are often very sparse
340
+ const SPCSet = IdSet{Int}
341
+
342
+ mutable struct NestedLoads
343
+ maybe:: Union{Nothing,SPCSet}
344
+ NestedLoads () = new (nothing )
345
+ end
346
+ function record_nested_load! (nested_loads:: NestedLoads , pc:: Int )
347
+ maybe = nested_loads. maybe
348
+ maybe === nothing && (maybe = nested_loads. maybe = SPCSet ())
349
+ push! (maybe:: SPCSet , pc)
350
+ end
351
+ function is_nested_load (nested_loads:: NestedLoads , pc:: Int )
352
+ maybe = nested_loads. maybe
353
+ maybe === nothing && return false
354
+ return pc in maybe:: SPCSet
355
+ end
356
+
338
357
# try to compute lifted values that can replace `getfield(x, field)` call
339
358
# where `x` is an immutable struct that are defined at any of `leaves`
340
- function lift_leaves (compact:: IncrementalCompact ,
341
- @nospecialize (result_t), field:: Int , leaves :: Vector{Any} )
359
+ function lift_leaves! (compact:: IncrementalCompact , leaves :: Vector{Any} ,
360
+ @nospecialize (result_t), field:: Int , nested_loads :: NestedLoads )
342
361
# For every leaf, the lifted value
343
362
lifted_leaves = LiftedLeaves ()
344
363
maybe_undef = false
@@ -388,11 +407,19 @@ function lift_leaves(compact::IncrementalCompact,
388
407
ocleaf = simple_walk (compact, ocleaf)
389
408
end
390
409
ocdef, _ = walk_to_def (compact, ocleaf)
391
- if isexpr (ocdef, :new_opaque_closure ) && isa (field, Int) && 1 ≤ field ≤ length (ocdef. args)- 5
410
+ if isexpr (ocdef, :new_opaque_closure ) && 1 ≤ field ≤ length (ocdef. args)- 5
392
411
lift_arg! (compact, leaf, cache_key, ocdef, 5 + field, lifted_leaves)
393
412
continue
394
413
end
395
414
return nothing
415
+ elseif is_known_call (def, getfield, compact)
416
+ if isa (leaf, SSAValue)
417
+ struct_typ = unwrap_unionall (widenconst (argextype (def. args[2 ], compact)))
418
+ if ismutabletype (struct_typ)
419
+ record_nested_load! (nested_loads, leaf. id)
420
+ end
421
+ end
422
+ return nothing
396
423
else
397
424
typ = argextype (leaf, compact)
398
425
if ! isa (typ, Const)
@@ -586,7 +613,7 @@ function perform_lifting!(compact::IncrementalCompact,
586
613
end
587
614
val = lifted_val. x
588
615
if isa (val, AnySSAValue)
589
- callback = (@nospecialize (pi ), @nospecialize (idx)) -> true
616
+ callback = (@nospecialize (x ), @nospecialize (idx)) -> true
590
617
val = simple_walk (compact, val, callback)
591
618
end
592
619
push! (new_node. values, val)
@@ -617,10 +644,6 @@ function perform_lifting!(compact::IncrementalCompact,
617
644
return stmt_val # N.B. should never happen
618
645
end
619
646
620
- # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
621
- # which can be very large sometimes, and program counters in question are often very sparse
622
- const SPCSet = IdSet{Int}
623
-
624
647
"""
625
648
sroa_pass!(ir::IRCode) -> newir::IRCode
626
649
@@ -639,10 +662,11 @@ its argument).
639
662
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
640
663
a result of succeeding dead code elimination.
641
664
"""
642
- function sroa_pass! (ir:: IRCode )
665
+ function sroa_pass! (ir:: IRCode , optional_opts :: Bool = true )
643
666
compact = IncrementalCompact (ir)
644
667
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
645
668
lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
669
+ nested_loads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
646
670
for ((_, idx), stmt) in compact
647
671
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
648
672
isa (stmt, Expr) || continue
@@ -670,7 +694,7 @@ function sroa_pass!(ir::IRCode)
670
694
preserved_arg = stmt. args[pidx]
671
695
isa (preserved_arg, SSAValue) || continue
672
696
let intermediaries = SPCSet ()
673
- callback = function (@nospecialize (pi ), @nospecialize (ssa))
697
+ callback = function (@nospecialize (x ), @nospecialize (ssa))
674
698
push! (intermediaries, ssa. id)
675
699
return false
676
700
end
@@ -698,7 +722,9 @@ function sroa_pass!(ir::IRCode)
698
722
if defuses === nothing
699
723
defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
700
724
end
701
- mid, defuse = get! (defuses, defidx, (SPCSet (), SSADefUse ()))
725
+ mid, defuse = get! (defuses, defidx) do
726
+ SPCSet (), SSADefUse ()
727
+ end
702
728
push! (defuse. ccall_preserve_uses, idx)
703
729
union! (mid, intermediaries)
704
730
end
@@ -708,16 +734,17 @@ function sroa_pass!(ir::IRCode)
708
734
compact[idx] = form_new_preserves (stmt, preserved, new_preserves)
709
735
end
710
736
continue
711
- # TODO : This isn't the best place to put these
712
- elseif is_known_call (stmt, typeassert, compact)
713
- canonicalize_typeassert! (compact, idx, stmt)
714
- continue
715
- elseif is_known_call (stmt, (=== ), compact)
716
- lift_comparison! (compact, idx, stmt, lifting_cache)
717
- continue
718
- # elseif is_known_call(stmt, isa, compact)
719
- # TODO do a similar optimization as `lift_comparison!` for `===`
720
737
else
738
+ if optional_opts
739
+ # TODO : This isn't the best place to put these
740
+ if is_known_call (stmt, typeassert, compact)
741
+ canonicalize_typeassert! (compact, idx, stmt)
742
+ elseif is_known_call (stmt, (=== ), compact)
743
+ lift_comparison! (compact, idx, stmt, lifting_cache)
744
+ # elseif is_known_call(stmt, isa, compact)
745
+ # TODO do a similar optimization as `lift_comparison!` for `===`
746
+ end
747
+ end
721
748
continue
722
749
end
723
750
@@ -743,7 +770,7 @@ function sroa_pass!(ir::IRCode)
743
770
if ismutabletype (struct_typ)
744
771
isa (val, SSAValue) || continue
745
772
let intermediaries = SPCSet ()
746
- callback = function (@nospecialize (pi ), @nospecialize (ssa))
773
+ callback = function (@nospecialize (x ), @nospecialize (ssa))
747
774
push! (intermediaries, ssa. id)
748
775
return false
749
776
end
@@ -753,7 +780,9 @@ function sroa_pass!(ir::IRCode)
753
780
if defuses === nothing
754
781
defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
755
782
end
756
- mid, defuse = get! (defuses, def. id, (SPCSet (), SSADefUse ()))
783
+ mid, defuse = get! (defuses, def. id) do
784
+ SPCSet (), SSADefUse ()
785
+ end
757
786
if is_setfield
758
787
push! (defuse. defs, idx)
759
788
else
@@ -775,7 +804,7 @@ function sroa_pass!(ir::IRCode)
775
804
isempty (leaves) && continue
776
805
777
806
result_t = argextype (SSAValue (idx), compact)
778
- lifted_result = lift_leaves (compact, result_t, field, leaves )
807
+ lifted_result = lift_leaves! (compact, leaves, result_t, field, nested_loads )
779
808
lifted_result === nothing && continue
780
809
lifted_leaves, any_undef = lifted_result
781
810
@@ -811,21 +840,25 @@ function sroa_pass!(ir::IRCode)
811
840
used_ssas = copy (compact. used_ssas)
812
841
simple_dce! (compact, (x:: SSAValue ) -> used_ssas[x. id] -= 1 )
813
842
ir = complete (compact)
814
- sroa_mutables! (ir, defuses, used_ssas)
815
- return ir
843
+ return sroa_mutables! (ir, defuses, used_ssas, nested_loads)
816
844
else
817
845
simple_dce! (compact)
818
846
return complete (compact)
819
847
end
820
848
end
821
849
822
- function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} )
850
+ function sroa_mutables! (ir:: IRCode ,
851
+ defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} ,
852
+ nested_loads:: NestedLoads )
823
853
# Compute domtree, needed below, now that we have finished compacting the IR.
824
854
# This needs to be after we iterate through the IR with `IncrementalCompact`
825
855
# because removing dead blocks can invalidate the domtree.
826
856
@timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks)
827
857
828
- for (idx, (intermediaries, defuse)) in defuses
858
+ nested_mloads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
859
+ local any_eliminated = false
860
+ # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
861
+ for (idx, (intermediaries, defuse)) in sort! (collect (defuses); by= first, rev= true )
829
862
intermediaries = collect (intermediaries)
830
863
# Check if there are any uses we did not account for. If so, the variable
831
864
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
@@ -840,7 +873,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
840
873
nleaves == nuses_total || continue
841
874
# Find the type for this allocation
842
875
defexpr = ir[SSAValue (idx)]
843
- isexpr (defexpr, :new ) || continue
876
+ isa (defexpr, Expr) || continue
877
+ if ! isexpr (defexpr, :new )
878
+ if is_known_call (defexpr, getfield, ir)
879
+ val = defexpr. args[2 ]
880
+ if isa (val, SSAValue)
881
+ struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
882
+ if ismutabletype (struct_typ)
883
+ record_nested_load! (nested_mloads, idx)
884
+ end
885
+ end
886
+ end
887
+ continue
888
+ end
844
889
newidx = idx
845
890
typ = ir. stmts[newidx][:type ]
846
891
if isa (typ, UnionAll)
@@ -910,6 +955,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
910
955
# Now go through all uses and rewrite them
911
956
for stmt in du. uses
912
957
ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
958
+ if ! any_eliminated
959
+ any_eliminated |= (is_nested_load (nested_loads, stmt) ||
960
+ is_nested_load (nested_mloads, stmt))
961
+ end
913
962
end
914
963
if ! isbitstype (ftyp)
915
964
if preserve_uses != = nothing
@@ -946,6 +995,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
946
995
947
996
@label skip
948
997
end
998
+ if any_eliminated
999
+ return sroa_pass! (compact! (ir), false )
1000
+ else
1001
+ return ir
1002
+ end
949
1003
end
950
1004
951
1005
function form_new_preserves (origex:: Expr , intermediates:: Vector{Int} , new_preserves:: Vector{Any} )
0 commit comments