Skip to content

Commit a9c6daf

Browse files
committed
optimizer: run SROA multiple times to handle more nested loads
1 parent dacd16f commit a9c6daf

File tree

3 files changed

+122
-42
lines changed

3 files changed

+122
-42
lines changed

NEWS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Compiler/Runtime improvements
4747
* Julia-level SROA (Scalar Replacement of Aggregates) has been improved, i.e. allowing elimination of
4848
`getfield` call with constant global field ([#42355]), enabling elimination of mutable struct with
4949
uninitialized fields ([#43208]), improving performance ([#43232]), handling more nested `getfield`
50-
calls ([#43239]).
50+
calls ([#43239], [#43267]).
5151
* Abstract callsite can now be inlined or statically resolved as far as the callsite has a single
5252
matching method ([#43113]).
5353

base/compiler/ssair/passes.jl

+80-31
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
169169
end
170170

171171
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
172-
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
172+
callback = (@nospecialize(x), @nospecialize(idx)) -> false)
173173
while true
174174
if isa(defssa, OldSSAValue)
175175
if already_inserted(compact, defssa)
@@ -337,10 +337,29 @@ struct LiftedValue
337337
end
338338
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
339339

340+
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
341+
# which can be very large sometimes, and program counters in question are often very sparse
342+
const SPCSet = IdSet{Int}
343+
344+
mutable struct NestedLoads
345+
maybe::Union{Nothing,SPCSet}
346+
NestedLoads() = new(nothing)
347+
end
348+
function record_nested_load!(nested_loads::NestedLoads, pc::Int)
349+
maybe = nested_loads.maybe
350+
maybe === nothing && (maybe = nested_loads.maybe = SPCSet())
351+
push!(maybe::SPCSet, pc)
352+
end
353+
function is_nested_load(nested_loads::NestedLoads, pc::Int)
354+
maybe = nested_loads.maybe
355+
maybe === nothing && return false
356+
return pc in maybe::SPCSet
357+
end
358+
340359
# try to compute lifted values that can replace `getfield(x, field)` call
341360
# where `x` is an immutable struct that are defined at any of `leaves`
342-
function lift_leaves(compact::IncrementalCompact,
343-
@nospecialize(result_t), field::Int, leaves::Vector{Any})
361+
function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any},
362+
@nospecialize(result_t), field::Int, nested_loads::NestedLoads)
344363
# For every leaf, the lifted value
345364
lifted_leaves = LiftedLeaves()
346365
maybe_undef = false
@@ -390,11 +409,19 @@ function lift_leaves(compact::IncrementalCompact,
390409
ocleaf = simple_walk(compact, ocleaf)
391410
end
392411
ocdef, _ = walk_to_def(compact, ocleaf)
393-
if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 field length(ocdef.args)-5
412+
if isexpr(ocdef, :new_opaque_closure) && 1 field length(ocdef.args)-5
394413
lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves)
395414
continue
396415
end
397416
return nothing
417+
elseif is_known_call(def, getfield, compact)
418+
if isa(leaf, SSAValue)
419+
struct_typ = unwrap_unionall(widenconst(argextype(def.args[2], compact)))
420+
if ismutabletype(struct_typ)
421+
record_nested_load!(nested_loads, leaf.id)
422+
end
423+
end
424+
return nothing
398425
else
399426
typ = argextype(leaf, compact)
400427
if !isa(typ, Const)
@@ -588,7 +615,7 @@ function perform_lifting!(compact::IncrementalCompact,
588615
end
589616
val = lifted_val.x
590617
if isa(val, AnySSAValue)
591-
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
618+
callback = (@nospecialize(x), @nospecialize(idx)) -> true
592619
val = simple_walk(compact, val, callback)
593620
end
594621
push!(new_node.values, val)
@@ -619,10 +646,6 @@ function perform_lifting!(compact::IncrementalCompact,
619646
return stmt_val # N.B. should never happen
620647
end
621648

622-
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
623-
# which can be very large sometimes, and program counters in question are often very sparse
624-
const SPCSet = IdSet{Int}
625-
626649
"""
627650
sroa_pass!(ir::IRCode) -> newir::IRCode
628651
@@ -641,10 +664,11 @@ its argument).
641664
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
642665
a result of succeeding dead code elimination.
643666
"""
644-
function sroa_pass!(ir::IRCode)
667+
function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
645668
compact = IncrementalCompact(ir)
646669
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
647670
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
671+
nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
648672
for ((_, idx), stmt) in compact
649673
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
650674
isa(stmt, Expr) || continue
@@ -672,7 +696,7 @@ function sroa_pass!(ir::IRCode)
672696
preserved_arg = stmt.args[pidx]
673697
isa(preserved_arg, SSAValue) || continue
674698
let intermediaries = SPCSet()
675-
callback = function (@nospecialize(pi), @nospecialize(ssa))
699+
callback = function (@nospecialize(x), @nospecialize(ssa))
676700
push!(intermediaries, ssa.id)
677701
return false
678702
end
@@ -700,7 +724,9 @@ function sroa_pass!(ir::IRCode)
700724
if defuses === nothing
701725
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
702726
end
703-
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
727+
mid, defuse = get!(defuses, defidx) do
728+
SPCSet(), SSADefUse()
729+
end
704730
push!(defuse.ccall_preserve_uses, idx)
705731
union!(mid, intermediaries)
706732
end
@@ -710,16 +736,17 @@ function sroa_pass!(ir::IRCode)
710736
compact[idx] = form_new_preserves(stmt, preserved, new_preserves)
711737
end
712738
continue
713-
# TODO: This isn't the best place to put these
714-
elseif is_known_call(stmt, typeassert, compact)
715-
canonicalize_typeassert!(compact, idx, stmt)
716-
continue
717-
elseif is_known_call(stmt, (===), compact)
718-
lift_comparison!(compact, idx, stmt, lifting_cache)
719-
continue
720-
# elseif is_known_call(stmt, isa, compact)
721-
# TODO do a similar optimization as `lift_comparison!` for `===`
722739
else
740+
if optional_opts
741+
# TODO: This isn't the best place to put these
742+
if is_known_call(stmt, typeassert, compact)
743+
canonicalize_typeassert!(compact, idx, stmt)
744+
elseif is_known_call(stmt, (===), compact)
745+
lift_comparison!(compact, idx, stmt, lifting_cache)
746+
# elseif is_known_call(stmt, isa, compact)
747+
# TODO do a similar optimization as `lift_comparison!` for `===`
748+
end
749+
end
723750
continue
724751
end
725752

@@ -745,7 +772,7 @@ function sroa_pass!(ir::IRCode)
745772
if ismutabletype(struct_typ)
746773
isa(val, SSAValue) || continue
747774
let intermediaries = SPCSet()
748-
callback = function (@nospecialize(pi), @nospecialize(ssa))
775+
callback = function (@nospecialize(x), @nospecialize(ssa))
749776
push!(intermediaries, ssa.id)
750777
return false
751778
end
@@ -755,7 +782,9 @@ function sroa_pass!(ir::IRCode)
755782
if defuses === nothing
756783
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
757784
end
758-
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
785+
mid, defuse = get!(defuses, def.id) do
786+
SPCSet(), SSADefUse()
787+
end
759788
if is_setfield
760789
push!(defuse.defs, idx)
761790
else
@@ -777,7 +806,7 @@ function sroa_pass!(ir::IRCode)
777806
isempty(leaves) && continue
778807

779808
result_t = argextype(SSAValue(idx), compact)
780-
lifted_result = lift_leaves(compact, result_t, field, leaves)
809+
lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads)
781810
lifted_result === nothing && continue
782811
lifted_leaves, any_undef = lifted_result
783812

@@ -813,18 +842,21 @@ function sroa_pass!(ir::IRCode)
813842
used_ssas = copy(compact.used_ssas)
814843
simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1)
815844
ir = complete(compact)
816-
sroa_mutables!(ir, defuses, used_ssas)
817-
return ir
845+
return sroa_mutables!(ir, defuses, used_ssas, nested_loads)
818846
else
819847
simple_dce!(compact)
820848
return complete(compact)
821849
end
822850
end
823851

824-
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
825-
# initialization of domtree is delayed to avoid the expensive computation in many cases
826-
local domtree = nothing
827-
for (idx, (intermediaries, defuse)) in defuses
852+
function sroa_mutables!(ir::IRCode,
853+
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
854+
nested_loads::NestedLoads)
855+
domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
856+
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
857+
any_eliminated = false
858+
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
859+
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
828860
intermediaries = collect(intermediaries)
829861
# Check if there are any uses we did not account for. If so, the variable
830862
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
@@ -839,7 +871,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
839871
nleaves == nuses_total || continue
840872
# Find the type for this allocation
841873
defexpr = ir[SSAValue(idx)]
842-
isexpr(defexpr, :new) || continue
874+
isa(defexpr, Expr) || continue
875+
if !isexpr(defexpr, :new)
876+
if is_known_call(defexpr, getfield, ir)
877+
val = defexpr.args[2]
878+
if isa(val, SSAValue)
879+
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
880+
if ismutabletype(struct_typ)
881+
record_nested_load!(nested_mloads, idx)
882+
end
883+
end
884+
end
885+
continue
886+
end
843887
newidx = idx
844888
typ = ir.stmts[newidx][:type]
845889
if isa(typ, UnionAll)
@@ -919,6 +963,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
919963
# Now go through all uses and rewrite them
920964
for stmt in du.uses
921965
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
966+
if !any_eliminated
967+
any_eliminated |= (is_nested_load(nested_loads, stmt) ||
968+
is_nested_load(nested_mloads, stmt))
969+
end
922970
end
923971
if !isbitstype(ftyp)
924972
if preserve_uses !== nothing
@@ -955,6 +1003,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
9551003

9561004
@label skip
9571005
end
1006+
return any_eliminated ? sroa_pass!(compact!(ir), false) : ir
9581007
end
9591008

9601009
function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})

test/compiler/irpasses.jl

+41-10
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[
9090
struct ImmutableXYZ; x; y; z; end
9191
mutable struct MutableXYZ; x; y; z; end
9292

93+
struct ImmutableOuter{T}; x::T; y::T; z::T; end
94+
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
95+
9396
# should optimize away very basic cases
9497
let src = code_typed1((Any,Any,Any)) do x, y, z
9598
xyz = ImmutableXYZ(x, y, z)
@@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
198201
@test any(isnew, src.code)
199202
end
200203

201-
# should include a simple alias analysis
202-
struct ImmutableOuter{T}; x::T; y::T; z::T; end
203-
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
204+
# alias analysis
205+
# --------------
204206
let src = code_typed1((Any,Any,Any)) do x, y, z
205207
xyz = ImmutableXYZ(x, y, z)
206208
outer = ImmutableOuter(xyz, xyz, xyz)
@@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
227229
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
228230
end
229231
end
230-
231-
# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well
232-
# OK: mutable(immutable(...)) case
232+
# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until
233+
# any nested mutable `getfield` calls become no longer eliminatable:
234+
# it's probably not the most efficient option and we may want to introduce some sort of
235+
# alias analysis and eliminates all the loads at once.
236+
# mutable(immutable(...)) case
233237
let src = code_typed1((Any,Any,Any)) do x, y, z
234238
xyz = MutableXYZ(x, y, z)
235239
t = (xyz,)
@@ -260,21 +264,48 @@ let # this is a simple end to end test case, which demonstrates allocation elimi
260264
# compiled code for `simple_sroa`, otherwise everything can be folded even without SROA
261265
@test @allocated(simple_sroa(s)) == 0
262266
end
263-
# FIXME: immutable(mutable(...)) case
267+
# immutable(mutable(...)) case
264268
let src = code_typed1((Any,Any,Any)) do x, y, z
265269
xyz = ImmutableXYZ(x, y, z)
266270
outer = MutableOuter(xyz, xyz, xyz)
267271
outer.x.x, outer.y.y, outer.z.z
268272
end
269-
@test_broken !any(isnew, src.code)
273+
@test !any(isnew, src.code)
274+
@test any(src.code) do @nospecialize x
275+
iscall((src, tuple), x) &&
276+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
277+
end
270278
end
271-
# FIXME: mutable(mutable(...)) case
279+
# mutable(mutable(...)) case
272280
let src = code_typed1((Any,Any,Any)) do x, y, z
273281
xyz = MutableXYZ(x, y, z)
274282
outer = MutableOuter(xyz, xyz, xyz)
275283
outer.x.x, outer.y.y, outer.z.z
276284
end
277-
@test_broken !any(isnew, src.code)
285+
@test !any(isnew, src.code)
286+
@test any(src.code) do @nospecialize x
287+
iscall((src, tuple), x) &&
288+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
289+
end
290+
end
291+
let src = code_typed1((Any,Any,Any)) do x, y, z
292+
xyz = MutableXYZ(x, y, z)
293+
inner = MutableOuter(xyz, xyz, xyz)
294+
outer = MutableOuter(inner, inner, inner)
295+
outer.x.x.x, outer.y.y.y, outer.z.z.z
296+
end
297+
@test !any(isnew, src.code)
298+
@test any(src.code) do @nospecialize x
299+
iscall((src, tuple), x) &&
300+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
301+
end
302+
end
303+
let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it should be able
304+
# to fully eliminate this insanely nested example
305+
src = code_typed1((Int,)) do x
306+
(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][]
307+
end
308+
@test !any(isnew, src.code)
278309
end
279310

280311
# should work nicely with inlining to optimize away a complicated case

0 commit comments

Comments
 (0)