Skip to content

Commit b448286

Browse files
committed
implement Jameson's suggestion
1 parent 67286e4 commit b448286

File tree

4 files changed

+81
-35
lines changed

4 files changed

+81
-35
lines changed

Compiler/src/optimize.jl

+43-9
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,48 @@ struct InliningState{Interp<:AbstractInterpreter}
143143
edges::Vector{Any}
144144
world::UInt
145145
interp::Interp
146+
opt_cache::IdDict{MethodInstance,CodeInstance}
146147
end
147-
function InliningState(sv::InferenceState, interp::AbstractInterpreter)
148-
return InliningState(sv.edges, frame_world(sv), interp)
148+
function InliningState(sv::InferenceState, interp::AbstractInterpreter,
149+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
150+
return InliningState(sv.edges, frame_world(sv), interp, opt_cache)
149151
end
150-
function InliningState(interp::AbstractInterpreter)
151-
return InliningState(Any[], get_inference_world(interp), interp)
152+
function InliningState(interp::AbstractInterpreter,
153+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
154+
return InliningState(Any[], get_inference_world(interp), interp, opt_cache)
155+
end
156+
157+
struct OptimizerCache{CodeCache}
158+
wvc::WorldView{CodeCache}
159+
owner
160+
opt_cache::IdDict{MethodInstance,CodeInstance}
161+
function OptimizerCache(
162+
wvc::WorldView{CodeCache},
163+
owner,
164+
opt_cache::IdDict{MethodInstance,CodeInstance}) where CodeCache
165+
@nospecialize owner
166+
new{CodeCache}(wvc, owner, opt_cache)
167+
end
168+
end
169+
function get((; wvc, owner, opt_cache)::OptimizerCache, mi::MethodInstance, default)
170+
if haskey(opt_cache, mi)
171+
codeinst = opt_cache[mi]
172+
if (codeinst.min_world wvc.worlds.min_world &&
173+
wvc.worlds.max_world codeinst.max_world &&
174+
codeinst.owner === owner)
175+
@assert isdefined(codeinst, :inferred) && codeinst.inferred === nothing
176+
return codeinst
177+
end
178+
end
179+
return get(wvc, mi, default)
152180
end
153181

154182
# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
155-
code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world)
183+
function code_cache(state::InliningState)
184+
cache = WorldView(code_cache(state.interp), state.world)
185+
owner = cache_owner(state.interp)
186+
return OptimizerCache(cache, owner, state.opt_cache)
187+
end
156188

157189
mutable struct OptimizationState{Interp<:AbstractInterpreter}
158190
linfo::MethodInstance
@@ -168,13 +200,15 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
168200
bb_vartables::Vector{Union{Nothing,VarTable}}
169201
insert_coverage::Bool
170202
end
171-
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter)
172-
inlining = InliningState(sv, interp)
203+
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter,
204+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
205+
inlining = InliningState(sv, interp, opt_cache)
173206
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod,
174207
sv.sptypes, sv.slottypes, inlining, sv.cfg,
175208
sv.unreachable, sv.bb_vartables, sv.insert_coverage)
176209
end
177-
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
210+
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter,
211+
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
178212
# prepare src for running optimization passes if it isn't already
179213
nssavalues = src.ssavaluetypes
180214
if nssavalues isa Int
@@ -194,7 +228,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn
194228
mod = isa(def, Method) ? def.module : def
195229
# Allow using the global MI cache, but don't track edges.
196230
# This method is mostly used for unit testing the optimizer
197-
inlining = InliningState(interp)
231+
inlining = InliningState(interp, opt_cache)
198232
cfg = compute_basic_blocks(src.code)
199233
unreachable = BitSet()
200234
bb_vartables = Union{VarTable,Nothing}[]

Compiler/src/typeinfer.jl

+36-25
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
106106
#@assert last(result.valid_worlds) <= get_world_counter() || isempty(caller.edges)
107107
if isdefined(result, :ci)
108108
ci = result.ci
109+
mi = result.linfo
109110
# if we aren't cached, we don't need this edge
110111
# but our caller might, so let's just make it anyways
111112
if last(result.valid_worlds) >= get_world_counter()
@@ -132,15 +133,15 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
132133
end
133134
di = inferred_result.debuginfo
134135
uncompressed = inferred_result
135-
inferred_result = maybe_compress_codeinfo(interp, result.linfo, inferred_result)
136+
inferred_result = maybe_compress_codeinfo(interp, mi, inferred_result)
136137
result.is_src_volatile = false
137138
elseif ci.owner === nothing
138139
# The global cache can only handle objects that codegen understands
139140
inferred_result = nothing
140141
end
141142
end
142143
if !@isdefined di
143-
di = DebugInfo(result.linfo)
144+
di = DebugInfo(mi)
144145
end
145146
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
146147
ipo_effects = encode_effects(result.ipo_effects)
@@ -149,6 +150,9 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
149150
UInt32, Any, Any, Any),
150151
ci, inferred_result, rettype, exctype, rettype_const, const_flags, min_world, max_world,
151152
ipo_effects, result.analysis_results, di, edges)
153+
if is_cached(caller) # CACHE_MODE_GLOBAL
154+
cache_result!(interp, mi, ci)
155+
end
152156
engine_reject(interp, ci)
153157
if !discard_src && isdefined(interp, :codegen) && uncompressed isa CodeInfo
154158
# record that the caller could use this result to generate code when required, if desired, to avoid repeating n^2 work
@@ -157,7 +161,6 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
157161
# This is necessary to get decent bootstrapping performance
158162
# when compiling the compiler to inject everything eagerly
159163
# where codegen can start finding and using it right away
160-
mi = result.linfo
161164
if mi.def isa Method && isa_compileable_sig(mi)
162165
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), ci, uncompressed)
163166
end
@@ -167,6 +170,10 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
167170
return nothing
168171
end
169172

173+
function cache_result!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstance)
174+
code_cache(interp)[mi] = ci
175+
end
176+
170177
function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstance, src::CodeInfo)
171178
user_edges = src.edges
172179
edges = user_edges isa SimpleVector ? user_edges : user_edges === nothing ? Core.svec() : Core.svec(user_edges...)
@@ -200,11 +207,13 @@ function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstan
200207
end
201208

202209
function finish_nocycle(::AbstractInterpreter, frame::InferenceState)
203-
finishinfer!(frame, frame.interp)
210+
opt_cache = IdDict{MethodInstance,CodeInstance}()
211+
finishinfer!(frame, frame.interp, opt_cache)
204212
opt = frame.result.src
205213
if opt isa OptimizationState # implies `may_optimize(caller.interp) === true`
206214
optimize(frame.interp, opt, frame.result)
207215
end
216+
empty!(opt_cache)
208217
finish!(frame.interp, frame)
209218
if frame.cycleid != 0
210219
frames = frame.callstack::Vector{AbsIntState}
@@ -227,10 +236,11 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
227236
cycle_valid_worlds = intersect(cycle_valid_worlds, caller.world.valid_worlds)
228237
cycle_valid_effects = merge_effects(cycle_valid_effects, caller.ipo_effects)
229238
end
239+
opt_cache = IdDict{MethodInstance,CodeInstance}()
230240
for frameid = cycleid:length(frames)
231241
caller = frames[frameid]::InferenceState
232242
adjust_cycle_frame!(caller, cycle_valid_worlds, cycle_valid_effects)
233-
finishinfer!(caller, caller.interp)
243+
finishinfer!(caller, caller.interp, opt_cache)
234244
end
235245
for frameid = cycleid:length(frames)
236246
caller = frames[frameid]::InferenceState
@@ -239,6 +249,7 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
239249
optimize(caller.interp, opt, caller.result)
240250
end
241251
end
252+
empty!(opt_cache)
242253
for frameid = cycleid:length(frames)
243254
caller = frames[frameid]::InferenceState
244255
finish!(caller.interp, caller)
@@ -285,22 +296,6 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance
285296
end
286297
end
287298

288-
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
289-
@assert isdefined(ci, :inferred)
290-
# check if the existing linfo metadata is also sufficient to describe the current inference result
291-
# to decide if it is worth caching this right now
292-
mi = result.linfo
293-
cache = WorldView(code_cache(interp), result.valid_worlds)
294-
if haskey(cache, mi)
295-
ci = cache[mi]
296-
# n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later
297-
@assert isdefined(ci, :inferred)
298-
return false
299-
end
300-
code_cache(interp)[mi] = ci
301-
return true
302-
end
303-
304299
function cycle_fix_limited(@nospecialize(typ), sv::InferenceState)
305300
if typ isa LimitedAccuracy
306301
if sv.parentid === 0
@@ -428,7 +423,8 @@ const empty_edges = Core.svec()
428423

429424
# inference completed on `me`
430425
# update the MethodInstance
431-
function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
426+
function finishinfer!(me::InferenceState, interp::AbstractInterpreter,
427+
opt_cache::IdDict{MethodInstance, CodeInstance})
432428
# prepare to run optimization passes on fulltree
433429
@assert isempty(me.ip)
434430
# inspect whether our inference had a limited result accuracy,
@@ -481,7 +477,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
481477
# disable optimization if we've already obtained very accurate result
482478
!result_is_constabi(interp, result)
483479
if doopt
484-
result.src = OptimizationState(me, interp)
480+
result.src = OptimizationState(me, interp, opt_cache)
485481
else
486482
result.src = me.src # for reflection etc.
487483
end
@@ -502,9 +498,11 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
502498
ci, rettype, exctype, rettype_const, const_flags, min_world, max_world,
503499
encode_effects(result.ipo_effects), result.analysis_results, di, edges)
504500
if is_cached(me) # CACHE_MODE_GLOBAL
505-
cached_result = cache_result!(me.interp, result, ci)
506-
if !cached_result
501+
already_cached = is_already_cached(me.interp, result, ci)
502+
if already_cached
507503
me.cache_mode = CACHE_MODE_VOLATILE
504+
else
505+
opt_cache[result.linfo] = ci
508506
end
509507
end
510508
end
@@ -551,6 +549,19 @@ end
551549
return ResultForCache(rettype, exctype, rettype_const, const_flags)
552550
end
553551

552+
function is_already_cached(interp::AbstractInterpreter, result::InferenceResult, ci::CodeInstance)
553+
# check if the existing linfo metadata is also sufficient to describe the current inference result
554+
# to decide if it is worth caching this right now
555+
mi = result.linfo
556+
cache = WorldView(code_cache(interp), result.valid_worlds)
557+
if haskey(cache, mi)
558+
# n.b.: accurate edge representation might cause the CodeInstance for this to be constructed later
559+
@assert isdefined(cache[mi], :inferred)
560+
return true
561+
end
562+
return false
563+
end
564+
554565
# record the backedges
555566
function store_backedges(caller::CodeInstance, edges::SimpleVector)
556567
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance

Compiler/src/types.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
2-
#
32

43
const WorkThunk = Any
54
# #@eval struct WorkThunk

test/compileall.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
13
# This test builds a full system image, so it can take a little while.
24
# We make it a separate test target here, so that it can run in parallel
35
# with the rest of the tests.

0 commit comments

Comments
 (0)