Skip to content

duplicate invoke expr during call method splitting #22481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 28, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 69 additions & 54 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3903,8 +3903,27 @@ function get_spec_lambda(atypes::ANY, sv, invoke_data::ANY)
end
end

function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY,
invoke_data::ANY)
function linearize_args!(args::Vector{Any}, atypes::Vector{Any}, stmts::Vector{Any}, sv::InferenceState)
# linearize the IR by moving the arguments to SSA position
na = length(args)
@assert length(atypes) == na
newargs = Vector{Any}(na)
for i = na:-1:1
aei = args[i]
ti = atypes[i]
if isa(aei, Expr) || isa(aei, GlobalRef)
newvar = newvar!(sv, ti)
unshift!(stmts, Expr(:(=), newvar, aei))
else
newvar = aei
end
newargs[i] = newvar
end
return newargs
end

function invoke_NF(argexprs, etype::ANY, atypes::Vector{Any}, sv::InferenceState,
atype_unlimited::ANY, invoke_data::ANY)
# converts a :call to :invoke
nu = countunionsplit(atypes)
nu > sv.params.MAX_UNION_SPLITTING && return NF
Expand All @@ -3918,42 +3937,37 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY,
end

if nu > 1
spec_hit = nothing
# linearize the IR by moving the arguments to SSA position
stmts = []

spec_miss = nothing
error_label = nothing
linfo_var = add_slot!(sv.src, MethodInstance, false)
ex = Expr(:call)
ex.args = copy(argexprs)
ex.typ = etype
stmts = []
arg_hoisted = false
for i = length(atypes):-1:1
if i == 1 && !(invoke_texpr === nothing)
unshift!(stmts, invoke_texpr)
arg_hoisted = true
end
ti = atypes[i]
if arg_hoisted || isa(ti, Union)
aei = ex.args[i]
if !effect_free(aei, sv.src, sv.mod, false)
arg_hoisted = true
newvar = newvar!(sv, ti)
unshift!(stmts, :($newvar = $aei))
ex.args[i] = newvar
end
end
end
ex.args = linearize_args!(argexprs, atypes, stmts, sv)
invoke_texpr === nothing || insert!(stmts, 2, invoke_texpr)
invoke_fexpr === nothing || unshift!(stmts, invoke_fexpr)

local ret_var, merge, invoke_ex, spec_hit
ret_var = add_slot!(sv.src, widenconst(etype), false)
merge = genlabel(sv)
invoke_ex = copy(ex)
invoke_ex.head = :invoke
unshift!(invoke_ex.args, nothing)
spec_hit = false

function splitunion(atypes::Vector{Any}, i::Int)
if i == 0
local sig = argtypes_to_type(atypes)
local li = get_spec_lambda(sig, sv, invoke_data)
li === nothing && return false
add_backedge!(li, sv)
local stmt = []
push!(stmt, Expr(:(=), linfo_var, li))
spec_hit === nothing && (spec_hit = genlabel(sv))
push!(stmt, GotoNode(spec_hit.label))
invoke_ex = copy(invoke_ex)
invoke_ex.args[1] = li
push!(stmt, Expr(:(=), ret_var, invoke_ex))
push!(stmt, GotoNode(merge.label))
spec_hit = true
return stmt
else
local ti = atypes[i]
Expand Down Expand Up @@ -3991,36 +4005,25 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY,
end
end
local match = splitunion(atypes, length(atypes))
if match !== false && spec_hit !== nothing
if match !== false && spec_hit
append!(stmts, match)
if error_label !== nothing
push!(stmts, error_label)
push!(stmts, Expr(:call, GlobalRef(_topmod(sv.mod), :error), "fatal error in type inference (type bound)"))
end
local ret_var, merge
if spec_miss !== nothing
ret_var = add_slot!(sv.src, widenconst(ex.typ), false)
merge = genlabel(sv)
push!(stmts, spec_miss)
push!(stmts, Expr(:(=), ret_var, ex))
push!(stmts, GotoNode(merge.label))
else
ret_var = newvar!(sv, ex.typ)
end
push!(stmts, spec_hit)
ex = copy(ex)
ex.head = :invoke
unshift!(ex.args, linfo_var)
push!(stmts, Expr(:(=), ret_var, ex))
if spec_miss !== nothing
push!(stmts, merge)
end
push!(stmts, merge)
return (ret_var, stmts)
end
else
local cache_linfo = get_spec_lambda(atype_unlimited, sv, invoke_data)
cache_linfo === nothing && return NF
add_backedge!(cache_linfo, sv)
argexprs = copy(argexprs)
unshift!(argexprs, cache_linfo)
ex = Expr(:invoke)
ex.args = argexprs
Expand Down Expand Up @@ -4207,6 +4210,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end

argexprs0 = argexprs
atypes0 = atypes
na = Int(method.nargs)
# check for vararg function
isva = false
Expand All @@ -4215,6 +4219,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
# construct tuple-forming expression for argument tail
vararg = mk_tuplecall(argexprs[na:end], sv)
argexprs = Any[argexprs[1:(na - 1)]..., vararg]
atypes = Any[atypes[1:(na - 1)]..., vararg.typ]
isva = true
elseif na != length(argexprs)
# we have a method match only because an earlier
Expand Down Expand Up @@ -4254,7 +4259,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference

# see if the method has been previously inferred (and cached)
linfo = code_for_method(method, metharg, methsp, sv.params.world, !force_infer) # Union{Void, MethodInstance}
isa(linfo, MethodInstance) || return invoke_NF(argexprs0, e.typ, atypes, sv,
isa(linfo, MethodInstance) || return invoke_NF(argexprs0, e.typ, atypes0, sv,
atype_unlimited, invoke_data)
linfo = linfo::MethodInstance
if linfo.jlcall_api == 2
Expand Down Expand Up @@ -4324,7 +4329,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
src_inlineable = ccall(:jl_ast_flag_inlineable, Bool, (Any,), inferred)
end
if !src_inferred || !src_inlineable
return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited,
return invoke_NF(argexprs0, e.typ, atypes0, sv, atype_unlimited,
invoke_data)
end

Expand All @@ -4351,7 +4356,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
if inline_worthy_stmts(current_stmts)
append!(current_stmts, ast)
if !inline_worthy_stmts(current_stmts)
return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited,
return invoke_NF(argexprs0, e.typ, atypes0, sv, atype_unlimited,
invoke_data)
end
end
Expand Down Expand Up @@ -4389,6 +4394,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
prelude_stmts = []
stmts_free = true # true = all entries of stmts are effect_free

argexprs = copy(argexprs)
for i = na:-1:1 # stmts_free needs to be calculated in reverse-argument order
#args_i = args[i]
aei = argexprs[i]
Expand Down Expand Up @@ -4479,19 +4485,18 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
local retval
multiret = false
lastexpr = pop!(body.args)
if isa(lastexpr,LabelNode)
if isa(lastexpr, LabelNode)
push!(body.args, lastexpr)
push!(body.args, Expr(:call, GlobalRef(topmod, :error), "fatal error in type inference (lowering)"))
lastexpr = nothing
elseif !(isa(lastexpr,Expr) && lastexpr.head === :return)
elseif !(isa(lastexpr, Expr) && lastexpr.head === :return)
# code sometimes ends with a meta node, e.g. inbounds pop
push!(body.args, lastexpr)
lastexpr = nothing
end
for a in body.args
push!(stmts, a)
if isa(a,Expr)
a = a::Expr
if isa(a, Expr)
if a.head === :return
if !multiret
# create slot first time
Expand Down Expand Up @@ -4581,7 +4586,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
end

if isa(expr,Expr)
if isa(expr, Expr)
old_t = e.typ
if old_t ⊑ expr.typ
# if we had better type information than the content being inlined,
Expand Down Expand Up @@ -4618,6 +4623,7 @@ function inline_worthy(body::Expr, cost::Integer=1000) # precondition: 0 < cost;
if !(isa(stmt, SSAValue) || inline_ignore(stmt))
nstmt += 1
end
isa(stmt, Expr) && stmt.head == :enter && return false # don't inline functions with try/catch
end
if nstmt < (symlim + 500) ÷ 1000
symlim *= 16
Expand Down Expand Up @@ -5456,6 +5462,8 @@ end
# TODO can probably be removed when we switch to a linear IR
function getfield_elim_pass!(sv::InferenceState)
body = sv.src.code
nssavalues = length(sv.src.ssavaluetypes)
sv.ssavalue_defs = find_ssavalue_defs(body, nssavalues)
for i = 1:length(body)
body[i] = _getfield_elim_pass!(body[i], sv)
end
Expand All @@ -5469,11 +5477,18 @@ function _getfield_elim_pass!(e::Expr, sv::InferenceState)
(isa(e.args[3],Int) || isa(e.args[3],QuoteNode))
e1 = e.args[2]
j = e.args[3]
if isa(e1,Expr)
alloc = is_allocation(e1, sv)
single_use = true
while isa(e1, SSAValue)
single_use = false
def = sv.ssavalue_defs[e1.id + 1]
stmt = sv.src.code[def]::Expr
e1 = stmt.args[2]
end
if isa(e1, Expr)
alloc = single_use && is_allocation(e1, sv)
if alloc !== false
flen, fnames = alloc
if isa(j,QuoteNode)
if isa(j, QuoteNode)
j = findfirst(fnames, j.value)
end
if 1 <= j <= flen
Expand All @@ -5489,17 +5504,17 @@ function _getfield_elim_pass!(e::Expr, sv::InferenceState)
end
end
end
elseif isa(e1, GlobalRef) || isa(e1, Symbol) || isa(e1, Slot) || isa(e1, SSAValue)
elseif isa(e1, GlobalRef) || isa(e1, Symbol) || isa(e1, Slot)
# non-self-quoting value
else
if isa(e1, QuoteNode)
e1 = e1.value
end
if isimmutable(e1) || isa(e1,SimpleVector)
if isimmutable(e1) || isa(e1, SimpleVector)
# SimpleVector length field is immutable
if isa(j, QuoteNode)
j = j.value
if !(isa(j,Int) || isa(j,Symbol))
if !(isa(j, Int) || isa(j, Symbol))
return e
end
end
Expand Down