Skip to content

Commit e1e502e

Browse files
committed
optimizer: run SROA multiple times to handle more nested loads
1 parent 98b485e commit e1e502e

File tree

3 files changed

+128
-47
lines changed

3 files changed

+128
-47
lines changed

NEWS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Compiler/Runtime improvements
4747
* Julia-level SROA (Scalar Replacement of Aggregates) has been improved, i.e. allowing elimination of
4848
`getfield` call with constant global field ([#42355]), enabling elimination of mutable struct with
4949
uninitialized fields ([#43208]), improving performance ([#43232]), handling more nested `getfield`
50-
calls ([#43239]).
50+
calls ([#43239], [#43267]).
5151
* Abstract callsite can now be inlined or statically resolved as far as the callsite has a single
5252
matching method ([#43113]).
5353

base/compiler/ssair/passes.jl

+86-36
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])
2929

3030
compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses)
3131

32-
function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr)
33-
field = stmt.args[3]
32+
try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) =
33+
try_compute_field(ir, stmt.args[3])
34+
35+
function try_compute_field(ir::Union{IncrementalCompact,IRCode}, @nospecialize(field))
3436
# fields are usually literals, handle them manually
3537
if isa(field, QuoteNode)
3638
field = field.value
37-
elseif isa(field, Int)
39+
elseif isa(field, Int) || isa(field, Symbol)
3840
# try to resolve other constants, e.g. global reference
3941
else
4042
field = argextype(field, ir)
@@ -44,8 +46,7 @@ function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr
4446
return nothing
4547
end
4648
end
47-
isa(field, Union{Int, Symbol}) || return nothing
48-
return field
49+
return isa(field, Union{Int, Symbol}) ? field : nothing
4950
end
5051

5152
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
@@ -167,7 +168,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
167168
end
168169

169170
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
170-
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
171+
callback = (@nospecialize(x), @nospecialize(idx)) -> false)
171172
while true
172173
if isa(defssa, OldSSAValue)
173174
if already_inserted(compact, defssa)
@@ -335,10 +336,29 @@ struct LiftedValue
335336
end
336337
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}
337338

339+
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
340+
# which can be very large sometimes, and program counters in question are often very sparse
341+
const SPCSet = IdSet{Int}
342+
343+
mutable struct NestedLoads
344+
maybe::Union{Nothing,SPCSet}
345+
NestedLoads() = new(nothing)
346+
end
347+
function record_nested_load!(nested_loads::NestedLoads, pc::Int)
348+
maybe = nested_loads.maybe
349+
maybe === nothing && (maybe = nested_loads.maybe = SPCSet())
350+
push!(maybe::SPCSet, pc)
351+
end
352+
function is_nested_load(nested_loads::NestedLoads, pc::Int)
353+
maybe = nested_loads.maybe
354+
maybe === nothing && return false
355+
return pc in maybe::SPCSet
356+
end
357+
338358
# try to compute lifted values that can replace `getfield(x, field)` call
339359
# where `x` is an immutable struct that are defined at any of `leaves`
340-
function lift_leaves(compact::IncrementalCompact,
341-
@nospecialize(result_t), field::Int, leaves::Vector{Any})
360+
function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any},
361+
@nospecialize(result_t), field::Int, nested_loads::NestedLoads)
342362
# For every leaf, the lifted value
343363
lifted_leaves = LiftedLeaves()
344364
maybe_undef = false
@@ -388,11 +408,19 @@ function lift_leaves(compact::IncrementalCompact,
388408
ocleaf = simple_walk(compact, ocleaf)
389409
end
390410
ocdef, _ = walk_to_def(compact, ocleaf)
391-
if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 field length(ocdef.args)-5
411+
if isexpr(ocdef, :new_opaque_closure) && 1 field length(ocdef.args)-5
392412
lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves)
393413
continue
394414
end
395415
return nothing
416+
elseif is_known_call(def, getfield, compact)
417+
if isa(leaf, SSAValue)
418+
struct_typ = unwrap_unionall(widenconst(argextype(def.args[2], compact)))
419+
if ismutabletype(struct_typ)
420+
record_nested_load!(nested_loads, leaf.id)
421+
end
422+
end
423+
return nothing
396424
else
397425
typ = argextype(leaf, compact)
398426
if !isa(typ, Const)
@@ -586,7 +614,7 @@ function perform_lifting!(compact::IncrementalCompact,
586614
end
587615
val = lifted_val.x
588616
if isa(val, AnySSAValue)
589-
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
617+
callback = (@nospecialize(x), @nospecialize(idx)) -> true
590618
val = simple_walk(compact, val, callback)
591619
end
592620
push!(new_node.values, val)
@@ -617,10 +645,6 @@ function perform_lifting!(compact::IncrementalCompact,
617645
return stmt_val # N.B. should never happen
618646
end
619647

620-
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
621-
# which can be very large sometimes, and program counters in question are often very sparse
622-
const SPCSet = IdSet{Int}
623-
624648
"""
625649
sroa_pass!(ir::IRCode) -> newir::IRCode
626650
@@ -639,10 +663,11 @@ its argument).
639663
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
640664
a result of succeeding dead code elimination.
641665
"""
642-
function sroa_pass!(ir::IRCode)
666+
function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
643667
compact = IncrementalCompact(ir)
644668
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
645669
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
670+
nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
646671
for ((_, idx), stmt) in compact
647672
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
648673
isa(stmt, Expr) || continue
@@ -670,7 +695,7 @@ function sroa_pass!(ir::IRCode)
670695
preserved_arg = stmt.args[pidx]
671696
isa(preserved_arg, SSAValue) || continue
672697
let intermediaries = SPCSet()
673-
callback = function (@nospecialize(pi), @nospecialize(ssa))
698+
callback = function (@nospecialize(x), @nospecialize(ssa))
674699
push!(intermediaries, ssa.id)
675700
return false
676701
end
@@ -698,7 +723,9 @@ function sroa_pass!(ir::IRCode)
698723
if defuses === nothing
699724
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
700725
end
701-
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
726+
mid, defuse = get!(defuses, defidx) do
727+
SPCSet(), SSADefUse()
728+
end
702729
push!(defuse.ccall_preserve_uses, idx)
703730
union!(mid, intermediaries)
704731
end
@@ -708,16 +735,17 @@ function sroa_pass!(ir::IRCode)
708735
compact[idx] = form_new_preserves(stmt, preserved, new_preserves)
709736
end
710737
continue
711-
# TODO: This isn't the best place to put these
712-
elseif is_known_call(stmt, typeassert, compact)
713-
canonicalize_typeassert!(compact, idx, stmt)
714-
continue
715-
elseif is_known_call(stmt, (===), compact)
716-
lift_comparison!(compact, idx, stmt, lifting_cache)
717-
continue
718-
# elseif is_known_call(stmt, isa, compact)
719-
# TODO do a similar optimization as `lift_comparison!` for `===`
720738
else
739+
if optional_opts
740+
# TODO: This isn't the best place to put these
741+
if is_known_call(stmt, typeassert, compact)
742+
canonicalize_typeassert!(compact, idx, stmt)
743+
elseif is_known_call(stmt, (===), compact)
744+
lift_comparison!(compact, idx, stmt, lifting_cache)
745+
# elseif is_known_call(stmt, isa, compact)
746+
# TODO do a similar optimization as `lift_comparison!` for `===`
747+
end
748+
end
721749
continue
722750
end
723751

@@ -743,7 +771,7 @@ function sroa_pass!(ir::IRCode)
743771
if ismutabletype(struct_typ)
744772
isa(val, SSAValue) || continue
745773
let intermediaries = SPCSet()
746-
callback = function (@nospecialize(pi), @nospecialize(ssa))
774+
callback = function (@nospecialize(x), @nospecialize(ssa))
747775
push!(intermediaries, ssa.id)
748776
return false
749777
end
@@ -753,7 +781,9 @@ function sroa_pass!(ir::IRCode)
753781
if defuses === nothing
754782
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
755783
end
756-
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
784+
mid, defuse = get!(defuses, def.id) do
785+
SPCSet(), SSADefUse()
786+
end
757787
if is_setfield
758788
push!(defuse.defs, idx)
759789
else
@@ -775,7 +805,7 @@ function sroa_pass!(ir::IRCode)
775805
isempty(leaves) && continue
776806

777807
result_t = argextype(SSAValue(idx), compact)
778-
lifted_result = lift_leaves(compact, result_t, field, leaves)
808+
lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads)
779809
lifted_result === nothing && continue
780810
lifted_leaves, any_undef = lifted_result
781811

@@ -811,18 +841,21 @@ function sroa_pass!(ir::IRCode)
811841
used_ssas = copy(compact.used_ssas)
812842
simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1)
813843
ir = complete(compact)
814-
sroa_mutables!(ir, defuses, used_ssas)
815-
return ir
844+
return sroa_mutables!(ir, defuses, used_ssas, nested_loads)
816845
else
817846
simple_dce!(compact)
818847
return complete(compact)
819848
end
820849
end
821850

822-
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
823-
# initialization of domtree is delayed to avoid the expensive computation in many cases
824-
local domtree = nothing
825-
for (idx, (intermediaries, defuse)) in defuses
851+
function sroa_mutables!(ir::IRCode,
852+
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
853+
nested_loads::NestedLoads)
854+
domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
855+
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
856+
any_eliminated = false
857+
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
858+
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
826859
intermediaries = collect(intermediaries)
827860
# Check if there are any uses we did not account for. If so, the variable
828861
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
@@ -837,7 +870,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
837870
nleaves == nuses_total || continue
838871
# Find the type for this allocation
839872
defexpr = ir[SSAValue(idx)]
840-
isexpr(defexpr, :new) || continue
873+
isa(defexpr, Expr) || continue
874+
if !isexpr(defexpr, :new)
875+
if is_known_call(defexpr, getfield, ir)
876+
val = defexpr.args[2]
877+
if isa(val, SSAValue)
878+
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
879+
if ismutabletype(struct_typ)
880+
record_nested_load!(nested_mloads, idx)
881+
end
882+
end
883+
end
884+
continue
885+
end
841886
newidx = idx
842887
typ = ir.stmts[newidx][:type]
843888
if isa(typ, UnionAll)
@@ -917,6 +962,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
917962
# Now go through all uses and rewrite them
918963
for stmt in du.uses
919964
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
965+
if !any_eliminated
966+
any_eliminated |= (is_nested_load(nested_loads, stmt) ||
967+
is_nested_load(nested_mloads, stmt))
968+
end
920969
end
921970
if !isbitstype(ftyp)
922971
if preserve_uses !== nothing
@@ -953,6 +1002,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
9531002

9541003
@label skip
9551004
end
1005+
return any_eliminated ? sroa_pass!(compact!(ir), false) : ir
9561006
end
9571007

9581008
function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})

test/compiler/irpasses.jl

+41-10
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[
9090
struct ImmutableXYZ; x; y; z; end
9191
mutable struct MutableXYZ; x; y; z; end
9292

93+
struct ImmutableOuter{T}; x::T; y::T; z::T; end
94+
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
95+
9396
# should optimize away very basic cases
9497
let src = code_typed1((Any,Any,Any)) do x, y, z
9598
xyz = ImmutableXYZ(x, y, z)
@@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
198201
@test any(isnew, src.code)
199202
end
200203

201-
# should include a simple alias analysis
202-
struct ImmutableOuter{T}; x::T; y::T; z::T; end
203-
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
204+
# alias analysis
205+
# --------------
204206
let src = code_typed1((Any,Any,Any)) do x, y, z
205207
xyz = ImmutableXYZ(x, y, z)
206208
outer = ImmutableOuter(xyz, xyz, xyz)
@@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
227229
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
228230
end
229231
end
230-
231-
# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well
232-
# OK: mutable(immutable(...)) case
232+
# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until
233+
# any nested mutable `getfield` calls become no longer eliminatable:
234+
# it's probably not the most efficient option and we may want to introduce some sort of
235+
# alias analysis and eliminates all the loads at once.
236+
# mutable(immutable(...)) case
233237
let src = code_typed1((Any,Any,Any)) do x, y, z
234238
xyz = MutableXYZ(x, y, z)
235239
t = (xyz,)
@@ -260,21 +264,48 @@ let # this is a simple end to end test case, which demonstrates allocation elimi
260264
# compiled code for `simple_sroa`, otherwise everything can be folded even without SROA
261265
@test @allocated(simple_sroa(s)) == 0
262266
end
263-
# FIXME: immutable(mutable(...)) case
267+
# immutable(mutable(...)) case
264268
let src = code_typed1((Any,Any,Any)) do x, y, z
265269
xyz = ImmutableXYZ(x, y, z)
266270
outer = MutableOuter(xyz, xyz, xyz)
267271
outer.x.x, outer.y.y, outer.z.z
268272
end
269-
@test_broken !any(isnew, src.code)
273+
@test !any(isnew, src.code)
274+
@test any(src.code) do @nospecialize x
275+
iscall((src, tuple), x) &&
276+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
277+
end
270278
end
271-
# FIXME: mutable(mutable(...)) case
279+
# mutable(mutable(...)) case
272280
let src = code_typed1((Any,Any,Any)) do x, y, z
273281
xyz = MutableXYZ(x, y, z)
274282
outer = MutableOuter(xyz, xyz, xyz)
275283
outer.x.x, outer.y.y, outer.z.z
276284
end
277-
@test_broken !any(isnew, src.code)
285+
@test !any(isnew, src.code)
286+
@test any(src.code) do @nospecialize x
287+
iscall((src, tuple), x) &&
288+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
289+
end
290+
end
291+
let src = code_typed1((Any,Any,Any)) do x, y, z
292+
xyz = MutableXYZ(x, y, z)
293+
inner = MutableOuter(xyz, xyz, xyz)
294+
outer = MutableOuter(inner, inner, inner)
295+
outer.x.x.x, outer.y.y.y, outer.z.z.z
296+
end
297+
@test !any(isnew, src.code)
298+
@test any(src.code) do @nospecialize x
299+
iscall((src, tuple), x) &&
300+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
301+
end
302+
end
303+
let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it should be able
304+
# to fully eliminate this insanely nested example
305+
src = code_typed1((Int,)) do x
306+
(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][]
307+
end
308+
@test !any(isnew, src.code)
278309
end
279310

280311
# should work nicely with inlining to optimize away a complicated case

0 commit comments

Comments
 (0)