Skip to content

Commit 8da655f

Browse files
committed
Add dimension parameter to HasShape
Needed to determine the shape of indices in a type-stable way.
1 parent b72d9eb commit 8da655f

File tree

9 files changed

+26
-20
lines changed

9 files changed

+26
-20
lines changed

base/array.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,9 @@ _similar_for(c, T, itr, isz) = similar(c, T)
456456
collect(collection)
457457
458458
Return an `Array` of all items in a collection or iterator. For dictionaries, returns
459-
`Pair{KeyType, ValType}`. If the argument is array-like or is an iterator with the `HasShape()`
460-
trait, the result will have the same shape and number of dimensions as the argument.
459+
`Pair{KeyType, ValType}`. If the argument is array-like or is an iterator with the
460+
[`HasShape`](@ref IteratorSize) trait, the result will have the same shape
461+
and number of dimensions as the argument.
461462
462463
# Examples
463464
```jldoctest

base/asyncmap.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function verify_ntasks(iterable, ntasks)
125125

126126
if ntasks == 0
127127
chklen = IteratorSize(iterable)
128-
if (chklen == HasLength()) || (chklen == HasShape())
128+
if (chklen isa HasLength) || (chklen isa HasShape)
129129
ntasks = max(1,min(100, length(iterable)))
130130
else
131131
ntasks = 100

base/generator.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ end
5353
abstract type IteratorSize end
5454
struct SizeUnknown <: IteratorSize end
5555
struct HasLength <: IteratorSize end
56-
struct HasShape <: IteratorSize end
56+
struct HasShape{N} <: IteratorSize end
5757
struct IsInfinite <: IteratorSize end
5858

5959
"""
@@ -63,8 +63,9 @@ Given the type of an iterator, return one of the following values:
6363
6464
* `SizeUnknown()` if the length (number of elements) cannot be determined in advance.
6565
* `HasLength()` if there is a fixed, finite length.
66-
* `HasShape()` if there is a known length plus a notion of multidimensional shape (as for an array).
67-
In this case the [`size`](@ref) function is valid for the iterator.
66+
* `HasShape{N}()` if there is a known length plus a notion of multidimensional shape (as for an array).
67+
In this case `N` should give the number of dimensions, and the [`size`](@ref) function is valid
68+
for the iterator.
6869
* `IsInfinite()` if the iterator yields values forever.
6970
7071
The default value (for iterators that do not define this function) is `HasLength()`.
@@ -75,7 +76,7 @@ result, and algorithms that resize their result incrementally.
7576
7677
```jldoctest
7778
julia> Base.IteratorSize(1:5)
78-
Base.HasShape()
79+
Base.HasShape{1}()
7980
8081
julia> Base.IteratorSize((2,3))
8182
Base.HasLength()
@@ -110,7 +111,7 @@ Base.HasEltype()
110111
IteratorEltype(x) = IteratorEltype(typeof(x))
111112
IteratorEltype(::Type) = HasEltype() # HasEltype is the default
112113

113-
IteratorSize(::Type{<:AbstractArray}) = HasShape()
114+
IteratorSize(::Type{<:AbstractArray{<:Any,N}}) where {N} = HasShape{N}()
114115
IteratorSize(::Type{Generator{I,F}}) where {I,F} = IteratorSize(I)
115116
length(g::Generator) = length(g.iter)
116117
size(g::Generator) = size(g.iter)

base/iterators.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -705,11 +705,15 @@ julia> collect(Iterators.product(1:2,3:5))
705705
"""
706706
product(iters...) = ProductIterator(iters)
707707

708-
IteratorSize(::Type{ProductIterator{Tuple{}}}) = HasShape()
708+
IteratorSize(::Type{ProductIterator{Tuple{}}}) = HasShape{0}()
709709
IteratorSize(::Type{ProductIterator{T}}) where {T<:Tuple} =
710710
prod_iteratorsize( IteratorSize(tuple_type_head(T)), IteratorSize(ProductIterator{tuple_type_tail(T)}) )
711711

712-
prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape()
712+
prod_iteratorsize(::HasLength, ::HasLength) = HasShape{2}()
713+
prod_iteratorsize(::HasLength, ::HasShape{N}) where {N} = HasShape{N+1}()
714+
prod_iteratorsize(::HasShape{N}, ::HasLength) where {N} = HasShape{N+1}()
715+
prod_iteratorsize(::HasShape{M}, ::HasShape{N}) where {M,N} = HasShape{M+N}()
716+
713717
# products can have an infinite iterator
714718
prod_iteratorsize(::IsInfinite, ::IsInfinite) = IsInfinite()
715719
prod_iteratorsize(a, ::IsInfinite) = IsInfinite()

base/multidimensional.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ module IteratorsMD
275275
eltype(R::CartesianIndices) = eltype(typeof(R))
276276
eltype(::Type{CartesianIndices{N}}) where {N} = CartesianIndex{N}
277277
eltype(::Type{CartesianIndices{N,TT}}) where {N,TT} = CartesianIndex{N}
278-
IteratorSize(::Type{<:CartesianIndices}) = Base.HasShape()
278+
IteratorSize(::Type{<:CartesianIndices{N}}) where {N} = Base.HasShape{N}()
279279

280280
@inline function start(iter::CartesianIndices)
281281
iterfirst, iterlast = first(iter), last(iter)

base/number.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ ndims(x::Number) = 0
5353
ndims(::Type{<:Number}) = 0
5454
length(x::Number) = 1
5555
endof(x::Number) = 1
56-
IteratorSize(::Type{<:Number}) = HasShape()
56+
IteratorSize(::Type{<:Number}) = HasShape{0}()
5757
keys(::Number) = OneTo(1)
5858

5959
getindex(x::Number) = x

doc/src/manual/interfaces.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ to generically build upon those behaviors.
1313
| `next(iter, state)` |   | Returns the current item and the next state |
1414
| `done(iter, state)` |   | Tests if there are any items remaining |
1515
| **Important optional methods** | **Default definition** | **Brief description** |
16-
| `IteratorSize(IterType)` | `HasLength()` | One of `HasLength()`, `HasShape()`, `IsInfinite()`, or `SizeUnknown()` as appropriate |
16+
| `IteratorSize(IterType)` | `HasLength()` | One of `HasLength()`, `HasShape{N}()`, `IsInfinite()`, or `SizeUnknown()` as appropriate |
1717
| `IteratorEltype(IterType)` | `HasEltype()` | Either `EltypeUnknown()` or `HasEltype()` as appropriate |
1818
| `eltype(IterType)` | `Any` | The type of the items returned by `next()` |
1919
| `length(iter)` | (*undefined*) | The number of items, if known |
@@ -22,7 +22,7 @@ to generically build upon those behaviors.
2222
| Value returned by `IteratorSize(IterType)` | Required Methods |
2323
|:------------------------------------------ |:------------------------------------------ |
2424
| `HasLength()` | `length(iter)` |
25-
| `HasShape()` | `length(iter)` and `size(iter, [dim...])` |
25+
| `HasShape{N}()` | `length(iter)` and `size(iter, [dim...])` |
2626
| `IsInfinite()` | (*none*) |
2727
| `SizeUnknown()` | (*none*) |
2828

test/generic_map_tests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ function testmap_equivalence(mapf, f, c...)
6161
x1 = mapf(f,c...)
6262
x2 = map(f,c...)
6363

64-
if Base.IteratorSize == Base.HasShape()
64+
if Base.IteratorSize isa Base.HasShape
6565
@test size(x1) == size(x2)
6666
else
6767
@test length(x1) == length(x2)

test/iterators.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,11 @@ end
318318
@test Base.IteratorSize(product(1:2, countfrom(1))) == Base.IsInfinite()
319319
@test Base.IteratorSize(product(countfrom(2), countfrom(1))) == Base.IsInfinite()
320320
@test Base.IteratorSize(product(countfrom(1), 1:2)) == Base.IsInfinite()
321-
@test Base.IteratorSize(product(1:2)) == Base.HasShape()
322-
@test Base.IteratorSize(product(1:2, 1:2)) == Base.HasShape()
323-
@test Base.IteratorSize(product(take(1:2, 1), take(1:2, 1))) == Base.HasShape()
324-
@test Base.IteratorSize(product(take(1:2, 2))) == Base.HasShape()
325-
@test Base.IteratorSize(product([1 2; 3 4])) == Base.HasShape()
321+
@test Base.IteratorSize(product(1:2)) == Base.HasShape{1}()
322+
@test Base.IteratorSize(product(1:2, 1:2)) == Base.HasShape{2}()
323+
@test Base.IteratorSize(product(take(1:2, 1), take(1:2, 1))) == Base.HasShape{2}()
324+
@test Base.IteratorSize(product(take(1:2, 2))) == Base.HasShape{1}()
325+
@test Base.IteratorSize(product([1 2; 3 4])) == Base.HasShape{2}()
326326

327327
# IteratorEltype trait business
328328
let f1 = Iterators.filter(i->i>0, 1:10)

0 commit comments

Comments
 (0)