Skip to content

Commit b30cd2c

Browse files
committed
perform inference using optimizer-derived type information
In certain cases, the optimizer can introduce new type information. This is particularly evident in SROA, where load forwarding can reveal type information that was not visible during abstract interpretation. In such cases, re-running abstract interpretation using this new type information can be highly valuable, however, currently, this only occurs when semi-concrete interpretation happens to be triggered. This commit introduces a new "post-optimization inference" phase at the end of the optimizer pipeline. When the optimizer derives new type information, this phase performs IR abstract interpretation to further optimize the IR.
1 parent d604057 commit b30cd2c

File tree

9 files changed

+191
-69
lines changed

9 files changed

+191
-69
lines changed

Compiler/src/inferencestate.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ mutable struct IRInterpretationState
825825
callstack #::Vector{AbsIntState}
826826
frameid::Int
827827
parentid::Int
828+
new_call_inferred::Bool
828829

829830
function IRInterpretationState(interp::AbstractInterpreter,
830831
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
@@ -850,7 +851,7 @@ mutable struct IRInterpretationState
850851
edges = Any[]
851852
callstack = AbsIntState[]
852853
return new(spec_info, ir, mi, WorldWithRange(world, valid_worlds), curridx, argtypes_refined, ir.sptypes, tpdum,
853-
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0)
854+
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0, #=new_call_inferred=#false)
854855
end
855856
end
856857

Compiler/src/optimize.jl

+52-12
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ end
999999

10001000
# run the optimization work
10011001
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
1002-
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt)
1002+
@timeit "optimizer" ir = run_passes_ipo_safe(interp, opt, caller)
10031003
ipo_dataflow_analysis!(interp, opt, ir, caller)
10041004
return finish(interp, opt, ir, caller)
10051005
end
@@ -1019,27 +1019,25 @@ matchpass(optimize_until::Int, stage, _) = optimize_until == stage
10191019
matchpass(optimize_until::String, _, name) = optimize_until == name
10201020
matchpass(::Nothing, _, _) = false
10211021

1022-
function run_passes_ipo_safe(
1023-
ci::CodeInfo,
1024-
sv::OptimizationState,
1025-
optimize_until = nothing, # run all passes by default
1026-
)
1022+
function run_passes_ipo_safe(interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult;
1023+
optimize_until = nothing) # run all passes by default
1024+
ci = sv.src
10271025
__stage__ = 0 # used by @pass
10281026
# NOTE: The pass name MUST be unique for `optimize_until::String` to work
10291027
@pass "convert" ir = convert_to_ircode(ci, sv)
10301028
@pass "slot2reg" ir = slot2reg(ir, ci, sv)
10311029
# TODO: Domsorting can produce an updated domtree - no need to recompute here
10321030
@pass "compact 1" ir = compact!(ir)
10331031
@pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
1034-
# @timeit "verify 2" verify_ir(ir)
10351032
@pass "compact 2" ir = compact!(ir)
10361033
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
1037-
@pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining)
1038-
if made_changes
1039-
@pass "compact 3" ir = compact!(ir, true)
1040-
end
1034+
@pass "ADCE" ir, changed = adce_pass!(ir, sv.inlining)
1035+
@pass "compact 3" changed && (
1036+
ir = compact!(ir, true))
1037+
@pass "optinf" optinf_worthwhile(ir) && (
1038+
ir = optinf!(ir, interp, sv, result))
10411039
if is_asserts()
1042-
@timeit "verify 3" begin
1040+
@timeit "verify" begin
10431041
verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp), sv.linfo)
10441042
verify_linetable(ir.debuginfo, length(ir.stmts))
10451043
end
@@ -1048,6 +1046,48 @@ function run_passes_ipo_safe(
10481046
return ir
10491047
end
10501048

1049+
# If the optimizer derives new type information (as implied by `IR_FLAG_REFINED`),
1050+
# and this new type information is available for the arguments of a call expression,
1051+
# further optimizations may be possible by performing irinterp on the optimized IR.
1052+
function optinf_worthwhile(ir::IRCode)
1053+
@assert isempty(ir.new_nodes) "expected compacted IRCode"
1054+
for i = 1:length(ir.stmts)
1055+
inst = ir[SSAValue(i)]
1056+
if has_flag(inst, IR_FLAG_REFINED)
1057+
if isexpr(inst[:stmt], :call)
1058+
return true
1059+
end
1060+
end
1061+
end
1062+
return false
1063+
end
1064+
1065+
function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult)
1066+
ci = sv.src
1067+
spec_info = SpecInfo(ci)
1068+
world = get_inference_world(interp)
1069+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
1070+
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes,
1071+
world, min_world, max_world)
1072+
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
1073+
if irsv.new_call_inferred
1074+
ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
1075+
ir = compact!(ir)
1076+
effects = result.effects
1077+
if nothrow
1078+
effects = Effects(effects; nothrow=true)
1079+
end
1080+
if noub
1081+
effects = Effects(effects; noub=ALWAYS_TRUE)
1082+
end
1083+
result.effects = effects
1084+
result.exc_result = refine_exception_type(result.exc_result, effects)
1085+
= strictneqpartialorder(ipo_lattice(interp))
1086+
result.result = rt result.result ? rt : result.result
1087+
end
1088+
return ir
1089+
end
1090+
10511091
function strip_trailing_junk!(code::Vector{Any}, ssavaluetypes::Vector{Any}, ssaflags::Vector, debuginfo::DebugInfoStream, cfg::CFG, info::Vector{CallInfo})
10521092
# Remove `nothing`s at the end, we don't handle them well
10531093
# (we expect the last instruction to be a terminator)

Compiler/src/ssair/EscapeAnalysis.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ using Base: # Base definitions
2424
isempty, length, max, min, missing, println, push!, pushfirst!,
2525
!, !==, &, *, +, -, :, <, <<, >, |, , , , , , , ,
2626
using ..Compiler: # Compiler specific definitions
27-
AbstractLattice, Compiler, IRCode, IR_FLAG_NOTHROW,
27+
@show, AbstractLattice, Compiler, IRCode, IR_FLAG_NOTHROW,
2828
argextype, fieldcount_noerror, has_flag, intrinsic_nothrow, is_meta_expr_head,
2929
is_identity_free_argtype, isexpr, setfield!_nothrow, singleton_type, try_compute_field,
30-
try_compute_fieldidx, widenconst
30+
try_compute_fieldidx, widenconst,
3131

3232
function include(x::String)
3333
if !isdefined(Base, :end_base_include)

Compiler/src/ssair/ir.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,8 @@ function reprocess_phi_node!(𝕃ₒ::AbstractLattice, compact::IncrementalCompa
17091709

17101710
# There's only one predecessor left - just replace it
17111711
v = phi.values[1]
1712-
if !(𝕃ₒ, compact[compact.ssa_rename[old_idx]][:type], argextype(v, compact))
1712+
= strictneqpartialorder(𝕃ₒ)
1713+
if argextype(v, compact) compact[compact.ssa_rename[old_idx]][:type]
17131714
v = Refined(v)
17141715
end
17151716
compact.ssa_rename[old_idx] = v

Compiler/src/ssair/irinterp.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sstate::St
5858
call = abstract_call(interp, arginfo, si, irsv)::Future
5959
Future{Any}(call, interp, irsv) do call, interp, irsv
6060
irsv.ir.stmts[irsv.curridx][:info] = call.info
61+
irsv.new_call_inferred |= true
6162
nothing
6263
end
6364
return call
@@ -204,7 +205,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
204205
# Handled at the very end
205206
return false
206207
elseif isa(stmt, PiNode)
207-
rt = tmeet(typeinf_lattice(interp), argextype(stmt.val, ir), widenconst(stmt.typ))
208+
= join(typeinf_lattice(interp))
209+
rt = argextype(stmt.val, ir) widenconst(stmt.typ)
208210
elseif stmt === nothing
209211
return false
210212
elseif isa(stmt, GlobalRef)
@@ -226,7 +228,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
226228
inst[:stmt] = quoted(rt.val)
227229
end
228230
return true
229-
elseif !(typeinf_lattice(interp), inst[:type], rt)
231+
end
232+
= strictneqpartialorder(typeinf_lattice(interp))
233+
if rt inst[:type]
230234
inst[:type] = rt
231235
return true
232236
end

Compiler/src/ssair/passes.jl

+26-10
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,10 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
989989
lifted_leaves === nothing && return
990990

991991
result_t = Union{}
992+
= join(𝕃ₒ)
992993
for v in values(lifted_leaves)
993994
v === nothing && return
994-
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
995+
result_t = result_t argextype(v.val, compact)
995996
end
996997

997998
(lifted_val, nest) = perform_lifting!(compact,
@@ -1001,8 +1002,12 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
10011002
compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
10021003
finish_phi_nest!(compact, nest)
10031004
if lifted_val !== nothing
1004-
if !(𝕃ₒ, compact[SSAValue(idx)][:type], tuple_tfunc(𝕃ₒ, Any[result_t]))
1005-
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
1005+
stmttype = tuple_tfunc(𝕃ₒ, Any[result_t])
1006+
inst = compact[SSAValue(idx)]
1007+
= strictneqpartialorder(𝕃ₒ)
1008+
if stmttype inst[:type]
1009+
inst[:type] = stmttype
1010+
add_flag!(inst, IR_FLAG_REFINED)
10061011
end
10071012
end
10081013

@@ -1440,19 +1445,23 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
14401445
lifted_leaves, any_undef = lifted_result
14411446

14421447
result_t = Union{}
1448+
= join(𝕃ₒ)
14431449
for v in values(lifted_leaves)
14441450
v === nothing && continue
1445-
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
1451+
result_t = result_t argextype(v.val, compact)
14461452
end
14471453

14481454
(lifted_val, nest) = perform_lifting!(compact,
14491455
visited_philikes, field, result_t, lifted_leaves, val, lazydomtree)
14501456

14511457
should_delete_node = false
1452-
line = compact[SSAValue(idx)][:line]
1453-
if lifted_val !== nothing && !(𝕃ₒ, compact[SSAValue(idx)][:type], result_t)
1458+
inst = compact[SSAValue(idx)]
1459+
line = inst[:line]
1460+
= strictneqpartialorder(𝕃ₒ)
1461+
if lifted_val !== nothing && result_t inst[:type]
14541462
compact[idx] = lifted_val === nothing ? nothing : lifted_val.val
1455-
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
1463+
inst[:type] = result_t
1464+
add_flag!(inst, IR_FLAG_REFINED)
14561465
elseif lifted_val === nothing || isa(lifted_val.val, AnySSAValue)
14571466
# Save some work in a later compaction, by inserting this into the renamer now,
14581467
# but only do this if we didn't set the REFINED flag, to save work for irinterp
@@ -1855,9 +1864,15 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
18551864
for use in du.uses
18561865
if use.kind === :getfield
18571866
inst = ir[SSAValue(use.idx)]
1858-
inst[:stmt] = compute_value_for_use(ir, domtree, allblocks,
1867+
newvalue = compute_value_for_use(ir, domtree, allblocks,
18591868
du, phinodes, fidx, use.idx)
1860-
add_flag!(inst, IR_FLAG_REFINED)
1869+
inst[:stmt] = newvalue
1870+
newvaluetyp = argextype(newvalue, ir)
1871+
= strictneqpartialorder(𝕃ₒ)
1872+
if newvaluetyp inst[:type]
1873+
inst[:type] = newvaluetyp
1874+
add_flag!(inst, IR_FLAG_REFINED)
1875+
end
18611876
elseif use.kind === :isdefined
18621877
continue # already rewritten if possible
18631878
elseif use.kind === :nopreserve
@@ -1878,11 +1893,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
18781893
for b in phiblocks
18791894
n = ir[phinodes[b]][:stmt]::PhiNode
18801895
result_t = Bottom
1896+
= join(𝕃ₒ)
18811897
for p in ir.cfg.blocks[b].preds
18821898
push!(n.edges, p)
18831899
v = compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, p)
18841900
push!(n.values, v)
1885-
result_t = tmerge(𝕃ₒ, result_t, argextype(v, ir))
1901+
result_t = result_t argextype(v, ir)
18861902
end
18871903
ir[phinodes[b]][:type] = result_t
18881904
end

Compiler/src/typeinfer.jl

+57-35
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState;
121121
end
122122
inferred_result = nothing
123123
relocatability = 0x1
124-
const_flag = is_result_constabi_eligible(result)
125-
if !can_discard_trees || (is_cached(caller) && !const_flag)
124+
(; rettype, exctype, rettype_const, const_flags) = ResultForCache(result)
125+
if !can_discard_trees || (is_cached(caller) && iszero(const_flags & 0x1))
126126
inferred_result = transform_result_for_cache(interp, result)
127127
# TODO: do we want to augment edges here with any :invoke targets that we got from inlining (such that we didn't have a direct edge to it already)?
128128
relocatability = 0x0
@@ -145,9 +145,13 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState;
145145
if !@isdefined di
146146
di = DebugInfo(result.linfo)
147147
end
148-
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, UInt8, Any, Any),
149-
ci, inferred_result, const_flag, first(result.valid_worlds), last(result.valid_worlds), encode_effects(result.ipo_effects),
150-
result.analysis_results, relocatability, di, edges)
148+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
149+
ipo_effects = encode_effects(result.ipo_effects)
150+
ccall(:jl_update_codeinst, Cvoid, (
151+
Any, Any, Any, Any, Any, Int32, UInt, UInt,
152+
UInt32, Any, UInt8, Any, Any),
153+
ci, inferred_result, rettype, exctype, rettype_const, const_flags, min_world, max_world,
154+
ipo_effects, result.analysis_results, relocatability, di, edges)
151155
engine_reject(interp, ci)
152156
end
153157
return nothing
@@ -451,38 +455,15 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
451455

452456
# finish populating inference results into the CodeInstance if possible, and maybe cache that globally for use elsewhere
453457
if isdefined(result, :ci)
454-
result_type = result.result
455-
result_type isa LimitedAccuracy && (result_type = result_type.typ)
456-
@assert !(result_type === nothing)
457-
if isa(result_type, Const)
458-
rettype_const = result_type.val
459-
const_flags = is_result_constabi_eligible(result) ? 0x3 : 0x2
460-
elseif isa(result_type, PartialOpaque)
461-
rettype_const = result_type
462-
const_flags = 0x2
463-
elseif isconstType(result_type)
464-
rettype_const = result_type.parameters[1]
465-
const_flags = 0x2
466-
elseif isa(result_type, PartialStruct)
467-
rettype_const = result_type.fields
468-
const_flags = 0x2
469-
elseif isa(result_type, InterConditional)
470-
rettype_const = result_type
471-
const_flags = 0x2
472-
elseif isa(result_type, InterMustAlias)
473-
rettype_const = result_type
474-
const_flags = 0x2
475-
else
476-
rettype_const = nothing
477-
const_flags = 0x0
478-
end
479-
relocatability = 0x0
458+
(; rettype, exctype, rettype_const, const_flags) = ResultForCache(result)
480459
di = nothing
481460
edges = empty_edges # `edges` will be updated within `finish!`
482461
ci = result.ci
483-
ccall(:jl_fill_codeinst, Cvoid, (Any, Any, Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
484-
ci, widenconst(result_type), widenconst(result.exc_result), rettype_const, const_flags,
485-
first(result.valid_worlds), last(result.valid_worlds),
462+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
463+
ccall(:jl_fill_codeinst, Cvoid, (
464+
Any, Any, Any, Any, Int32, UInt, UInt,
465+
UInt32, Any, Any, Any),
466+
ci, rettype, exctype, rettype_const, const_flags, min_world, max_world,
486467
encode_effects(result.ipo_effects), result.analysis_results, di, edges)
487468
if is_cached(me)
488469
cached_result = cache_result!(me.interp, result, ci)
@@ -494,6 +475,46 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
494475
nothing
495476
end
496477

478+
struct ResultForCache
479+
rettype
480+
exctype
481+
rettype_const
482+
const_flags::UInt8
483+
ResultForCache(rettype, exctype, rettype_const, const_flags::UInt8) = (
484+
@nospecialize rettype exctype rettype_const;
485+
new(rettype, exctype, rettype_const, const_flags))
486+
end
487+
@inline function ResultForCache(result::InferenceResult)
488+
result_type = result.result
489+
result_type isa LimitedAccuracy && (result_type = result_type.typ)
490+
@assert !(result_type === nothing)
491+
rettype = widenconst(result_type)
492+
exctype = widenconst(result.exc_result)
493+
if isa(result_type, Const)
494+
rettype_const = result_type.val
495+
const_flags = is_result_constabi_eligible(result) ? 0x3 : 0x2
496+
elseif isa(result_type, PartialOpaque)
497+
rettype_const = result_type
498+
const_flags = 0x2
499+
elseif isconstType(result_type)
500+
rettype_const = result_type.parameters[1]
501+
const_flags = 0x2
502+
elseif isa(result_type, PartialStruct)
503+
rettype_const = result_type.fields
504+
const_flags = 0x2
505+
elseif isa(result_type, InterConditional)
506+
rettype_const = result_type
507+
const_flags = 0x2
508+
elseif isa(result_type, InterMustAlias)
509+
rettype_const = result_type
510+
const_flags = 0x2
511+
else
512+
rettype_const = nothing
513+
const_flags = 0x0
514+
end
515+
return ResultForCache(rettype, exctype, rettype_const, const_flags)
516+
end
517+
497518
# record the backedges
498519
function store_backedges(caller::CodeInstance, edges::SimpleVector)
499520
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance
@@ -1000,7 +1021,7 @@ function typeinf_ircode(interp::AbstractInterpreter, mi::MethodInstance,
10001021
end
10011022
(; result) = frame
10021023
opt = OptimizationState(frame, interp)
1003-
ir = run_passes_ipo_safe(opt.src, opt, optimize_until)
1024+
ir = run_passes_ipo_safe(interp, opt, result; optimize_until)
10041025
rt = widenconst(ignorelimited(result.result))
10051026
return ir, rt
10061027
end
@@ -1025,6 +1046,7 @@ function typeinf_frame(interp::AbstractInterpreter, mi::MethodInstance, run_opti
10251046
opt = OptimizationState(frame, interp)
10261047
optimize(interp, opt, frame.result)
10271048
src = ir_to_codeinf!(opt)
1049+
src.rettype = widenconst(result.result)
10281050
end
10291051
result.src = frame.src = src
10301052
end

0 commit comments

Comments
 (0)