Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer: enable SROA of mutable φ-nodes #43505

Open
wants to merge 1 commit into
base: avi/multisroa
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
optimizer: enable SROA of mutable φ-nodes
This commit allows elimination of mutable φ-node (and its predecessor mutables allocations).
As an contrived example, it allows this `mutable_ϕ_elim(::String, ::Vector{String})`
to run without any allocations at all:
```julia
function mutable_ϕ_elim(x, xs)
    r = Ref(x)
    for x in xs
        r = Ref(x)
    end
    return r[]
end

let xs = String[string(gensym()) for _ in 1:100]
    mutable_ϕ_elim("init", xs)
    @test @allocated(mutable_ϕ_elim("init", xs)) == 0
end
```

This mutable ϕ-node elimination is still limited though.
Most notably, the current implementation doesn't work if a mutable
allocation forms multiple ϕ-nodes, since we check allocation eliminability
(i.e. escapability) by counting usages counts and thus it's hard to
reason about multiple ϕ-nodes at a time.
For example, currently mutable allocations involved in cases like below
will still not be eliminated:
```julia
code_typed((Bool,String,String),) do cond, x, y
    if cond
        ϕ2 = ϕ1 = Ref(x)
    else
        ϕ2 = ϕ1 = Ref(y)
    end
    ϕ1[], ϕ2[]
end

\# more realistic example
mutable struct Point{T}
    x::T
    y::T
end
add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y)
function compute(a::Point{ComplexF64}, b::Point{ComplexF64})
    for i in 0:(100000000-1)
        a = add(add(a, b), b)
    end
    a.x, a.y
end
```

I'd say this limitation should be addressed by first introducing a better
abstraction for reasoning escape information. More specifically, I'd like
introduce EscapeAnalysis.jl into Julia base first, and then gradually
adapt it to improve our SROA pass, since EA will allow us to reason about
all escape information imposed on whatever object more easily and should
help us get rid of the complexities of our current SROA implementation.

For now, I'd like to get in this enhancement even though it has the
limitation elaborated above, as far as this commit doesn't introduce
latency problem (which is unlikely).
aviatesk committed Jan 8, 2022
commit bf97c297435104ff8a3ed265757874e2f63268e9
229 changes: 175 additions & 54 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
@@ -101,9 +101,9 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
end
end

# even when the allocation contains an uninitialized field, we try an extra effort to check
# if this load at `idx` have any "safe" `setfield!` calls that define the field
function has_safe_def(
# even when the allocation contains an uninitialized field, we try an extra effort to
# check if all loads have "safe" `setfield!` calls that define the uninitialized field
function has_safe_def_for_undef_field(
ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse,
newidx::Int, idx::Int)
def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx)
@@ -208,14 +208,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint))
callback = function (@nospecialize(pi), @nospecialize(idx))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
@nospecialize(typeconstraint), @nospecialize(callback = nothing))
newcallback = function (@nospecialize(x), @nospecialize(idx))
if isa(x, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(x.typ))
end
callback === nothing || callback(x, idx)
return false
end
def = simple_walk(compact, defssa, callback)
def = simple_walk(compact, defssa, newcallback)
return Pair{Any, Any}(def, typeconstraint)
end

@@ -225,7 +226,9 @@ end
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
(pruning those leaves rules out by path conditions).
"""
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
function walk_to_defs(compact::IncrementalCompact,
@nospecialize(defssa), @nospecialize(typeconstraint),
@nospecialize(callback = nothing))
visited_phinodes = AnySSAValue[]
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
def = compact[defssa]
@@ -261,7 +264,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
val = OldSSAValue(val.id)
end
if isa(val, AnySSAValue)
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint)
new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback)
if isa(new_def, AnySSAValue)
if !haskey(visited_constraints, new_def)
push!(worklist_defs, new_def)
@@ -722,10 +725,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
continue
end
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
end
mid, defuse = get!(defuses, defidx) do
SPCSet(), SSADefUse()
mid, defuse, phidefs = get!(defuses, defidx) do
SPCSet(), SSADefUse(), PhiDefs(nothing)
end
push!(defuse.ccall_preserve_uses, idx)
union!(mid, intermediaries)
@@ -780,16 +783,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
# Mutable stuff here
isa(def, SSAValue) || continue
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}()
end
mid, defuse = get!(defuses, def.id) do
SPCSet(), SSADefUse()
mid, defuse, phidefs = get!(defuses, def.id) do
SPCSet(), SSADefUse(), PhiDefs(nothing)
end
if is_setfield
push!(defuse.defs, idx)
else
push!(defuse.uses, idx)
end
defval = compact[def]
if isa(defval, PhiNode)
phicallback = function (@nospecialize(x), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
defs, _ = walk_to_defs(compact, def, struct_typ, phicallback)
if _any(@nospecialize(d)->!isa(d, SSAValue), defs)
delete!(defuses, def.id)
continue
end
phidefs[] = Int[(def::SSAValue).id for def in defs]
end
union!(mid, intermediaries)
end
continue
@@ -849,43 +865,73 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
end
end

# TODO:
# - run mutable SROA on the same IR as when we collect information about mutable allocations
# - simplify and improve the eliminability check below using an escape analysis

const PhiDefs = RefValue{Union{Nothing,Vector{Int}}}

function sroa_mutables!(ir::IRCode,
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int},
nested_loads::NestedLoads)
domtree = nothing # initialization of domtree is delayed to avoid the expensive computation in many cases
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
any_eliminated = false
eliminable_defs = nothing # tracks eliminable "definitions" if initialized
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true)
intermediaries = collect(intermediaries)
phidefs = phidefs[]
# Check if there are any uses we did not account for. If so, the variable
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
# show up in the nuses_total count.
nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
nleaves = count_leaves(defuse)
if phidefs !== nothing
# if this defines ϕ, we also track leaves of all predecessors as well
# FIXME this doesn't work when any predecessor is used by another ϕ-node
for pidx in phidefs
haskey(defuses, pidx) || continue
pdefuse = defuses[pidx][2]
nleaves += count_leaves(pdefuse)
end
end
nuses = 0
for idx in intermediaries
nuses += used_ssas[idx]
end
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
nuses -= length(intermediaries)
nuses_total = used_ssas[idx] + nuses
if phidefs !== nothing
for pidx in phidefs
# NOTE we don't need to accout for intermediates for this predecessor here,
# since they are already included in intermediates of this ϕ-node
# FIXME this doesn't work when any predecessor is used by another ϕ-node
nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself
end
end
nleaves == nuses_total || continue
# Find the type for this allocation
defexpr = ir[SSAValue(idx)]
isa(defexpr, Expr) || continue
if !isexpr(defexpr, :new)
if is_known_call(defexpr, getfield, ir)
val = defexpr.args[2]
if isa(val, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
if ismutabletype(struct_typ)
record_nested_load!(nested_mloads, idx)
end
if isa(defexpr, Expr)
@assert phidefs === nothing
if !isexpr(defexpr, :new)
maybe_record_nested_load!(nested_mloads, ir, idx)
continue
end
elseif isa(defexpr, PhiNode)
phidefs === nothing && continue
for pidx in phidefs
pexpr = ir[SSAValue(pidx)]
if !isexpr(pexpr, :new)
maybe_record_nested_load!(nested_mloads, ir, pidx)
@goto skip
end
end
else
continue
end
newidx = idx
typ = ir.stmts[newidx][:type]
typ = ir.stmts[idx][:type]
if isa(typ, UnionAll)
typ = unwrap_unionall(typ)
end
@@ -897,25 +943,29 @@ function sroa_mutables!(ir::IRCode,
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
all_forwarded = true
for use in defuse.uses
stmt = ir[SSAValue(use)] # == `getfield` call
# We may have discovered above that this use is dead
# after the getfield elim of immutables. In that case,
# it would have been deleted. That's fine, just ignore
# the use in that case.
if stmt === nothing
eliminable = check_use_eliminability!(fielddefuse, ir, use, typ)
if eliminable === nothing
# We may have discovered above that this use is dead
# after the getfield elim of immutables. In that case,
# it would have been deleted. That's fine, just ignore
# the use in that case.
all_forwarded = false
continue
elseif !eliminable
@goto skip
end
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
field === nothing && @goto skip
push!(fielddefuse[field].uses, use)
end
for def in defuse.defs
stmt = ir[SSAValue(def)]::Expr # == `setfield!` call
field = try_compute_fieldidx_stmt(ir, stmt, typ)
field === nothing && @goto skip
isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
push!(fielddefuse[field].defs, def)
check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip
end
if phidefs !== nothing
for pidx in phidefs
haskey(defuses, pidx) || continue
pdefuse = defuses[pidx][2]
for pdef in pdefuse.defs
check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip
end
end
end
# Check that the defexpr has defined values for all the fields
# we're accessing. In the future, we may want to relax this,
@@ -926,7 +976,13 @@ function sroa_mutables!(ir::IRCode,
for fidx in 1:ndefuse
du = fielddefuse[fidx]
isempty(du.uses) && continue
push!(du.defs, newidx)
if phidefs === nothing
push!(du.defs, idx)
else
for pidx in phidefs
push!(du.defs, pidx)
end
end
ldu = compute_live_ins(ir.cfg, du)
if isempty(ldu.live_in_bbs)
phiblocks = Int[]
@@ -936,10 +992,24 @@ function sroa_mutables!(ir::IRCode,
end
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
blocks[fidx] = phiblocks, allblocks
if fidx + 1 > length(defexpr.args)
for use in du.uses
if phidefs !== nothing
# check if all predecessors have safe definitions
for pidx in phidefs
newexpr = ir[SSAValue(pidx)]::Expr # == new(...)
if fidx + 1 > length(newexpr.args) # this field can be undefined
domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks))
for use in du.uses
has_safe_def_for_undef_field(ir, domtree, allblocks, du, pidx, use) || @goto skip
end
end
end
else
newexpr = defexpr::Expr # == new(...)
if fidx + 1 > length(newexpr.args) # this field can be undefined
domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks))
has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip
for use in du.uses
has_safe_def_for_undef_field(ir, domtree, allblocks, du, idx, use) || @goto skip
end
end
end
end
@@ -984,28 +1054,79 @@ function sroa_mutables!(ir::IRCode,
end
end
end
for stmt in du.defs
stmt == newidx && continue
ir[SSAValue(stmt)] = nothing
eliminable_defs === nothing && (eliminable_defs = SPCSet())
for def in du.defs
push!(eliminable_defs, def)
end
if phidefs !== nothing
# record ϕ-node itself eliminable here, since we didn't include it in `du.defs`
# we also modify usage counts of its predecessors so that their SROA may work
# in succeeding iteration
push!(eliminable_defs, idx)
for pidx in phidefs
used_ssas[pidx] -= 1
end
end
end
preserve_uses === nothing && continue
if all_forwarded
# this means all ccall preserves have been replaced with forwarded loads
# so we can potentially eliminate the allocation, otherwise we must preserve
# the whole allocation.
push!(intermediaries, newidx)
push!(intermediaries, idx)
end
# Insert the new preserves
for (use, new_preserves) in preserve_uses
ir[SSAValue(use)] = form_new_preserves(ir[SSAValue(use)]::Expr, intermediaries, new_preserves)
end

@label skip
end
# now eliminate "definitions" (i.e. allocations, ϕ-nodes, and `setfield!` calls)
# that should have no usage at this moment
if eliminable_defs !== nothing
for idx in eliminable_defs
ir[SSAValue(idx)] = nothing
end
end
return any_eliminated ? sroa_pass!(compact!(ir), false) : ir
end

count_leaves(defuse::SSADefUse) =
length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)

function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int)
defexpr = ir[SSAValue(idx)]
if is_known_call(defexpr, getfield, ir)
val = defexpr.args[2]
if isa(val, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
if ismutabletype(struct_typ)
record_nested_load!(nested_mloads, idx)
end
end
end
end

function check_use_eliminability!(fielddefuse::Vector{SSADefUse},
ir::IRCode, useidx::Int, struct_typ::DataType)
stmt = ir[SSAValue(useidx)] # == `getfield` call
stmt === nothing && return nothing
field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ)
field === nothing && return false
push!(fielddefuse[field].uses, useidx)
return true
end

function check_def_eliminability!(fielddefuse::Vector{SSADefUse},
ir::IRCode, defidx::Int, struct_typ::DataType)
stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call
field = try_compute_fieldidx_stmt(ir, stmt, struct_typ)
field === nothing && return false
isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error
push!(fielddefuse[field].defs, defidx)
return true
end

function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})
newex = Expr(:foreigncall)
nccallargs = length(origex.args[3]::SimpleVector)
203 changes: 202 additions & 1 deletion test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
@@ -230,7 +230,7 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
end
end
# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until
# any nested mutable `getfield` calls become no longer eliminatable:
# any nested mutable `getfield` calls become no longer eliminable:
# it's probably not the most efficient option and we may want to introduce some sort of
# alias analysis and eliminates all the loads at once.
# mutable(immutable(...)) case
@@ -308,6 +308,207 @@ let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it sho
@test !any(isnew, src.code)
end

# ϕ-allocation elimination
# ------------------------
mutable struct MutableSome
x::Any
MutableSome(@nospecialize x) = new(x)
MutableSome() = new()
end
Base.getindex(s::MutableSome) = s.x
Base.setindex!(s::MutableSome, @nospecialize x) = s.x = x
@testset "mutable ϕ-allocation elimination" begin
# safe cases
let src = code_typed1((Bool,Any,Any)) do cond, x, y
if cond
ϕ = MutableSome(x)
else
ϕ = MutableSome(y)
end
ϕ[]
end
@test !any(isnew, src.code)
@test count(src.code) do @nospecialize x
isa(x, Core.PhiNode) &&
#=x=# Core.Argument(3) in x.values &&
#=y=# Core.Argument(4) in x.values
end == 1
end
let src = code_typed1((Bool,Bool,Any,Any,Any)) do cond1, cond2, x, y, z
if cond1
ϕ = MutableSome(x)
elseif cond2
ϕ = MutableSome(y)
else
ϕ = MutableSome(z)
end
ϕ[]
end
@test !any(isnew, src.code)
@test count(src.code) do @nospecialize x
isa(x, Core.PhiNode) &&
#=x=# Core.Argument(4) in x.values &&
#=y=# Core.Argument(5) in x.values &&
#=z=# Core.Argument(6) in x.values
end == 1
end
let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z
if cond
ϕ = MutableSome(x)
else
ϕ = MutableSome(y)
ϕ[] = z
end
ϕ[]
end
@test !any(isnew, src.code)
@test !any(iscall((src, setfield!)), src.code)
@test count(src.code) do @nospecialize x
isa(x, Core.PhiNode) &&
#=x=# Core.Argument(3) in x.values &&
#=z=# Core.Argument(5) in x.values
end == 1
end
let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z
if cond
ϕ = MutableSome(x)
else
ϕ = MutableSome(y)
end
ϕ[] = z
ϕ[]
end
@test !any(isnew, src.code)
@test !any(iscall((src, setfield!)), src.code)
@test count(src.code) do @nospecialize x
isa(x, Core.ReturnNode) &&
#=z=# Core.Argument(5) === x.val
end == 1
end
let src = code_typed1((Bool,Any,Any,)) do cond, x, y
if cond
ϕ = MutableSome(x)
out1 = ϕ[]
else
ϕ = MutableSome(y)
out1 = ϕ[]
end
out2 = ϕ[]
out1, out2
end
@test !any(isnew, src.code)
@test count(src.code) do @nospecialize x
isa(x, Core.PhiNode) &&
#=x=# Core.Argument(3) in x.values &&
#=y=# Core.Argument(4) in x.values
end == 2
end
let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z
if cond
ϕ = MutableSome(x)
out1 = ϕ[]
else
ϕ = MutableSome(y)
out1 = ϕ[]
ϕ[] = z
end
out2 = ϕ[]
out1, out2
end
@test !any(isnew, src.code)
@test !any(iscall((src, setfield!)), src.code)
@test count(src.code) do @nospecialize x
isa(x, Core.PhiNode) &&
#=x=# Core.Argument(3) in x.values &&
#=y=# Core.Argument(4) in x.values
end == 1
@test count(src.code) do @nospecialize x
isa(x, Core.PhiNode) &&
#=x=# Core.Argument(3) in x.values &&
#=z=# Core.Argument(5) in x.values
end == 1
end

# unsafe cases
let src = code_typed1((Bool,Any,Any)) do cond, x, y
if cond
ϕ = MutableSome(x)
else
ϕ = MutableSome(y)
end
some_escape(ϕ)
ϕ[]
end
@test count(isnew, src.code) == 2
end
let src = code_typed1((Bool,Any,Any)) do cond, x, y
if cond
ϕ = MutableSome(x)
some_escape(ϕ)
else
ϕ = MutableSome(y)
end
ϕ[]
end
@test count(isnew, src.code) == 2
end
let src = code_typed1((Bool,Any,)) do cond, x
if cond
ϕ = MutableSome(x)
else
ϕ = MutableSome()
end
ϕ[]
end
@test count(isnew, src.code) == 2
end
let src = code_typed1((Bool,Any,Any)) do cond, x, y
if cond
ϕ = MutableSome(x)
else
ϕ = MutableSome()
ϕ[] = y
end
ϕ[]
end
@test !any(isnew, src.code)
@test !any(iscall((src, setfield!)), src.code)
@test count(src.code) do @nospecialize x
isa(x, Core.PhiNode) &&
#=x=# Core.Argument(3) in x.values &&
#=y=# Core.Argument(4) in x.values
end == 1
end

# FIXME allocation forming multiple ϕ
let src = code_typed1((Bool,Any,Any)) do cond, x, y
if cond
ϕ2 = ϕ1 = MutableSome(x)
else
ϕ2 = ϕ1 = MutableSome(y)
end
ϕ1[], ϕ2[]
end
@test_broken !any(isnew, src.code)
@test_broken count(src.code) do @nospecialize x
isa(x, Core.PhiNode) &&
#=x=# Core.Argument(3) in x.values &&
#=y=# Core.Argument(4) in x.values
end == 1
end
end
function mutable_ϕ_elim(x, xs)
r = Ref(x)
for x in xs
r = Ref(x)
end
return r[]
end
let xs = String[string(gensym()) for _ in 1:100]
mutable_ϕ_elim("init", xs)
@test @allocated(mutable_ϕ_elim("init", xs)) == 0
end

# should work nicely with inlining to optimize away a complicated case
# adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B
struct Point