Skip to content

Commit 29fd2ac

Browse files
committed
optimizer: enable SROA of mutable φ-nodes
This commit allows elimination of mutable φ-node (and its predecessor mutables allocations). As an contrived example, it allows this `mutable_ϕ_elim(::String, ::Vector{String})` to run without any allocations at all: ```julia function mutable_ϕ_elim(x, xs) r = Ref(x) for x in xs r = Ref(x) end return r[] end let xs = String[string(gensym()) for _ in 1:100] mutable_ϕ_elim("init", xs) @test @allocated(mutable_ϕ_elim("init", xs)) == 0 end ``` This mutable ϕ-node elimination is still limited though. Most notably, the current implementation doesn't work if a mutable allocation forms multiple ϕ-nodes, since we check allocation eliminability (i.e. escapability) by counting usages counts and thus it's hard to reason about multiple ϕ-nodes at a time. For example, currently mutable allocations involved in cases like below will still not be eliminated: ```julia code_typed((Bool,String,String),) do cond, x, y if cond ϕ2 = ϕ1 = Ref(x) else ϕ2 = ϕ1 = Ref(y) end ϕ1[], ϕ2[] end \# more realistic example mutable struct Point{T} x::T y::T end add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) function compute(a::Point{ComplexF64}, b::Point{ComplexF64}) for i in 0:(100000000-1) a = add(add(a, b), b) end a.x, a.y end ``` I'd say this limitation should be addressed by first introducing a better abstraction for reasoning escape information. More specifically, I'd like introduce EscapeAnalysis.jl into Julia base first, and then gradually adapt it to improve our SROA pass, since EA will allow us to reason about all escape information imposed on whatever object more easily and should help us get rid of the complexities of our current SROA implementation. For now, I'd like to get in this enhancement even though it has the limitation elaborated above, as far as this commit doesn't introduce latency problem (which is unlikely).
1 parent e1e502e commit 29fd2ac

File tree

2 files changed

+377
-55
lines changed

2 files changed

+377
-55
lines changed

base/compiler/ssair/passes.jl

+175-54
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
100100
end
101101
end
102102

103-
# even when the allocation contains an uninitialized field, we try an extra effort to check
104-
# if this load at `idx` have any "safe" `setfield!` calls that define the field
105-
function has_safe_def(
103+
# even when the allocation contains an uninitialized field, we try an extra effort to
104+
# check if all loads have "safe" `setfield!` calls that define the uninitialized field
105+
function has_safe_def_for_undef_field(
106106
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
107107
newidx::Int, idx::Int)
108108
def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx)
@@ -207,14 +207,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
207207
end
208208

209209
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
210-
@nospecialize(typeconstraint))
211-
callback = function (@nospecialize(pi), @nospecialize(idx))
212-
if isa(pi, PiNode)
213-
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
210+
@nospecialize(typeconstraint), @nospecialize(callback = nothing))
211+
newcallback = function (@nospecialize(x), @nospecialize(idx))
212+
if isa(x, PiNode)
213+
typeconstraint = typeintersect(typeconstraint, widenconst(x.typ))
214214
end
215+
callback === nothing || callback(x, idx)
215216
return false
216217
end
217-
def = simple_walk(compact, defssa, callback)
218+
def = simple_walk(compact, defssa, newcallback)
218219
return Pair{Any, Any}(def, typeconstraint)
219220
end
220221

@@ -224,7 +225,9 @@ end
224225
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
225226
(pruning those leaves rules out by path conditions).
226227
"""
227-
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
228+
function walk_to_defs(compact::IncrementalCompact,
229+
@nospecialize(defssa), @nospecialize(typeconstraint),
230+
@nospecialize(callback = nothing))
228231
visited_phinodes = AnySSAValue[]
229232
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
230233
def = compact[defssa]
@@ -260,7 +263,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
260263
val = OldSSAValue(val.id)
261264
end
262265
if isa(val, AnySSAValue)
263-
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint)
266+
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback)
264267
if isa(new_def, AnySSAValue)
265268
if !haskey(visited_constraints, new_def)
266269
push!(worklist_defs, new_def)
@@ -721,10 +724,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
721724
continue
722725
end
723726
if defuses === nothing
724-
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
727+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
725728
end
726-
mid, defuse = get!(defuses, defidx) do
727-
SPCSet(), SSADefUse()
729+
mid, defuse, phidefs = get!(defuses, defidx) do
730+
SPCSet(), SSADefUse(), PhiDefs(nothing)
728731
end
729732
push!(defuse.ccall_preserve_uses, idx)
730733
union!(mid, intermediaries)
@@ -779,16 +782,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
779782
# Mutable stuff here
780783
isa(def, SSAValue) || continue
781784
if defuses === nothing
782-
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
785+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
783786
end
784-
mid, defuse = get!(defuses, def.id) do
785-
SPCSet(), SSADefUse()
787+
mid, defuse, phidefs = get!(defuses, def.id) do
788+
SPCSet(), SSADefUse(), PhiDefs(nothing)
786789
end
787790
if is_setfield
788791
push!(defuse.defs, idx)
789792
else
790793
push!(defuse.uses, idx)
791794
end
795+
defval = compact[def]
796+
if isa(defval, PhiNode)
797+
phicallback = function (@nospecialize(x), @nospecialize(ssa))
798+
push!(intermediaries, ssa.id)
799+
return false
800+
end
801+
defs, _ = walk_to_defs(compact, def, struct_typ, phicallback)
802+
if _any(@nospecialize(d)->!isa(d, SSAValue), defs)
803+
delete!(defuses, def.id)
804+
continue
805+
end
806+
phidefs[] = Int[(def::SSAValue).id for def in defs]
807+
end
792808
union!(mid, intermediaries)
793809
end
794810
continue
@@ -848,43 +864,73 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
848864
end
849865
end
850866

867+
# TODO:
868+
# - run mutable SROA on the same IR as when we collect information about mutable allocations
869+
# - simplify and improve the eliminability check below using an escape analysis
870+
871+
const PhiDefs = RefValue{Union{Nothing,Vector{Int}}}
872+
851873
function sroa_mutables!(ir::IRCode,
852-
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
874+
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int},
853875
nested_loads::NestedLoads)
854876
domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
855877
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
856878
any_eliminated = false
879+
eliminable_defs = nothing # tracks eliminable "definitions" if initialized
857880
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
858-
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
881+
for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true)
859882
intermediaries = collect(intermediaries)
883+
phidefs = phidefs[]
860884
# Check if there are any uses we did not account for. If so, the variable
861885
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
862886
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
863887
# show up in the nuses_total count.
864-
nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
888+
nleaves = count_leaves(defuse)
889+
if phidefs !== nothing
890+
# if this defines ϕ, we also track leaves of all predecessors as well
891+
# FIXME this doesn't work when any predecessor is used by another ϕ-node
892+
for pidx in phidefs
893+
haskey(defuses, pidx) || continue
894+
pdefuse = defuses[pidx][2]
895+
nleaves += count_leaves(pdefuse)
896+
end
897+
end
865898
nuses = 0
866899
for idx in intermediaries
867900
nuses += used_ssas[idx]
868901
end
869-
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
902+
nuses -= length(intermediaries)
903+
nuses_total = used_ssas[idx] + nuses
904+
if phidefs !== nothing
905+
for pidx in phidefs
906+
# NOTE we don't need to accout for intermediates for this predecessor here,
907+
# since they are already included in intermediates of this ϕ-node
908+
# FIXME this doesn't work when any predecessor is used by another ϕ-node
909+
nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself
910+
end
911+
end
870912
nleaves == nuses_total || continue
871913
# Find the type for this allocation
872914
defexpr = ir[SSAValue(idx)]
873-
isa(defexpr, Expr) || continue
874-
if !isexpr(defexpr, :new)
875-
if is_known_call(defexpr, getfield, ir)
876-
val = defexpr.args[2]
877-
if isa(val, SSAValue)
878-
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
879-
if ismutabletype(struct_typ)
880-
record_nested_load!(nested_mloads, idx)
881-
end
915+
if isa(defexpr, Expr)
916+
@assert phidefs === nothing
917+
if !isexpr(defexpr, :new)
918+
maybe_record_nested_load!(nested_mloads, ir, idx)
919+
continue
920+
end
921+
elseif isa(defexpr, PhiNode)
922+
phidefs === nothing && continue
923+
for pidx in phidefs
924+
pexpr = ir[SSAValue(pidx)]
925+
if !isexpr(pexpr, :new)
926+
maybe_record_nested_load!(nested_mloads, ir, pidx)
927+
@goto skip
882928
end
883929
end
930+
else
884931
continue
885932
end
886-
newidx = idx
887-
typ = ir.stmts[newidx][:type]
933+
typ = ir.stmts[idx][:type]
888934
if isa(typ, UnionAll)
889935
typ = unwrap_unionall(typ)
890936
end
@@ -896,25 +942,29 @@ function sroa_mutables!(ir::IRCode,
896942
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
897943
all_forwarded = true
898944
for use in defuse.uses
899-
stmt = ir[SSAValue(use)] # == `getfield` call
900-
# We may have discovered above that this use is dead
901-
# after the getfield elim of immutables. In that case,
902-
# it would have been deleted. That's fine, just ignore
903-
# the use in that case.
904-
if stmt === nothing
945+
eliminable = check_use_eliminability!(fielddefuse, ir, use, typ)
946+
if eliminable === nothing
947+
# We may have discovered above that this use is dead
948+
# after the getfield elim of immutables. In that case,
949+
# it would have been deleted. That's fine, just ignore
950+
# the use in that case.
905951
all_forwarded = false
906952
continue
953+
elseif !eliminable
954+
@goto skip
907955
end
908-
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
909-
field === nothing && @goto skip
910-
push!(fielddefuse[field].uses, use)
911956
end
912957
for def in defuse.defs
913-
stmt = ir[SSAValue(def)]::Expr # == `setfield!` call
914-
field = try_compute_fieldidx_stmt(ir, stmt, typ)
915-
field === nothing && @goto skip
916-
isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
917-
push!(fielddefuse[field].defs, def)
958+
check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip
959+
end
960+
if phidefs !== nothing
961+
for pidx in phidefs
962+
haskey(defuses, pidx) || continue
963+
pdefuse = defuses[pidx][2]
964+
for pdef in pdefuse.defs
965+
check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip
966+
end
967+
end
918968
end
919969
# Check that the defexpr has defined values for all the fields
920970
# we're accessing. In the future, we may want to relax this,
@@ -925,7 +975,13 @@ function sroa_mutables!(ir::IRCode,
925975
for fidx in 1:ndefuse
926976
du = fielddefuse[fidx]
927977
isempty(du.uses) && continue
928-
push!(du.defs, newidx)
978+
if phidefs === nothing
979+
push!(du.defs, idx)
980+
else
981+
for pidx in phidefs
982+
push!(du.defs, pidx)
983+
end
984+
end
929985
ldu = compute_live_ins(ir.cfg, du)
930986
if isempty(ldu.live_in_bbs)
931987
phiblocks = Int[]
@@ -935,10 +991,24 @@ function sroa_mutables!(ir::IRCode,
935991
end
936992
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
937993
blocks[fidx] = phiblocks, allblocks
938-
if fidx + 1 > length(defexpr.args)
939-
for use in du.uses
994+
if phidefs !== nothing
995+
# check if all predecessors have safe definitions
996+
for pidx in phidefs
997+
newexpr = ir[SSAValue(pidx)]::Expr # == new(...)
998+
if fidx + 1 > length(newexpr.args) # this field can be undefined
999+
domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks))
1000+
for use in du.uses
1001+
has_safe_def_for_undef_field(ir, domtree, allblocks, du, pidx, use) || @goto skip
1002+
end
1003+
end
1004+
end
1005+
else
1006+
newexpr = defexpr::Expr # == new(...)
1007+
if fidx + 1 > length(newexpr.args) # this field can be undefined
9401008
domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks))
941-
has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip
1009+
for use in du.uses
1010+
has_safe_def_for_undef_field(ir, domtree, allblocks, du, idx, use) || @goto skip
1011+
end
9421012
end
9431013
end
9441014
end
@@ -983,28 +1053,79 @@ function sroa_mutables!(ir::IRCode,
9831053
end
9841054
end
9851055
end
986-
for stmt in du.defs
987-
stmt == newidx && continue
988-
ir[SSAValue(stmt)] = nothing
1056+
eliminable_defs === nothing && (eliminable_defs = SPCSet())
1057+
for def in du.defs
1058+
push!(eliminable_defs, def)
1059+
end
1060+
if phidefs !== nothing
1061+
# record ϕ-node itself eliminable here, since we didn't include it in `du.defs`
1062+
# we also modify usage counts of its predecessors so that their SROA may work
1063+
# in succeeding iteration
1064+
push!(eliminable_defs, idx)
1065+
for pidx in phidefs
1066+
used_ssas[pidx] -= 1
1067+
end
9891068
end
9901069
end
9911070
preserve_uses === nothing && continue
9921071
if all_forwarded
9931072
# this means all ccall preserves have been replaced with forwarded loads
9941073
# so we can potentially eliminate the allocation, otherwise we must preserve
9951074
# the whole allocation.
996-
push!(intermediaries, newidx)
1075+
push!(intermediaries, idx)
9971076
end
9981077
# Insert the new preserves
9991078
for (use, new_preserves) in preserve_uses
10001079
ir[SSAValue(use)] = form_new_preserves(ir[SSAValue(use)]::Expr, intermediaries, new_preserves)
10011080
end
1002-
10031081
@label skip
10041082
end
1083+
# now eliminate "definitions" (i.e. allocations, ϕ-nodes, and `setfield!` calls)
1084+
# that should have no usage at this moment
1085+
if eliminable_defs !== nothing
1086+
for idx in eliminable_defs
1087+
ir[SSAValue(idx)] = nothing
1088+
end
1089+
end
10051090
return any_eliminated ? sroa_pass!(compact!(ir), false) : ir
10061091
end
10071092

1093+
count_leaves(defuse::SSADefUse) =
1094+
length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
1095+
1096+
function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int)
1097+
defexpr = ir[SSAValue(idx)]
1098+
if is_known_call(defexpr, getfield, ir)
1099+
val = defexpr.args[2]
1100+
if isa(val, SSAValue)
1101+
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
1102+
if ismutabletype(struct_typ)
1103+
record_nested_load!(nested_mloads, idx)
1104+
end
1105+
end
1106+
end
1107+
end
1108+
1109+
function check_use_eliminability!(fielddefuse::Vector{SSADefUse},
1110+
ir::IRCode, useidx::Int, struct_typ::DataType)
1111+
stmt = ir[SSAValue(useidx)] # == `getfield` call
1112+
stmt === nothing && return nothing
1113+
field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ)
1114+
field === nothing && return false
1115+
push!(fielddefuse[field].uses, useidx)
1116+
return true
1117+
end
1118+
1119+
function check_def_eliminability!(fielddefuse::Vector{SSADefUse},
1120+
ir::IRCode, defidx::Int, struct_typ::DataType)
1121+
stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call
1122+
field = try_compute_fieldidx_stmt(ir, stmt, struct_typ)
1123+
field === nothing && return false
1124+
isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error
1125+
push!(fielddefuse[field].defs, defidx)
1126+
return true
1127+
end
1128+
10081129
function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})
10091130
newex = Expr(:foreigncall)
10101131
nccallargs = length(origex.args[3]::SimpleVector)

0 commit comments

Comments
 (0)