Skip to content

Commit 1b8be01

Browse files
committed
optimizer: enable SROA of mutable φ-nodes
This commit allows elimination of mutable φ-node (and its predecessor mutables allocations). As an contrived example, it allows this `mutable_ϕ_elim(::String, ::Vector{String})` to run without any allocations at all: ```julia function mutable_ϕ_elim(x, xs) r = Ref(x) for x in xs r = Ref(x) end return r[] end let xs = String[string(gensym()) for _ in 1:100] mutable_ϕ_elim("init", xs) @test @allocated(mutable_ϕ_elim("init", xs)) == 0 end ``` This mutable ϕ-node elimination is still limited though. Most notably, the current implementation doesn't work if a mutable allocation forms multiple ϕ-nodes, since we check allocation eliminability (i.e. escapability) by counting usages counts and thus it's hard to reason about multiple ϕ-nodes at a time. For example, currently mutable allocations involved in cases like below will still not be eliminated: ```julia code_typed((Bool,String,String),) do cond, x, y if cond ϕ2 = ϕ1 = Ref(x) else ϕ2 = ϕ1 = Ref(y) end ϕ1[], ϕ2[] end \# more realistic example mutable struct Point{T} x::T y::T end add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) function compute(a::Point{ComplexF64}, b::Point{ComplexF64}) for i in 0:(100000000-1) a = add(add(a, b), b) end a.x, a.y end ``` I'd say this limitation should be addressed by first introducing a better abstraction for reasoning escape information. More specifically, I'd like introduce EscapeAnalysis.jl into Julia base first, and then gradually adapt it to improve our SROA pass, since EA will allow us to reason about all escape information imposed on whatever object more easily and should help us get rid of the complexities of our current SROA implementation. For now, I'd like to get in this enhancement even though it has the limitation elaborated above, as far as this commit doesn't introduce latency problem (which is unlikely).
1 parent 738df81 commit 1b8be01

File tree

2 files changed

+348
-53
lines changed

2 files changed

+348
-53
lines changed

base/compiler/ssair/passes.jl

+165-52
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,22 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
100100
end
101101
end
102102

103-
# even when the allocation contains an uninitialized field, we try an extra effort to check
104-
# if this load at `idx` have any "safe" `setfield!` calls that define the field
105103
function has_safe_def(
104+
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
105+
newidx::Int, fidx::Int)
106+
newexpr = ir[SSAValue(newidx)]::Expr
107+
108+
fidx + 1 length(newexpr.args) && return true # assured to have a safe definition for all usages
109+
110+
# even when the allocation contains an uninitialized field, we try an extra effort to
111+
# check if all loads have "safe" `setfield!` calls that define the uninitialized field
112+
for use in du.uses
113+
has_safe_def_for_uninitialized_field(ir, domtree, allblocks, du, newidx, use) || return false
114+
end
115+
return true # shuold be safe
116+
end
117+
118+
function has_safe_def_for_uninitialized_field(
106119
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
107120
newidx::Int, idx::Int)
108121
def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx)
@@ -207,14 +220,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
207220
end
208221

209222
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
210-
@nospecialize(typeconstraint))
211-
callback = function (@nospecialize(pi), @nospecialize(idx))
212-
if isa(pi, PiNode)
213-
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
223+
@nospecialize(typeconstraint), @nospecialize(callback = nothing))
224+
newcallback = function (@nospecialize(x), @nospecialize(idx))
225+
if isa(x, PiNode)
226+
typeconstraint = typeintersect(typeconstraint, widenconst(x.typ))
214227
end
228+
callback === nothing || callback(x, idx)
215229
return false
216230
end
217-
def = simple_walk(compact, defssa, callback)
231+
def = simple_walk(compact, defssa, newcallback)
218232
return Pair{Any, Any}(def, typeconstraint)
219233
end
220234

@@ -224,7 +238,9 @@ end
224238
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
225239
(pruning those leaves rules out by path conditions).
226240
"""
227-
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
241+
function walk_to_defs(compact::IncrementalCompact,
242+
@nospecialize(defssa), @nospecialize(typeconstraint),
243+
@nospecialize(callback = nothing))
228244
visited_phinodes = AnySSAValue[]
229245
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
230246
def = compact[defssa]
@@ -260,7 +276,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
260276
val = OldSSAValue(val.id)
261277
end
262278
if isa(val, AnySSAValue)
263-
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint)
279+
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback)
264280
if isa(new_def, AnySSAValue)
265281
if !haskey(visited_constraints, new_def)
266282
push!(worklist_defs, new_def)
@@ -721,10 +737,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
721737
continue
722738
end
723739
if defuses === nothing
724-
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
740+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
725741
end
726-
mid, defuse = get!(defuses, defidx) do
727-
SPCSet(), SSADefUse()
742+
mid, defuse, phidefs = get!(defuses, defidx) do
743+
SPCSet(), SSADefUse(), PhiDefs(nothing)
728744
end
729745
push!(defuse.ccall_preserve_uses, idx)
730746
union!(mid, intermediaries)
@@ -779,16 +795,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
779795
# Mutable stuff here
780796
isa(def, SSAValue) || continue
781797
if defuses === nothing
782-
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
798+
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
783799
end
784-
mid, defuse = get!(defuses, def.id) do
785-
SPCSet(), SSADefUse()
800+
mid, defuse, phidefs = get!(defuses, def.id) do
801+
SPCSet(), SSADefUse(), PhiDefs(nothing)
786802
end
787803
if is_setfield
788804
push!(defuse.defs, idx)
789805
else
790806
push!(defuse.uses, idx)
791807
end
808+
defval = compact[def]
809+
if isa(defval, PhiNode)
810+
phicallback = function (@nospecialize(x), @nospecialize(ssa))
811+
push!(intermediaries, ssa.id)
812+
return false
813+
end
814+
defs, _ = walk_to_defs(compact, def, struct_typ, phicallback)
815+
if _any(@nospecialize(d)->!isa(d, SSAValue), defs)
816+
delete!(defuses, def.id)
817+
continue
818+
end
819+
phidefs[] = Int[(def::SSAValue).id for def in defs]
820+
end
792821
union!(mid, intermediaries)
793822
end
794823
continue
@@ -848,8 +877,14 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
848877
end
849878
end
850879

880+
# TODO:
881+
# - run mutable SROA on the same IR as when we collect information about mutable allocations
882+
# - simplify and improve the eliminability check below using an escape analysis
883+
884+
const PhiDefs = RefValue{Union{Nothing,Vector{Int}}}
885+
851886
function sroa_mutables!(ir::IRCode,
852-
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
887+
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int},
853888
nested_loads::NestedLoads)
854889
# Compute domtree, needed below, now that we have finished compacting the IR.
855890
# This needs to be after we iterate through the IR with `IncrementalCompact`
@@ -859,36 +894,58 @@ function sroa_mutables!(ir::IRCode,
859894
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
860895
local any_eliminated = false
861896
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
862-
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
897+
for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true)
863898
intermediaries = collect(intermediaries)
899+
phidefs = phidefs[]
864900
# Check if there are any uses we did not account for. If so, the variable
865901
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
866902
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
867903
# show up in the nuses_total count.
868-
nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
904+
nleaves = count_leaves(defuse)
905+
if phidefs !== nothing
906+
# if this defines ϕ, we also track leaves of all predecessors as well
907+
# FIXME this doesn't work when any predecessor is used by another ϕ-node
908+
for pidx in phidefs
909+
haskey(defuses, pidx) || continue
910+
pdefuse = defuses[pidx][2]
911+
nleaves += count_leaves(pdefuse)
912+
end
913+
end
869914
nuses = 0
870915
for idx in intermediaries
871916
nuses += used_ssas[idx]
872917
end
873-
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
918+
nuses -= length(intermediaries)
919+
nuses_total = used_ssas[idx] + nuses
920+
if phidefs !== nothing
921+
for pidx in phidefs
922+
# NOTE we don't need to accout for intermediates for this predecessor here,
923+
# since they are already included in intermediates of this ϕ-node
924+
# FIXME this doesn't work when any predecessor is used by another ϕ-node
925+
nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself
926+
end
927+
end
874928
nleaves == nuses_total || continue
875929
# Find the type for this allocation
876930
defexpr = ir[SSAValue(idx)]
877-
isa(defexpr, Expr) || continue
878-
if !isexpr(defexpr, :new)
879-
if is_known_call(defexpr, getfield, ir)
880-
val = defexpr.args[2]
881-
if isa(val, SSAValue)
882-
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
883-
if ismutabletype(struct_typ)
884-
record_nested_load!(nested_mloads, idx)
885-
end
931+
if isa(defexpr, Expr)
932+
if !isexpr(defexpr, :new)
933+
maybe_record_nested_load!(nested_mloads, ir, idx)
934+
continue
935+
end
936+
elseif isa(defexpr, PhiNode)
937+
phidefs === nothing && continue
938+
for pidx in phidefs
939+
pexpr = ir[SSAValue(pidx)]
940+
if !isexpr(pexpr, :new)
941+
maybe_record_nested_load!(nested_mloads, ir, pidx)
942+
@goto skip
886943
end
887944
end
945+
else
888946
continue
889947
end
890-
newidx = idx
891-
typ = ir.stmts[newidx][:type]
948+
typ = ir.stmts[idx][:type]
892949
if isa(typ, UnionAll)
893950
typ = unwrap_unionall(typ)
894951
end
@@ -900,25 +957,29 @@ function sroa_mutables!(ir::IRCode,
900957
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
901958
all_forwarded = true
902959
for use in defuse.uses
903-
stmt = ir[SSAValue(use)] # == `getfield` call
904-
# We may have discovered above that this use is dead
905-
# after the getfield elim of immutables. In that case,
906-
# it would have been deleted. That's fine, just ignore
907-
# the use in that case.
908-
if stmt === nothing
960+
eliminable = check_use_eliminability!(fielddefuse, ir, use, typ)
961+
if eliminable === nothing
962+
# We may have discovered above that this use is dead
963+
# after the getfield elim of immutables. In that case,
964+
# it would have been deleted. That's fine, just ignore
965+
# the use in that case.
909966
all_forwarded = false
910967
continue
968+
elseif !eliminable
969+
@goto skip
911970
end
912-
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
913-
field === nothing && @goto skip
914-
push!(fielddefuse[field].uses, use)
915971
end
916972
for def in defuse.defs
917-
stmt = ir[SSAValue(def)]::Expr # == `setfield!` call
918-
field = try_compute_fieldidx_stmt(ir, stmt, typ)
919-
field === nothing && @goto skip
920-
isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
921-
push!(fielddefuse[field].defs, def)
973+
check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip
974+
end
975+
if phidefs !== nothing
976+
for pidx in phidefs
977+
haskey(defuses, pidx) || continue
978+
pdefuse = defuses[pidx][2]
979+
for pdef in pdefuse.defs
980+
check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip
981+
end
982+
end
922983
end
923984
# Check that the defexpr has defined values for all the fields
924985
# we're accessing. In the future, we may want to relax this,
@@ -929,15 +990,24 @@ function sroa_mutables!(ir::IRCode,
929990
for fidx in 1:ndefuse
930991
du = fielddefuse[fidx]
931992
isempty(du.uses) && continue
932-
push!(du.defs, newidx)
993+
if phidefs === nothing
994+
push!(du.defs, idx)
995+
else
996+
for pidx in phidefs
997+
push!(du.defs, pidx)
998+
end
999+
end
9331000
ldu = compute_live_ins(ir.cfg, du)
9341001
phiblocks = isempty(ldu.live_in_bbs) ? Int[] : iterated_dominance_frontier(ir.cfg, ldu, domtree)
9351002
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
9361003
blocks[fidx] = phiblocks, allblocks
937-
if fidx + 1 > length(defexpr.args)
938-
for use in du.uses
939-
has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip
1004+
if phidefs !== nothing
1005+
# check if all predecessors have safe definitions
1006+
for pidx in phidefs
1007+
has_safe_def(ir, domtree, allblocks, du, pidx, fidx) || @goto skip
9401008
end
1009+
else
1010+
has_safe_def(ir, domtree, allblocks, du, idx, fidx) || @goto skip
9411011
end
9421012
end
9431013
# Everything accounted for. Go field by field and perform idf
@@ -977,17 +1047,24 @@ function sroa_mutables!(ir::IRCode,
9771047
end
9781048
end
9791049
end
980-
for stmt in du.defs
981-
stmt == newidx && continue
982-
ir[SSAValue(stmt)] = nothing
1050+
if isa(defexpr, PhiNode)
1051+
ir[SSAValue(idx)] = nothing
1052+
for pidx in phidefs::Vector{Int}
1053+
used_ssas[pidx] -= 1
1054+
end
1055+
else
1056+
for stmt in du.defs
1057+
stmt == idx && continue
1058+
ir[SSAValue(stmt)] = nothing
1059+
end
9831060
end
9841061
end
9851062
preserve_uses === nothing && continue
9861063
if all_forwarded
9871064
# this means all ccall preserves have been replaced with forwarded loads
9881065
# so we can potentially eliminate the allocation, otherwise we must preserve
9891066
# the whole allocation.
990-
push!(intermediaries, newidx)
1067+
push!(intermediaries, idx)
9911068
end
9921069
# Insert the new preserves
9931070
for (use, new_preserves) in preserve_uses
@@ -1003,6 +1080,42 @@ function sroa_mutables!(ir::IRCode,
10031080
end
10041081
end
10051082

1083+
count_leaves(defuse::SSADefUse) =
1084+
length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
1085+
1086+
function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int)
1087+
defexpr = ir[SSAValue(idx)]
1088+
if is_known_call(defexpr, getfield, ir)
1089+
val = defexpr.args[2]
1090+
if isa(val, SSAValue)
1091+
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
1092+
if ismutabletype(struct_typ)
1093+
record_nested_load!(nested_mloads, idx)
1094+
end
1095+
end
1096+
end
1097+
end
1098+
1099+
function check_use_eliminability!(fielddefuse::Vector{SSADefUse},
1100+
ir::IRCode, useidx::Int, struct_typ::DataType)
1101+
stmt = ir[SSAValue(useidx)] # == `getfield` call
1102+
stmt === nothing && return nothing
1103+
field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ)
1104+
field === nothing && return false
1105+
push!(fielddefuse[field].uses, useidx)
1106+
return true
1107+
end
1108+
1109+
function check_def_eliminability!(fielddefuse::Vector{SSADefUse},
1110+
ir::IRCode, defidx::Int, struct_typ::DataType)
1111+
stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call
1112+
field = try_compute_fieldidx_stmt(ir, stmt, struct_typ)
1113+
field === nothing && return false
1114+
isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error
1115+
push!(fielddefuse[field].defs, defidx)
1116+
return true
1117+
end
1118+
10061119
function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})
10071120
newex = Expr(:foreigncall)
10081121
nccallargs = length(origex.args[3]::SimpleVector)

0 commit comments

Comments
 (0)