Skip to content

Commit 1895479

Browse files
committed
inference: avoid inferring unreachable code methods (#51317)
(cherry picked from commit 0a82b71)
1 parent 9103718 commit 1895479

10 files changed

+137
-120
lines changed

base/compiler/abstractinterpretation.jl

+35-22
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,10 @@ function abstract_call_method(interp::AbstractInterpreter,
506506
return MethodCallResult(Any, false, false, nothing, Effects())
507507
end
508508
sigtuple = unwrap_unionall(sig)
509-
sigtuple isa DataType || return MethodCallResult(Any, false, false, nothing, Effects())
509+
sigtuple isa DataType ||
510+
return MethodCallResult(Any, false, false, nothing, Effects())
511+
all(@nospecialize(x) -> valid_as_lattice(unwrapva(x), true), sigtuple.parameters) ||
512+
return MethodCallResult(Union{}, false, false, nothing, EFFECTS_THROWS) # catch bad type intersections early
510513

511514
if is_nospecializeinfer(method)
512515
sig = get_nospecializeinfer_sig(method, sig, sparams)
@@ -1385,25 +1388,35 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
13851388
end
13861389
if isa(tti, Union)
13871390
utis = uniontypes(tti)
1388-
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
1389-
return AbstractIterationResult(Any[Vararg{Any}], nothing, Effects())
1390-
end
1391-
ltp = length((utis[1]::DataType).parameters)
1392-
for t in utis
1393-
if length((t::DataType).parameters) != ltp
1394-
return AbstractIterationResult(Any[Vararg{Any}], nothing)
1391+
# refine the Union to remove elements that are not valid tags for objects
1392+
filter!(@nospecialize(x) -> valid_as_lattice(x, true), utis)
1393+
if length(utis) == 0
1394+
return AbstractIterationResult(Any[], nothing) # oops, this statement was actually unreachable
1395+
elseif length(utis) == 1
1396+
tti = utis[1]
1397+
tti0 = rewrap_unionall(tti, tti0)
1398+
else
1399+
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
1400+
return AbstractIterationResult(Any[Vararg{Any}], nothing, Effects())
13951401
end
1396-
end
1397-
result = Any[ Union{} for _ in 1:ltp ]
1398-
for t in utis
1399-
tps = (t::DataType).parameters
1400-
_all(valid_as_lattice, tps) || continue
1401-
for j in 1:ltp
1402-
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
1402+
ltp = length((utis[1]::DataType).parameters)
1403+
for t in utis
1404+
if length((t::DataType).parameters) != ltp
1405+
return AbstractIterationResult(Any[Vararg{Any}], nothing)
1406+
end
1407+
end
1408+
result = Any[ Union{} for _ in 1:ltp ]
1409+
for t in utis
1410+
tps = (t::DataType).parameters
1411+
for j in 1:ltp
1412+
@assert valid_as_lattice(tps[j], true)
1413+
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
1414+
end
14031415
end
1416+
return AbstractIterationResult(result, nothing)
14041417
end
1405-
return AbstractIterationResult(result, nothing)
1406-
elseif tti0 <: Tuple
1418+
end
1419+
if tti0 <: Tuple
14071420
if isa(tti0, DataType)
14081421
return AbstractIterationResult(Any[ p for p in tti0.parameters ], nothing)
14091422
elseif !isa(tti, DataType)
@@ -1667,7 +1680,7 @@ end
16671680
return isa_condition(xt, ty, max_union_splitting)
16681681
end
16691682
@inline function isa_condition(@nospecialize(xt), @nospecialize(ty), max_union_splitting::Int)
1670-
tty_ub, isexact_tty = instanceof_tfunc(ty)
1683+
tty_ub, isexact_tty = instanceof_tfunc(ty, true)
16711684
tty = widenconst(xt)
16721685
if isexact_tty && !isa(tty_ub, TypeVar)
16731686
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
@@ -1677,7 +1690,7 @@ end
16771690
# `typeintersect` may be unable narrow down `Type`-type
16781691
thentype = tty_ub
16791692
end
1680-
valid_as_lattice(thentype) || (thentype = Bottom)
1693+
valid_as_lattice(thentype, true) || (thentype = Bottom)
16811694
elsetype = typesubtract(tty, tty_lb, max_union_splitting)
16821695
return ConditionalTypes(thentype, elsetype)
16831696
end
@@ -1923,7 +1936,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
19231936
ft′ = argtype_by_index(argtypes, 2)
19241937
ft = widenconst(ft′)
19251938
ft === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
1926-
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3))
1939+
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3), false)
19271940
isexact || return CallMeta(Any, Effects(), NoCallInfo())
19281941
unwrapped = unwrap_unionall(types)
19291942
if types === Bottom || !(unwrapped isa DataType) || unwrapped.name !== Tuple.name
@@ -2380,7 +2393,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
23802393
(; rt, effects) = abstract_eval_call(interp, e, vtypes, sv)
23812394
t = rt
23822395
elseif ehead === :new
2383-
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
2396+
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
23842397
ut = unwrap_unionall(t)
23852398
consistent = ALWAYS_FALSE
23862399
nothrow = false
@@ -2444,7 +2457,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
24442457
end
24452458
effects = Effects(EFFECTS_TOTAL; consistent, nothrow)
24462459
elseif ehead === :splatnew
2447-
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
2460+
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
24482461
nothrow = false # TODO: More precision
24492462
if length(e.args) == 2 && isconcretedispatch(t) && !ismutabletype(t)
24502463
at = abstract_eval_value(interp, e.args[2], vtypes, sv)

base/compiler/abstractlattice.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,18 @@ is_valid_lattice_norec(::InferenceLattice, @nospecialize(elem)) = isa(elem, Limi
9898
"""
9999
tmeet(𝕃::AbstractLattice, a, b::Type)
100100
101-
Compute the lattice meet of lattice elements `a` and `b` over the lattice `𝕃`.
102-
If `𝕃` is `JLTypeLattice`, this is equivalent to type intersection.
101+
Compute the lattice meet of lattice elements `a` and `b` over the lattice `𝕃`,
102+
dropping any results that will not be inhabited at runtime.
103+
If `𝕃` is `JLTypeLattice`, this is equivalent to type intersection plus the
104+
elimination of results that have no concrete subtypes.
103105
Note that currently `b` is restricted to being a type
104106
(interpreted as a lattice element in the `JLTypeLattice` sub-lattice of `𝕃`).
105107
"""
106108
function tmeet end
107109

108110
function tmeet(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type))
109111
ti = typeintersect(a, b)
110-
valid_as_lattice(ti) || return Bottom
112+
valid_as_lattice(ti, true) || return Bottom
111113
return ti
112114
end
113115

base/compiler/optimize.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
305305
elseif head === :new_opaque_closure
306306
length(args) < 4 && return (false, false, false)
307307
typ = argextype(args[1], src)
308-
typ, isexact = instanceof_tfunc(typ)
308+
typ, isexact = instanceof_tfunc(typ, true)
309309
isexact || return (false, false, false)
310310
(𝕃ₒ, typ, Tuple) || return (false, false, false)
311311
rt_lb = argextype(args[2], src)

base/compiler/ssair/inlining.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
12021202
end
12031203

12041204
function invoke_signature(argtypes::Vector{Any})
1205-
ft, argtyps = widenconst(argtypes[2]), instanceof_tfunc(widenconst(argtypes[3]))[1]
1205+
ft, argtyps = widenconst(argtypes[2]), instanceof_tfunc(widenconst(argtypes[3]), false)[1]
12061206
return rewrap_unionall(Tuple{ft, unwrap_unionall(argtyps).parameters...}, argtyps)
12071207
end
12081208

base/compiler/ssair/passes.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1705,7 +1705,7 @@ function adce_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
17051705
else
17061706
if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3
17071707
# nullify safe `typeassert` calls
1708-
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact))
1708+
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact), true)
17091709
if isexact && (𝕃ₒ, argextype(stmt.args[2], compact), ty)
17101710
compact[idx] = nothing
17111711
continue

0 commit comments

Comments
 (0)