Skip to content

Commit 7fda00b

Browse files
aviateskLilithHafner
authored andcommitted
optimizer: compute side-effect-freeness for array allocations (JuliaLang#43565)
This would be useful for Julia-level optimizations on arrays. Initially I want to have this in order to add array primitives support in EscapeAnalysis.jl, which should help us implement a variety of array optimizations including dead array allocation elimination, copy-elision from `Array` to `ImmutableArray` conversion (JuliaLang#42465), etc., but I found this might be already useful for us since this enables some DCE in very simple cases like: ```julia julia> function simple!(x::T) where T d = IdDict{T,T}() # dead alloc # ... computations that don't use `d` at all return nothing end simple! (generic function with 1 method) julia> @code_typed simple!("foo") CodeInfo( 1 ─ return Main.nothing ) => Nothing ``` This enhancement is super limited though, e.g. DCE won't happen when array allocation involves other primitive operations like `arrayset`: ```julia julia> code_typed() do a = Int[0,1,2] nothing end 1-element Vector{Any}: CodeInfo( 1 ─ %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Int64}, svec(Any, Int64), 0, :(:ccall), Vector{Int64}, 3, 3))::Vector{Int64} │ Base.arrayset(false, %1, 0, 1)::Vector{Int64} │ Base.arrayset(false, %1, 1, 2)::Vector{Int64} │ Base.arrayset(false, %1, 2, 3)::Vector{Int64} └── return Main.nothing ) => Nothing ``` Further enhancement o optimize cases like above will be based on top of incoming EA.jl (Julia-level escape analysis) or LLVM-level escape analysis.
1 parent 2c08297 commit 7fda00b

File tree

11 files changed

+205
-62
lines changed

11 files changed

+205
-62
lines changed

base/array.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,17 @@ end
193193

194194

195195
"""
196-
Base.bitsunionsize(U::Union)
196+
Base.bitsunionsize(U::Union) -> Int
197197
198198
For a `Union` of [`isbitstype`](@ref) types, return the size of the largest type; assumes `Base.isbitsunion(U) == true`.
199199
200200
# Examples
201201
```jldoctest
202202
julia> Base.bitsunionsize(Union{Float64, UInt8})
203-
0x0000000000000008
203+
8
204204
205205
julia> Base.bitsunionsize(Union{Float64, UInt8, Int128})
206-
0x0000000000000010
206+
16
207207
```
208208
"""
209209
function bitsunionsize(u::Union)

base/boot.jl

-1
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ unsafe_convert(::Type{T}, x::T) where {T} = x
444444

445445
const NTuple{N,T} = Tuple{Vararg{T,N}}
446446

447-
448447
## primitive Array constructors
449448
struct UndefInitializer end
450449
const undef = UndefInitializer()

base/compiler/optimize.jl

+72
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRC
236236
eT fT || return false
237237
end
238238
return true
239+
elseif head === :foreigncall
240+
return foreigncall_effect_free(stmt, rt, src)
239241
elseif head === :new_opaque_closure
240242
length(args) < 5 && return false
241243
typ = argextype(args[1], src)
@@ -260,6 +262,76 @@ function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRC
260262
return true
261263
end
262264

265+
function foreigncall_effect_free(stmt::Expr, @nospecialize(rt), src::Union{IRCode,IncrementalCompact})
266+
args = stmt.args
267+
name = args[1]
268+
isa(name, QuoteNode) && (name = name.value)
269+
isa(name, Symbol) || return false
270+
ndims = alloc_array_ndims(name)
271+
if ndims !== nothing
272+
if ndims == 0
273+
return new_array_no_throw(args, src)
274+
else
275+
return alloc_array_no_throw(args, ndims, src)
276+
end
277+
end
278+
return false
279+
end
280+
281+
function alloc_array_ndims(name::Symbol)
282+
if name === :jl_alloc_array_1d
283+
return 1
284+
elseif name === :jl_alloc_array_2d
285+
return 2
286+
elseif name === :jl_alloc_array_3d
287+
return 3
288+
elseif name === :jl_new_array
289+
return 0
290+
end
291+
return nothing
292+
end
293+
294+
function alloc_array_no_throw(args::Vector{Any}, ndims::Int, src::Union{IRCode,IncrementalCompact})
295+
length(args) ndims+6 || return false
296+
atype = widenconst(argextype(args[6], src))
297+
isType(atype) || return false
298+
atype = atype.parameters[1]
299+
dims = Csize_t[]
300+
for i in 1:ndims
301+
dim = argextype(args[i+6], src)
302+
isa(dim, Const) || return false
303+
dimval = dim.val
304+
isa(dimval, Int) || return false
305+
push!(dims, reinterpret(Csize_t, dimval))
306+
end
307+
return _new_array_no_throw(atype, ndims, dims)
308+
end
309+
310+
function new_array_no_throw(args::Vector{Any}, src::Union{IRCode,IncrementalCompact})
311+
length(args) 7 || return false
312+
atype = widenconst(argextype(args[6], src))
313+
isType(atype) || return false
314+
atype = atype.parameters[1]
315+
dims = argextype(args[7], src)
316+
isa(dims, Const) || return dims === Tuple{}
317+
dimsval = dims.val
318+
isa(dimsval, Tuple{Vararg{Int}}) || return false
319+
ndims = nfields(dimsval)
320+
isa(ndims, Int) || return false
321+
dims = Csize_t[reinterpret(Csize_t, dimval) for dimval in dimsval]
322+
return _new_array_no_throw(atype, ndims, dims)
323+
end
324+
325+
function _new_array_no_throw(@nospecialize(atype), ndims::Int, dims::Vector{Csize_t})
326+
isa(atype, DataType) || return false
327+
eltype = atype.parameters[1]
328+
iskindtype(typeof(eltype)) || return false
329+
elsz = aligned_sizeof(eltype)
330+
return ccall(:jl_array_validate_dims, Cint,
331+
(Ptr{Csize_t}, Ptr{Csize_t}, UInt32, Ptr{Csize_t}, Csize_t),
332+
#=nel=#RefValue{Csize_t}(), #=tot=#RefValue{Csize_t}(), ndims, dims, elsz) == 0
333+
end
334+
263335
"""
264336
argextype(x, src::Union{IRCode,IncrementalCompact}) -> t
265337
argextype(x, src::CodeInfo, sptypes::Vector{Any}) -> t

base/compiler/tfuncs.jl

+36-25
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ function sizeof_nothrow(@nospecialize(x))
341341
exact || return false # Could always be the type Bottom at runtime, for example, which throws
342342
t === DataType && return true # DataType itself has a size
343343
if isa(x, Union)
344-
isinline, sz, _ = uniontype_layout(x)
344+
isinline = uniontype_layout(x)[1]
345345
return isinline # even any subset of this union would have a size
346346
end
347347
isa(x, DataType) || return false
@@ -381,7 +381,7 @@ function sizeof_tfunc(@nospecialize(x),)
381381
# Normalize the query to ask about that type.
382382
x = unwrap_unionall(t)
383383
if exact && isa(x, Union)
384-
isinline, sz, _ = uniontype_layout(x)
384+
isinline = uniontype_layout(x)[1]
385385
return isinline ? Const(Int(Core.sizeof(x))) : Bottom
386386
end
387387
isa(x, DataType) || return Int
@@ -1470,30 +1470,27 @@ function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva),
14701470
end
14711471

14721472
# whether getindex for the elements can potentially throw UndefRef
1473-
function array_type_undefable(@nospecialize(a))
1474-
if isa(a, Union)
1475-
return array_type_undefable(a.a) || array_type_undefable(a.b)
1476-
elseif isa(a, UnionAll)
1473+
function array_type_undefable(@nospecialize(arytype))
1474+
if isa(arytype, Union)
1475+
return array_type_undefable(arytype.a) || array_type_undefable(arytype.b)
1476+
elseif isa(arytype, UnionAll)
14771477
return true
14781478
else
1479-
etype = (a::DataType).parameters[1]
1480-
return !(etype isa Type && (isbitstype(etype) || isbitsunion(etype)))
1479+
elmtype = (arytype::DataType).parameters[1]
1480+
return !(elmtype isa Type && (isbitstype(elmtype) || isbitsunion(elmtype)))
14811481
end
14821482
end
14831483

1484-
function array_builtin_common_nothrow(argtypes::Array{Any,1}, first_idx_idx::Int)
1484+
function array_builtin_common_nothrow(argtypes::Vector{Any}, first_idx_idx::Int)
14851485
length(argtypes) >= 4 || return false
1486-
atype = argtypes[2]
1487-
(argtypes[1] Bool && atype Array) || return false
1488-
for i = first_idx_idx:length(argtypes)
1489-
argtypes[i] Int || return false
1490-
end
1486+
boundcheck = argtypes[1]
1487+
arytype = argtypes[2]
1488+
array_builtin_common_typecheck(boundcheck, arytype, argtypes, first_idx_idx) || return false
14911489
# If we could potentially throw undef ref errors, bail out now.
1492-
atype = widenconst(atype)
1493-
array_type_undefable(atype) && return false
1490+
arytype = widenconst(arytype)
1491+
array_type_undefable(arytype) && return false
14941492
# If we have @inbounds (first argument is false), we're allowed to assume
14951493
# we don't throw bounds errors.
1496-
boundcheck = argtypes[1]
14971494
if isa(boundcheck, Const)
14981495
!(boundcheck.val::Bool) && return true
14991496
end
@@ -1503,19 +1500,33 @@ function array_builtin_common_nothrow(argtypes::Array{Any,1}, first_idx_idx::Int
15031500
return false
15041501
end
15051502

1503+
function array_builtin_common_typecheck(
1504+
@nospecialize(boundcheck), @nospecialize(arytype),
1505+
argtypes::Vector{Any}, first_idx_idx::Int)
1506+
(boundcheck Bool && arytype Array) || return false
1507+
for i = first_idx_idx:length(argtypes)
1508+
argtypes[i] Int || return false
1509+
end
1510+
return true
1511+
end
1512+
1513+
function arrayset_typecheck(@nospecialize(arytype), @nospecialize(elmtype))
1514+
# Check that we can determine the element type
1515+
arytype = widenconst(arytype)
1516+
isa(arytype, DataType) || return false
1517+
elmtype_expected = arytype.parameters[1]
1518+
isa(elmtype_expected, Type) || return false
1519+
# Check that the element type is compatible with the element we're assigning
1520+
elmtype elmtype_expected || return false
1521+
return true
1522+
end
1523+
15061524
# Query whether the given builtin is guaranteed not to throw given the argtypes
15071525
function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecialize(rt))
15081526
if f === arrayset
15091527
array_builtin_common_nothrow(argtypes, 4) || return true
15101528
# Additionally check element type compatibility
1511-
a = widenconst(argtypes[2])
1512-
# Check that we can determine the element type
1513-
isa(a, DataType) || return false
1514-
ap1 = a.parameters[1]
1515-
isa(ap1, Type) || return false
1516-
# Check that the element type is compatible with the element we're assigning
1517-
argtypes[3] ap1 || return false
1518-
return true
1529+
return arrayset_typecheck(argtypes[2], argtypes[3])
15191530
elseif f === arrayref || f === const_arrayref
15201531
return array_builtin_common_nothrow(argtypes, 3)
15211532
elseif f === Core._expr

base/reflection.jl

+7-5
Original file line numberDiff line numberDiff line change
@@ -351,22 +351,24 @@ function datatype_alignment(dt::DataType)
351351
return Int(alignment)
352352
end
353353

354-
function uniontype_layout(T::Type)
354+
function uniontype_layout(@nospecialize T::Type)
355355
sz = RefValue{Csize_t}(0)
356356
algn = RefValue{Csize_t}(0)
357357
isinline = ccall(:jl_islayout_inline, Cint, (Any, Ptr{Csize_t}, Ptr{Csize_t}), T, sz, algn) != 0
358-
(isinline, sz[], algn[])
358+
(isinline, Int(sz[]), Int(algn[]))
359359
end
360360

361+
LLT_ALIGN(x, sz) = (x + sz - 1) & -sz
362+
361363
# amount of total space taken by T when stored in a container
362-
function aligned_sizeof(T::Type)
364+
function aligned_sizeof(@nospecialize T::Type)
363365
@_pure_meta
364366
if isbitsunion(T)
365367
_, sz, al = uniontype_layout(T)
366-
return (sz + al - 1) & -al
368+
return LLT_ALIGN(sz, al)
367369
elseif allocatedinline(T)
368370
al = datatype_alignment(T)
369-
return (Core.sizeof(T) + al - 1) & -al
371+
return LLT_ALIGN(Core.sizeof(T), al)
370372
else
371373
return Core.sizeof(Ptr{Cvoid})
372374
end

src/array.c

+28-21
Original file line numberDiff line numberDiff line change
@@ -76,27 +76,40 @@ typedef uint64_t wideint_t;
7676

7777
#define MAXINTVAL (((size_t)-1)>>1)
7878

79+
JL_DLLEXPORT int jl_array_validate_dims(size_t *nel, size_t *tot, uint32_t ndims, size_t *dims, size_t elsz)
80+
{
81+
size_t i;
82+
size_t _nel = 1;
83+
for(i=0; i < ndims; i++) {
84+
size_t di = dims[i];
85+
wideint_t prod = (wideint_t)_nel * (wideint_t)di;
86+
if (prod >= (wideint_t) MAXINTVAL || di >= MAXINTVAL)
87+
return 1;
88+
_nel = prod;
89+
}
90+
wideint_t prod = (wideint_t)elsz * (wideint_t)_nel;
91+
if (prod >= (wideint_t) MAXINTVAL)
92+
return 2;
93+
*nel = _nel;
94+
*tot = (size_t)prod;
95+
return 0;
96+
}
97+
7998
static jl_array_t *_new_array_(jl_value_t *atype, uint32_t ndims, size_t *dims,
80-
int8_t isunboxed, int8_t hasptr, int8_t isunion, int8_t zeroinit, int elsz)
99+
int8_t isunboxed, int8_t hasptr, int8_t isunion, int8_t zeroinit, size_t elsz)
81100
{
82101
jl_task_t *ct = jl_current_task;
83-
size_t i, tot, nel=1;
102+
size_t i, tot, nel;
84103
void *data;
85104
jl_array_t *a;
86-
87-
for(i=0; i < ndims; i++) {
88-
size_t di = dims[i];
89-
wideint_t prod = (wideint_t)nel * (wideint_t)di;
90-
if (prod > (wideint_t) MAXINTVAL || di > MAXINTVAL)
91-
jl_exceptionf(jl_argumenterror_type, "invalid Array dimensions");
92-
nel = prod;
93-
}
105+
assert(isunboxed || elsz == sizeof(void*));
94106
assert(atype == NULL || isunion == jl_is_uniontype(jl_tparam0(atype)));
107+
int validated = jl_array_validate_dims(&nel, &tot, ndims, dims, elsz);
108+
if (validated == 1)
109+
jl_exceptionf(jl_argumenterror_type, "invalid Array dimensions");
110+
else if (validated == 2)
111+
jl_error("invalid Array size");
95112
if (isunboxed) {
96-
wideint_t prod = (wideint_t)elsz * (wideint_t)nel;
97-
if (prod > (wideint_t) MAXINTVAL)
98-
jl_error("invalid Array size");
99-
tot = prod;
100113
if (elsz == 1 && !isunion) {
101114
// extra byte for all julia allocated byte arrays
102115
tot++;
@@ -106,12 +119,6 @@ static jl_array_t *_new_array_(jl_value_t *atype, uint32_t ndims, size_t *dims,
106119
tot += nel;
107120
}
108121
}
109-
else {
110-
wideint_t prod = (wideint_t)sizeof(void*) * (wideint_t)nel;
111-
if (prod > (wideint_t) MAXINTVAL)
112-
jl_error("invalid Array size");
113-
tot = prod;
114-
}
115122

116123
int ndimwords = jl_array_ndimwords(ndims);
117124
int tsz = sizeof(jl_array_t) + ndimwords*sizeof(size_t);
@@ -196,7 +203,7 @@ static inline jl_array_t *_new_array(jl_value_t *atype, uint32_t ndims, size_t *
196203
jl_array_t *jl_new_array_for_deserialization(jl_value_t *atype, uint32_t ndims, size_t *dims,
197204
int isunboxed, int hasptr, int isunion, int elsz)
198205
{
199-
return _new_array_(atype, ndims, dims, isunboxed, hasptr, isunion, 0, elsz);
206+
return _new_array_(atype, ndims, dims, isunboxed, hasptr, isunion, 0, (size_t)elsz);
200207
}
201208

202209
#ifndef JL_NDEBUG

src/jl_exported_funcs.inc

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
XX(jl_array_to_string) \
4646
XX(jl_array_typetagdata) \
4747
XX(jl_arrayunset) \
48+
XX(jl_array_validate_dims) \
4849
XX(jl_atexit_hook) \
4950
XX(jl_atomic_bool_cmpswap_bits) \
5051
XX(jl_atomic_cmpswap_bits) \
@@ -564,4 +565,3 @@
564565
YY(LLVMExtraAddGCInvariantVerifierPass) \
565566
YY(LLVMExtraAddDemoteFloat16Pass) \
566567
YY(LLVMExtraAddCPUFeaturesPass) \
567-

src/julia.h

+1
Original file line numberDiff line numberDiff line change
@@ -1530,6 +1530,7 @@ JL_DLLEXPORT void jl_array_sizehint(jl_array_t *a, size_t sz);
15301530
JL_DLLEXPORT void jl_array_ptr_1d_push(jl_array_t *a, jl_value_t *item);
15311531
JL_DLLEXPORT void jl_array_ptr_1d_append(jl_array_t *a, jl_array_t *a2);
15321532
JL_DLLEXPORT jl_value_t *jl_apply_array_type(jl_value_t *type, size_t dim);
1533+
JL_DLLEXPORT int jl_array_validate_dims(size_t *nel, size_t *tot, uint32_t ndims, size_t *dims, size_t elsz);
15331534
// property access
15341535
JL_DLLEXPORT void *jl_array_ptr(jl_array_t *a);
15351536
JL_DLLEXPORT void *jl_array_eltype(jl_value_t *a);

test/cmdlineargs.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,7 @@ let exename = `$(Base.julia_cmd()) --startup-file=no --color=no`
344344
rm(memfile)
345345
@test popfirst!(got) == " 0 g(x) = x + 123456"
346346
@test popfirst!(got) == " - function f(x)"
347-
if Sys.WORD_SIZE == 64
348-
@test popfirst!(got) == " 48 []"
349-
else
350-
@test popfirst!(got) == " 32 []"
351-
end
347+
@test popfirst!(got) == " - []"
352348
if Sys.WORD_SIZE == 64
353349
# P64 pools with 64 bit tags
354350
@test popfirst!(got) == " 16 Base.invokelatest(g, 0)"

test/compiler/inline.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ end
151151
end
152152

153153
function fully_eliminated(f, args)
154+
@nospecialize f args
154155
let code = code_typed(f, args)[1][1].code
155156
return length(code) == 1 && isa(code[1], ReturnNode)
156157
end
157158
end
158-
159159
function fully_eliminated(f, args, retval)
160+
@nospecialize f args
160161
let code = code_typed(f, args)[1][1].code
161162
return length(code) == 1 && isa(code[1], ReturnNode) && code[1].val == retval
162163
end

0 commit comments

Comments
 (0)