Skip to content

Commit a89b590

Browse files
committed
implement Jameson's suggestion
1 parent abd56dd commit a89b590

File tree

3 files changed

+76
-34
lines changed

3 files changed

+76
-34
lines changed

Compiler/src/optimize.jl

+34-9
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,39 @@ struct InliningState{Interp<:AbstractInterpreter}
143143
edges::Vector{Any}
144144
world::UInt
145145
interp::Interp
146+
opt_cache::IdDict{MethodInstance,CodeInstance}
146147
end
147-
function InliningState(sv::InferenceState, interp::AbstractInterpreter)
148-
return InliningState(sv.edges, frame_world(sv), interp)
148+
function InliningState(sv::InferenceState, interp::AbstractInterpreter,
149+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
150+
return InliningState(sv.edges, frame_world(sv), interp, opt_cache)
149151
end
150-
function InliningState(interp::AbstractInterpreter)
151-
return InliningState(Any[], get_inference_world(interp), interp)
152+
function InliningState(interp::AbstractInterpreter,
153+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
154+
return InliningState(Any[], get_inference_world(interp), interp, opt_cache)
155+
end
156+
157+
struct OptimizerCache{CodeCache}
158+
code_cache::CodeCache
159+
opt_cache::IdDict{MethodInstance,CodeInstance}
160+
end
161+
function get((; code_cache, opt_cache)::OptimizerCache{WorldView{InternalCodeCache}}, mi::MethodInstance, default)
162+
if haskey(opt_cache, mi)
163+
codeinst = opt_cache[mi]
164+
if (codeinst.min_world code_cache.worlds.min_world &&
165+
code_cache.worlds.max_world codeinst.max_world &&
166+
codeinst.owner === code_cache.cache.owner)
167+
@assert isdefined(codeinst, :inferred) && codeinst.inferred === nothing
168+
return codeinst
169+
end
170+
end
171+
return get(code_cache, mi, default)
152172
end
153173

154174
# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
155-
code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world)
175+
function code_cache(state::InliningState)
176+
cache = WorldView(code_cache(state.interp), state.world)
177+
return OptimizerCache(cache, state.opt_cache)
178+
end
156179

157180
mutable struct OptimizationState{Interp<:AbstractInterpreter}
158181
linfo::MethodInstance
@@ -168,13 +191,15 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
168191
bb_vartables::Vector{Union{Nothing,VarTable}}
169192
insert_coverage::Bool
170193
end
171-
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter)
172-
inlining = InliningState(sv, interp)
194+
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter,
195+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
196+
inlining = InliningState(sv, interp, opt_cache)
173197
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod,
174198
sv.sptypes, sv.slottypes, inlining, sv.cfg,
175199
sv.unreachable, sv.bb_vartables, sv.insert_coverage)
176200
end
177-
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
201+
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter,
202+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
178203
# prepare src for running optimization passes if it isn't already
179204
nssavalues = src.ssavaluetypes
180205
if nssavalues isa Int
@@ -194,7 +219,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn
194219
mod = isa(def, Method) ? def.module : def
195220
# Allow using the global MI cache, but don't track edges.
196221
# This method is mostly used for unit testing the optimizer
197-
inlining = InliningState(interp)
222+
inlining = InliningState(interp, opt_cache)
198223
cfg = compute_basic_blocks(src.code)
199224
unreachable = BitSet()
200225
bb_vartables = Union{VarTable,Nothing}[]

Compiler/src/typeinfer.jl

+36-25
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
106106
#@assert last(result.valid_worlds) <= get_world_counter() || isempty(caller.edges)
107107
if isdefined(result, :ci)
108108
ci = result.ci
109+
mi = result.linfo
109110
# if we aren't cached, we don't need this edge
110111
# but our caller might, so let's just make it anyways
111112
if last(result.valid_worlds) >= get_world_counter()
@@ -132,15 +133,15 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
132133
end
133134
di = inferred_result.debuginfo
134135
uncompressed = inferred_result
135-
inferred_result = maybe_compress_codeinfo(interp, result.linfo, inferred_result)
136+
inferred_result = maybe_compress_codeinfo(interp, mi, inferred_result)
136137
result.is_src_volatile = false
137138
elseif ci.owner === nothing
138139
# The global cache can only handle objects that codegen understands
139140
inferred_result = nothing
140141
end
141142
end
142143
if !@isdefined di
143-
di = DebugInfo(result.linfo)
144+
di = DebugInfo(mi)
144145
end
145146
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
146147
ipo_effects = encode_effects(result.ipo_effects)
@@ -149,6 +150,9 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
149150
UInt32, Any, Any, Any),
150151
ci, inferred_result, rettype, exctype, rettype_const, const_flags, min_world, max_world,
151152
ipo_effects, result.analysis_results, di, edges)
153+
if is_cached(caller) # CACHE_MODE_GLOBAL
154+
cache_result!(interp, mi, ci)
155+
end
152156
engine_reject(interp, ci)
153157
if !discard_src && isdefined(interp, :codegen) && uncompressed isa CodeInfo
154158
# record that the caller could use this result to generate code when required, if desired, to avoid repeating n^2 work
@@ -157,7 +161,6 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
157161
# This is necessary to get decent bootstrapping performance
158162
# when compiling the compiler to inject everything eagerly
159163
# where codegen can start finding and using it right away
160-
mi = result.linfo
161164
if mi.def isa Method && isa_compileable_sig(mi)
162165
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), ci, uncompressed)
163166
end
@@ -167,6 +170,10 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
167170
return nothing
168171
end
169172

173+
function cache_result!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstance)
174+
code_cache(interp)[mi] = ci
175+
end
176+
170177
function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstance, src::CodeInfo)
171178
user_edges = src.edges
172179
edges = user_edges isa SimpleVector ? user_edges : user_edges === nothing ? Core.svec() : Core.svec(user_edges...)
@@ -200,11 +207,13 @@ function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstan
200207
end
201208

202209
function finish_nocycle(::AbstractInterpreter, frame::InferenceState)
203-
finishinfer!(frame, frame.interp)
210+
opt_cache = IdDict{MethodInstance,CodeInstance}()
211+
finishinfer!(frame, frame.interp, opt_cache)
204212
opt = frame.result.src
205213
if opt isa OptimizationState # implies `may_optimize(caller.interp) === true`
206214
optimize(frame.interp, opt, frame.result)
207215
end
216+
empty!(opt_cache)
208217
finish!(frame.interp, frame)
209218
if frame.cycleid != 0
210219
frames = frame.callstack::Vector{AbsIntState}
@@ -227,10 +236,11 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
227236
cycle_valid_worlds = intersect(cycle_valid_worlds, caller.world.valid_worlds)
228237
cycle_valid_effects = merge_effects(cycle_valid_effects, caller.ipo_effects)
229238
end
239+
opt_cache = IdDict{MethodInstance,CodeInstance}()
230240
for frameid = cycleid:length(frames)
231241
caller = frames[frameid]::InferenceState
232242
adjust_cycle_frame!(caller, cycle_valid_worlds, cycle_valid_effects)
233-
finishinfer!(caller, caller.interp)
243+
finishinfer!(caller, caller.interp, opt_cache)
234244
end
235245
for frameid = cycleid:length(frames)
236246
caller = frames[frameid]::InferenceState
@@ -239,6 +249,7 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
239249
optimize(caller.interp, opt, caller.result)
240250
end
241251
end
252+
empty!(opt_cache)
242253
for frameid = cycleid:length(frames)
243254
caller = frames[frameid]::InferenceState
244255
finish!(caller.interp, caller)
@@ -285,22 +296,6 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance
285296
end
286297
end
287298

288-
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
289-
@assert isdefined(ci, :inferred)
290-
# check if the existing linfo metadata is also sufficient to describe the current inference result
291-
# to decide if it is worth caching this right now
292-
mi = result.linfo
293-
cache = WorldView(code_cache(interp), result.valid_worlds)
294-
if haskey(cache, mi)
295-
ci = cache[mi]
296-
# n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later
297-
@assert isdefined(ci, :inferred)
298-
return false
299-
end
300-
code_cache(interp)[mi] = ci
301-
return true
302-
end
303-
304299
function cycle_fix_limited(@nospecialize(typ), sv::InferenceState)
305300
if typ isa LimitedAccuracy
306301
if sv.parentid === 0
@@ -428,7 +423,8 @@ const empty_edges = Core.svec()
428423

429424
# inference completed on `me`
430425
# update the MethodInstance
431-
function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
426+
function finishinfer!(me::InferenceState, interp::AbstractInterpreter,
427+
opt_cache::IdDict{MethodInstance, CodeInstance})
432428
# prepare to run optimization passes on fulltree
433429
@assert isempty(me.ip)
434430
# inspect whether our inference had a limited result accuracy,
@@ -481,7 +477,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
481477
# disable optimization if we've already obtained very accurate result
482478
!result_is_constabi(interp, result)
483479
if doopt
484-
result.src = OptimizationState(me, interp)
480+
result.src = OptimizationState(me, interp, opt_cache)
485481
else
486482
result.src = me.src # for reflection etc.
487483
end
@@ -502,9 +498,11 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
502498
ci, rettype, exctype, rettype_const, const_flags, min_world, max_world,
503499
encode_effects(result.ipo_effects), result.analysis_results, di, edges)
504500
if is_cached(me) # CACHE_MODE_GLOBAL
505-
cached_result = cache_result!(me.interp, result, ci)
506-
if !cached_result
501+
already_cached = is_already_cached(me.interp, result, ci)
502+
if already_cached
507503
me.cache_mode = CACHE_MODE_VOLATILE
504+
else
505+
opt_cache[result.linfo] = ci
508506
end
509507
end
510508
end
@@ -551,6 +549,19 @@ end
551549
return ResultForCache(rettype, exctype, rettype_const, const_flags)
552550
end
553551

552+
function is_already_cached(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
553+
# check if the existing linfo metadata is also sufficient to describe the current inference result
554+
# to decide if it is worth caching this right now
555+
mi = result.linfo
556+
cache = WorldView(code_cache(interp), result.valid_worlds)
557+
if haskey(cache, mi)
558+
# n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later
559+
@assert isdefined(cache[mi], :inferred)
560+
return true
561+
end
562+
return false
563+
end
564+
554565
# record the backedges
555566
function store_backedges(caller::CodeInstance, edges::SimpleVector)
556567
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance

Compiler/test/newinterp.jl

+6
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ macro newinterp(InterpName, ephemeral_cache::Bool=false)
5858
$Compiler.getindex(wvc::$Compiler.WorldView{$InterpCacheName}, mi::$C.MethodInstance) = getindex(wvc.cache.dict, mi)
5959
$Compiler.haskey(wvc::$Compiler.WorldView{$InterpCacheName}, mi::$C.MethodInstance) = haskey(wvc.cache.dict, mi)
6060
$Compiler.setindex!(wvc::$Compiler.WorldView{$InterpCacheName}, ci::$C.CodeInstance, mi::$C.MethodInstance) = setindex!(wvc.cache.dict, ci, mi)
61+
function $Compiler.get(cache::$Compiler.OptimizerCache{$Compiler.WorldView{$InterpCacheName}}, mi::$C.MethodInstance, default)
62+
if haskey(cache.opt_cache, mi)
63+
return cache.opt_cache[mi]
64+
end
65+
return get(cache.code_cache, mi, default)
66+
end
6167
end)
6268
end
6369
end

0 commit comments

Comments
 (0)