Skip to content

Commit 5257c76

Browse files
committed
Unify static_<first/step/last> with Static
This extends the `Static` definitions by specializing specifically on array types.
1 parent a29940f commit 5257c76

File tree

1 file changed

+40
-35
lines changed

1 file changed

+40
-35
lines changed

src/StaticArrayInterface.jl

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,35 @@ using PrecompileTools
1919
@recompile_invalidations begin
2020
using ArrayInterface
2121
import ArrayInterface: allowed_getindex, allowed_setindex!, aos_to_soa, buffer,
22-
parent_type, fast_matrix_colors, findstructralnz,
23-
has_sparsestruct,
24-
issingular, isstructured, matrix_colors, restructure,
25-
lu_instance,
26-
safevec, zeromatrix, undefmatrix, ColoringAlgorithm,
27-
fast_scalar_indexing, parameterless_type,
28-
is_forwarding_wrapper,
29-
map_tuple_type, flatten_tuples, GetIndex, SetIndex!,
30-
defines_strides, ndims_index, ndims_shape,
31-
stride_preserving_index
22+
parent_type, fast_matrix_colors, findstructralnz,
23+
has_sparsestruct,
24+
issingular, isstructured, matrix_colors, restructure,
25+
lu_instance,
26+
safevec, zeromatrix, undefmatrix, ColoringAlgorithm,
27+
fast_scalar_indexing, parameterless_type,
28+
is_forwarding_wrapper,
29+
map_tuple_type, flatten_tuples, GetIndex, SetIndex!,
30+
defines_strides, ndims_index, ndims_shape,
31+
stride_preserving_index
3232

3333
# ArrayIndex subtypes and methods
3434
import ArrayInterface: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex,
35-
TridiagonalIndex
35+
TridiagonalIndex
3636
# managing immutables
3737
import ArrayInterface: ismutable, can_change_size, can_setindex
3838
# constants
3939
import ArrayInterface: MatAdjTrans, VecAdjTrans, UpTri, LoTri
4040
# device pieces
4141
import ArrayInterface: AbstractDevice, AbstractCPU, CPUPointer, CPUTuple, CheckParent,
42-
CPUIndex, GPU, can_avx, device
42+
CPUIndex, GPU, can_avx, device
4343

4444
using Static
4545
using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
46-
permute, invariant_permutation, field_type, reduce_tup, find_first_eq,
47-
OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange,
48-
IntType,
49-
SOneTo, SUnitRange
46+
permute, invariant_permutation, field_type, reduce_tup, find_first_eq,
47+
OptionallyStaticUnitRange, OptionallyStaticStepRange,
48+
OptionallyStaticRange,
49+
IntType,
50+
SOneTo, SUnitRange, static_first, static_step, static_last
5051

5152
using IfElse
5253

@@ -77,26 +78,26 @@ known_offset1, known_offsets, known_size, known_step, known_strides
7778
Subtype of `ArrayIndex` that transforms and index using stride layout information
7879
derived from `x`.
7980
"""
80-
struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N}
81+
struct StrideIndex{N, R, C, S, O} <: ArrayIndex{N}
8182
strides::S
8283
offsets::O
83-
@inline function StrideIndex{N,R,C}(s::S, o::O) where {N,R,C,S,O}
84-
return new{N,R::NTuple{N,Int},C,S,O}(s, o)
84+
@inline function StrideIndex{N, R, C}(s::S, o::O) where {N, R, C, S, O}
85+
return new{N, R::NTuple{N, Int}, C, S, O}(s, o)
8586
end
8687
end
8788

8889
"""
8990
LazyAxis{N}(parent::AbstractArray)
9091
A lazy representation of `axes(parent, N)`.
9192
"""
92-
struct LazyAxis{N,P} <: AbstractUnitRange{Int}
93+
struct LazyAxis{N, P} <: AbstractUnitRange{Int}
9394
parent::P
9495

95-
function LazyAxis{N}(parent::P) where {N,P}
96-
N > 0 && return new{N::Int,P}(parent)
96+
function LazyAxis{N}(parent::P) where {N, P}
97+
N > 0 && return new{N::Int, P}(parent)
9798
throw_dim_error(parent, N)
9899
end
99-
@inline LazyAxis{:}(parent::P) where {P} = new{ifelse(ndims(P) === 1, 1, :),P}(parent)
100+
@inline LazyAxis{:}(parent::P) where {P} = new{ifelse(ndims(P) === 1, 1, :), P}(parent)
100101
end
101102

102103
function throw_dim_error(@nospecialize(x), @nospecialize(dim))
@@ -111,9 +112,10 @@ An abstract trait that is used to determine how axes are combined when calling `
111112
"""
112113
abstract type BroadcastAxis end
113114

114-
@assume_effects :total function _find_first_true(isi::Tuple{Vararg{Union{Bool, Static.StaticBool}, N}}) where {N}
115+
@assume_effects :total function _find_first_true(isi::Tuple{Vararg{
116+
Union{Bool, Static.StaticBool}, N}}) where {N}
115117
for i in 1:N
116-
x = getfield(isi, i)
118+
x = getfield(isi, i)
117119
if (x isa Bool && x === true) || x isa Static.True
118120
return i
119121
end
@@ -235,7 +237,8 @@ struct IndicesInfo{Np, pdims, cdims, Nc}
235237
IndicesInfo(x::SubArray) = IndicesInfo{ndims(parent(x))}(typeof(x.indices))
236238
end
237239

238-
@inline function _map_splats(nsplat::Int, splat_index::Int, dims::Tuple{Vararg{Union{Int,StaticInt}}})
240+
@inline function _map_splats(nsplat::Int, splat_index::Int, dims::Tuple{Vararg{Union{
241+
Int, StaticInt}}})
239242
ntuple(length(dims)) do i
240243
i === splat_index ? (nsplat * getfield(dims, i)) : getfield(dims, i)
241244
end
@@ -323,9 +326,9 @@ end
323326
return setindex!(A, val; kwargs...)
324327
end
325328

326-
@inline static_first(x) = Static.maybe_static(known_first, first, x)
327-
@inline static_last(x) = Static.maybe_static(known_last, last, x)
328-
@inline static_step(x) = Static.maybe_static(known_step, step, x)
329+
@inline Static.static_first(x::AbstractArray) = Static.maybe_static(known_first, first, x)
330+
@inline Static.static_last(x::AbstractArray) = Static.maybe_static(known_last, last, x)
331+
@inline Static.static_step(x::AbstractArray) = Static.maybe_static(known_step, step, x)
329332

330333
@inline function _to_cartesian(a, i::IntType)
331334
@inbounds(CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i])
@@ -434,7 +437,7 @@ Base.@propagate_inbounds function deleteat(collection::AbstractVector, index)
434437
return unsafe_deleteat(collection, index)
435438
end
436439
Base.@propagate_inbounds function deleteat(collection::Tuple{Vararg{Any, N}},
437-
index) where {N}
440+
index) where {N}
438441
@boundscheck if !checkindex(Bool, StaticInt{1}():StaticInt{N}(), index)
439442
throw(BoundsError(collection, index))
440443
end
@@ -504,11 +507,13 @@ include("broadcast.jl")
504507
# Putting some things in `setup` can reduce the size of the
505508
# precompile file and potentially make loading faster.
506509
arrays = [rand(4), Base.oneto(5)]
507-
@compile_workload begin for x in arrays
508-
known_first(x)
509-
known_step(x)
510-
known_last(x)
511-
end end
510+
@compile_workload begin
511+
for x in arrays
512+
known_first(x)
513+
known_step(x)
514+
known_last(x)
515+
end
516+
end
512517
end
513518

514519
end

0 commit comments

Comments
 (0)