@@ -140,7 +140,7 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
140
140
end
141
141
142
142
function Base. copy (a:: AbstractNamedDimsArray )
143
- return nameddimsarraytype (a )(copy (dename (a)), nameddimsindices (a))
143
+ return constructorof ( typeof (a) )(copy (dename (a)), nameddimsindices (a))
144
144
end
145
145
146
146
const NamedDimsIndices = Union{
@@ -181,9 +181,11 @@ Base.values(s::NaiveOrderedSet) = s.values
181
181
Base. Tuple (s:: NaiveOrderedSet ) = Tuple (values (s))
182
182
Base. length (s:: NaiveOrderedSet ) = length (values (s))
183
183
Base. axes (s:: NaiveOrderedSet ) = axes (values (s))
184
+ Base. keys (s:: NaiveOrderedSet ) = Base. OneTo (length (s))
184
185
Base.:(== )(s1:: NaiveOrderedSet , s2:: NaiveOrderedSet ) = issetequal (values (s1), values (s2))
185
186
Base. iterate (s:: NaiveOrderedSet , args... ) = iterate (values (s), args... )
186
187
Base. getindex (s:: NaiveOrderedSet , I:: Int ) = values (s)[I]
188
+ Base. get (s:: NaiveOrderedSet , I:: Integer , default) = get (values (s), I, default)
187
189
Base. invperm (s:: NaiveOrderedSet ) = NaiveOrderedSet (invperm (values (s)))
188
190
Base. Broadcast. _axes (:: Broadcasted , axes:: NaiveOrderedSet ) = axes
189
191
Base. Broadcast. BroadcastStyle (:: Type{<:NaiveOrderedSet} ) = Style {NaiveOrderedSet} ()
@@ -210,6 +212,10 @@ function Base.size(a::AbstractNamedDimsArray)
210
212
return NaiveOrderedSet (map (named, size (dename (a)), nameddimsindices (a)))
211
213
end
212
214
215
+ function Base. length (a:: AbstractNamedDimsArray )
216
+ return prod (size (a); init= 1 )
217
+ end
218
+
213
219
# Circumvent issue when ndims isn't known at compile time.
214
220
function Base. axes (a:: AbstractNamedDimsArray , d)
215
221
return d <= ndims (a) ? axes (a)[d] : OneTo (1 )
@@ -233,17 +239,20 @@ to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims)
233
239
to_nameddimsaxis (ax:: NamedDimsAxis ) = ax
234
240
to_nameddimsaxis (I:: NamedDimsIndices ) = named (dename (only (axes (I))), I)
235
241
236
- nameddimsarraytype (a:: AbstractNamedDimsArray ) = nameddimsarraytype (typeof (a))
237
- nameddimsarraytype (a:: Type{<:AbstractNamedDimsArray} ) = unspecify_type_parameters (a)
242
+ # Interface inspired by [ConstructionBase.constructorof](https://github.com/JuliaObjects/ConstructionBase.jl).
243
+ constructorof (type:: Type{<:AbstractArray} ) = unspecify_type_parameters (type)
244
+
245
+ constructorof_nameddims (type:: Type{<:AbstractNamedDimsArray} ) = constructorof (type)
246
+ constructorof_nameddims (type:: Type{<:AbstractArray} ) = NamedDimsArray
238
247
239
248
function similar_nameddims (a:: AbstractNamedDimsArray , elt:: Type , inds)
240
249
ax = to_nameddimsaxes (inds)
241
- return nameddimsarraytype (a )(similar (dename (a), elt, dename .(Tuple (ax))), name .(ax))
250
+ return constructorof ( typeof (a) )(similar (dename (a), elt, dename .(Tuple (ax))), name .(ax))
242
251
end
243
252
244
253
function similar_nameddims (a:: AbstractArray , elt:: Type , inds)
245
254
ax = to_nameddimsaxes (inds)
246
- return nameddims (similar (a, elt, dename .(Tuple (ax))), name .(ax))
255
+ return constructorof_nameddims ( typeof (a)) (similar (a, elt, dename .(Tuple (ax))), name .(ax))
247
256
end
248
257
249
258
# Base.similar gets the eltype at compile time.
@@ -262,7 +271,7 @@ function Base.similar(a::AbstractArray, elt::Type, inds::NaiveOrderedSet)
262
271
end
263
272
264
273
function setnameddimsindices (a:: AbstractNamedDimsArray , nameddimsindices)
265
- return nameddimsarraytype (a )(dename (a), nameddimsindices)
274
+ return constructorof ( typeof (a) )(dename (a), nameddimsindices)
266
275
end
267
276
function replacenameddimsindices (f, a:: AbstractNamedDimsArray )
268
277
return setnameddimsindices (a, replace (f, nameddimsindices (a)))
@@ -419,10 +428,18 @@ function Base.setindex!(a::AbstractNamedDimsArray, value, I::CartesianIndex)
419
428
setindex! (a, value, to_indices (a, (I,))... )
420
429
return a
421
430
end
431
+
432
+ function flatten_namedinteger (i:: AbstractNamedInteger )
433
+ if name (i) isa Union{AbstractNamedUnitRange,AbstractNamedArray}
434
+ return name (i)[dename (i)]
435
+ end
436
+ return i
437
+ end
438
+
422
439
function Base. setindex! (
423
440
a:: AbstractNamedDimsArray , value, I1:: AbstractNamedInteger , Irest:: AbstractNamedInteger...
424
441
)
425
- I = ( I1, Irest... )
442
+ I = flatten_namedinteger .(( I1, Irest... ) )
426
443
# TODO : Check if this permuation should be inverted.
427
444
perm = getperm (name .(nameddimsindices (a)), name .(I))
428
445
# TODO : Throw a `NameMismatch` error.
@@ -510,7 +527,9 @@ function Base.view(a::AbstractNamedDimsArray, I1::NamedViewIndex, Irest::NamedVi
510
527
subinds = map (nameddimsindices (a), I) do dimname, i
511
528
return checked_indexin (dename (i), dename (dimname))
512
529
end
513
- return nameddims (view (dename (a), subinds... ), sub_nameddimsindices)
530
+ return constructorof_nameddims (typeof (a))(
531
+ view (dename (a), subinds... ), sub_nameddimsindices
532
+ )
514
533
end
515
534
516
535
function Base. getindex (
@@ -522,22 +541,22 @@ end
522
541
# Repeated definition of `Base.ViewIndex`.
523
542
const ViewIndex = Union{Real,AbstractArray}
524
543
525
- function nameddims_view (a:: AbstractArray , I... )
544
+ function view_nameddims (a:: AbstractArray , I... )
526
545
sub_dims = filter (dim -> ! (I[dim] isa Real), ntuple (identity, ndims (a)))
527
546
sub_nameddimsindices = map (dim -> nameddimsindices (a, dim)[I[dim]], sub_dims)
528
- return nameddims (view (dename (a), I... ), sub_nameddimsindices)
547
+ return constructorof ( typeof (a)) (view (dename (a), I... ), sub_nameddimsindices)
529
548
end
530
549
531
550
function Base. view (a:: AbstractNamedDimsArray , I:: ViewIndex... )
532
- return nameddims_view (a, I... )
551
+ return view_nameddims (a, I... )
533
552
end
534
553
535
- function nameddims_getindex (a:: AbstractArray , I... )
554
+ function getindex_nameddims (a:: AbstractArray , I... )
536
555
return copy (view (a, I... ))
537
556
end
538
557
539
558
function Base. getindex (a:: AbstractNamedDimsArray , I:: ViewIndex... )
540
- return nameddims_getindex (a, I... )
559
+ return getindex_nameddims (a, I... )
541
560
end
542
561
543
562
function Base. setindex! (
@@ -556,7 +575,7 @@ function Base.setindex!(
556
575
Irest:: NamedViewIndex... ,
557
576
)
558
577
I = (I1, Irest... )
559
- setindex! (a, nameddimsarraytype (a )(value, I), I... )
578
+ setindex! (a, constructorof ( typeof (a) )(value, I), I... )
560
579
return a
561
580
end
562
581
function Base. setindex! (
@@ -580,13 +599,13 @@ end
580
599
function aligndims (a:: AbstractArray , dims)
581
600
new_nameddimsindices = to_nameddimsindices (a, dims)
582
601
# TODO : Check this permutation is correct (it may be the inverse of what we want).
583
- perm = getperm (nameddimsindices (a), new_nameddimsindices)
602
+ perm = Tuple ( getperm (nameddimsindices (a), new_nameddimsindices) )
584
603
isperm (perm) || throw (
585
604
NameMismatch (
586
605
" Dimension name mismatch $(nameddimsindices (a)) , $(new_nameddimsindices) ."
587
606
),
588
607
)
589
- return nameddimsarraytype (a )(permutedims (dename (a), perm), new_nameddimsindices)
608
+ return constructorof ( typeof (a) )(permutedims (dename (a), perm), new_nameddimsindices)
590
609
end
591
610
592
611
function aligneddims (a:: AbstractArray , dims)
@@ -598,7 +617,9 @@ function aligneddims(a::AbstractArray, dims)
598
617
" Dimension name mismatch $(nameddimsindices (a)) , $(new_nameddimsindices) ."
599
618
),
600
619
)
601
- return nameddimsarraytype (a)(PermutedDimsArray (dename (a), perm), new_nameddimsindices)
620
+ return constructorof_nameddims (typeof (a))(
621
+ PermutedDimsArray (dename (a), perm), new_nameddimsindices
622
+ )
602
623
end
603
624
604
625
# Convenient constructors
@@ -711,16 +732,17 @@ using Base.Broadcast:
711
732
broadcasted,
712
733
check_broadcast_shape,
713
734
combine_axes
714
- using MapBroadcast: Mapped, mapped
735
+ using MapBroadcast: MapBroadcast, Mapped, mapped, tile
715
736
716
737
abstract type AbstractNamedDimsArrayStyle{N} <: AbstractArrayStyle{N} end
717
738
718
- struct NamedDimsArrayStyle{N} <: AbstractNamedDimsArrayStyle{N} end
719
- NamedDimsArrayStyle (:: Val{N} ) where {N} = NamedDimsArrayStyle {N} ()
720
- NamedDimsArrayStyle {M} (:: Val{N} ) where {M,N} = NamedDimsArrayStyle {N} ()
739
+ struct NamedDimsArrayStyle{N,NDA} <: AbstractNamedDimsArrayStyle{N} end
740
+ NamedDimsArrayStyle (:: Val{N} ) where {N} = NamedDimsArrayStyle {N,NamedDimsArray} ()
741
+ NamedDimsArrayStyle {M} (:: Val{N} ) where {M,N} = NamedDimsArrayStyle {N,NamedDimsArray} ()
742
+ NamedDimsArrayStyle {M,NDA} (:: Val{N} ) where {M,N,NDA} = NamedDimsArrayStyle {N,NDA} ()
721
743
722
744
function Broadcast. BroadcastStyle (arraytype:: Type{<:AbstractNamedDimsArray} )
723
- return NamedDimsArrayStyle {ndims(arraytype)} ()
745
+ return NamedDimsArrayStyle {ndims(arraytype),constructorof(arraytype) } ()
724
746
end
725
747
726
748
function Broadcast. combine_axes (
@@ -762,6 +784,24 @@ function set_promote_shape(
762
784
return named .(ax_promoted, name .(ax1))
763
785
end
764
786
787
+ # Handle operations like `ITensor() + ITensor(i, j)`.
788
+ # TODO : Decide if this should be a general definition for `AbstractNamedDimsArray`,
789
+ # or just for `AbstractITensor`.
790
+ function set_promote_shape (
791
+ ax1:: Tuple{} , ax2:: Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}}
792
+ )
793
+ return ax2
794
+ end
795
+
796
+ # Handle operations like `ITensor(i, j) + ITensor()`.
797
+ # TODO : Decide if this should be a general definition for `AbstractNamedDimsArray`,
798
+ # or just for `AbstractITensor`.
799
+ function set_promote_shape (
800
+ ax1:: Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}} , ax2:: Tuple{}
801
+ )
802
+ return ax1
803
+ end
804
+
765
805
function Broadcast. check_broadcast_shape (ax1:: NaiveOrderedSet , ax2:: NaiveOrderedSet )
766
806
return set_check_broadcast_shape (Tuple (ax1), Tuple (ax2))
767
807
end
@@ -775,6 +815,7 @@ function set_check_broadcast_shape(
775
815
check_broadcast_shape (dename .(ax1), dename .(ax2_aligned))
776
816
return nothing
777
817
end
818
+ set_check_broadcast_shape (ax1:: Tuple{} , ax2:: Tuple{} ) = nothing
778
819
779
820
# Dename and lazily permute the arguments using the reference
780
821
# dimension names.
@@ -783,19 +824,33 @@ function denamed(m::Mapped, nameddimsindices)
783
824
return mapped (m. f, map (arg -> denamed (arg, nameddimsindices), m. args)... )
784
825
end
785
826
827
+ function nameddimsarraytype (style:: NamedDimsArrayStyle{<:Any,NDA} ) where {NDA}
828
+ return NDA
829
+ end
830
+
831
+ using FillArrays: Fill
832
+
833
+ function MapBroadcast. tile (a:: AbstractNamedDimsArray , ax)
834
+ axes (a) == ax && return a
835
+ if iszero (ndims (a))
836
+ return constructorof (typeof (a))(Fill (a[], dename .(Tuple (ax))), name .(ax))
837
+ end
838
+ return error (" Not implemented." )
839
+ end
840
+
786
841
function Base. similar (bc:: Broadcasted{<:AbstractNamedDimsArrayStyle} , elt:: Type , ax)
787
842
nameddimsindices = name .(ax)
788
843
m′ = denamed (Mapped (bc), nameddimsindices)
789
844
# TODO : Store the wrapper type in `AbstractNamedDimsArrayStyle` and use that
790
845
# wrapper type rather than the generic `nameddims` constructor, which
791
846
# can lose information.
792
847
# Call it as `nameddimsarraytype(bc.style)`.
793
- return nameddims (similar (m′, elt, dename .(Tuple (ax))), nameddimsindices)
848
+ return nameddimsarraytype (bc. style)(
849
+ similar (m′, elt, dename .(Tuple (ax))), nameddimsindices
850
+ )
794
851
end
795
852
796
- function Base. copyto! (
797
- dest:: AbstractArray{<:Any,N} , bc:: Broadcasted{<:AbstractNamedDimsArrayStyle{N}}
798
- ) where {N}
853
+ function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:AbstractNamedDimsArrayStyle} )
799
854
return copyto! (dest, Mapped (bc))
800
855
end
801
856
0 commit comments