@@ -100,9 +100,22 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
100
100
end
101
101
end
102
102
103
- # even when the allocation contains an uninitialized field, we try an extra effort to check
104
- # if this load at `idx` have any "safe" `setfield!` calls that define the field
105
103
function has_safe_def (
104
+ ir:: IRCode , domtree:: DomTree , allblocks:: Vector{Int} , du:: SSADefUse ,
105
+ newidx:: Int , fidx:: Int )
106
+ newexpr = ir[SSAValue (newidx)]:: Expr
107
+
108
+ fidx + 1 ≤ length (newexpr. args) && return true # assured to have a safe definition for all usages
109
+
110
+ # even when the allocation contains an uninitialized field, we try an extra effort to
111
+ # check if all loads have "safe" `setfield!` calls that define the uninitialized field
112
+ for use in du. uses
113
+ has_safe_def_for_uninitialized_field (ir, domtree, allblocks, du, newidx, use) || return false
114
+ end
115
+ return true # shuold be safe
116
+ end
117
+
118
+ function has_safe_def_for_uninitialized_field (
106
119
ir:: IRCode , domtree:: DomTree , allblocks:: Vector{Int} , du:: SSADefUse ,
107
120
newidx:: Int , idx:: Int )
108
121
def, _, _ = find_def_for_use (ir, domtree, allblocks, du, idx)
@@ -207,14 +220,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
207
220
end
208
221
209
222
function simple_walk_constraint (compact:: IncrementalCompact , @nospecialize (defssa#= ::AnySSAValue=# ),
210
- @nospecialize (typeconstraint))
211
- callback = function (@nospecialize (pi ), @nospecialize (idx))
212
- if isa (pi , PiNode)
213
- typeconstraint = typeintersect (typeconstraint, widenconst (pi . typ))
223
+ @nospecialize (typeconstraint), @nospecialize (callback = nothing ) )
224
+ newcallback = function (@nospecialize (x ), @nospecialize (idx))
225
+ if isa (x , PiNode)
226
+ typeconstraint = typeintersect (typeconstraint, widenconst (x . typ))
214
227
end
228
+ callback === nothing || callback (x, idx)
215
229
return false
216
230
end
217
- def = simple_walk (compact, defssa, callback )
231
+ def = simple_walk (compact, defssa, newcallback )
218
232
return Pair {Any, Any} (def, typeconstraint)
219
233
end
220
234
224
238
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
225
239
(pruning those leaves rules out by path conditions).
226
240
"""
227
- function walk_to_defs (compact:: IncrementalCompact , @nospecialize (defssa), @nospecialize (typeconstraint))
241
+ function walk_to_defs (compact:: IncrementalCompact ,
242
+ @nospecialize (defssa), @nospecialize (typeconstraint),
243
+ @nospecialize (callback = nothing ))
228
244
visited_phinodes = AnySSAValue[]
229
245
isa (defssa, AnySSAValue) || return Any[defssa], visited_phinodes
230
246
def = compact[defssa]
@@ -260,7 +276,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
260
276
val = OldSSAValue (val. id)
261
277
end
262
278
if isa (val, AnySSAValue)
263
- new_def, new_constraint = simple_walk_constraint (compact, val, typeconstraint)
279
+ new_def, new_constraint = simple_walk_constraint (compact, val, typeconstraint, callback )
264
280
if isa (new_def, AnySSAValue)
265
281
if ! haskey (visited_constraints, new_def)
266
282
push! (worklist_defs, new_def)
@@ -721,10 +737,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
721
737
continue
722
738
end
723
739
if defuses === nothing
724
- defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
740
+ defuses = IdDict {Int, Tuple{SPCSet, SSADefUse, PhiDefs }} ()
725
741
end
726
- mid, defuse = get! (defuses, defidx) do
727
- SPCSet (), SSADefUse ()
742
+ mid, defuse, phidefs = get! (defuses, defidx) do
743
+ SPCSet (), SSADefUse (), PhiDefs ( nothing )
728
744
end
729
745
push! (defuse. ccall_preserve_uses, idx)
730
746
union! (mid, intermediaries)
@@ -779,16 +795,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
779
795
# Mutable stuff here
780
796
isa (def, SSAValue) || continue
781
797
if defuses === nothing
782
- defuses = IdDict {Int, Tuple{SPCSet, SSADefUse}} ()
798
+ defuses = IdDict {Int, Tuple{SPCSet, SSADefUse, PhiDefs }} ()
783
799
end
784
- mid, defuse = get! (defuses, def. id) do
785
- SPCSet (), SSADefUse ()
800
+ mid, defuse, phidefs = get! (defuses, def. id) do
801
+ SPCSet (), SSADefUse (), PhiDefs ( nothing )
786
802
end
787
803
if is_setfield
788
804
push! (defuse. defs, idx)
789
805
else
790
806
push! (defuse. uses, idx)
791
807
end
808
+ defval = compact[def]
809
+ if isa (defval, PhiNode)
810
+ phicallback = function (@nospecialize (x), @nospecialize (ssa))
811
+ push! (intermediaries, ssa. id)
812
+ return false
813
+ end
814
+ defs, _ = walk_to_defs (compact, def, struct_typ, phicallback)
815
+ if _any (@nospecialize (d)-> ! isa (d, SSAValue), defs)
816
+ delete! (defuses, def. id)
817
+ continue
818
+ end
819
+ phidefs[] = Int[(def:: SSAValue ). id for def in defs]
820
+ end
792
821
union! (mid, intermediaries)
793
822
end
794
823
continue
@@ -848,8 +877,14 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
848
877
end
849
878
end
850
879
880
+ # TODO :
881
+ # - run mutable SROA on the same IR as when we collect information about mutable allocations
882
+ # - simplify and improve the eliminability check below using an escape analysis
883
+
884
+ const PhiDefs = RefValue{Union{Nothing,Vector{Int}}}
885
+
851
886
function sroa_mutables! (ir:: IRCode ,
852
- defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse}} , used_ssas:: Vector{Int} ,
887
+ defuses:: IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs }} , used_ssas:: Vector{Int} ,
853
888
nested_loads:: NestedLoads )
854
889
# Compute domtree, needed below, now that we have finished compacting the IR.
855
890
# This needs to be after we iterate through the IR with `IncrementalCompact`
@@ -859,36 +894,58 @@ function sroa_mutables!(ir::IRCode,
859
894
nested_mloads = NestedLoads () # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
860
895
local any_eliminated = false
861
896
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
862
- for (idx, (intermediaries, defuse)) in sort! (collect (defuses); by= first, rev= true )
897
+ for (idx, (intermediaries, defuse, phidefs )) in sort! (collect (defuses); by= first, rev= true )
863
898
intermediaries = collect (intermediaries)
899
+ phidefs = phidefs[]
864
900
# Check if there are any uses we did not account for. If so, the variable
865
901
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
866
902
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
867
903
# show up in the nuses_total count.
868
- nleaves = length (defuse. uses) + length (defuse. defs) + length (defuse. ccall_preserve_uses)
904
+ nleaves = count_leaves (defuse)
905
+ if phidefs != = nothing
906
+ # if this defines ϕ, we also track leaves of all predecessors as well
907
+ # FIXME this doesn't work when any predecessor is used by another ϕ-node
908
+ for pidx in phidefs
909
+ haskey (defuses, pidx) || continue
910
+ pdefuse = defuses[pidx][2 ]
911
+ nleaves += count_leaves (pdefuse)
912
+ end
913
+ end
869
914
nuses = 0
870
915
for idx in intermediaries
871
916
nuses += used_ssas[idx]
872
917
end
873
- nuses_total = used_ssas[idx] + nuses - length (intermediaries)
918
+ nuses -= length (intermediaries)
919
+ nuses_total = used_ssas[idx] + nuses
920
+ if phidefs != = nothing
921
+ for pidx in phidefs
922
+ # NOTE we don't need to accout for intermediates for this predecessor here,
923
+ # since they are already included in intermediates of this ϕ-node
924
+ # FIXME this doesn't work when any predecessor is used by another ϕ-node
925
+ nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself
926
+ end
927
+ end
874
928
nleaves == nuses_total || continue
875
929
# Find the type for this allocation
876
930
defexpr = ir[SSAValue (idx)]
877
- isa (defexpr, Expr) || continue
878
- if ! isexpr (defexpr, :new )
879
- if is_known_call (defexpr, getfield, ir)
880
- val = defexpr. args[2 ]
881
- if isa (val, SSAValue)
882
- struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
883
- if ismutabletype (struct_typ)
884
- record_nested_load! (nested_mloads, idx)
885
- end
931
+ if isa (defexpr, Expr)
932
+ if ! isexpr (defexpr, :new )
933
+ maybe_record_nested_load! (nested_mloads, ir, idx)
934
+ continue
935
+ end
936
+ elseif isa (defexpr, PhiNode)
937
+ phidefs === nothing && continue
938
+ for pidx in phidefs
939
+ pexpr = ir[SSAValue (pidx)]
940
+ if ! isexpr (pexpr, :new )
941
+ maybe_record_nested_load! (nested_mloads, ir, pidx)
942
+ @goto skip
886
943
end
887
944
end
945
+ else
888
946
continue
889
947
end
890
- newidx = idx
891
- typ = ir. stmts[newidx][:type ]
948
+ typ = ir. stmts[idx][:type ]
892
949
if isa (typ, UnionAll)
893
950
typ = unwrap_unionall (typ)
894
951
end
@@ -900,25 +957,29 @@ function sroa_mutables!(ir::IRCode,
900
957
fielddefuse = SSADefUse[SSADefUse () for _ = 1 : fieldcount (typ)]
901
958
all_forwarded = true
902
959
for use in defuse. uses
903
- stmt = ir[ SSAValue (use)] # == `getfield` call
904
- # We may have discovered above that this use is dead
905
- # after the getfield elim of immutables. In that case,
906
- # it would have been deleted. That's fine, just ignore
907
- # the use in that case.
908
- if stmt === nothing
960
+ eliminable = check_use_eliminability! (fielddefuse, ir, use, typ)
961
+ if eliminable === nothing
962
+ # We may have discovered above that this use is dead
963
+ # after the getfield elim of immutables. In that case,
964
+ # it would have been deleted. That's fine, just ignore
965
+ # the use in that case.
909
966
all_forwarded = false
910
967
continue
968
+ elseif ! eliminable
969
+ @goto skip
911
970
end
912
- field = try_compute_fieldidx_stmt (ir, stmt:: Expr , typ)
913
- field === nothing && @goto skip
914
- push! (fielddefuse[field]. uses, use)
915
971
end
916
972
for def in defuse. defs
917
- stmt = ir[SSAValue (def)]:: Expr # == `setfield!` call
918
- field = try_compute_fieldidx_stmt (ir, stmt, typ)
919
- field === nothing && @goto skip
920
- isconst (typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
921
- push! (fielddefuse[field]. defs, def)
973
+ check_def_eliminability! (fielddefuse, ir, def, typ) || @goto skip
974
+ end
975
+ if phidefs != = nothing
976
+ for pidx in phidefs
977
+ haskey (defuses, pidx) || continue
978
+ pdefuse = defuses[pidx][2 ]
979
+ for pdef in pdefuse. defs
980
+ check_def_eliminability! (fielddefuse, ir, pdef, typ) || @goto skip
981
+ end
982
+ end
922
983
end
923
984
# Check that the defexpr has defined values for all the fields
924
985
# we're accessing. In the future, we may want to relax this,
@@ -929,15 +990,24 @@ function sroa_mutables!(ir::IRCode,
929
990
for fidx in 1 : ndefuse
930
991
du = fielddefuse[fidx]
931
992
isempty (du. uses) && continue
932
- push! (du. defs, newidx)
993
+ if phidefs === nothing
994
+ push! (du. defs, idx)
995
+ else
996
+ for pidx in phidefs
997
+ push! (du. defs, pidx)
998
+ end
999
+ end
933
1000
ldu = compute_live_ins (ir. cfg, du)
934
1001
phiblocks = isempty (ldu. live_in_bbs) ? Int[] : iterated_dominance_frontier (ir. cfg, ldu, domtree)
935
1002
allblocks = sort (vcat (phiblocks, ldu. def_bbs))
936
1003
blocks[fidx] = phiblocks, allblocks
937
- if fidx + 1 > length (defexpr. args)
938
- for use in du. uses
939
- has_safe_def (ir, domtree, allblocks, du, newidx, use) || @goto skip
1004
+ if phidefs != = nothing
1005
+ # check if all predecessors have safe definitions
1006
+ for pidx in phidefs
1007
+ has_safe_def (ir, domtree, allblocks, du, pidx, fidx) || @goto skip
940
1008
end
1009
+ else
1010
+ has_safe_def (ir, domtree, allblocks, du, idx, fidx) || @goto skip
941
1011
end
942
1012
end
943
1013
# Everything accounted for. Go field by field and perform idf
@@ -977,17 +1047,24 @@ function sroa_mutables!(ir::IRCode,
977
1047
end
978
1048
end
979
1049
end
980
- for stmt in du. defs
981
- stmt == newidx && continue
982
- ir[SSAValue (stmt)] = nothing
1050
+ if isa (defexpr, PhiNode)
1051
+ ir[SSAValue (idx)] = nothing
1052
+ for pidx in phidefs:: Vector{Int}
1053
+ used_ssas[pidx] -= 1
1054
+ end
1055
+ else
1056
+ for stmt in du. defs
1057
+ stmt == idx && continue
1058
+ ir[SSAValue (stmt)] = nothing
1059
+ end
983
1060
end
984
1061
end
985
1062
preserve_uses === nothing && continue
986
1063
if all_forwarded
987
1064
# this means all ccall preserves have been replaced with forwarded loads
988
1065
# so we can potentially eliminate the allocation, otherwise we must preserve
989
1066
# the whole allocation.
990
- push! (intermediaries, newidx )
1067
+ push! (intermediaries, idx )
991
1068
end
992
1069
# Insert the new preserves
993
1070
for (use, new_preserves) in preserve_uses
@@ -1003,6 +1080,42 @@ function sroa_mutables!(ir::IRCode,
1003
1080
end
1004
1081
end
1005
1082
1083
+ count_leaves (defuse:: SSADefUse ) =
1084
+ length (defuse. uses) + length (defuse. defs) + length (defuse. ccall_preserve_uses)
1085
+
1086
+ function maybe_record_nested_load! (nested_mloads:: NestedLoads , ir:: IRCode , idx:: Int )
1087
+ defexpr = ir[SSAValue (idx)]
1088
+ if is_known_call (defexpr, getfield, ir)
1089
+ val = defexpr. args[2 ]
1090
+ if isa (val, SSAValue)
1091
+ struct_typ = unwrap_unionall (widenconst (argextype (val, ir)))
1092
+ if ismutabletype (struct_typ)
1093
+ record_nested_load! (nested_mloads, idx)
1094
+ end
1095
+ end
1096
+ end
1097
+ end
1098
+
1099
+ function check_use_eliminability! (fielddefuse:: Vector{SSADefUse} ,
1100
+ ir:: IRCode , useidx:: Int , struct_typ:: DataType )
1101
+ stmt = ir[SSAValue (useidx)] # == `getfield` call
1102
+ stmt === nothing && return nothing
1103
+ field = try_compute_fieldidx_stmt (ir, stmt:: Expr , struct_typ)
1104
+ field === nothing && return false
1105
+ push! (fielddefuse[field]. uses, useidx)
1106
+ return true
1107
+ end
1108
+
1109
+ function check_def_eliminability! (fielddefuse:: Vector{SSADefUse} ,
1110
+ ir:: IRCode , defidx:: Int , struct_typ:: DataType )
1111
+ stmt = ir[SSAValue (defidx)]:: Expr # == `setfield!` call
1112
+ field = try_compute_fieldidx_stmt (ir, stmt, struct_typ)
1113
+ field === nothing && return false
1114
+ isconst (struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error
1115
+ push! (fielddefuse[field]. defs, defidx)
1116
+ return true
1117
+ end
1118
+
1006
1119
function form_new_preserves (origex:: Expr , intermediates:: Vector{Int} , new_preserves:: Vector{Any} )
1007
1120
newex = Expr (:foreigncall )
1008
1121
nccallargs = length (origex. args[3 ]:: SimpleVector )
0 commit comments