Skip to content

Commit bf97c29

Browse files
committed
optimizer: enable SROA of mutable φ-nodes
This commit allows elimination of mutable φ-node (and its predecessor mutables allocations). As an contrived example, it allows this `mutable_ϕ_elim(::String, ::Vector{String})` to run without any allocations at all: ```julia function mutable_ϕ_elim(x, xs) r = Ref(x) for x in xs r = Ref(x) end return r[] end let xs = String[string(gensym()) for _ in 1:100] mutable_ϕ_elim("init", xs) @test @allocated(mutable_ϕ_elim("init", xs)) == 0 end ``` This mutable ϕ-node elimination is still limited though. Most notably, the current implementation doesn't work if a mutable allocation forms multiple ϕ-nodes, since we check allocation eliminability (i.e. escapability) by counting usages counts and thus it's hard to reason about multiple ϕ-nodes at a time. For example, currently mutable allocations involved in cases like below will still not be eliminated: ```julia code_typed((Bool,String,String),) do cond, x, y if cond ϕ2 = ϕ1 = Ref(x) else ϕ2 = ϕ1 = Ref(y) end ϕ1[], ϕ2[] end \# more realistic example mutable struct Point{T} x::T y::T end add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) function compute(a::Point{ComplexF64}, b::Point{ComplexF64}) for i in 0:(100000000-1) a = add(add(a, b), b) end a.x, a.y end ``` I'd say this limitation should be addressed by first introducing a better abstraction for reasoning escape information. More specifically, I'd like introduce EscapeAnalysis.jl into Julia base first, and then gradually adapt it to improve our SROA pass, since EA will allow us to reason about all escape information imposed on whatever object more easily and should help us get rid of the complexities of our current SROA implementation. For now, I'd like to get in this enhancement even though it has the limitation elaborated above, as far as this commit doesn't introduce latency problem (which is unlikely).
1 parent a9c6daf commit bf97c29

File tree

2 files changed

+377
-55
lines changed

2 files changed

+377
-55
lines changed

base/compiler/ssair/passes.jl

+175-54
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
101101
end
102102
end
103103

104-
# even when the allocation contains an uninitialized field, we try an extra effort to check
105-
# if this load at `idx` have any "safe" `setfield!` calls that define the field
106-
function has_safe_def(
104+
# even when the allocation contains an uninitialized field, we try an extra effort to
105+
# check if all loads have "safe" `setfield!` calls that define the uninitialized field
106+
function has_safe_def_for_undef_field(
107107
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
108108
newidx::Int, idx::Int)
109109
def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx)
@@ -208,14 +208,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
208208
end
209209

210210
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
211-
@nospecialize(typeconstraint))
212-
callback = function (@nospecialize(pi), @nospecialize(idx))
213-
if isa(pi, PiNode)
214-
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
211+
@nospecialize(typeconstraint), @nospecialize(callback = nothing))
212+
newcallback = function (@nospecialize(x), @nospecialize(idx))
213+
if isa(x, PiNode)
214+
typeconstraint = typeintersect(typeconstraint, widenconst(x.typ))
215215
end
216+
callback === nothing || callback(x, idx)
216217
return false
217218
end
218-
def = simple_walk(compact, defssa, callback)
219+
def = simple_walk(compact, defssa, newcallback)
219220
return Pair{Any, Any}(def, typeconstraint)
220221
end
221222

@@ -225,7 +226,9 @@ end
225226
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
226227
(pruning those leaves rules out by path conditions).
227228
"""
228-
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
229+
function walk_to_defs(compact::IncrementalCompact,
230+
@nospecialize(defssa), @nospecialize(typeconstraint),
231+
@nospecialize(callback = nothing))
229232
visited_phinodes = AnySSAValue[]
230233
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
231234
def = compact[defssa]
@@ -261,7 +264,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
261264
val = OldSSAValue(val.id)
262265
end
263266
if isa(val, AnySSAValue)
264-
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint)
267+
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback)
265268
if isa(new_def, AnySSAValue)
266269
if !haskey(visited_constraints, new_def)
267270
push!(worklist_defs, new_def)
@@ -722,10 +725,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
722725
continue
723726
end
724727
if defuses === nothing
725-
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
728+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
726729
end
727-
mid, defuse = get!(defuses, defidx) do
728-
SPCSet(), SSADefUse()
730+
mid, defuse, phidefs = get!(defuses, defidx) do
731+
SPCSet(), SSADefUse(), PhiDefs(nothing)
729732
end
730733
push!(defuse.ccall_preserve_uses, idx)
731734
union!(mid, intermediaries)
@@ -780,16 +783,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
780783
# Mutable stuff here
781784
isa(def, SSAValue) || continue
782785
if defuses === nothing
783-
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
786+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
784787
end
785-
mid, defuse = get!(defuses, def.id) do
786-
SPCSet(), SSADefUse()
788+
mid, defuse, phidefs = get!(defuses, def.id) do
789+
SPCSet(), SSADefUse(), PhiDefs(nothing)
787790
end
788791
if is_setfield
789792
push!(defuse.defs, idx)
790793
else
791794
push!(defuse.uses, idx)
792795
end
796+
defval = compact[def]
797+
if isa(defval, PhiNode)
798+
phicallback = function (@nospecialize(x), @nospecialize(ssa))
799+
push!(intermediaries, ssa.id)
800+
return false
801+
end
802+
defs, _ = walk_to_defs(compact, def, struct_typ, phicallback)
803+
if _any(@nospecialize(d)->!isa(d, SSAValue), defs)
804+
delete!(defuses, def.id)
805+
continue
806+
end
807+
phidefs[] = Int[(def::SSAValue).id for def in defs]
808+
end
793809
union!(mid, intermediaries)
794810
end
795811
continue
@@ -849,43 +865,73 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
849865
end
850866
end
851867

868+
# TODO:
869+
# - run mutable SROA on the same IR as when we collect information about mutable allocations
870+
# - simplify and improve the eliminability check below using an escape analysis
871+
872+
const PhiDefs = RefValue{Union{Nothing,Vector{Int}}}
873+
852874
function sroa_mutables!(ir::IRCode,
853-
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
875+
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int},
854876
nested_loads::NestedLoads)
855877
domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
856878
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
857879
any_eliminated = false
880+
eliminable_defs = nothing # tracks eliminable "definitions" if initialized
858881
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
859-
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
882+
for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true)
860883
intermediaries = collect(intermediaries)
884+
phidefs = phidefs[]
861885
# Check if there are any uses we did not account for. If so, the variable
862886
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
863887
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
864888
# show up in the nuses_total count.
865-
nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
889+
nleaves = count_leaves(defuse)
890+
if phidefs !== nothing
891+
# if this defines ϕ, we also track leaves of all predecessors as well
892+
# FIXME this doesn't work when any predecessor is used by another ϕ-node
893+
for pidx in phidefs
894+
haskey(defuses, pidx) || continue
895+
pdefuse = defuses[pidx][2]
896+
nleaves += count_leaves(pdefuse)
897+
end
898+
end
866899
nuses = 0
867900
for idx in intermediaries
868901
nuses += used_ssas[idx]
869902
end
870-
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
903+
nuses -= length(intermediaries)
904+
nuses_total = used_ssas[idx] + nuses
905+
if phidefs !== nothing
906+
for pidx in phidefs
907+
# NOTE we don't need to accout for intermediates for this predecessor here,
908+
# since they are already included in intermediates of this ϕ-node
909+
# FIXME this doesn't work when any predecessor is used by another ϕ-node
910+
nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself
911+
end
912+
end
871913
nleaves == nuses_total || continue
872914
# Find the type for this allocation
873915
defexpr = ir[SSAValue(idx)]
874-
isa(defexpr, Expr) || continue
875-
if !isexpr(defexpr, :new)
876-
if is_known_call(defexpr, getfield, ir)
877-
val = defexpr.args[2]
878-
if isa(val, SSAValue)
879-
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
880-
if ismutabletype(struct_typ)
881-
record_nested_load!(nested_mloads, idx)
882-
end
916+
if isa(defexpr, Expr)
917+
@assert phidefs === nothing
918+
if !isexpr(defexpr, :new)
919+
maybe_record_nested_load!(nested_mloads, ir, idx)
920+
continue
921+
end
922+
elseif isa(defexpr, PhiNode)
923+
phidefs === nothing && continue
924+
for pidx in phidefs
925+
pexpr = ir[SSAValue(pidx)]
926+
if !isexpr(pexpr, :new)
927+
maybe_record_nested_load!(nested_mloads, ir, pidx)
928+
@goto skip
883929
end
884930
end
931+
else
885932
continue
886933
end
887-
newidx = idx
888-
typ = ir.stmts[newidx][:type]
934+
typ = ir.stmts[idx][:type]
889935
if isa(typ, UnionAll)
890936
typ = unwrap_unionall(typ)
891937
end
@@ -897,25 +943,29 @@ function sroa_mutables!(ir::IRCode,
897943
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
898944
all_forwarded = true
899945
for use in defuse.uses
900-
stmt = ir[SSAValue(use)] # == `getfield` call
901-
# We may have discovered above that this use is dead
902-
# after the getfield elim of immutables. In that case,
903-
# it would have been deleted. That's fine, just ignore
904-
# the use in that case.
905-
if stmt === nothing
946+
eliminable = check_use_eliminability!(fielddefuse, ir, use, typ)
947+
if eliminable === nothing
948+
# We may have discovered above that this use is dead
949+
# after the getfield elim of immutables. In that case,
950+
# it would have been deleted. That's fine, just ignore
951+
# the use in that case.
906952
all_forwarded = false
907953
continue
954+
elseif !eliminable
955+
@goto skip
908956
end
909-
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
910-
field === nothing && @goto skip
911-
push!(fielddefuse[field].uses, use)
912957
end
913958
for def in defuse.defs
914-
stmt = ir[SSAValue(def)]::Expr # == `setfield!` call
915-
field = try_compute_fieldidx_stmt(ir, stmt, typ)
916-
field === nothing && @goto skip
917-
isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
918-
push!(fielddefuse[field].defs, def)
959+
check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip
960+
end
961+
if phidefs !== nothing
962+
for pidx in phidefs
963+
haskey(defuses, pidx) || continue
964+
pdefuse = defuses[pidx][2]
965+
for pdef in pdefuse.defs
966+
check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip
967+
end
968+
end
919969
end
920970
# Check that the defexpr has defined values for all the fields
921971
# we're accessing. In the future, we may want to relax this,
@@ -926,7 +976,13 @@ function sroa_mutables!(ir::IRCode,
926976
for fidx in 1:ndefuse
927977
du = fielddefuse[fidx]
928978
isempty(du.uses) && continue
929-
push!(du.defs, newidx)
979+
if phidefs === nothing
980+
push!(du.defs, idx)
981+
else
982+
for pidx in phidefs
983+
push!(du.defs, pidx)
984+
end
985+
end
930986
ldu = compute_live_ins(ir.cfg, du)
931987
if isempty(ldu.live_in_bbs)
932988
phiblocks = Int[]
@@ -936,10 +992,24 @@ function sroa_mutables!(ir::IRCode,
936992
end
937993
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
938994
blocks[fidx] = phiblocks, allblocks
939-
if fidx + 1 > length(defexpr.args)
940-
for use in du.uses
995+
if phidefs !== nothing
996+
# check if all predecessors have safe definitions
997+
for pidx in phidefs
998+
newexpr = ir[SSAValue(pidx)]::Expr # == new(...)
999+
if fidx + 1 > length(newexpr.args) # this field can be undefined
1000+
domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks))
1001+
for use in du.uses
1002+
has_safe_def_for_undef_field(ir, domtree, allblocks, du, pidx, use) || @goto skip
1003+
end
1004+
end
1005+
end
1006+
else
1007+
newexpr = defexpr::Expr # == new(...)
1008+
if fidx + 1 > length(newexpr.args) # this field can be undefined
9411009
domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks))
942-
has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip
1010+
for use in du.uses
1011+
has_safe_def_for_undef_field(ir, domtree, allblocks, du, idx, use) || @goto skip
1012+
end
9431013
end
9441014
end
9451015
end
@@ -984,28 +1054,79 @@ function sroa_mutables!(ir::IRCode,
9841054
end
9851055
end
9861056
end
987-
for stmt in du.defs
988-
stmt == newidx && continue
989-
ir[SSAValue(stmt)] = nothing
1057+
eliminable_defs === nothing && (eliminable_defs = SPCSet())
1058+
for def in du.defs
1059+
push!(eliminable_defs, def)
1060+
end
1061+
if phidefs !== nothing
1062+
# record ϕ-node itself eliminable here, since we didn't include it in `du.defs`
1063+
# we also modify usage counts of its predecessors so that their SROA may work
1064+
# in succeeding iteration
1065+
push!(eliminable_defs, idx)
1066+
for pidx in phidefs
1067+
used_ssas[pidx] -= 1
1068+
end
9901069
end
9911070
end
9921071
preserve_uses === nothing && continue
9931072
if all_forwarded
9941073
# this means all ccall preserves have been replaced with forwarded loads
9951074
# so we can potentially eliminate the allocation, otherwise we must preserve
9961075
# the whole allocation.
997-
push!(intermediaries, newidx)
1076+
push!(intermediaries, idx)
9981077
end
9991078
# Insert the new preserves
10001079
for (use, new_preserves) in preserve_uses
10011080
ir[SSAValue(use)] = form_new_preserves(ir[SSAValue(use)]::Expr, intermediaries, new_preserves)
10021081
end
1003-
10041082
@label skip
10051083
end
1084+
# now eliminate "definitions" (i.e. allocations, ϕ-nodes, and `setfield!` calls)
1085+
# that should have no usage at this moment
1086+
if eliminable_defs !== nothing
1087+
for idx in eliminable_defs
1088+
ir[SSAValue(idx)] = nothing
1089+
end
1090+
end
10061091
return any_eliminated ? sroa_pass!(compact!(ir), false) : ir
10071092
end
10081093

1094+
count_leaves(defuse::SSADefUse) =
1095+
length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
1096+
1097+
function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int)
1098+
defexpr = ir[SSAValue(idx)]
1099+
if is_known_call(defexpr, getfield, ir)
1100+
val = defexpr.args[2]
1101+
if isa(val, SSAValue)
1102+
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
1103+
if ismutabletype(struct_typ)
1104+
record_nested_load!(nested_mloads, idx)
1105+
end
1106+
end
1107+
end
1108+
end
1109+
1110+
function check_use_eliminability!(fielddefuse::Vector{SSADefUse},
1111+
ir::IRCode, useidx::Int, struct_typ::DataType)
1112+
stmt = ir[SSAValue(useidx)] # == `getfield` call
1113+
stmt === nothing && return nothing
1114+
field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ)
1115+
field === nothing && return false
1116+
push!(fielddefuse[field].uses, useidx)
1117+
return true
1118+
end
1119+
1120+
function check_def_eliminability!(fielddefuse::Vector{SSADefUse},
1121+
ir::IRCode, defidx::Int, struct_typ::DataType)
1122+
stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call
1123+
field = try_compute_fieldidx_stmt(ir, stmt, struct_typ)
1124+
field === nothing && return false
1125+
isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error
1126+
push!(fielddefuse[field].defs, defidx)
1127+
return true
1128+
end
1129+
10091130
function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})
10101131
newex = Expr(:foreigncall)
10111132
nccallargs = length(origex.args[3]::SimpleVector)

0 commit comments

Comments
 (0)