Skip to content

Commit 3e94518

Browse files
authored
Merge pull request #26306 from JuliaLang/kf/domsort
[NewOptimizer] Domsort basic blocks
2 parents 15f84b3 + 2dd211c commit 3e94518

17 files changed

+310
-32
lines changed

base/compiler/optimize.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -4194,8 +4194,7 @@ function copy_duplicated_expr_pass!(sv::OptimizationState)
41944194
end
41954195

41964196
# fix label numbers to always equal the statement index of the label
4197-
function reindex_labels!(sv::OptimizationState)
4198-
body = sv.src.code
4197+
function reindex_labels!(body::Vector{Any})
41994198
mapping = get_label_map(body)
42004199
for i = 1:length(body)
42014200
el = body[i]
@@ -4235,6 +4234,11 @@ function reindex_labels!(sv::OptimizationState)
42354234
end
42364235
end
42374236

4237+
4238+
function reindex_labels!(sv::OptimizationState)
4239+
reindex_labels!(sv.src.code)
4240+
end
4241+
42384242
function return_type(@nospecialize(f), @nospecialize(t))
42394243
params = Params(ccall(:jl_get_tls_world_age, UInt, ()))
42404244
rt = Union{}

base/compiler/ssair/domtree.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,4 @@ function construct_domtree(cfg)
9494
# Recursively set level
9595
update_level!(domtree, 1, 1)
9696
DomTree(idoms, domtree)
97-
end
97+
end

base/compiler/ssair/driver.jl

+21-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ function normalize(@nospecialize(stmt), meta::Vector{Any}, table::Vector{LineInf
6565
elseif stmt.head === :gotoifnot
6666
return GotoIfNot(stmt.args...)
6767
elseif stmt.head === :return
68-
return ReturnNode{Any}(stmt.args...)
68+
return ReturnNode((length(stmt.args) == 0 ? (nothing,) : stmt.args)...)
69+
elseif stmt.head === :unreachable
70+
return ReturnNode()
6971
end
7072
elseif isa(stmt, LabelNode)
7173
return nothing
@@ -89,6 +91,23 @@ end
8991
function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode})
9092
mod = linetable[1].mod
9193
ci.code = copy(ci.code)
94+
# Go through and add an unreachable node after every
95+
# Union{} call. Then reindex labels.
96+
idx = 1
97+
while idx <= length(ci.code)
98+
stmt = ci.code[idx]
99+
if isexpr(stmt, :(=))
100+
stmt = stmt.args[2]
101+
end
102+
if isa(stmt, Expr) && stmt.typ === Union{}
103+
if !(idx < length(ci.code) && isexpr(ci.code[idx+1], :unreachable))
104+
insert!(ci.code, idx + 1, ReturnNode())
105+
idx += 1
106+
end
107+
end
108+
idx += 1
109+
end
110+
reindex_labels!(ci.code)
92111
meta = Any[]
93112
lines = fill(0, length(ci.code))
94113
let loc = RefValue(1)
@@ -110,6 +129,7 @@ function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode})
110129
IRCode(code, lines, cfg, argtypes, mod, meta)
111130
end
112131
ir = construct_ssa!(ci, ir, domtree, defuse_insts, nargs)
132+
domtree = construct_domtree(ir.cfg)
113133
ir = compact!(ir)
114134
verify_ir(ir)
115135
ir = type_lift_pass!(ir)

base/compiler/ssair/ir.jl

+19-10
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ struct GotoIfNot
1111
GotoIfNot(@nospecialize(cond), dest::Int) = new(cond, dest)
1212
end
1313

14-
struct ReturnNode{T}
15-
val::T
16-
ReturnNode{T}(@nospecialize(val)) where {T} = new{T}(val::T)
17-
ReturnNode{T}() where {T} = new{T}()
14+
struct ReturnNode
15+
val
16+
ReturnNode(@nospecialize(val)) = new(val)
17+
# unassigned val indicates unreachable
18+
ReturnNode() = new()
1819
end
1920

2021
"""
@@ -31,6 +32,8 @@ start(r::StmtRange) = 0
3132
done(r::StmtRange, state) = r.last - r.first < state
3233
next(r::StmtRange, state) = (r.first + state, state + 1)
3334

35+
StmtRange(range::UnitRange{Int}) = StmtRange(first(range), last(range))
36+
3437
struct BasicBlock
3538
stmts::StmtRange
3639
preds::Vector{Int}
@@ -264,7 +267,7 @@ function done(it::UseRefIterator, use)
264267
false
265268
end
266269

267-
function scan_ssa_use!(used::IdSet{Int64}, @nospecialize(stmt))
270+
function scan_ssa_use!(used, @nospecialize(stmt))
268271
if isa(stmt, SSAValue)
269272
push!(used, stmt.id)
270273
end
@@ -340,9 +343,9 @@ mutable struct IncrementalCompact
340343
end
341344

342345
struct TypesView
343-
compact::IncrementalCompact
346+
ir::Union{IRCode, IncrementalCompact}
344347
end
345-
types(compact::IncrementalCompact) = TypesView(compact)
348+
types(ir::Union{IRCode, IncrementalCompact}) = TypesView(ir)
346349

347350
function getindex(compact::IncrementalCompact, idx)
348351
if idx < compact.result_idx
@@ -368,10 +371,16 @@ function setindex!(compact::IncrementalCompact, v, idx)
368371
end
369372

370373
function getindex(view::TypesView, idx)
371-
if idx < view.compact.result_idx
374+
isa(idx, SSAValue) && (idx = idx.id)
375+
if isa(view.ir, IncrementalCompact) && idx < view.compact.result_idx
372376
return view.compact.result_types[idx]
373377
else
374-
return view.compact.ir.types[idx]
378+
ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir
379+
if idx <= length(ir.types)
380+
return ir.types[idx]
381+
else
382+
return ir.new_nodes[idx - length(ir.types)][2]
383+
end
375384
end
376385
end
377386

@@ -457,7 +466,7 @@ function next(compact::IncrementalCompact, (idx, active_bb, old_result_idx)::Tup
457466
compact.result_types[old_result_idx] = typ
458467
compact.result_lines[old_result_idx] = new_line
459468
result_idx = process_node!(compact, old_result_idx, new_node, new_idx, idx)
460-
(old_result_idx == result_idx) && return next(compact, (idx, result_idx))
469+
(old_result_idx == result_idx) && return next(compact, (idx, active_bb, result_idx))
461470
compact.result_idx = result_idx
462471
return (old_result_idx, compact.result[old_result_idx]), (compact.idx, active_bb, compact.result_idx)
463472
end

base/compiler/ssair/legacy.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ function replace_code!(ci::CodeInfo, code::IRCode, nargs::Int, linetable::Vector
159159
new_stmt = Expr(:return, rename(stmt.val))
160160
else
161161
# Unreachable, so no issue with this
162-
new_stmt = nothing
162+
new_stmt = Expr(:unreachable)
163163
end
164164
elseif isa(stmt, SSAValue)
165165
new_stmt = rename(stmt)

base/compiler/ssair/passes.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function type_lift_pass!(ir::IRCode)
3535
def = ir.stmts[item]
3636
edges = copy(def.edges)
3737
values = Vector{Any}(uninitialized, length(edges))
38-
new_phi = insert_node!(ir, item, Bool, PhiNode(edges, values))
38+
new_phi = length(values) == 0 ? false : insert_node!(ir, item, Bool, PhiNode(edges, values))
3939
processed[item] = new_phi
4040
if first
4141
lifted_undef[stmt_id] = new_phi

base/compiler/ssair/show.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ function Base.show(io::IO, code::IRCode)
8080
bbrange = cfg.blocks[bb_idx].stmts
8181
bbrange = bbrange.first:bbrange.last
8282
bb_pad = max_bb_idx_size - length(string(bb_idx))
83+
bb_start_str = string("$(bb_idx) ",length(cfg.blocks[bb_idx].preds) <= 1 ? "" : "", ""^(bb_pad)," ")
8384
if idx != last(bbrange)
8485
if idx == first(bbrange)
85-
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
86+
print(io, bb_start_str)
8687
else
8788
print(io, ""," "^max_bb_idx_size)
8889
end
@@ -98,7 +99,7 @@ function Base.show(io::IO, code::IRCode)
9899
node_idx += length(code.stmts)
99100
if print_sep
100101
if floop
101-
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
102+
print(io, bb_start_str)
102103
else
103104
print(io, ""," "^max_bb_idx_size)
104105
end
@@ -117,7 +118,7 @@ function Base.show(io::IO, code::IRCode)
117118
end
118119
if print_sep
119120
if idx == first(bbrange) && floop
120-
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
121+
print(io, bb_start_str)
121122
else
122123
print(io, idx == last(bbrange) ? string("", ""^(1+max_bb_idx_size), " ") :
123124
string("", " "^max_bb_idx_size))

0 commit comments

Comments
 (0)