Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4ac3536

Browse files
committedFeb 2, 2022
optimizer: simple array SROA
Implements a simple Julia-level array allocation elimination on top of #43888. ```julia julia> code_typed((String,String)) do s, t a = Vector{Base.RefValue{String}}(undef, 2) a[1] = Ref(s) a[2] = Ref(t) return a[1][] end ``` ```diff diff --git a/master b/pr index 9c8da14380..5b63d08190 100644 --- a/master +++ b/pr @@ -1,11 +1,4 @@ 1-element Vector{Any}: CodeInfo( -1 ─ %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Base.RefValue{String}}, svec(Any, Int64), 0, :(:ccall), Vector{Base.RefValue{String}}, 2, 2))::Vector{Base.RefValue{String}} -│ %2 = %new(Base.RefValue{String}, s)::Base.RefValue{String} -│ Base.arrayset(true, %1, %2, 1)::Vector{Base.RefValue{String}} -│ %4 = %new(Base.RefValue{String}, t)::Base.RefValue{String} -│ Base.arrayset(true, %1, %4, 2)::Vector{Base.RefValue{String}} -│ %6 = Base.arrayref(true, %1, 1)::Base.RefValue{String} -│ %7 = Base.getfield(%6, :x)::String -└── return %7 +1 ─ return s ) => String ``` Still this array SROA handle is very limited and able to handle only trivial examples (though I confirmed this version already eliminates few array allocations during sysimg build). For those who interested, I added some discussions on array optimization [here](https://aviatesk.github.io/EscapeAnalysis.jl/dev/#EA-Array-Analysis).
1 parent a19cdce commit 4ac3536

File tree

4 files changed

+505
-166
lines changed

4 files changed

+505
-166
lines changed
 

‎base/compiler/optimize.jl

+12-2
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,7 @@ end
269269

270270
function foreigncall_effect_free(stmt::Expr, src::Union{IRCode,IncrementalCompact})
271271
args = stmt.args
272-
name = args[1]
273-
isa(name, QuoteNode) && (name = name.value)
272+
name = normalize(args[1])
274273
isa(name, Symbol) || return false
275274
ndims = alloc_array_ndims(name)
276275
if ndims !== nothing
@@ -296,6 +295,17 @@ function alloc_array_ndims(name::Symbol)
296295
return nothing
297296
end
298297

298+
normalize(@nospecialize x) = isa(x, QuoteNode) ? x.value : x
299+
300+
function is_array_alloc(@nospecialize stmt)
301+
isa(stmt, Expr) || return false
302+
if isexpr(stmt, :foreigncall)
303+
name = normalize(stmt.args[1])
304+
return isa(name, Symbol) && alloc_array_ndims(name) !== nothing
305+
end
306+
return false
307+
end
308+
299309
function alloc_array_no_throw(args::Vector{Any}, ndims::Int, src::Union{IRCode,IncrementalCompact})
300310
length(args) ndims+6 || return false
301311
atype = instanceof_tfunc(argextype(args[6], src))[1]

‎base/compiler/ssair/passes.jl

+355-154
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,9 @@ function sroa_pass!(ir::IRCode, nargs::Int, mi_cache::MICache) where MICache
560560
anymutability = true
561561
end
562562
continue
563+
elseif is_array_alloc(stmt)
564+
anymutability = true
565+
continue
563566
# elseif is_known_call(stmt, setfield!, compact)
564567
# 4 <= length(stmt.args) <= 5 || continue
565568
# if length(stmt.args) == 5
@@ -695,7 +698,8 @@ function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preser
695698
return newex
696699
end
697700

698-
import .EscapeAnalysis: EscapeInfo, IndexableFields, LivenessSet, getaliases
701+
import .EscapeAnalysis:
702+
EscapeInfo, IndexableFields, IndexableElements, LivenessSet, ArrayInfo, getaliases
699703

700704
function sroa_mutables!(ir::IRCode, nargs::Int, mi_cache::MICache) where MICache
701705
# Compute domtree now, needed below, now that we have finished compacting the IR.
@@ -709,12 +713,12 @@ function sroa_mutables!(ir::IRCode, nargs::Int, mi_cache::MICache) where MICache
709713
eliminated = BitSet()
710714
revisit = Tuple{#=related=#Vector{SSAValue}, #=Liveness=#LivenessSet}[]
711715
all_preserved = true
712-
newpreserves = nothing
716+
newpreserves = IdDict{Int,Vector{Any}}()
713717
while !isempty(wset)
714718
idx = pop!(wset)
715719
ssa = SSAValue(idx)
716720
stmt = ir[ssa][:inst]
717-
isexpr(stmt, :new) || continue
721+
isexpr(stmt, :new) || is_array_alloc(stmt) || continue
718722
einfo = estate[ssa]
719723
is_load_forwardable(einfo) || continue
720724
aliases = getaliases(ssa, estate)
@@ -728,151 +732,48 @@ function sroa_mutables!(ir::IRCode, nargs::Int, mi_cache::MICache) where MICache
728732
delete!(wset, alias.id)
729733
end
730734
end
731-
finfos = (einfo.AliasInfo::IndexableFields).infos
732-
nfields = length(finfos)
733-
734-
# Partition defuses by field
735-
fdefuses = Vector{FieldDefUse}(undef, nfields)
736-
for i = 1:nfields
737-
finfo = finfos[i]
738-
fdu = FieldDefUse()
739-
for pc in finfo
740-
if pc > 0
741-
push!(fdu.uses, GetfieldLoad(pc)) # use (getfield call)
742-
else
743-
push!(fdu.defs, -pc) # def (setfield! call or :new expression)
744-
end
745-
end
746-
fdefuses[i] = fdu
747-
end
748-
749-
Liveness = einfo.Liveness
750-
for livepc in Liveness
751-
livestmt = ir[SSAValue(livepc)][:inst]
752-
if is_known_call(livestmt, Core.ifelse, ir)
753-
# the succeeding domination analysis doesn't account for conditional branching
754-
# by ifelse branching at this moment
755-
@goto next_itr
756-
elseif is_known_call(livestmt, isdefined, ir)
757-
args = livestmt.args
758-
length(args) 3 || continue
759-
obj = args[2]
760-
isa(obj, SSAValue) || continue
761-
obj in related || continue
762-
fld = args[3]
763-
fldval = try_compute_field(ir, fld)
764-
fldval === nothing && continue
765-
typ = unwrap_unionall(widenconst(argextype(obj, ir)))
766-
isa(typ, DataType) || continue
767-
fldidx = try_compute_fieldidx(typ, fldval)
768-
fldidx === nothing && continue
769-
push!(fdefuses[fldidx].uses, IsdefinedUse(livepc))
770-
elseif isexpr(livestmt, :foreigncall)
771-
# we shouldn't eliminate this use if it's used as a direct argument
772-
args = livestmt.args
773-
nccallargs = length(args[3]::SimpleVector)
774-
for i = 6:(5+nccallargs)
775-
arg = args[i]
776-
isa(arg, SSAValue) && arg in related && @goto next_liveness
777-
end
778-
# this use is preserve, and may be eliminable
779-
for fidx in 1:nfields
780-
push!(fdefuses[fidx].uses, PreserveUse(livepc))
781-
end
782-
end
783-
@label next_liveness
784-
end
785735

786-
for fidx in 1:nfields
787-
fdu = fdefuses[fidx]
788-
isempty(fdu.uses) && @goto next_use
789-
# check if all uses have safe definitions first, otherwise we should bail out
790-
# since then we may fail to form new ϕ-nodes
791-
ldu = compute_live_ins(ir.cfg, fdu)
792-
if isempty(ldu.live_in_bbs)
793-
phiblocks = Int[]
794-
else
795-
phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree)
796-
end
797-
allblocks = sort!(vcat(phiblocks, ldu.def_bbs))
798-
for use in fdu.uses
799-
isa(use, IsdefinedUse) && continue
800-
if isa(use, PreserveUse) && isempty(fdu.defs)
801-
# nothing to preserve, just ignore this use (may happen when there are unintialized fields)
802-
continue
803-
end
804-
if !has_safe_def(ir, domtree, allblocks, fdu, getuseidx(use))
805-
all_preserved = false
806-
@goto next_use
807-
end
808-
end
809-
phinodes = IdDict{Int, SSAValue}()
810-
for b in phiblocks
811-
phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts),
812-
NewInstruction(PhiNode(), Any))
813-
end
814-
# Now go through all uses and rewrite them
815-
for use in fdu.uses
816-
if isa(use, GetfieldLoad)
817-
use = getuseidx(use)
818-
ir[SSAValue(use)][:inst] = compute_value_for_use(
819-
ir, domtree, allblocks, fdu, phinodes, fidx, use)
820-
push!(eliminated, use)
821-
elseif all_preserved && isa(use, PreserveUse)
822-
if newpreserves === nothing
823-
newpreserves = IdDict{Int,Vector{Any}}()
824-
end
825-
# record this `use` as replaceable no matter if we preserve new value or not
826-
use = getuseidx(use)
827-
newvalues = get!(()->Any[], newpreserves, use)
828-
isempty(fdu.defs) && continue # nothing to preserve (may happen when there are unintialized fields)
829-
newval = compute_value_for_use(
830-
ir, domtree, allblocks, fdu, phinodes, fidx, use)
831-
if !isbitstype(widenconst(argextype(newval, ir)))
832-
push!(newvalues, newval)
833-
end
834-
elseif isa(use, IsdefinedUse)
835-
use = getuseidx(use)
836-
if has_safe_def(ir, domtree, allblocks, fdu, use)
837-
ir[SSAValue(use)][:inst] = true
838-
push!(eliminated, use)
839-
end
840-
else
841-
throw("unexpected use")
842-
end
843-
end
844-
for b in phiblocks
845-
ϕssa = phinodes[b]
846-
n = ir[ϕssa][:inst]::PhiNode
847-
t = Bottom
848-
for p in ir.cfg.blocks[b].preds
849-
push!(n.edges, p)
850-
v = compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, p)
851-
push!(n.values, v)
852-
if t !== Any
853-
t = tmerge(t, argextype(v, ir))
854-
end
855-
end
856-
ir[ϕssa][:type] = t
857-
end
858-
@label next_use
736+
AliasInfo = einfo.AliasInfo
737+
if isa(AliasInfo, IndexableFields)
738+
@assert isexpr(stmt, :new) "invalid escape analysis"
739+
all_preserved &= load_forward_object!(ir, domtree,
740+
eliminated, revisit,
741+
newpreserves, related,
742+
AliasInfo, einfo.Liveness)
743+
else
744+
@assert is_array_alloc(stmt) "invalid escape analysis"
745+
arrayinfo = estate.arrayinfo
746+
@assert isa(arrayinfo, ArrayInfo) && haskey(arrayinfo, idx) "invalid escape analysis"
747+
dims = arrayinfo[idx]
748+
all_preserved &= load_forward_array!(ir, domtree,
749+
eliminated, revisit,
750+
newpreserves, related,
751+
AliasInfo::IndexableElements, einfo.Liveness, dims)
859752
end
860-
push!(revisit, (related, Liveness))
861-
@label next_itr
862753
end
863754

864755
# remove dead setfield! and :new allocs
865756
deadssas = IdSet{SSAValue}()
866-
if all_preserved && newpreserves !== nothing
757+
if all_preserved
867758
preserved = keys(newpreserves)
868759
else
869760
preserved = EMPTY_PRESERVED_SSAS
870761
end
871762
mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved)
872763
for ssa in deadssas
764+
# stmt = ir[ssa][:inst]
765+
# if is_known_call(stmt, setfield!, ir)
766+
# println("[SROA] eliminated setfield!: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt)
767+
# elseif isexpr(stmt, :new)
768+
# println("[SROA] eliminated object alloc: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt)
769+
# elseif is_known_call(stmt, arrayset, ir)
770+
# println("[SROA] eliminated arrayset: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt)
771+
# elseif is_array_alloc(stmt)
772+
# println("[SROA] eliminated array alloc: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt)
773+
# end
873774
ir[ssa][:inst] = nothing
874775
end
875-
if all_preserved && newpreserves !== nothing
776+
if all_preserved
876777
deadssas = Int[ssa.id for ssa in deadssas]
877778
for (idx, newuses) in newpreserves
878779
ir[SSAValue(idx)][:inst] = form_new_preserves(
@@ -883,20 +784,289 @@ function sroa_mutables!(ir::IRCode, nargs::Int, mi_cache::MICache) where MICache
883784
return ir
884785
end
885786

787+
function load_forward_object!(ir::IRCode, domtree::DomTree,
788+
eliminated::BitSet, revisit::Vector{Tuple{Vector{SSAValue}, LivenessSet}},
789+
newpreserves::IdDict{Int,Vector{Any}}, related::Vector{SSAValue},
790+
AliasInfo::IndexableFields, Liveness::LivenessSet)
791+
finfos = AliasInfo.infos
792+
nfields = length(finfos)
793+
794+
# Partition defuses by field
795+
all_preserved = true
796+
fdefuses = Vector{IndexedDefUse}(undef, nfields)
797+
for i = 1:nfields
798+
finfo = finfos[i]
799+
idu = IndexedDefUse()
800+
for pc in finfo
801+
if pc > 0
802+
push!(idu.uses, LoadUse(pc)) # use (getfield call)
803+
else
804+
push!(idu.defs, -pc) # def (setfield! call or :new expression)
805+
end
806+
end
807+
fdefuses[i] = idu
808+
end
809+
810+
for livepc in Liveness
811+
livestmt = ir[SSAValue(livepc)][:inst]
812+
if is_known_call(livestmt, Core.ifelse, ir)
813+
# the succeeding domination analysis doesn't account for conditional branching
814+
# by ifelse branching at this moment
815+
return false
816+
elseif is_known_call(livestmt, isdefined, ir)
817+
args = livestmt.args
818+
length(args) 3 || continue
819+
obj = args[2]
820+
isa(obj, SSAValue) || continue
821+
obj in related || continue
822+
fld = args[3]
823+
fldval = try_compute_field(ir, fld)
824+
fldval === nothing && continue
825+
typ = unwrap_unionall(widenconst(argextype(obj, ir)))
826+
isa(typ, DataType) || continue
827+
fldidx = try_compute_fieldidx(typ, fldval)
828+
fldidx === nothing && continue
829+
push!(fdefuses[fldidx].uses, IsdefinedUse(livepc))
830+
elseif isexpr(livestmt, :foreigncall)
831+
# we shouldn't eliminate this use if it's used as a direct argument
832+
args = livestmt.args
833+
nccallargs = length(args[3]::SimpleVector)
834+
for i = 6:(5+nccallargs)
835+
arg = args[i]
836+
isa(arg, SSAValue) && arg in related && @goto next_liveness
837+
end
838+
# this use is preserve, and may be eliminable
839+
for fidx in 1:nfields
840+
push!(fdefuses[fidx].uses, PreserveUse(livepc))
841+
end
842+
end
843+
@label next_liveness
844+
end
845+
846+
for fidx in 1:nfields
847+
idu = fdefuses[fidx]
848+
isempty(idu.uses) && @goto next_use
849+
# check if all uses have safe definitions first, otherwise we should bail out
850+
# since then we may fail to form new ϕ-nodes
851+
ldu = compute_live_ins(ir.cfg, idu)
852+
if isempty(ldu.live_in_bbs)
853+
phiblocks = Int[]
854+
else
855+
phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree)
856+
end
857+
allblocks = sort!(vcat(phiblocks, ldu.def_bbs))
858+
for use in idu.uses
859+
isa(use, IsdefinedUse) && continue
860+
if isa(use, PreserveUse) && isempty(idu.defs)
861+
# nothing to preserve, just ignore this use (may happen when there are unintialized fields)
862+
continue
863+
end
864+
if !has_safe_def(ir, domtree, allblocks, idu, getuseidx(use))
865+
all_preserved = false
866+
@goto next_use
867+
end
868+
end
869+
phinodes = IdDict{Int, SSAValue}()
870+
for b in phiblocks
871+
phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts),
872+
NewInstruction(PhiNode(), Any))
873+
end
874+
# Now go through all uses and rewrite them
875+
for use in idu.uses
876+
if isa(use, LoadUse)
877+
use = getuseidx(use)
878+
ir[SSAValue(use)][:inst] = compute_value_for_use(
879+
ir, domtree, allblocks, idu, phinodes, fidx, use)
880+
push!(eliminated, use)
881+
elseif isa(use, PreserveUse)
882+
all_preserved || continue
883+
# record this `use` as replaceable no matter if we preserve new value or not
884+
use = getuseidx(use)
885+
newvalues = get!(()->Any[], newpreserves, use)
886+
isempty(idu.defs) && continue # nothing to preserve (may happen when there are unintialized fields)
887+
newval = compute_value_for_use(
888+
ir, domtree, allblocks, idu, phinodes, fidx, use)
889+
if !isbitstype(widenconst(argextype(newval, ir)))
890+
push!(newvalues, newval)
891+
end
892+
elseif isa(use, IsdefinedUse)
893+
use = getuseidx(use)
894+
if has_safe_def(ir, domtree, allblocks, idu, use)
895+
ir[SSAValue(use)][:inst] = true
896+
push!(eliminated, use)
897+
end
898+
else
899+
throw("load_forward_object!: unexpected use")
900+
end
901+
end
902+
for b in phiblocks
903+
ϕssa = phinodes[b]
904+
n = ir[ϕssa][:inst]::PhiNode
905+
t = Bottom
906+
for p in ir.cfg.blocks[b].preds
907+
push!(n.edges, p)
908+
v = compute_value_for_block(ir, domtree, allblocks, idu, phinodes, fidx, p)
909+
push!(n.values, v)
910+
if t !== Any
911+
t = tmerge(t, argextype(v, ir))
912+
end
913+
end
914+
ir[ϕssa][:type] = t
915+
end
916+
@label next_use
917+
end
918+
push!(revisit, (related, Liveness))
919+
920+
return all_preserved
921+
end
922+
923+
# TODO is_array_isassigned folding?
924+
function load_forward_array!(ir::IRCode, domtree::DomTree,
925+
eliminated::BitSet, revisit::Vector{Tuple{Vector{SSAValue}, LivenessSet}},
926+
newpreserves::IdDict{Int,Vector{Any}}, related::Vector{SSAValue},
927+
AliasInfo::IndexableElements, Liveness::LivenessSet, dims::Vector{Int})
928+
elminfos = AliasInfo.infos
929+
elmkeys = keys(elminfos)
930+
931+
# Partition defuses by index
932+
all_preserved = true
933+
edefuses = IdDict{Int,IndexedDefUse}()
934+
for eidx in elmkeys
935+
einfo = elminfos[eidx]
936+
idu = IndexedDefUse()
937+
for pc in einfo
938+
if pc > 0
939+
push!(idu.uses, LoadUse(pc)) # use (arrayref call)
940+
else
941+
push!(idu.defs, -pc) # def (arrayset call)
942+
end
943+
end
944+
edefuses[eidx] = idu
945+
end
946+
947+
for livepc in Liveness
948+
ssa = SSAValue(livepc)
949+
livestmt = ir[ssa][:inst]
950+
if is_known_call(livestmt, Core.ifelse, ir)
951+
# the succeeding domination analysis doesn't account for conditional branching
952+
# by ifelse branching at this moment
953+
return false
954+
elseif is_known_call(livestmt, arraylen, ir)
955+
len = 1
956+
for dim in dims
957+
len *= dim
958+
end
959+
ir[ssa][:inst] = len
960+
push!(eliminated, livepc)
961+
elseif is_known_call(livestmt, arraysize, ir)
962+
length(livestmt.args) 3 || continue
963+
dim = argextype(livestmt.args[3], ir)
964+
isa(dim, Const) || continue
965+
dim = dim.val
966+
isa(dim, Int) || continue
967+
checkbounds(Bool, dims, dim) || continue
968+
ir[ssa][:inst] = dims[dim]
969+
push!(eliminated, livepc)
970+
elseif isexpr(livestmt, :foreigncall)
971+
# we shouldn't eliminate this use if it's used as a direct argument
972+
args = livestmt.args
973+
nccallargs = length(args[3]::SimpleVector)
974+
for i = 6:(5+nccallargs)
975+
arg = args[i]
976+
isa(arg, SSAValue) && arg in related && @goto next_liveness
977+
end
978+
# this use is preserve, and may be eliminable
979+
for eidx in elmkeys
980+
push!(edefuses[eidx].uses, PreserveUse(livepc))
981+
end
982+
end
983+
@label next_liveness
984+
end
985+
986+
for eidx in elmkeys
987+
idu = edefuses[eidx]
988+
isempty(idu.uses) && @goto next_use
989+
# check if all uses have safe definitions first, otherwise we should bail out
990+
# since then we may fail to form new ϕ-nodes
991+
ldu = compute_live_ins(ir.cfg, idu)
992+
if isempty(ldu.live_in_bbs)
993+
phiblocks = Int[]
994+
else
995+
phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree)
996+
end
997+
allblocks = sort!(vcat(phiblocks, ldu.def_bbs))
998+
for use in idu.uses
999+
if isa(use, PreserveUse) && isempty(idu.defs)
1000+
# nothing to preserve, just ignore this use (may happen when there are unintialized fields)
1001+
continue
1002+
end
1003+
if !has_safe_def(ir, domtree, allblocks, idu, getuseidx(use))
1004+
all_preserved = false
1005+
@goto next_use
1006+
end
1007+
end
1008+
phinodes = IdDict{Int, SSAValue}()
1009+
for b in phiblocks
1010+
phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts),
1011+
NewInstruction(PhiNode(), Any))
1012+
end
1013+
# Now go through all uses and rewrite them
1014+
for use in idu.uses
1015+
if isa(use, LoadUse)
1016+
use = getuseidx(use)
1017+
ir[SSAValue(use)][:inst] = compute_value_for_use(
1018+
ir, domtree, allblocks, idu, phinodes, eidx, use)
1019+
push!(eliminated, use)
1020+
elseif isa(use, PreserveUse)
1021+
all_preserved || continue
1022+
# record this `use` as replaceable no matter if we preserve new value or not
1023+
use = getuseidx(use)
1024+
newvalues = get!(()->Any[], newpreserves, use)
1025+
isempty(idu.defs) && continue # nothing to preserve (may happen when there are unintialized fields)
1026+
newval = compute_value_for_use(
1027+
ir, domtree, allblocks, idu, phinodes, eidx, use)
1028+
if !isbitstype(widenconst(argextype(newval, ir)))
1029+
push!(newvalues, newval)
1030+
end
1031+
else
1032+
throw("load_forward_array!: unexpected use")
1033+
end
1034+
end
1035+
for b in phiblocks
1036+
ϕssa = phinodes[b]
1037+
n = ir[ϕssa][:inst]::PhiNode
1038+
t = Bottom
1039+
for p in ir.cfg.blocks[b].preds
1040+
push!(n.edges, p)
1041+
v = compute_value_for_block(ir, domtree, allblocks, idu, phinodes, eidx, p)
1042+
push!(n.values, v)
1043+
if t !== Any
1044+
t = tmerge(t, argextype(v, ir))
1045+
end
1046+
end
1047+
ir[ϕssa][:type] = t
1048+
end
1049+
@label next_use
1050+
end
1051+
push!(revisit, (related, Liveness))
1052+
1053+
return all_preserved
1054+
end
1055+
8861056
const EMPTY_PRESERVED_SSAS = keys(IdDict{Int,Vector{Any}}())
8871057
const PreservedSets = typeof(EMPTY_PRESERVED_SSAS)
8881058

8891059
function is_load_forwardable(x::EscapeInfo)
8901060
AliasInfo = x.AliasInfo
891-
return isa(AliasInfo, IndexableFields)
1061+
return isa(AliasInfo, IndexableFields) || isa(AliasInfo, IndexableElements)
8921062
end
8931063

894-
struct FieldDefUse
1064+
struct IndexedDefUse
8951065
uses::Vector{Any}
8961066
defs::Vector{Int}
8971067
end
898-
FieldDefUse() = FieldDefUse(Any[], Int[])
899-
struct GetfieldLoad
1068+
IndexedDefUse() = IndexedDefUse(Any[], Int[])
1069+
struct LoadUse
9001070
idx::Int
9011071
end
9021072
struct PreserveUse
@@ -906,7 +1076,7 @@ struct IsdefinedUse
9061076
idx::Int
9071077
end
9081078
function getuseidx(@nospecialize use)
909-
if isa(use, GetfieldLoad)
1079+
if isa(use, LoadUse)
9101080
return use.idx
9111081
elseif isa(use, PreserveUse)
9121082
return use.idx
@@ -916,21 +1086,21 @@ function getuseidx(@nospecialize use)
9161086
throw("getuseidx: unexpected use")
9171087
end
9181088

919-
function compute_live_ins(cfg::CFG, fdu::FieldDefUse)
1089+
function compute_live_ins(cfg::CFG, idu::IndexedDefUse)
9201090
uses = Int[]
921-
for use in fdu.uses
1091+
for use in idu.uses
9221092
isa(use, IsdefinedUse) && continue
9231093
push!(uses, getuseidx(use))
9241094
end
925-
return compute_live_ins(cfg, fdu.defs, uses)
1095+
return compute_live_ins(cfg, idu.defs, uses)
9261096
end
9271097

9281098
# even when the allocation contains an uninitialized field, we try an extra effort to check
9291099
# if this load at `idx` have any "safe" `setfield!` calls that define the field
9301100
# try to find
9311101
function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
932-
fdu::FieldDefUse, use::Int)
933-
dfu = find_def_for_use(ir, domtree, allblocks, fdu, use)
1102+
idu::IndexedDefUse, use::Int)
1103+
dfu = find_def_for_use(ir, domtree, allblocks, idu, use)
9341104
dfu === nothing && return false
9351105
def = dfu[1]
9361106
def 0 && return true # found a "safe" definition
@@ -946,7 +1116,7 @@ function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
9461116
pred in seen && return false
9471117
use = last(ir.cfg.blocks[pred].stmts)
9481118
# NOTE this `use` isn't a load, and so the inclusive condition can be used
949-
dfu = find_def_for_use(ir, domtree, allblocks, fdu, use, true)
1119+
dfu = find_def_for_use(ir, domtree, allblocks, idu, use, true)
9501120
dfu === nothing && return false
9511121
def = dfu[1]
9521122
push!(seen, pred)
@@ -961,12 +1131,12 @@ end
9611131

9621132
# find the first dominating def for the given use
9631133
function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
964-
fdu::FieldDefUse, use::Int, inclusive::Bool=false)
1134+
idu::IndexedDefUse, use::Int, inclusive::Bool=false)
9651135
useblock = block_for_inst(ir.cfg, use)
9661136
curblock = find_curblock(domtree, allblocks, useblock)
9671137
curblock === nothing && return nothing
9681138
local def = 0
969-
for idx in fdu.defs
1139+
for idx in idu.defs
9701140
if block_for_inst(ir.cfg, idx) == curblock
9711141
if curblock != useblock
9721142
# Find the last def in this block
@@ -995,15 +1165,15 @@ function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int)
9951165
end
9961166

9971167
function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
998-
fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int)
999-
dfu = find_def_for_use(ir, domtree, allblocks, fdu, use)
1168+
idu::IndexedDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int)
1169+
dfu = find_def_for_use(ir, domtree, allblocks, idu, use)
10001170
@assert dfu !== nothing "has_safe_def condition unsatisfied"
10011171
def, useblock, curblock = dfu
10021172
if def == 0
10031173
if !haskey(phinodes, curblock)
10041174
# If this happens, we need to search the predecessors for defs. Which
10051175
# one doesn't matter - if it did, we'd have had a phinode
1006-
return compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds))
1176+
return compute_value_for_block(ir, domtree, allblocks, idu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds))
10071177
end
10081178
# The use is the phinode
10091179
return phinodes[curblock]
@@ -1013,11 +1183,11 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
10131183
end
10141184

10151185
function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
1016-
fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int)
1186+
idu::IndexedDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int)
10171187
curblock = find_curblock(domtree, allblocks, curblock)
10181188
@assert curblock !== nothing "has_safe_def condition unsatisfied"
10191189
def = 0
1020-
for stmt in fdu.defs
1190+
for stmt in idu.defs
10211191
if block_for_inst(ir.cfg, stmt) == curblock
10221192
def = max(def, stmt)
10231193
end
@@ -1029,9 +1199,12 @@ function val_for_def_expr(ir::IRCode, def::Int, fidx::Int)
10291199
ex = ir[SSAValue(def)][:inst]
10301200
if isexpr(ex, :new)
10311201
return ex.args[1+fidx]
1032-
else
1033-
@assert is_known_call(ex, setfield!, ir) "invalid load forwarding"
1202+
elseif is_known_call(ex, setfield!, ir)
10341203
return ex.args[4]
1204+
elseif is_known_call(ex, arrayset, ir)
1205+
return ex.args[4]
1206+
else
1207+
throw("invalid load forwarding")
10351208
end
10361209
end
10371210

@@ -1100,6 +1273,34 @@ function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue},
11001273
end
11011274
end
11021275
return false
1276+
elseif is_known_call(stmt, arrayset, ir)
1277+
@assert length(stmt.args) 4 "invalid escape analysis"
1278+
ary = stmt.args[3]
1279+
val = stmt.args[4]
1280+
if isa(ary, SSAValue)
1281+
if ary in related
1282+
push!(eliminable, ssa)
1283+
@goto next_live
1284+
end
1285+
if isa(val, SSAValue) && val in related
1286+
if ary in deadssas
1287+
push!(eliminable, ssa)
1288+
@goto next_live
1289+
end
1290+
for new_revisit_idx in wset
1291+
if ary in revisit[new_revisit_idx][1]
1292+
delete!(wset, new_revisit_idx)
1293+
if mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved, wset, new_revisit_idx)
1294+
push!(eliminable, ssa)
1295+
@goto next_live
1296+
else
1297+
return false
1298+
end
1299+
end
1300+
end
1301+
end
1302+
end
1303+
return false
11031304
elseif isexpr(stmt, :foreigncall)
11041305
livepc in preserved && @goto next_live
11051306
return false

‎test/compiler/codegen.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -548,27 +548,27 @@ end
548548
# main use case
549549
function f1(cond)
550550
val = [1]
551-
GC.@preserve val begin end
551+
GC.@preserve val begin val end
552552
end
553553
@test occursin("llvm.julia.gc_preserve_begin", get_llvm(f1, Tuple{Bool}, true, false, false))
554554

555555
# stack allocated objects (JuliaLang/julia#34241)
556556
function f3(cond)
557557
val = ([1],)
558-
GC.@preserve val begin end
558+
GC.@preserve val begin val end
559559
end
560560
@test occursin("llvm.julia.gc_preserve_begin", get_llvm(f3, Tuple{Bool}, true, false, false))
561561

562562
# unions of immutables (JuliaLang/julia#39501)
563563
function f2(cond)
564564
val = cond ? 1 : 1f0
565-
GC.@preserve val begin end
565+
GC.@preserve val begin val end
566566
end
567567
@test !occursin("llvm.julia.gc_preserve_begin", get_llvm(f2, Tuple{Bool}, true, false, false))
568568
# make sure the fix for the above doesn't regress #34241
569569
function f4(cond)
570570
val = cond ? ([1],) : ([1f0],)
571-
GC.@preserve val begin end
571+
GC.@preserve val begin val end
572572
end
573573
@test occursin("llvm.julia.gc_preserve_begin", get_llvm(f4, Tuple{Bool}, true, false, false))
574574
end

‎test/compiler/irpasses.jl

+134-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Core:
99
# utilities
1010
# =========
1111

12-
import Core.Compiler: argextype, singleton_type, widenconst
12+
import Core.Compiler: argextype, singleton_type, widenconst, is_array_alloc
1313

1414
argextype(@nospecialize args...) = argextype(args..., Any[])
1515
code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::CodeInfo
@@ -100,15 +100,24 @@ end
100100
# SROA
101101
# ====
102102

103-
is_load_forwarded(src::CodeInfo) = !any(iscall((src, getfield)), src.code)
104-
is_scalar_replaced(src::CodeInfo) =
105-
is_load_forwarded(src) && !any(iscall((src, setfield!)), src.code) && !any(isnew, src.code)
103+
is_load_forwarded(src::CodeInfo) =
104+
!any(iscall((src, getfield)), src.code) && !any(iscall((src, Core.arrayref)), src.code)
105+
function is_scalar_replaced(src::CodeInfo)
106+
is_load_forwarded(src) || return false
107+
any(iscall((src, setfield!)), src.code) && return false
108+
any(isnew, src.code) && return false
109+
any(iscall((src, Core.arrayset)), src.code) && return false
110+
any(is_array_alloc, src.code) && return false
111+
return true
112+
end
106113

107114
function is_load_forwarded(@nospecialize(T), src::CodeInfo)
108115
for i in 1:length(src.code)
109116
x = src.code[i]
110117
if iscall((src, getfield), x)
111118
widenconst(argextype(x.args[1], src)) <: T && return false
119+
elseif iscall((src, Core.arrayref), x)
120+
widenconst(argextype(x.args[1], src)) <: T && return false
112121
end
113122
end
114123
return true
@@ -121,6 +130,10 @@ function is_scalar_replaced(@nospecialize(T), src::CodeInfo)
121130
widenconst(argextype(x.args[1], src)) <: T && return false
122131
elseif isnew(x)
123132
widenconst(argextype(SSAValue(i), src)) <: T && return false
133+
elseif iscall((src, Core.arrayset), x)
134+
widenconst(argextype(x.args[1], src)) <: T && return false
135+
elseif is_array_alloc(x)
136+
widenconst(argextype(SSAValue(i), src)) <: T && return false
124137
end
125138
end
126139
return true
@@ -736,7 +749,7 @@ function mutable_ϕ_elim(x, xs)
736749
return r[]
737750
end
738751
let src = code_typed1(mutable_ϕ_elim, (String, Vector{String}))
739-
@test is_scalar_replaced(src)
752+
@test is_scalar_replaced(Ref{String}, src)
740753

741754
xs = String[string(gensym()) for _ in 1:100]
742755
mutable_ϕ_elim("init", xs)
@@ -875,7 +888,7 @@ function isdefined_elim()
875888
return arr
876889
end
877890
let src = code_typed1(isdefined_elim)
878-
@test is_scalar_replaced(src)
891+
@test count(isnew, src.code) == 0 # eliminates closure constructs
879892
end
880893
@test isdefined_elim() == Any[]
881894

@@ -930,6 +943,121 @@ let # immutable case
930943
@test count(isnew, src.code) == 0
931944
end
932945

946+
# array SROA
947+
# ----------
948+
949+
let src = code_typed1((Any,)) do s
950+
a = Vector{Any}(undef, 1)
951+
a[1] = s
952+
return a[1]
953+
end
954+
@test is_scalar_replaced(src)
955+
end
956+
let src = code_typed1((Any,)) do s
957+
a = Any[nothing]
958+
a[1] = s
959+
return a[1]
960+
end
961+
@test is_scalar_replaced(src)
962+
end
963+
let src = code_typed1((String,String)) do s, t
964+
a = Vector{Any}(undef, 2)
965+
a[1] = Ref(s)
966+
a[2] = Ref(t)
967+
return a[1]
968+
end
969+
@test count(isnew, src.code) == 1
970+
end
971+
let src = code_typed1((String,)) do s
972+
a = Vector{Base.RefValue{String}}(undef, 1)
973+
a[1] = Ref(s)
974+
return a[1][]
975+
end
976+
@test is_scalar_replaced(src)
977+
end
978+
let src = code_typed1((String,String)) do s, t
979+
a = Vector{Base.RefValue{String}}(undef, 2)
980+
a[1] = Ref(s)
981+
a[2] = Ref(t)
982+
return a[1][]
983+
end
984+
@test is_scalar_replaced(src)
985+
end
986+
let src = code_typed1((Any,)) do s
987+
a = Vector{Any}[Any[nothing]]
988+
a[1][1] = s
989+
return a[1][1]
990+
end
991+
@test_broken is_scalar_replaced(src)
992+
end
993+
let src = code_typed1((Bool,Any,Any)) do c, s, t
994+
a = Any[nothing]
995+
if c
996+
a[1] = s
997+
else
998+
a[1] = t
999+
end
1000+
return a[1]
1001+
end
1002+
@test is_scalar_replaced(src)
1003+
end
1004+
let src = code_typed1((Bool,Any,Any,Any,Any,)) do c, s1, s2, t1, t2
1005+
if c
1006+
a = Vector{Any}(undef, 2)
1007+
a[1] = s1
1008+
a[2] = s2
1009+
else
1010+
a = Vector{Any}(undef, 2)
1011+
a[1] = t1
1012+
a[2] = t2
1013+
end
1014+
return a[1]
1015+
end
1016+
@test is_scalar_replaced(src)
1017+
end
1018+
let src = code_typed1((Bool,Any,Any)) do c, s, t
1019+
# XXX this implicitly forms tuple to getfield chains
1020+
# and SROA on it produces complicated control flow
1021+
if c
1022+
a = Any[s]
1023+
else
1024+
a = Any[t]
1025+
end
1026+
return a[1]
1027+
end
1028+
@test_broken is_scalar_replaced(src)
1029+
end
1030+
1031+
# arraylen / arraysize elimination
1032+
let src = code_typed1((Any,)) do s
1033+
a = Vector{Any}(undef, 1)
1034+
a[1] = s
1035+
return a[1], length(a)
1036+
end
1037+
@test is_scalar_replaced(src)
1038+
end
1039+
let src = code_typed1((Any,)) do s
1040+
a = Matrix{Any}(undef, 2, 2)
1041+
a[1, 1] = s
1042+
return a[1, 1], length(a)
1043+
end
1044+
@test is_scalar_replaced(src)
1045+
end
1046+
let src = code_typed1((Any,)) do s
1047+
a = Vector{Any}(undef, 1)
1048+
a[1] = s
1049+
return a[1], size(a, 1)
1050+
end
1051+
@test is_scalar_replaced(src)
1052+
end
1053+
let src = code_typed1((Any,)) do s
1054+
a = Matrix{Any}(undef, 2, 2)
1055+
a[1, 1] = s
1056+
return a[1, 1], size(a)
1057+
end
1058+
@test is_scalar_replaced(src)
1059+
end
1060+
9331061
# comparison lifting
9341062
# ==================
9351063

0 commit comments

Comments
 (0)
Please sign in to comment.