Skip to content

Commit 1bf7837

Browse files
committed
perform inference using optimizer-derived type information
In certain cases, the optimizer can introduce new type information. This is particularly evident in SROA, where load forwarding can reveal type information that was not visible during abstract interpretation. In such cases, re-running abstract interpretation using this new type information can be highly valuable, however, currently, this only occurs when semi-concrete interpretation happens to be triggered. This commit introduces a new "post-optimization inference" phase at the end of the optimizer pipeline. When the optimizer derives new type information, this phase performs IR abstract interpretation to further optimize the IR.
1 parent e3f90c4 commit 1bf7837

File tree

8 files changed

+108
-30
lines changed

8 files changed

+108
-30
lines changed

Compiler/src/inferencestate.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ mutable struct IRInterpretationState
825825
callstack #::Vector{AbsIntState}
826826
frameid::Int
827827
parentid::Int
828+
new_call_inferred::Bool
828829

829830
function IRInterpretationState(interp::AbstractInterpreter,
830831
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
@@ -850,7 +851,7 @@ mutable struct IRInterpretationState
850851
edges = Any[]
851852
callstack = AbsIntState[]
852853
return new(spec_info, ir, mi, WorldWithRange(world, valid_worlds), curridx, argtypes_refined, ir.sptypes, tpdum,
853-
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0)
854+
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0, #=new_call_inferred=#false)
854855
end
855856
end
856857

Compiler/src/optimize.jl

+52-12
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ end
999999

10001000
# run the optimization work
10011001
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
1002-
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt)
1002+
@timeit "optimizer" ir = run_passes_ipo_safe(interp, opt, caller)
10031003
ipo_dataflow_analysis!(interp, opt, ir, caller)
10041004
return finish(interp, opt, ir, caller)
10051005
end
@@ -1019,27 +1019,25 @@ matchpass(optimize_until::Int, stage, _) = optimize_until == stage
10191019
matchpass(optimize_until::String, _, name) = optimize_until == name
10201020
matchpass(::Nothing, _, _) = false
10211021

1022-
function run_passes_ipo_safe(
1023-
ci::CodeInfo,
1024-
sv::OptimizationState,
1025-
optimize_until = nothing, # run all passes by default
1026-
)
1022+
function run_passes_ipo_safe(interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult;
1023+
optimize_until = nothing) # run all passes by default
1024+
ci = sv.src
10271025
__stage__ = 0 # used by @pass
10281026
# NOTE: The pass name MUST be unique for `optimize_until::AbstractString` to work
10291027
@pass "convert" ir = convert_to_ircode(ci, sv)
10301028
@pass "slot2reg" ir = slot2reg(ir, ci, sv)
10311029
# TODO: Domsorting can produce an updated domtree - no need to recompute here
10321030
@pass "compact 1" ir = compact!(ir)
10331031
@pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
1034-
# @timeit "verify 2" verify_ir(ir)
10351032
@pass "compact 2" ir = compact!(ir)
10361033
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
1037-
@pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining)
1038-
if made_changes
1039-
@pass "compact 3" ir = compact!(ir, true)
1040-
end
1034+
@pass "ADCE" ir, changed = adce_pass!(ir, sv.inlining)
1035+
@pass "compact 3" changed && (
1036+
ir = compact!(ir, true))
1037+
@pass "optinf" optinf_worthwhile(ir) && (
1038+
ir = optinf!(ir, interp, sv, result))
10411039
if is_asserts()
1042-
@timeit "verify 3" begin
1040+
@timeit "verify" begin
10431041
verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp), sv.linfo)
10441042
verify_linetable(ir.debuginfo, length(ir.stmts))
10451043
end
@@ -1048,6 +1046,48 @@ function run_passes_ipo_safe(
10481046
return ir
10491047
end
10501048

1049+
# If the optimizer derives new type information (as implied by `IR_FLAG_REFINED`),
1050+
# and this new type information is available for the arguments of a call expression,
1051+
# further optimizations may be possible by performing irinterp on the optimized IR.
1052+
function optinf_worthwhile(ir::IRCode)
1053+
@assert isempty(ir.new_nodes) "expected compacted IRCode"
1054+
for i = 1:length(ir.stmts)
1055+
if has_flag(ir[SSAValue(i)], IR_FLAG_REFINED)
1056+
stmt = ir[SSAValue(i)][:stmt]
1057+
if isexpr(stmt, :call)
1058+
return true
1059+
end
1060+
end
1061+
end
1062+
return false
1063+
end
1064+
1065+
function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult)
1066+
ci = sv.src
1067+
spec_info = SpecInfo(ci)
1068+
world = get_inference_world(interp)
1069+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
1070+
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes,
1071+
world, min_world, max_world)
1072+
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
1073+
if irsv.new_call_inferred
1074+
ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
1075+
ir = compact!(ir)
1076+
effects = result.effects
1077+
if nothrow
1078+
effects = Effects(effects; nothrow=true)
1079+
end
1080+
if noub
1081+
effects = Effects(effects; noub=ALWAYS_TRUE)
1082+
end
1083+
result.effects = effects
1084+
result.exc_result = refine_exception_type(result.exc_result, effects)
1085+
= strictneqpartialorder(ipo_lattice(interp))
1086+
result.result = rt result.result ? rt : result.result
1087+
end
1088+
return ir
1089+
end
1090+
10511091
function strip_trailing_junk!(code::Vector{Any}, ssavaluetypes::Vector{Any}, ssaflags::Vector, debuginfo::DebugInfoStream, cfg::CFG, info::Vector{CallInfo})
10521092
# Remove `nothing`s at the end, we don't handle them well
10531093
# (we expect the last instruction to be a terminator)

Compiler/src/ssair/EscapeAnalysis.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ using Base: # Base definitions
2424
isempty, length, max, min, missing, println, push!, pushfirst!,
2525
!, !==, &, *, +, -, :, <, <<, >, |, , , , , , , ,
2626
using ..Compiler: # Compiler specific definitions
27-
AbstractLattice, Compiler, IRCode, IR_FLAG_NOTHROW,
27+
@show, AbstractLattice, Compiler, IRCode, IR_FLAG_NOTHROW,
2828
argextype, fieldcount_noerror, has_flag, intrinsic_nothrow, is_meta_expr_head,
2929
is_identity_free_argtype, isexpr, setfield!_nothrow, singleton_type, try_compute_field,
30-
try_compute_fieldidx, widenconst
30+
try_compute_fieldidx, widenconst,
3131

3232
function include(x::String)
3333
if !isdefined(Base, :end_base_include)

Compiler/src/ssair/ir.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,8 @@ function reprocess_phi_node!(𝕃ₒ::AbstractLattice, compact::IncrementalCompa
17091709

17101710
# There's only one predecessor left - just replace it
17111711
v = phi.values[1]
1712-
if !(𝕃ₒ, compact[compact.ssa_rename[old_idx]][:type], argextype(v, compact))
1712+
= strictneqpartialorder(𝕃ₒ)
1713+
if argextype(v, compact) compact[compact.ssa_rename[old_idx]][:type]
17131714
v = Refined(v)
17141715
end
17151716
compact.ssa_rename[old_idx] = v

Compiler/src/ssair/irinterp.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sstate::St
5858
call = abstract_call(interp, arginfo, si, irsv)::Future
5959
Future{Any}(call, interp, irsv) do call, interp, irsv
6060
irsv.ir.stmts[irsv.curridx][:info] = call.info
61+
irsv.new_call_inferred |= true
6162
nothing
6263
end
6364
return call
@@ -204,7 +205,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
204205
# Handled at the very end
205206
return false
206207
elseif isa(stmt, PiNode)
207-
rt = tmeet(typeinf_lattice(interp), argextype(stmt.val, ir), widenconst(stmt.typ))
208+
= join(typeinf_lattice(interp))
209+
rt = argextype(stmt.val, ir) widenconst(stmt.typ)
208210
elseif stmt === nothing
209211
return false
210212
elseif isa(stmt, GlobalRef)
@@ -226,7 +228,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
226228
inst[:stmt] = quoted(rt.val)
227229
end
228230
return true
229-
elseif !(typeinf_lattice(interp), inst[:type], rt)
231+
end
232+
= strictneqpartialorder(typeinf_lattice(interp))
233+
if rt inst[:type]
230234
inst[:type] = rt
231235
return true
232236
end

Compiler/src/ssair/passes.jl

+26-10
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,10 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
989989
lifted_leaves === nothing && return
990990

991991
result_t = Union{}
992+
= join(𝕃ₒ)
992993
for v in values(lifted_leaves)
993994
v === nothing && return
994-
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
995+
result_t = result_t argextype(v.val, compact)
995996
end
996997

997998
(lifted_val, nest) = perform_lifting!(compact,
@@ -1001,8 +1002,12 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
10011002
compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
10021003
finish_phi_nest!(compact, nest)
10031004
if lifted_val !== nothing
1004-
if !(𝕃ₒ, compact[SSAValue(idx)][:type], tuple_tfunc(𝕃ₒ, Any[result_t]))
1005-
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
1005+
stmttype = tuple_tfunc(𝕃ₒ, Any[result_t])
1006+
inst = compact[SSAValue(idx)]
1007+
= strictneqpartialorder(𝕃ₒ)
1008+
if stmttype inst[:type]
1009+
inst[:type] = stmttype
1010+
add_flag!(inst, IR_FLAG_REFINED)
10061011
end
10071012
end
10081013

@@ -1440,19 +1445,23 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
14401445
lifted_leaves, any_undef = lifted_result
14411446

14421447
result_t = Union{}
1448+
= join(𝕃ₒ)
14431449
for v in values(lifted_leaves)
14441450
v === nothing && continue
1445-
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
1451+
result_t = result_t argextype(v.val, compact)
14461452
end
14471453

14481454
(lifted_val, nest) = perform_lifting!(compact,
14491455
visited_philikes, field, result_t, lifted_leaves, val, lazydomtree)
14501456

14511457
should_delete_node = false
1452-
line = compact[SSAValue(idx)][:line]
1453-
if lifted_val !== nothing && !(𝕃ₒ, compact[SSAValue(idx)][:type], result_t)
1458+
inst = compact[SSAValue(idx)]
1459+
line = inst[:line]
1460+
= strictneqpartialorder(𝕃ₒ)
1461+
if lifted_val !== nothing && result_t inst[:type]
14541462
compact[idx] = lifted_val === nothing ? nothing : lifted_val.val
1455-
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
1463+
inst[:type] = result_t
1464+
add_flag!(inst, IR_FLAG_REFINED)
14561465
elseif lifted_val === nothing || isa(lifted_val.val, AnySSAValue)
14571466
# Save some work in a later compaction, by inserting this into the renamer now,
14581467
# but only do this if we didn't set the REFINED flag, to save work for irinterp
@@ -1855,9 +1864,15 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
18551864
for use in du.uses
18561865
if use.kind === :getfield
18571866
inst = ir[SSAValue(use.idx)]
1858-
inst[:stmt] = compute_value_for_use(ir, domtree, allblocks,
1867+
newvalue = compute_value_for_use(ir, domtree, allblocks,
18591868
du, phinodes, fidx, use.idx)
1860-
add_flag!(inst, IR_FLAG_REFINED)
1869+
inst[:stmt] = newvalue
1870+
newvaluetyp = argextype(newvalue, ir)
1871+
= strictneqpartialorder(𝕃ₒ)
1872+
if newvaluetyp inst[:type]
1873+
inst[:type] = newvaluetyp
1874+
add_flag!(inst, IR_FLAG_REFINED)
1875+
end
18611876
elseif use.kind === :isdefined
18621877
continue # already rewritten if possible
18631878
elseif use.kind === :nopreserve
@@ -1878,11 +1893,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
18781893
for b in phiblocks
18791894
n = ir[phinodes[b]][:stmt]::PhiNode
18801895
result_t = Bottom
1896+
= join(𝕃ₒ)
18811897
for p in ir.cfg.blocks[b].preds
18821898
push!(n.edges, p)
18831899
v = compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, p)
18841900
push!(n.values, v)
1885-
result_t = tmerge(𝕃ₒ, result_t, argextype(v, ir))
1901+
result_t = result_t argextype(v, ir)
18861902
end
18871903
ir[phinodes[b]][:type] = result_t
18881904
end

Compiler/src/typeinfer.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ function typeinf_ircode(interp::AbstractInterpreter, mi::MethodInstance,
10001000
end
10011001
(; result) = frame
10021002
opt = OptimizationState(frame, interp)
1003-
ir = run_passes_ipo_safe(opt.src, opt, optimize_until)
1003+
ir = run_passes_ipo_safe(interp, opt, result; optimize_until)
10041004
rt = widenconst(ignorelimited(result.result))
10051005
return ir, rt
10061006
end
@@ -1025,6 +1025,7 @@ function typeinf_frame(interp::AbstractInterpreter, mi::MethodInstance, run_opti
10251025
opt = OptimizationState(frame, interp)
10261026
optimize(interp, opt, frame.result)
10271027
src = ir_to_codeinf!(opt)
1028+
src.rettype = widenconst(result.result)
10281029
end
10291030
result.src = frame.src = src
10301031
end

Compiler/test/inference.jl

+16-1
Original file line numberDiff line numberDiff line change
@@ -3494,7 +3494,7 @@ f31974(n::Int) = f31974(1:n)
34943494
@test code_typed(f31974, Tuple{Int}) !== nothing
34953495

34963496
f_overly_abstract_complex() = Complex(Ref{Number}(1)[])
3497-
@test Base.return_types(f_overly_abstract_complex, Tuple{}) == [Complex]
3497+
@test Base.infer_return_type(f_overly_abstract_complex, Tuple{}) == Complex{Int}
34983498

34993499
# Issue 26724
35003500
const IntRange = AbstractUnitRange{<:Integer}
@@ -6155,3 +6155,18 @@ end <: Any
61556155
end
61566156
return out
61576157
end == Union{Float64,DomainError}
6158+
6159+
# opt inf
6160+
@test Base.infer_return_type((Vector{Any},)) do argtypes
6161+
box = Core.Box()
6162+
box.contents = argtypes
6163+
return length(box.contents)
6164+
end == Int
6165+
@test Base.infer_return_type((Vector{Any},)) do argtypes
6166+
local argtypesi
6167+
function cls()
6168+
argtypesi = @noinline copy(argtypes)
6169+
return length(argtypesi)
6170+
end
6171+
return @inline cls()
6172+
end == Int

0 commit comments

Comments
 (0)