Skip to content

Commit 009ba04

Browse files
committed
fix regression
1 parent 333604b commit 009ba04

File tree

4 files changed

+36
-19
lines changed

4 files changed

+36
-19
lines changed

Compiler/src/inferencestate.jl

+15-11
Original file line numberDiff line numberDiff line change
@@ -828,21 +828,25 @@ mutable struct IRInterpretationState
828828
new_call_inferred::Bool
829829

830830
function IRInterpretationState(interp::AbstractInterpreter,
831-
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
831+
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Union{Nothing,Vector{Any}},
832832
world::UInt, min_world::UInt, max_world::UInt)
833833
curridx = 1
834-
given_argtypes = Vector{Any}(undef, length(argtypes))
835-
for i = 1:length(given_argtypes)
836-
given_argtypes[i] = widenslotwrapper(argtypes[i])
837-
end
838-
if isa(mi.def, Method)
839-
argtypes_refined = Bool[!(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
840-
for i = 1:length(given_argtypes)]
834+
if argtypes !== nothing
835+
given_argtypes = Vector{Any}(undef, length(argtypes))
836+
for i = 1:length(given_argtypes)
837+
given_argtypes[i] = widenslotwrapper(argtypes[i])
838+
end
839+
if isa(mi.def, Method)
840+
argtypes_refined = Bool[!(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
841+
for i = 1:length(given_argtypes)]
842+
else
843+
argtypes_refined = Bool[false for _ = 1:length(given_argtypes)]
844+
end
845+
empty!(ir.argtypes)
846+
append!(ir.argtypes, given_argtypes)
841847
else
842-
argtypes_refined = Bool[false for i = 1:length(given_argtypes)]
848+
argtypes_refined = Bool[false for _ = 1:length(ir.argtypes)]
843849
end
844-
empty!(ir.argtypes)
845-
append!(ir.argtypes, given_argtypes)
846850
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
847851
ssa_refined = BitSet()
848852
lazyreachability = LazyCFGReachability(ir)

Compiler/src/optimize.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -1067,20 +1067,26 @@ function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState,
10671067
spec_info = SpecInfo(ci)
10681068
world = get_inference_world(interp)
10691069
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
1070-
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes,
1070+
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, #=argtypes=#nothing,
10711071
world, min_world, max_world)
1072-
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
1072+
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv;
1073+
# While `optinf!` itself performs reanalysis based on `IR_FLAG_REFINED`, since
1074+
# `IR_FLAG_REFINED` is also useful for subsequent `semi_concrete_eval`s that may
1075+
# occur on this `IRCode`, it is necessary to ensure that `optinf!` does not
1076+
# subtract `IR_FLAG_REFINED` (otherwise there might cases where the expected
1077+
# constant propagation information is not obtained through `irinterp`)
1078+
sub_ir_flag_refined=false)
10731079
if irsv.new_call_inferred
10741080
ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
10751081
ir = compact!(ir)
1076-
effects = result.effects
1082+
effects = result.ipo_effects
10771083
if nothrow
10781084
effects = Effects(effects; nothrow=true)
10791085
end
10801086
if noub
10811087
effects = Effects(effects; noub=ALWAYS_TRUE)
10821088
end
1083-
result.effects = effects
1089+
result.ipo_effects = effects
10841090
result.exc_result = refine_exception_type(result.exc_result, effects)
10851091
= strictneqpartialorder(ipo_lattice(interp))
10861092
result.result = rt result.result ? rt : result.result

Compiler/src/ssair/irinterp.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
229229
end
230230
return true
231231
end
232-
= strictneqpartialorder(typeinf_lattice(interp))
233-
if rt inst[:type]
232+
= strictpartialorder(typeinf_lattice(interp))
233+
if rt inst[:type]
234234
inst[:type] = rt
235235
return true
236236
end
@@ -319,6 +319,7 @@ function is_all_const_call(@nospecialize(stmt), interp::AbstractInterpreter, irs
319319
end
320320

321321
function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState;
322+
sub_ir_flag_refined::Bool = true,
322323
externally_refined::Union{Nothing,BitSet} = nothing)
323324
(; ir, tpdum, ssa_refined) = irsv
324325

@@ -341,7 +342,7 @@ function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRI
341342
any_refined = false
342343
if has_flag(flag, IR_FLAG_REFINED)
343344
any_refined = true
344-
sub_flag!(inst, IR_FLAG_REFINED)
345+
sub_ir_flag_refined && sub_flag!(inst, IR_FLAG_REFINED)
345346
elseif is_all_const_call(stmt, interp, irsv)
346347
# force reinference on calls with all constant arguments
347348
any_refined = true
@@ -394,7 +395,7 @@ function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRI
394395
stmt = inst[:stmt]
395396
flag = inst[:flag]
396397
if has_flag(flag, IR_FLAG_REFINED)
397-
sub_flag!(inst, IR_FLAG_REFINED)
398+
sub_ir_flag_refined && sub_flag!(inst, IR_FLAG_REFINED)
398399
push!(stmt_ip, idx)
399400
end
400401
check_ret!(stmt, idx)

Compiler/test/inference.jl

+6
Original file line numberDiff line numberDiff line change
@@ -6190,3 +6190,9 @@ let mi = only(methods(func_opt_inf, ())).specializations
61906190
ci = mi.cache
61916191
@test ci.rettype_const == sin(1.0)
61926192
end
6193+
@test fully_eliminated((BitSet,)) do b
6194+
iterate((pairs((b,))))[1][1]
6195+
end
6196+
@test fully_eliminated((BitSet,)) do b
6197+
iterate((pairs((b,))))[2]
6198+
end

0 commit comments

Comments
 (0)