Skip to content

Commit 44b7d1e

Browse files
committed
perform inference using optimizer-derived type information
In certain cases, the optimizer can introduce new type information. This is particularly evident in SROA, where load forwarding can reveal type information that was not visible during abstract interpretation. In such cases, re-running abstract interpretation using this new type information can be highly valuable, however, currently, this only occurs when semi-concrete interpretation happens to be triggered. This commit introduces a new "post-optimization inference" phase at the end of the optimizer pipeline. When the optimizer derives new type information, this phase performs IR abstract interpretation to further optimize the IR.
1 parent 3b629f1 commit 44b7d1e

File tree

9 files changed

+192
-70
lines changed

9 files changed

+192
-70
lines changed

Compiler/src/inferencestate.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ mutable struct IRInterpretationState
825825
callstack #::Vector{AbsIntState}
826826
frameid::Int
827827
parentid::Int
828+
new_call_inferred::Bool
828829

829830
function IRInterpretationState(interp::AbstractInterpreter,
830831
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
@@ -850,7 +851,7 @@ mutable struct IRInterpretationState
850851
edges = Any[]
851852
callstack = AbsIntState[]
852853
return new(spec_info, ir, mi, WorldWithRange(world, valid_worlds), curridx, argtypes_refined, ir.sptypes, tpdum,
853-
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0)
854+
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0, #=new_call_inferred=#false)
854855
end
855856
end
856857

Compiler/src/optimize.jl

+52-12
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,7 @@ end
999999

10001000
# run the optimization work
10011001
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
1002-
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt)
1002+
@timeit "optimizer" ir = run_passes_ipo_safe(interp, opt, caller)
10031003
ipo_dataflow_analysis!(interp, opt, ir, caller)
10041004
return finish(interp, opt, ir, caller)
10051005
end
@@ -1019,27 +1019,25 @@ matchpass(optimize_until::Int, stage, _) = optimize_until == stage
10191019
matchpass(optimize_until::String, _, name) = optimize_until == name
10201020
matchpass(::Nothing, _, _) = false
10211021

1022-
function run_passes_ipo_safe(
1023-
ci::CodeInfo,
1024-
sv::OptimizationState,
1025-
optimize_until = nothing, # run all passes by default
1026-
)
1022+
function run_passes_ipo_safe(interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult;
1023+
optimize_until = nothing) # run all passes by default
1024+
ci = sv.src
10271025
__stage__ = 0 # used by @pass
10281026
# NOTE: The pass name MUST be unique for `optimize_until::String` to work
10291027
@pass "convert" ir = convert_to_ircode(ci, sv)
10301028
@pass "slot2reg" ir = slot2reg(ir, ci, sv)
10311029
# TODO: Domsorting can produce an updated domtree - no need to recompute here
10321030
@pass "compact 1" ir = compact!(ir)
10331031
@pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
1034-
# @timeit "verify 2" verify_ir(ir)
10351032
@pass "compact 2" ir = compact!(ir)
10361033
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
1037-
@pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining)
1038-
if made_changes
1039-
@pass "compact 3" ir = compact!(ir, true)
1040-
end
1034+
@pass "ADCE" ir, changed = adce_pass!(ir, sv.inlining)
1035+
@pass "compact 3" changed && (
1036+
ir = compact!(ir, true))
1037+
@pass "optinf" optinf_worthwhile(ir) && (
1038+
ir = optinf!(ir, interp, sv, result))
10411039
if is_asserts()
1042-
@timeit "verify 3" begin
1040+
@timeit "verify" begin
10431041
verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp), sv.linfo)
10441042
verify_linetable(ir.debuginfo, length(ir.stmts))
10451043
end
@@ -1048,6 +1046,48 @@ function run_passes_ipo_safe(
10481046
return ir
10491047
end
10501048

1049+
# If the optimizer derives new type information (as implied by `IR_FLAG_REFINED`),
1050+
# and this new type information is available for the arguments of a call expression,
1051+
# further optimizations may be possible by performing irinterp on the optimized IR.
1052+
function optinf_worthwhile(ir::IRCode)
1053+
@assert isempty(ir.new_nodes) "expected compacted IRCode"
1054+
for i = 1:length(ir.stmts)
1055+
inst = ir[SSAValue(i)]
1056+
if has_flag(inst, IR_FLAG_REFINED)
1057+
if isexpr(inst[:stmt], :call)
1058+
return true
1059+
end
1060+
end
1061+
end
1062+
return false
1063+
end
1064+
1065+
function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult)
1066+
ci = sv.src
1067+
spec_info = SpecInfo(ci)
1068+
world = get_inference_world(interp)
1069+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
1070+
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes,
1071+
world, min_world, max_world)
1072+
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
1073+
if irsv.new_call_inferred
1074+
ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
1075+
ir = compact!(ir)
1076+
effects = result.effects
1077+
if nothrow
1078+
effects = Effects(effects; nothrow=true)
1079+
end
1080+
if noub
1081+
effects = Effects(effects; noub=ALWAYS_TRUE)
1082+
end
1083+
result.effects = effects
1084+
result.exc_result = refine_exception_type(result.exc_result, effects)
1085+
= strictneqpartialorder(ipo_lattice(interp))
1086+
result.result = rt result.result ? rt : result.result
1087+
end
1088+
return ir
1089+
end
1090+
10511091
function strip_trailing_junk!(code::Vector{Any}, ssavaluetypes::Vector{Any}, ssaflags::Vector, debuginfo::DebugInfoStream, cfg::CFG, info::Vector{CallInfo})
10521092
# Remove `nothing`s at the end, we don't handle them well
10531093
# (we expect the last instruction to be a terminator)

Compiler/src/ssair/EscapeAnalysis.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ using Base: # Base definitions
2424
isempty, length, max, min, missing, println, push!, pushfirst!,
2525
!, !==, &, *, +, -, :, <, <<, >, |, , , , , , , ,
2626
using ..Compiler: # Compiler specific definitions
27-
AbstractLattice, Compiler, IRCode, IR_FLAG_NOTHROW,
27+
@show, AbstractLattice, Compiler, IRCode, IR_FLAG_NOTHROW,
2828
argextype, fieldcount_noerror, has_flag, intrinsic_nothrow, is_meta_expr_head,
2929
is_identity_free_argtype, isexpr, setfield!_nothrow, singleton_type, try_compute_field,
30-
try_compute_fieldidx, widenconst
30+
try_compute_fieldidx, widenconst,
3131

3232
function include(x::String)
3333
if !isdefined(Base, :end_base_include)

Compiler/src/ssair/ir.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,8 @@ function reprocess_phi_node!(𝕃ₒ::AbstractLattice, compact::IncrementalCompa
17091709

17101710
# There's only one predecessor left - just replace it
17111711
v = phi.values[1]
1712-
if !(𝕃ₒ, compact[compact.ssa_rename[old_idx]][:type], argextype(v, compact))
1712+
= strictneqpartialorder(𝕃ₒ)
1713+
if argextype(v, compact) compact[compact.ssa_rename[old_idx]][:type]
17131714
v = Refined(v)
17141715
end
17151716
compact.ssa_rename[old_idx] = v

Compiler/src/ssair/irinterp.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sstate::St
5858
call = abstract_call(interp, arginfo, si, irsv)::Future
5959
Future{Any}(call, interp, irsv) do call, interp, irsv
6060
irsv.ir.stmts[irsv.curridx][:info] = call.info
61+
irsv.new_call_inferred |= true
6162
nothing
6263
end
6364
return call
@@ -204,7 +205,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
204205
# Handled at the very end
205206
return false
206207
elseif isa(stmt, PiNode)
207-
rt = tmeet(typeinf_lattice(interp), argextype(stmt.val, ir), widenconst(stmt.typ))
208+
= join(typeinf_lattice(interp))
209+
rt = argextype(stmt.val, ir) widenconst(stmt.typ)
208210
elseif stmt === nothing
209211
return false
210212
elseif isa(stmt, GlobalRef)
@@ -226,7 +228,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
226228
inst[:stmt] = quoted(rt.val)
227229
end
228230
return true
229-
elseif !(typeinf_lattice(interp), inst[:type], rt)
231+
end
232+
= strictneqpartialorder(typeinf_lattice(interp))
233+
if rt inst[:type]
230234
inst[:type] = rt
231235
return true
232236
end

Compiler/src/ssair/passes.jl

+26-10
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,10 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
989989
lifted_leaves === nothing && return
990990

991991
result_t = Union{}
992+
= join(𝕃ₒ)
992993
for v in values(lifted_leaves)
993994
v === nothing && return
994-
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
995+
result_t = result_t argextype(v.val, compact)
995996
end
996997

997998
(lifted_val, nest) = perform_lifting!(compact,
@@ -1001,8 +1002,12 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
10011002
compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
10021003
finish_phi_nest!(compact, nest)
10031004
if lifted_val !== nothing
1004-
if !(𝕃ₒ, compact[SSAValue(idx)][:type], tuple_tfunc(𝕃ₒ, Any[result_t]))
1005-
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
1005+
stmttype = tuple_tfunc(𝕃ₒ, Any[result_t])
1006+
inst = compact[SSAValue(idx)]
1007+
= strictneqpartialorder(𝕃ₒ)
1008+
if stmttype inst[:type]
1009+
inst[:type] = stmttype
1010+
add_flag!(inst, IR_FLAG_REFINED)
10061011
end
10071012
end
10081013

@@ -1440,19 +1445,23 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
14401445
lifted_leaves, any_undef = lifted_result
14411446

14421447
result_t = Union{}
1448+
= join(𝕃ₒ)
14431449
for v in values(lifted_leaves)
14441450
v === nothing && continue
1445-
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
1451+
result_t = result_t argextype(v.val, compact)
14461452
end
14471453

14481454
(lifted_val, nest) = perform_lifting!(compact,
14491455
visited_philikes, field, result_t, lifted_leaves, val, lazydomtree)
14501456

14511457
should_delete_node = false
1452-
line = compact[SSAValue(idx)][:line]
1453-
if lifted_val !== nothing && !(𝕃ₒ, compact[SSAValue(idx)][:type], result_t)
1458+
inst = compact[SSAValue(idx)]
1459+
line = inst[:line]
1460+
= strictneqpartialorder(𝕃ₒ)
1461+
if lifted_val !== nothing && result_t inst[:type]
14541462
compact[idx] = lifted_val === nothing ? nothing : lifted_val.val
1455-
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
1463+
inst[:type] = result_t
1464+
add_flag!(inst, IR_FLAG_REFINED)
14561465
elseif lifted_val === nothing || isa(lifted_val.val, AnySSAValue)
14571466
# Save some work in a later compaction, by inserting this into the renamer now,
14581467
# but only do this if we didn't set the REFINED flag, to save work for irinterp
@@ -1855,9 +1864,15 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
18551864
for use in du.uses
18561865
if use.kind === :getfield
18571866
inst = ir[SSAValue(use.idx)]
1858-
inst[:stmt] = compute_value_for_use(ir, domtree, allblocks,
1867+
newvalue = compute_value_for_use(ir, domtree, allblocks,
18591868
du, phinodes, fidx, use.idx)
1860-
add_flag!(inst, IR_FLAG_REFINED)
1869+
inst[:stmt] = newvalue
1870+
newvaluetyp = argextype(newvalue, ir)
1871+
= strictneqpartialorder(𝕃ₒ)
1872+
if newvaluetyp inst[:type]
1873+
inst[:type] = newvaluetyp
1874+
add_flag!(inst, IR_FLAG_REFINED)
1875+
end
18611876
elseif use.kind === :isdefined
18621877
continue # already rewritten if possible
18631878
elseif use.kind === :nopreserve
@@ -1878,11 +1893,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
18781893
for b in phiblocks
18791894
n = ir[phinodes[b]][:stmt]::PhiNode
18801895
result_t = Bottom
1896+
= join(𝕃ₒ)
18811897
for p in ir.cfg.blocks[b].preds
18821898
push!(n.edges, p)
18831899
v = compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, p)
18841900
push!(n.values, v)
1885-
result_t = tmerge(𝕃ₒ, result_t, argextype(v, ir))
1901+
result_t = result_t argextype(v, ir)
18861902
end
18871903
ir[phinodes[b]][:type] = result_t
18881904
end

Compiler/src/typeinfer.jl

+58-36
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,9 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
118118
# we can now widen our applicability in the global cache too
119119
store_backedges(ci, edges)
120120
end
121-
inferred_result = nothing
122-
uncompressed = inferred_result
123-
const_flag = is_result_constabi_eligible(result)
124-
discard_src = caller.cache_mode === CACHE_MODE_NULL || const_flag
121+
uncompressed = inferred_result = nothing
122+
(; rettype, exctype, rettype_const, const_flags) = ResultForCache(result)
123+
discard_src = caller.cache_mode === CACHE_MODE_NULL || is_result_constabi_eligible(result)
125124
if !discard_src
126125
inferred_result = transform_result_for_cache(interp, result)
127126
# TODO: do we want to augment edges here with any :invoke targets that we got from inlining (such that we didn't have a direct edge to it already)?
@@ -143,9 +142,13 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
143142
if !@isdefined di
144143
di = DebugInfo(result.linfo)
145144
end
146-
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
147-
ci, inferred_result, const_flag, first(result.valid_worlds), last(result.valid_worlds), encode_effects(result.ipo_effects),
148-
result.analysis_results, di, edges)
145+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
146+
ipo_effects = encode_effects(result.ipo_effects)
147+
ccall(:jl_update_codeinst, Cvoid, (
148+
Any, Any, Any, Any, Any, Int32, UInt, UInt,
149+
UInt32, Any, Any, Any),
150+
ci, inferred_result, rettype, exctype, rettype_const, const_flags, min_world, max_world,
151+
ipo_effects, result.analysis_results, di, edges)
149152
engine_reject(interp, ci)
150153
if !discard_src && isdefined(interp, :codegen) && uncompressed isa CodeInfo
151154
# record that the caller could use this result to generate code when required, if desired, to avoid repeating n^2 work
@@ -488,37 +491,15 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
488491

489492
# finish populating inference results into the CodeInstance if possible, and maybe cache that globally for use elsewhere
490493
if isdefined(result, :ci)
491-
result_type = result.result
492-
result_type isa LimitedAccuracy && (result_type = result_type.typ)
493-
@assert !(result_type === nothing)
494-
if isa(result_type, Const)
495-
rettype_const = result_type.val
496-
const_flags = is_result_constabi_eligible(result) ? 0x3 : 0x2
497-
elseif isa(result_type, PartialOpaque)
498-
rettype_const = result_type
499-
const_flags = 0x2
500-
elseif isconstType(result_type)
501-
rettype_const = result_type.parameters[1]
502-
const_flags = 0x2
503-
elseif isa(result_type, PartialStruct)
504-
rettype_const = result_type.fields
505-
const_flags = 0x2
506-
elseif isa(result_type, InterConditional)
507-
rettype_const = result_type
508-
const_flags = 0x2
509-
elseif isa(result_type, InterMustAlias)
510-
rettype_const = result_type
511-
const_flags = 0x2
512-
else
513-
rettype_const = nothing
514-
const_flags = 0x0
515-
end
494+
(; rettype, exctype, rettype_const, const_flags) = ResultForCache(result)
516495
di = nothing
517496
edges = empty_edges # `edges` will be updated within `finish!`
518497
ci = result.ci
519-
ccall(:jl_fill_codeinst, Cvoid, (Any, Any, Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
520-
ci, widenconst(result_type), widenconst(result.exc_result), rettype_const, const_flags,
521-
first(result.valid_worlds), last(result.valid_worlds),
498+
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
499+
ccall(:jl_fill_codeinst, Cvoid, (
500+
Any, Any, Any, Any, Int32, UInt, UInt,
501+
UInt32, Any, Any, Any),
502+
ci, rettype, exctype, rettype_const, const_flags, min_world, max_world,
522503
encode_effects(result.ipo_effects), result.analysis_results, di, edges)
523504
if is_cached(me) # CACHE_MODE_GLOBAL
524505
cached_result = cache_result!(me.interp, result, ci)
@@ -530,6 +511,46 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
530511
nothing
531512
end
532513

514+
struct ResultForCache
515+
rettype
516+
exctype
517+
rettype_const
518+
const_flags::UInt8
519+
ResultForCache(rettype, exctype, rettype_const, const_flags::UInt8) = (
520+
@nospecialize rettype exctype rettype_const;
521+
new(rettype, exctype, rettype_const, const_flags))
522+
end
523+
@inline function ResultForCache(result::InferenceResult)
524+
result_type = result.result
525+
result_type isa LimitedAccuracy && (result_type = result_type.typ)
526+
@assert !(result_type === nothing)
527+
rettype = widenconst(result_type)
528+
exctype = widenconst(result.exc_result)
529+
if isa(result_type, Const)
530+
rettype_const = result_type.val
531+
const_flags = is_result_constabi_eligible(result) ? 0x3 : 0x2
532+
elseif isa(result_type, PartialOpaque)
533+
rettype_const = result_type
534+
const_flags = 0x2
535+
elseif isconstType(result_type)
536+
rettype_const = result_type.parameters[1]
537+
const_flags = 0x2
538+
elseif isa(result_type, PartialStruct)
539+
rettype_const = result_type.fields
540+
const_flags = 0x2
541+
elseif isa(result_type, InterConditional)
542+
rettype_const = result_type
543+
const_flags = 0x2
544+
elseif isa(result_type, InterMustAlias)
545+
rettype_const = result_type
546+
const_flags = 0x2
547+
else
548+
rettype_const = nothing
549+
const_flags = 0x0
550+
end
551+
return ResultForCache(rettype, exctype, rettype_const, const_flags)
552+
end
553+
533554
# record the backedges
534555
function store_backedges(caller::CodeInstance, edges::SimpleVector)
535556
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance
@@ -1025,7 +1046,7 @@ function typeinf_ircode(interp::AbstractInterpreter, mi::MethodInstance,
10251046
end
10261047
(; result) = frame
10271048
opt = OptimizationState(frame, interp)
1028-
ir = run_passes_ipo_safe(opt.src, opt, optimize_until)
1049+
ir = run_passes_ipo_safe(interp, opt, result; optimize_until)
10291050
rt = widenconst(ignorelimited(result.result))
10301051
return ir, rt
10311052
end
@@ -1050,6 +1071,7 @@ function typeinf_frame(interp::AbstractInterpreter, mi::MethodInstance, run_opti
10501071
opt = OptimizationState(frame, interp)
10511072
optimize(interp, opt, frame.result)
10521073
src = ir_to_codeinf!(opt)
1074+
src.rettype = widenconst(result.result)
10531075
end
10541076
result.src = frame.src = src
10551077
end

0 commit comments

Comments
 (0)