Skip to content

Commit 73fbd3d

Browse files
committed
optimizer: enable SROA of mutable φ-nodes
This commit allows elimination of mutable φ-node (and its predecessor mutables allocations). As an contrived example, it allows this `mutable_ϕ_elim(::String, ::Vector{String})` to run without any allocations at all: ```julia function mutable_ϕ_elim(x, xs) r = Ref(x) for x in xs r = Ref(x) end return r[] end let xs = String[string(gensym()) for _ in 1:100] mutable_ϕ_elim("init", xs) @test @allocated(mutable_ϕ_elim("init", xs)) == 0 end ``` This mutable ϕ-node elimination is still limited though. Most notably, the current implementation doesn't work if a mutable allocation forms multiple ϕ-nodes, since we check allocation eliminability (i.e. escapability) by counting usages counts and thus it's hard to reason about multiple ϕ-nodes at a time. For example, currently mutable allocations involved in cases like below will still not be eliminated: ```julia code_typed((Bool,String,String),) do cond, x, y if cond ϕ2 = ϕ1 = Ref(x) else ϕ2 = ϕ1 = Ref(y) end ϕ1[], ϕ2[] end \# more realistic example mutable struct Point{T} x::T y::T end add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) function compute(a::Point{ComplexF64}, b::Point{ComplexF64}) for i in 0:(100000000-1) a = add(add(a, b), b) end a.x, a.y end ``` I'd say this limitation should be addressed by first introducing a better abstraction for reasoning escape information. More specifically, I'd like introduce EscapeAnalysis.jl into Julia base first, and then gradually adapt it to improve our SROA pass, since EA will allow us to reason about all escape information imposed on whatever object more easily and should help us get rid of the complexities of our current SROA implementation. For now, I'd like to get in this enhancement even though it has the limitation elaborated above, as far as this commit doesn't introduce latency problem (which is unlikely).
1 parent dec65e1 commit 73fbd3d

File tree

2 files changed

+348
-53
lines changed

2 files changed

+348
-53
lines changed

base/compiler/ssair/passes.jl

+165-52
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,22 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
9999
end
100100
end
101101

102-
# even when the allocation contains an uninitialized field, we try an extra effort to check
103-
# if this load at `idx` have any "safe" `setfield!` calls that define the field
104102
function has_safe_def(
103+
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
104+
newidx::Int, fidx::Int)
105+
newexpr = ir[SSAValue(newidx)]::Expr
106+
107+
fidx + 1 length(newexpr.args) && return true # assured to have a safe definition for all usages
108+
109+
# even when the allocation contains an uninitialized field, we try an extra effort to
110+
# check if all loads have "safe" `setfield!` calls that define the uninitialized field
111+
for use in du.uses
112+
has_safe_def_for_uninitialized_field(ir, domtree, allblocks, du, newidx, use) || return false
113+
end
114+
return true # shuold be safe
115+
end
116+
117+
function has_safe_def_for_uninitialized_field(
105118
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
106119
newidx::Int, idx::Int)
107120
def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx)
@@ -206,14 +219,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
206219
end
207220

208221
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
209-
@nospecialize(typeconstraint))
210-
callback = function (@nospecialize(pi), @nospecialize(idx))
211-
if isa(pi, PiNode)
212-
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
222+
@nospecialize(typeconstraint), @nospecialize(callback = nothing))
223+
newcallback = function (@nospecialize(x), @nospecialize(idx))
224+
if isa(x, PiNode)
225+
typeconstraint = typeintersect(typeconstraint, widenconst(x.typ))
213226
end
227+
callback === nothing || callback(x, idx)
214228
return false
215229
end
216-
def = simple_walk(compact, defssa, callback)
230+
def = simple_walk(compact, defssa, newcallback)
217231
return Pair{Any, Any}(def, typeconstraint)
218232
end
219233

@@ -223,7 +237,9 @@ end
223237
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
224238
(pruning those leaves rules out by path conditions).
225239
"""
226-
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
240+
function walk_to_defs(compact::IncrementalCompact,
241+
@nospecialize(defssa), @nospecialize(typeconstraint),
242+
@nospecialize(callback = nothing))
227243
visited_phinodes = AnySSAValue[]
228244
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
229245
def = compact[defssa]
@@ -259,7 +275,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
259275
val = OldSSAValue(val.id)
260276
end
261277
if isa(val, AnySSAValue)
262-
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint)
278+
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback)
263279
if isa(new_def, AnySSAValue)
264280
if !haskey(visited_constraints, new_def)
265281
push!(worklist_defs, new_def)
@@ -720,10 +736,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
720736
continue
721737
end
722738
if defuses === nothing
723-
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
739+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
724740
end
725-
mid, defuse = get!(defuses, defidx) do
726-
SPCSet(), SSADefUse()
741+
mid, defuse, phidefs = get!(defuses, defidx) do
742+
SPCSet(), SSADefUse(), PhiDefs(nothing)
727743
end
728744
push!(defuse.ccall_preserve_uses, idx)
729745
union!(mid, intermediaries)
@@ -778,16 +794,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
778794
# Mutable stuff here
779795
isa(def, SSAValue) || continue
780796
if defuses === nothing
781-
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
797+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
782798
end
783-
mid, defuse = get!(defuses, def.id) do
784-
SPCSet(), SSADefUse()
799+
mid, defuse, phidefs = get!(defuses, def.id) do
800+
SPCSet(), SSADefUse(), PhiDefs(nothing)
785801
end
786802
if is_setfield
787803
push!(defuse.defs, idx)
788804
else
789805
push!(defuse.uses, idx)
790806
end
807+
defval = compact[def]
808+
if isa(defval, PhiNode)
809+
phicallback = function (@nospecialize(x), @nospecialize(ssa))
810+
push!(intermediaries, ssa.id)
811+
return false
812+
end
813+
defs, _ = walk_to_defs(compact, def, struct_typ, phicallback)
814+
if _any(@nospecialize(d)->!isa(d, SSAValue), defs)
815+
delete!(defuses, def.id)
816+
continue
817+
end
818+
phidefs[] = Int[(def::SSAValue).id for def in defs]
819+
end
791820
union!(mid, intermediaries)
792821
end
793822
continue
@@ -847,8 +876,14 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
847876
end
848877
end
849878

879+
# TODO:
880+
# - run mutable SROA on the same IR as when we collect information about mutable allocations
881+
# - simplify and improve the eliminability check below using an escape analysis
882+
883+
const PhiDefs = RefValue{Union{Nothing,Vector{Int}}}
884+
850885
function sroa_mutables!(ir::IRCode,
851-
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
886+
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int},
852887
nested_loads::NestedLoads)
853888
# Compute domtree, needed below, now that we have finished compacting the IR.
854889
# This needs to be after we iterate through the IR with `IncrementalCompact`
@@ -858,36 +893,58 @@ function sroa_mutables!(ir::IRCode,
858893
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
859894
local any_eliminated = false
860895
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
861-
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
896+
for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true)
862897
intermediaries = collect(intermediaries)
898+
phidefs = phidefs[]
863899
# Check if there are any uses we did not account for. If so, the variable
864900
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
865901
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
866902
# show up in the nuses_total count.
867-
nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
903+
nleaves = count_leaves(defuse)
904+
if phidefs !== nothing
905+
# if this defines ϕ, we also track leaves of all predecessors as well
906+
# FIXME this doesn't work when any predecessor is used by another ϕ-node
907+
for pidx in phidefs
908+
haskey(defuses, pidx) || continue
909+
pdefuse = defuses[pidx][2]
910+
nleaves += count_leaves(pdefuse)
911+
end
912+
end
868913
nuses = 0
869914
for idx in intermediaries
870915
nuses += used_ssas[idx]
871916
end
872-
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
917+
nuses -= length(intermediaries)
918+
nuses_total = used_ssas[idx] + nuses
919+
if phidefs !== nothing
920+
for pidx in phidefs
921+
# NOTE we don't need to accout for intermediates for this predecessor here,
922+
# since they are already included in intermediates of this ϕ-node
923+
# FIXME this doesn't work when any predecessor is used by another ϕ-node
924+
nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself
925+
end
926+
end
873927
nleaves == nuses_total || continue
874928
# Find the type for this allocation
875929
defexpr = ir[SSAValue(idx)]
876-
isa(defexpr, Expr) || continue
877-
if !isexpr(defexpr, :new)
878-
if is_known_call(defexpr, getfield, ir)
879-
val = defexpr.args[2]
880-
if isa(val, SSAValue)
881-
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
882-
if ismutabletype(struct_typ)
883-
record_nested_load!(nested_mloads, idx)
884-
end
930+
if isa(defexpr, Expr)
931+
if !isexpr(defexpr, :new)
932+
maybe_record_nested_load!(nested_mloads, ir, idx)
933+
continue
934+
end
935+
elseif isa(defexpr, PhiNode)
936+
phidefs === nothing && continue
937+
for pidx in phidefs
938+
pexpr = ir[SSAValue(pidx)]
939+
if !isexpr(pexpr, :new)
940+
maybe_record_nested_load!(nested_mloads, ir, pidx)
941+
@goto skip
885942
end
886943
end
944+
else
887945
continue
888946
end
889-
newidx = idx
890-
typ = ir.stmts[newidx][:type]
947+
typ = ir.stmts[idx][:type]
891948
if isa(typ, UnionAll)
892949
typ = unwrap_unionall(typ)
893950
end
@@ -899,25 +956,29 @@ function sroa_mutables!(ir::IRCode,
899956
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
900957
all_forwarded = true
901958
for use in defuse.uses
902-
stmt = ir[SSAValue(use)] # == `getfield` call
903-
# We may have discovered above that this use is dead
904-
# after the getfield elim of immutables. In that case,
905-
# it would have been deleted. That's fine, just ignore
906-
# the use in that case.
907-
if stmt === nothing
959+
eliminable = check_use_eliminability!(fielddefuse, ir, use, typ)
960+
if eliminable === nothing
961+
# We may have discovered above that this use is dead
962+
# after the getfield elim of immutables. In that case,
963+
# it would have been deleted. That's fine, just ignore
964+
# the use in that case.
908965
all_forwarded = false
909966
continue
967+
elseif !eliminable
968+
@goto skip
910969
end
911-
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
912-
field === nothing && @goto skip
913-
push!(fielddefuse[field].uses, use)
914970
end
915971
for def in defuse.defs
916-
stmt = ir[SSAValue(def)]::Expr # == `setfield!` call
917-
field = try_compute_fieldidx_stmt(ir, stmt, typ)
918-
field === nothing && @goto skip
919-
isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
920-
push!(fielddefuse[field].defs, def)
972+
check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip
973+
end
974+
if phidefs !== nothing
975+
for pidx in phidefs
976+
haskey(defuses, pidx) || continue
977+
pdefuse = defuses[pidx][2]
978+
for pdef in pdefuse.defs
979+
check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip
980+
end
981+
end
921982
end
922983
# Check that the defexpr has defined values for all the fields
923984
# we're accessing. In the future, we may want to relax this,
@@ -928,15 +989,24 @@ function sroa_mutables!(ir::IRCode,
928989
for fidx in 1:ndefuse
929990
du = fielddefuse[fidx]
930991
isempty(du.uses) && continue
931-
push!(du.defs, newidx)
992+
if phidefs === nothing
993+
push!(du.defs, idx)
994+
else
995+
for pidx in phidefs
996+
push!(du.defs, pidx)
997+
end
998+
end
932999
ldu = compute_live_ins(ir.cfg, du)
9331000
phiblocks = isempty(ldu.live_in_bbs) ? Int[] : iterated_dominance_frontier(ir.cfg, ldu, domtree)
9341001
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
9351002
blocks[fidx] = phiblocks, allblocks
936-
if fidx + 1 > length(defexpr.args)
937-
for use in du.uses
938-
has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip
1003+
if phidefs !== nothing
1004+
# check if all predecessors have safe definitions
1005+
for pidx in phidefs
1006+
has_safe_def(ir, domtree, allblocks, du, pidx, fidx) || @goto skip
9391007
end
1008+
else
1009+
has_safe_def(ir, domtree, allblocks, du, idx, fidx) || @goto skip
9401010
end
9411011
end
9421012
# Everything accounted for. Go field by field and perform idf
@@ -976,17 +1046,24 @@ function sroa_mutables!(ir::IRCode,
9761046
end
9771047
end
9781048
end
979-
for stmt in du.defs
980-
stmt == newidx && continue
981-
ir[SSAValue(stmt)] = nothing
1049+
if isa(defexpr, PhiNode)
1050+
ir[SSAValue(idx)] = nothing
1051+
for pidx in phidefs::Vector{Int}
1052+
used_ssas[pidx] -= 1
1053+
end
1054+
else
1055+
for stmt in du.defs
1056+
stmt == idx && continue
1057+
ir[SSAValue(stmt)] = nothing
1058+
end
9821059
end
9831060
end
9841061
preserve_uses === nothing && continue
9851062
if all_forwarded
9861063
# this means all ccall preserves have been replaced with forwarded loads
9871064
# so we can potentially eliminate the allocation, otherwise we must preserve
9881065
# the whole allocation.
989-
push!(intermediaries, newidx)
1066+
push!(intermediaries, idx)
9901067
end
9911068
# Insert the new preserves
9921069
for (use, new_preserves) in preserve_uses
@@ -1002,6 +1079,42 @@ function sroa_mutables!(ir::IRCode,
10021079
end
10031080
end
10041081

1082+
count_leaves(defuse::SSADefUse) =
1083+
length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
1084+
1085+
function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int)
1086+
defexpr = ir[SSAValue(idx)]
1087+
if is_known_call(defexpr, getfield, ir)
1088+
val = defexpr.args[2]
1089+
if isa(val, SSAValue)
1090+
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
1091+
if ismutabletype(struct_typ)
1092+
record_nested_load!(nested_mloads, idx)
1093+
end
1094+
end
1095+
end
1096+
end
1097+
1098+
function check_use_eliminability!(fielddefuse::Vector{SSADefUse},
1099+
ir::IRCode, useidx::Int, struct_typ::DataType)
1100+
stmt = ir[SSAValue(useidx)] # == `getfield` call
1101+
stmt === nothing && return nothing
1102+
field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ)
1103+
field === nothing && return false
1104+
push!(fielddefuse[field].uses, useidx)
1105+
return true
1106+
end
1107+
1108+
function check_def_eliminability!(fielddefuse::Vector{SSADefUse},
1109+
ir::IRCode, defidx::Int, struct_typ::DataType)
1110+
stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call
1111+
field = try_compute_fieldidx_stmt(ir, stmt, struct_typ)
1112+
field === nothing && return false
1113+
isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error
1114+
push!(fielddefuse[field].defs, defidx)
1115+
return true
1116+
end
1117+
10051118
function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})
10061119
newex = Expr(:foreigncall)
10071120
nccallargs = length(origex.args[3]::SimpleVector)

0 commit comments

Comments
 (0)