@@ -29,12 +29,14 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
29
29
30
30
compute_live_ins (cfg:: CFG , du:: SSADefUse ) = compute_live_ins (cfg, du. defs, du. uses)
31
31
32
- function try_compute_field_stmt (ir:: Union{IncrementalCompact,IRCode} , stmt:: Expr )
33
- field = stmt. args[3 ]
32
+ try_compute_field_stmt (ir:: Union{IncrementalCompact,IRCode} , stmt:: Expr ) =
33
+ try_compute_field (ir, stmt. args[3 ])
34
+
35
+ function try_compute_field (ir:: Union{IncrementalCompact,IRCode} , @nospecialize (field))
34
36
# fields are usually literals, handle them manually
35
37
if isa (field, QuoteNode)
36
38
field = field. value
37
- elseif isa (field, Int)
39
+ elseif isa (field, Int) || isa (field, Symbol)
38
40
# try to resolve other constants, e.g. global reference
39
41
else
40
42
field = argextype (field, ir)
@@ -44,8 +46,7 @@ function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr
44
46
return nothing
45
47
end
46
48
end
47
- isa (field, Union{Int, Symbol}) || return nothing
48
- return field
49
+ return isa (field, Union{Int, Symbol}) ? field : nothing
49
50
end
50
51
51
52
function try_compute_fieldidx_stmt (ir:: Union{IncrementalCompact,IRCode} , stmt:: Expr , typ:: DataType )
@@ -167,7 +168,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
167
168
end
168
169
169
170
function simple_walk (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
170
- callback = (@nospecialize (pi ), @nospecialize (idx)) -> false )
171
+ callback = (@nospecialize (x ), @nospecialize (idx)) -> false )
171
172
while true
172
173
if isa (defssa, OldSSAValue)
173
174
if already_inserted (compact, defssa)
@@ -335,10 +336,29 @@ struct LiftedValue
335
336
end
336
337
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
337
338
339
+ # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
340
+ # which can be very large sometimes, and program counters in question are often very sparse
341
+ const SPCSet = IdSet{Int}
342
+
343
+ mutable struct NestedLoads
344
+ maybe:: Union{Nothing,SPCSet}
345
+ NestedLoads () = new (nothing )
346
+ end
347
+ function record_nested_load! (nested_loads:: NestedLoads , pc:: Int )
348
+ maybe = nested_loads. maybe
349
+ maybe === nothing && (maybe = nested_loads. maybe = SPCSet ())
350
+ push! (maybe:: SPCSet , pc)
351
+ end
352
+ function is_nested_load (nested_loads:: NestedLoads , pc:: Int )
353
+ maybe = nested_loads. maybe
354
+ maybe === nothing && return false
355
+ return pc in maybe:: SPCSet
356
+ end
357
+
338
358
# try to compute lifted values that can replace `getfield(x, field)` call
339
359
# where `x` is an immutable struct that are defined at any of `leaves`
340
- function lift_leaves (compact:: IncrementalCompact ,
341
- @nospecialize (result_t), field:: Int , leaves :: Vector{Any} )
360
+ function lift_leaves! (compact:: IncrementalCompact , leaves :: Vector{Any} ,
361
+ @nospecialize (result_t), field:: Int , nested_loads :: NestedLoads )
342
362
# For every leaf, the lifted value
343
363
lifted_leaves = LiftedLeaves ()
344
364
maybe_undef = false
@@ -388,11 +408,19 @@ function lift_leaves(compact::IncrementalCompact,
388
408
ocleaf = simple_walk (compact, ocleaf)
389
409
end
390
410
ocdef, _ = walk_to_def (compact, ocleaf)
391
- if isexpr (ocdef, :new_opaque_closure ) && isa (field, Int) && 1 ≤ field ≤ length (ocdef. args)- 5
411
+ if isexpr (ocdef, :new_opaque_closure ) && 1 ≤ field ≤ length (ocdef. args)- 5
392
412
lift_arg! (compact, leaf, cache_key, ocdef, 5 + field, lifted_leaves)
393
413
continue
394
414
end
395
415
return nothing
416
+ elseif is_known_call (def, getfield, compact)
417
+ if isa (leaf, SSAValue)
418
+ struct_typ = unwrap_unionall (widenconst (argextype (def. args[2 ], compact)))
419
+ if ismutabletype (struct_typ)
420
+ record_nested_load! (nested_loads, leaf. id)
421
+ end
422
+ end
423
+ return nothing
396
424
else
397
425
typ = argextype (leaf, compact)
398
426
if ! isa (typ, Const)
@@ -586,7 +614,7 @@ function perform_lifting!(compact::IncrementalCompact,
586
614
end
587
615
val = lifted_val. x
588
616
if isa (val, AnySSAValue)
589
- callback = (@nospecialize (pi ), @nospecialize (idx)) -> true
617
+ callback = (@nospecialize (x ), @nospecialize (idx)) -> true
590
618
val = simple_walk (compact, val, callback)
591
619
end
592
620
push! (new_node. values, val)
@@ -617,10 +645,6 @@ function perform_lifting!(compact::IncrementalCompact,
617
645
return stmt_val # N.B. should never happen
618
646
end
619
647
620
- # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
621
- # which can be very large sometimes, and program counters in question are often very sparse
622
- const SPCSet = IdSet{Int}
623
-
624
648
"""
625
649
sroa_pass!(ir::IRCode) -> newir::IRCode
626
650
@@ -639,10 +663,11 @@ its argument).
639
663
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
640
664
a result of succeeding dead code elimination.
641
665
"""
642
- function sroa_pass! (ir:: IRCode )
666
+ function sroa_pass! (ir:: IRCode , optional_opts :: Bool = true )
643
667
compact = IncrementalCompact (ir)
644
668
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
645
669
lifting_cache = IdDict {Pair{AnySSAValue, Any}, AnySSAValue} ()
670
+ nested_loads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
646
671
for ((_, idx), stmt) in compact
647
672
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
648
673
isa (stmt, Expr) || continue
@@ -670,7 +695,7 @@ function sroa_pass!(ir::IRCode)
670
695
preserved_arg = stmt. args[pidx]
671
696
isa (preserved_arg, SSAValue) || continue
672
697
let intermediaries = SPCSet ()
673
- callback = function (@nospecialize (pi ), @nospecialize (ssa))
698
+ callback = function (@nospecialize (x ), @nospecialize (ssa))
674
699
push! (intermediaries, ssa. id)
675
700
return false
676
701
end
@@ -698,7 +723,9 @@ function sroa_pass!(ir::IRCode)
698
723
if defuses === nothing
699
724
defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
700
725
end
701
- mid, defuse = get! (defuses, defidx, (SPCSet (), SSADefUse ()))
726
+ mid, defuse = get! (defuses, defidx) do
727
+ SPCSet (), SSADefUse ()
728
+ end
702
729
push! (defuse. ccall_preserve_uses, idx)
703
730
union! (mid, intermediaries)
704
731
end
@@ -708,16 +735,17 @@ function sroa_pass!(ir::IRCode)
708
735
compact[idx] = form_new_preserves (stmt, preserved, new_preserves)
709
736
end
710
737
continue
711
- # TODO : This isn't the best place to put these
712
- elseif is_known_call (stmt, typeassert, compact)
713
- canonicalize_typeassert! (compact, idx, stmt)
714
- continue
715
- elseif is_known_call (stmt, (=== ), compact)
716
- lift_comparison! (compact, idx, stmt, lifting_cache)
717
- continue
718
- # elseif is_known_call(stmt, isa, compact)
719
- # TODO do a similar optimization as `lift_comparison!` for `===`
720
738
else
739
+ if optional_opts
740
+ # TODO : This isn't the best place to put these
741
+ if is_known_call (stmt, typeassert, compact)
742
+ canonicalize_typeassert! (compact, idx, stmt)
743
+ elseif is_known_call (stmt, (=== ), compact)
744
+ lift_comparison! (compact, idx, stmt, lifting_cache)
745
+ # elseif is_known_call(stmt, isa, compact)
746
+ # TODO do a similar optimization as `lift_comparison!` for `===`
747
+ end
748
+ end
721
749
continue
722
750
end
723
751
@@ -743,7 +771,7 @@ function sroa_pass!(ir::IRCode)
743
771
if ismutabletype (struct_typ)
744
772
isa (val, SSAValue) || continue
745
773
let intermediaries = SPCSet ()
746
- callback = function (@nospecialize (pi ), @nospecialize (ssa))
774
+ callback = function (@nospecialize (x ), @nospecialize (ssa))
747
775
push! (intermediaries, ssa. id)
748
776
return false
749
777
end
@@ -753,7 +781,9 @@ function sroa_pass!(ir::IRCode)
753
781
if defuses === nothing
754
782
defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
755
783
end
756
- mid, defuse = get! (defuses, def. id, (SPCSet (), SSADefUse ()))
784
+ mid, defuse = get! (defuses, def. id) do
785
+ SPCSet (), SSADefUse ()
786
+ end
757
787
if is_setfield
758
788
push! (defuse. defs, idx)
759
789
else
@@ -775,7 +805,7 @@ function sroa_pass!(ir::IRCode)
775
805
isempty (leaves) && continue
776
806
777
807
result_t = argextype (SSAValue (idx), compact)
778
- lifted_result = lift_leaves (compact, result_t, field, leaves )
808
+ lifted_result = lift_leaves! (compact, leaves, result_t, field, nested_loads )
779
809
lifted_result === nothing && continue
780
810
lifted_leaves, any_undef = lifted_result
781
811
@@ -811,21 +841,25 @@ function sroa_pass!(ir::IRCode)
811
841
used_ssas = copy (compact. used_ssas)
812
842
simple_dce! (compact, (x:: SSAValue ) -> used_ssas[x. id] -= 1 )
813
843
ir = complete (compact)
814
- sroa_mutables! (ir, defuses, used_ssas)
815
- return ir
844
+ return sroa_mutables! (ir, defuses, used_ssas, nested_loads)
816
845
else
817
846
simple_dce! (compact)
818
847
return complete (compact)
819
848
end
820
849
end
821
850
822
- function sroa_mutables! (ir:: IRCode , defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} )
851
+ function sroa_mutables! (ir:: IRCode ,
852
+ defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} ,
853
+ nested_loads:: NestedLoads )
823
854
# Compute domtree, needed below, now that we have finished compacting the IR.
824
855
# This needs to be after we iterate through the IR with `IncrementalCompact`
825
856
# because removing dead blocks can invalidate the domtree.
826
857
@timeit " domtree 2" domtree = construct_domtree (ir. cfg. blocks)
827
858
828
- for (idx, (intermediaries, defuse)) in defuses
859
+ nested_mloads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
860
+ local any_eliminated = false
861
+ # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
862
+ for (idx, (intermediaries, defuse)) in sort! (collect (defuses); by= first, rev= true )
829
863
intermediaries = collect (intermediaries)
830
864
# Check if there are any uses we did not account for. If so, the variable
831
865
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
@@ -840,7 +874,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
840
874
nleaves == nuses_total || continue
841
875
# Find the type for this allocation
842
876
defexpr = ir[SSAValue (idx)]
843
- isexpr (defexpr, :new ) || continue
877
+ isa (defexpr, Expr) || continue
878
+ if ! isexpr (defexpr, :new )
879
+ if is_known_call (defexpr, getfield, ir)
880
+ val = defexpr. args[2 ]
881
+ if isa (val, SSAValue)
882
+ struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
883
+ if ismutabletype (struct_typ)
884
+ record_nested_load! (nested_mloads, idx)
885
+ end
886
+ end
887
+ end
888
+ continue
889
+ end
844
890
newidx = idx
845
891
typ = ir. stmts[newidx][:type ]
846
892
if isa (typ, UnionAll)
@@ -910,6 +956,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
910
956
# Now go through all uses and rewrite them
911
957
for stmt in du. uses
912
958
ir[SSAValue (stmt)] = compute_value_for_use (ir, domtree, allblocks, du, phinodes, fidx, stmt)
959
+ if ! any_eliminated
960
+ any_eliminated |= (is_nested_load (nested_loads, stmt) ||
961
+ is_nested_load (nested_mloads, stmt))
962
+ end
913
963
end
914
964
if ! isbitstype (ftyp)
915
965
if preserve_uses != = nothing
@@ -946,6 +996,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
946
996
947
997
@label skip
948
998
end
999
+ if any_eliminated
1000
+ return sroa_pass! (compact! (ir), false )
1001
+ else
1002
+ return ir
1003
+ end
949
1004
end
950
1005
951
1006
function form_new_preserves (origex:: Expr , intermediates:: Vector{Int} , new_preserves:: Vector{Any} )
0 commit comments