Skip to content

Commit dec65e1

Browse files
committed
optimizer: run SROA multiple times to handle more nested loads
1 parent 7cd1da3 commit dec65e1

File tree

2 files changed

+124
-39
lines changed

2 files changed

+124
-39
lines changed

base/compiler/ssair/passes.jl

+83-29
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
167167
end
168168

169169
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
170-
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
170+
callback = (@nospecialize(x), @nospecialize(idx)) -> false)
171171
while true
172172
if isa(defssa, OldSSAValue)
173173
if already_inserted(compact, defssa)
@@ -335,10 +335,29 @@ struct LiftedValue
335335
end
336336
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
337337

338+
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
339+
# which can be very large sometimes, and program counters in question are often very sparse
340+
const SPCSet = IdSet{Int}
341+
342+
mutable struct NestedLoads
343+
maybe::Union{Nothing,SPCSet}
344+
NestedLoads() = new(nothing)
345+
end
346+
function record_nested_load!(nested_loads::NestedLoads, pc::Int)
347+
maybe = nested_loads.maybe
348+
maybe === nothing && (maybe = nested_loads.maybe = SPCSet())
349+
push!(maybe::SPCSet, pc)
350+
end
351+
function is_nested_load(nested_loads::NestedLoads, pc::Int)
352+
maybe = nested_loads.maybe
353+
maybe === nothing && return false
354+
return pc in maybe::SPCSet
355+
end
356+
338357
# try to compute lifted values that can replace `getfield(x, field)` call
339358
# where `x` is an immutable struct that are defined at any of `leaves`
340-
function lift_leaves(compact::IncrementalCompact,
341-
@nospecialize(result_t), field::Int, leaves::Vector{Any})
359+
function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any},
360+
@nospecialize(result_t), field::Int, nested_loads::NestedLoads)
342361
# For every leaf, the lifted value
343362
lifted_leaves = LiftedLeaves()
344363
maybe_undef = false
@@ -388,11 +407,19 @@ function lift_leaves(compact::IncrementalCompact,
388407
ocleaf = simple_walk(compact, ocleaf)
389408
end
390409
ocdef, _ = walk_to_def(compact, ocleaf)
391-
if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 field length(ocdef.args)-5
410+
if isexpr(ocdef, :new_opaque_closure) && 1 field length(ocdef.args)-5
392411
lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves)
393412
continue
394413
end
395414
return nothing
415+
elseif is_known_call(def, getfield, compact)
416+
if isa(leaf, SSAValue)
417+
struct_typ = unwrap_unionall(widenconst(argextype(def.args[2], compact)))
418+
if ismutabletype(struct_typ)
419+
record_nested_load!(nested_loads, leaf.id)
420+
end
421+
end
422+
return nothing
396423
else
397424
typ = argextype(leaf, compact)
398425
if !isa(typ, Const)
@@ -586,7 +613,7 @@ function perform_lifting!(compact::IncrementalCompact,
586613
end
587614
val = lifted_val.x
588615
if isa(val, AnySSAValue)
589-
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
616+
callback = (@nospecialize(x), @nospecialize(idx)) -> true
590617
val = simple_walk(compact, val, callback)
591618
end
592619
push!(new_node.values, val)
@@ -617,10 +644,6 @@ function perform_lifting!(compact::IncrementalCompact,
617644
return stmt_val # N.B. should never happen
618645
end
619646

620-
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
621-
# which can be very large sometimes, and program counters in question are often very sparse
622-
const SPCSet = IdSet{Int}
623-
624647
"""
625648
sroa_pass!(ir::IRCode) -> newir::IRCode
626649
@@ -639,10 +662,11 @@ its argument).
639662
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
640663
a result of succeeding dead code elimination.
641664
"""
642-
function sroa_pass!(ir::IRCode)
665+
function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
643666
compact = IncrementalCompact(ir)
644667
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
645668
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
669+
nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
646670
for ((_, idx), stmt) in compact
647671
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
648672
isa(stmt, Expr) || continue
@@ -670,7 +694,7 @@ function sroa_pass!(ir::IRCode)
670694
preserved_arg = stmt.args[pidx]
671695
isa(preserved_arg, SSAValue) || continue
672696
let intermediaries = SPCSet()
673-
callback = function (@nospecialize(pi), @nospecialize(ssa))
697+
callback = function (@nospecialize(x), @nospecialize(ssa))
674698
push!(intermediaries, ssa.id)
675699
return false
676700
end
@@ -698,7 +722,9 @@ function sroa_pass!(ir::IRCode)
698722
if defuses === nothing
699723
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
700724
end
701-
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
725+
mid, defuse = get!(defuses, defidx) do
726+
SPCSet(), SSADefUse()
727+
end
702728
push!(defuse.ccall_preserve_uses, idx)
703729
union!(mid, intermediaries)
704730
end
@@ -708,16 +734,17 @@ function sroa_pass!(ir::IRCode)
708734
compact[idx] = form_new_preserves(stmt, preserved, new_preserves)
709735
end
710736
continue
711-
# TODO: This isn't the best place to put these
712-
elseif is_known_call(stmt, typeassert, compact)
713-
canonicalize_typeassert!(compact, idx, stmt)
714-
continue
715-
elseif is_known_call(stmt, (===), compact)
716-
lift_comparison!(compact, idx, stmt, lifting_cache)
717-
continue
718-
# elseif is_known_call(stmt, isa, compact)
719-
# TODO do a similar optimization as `lift_comparison!` for `===`
720737
else
738+
if optional_opts
739+
# TODO: This isn't the best place to put these
740+
if is_known_call(stmt, typeassert, compact)
741+
canonicalize_typeassert!(compact, idx, stmt)
742+
elseif is_known_call(stmt, (===), compact)
743+
lift_comparison!(compact, idx, stmt, lifting_cache)
744+
# elseif is_known_call(stmt, isa, compact)
745+
# TODO do a similar optimization as `lift_comparison!` for `===`
746+
end
747+
end
721748
continue
722749
end
723750

@@ -743,7 +770,7 @@ function sroa_pass!(ir::IRCode)
743770
if ismutabletype(struct_typ)
744771
isa(val, SSAValue) || continue
745772
let intermediaries = SPCSet()
746-
callback = function (@nospecialize(pi), @nospecialize(ssa))
773+
callback = function (@nospecialize(x), @nospecialize(ssa))
747774
push!(intermediaries, ssa.id)
748775
return false
749776
end
@@ -753,7 +780,9 @@ function sroa_pass!(ir::IRCode)
753780
if defuses === nothing
754781
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
755782
end
756-
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
783+
mid, defuse = get!(defuses, def.id) do
784+
SPCSet(), SSADefUse()
785+
end
757786
if is_setfield
758787
push!(defuse.defs, idx)
759788
else
@@ -775,7 +804,7 @@ function sroa_pass!(ir::IRCode)
775804
isempty(leaves) && continue
776805

777806
result_t = argextype(SSAValue(idx), compact)
778-
lifted_result = lift_leaves(compact, result_t, field, leaves)
807+
lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads)
779808
lifted_result === nothing && continue
780809
lifted_leaves, any_undef = lifted_result
781810

@@ -811,21 +840,25 @@ function sroa_pass!(ir::IRCode)
811840
used_ssas = copy(compact.used_ssas)
812841
simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1)
813842
ir = complete(compact)
814-
sroa_mutables!(ir, defuses, used_ssas)
815-
return ir
843+
return sroa_mutables!(ir, defuses, used_ssas, nested_loads)
816844
else
817845
simple_dce!(compact)
818846
return complete(compact)
819847
end
820848
end
821849

822-
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
850+
function sroa_mutables!(ir::IRCode,
851+
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
852+
nested_loads::NestedLoads)
823853
# Compute domtree, needed below, now that we have finished compacting the IR.
824854
# This needs to be after we iterate through the IR with `IncrementalCompact`
825855
# because removing dead blocks can invalidate the domtree.
826856
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)
827857

828-
for (idx, (intermediaries, defuse)) in defuses
858+
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
859+
local any_eliminated = false
860+
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
861+
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
829862
intermediaries = collect(intermediaries)
830863
# Check if there are any uses we did not account for. If so, the variable
831864
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
@@ -840,7 +873,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
840873
nleaves == nuses_total || continue
841874
# Find the type for this allocation
842875
defexpr = ir[SSAValue(idx)]
843-
isexpr(defexpr, :new) || continue
876+
isa(defexpr, Expr) || continue
877+
if !isexpr(defexpr, :new)
878+
if is_known_call(defexpr, getfield, ir)
879+
val = defexpr.args[2]
880+
if isa(val, SSAValue)
881+
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
882+
if ismutabletype(struct_typ)
883+
record_nested_load!(nested_mloads, idx)
884+
end
885+
end
886+
end
887+
continue
888+
end
844889
newidx = idx
845890
typ = ir.stmts[newidx][:type]
846891
if isa(typ, UnionAll)
@@ -910,6 +955,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
910955
# Now go through all uses and rewrite them
911956
for stmt in du.uses
912957
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
958+
if !any_eliminated
959+
any_eliminated |= (is_nested_load(nested_loads, stmt) ||
960+
is_nested_load(nested_mloads, stmt))
961+
end
913962
end
914963
if !isbitstype(ftyp)
915964
if preserve_uses !== nothing
@@ -946,6 +995,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
946995

947996
@label skip
948997
end
998+
if any_eliminated
999+
return sroa_pass!(compact!(ir), false)
1000+
else
1001+
return ir
1002+
end
9491003
end
9501004

9511005
function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})

test/compiler/irpasses.jl

+41-10
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[
9090
struct ImmutableXYZ; x; y; z; end
9191
mutable struct MutableXYZ; x; y; z; end
9292

93+
struct ImmutableOuter{T}; x::T; y::T; z::T; end
94+
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
95+
9396
# should optimize away very basic cases
9497
let src = code_typed1((Any,Any,Any)) do x, y, z
9598
xyz = ImmutableXYZ(x, y, z)
@@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
198201
@test any(isnew, src.code)
199202
end
200203

201-
# should include a simple alias analysis
202-
struct ImmutableOuter{T}; x::T; y::T; z::T; end
203-
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
204+
# alias analysis
205+
# --------------
204206
let src = code_typed1((Any,Any,Any)) do x, y, z
205207
xyz = ImmutableXYZ(x, y, z)
206208
outer = ImmutableOuter(xyz, xyz, xyz)
@@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
227229
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
228230
end
229231
end
230-
231-
# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well
232-
# OK: mutable(immutable(...)) case
232+
# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until
233+
# any nested mutable `getfield` calls become no longer eliminatable:
234+
# it's probably not the most efficient option and we may want to introduce some sort of
235+
# alias analysis and eliminates all the loads at once.
236+
# mutable(immutable(...)) case
233237
let src = code_typed1((Any,Any,Any)) do x, y, z
234238
xyz = MutableXYZ(x, y, z)
235239
t = (xyz,)
@@ -260,21 +264,48 @@ let # this is a simple end to end test case, which demonstrates allocation elimi
260264
# compiled code for `simple_sroa`, otherwise everything can be folded even without SROA
261265
@test @allocated(simple_sroa(s)) == 0
262266
end
263-
# FIXME: immutable(mutable(...)) case
267+
# immutable(mutable(...)) case
264268
let src = code_typed1((Any,Any,Any)) do x, y, z
265269
xyz = ImmutableXYZ(x, y, z)
266270
outer = MutableOuter(xyz, xyz, xyz)
267271
outer.x.x, outer.y.y, outer.z.z
268272
end
269-
@test_broken !any(isnew, src.code)
273+
@test !any(isnew, src.code)
274+
@test any(src.code) do @nospecialize x
275+
iscall((src, tuple), x) &&
276+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
277+
end
270278
end
271-
# FIXME: mutable(mutable(...)) case
279+
# mutable(mutable(...)) case
272280
let src = code_typed1((Any,Any,Any)) do x, y, z
273281
xyz = MutableXYZ(x, y, z)
274282
outer = MutableOuter(xyz, xyz, xyz)
275283
outer.x.x, outer.y.y, outer.z.z
276284
end
277-
@test_broken !any(isnew, src.code)
285+
@test !any(isnew, src.code)
286+
@test any(src.code) do @nospecialize x
287+
iscall((src, tuple), x) &&
288+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
289+
end
290+
end
291+
let src = code_typed1((Any,Any,Any)) do x, y, z
292+
xyz = MutableXYZ(x, y, z)
293+
inner = MutableOuter(xyz, xyz, xyz)
294+
outer = MutableOuter(inner, inner, inner)
295+
outer.x.x.x, outer.y.y.y, outer.z.z.z
296+
end
297+
@test !any(isnew, src.code)
298+
@test any(src.code) do @nospecialize x
299+
iscall((src, tuple), x) &&
300+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
301+
end
302+
end
303+
let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it should be able
304+
# to fully eliminate this insanely nested example
305+
src = code_typed1((Int,)) do x
306+
(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][]
307+
end
308+
@test !any(isnew, src.code)
278309
end
279310

280311
# should work nicely with inlining to optimize away a complicated case

0 commit comments

Comments
 (0)