Skip to content

Commit 2dd211c

Browse files
committed
[NewOptimizer] Domsort basic blocks
This sorts basic blocks into dominance order right after SSA construction. Doing so has the benefit that linear traversals of the function automatically become reverse post order traversals, or in other words, we're guaranteed that during a linear traversal, we see any SSA definitions before uses of those SSA values. This property is a prerequisite for the incremental compactor to work correctly without back tracking. It should also allow for simplifying codegen later, though not until the new IR is the only thing supported by codegen. Fixes core/asyncmap tests with the new optimizer enabled. I believe those were the last tests that caused invalid IR generation (there's still a few test failures due to the disabled alloc elim pass).
1 parent 8fff818 commit 2dd211c

File tree

17 files changed

+310
-32
lines changed

17 files changed

+310
-32
lines changed

base/compiler/optimize.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4194,8 +4194,7 @@ function copy_duplicated_expr_pass!(sv::OptimizationState)
41944194
end
41954195

41964196
# fix label numbers to always equal the statement index of the label
4197-
function reindex_labels!(sv::OptimizationState)
4198-
body = sv.src.code
4197+
function reindex_labels!(body::Vector{Any})
41994198
mapping = get_label_map(body)
42004199
for i = 1:length(body)
42014200
el = body[i]
@@ -4235,6 +4234,11 @@ function reindex_labels!(sv::OptimizationState)
42354234
end
42364235
end
42374236

4237+
4238+
function reindex_labels!(sv::OptimizationState)
4239+
reindex_labels!(sv.src.code)
4240+
end
4241+
42384242
function return_type(@nospecialize(f), @nospecialize(t))
42394243
params = Params(ccall(:jl_get_tls_world_age, UInt, ()))
42404244
rt = Union{}

base/compiler/ssair/domtree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,4 @@ function construct_domtree(cfg)
9494
# Recursively set level
9595
update_level!(domtree, 1, 1)
9696
DomTree(idoms, domtree)
97-
end
97+
end

base/compiler/ssair/driver.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ function normalize(@nospecialize(stmt), meta::Vector{Any}, table::Vector{LineInf
6565
elseif stmt.head === :gotoifnot
6666
return GotoIfNot(stmt.args...)
6767
elseif stmt.head === :return
68-
return ReturnNode{Any}(stmt.args...)
68+
return ReturnNode((length(stmt.args) == 0 ? (nothing,) : stmt.args)...)
69+
elseif stmt.head === :unreachable
70+
return ReturnNode()
6971
end
7072
elseif isa(stmt, LabelNode)
7173
return nothing
@@ -89,6 +91,23 @@ end
8991
function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode})
9092
mod = linetable[1].mod
9193
ci.code = copy(ci.code)
94+
# Go through and add an unreachable node after every
95+
# Union{} call. Then reindex labels.
96+
idx = 1
97+
while idx <= length(ci.code)
98+
stmt = ci.code[idx]
99+
if isexpr(stmt, :(=))
100+
stmt = stmt.args[2]
101+
end
102+
if isa(stmt, Expr) && stmt.typ === Union{}
103+
if !(idx < length(ci.code) && isexpr(ci.code[idx+1], :unreachable))
104+
insert!(ci.code, idx + 1, ReturnNode())
105+
idx += 1
106+
end
107+
end
108+
idx += 1
109+
end
110+
reindex_labels!(ci.code)
92111
meta = Any[]
93112
lines = fill(0, length(ci.code))
94113
let loc = RefValue(1)
@@ -110,6 +129,7 @@ function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode})
110129
IRCode(code, lines, cfg, argtypes, mod, meta)
111130
end
112131
ir = construct_ssa!(ci, ir, domtree, defuse_insts, nargs)
132+
domtree = construct_domtree(ir.cfg)
113133
ir = compact!(ir)
114134
verify_ir(ir)
115135
ir = type_lift_pass!(ir)

base/compiler/ssair/ir.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ struct GotoIfNot
1111
GotoIfNot(@nospecialize(cond), dest::Int) = new(cond, dest)
1212
end
1313

14-
struct ReturnNode{T}
15-
val::T
16-
ReturnNode{T}(@nospecialize(val)) where {T} = new{T}(val::T)
17-
ReturnNode{T}() where {T} = new{T}()
14+
struct ReturnNode
15+
val
16+
ReturnNode(@nospecialize(val)) = new(val)
17+
# unassigned val indicates unreachable
18+
ReturnNode() = new()
1819
end
1920

2021
"""
@@ -31,6 +32,8 @@ start(r::StmtRange) = 0
3132
done(r::StmtRange, state) = r.last - r.first < state
3233
next(r::StmtRange, state) = (r.first + state, state + 1)
3334

35+
StmtRange(range::UnitRange{Int}) = StmtRange(first(range), last(range))
36+
3437
struct BasicBlock
3538
stmts::StmtRange
3639
preds::Vector{Int}
@@ -264,7 +267,7 @@ function done(it::UseRefIterator, use)
264267
false
265268
end
266269

267-
function scan_ssa_use!(used::IdSet{Int64}, @nospecialize(stmt))
270+
function scan_ssa_use!(used, @nospecialize(stmt))
268271
if isa(stmt, SSAValue)
269272
push!(used, stmt.id)
270273
end
@@ -340,9 +343,9 @@ mutable struct IncrementalCompact
340343
end
341344

342345
struct TypesView
343-
compact::IncrementalCompact
346+
ir::Union{IRCode, IncrementalCompact}
344347
end
345-
types(compact::IncrementalCompact) = TypesView(compact)
348+
types(ir::Union{IRCode, IncrementalCompact}) = TypesView(ir)
346349

347350
function getindex(compact::IncrementalCompact, idx)
348351
if idx < compact.result_idx
@@ -368,10 +371,16 @@ function setindex!(compact::IncrementalCompact, v, idx)
368371
end
369372

370373
function getindex(view::TypesView, idx)
371-
if idx < view.compact.result_idx
374+
isa(idx, SSAValue) && (idx = idx.id)
375+
if isa(view.ir, IncrementalCompact) && idx < view.compact.result_idx
372376
return view.compact.result_types[idx]
373377
else
374-
return view.compact.ir.types[idx]
378+
ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir
379+
if idx <= length(ir.types)
380+
return ir.types[idx]
381+
else
382+
return ir.new_nodes[idx - length(ir.types)][2]
383+
end
375384
end
376385
end
377386

@@ -457,7 +466,7 @@ function next(compact::IncrementalCompact, (idx, active_bb, old_result_idx)::Tup
457466
compact.result_types[old_result_idx] = typ
458467
compact.result_lines[old_result_idx] = new_line
459468
result_idx = process_node!(compact, old_result_idx, new_node, new_idx, idx)
460-
(old_result_idx == result_idx) && return next(compact, (idx, result_idx))
469+
(old_result_idx == result_idx) && return next(compact, (idx, active_bb, result_idx))
461470
compact.result_idx = result_idx
462471
return (old_result_idx, compact.result[old_result_idx]), (compact.idx, active_bb, compact.result_idx)
463472
end

base/compiler/ssair/legacy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ function replace_code!(ci::CodeInfo, code::IRCode, nargs::Int, linetable::Vector
159159
new_stmt = Expr(:return, rename(stmt.val))
160160
else
161161
# Unreachable, so no issue with this
162-
new_stmt = nothing
162+
new_stmt = Expr(:unreachable)
163163
end
164164
elseif isa(stmt, SSAValue)
165165
new_stmt = rename(stmt)

base/compiler/ssair/passes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function type_lift_pass!(ir::IRCode)
3535
def = ir.stmts[item]
3636
edges = copy(def.edges)
3737
values = Vector{Any}(uninitialized, length(edges))
38-
new_phi = insert_node!(ir, item, Bool, PhiNode(edges, values))
38+
new_phi = length(values) == 0 ? false : insert_node!(ir, item, Bool, PhiNode(edges, values))
3939
processed[item] = new_phi
4040
if first
4141
lifted_undef[stmt_id] = new_phi

base/compiler/ssair/show.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ function Base.show(io::IO, code::IRCode)
8080
bbrange = cfg.blocks[bb_idx].stmts
8181
bbrange = bbrange.first:bbrange.last
8282
bb_pad = max_bb_idx_size - length(string(bb_idx))
83+
bb_start_str = string("$(bb_idx) ",length(cfg.blocks[bb_idx].preds) <= 1 ? "" : "", ""^(bb_pad)," ")
8384
if idx != last(bbrange)
8485
if idx == first(bbrange)
85-
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
86+
print(io, bb_start_str)
8687
else
8788
print(io, ""," "^max_bb_idx_size)
8889
end
@@ -98,7 +99,7 @@ function Base.show(io::IO, code::IRCode)
9899
node_idx += length(code.stmts)
99100
if print_sep
100101
if floop
101-
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
102+
print(io, bb_start_str)
102103
else
103104
print(io, ""," "^max_bb_idx_size)
104105
end
@@ -117,7 +118,7 @@ function Base.show(io::IO, code::IRCode)
117118
end
118119
if print_sep
119120
if idx == first(bbrange) && floop
120-
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
121+
print(io, bb_start_str)
121122
else
122123
print(io, idx == last(bbrange) ? string("", ""^(1+max_bb_idx_size), " ") :
123124
string("", " "^max_bb_idx_size))

0 commit comments

Comments
 (0)