Skip to content

Commit 555adcb

Browse files
authored
Random fixes (#16)
1 parent 4605dfc commit 555adcb

File tree

6 files changed

+104
-47
lines changed

6 files changed

+104
-47
lines changed

Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.5"
4+
version = "0.3.6"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
99
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
10+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -25,8 +26,9 @@ Adapt = "4.1.1"
2526
ArrayLayouts = "1.11.0"
2627
BlockArrays = "1.3.0"
2728
DerivableInterfaces = "0.3.7"
29+
FillArrays = "1.13.0"
2830
LinearAlgebra = "1.10"
29-
MapBroadcast = "0.1.5"
31+
MapBroadcast = "0.1.6"
3032
Random = "1.10"
3133
SimpleTraits = "0.9.4"
3234
TensorAlgebra = "0.1"

ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl

+6-10
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,33 @@ module NamedDimsArraysBlockArraysExt
22
using ArrayLayouts: ArrayLayouts
33
using BlockArrays: Block, BlockRange
44
using NamedDimsArrays:
5-
AbstractNamedDimsArray,
6-
AbstractNamedUnitRange,
7-
named_getindex,
8-
nameddims_getindex,
9-
nameddims_view
5+
AbstractNamedDimsArray, AbstractNamedUnitRange, getindex_named, view_nameddims
106

117
function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::Block{1})
128
# TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead.
13-
return named_getindex(r, I)
9+
return getindex_named(r, I)
1410
end
1511

1612
function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::BlockRange{1})
1713
# TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead.
18-
return named_getindex(r, I)
14+
return getindex_named(r, I)
1915
end
2016

2117
const BlockIndex{N} = Union{Block{N},BlockRange{N},AbstractVector{<:Block{N}}}
2218

2319
function Base.view(a::AbstractNamedDimsArray, I1::Block{1}, Irest::BlockIndex{1}...)
2420
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
25-
return nameddims_view(a, I1, Irest...)
21+
return view_nameddims(a, I1, Irest...)
2622
end
2723

2824
function Base.view(a::AbstractNamedDimsArray, I::Block)
2925
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
30-
return nameddims_view(a, Tuple(I)...)
26+
return view_nameddims(a, Tuple(I)...)
3127
end
3228

3329
function Base.view(a::AbstractNamedDimsArray, I1::BlockIndex{1}, Irest::BlockIndex{1}...)
3430
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
35-
return nameddims_view(a, I1, Irest...)
31+
return view_nameddims(a, I1, Irest...)
3632
end
3733

3834
# Fix ambiguity error.

src/abstractnamedarray.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ function Base.hash(a::AbstractNamedArray, h::UInt)
3838
return hash(name(a), h)
3939
end
4040

41-
named_getindex(a::AbstractArray, I...) = named(getindex(dename(a), I...), name(a))
41+
getindex_named(a::AbstractArray, I...) = named(getindex(dename(a), I...), name(a))
4242

4343
# Array funcionality.
4444
Base.size(a::AbstractNamedArray) = map(s -> named(s, name(a)), size(dename(a)))
4545
Base.axes(a::AbstractNamedArray) = map(s -> named(s, name(a)), axes(dename(a)))
4646
Base.eachindex(a::AbstractNamedArray) = eachindex(dename(a))
4747
function Base.getindex(a::AbstractNamedArray{<:Any,N}, I::Vararg{Int,N}) where {N}
48-
return named_getindex(a, I...)
48+
return getindex_named(a, I...)
4949
end
5050
function Base.getindex(a::AbstractNamedArray, I::Int)
51-
return named_getindex(a, I)
51+
return getindex_named(a, I)
5252
end
5353
Base.isempty(a::AbstractNamedArray) = isempty(dename(a))
5454

src/abstractnameddimsarray.jl

+81-26
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
140140
end
141141

142142
function Base.copy(a::AbstractNamedDimsArray)
143-
return nameddimsarraytype(a)(copy(dename(a)), nameddimsindices(a))
143+
return constructorof(typeof(a))(copy(dename(a)), nameddimsindices(a))
144144
end
145145

146146
const NamedDimsIndices = Union{
@@ -181,9 +181,11 @@ Base.values(s::NaiveOrderedSet) = s.values
181181
Base.Tuple(s::NaiveOrderedSet) = Tuple(values(s))
182182
Base.length(s::NaiveOrderedSet) = length(values(s))
183183
Base.axes(s::NaiveOrderedSet) = axes(values(s))
184+
Base.keys(s::NaiveOrderedSet) = Base.OneTo(length(s))
184185
Base.:(==)(s1::NaiveOrderedSet, s2::NaiveOrderedSet) = issetequal(values(s1), values(s2))
185186
Base.iterate(s::NaiveOrderedSet, args...) = iterate(values(s), args...)
186187
Base.getindex(s::NaiveOrderedSet, I::Int) = values(s)[I]
188+
Base.get(s::NaiveOrderedSet, I::Integer, default) = get(values(s), I, default)
187189
Base.invperm(s::NaiveOrderedSet) = NaiveOrderedSet(invperm(values(s)))
188190
Base.Broadcast._axes(::Broadcasted, axes::NaiveOrderedSet) = axes
189191
Base.Broadcast.BroadcastStyle(::Type{<:NaiveOrderedSet}) = Style{NaiveOrderedSet}()
@@ -210,6 +212,10 @@ function Base.size(a::AbstractNamedDimsArray)
210212
return NaiveOrderedSet(map(named, size(dename(a)), nameddimsindices(a)))
211213
end
212214

215+
function Base.length(a::AbstractNamedDimsArray)
216+
return prod(size(a); init=1)
217+
end
218+
213219
# Circumvent issue when ndims isn't known at compile time.
214220
function Base.axes(a::AbstractNamedDimsArray, d)
215221
return d <= ndims(a) ? axes(a)[d] : OneTo(1)
@@ -233,17 +239,20 @@ to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims)
233239
to_nameddimsaxis(ax::NamedDimsAxis) = ax
234240
to_nameddimsaxis(I::NamedDimsIndices) = named(dename(only(axes(I))), I)
235241

236-
nameddimsarraytype(a::AbstractNamedDimsArray) = nameddimsarraytype(typeof(a))
237-
nameddimsarraytype(a::Type{<:AbstractNamedDimsArray}) = unspecify_type_parameters(a)
242+
# Interface inspired by [ConstructionBase.constructorof](https://github.com/JuliaObjects/ConstructionBase.jl).
243+
constructorof(type::Type{<:AbstractArray}) = unspecify_type_parameters(type)
244+
245+
constructorof_nameddims(type::Type{<:AbstractNamedDimsArray}) = constructorof(type)
246+
constructorof_nameddims(type::Type{<:AbstractArray}) = NamedDimsArray
238247

239248
function similar_nameddims(a::AbstractNamedDimsArray, elt::Type, inds)
240249
ax = to_nameddimsaxes(inds)
241-
return nameddimsarraytype(a)(similar(dename(a), elt, dename.(Tuple(ax))), name.(ax))
250+
return constructorof(typeof(a))(similar(dename(a), elt, dename.(Tuple(ax))), name.(ax))
242251
end
243252

244253
function similar_nameddims(a::AbstractArray, elt::Type, inds)
245254
ax = to_nameddimsaxes(inds)
246-
return nameddims(similar(a, elt, dename.(Tuple(ax))), name.(ax))
255+
return constructorof_nameddims(typeof(a))(similar(a, elt, dename.(Tuple(ax))), name.(ax))
247256
end
248257

249258
# Base.similar gets the eltype at compile time.
@@ -262,7 +271,7 @@ function Base.similar(a::AbstractArray, elt::Type, inds::NaiveOrderedSet)
262271
end
263272

264273
function setnameddimsindices(a::AbstractNamedDimsArray, nameddimsindices)
265-
return nameddimsarraytype(a)(dename(a), nameddimsindices)
274+
return constructorof(typeof(a))(dename(a), nameddimsindices)
266275
end
267276
function replacenameddimsindices(f, a::AbstractNamedDimsArray)
268277
return setnameddimsindices(a, replace(f, nameddimsindices(a)))
@@ -419,10 +428,18 @@ function Base.setindex!(a::AbstractNamedDimsArray, value, I::CartesianIndex)
419428
setindex!(a, value, to_indices(a, (I,))...)
420429
return a
421430
end
431+
432+
function flatten_namedinteger(i::AbstractNamedInteger)
433+
if name(i) isa Union{AbstractNamedUnitRange,AbstractNamedArray}
434+
return name(i)[dename(i)]
435+
end
436+
return i
437+
end
438+
422439
function Base.setindex!(
423440
a::AbstractNamedDimsArray, value, I1::AbstractNamedInteger, Irest::AbstractNamedInteger...
424441
)
425-
I = (I1, Irest...)
442+
I = flatten_namedinteger.((I1, Irest...))
426443
# TODO: Check if this permuation should be inverted.
427444
perm = getperm(name.(nameddimsindices(a)), name.(I))
428445
# TODO: Throw a `NameMismatch` error.
@@ -510,7 +527,9 @@ function Base.view(a::AbstractNamedDimsArray, I1::NamedViewIndex, Irest::NamedVi
510527
subinds = map(nameddimsindices(a), I) do dimname, i
511528
return checked_indexin(dename(i), dename(dimname))
512529
end
513-
return nameddims(view(dename(a), subinds...), sub_nameddimsindices)
530+
return constructorof_nameddims(typeof(a))(
531+
view(dename(a), subinds...), sub_nameddimsindices
532+
)
514533
end
515534

516535
function Base.getindex(
@@ -522,22 +541,22 @@ end
522541
# Repeated definition of `Base.ViewIndex`.
523542
const ViewIndex = Union{Real,AbstractArray}
524543

525-
function nameddims_view(a::AbstractArray, I...)
544+
function view_nameddims(a::AbstractArray, I...)
526545
sub_dims = filter(dim -> !(I[dim] isa Real), ntuple(identity, ndims(a)))
527546
sub_nameddimsindices = map(dim -> nameddimsindices(a, dim)[I[dim]], sub_dims)
528-
return nameddims(view(dename(a), I...), sub_nameddimsindices)
547+
return constructorof(typeof(a))(view(dename(a), I...), sub_nameddimsindices)
529548
end
530549

531550
function Base.view(a::AbstractNamedDimsArray, I::ViewIndex...)
532-
return nameddims_view(a, I...)
551+
return view_nameddims(a, I...)
533552
end
534553

535-
function nameddims_getindex(a::AbstractArray, I...)
554+
function getindex_nameddims(a::AbstractArray, I...)
536555
return copy(view(a, I...))
537556
end
538557

539558
function Base.getindex(a::AbstractNamedDimsArray, I::ViewIndex...)
540-
return nameddims_getindex(a, I...)
559+
return getindex_nameddims(a, I...)
541560
end
542561

543562
function Base.setindex!(
@@ -556,7 +575,7 @@ function Base.setindex!(
556575
Irest::NamedViewIndex...,
557576
)
558577
I = (I1, Irest...)
559-
setindex!(a, nameddimsarraytype(a)(value, I), I...)
578+
setindex!(a, constructorof(typeof(a))(value, I), I...)
560579
return a
561580
end
562581
function Base.setindex!(
@@ -580,13 +599,13 @@ end
580599
function aligndims(a::AbstractArray, dims)
581600
new_nameddimsindices = to_nameddimsindices(a, dims)
582601
# TODO: Check this permutation is correct (it may be the inverse of what we want).
583-
perm = getperm(nameddimsindices(a), new_nameddimsindices)
602+
perm = Tuple(getperm(nameddimsindices(a), new_nameddimsindices))
584603
isperm(perm) || throw(
585604
NameMismatch(
586605
"Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)."
587606
),
588607
)
589-
return nameddimsarraytype(a)(permutedims(dename(a), perm), new_nameddimsindices)
608+
return constructorof(typeof(a))(permutedims(dename(a), perm), new_nameddimsindices)
590609
end
591610

592611
function aligneddims(a::AbstractArray, dims)
@@ -598,7 +617,9 @@ function aligneddims(a::AbstractArray, dims)
598617
"Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)."
599618
),
600619
)
601-
return nameddimsarraytype(a)(PermutedDimsArray(dename(a), perm), new_nameddimsindices)
620+
return constructorof_nameddims(typeof(a))(
621+
PermutedDimsArray(dename(a), perm), new_nameddimsindices
622+
)
602623
end
603624

604625
# Convenient constructors
@@ -711,16 +732,17 @@ using Base.Broadcast:
711732
broadcasted,
712733
check_broadcast_shape,
713734
combine_axes
714-
using MapBroadcast: Mapped, mapped
735+
using MapBroadcast: MapBroadcast, Mapped, mapped, tile
715736

716737
abstract type AbstractNamedDimsArrayStyle{N} <: AbstractArrayStyle{N} end
717738

718-
struct NamedDimsArrayStyle{N} <: AbstractNamedDimsArrayStyle{N} end
719-
NamedDimsArrayStyle(::Val{N}) where {N} = NamedDimsArrayStyle{N}()
720-
NamedDimsArrayStyle{M}(::Val{N}) where {M,N} = NamedDimsArrayStyle{N}()
739+
struct NamedDimsArrayStyle{N,NDA} <: AbstractNamedDimsArrayStyle{N} end
740+
NamedDimsArrayStyle(::Val{N}) where {N} = NamedDimsArrayStyle{N,NamedDimsArray}()
741+
NamedDimsArrayStyle{M}(::Val{N}) where {M,N} = NamedDimsArrayStyle{N,NamedDimsArray}()
742+
NamedDimsArrayStyle{M,NDA}(::Val{N}) where {M,N,NDA} = NamedDimsArrayStyle{N,NDA}()
721743

722744
function Broadcast.BroadcastStyle(arraytype::Type{<:AbstractNamedDimsArray})
723-
return NamedDimsArrayStyle{ndims(arraytype)}()
745+
return NamedDimsArrayStyle{ndims(arraytype),constructorof(arraytype)}()
724746
end
725747

726748
function Broadcast.combine_axes(
@@ -762,6 +784,24 @@ function set_promote_shape(
762784
return named.(ax_promoted, name.(ax1))
763785
end
764786

787+
# Handle operations like `ITensor() + ITensor(i, j)`.
788+
# TODO: Decide if this should be a general definition for `AbstractNamedDimsArray`,
789+
# or just for `AbstractITensor`.
790+
function set_promote_shape(
791+
ax1::Tuple{}, ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}}
792+
)
793+
return ax2
794+
end
795+
796+
# Handle operations like `ITensor(i, j) + ITensor()`.
797+
# TODO: Decide if this should be a general definition for `AbstractNamedDimsArray`,
798+
# or just for `AbstractITensor`.
799+
function set_promote_shape(
800+
ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}}, ax2::Tuple{}
801+
)
802+
return ax1
803+
end
804+
765805
function Broadcast.check_broadcast_shape(ax1::NaiveOrderedSet, ax2::NaiveOrderedSet)
766806
return set_check_broadcast_shape(Tuple(ax1), Tuple(ax2))
767807
end
@@ -775,6 +815,7 @@ function set_check_broadcast_shape(
775815
check_broadcast_shape(dename.(ax1), dename.(ax2_aligned))
776816
return nothing
777817
end
818+
set_check_broadcast_shape(ax1::Tuple{}, ax2::Tuple{}) = nothing
778819

779820
# Dename and lazily permute the arguments using the reference
780821
# dimension names.
@@ -783,19 +824,33 @@ function denamed(m::Mapped, nameddimsindices)
783824
return mapped(m.f, map(arg -> denamed(arg, nameddimsindices), m.args)...)
784825
end
785826

827+
function nameddimsarraytype(style::NamedDimsArrayStyle{<:Any,NDA}) where {NDA}
828+
return NDA
829+
end
830+
831+
using FillArrays: Fill
832+
833+
function MapBroadcast.tile(a::AbstractNamedDimsArray, ax)
834+
axes(a) == ax && return a
835+
if iszero(ndims(a))
836+
return constructorof(typeof(a))(Fill(a[], dename.(Tuple(ax))), name.(ax))
837+
end
838+
return error("Not implemented.")
839+
end
840+
786841
function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax)
787842
nameddimsindices = name.(ax)
788843
m′ = denamed(Mapped(bc), nameddimsindices)
789844
# TODO: Store the wrapper type in `AbstractNamedDimsArrayStyle` and use that
790845
# wrapper type rather than the generic `nameddims` constructor, which
791846
# can lose information.
792847
# Call it as `nameddimsarraytype(bc.style)`.
793-
return nameddims(similar(m′, elt, dename.(Tuple(ax))), nameddimsindices)
848+
return nameddimsarraytype(bc.style)(
849+
similar(m′, elt, dename.(Tuple(ax))), nameddimsindices
850+
)
794851
end
795852

796-
function Base.copyto!(
797-
dest::AbstractArray{<:Any,N}, bc::Broadcasted{<:AbstractNamedDimsArrayStyle{N}}
798-
) where {N}
853+
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractNamedDimsArrayStyle})
799854
return copyto!(dest, Mapped(bc))
800855
end
801856

src/abstractnamedinteger.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ struct FusedNames{Names} <: AbstractName
5757
names::Names
5858
end
5959
fusednames(name1, name2) = FusedNames((name1, name2))
60-
fusednames(name1::FusedNames, name2::FusedNames) = FusedNames(generic_vcat(name1, name2))
60+
function fusednames(name1::FusedNames, name2::FusedNames)
61+
return FusedNames(generic_vcat(name1.names, name2.names))
62+
end
6163
fusednames(name1, name2::FusedNames) = fusednames(FusedNames((name1,)), name2)
6264
fusednames(name1::FusedNames, name2) = fusednames(name1, FusedNames((name2,)))
6365

@@ -86,6 +88,8 @@ Base.:-(i::AbstractNamedInteger) = setvalue(i, -dename(i))
8688
# TODO: See if we can delete this.
8789
Base.:+(i1::Int, i2::AbstractNamedInteger) = i1 + dename(i2)
8890

91+
Base.:*(i1::Int, i2::AbstractNamedInteger) = named(i1 * dename(i2), name(i2))
92+
8993
Base.zero(i::AbstractNamedInteger) = setvalue(i, zero(dename(i)))
9094
Base.one(i::AbstractNamedInteger) = setvalue(i, one(dename(i)))
9195
Base.signbit(i::AbstractNamedInteger) = signbit(dename(i))

src/abstractnamedunitrange.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ named(r::AbstractUnitRange, name) = namedunitrange(r, name)
1616

1717
# Derived interface.
1818
# TODO: Use `Accessors.@set`?
19-
setname(r::AbstractNamedUnitRange, name) = namedunitrange(dename(r), name)
19+
setname(r::AbstractNamedUnitRange, name) = named(dename(r), name)
2020

2121
# TODO: Use `TypeParameterAccessors`.
2222
denametype(::Type{<:AbstractNamedUnitRange{<:Any,Value}}) where {Value} = Value
@@ -43,17 +43,17 @@ Base.length(r::AbstractNamedUnitRange) = named(length(dename(r)), name(r))
4343
Base.size(r::AbstractNamedUnitRange) = (named(length(dename(r)), name(r)),)
4444
Base.axes(r::AbstractNamedUnitRange) = (named(only(axes(dename(r))), name(r)),)
4545
Base.step(r::AbstractNamedUnitRange) = named(step(dename(r)), name(r))
46-
Base.getindex(r::AbstractNamedUnitRange, I::Int) = named_getindex(r, I)
46+
Base.getindex(r::AbstractNamedUnitRange, I::Int) = getindex_named(r, I)
4747
# Fix ambiguity error.
4848
function Base.getindex(r::AbstractNamedUnitRange, I::AbstractUnitRange{<:Integer})
49-
return named_getindex(r, I)
49+
return getindex_named(r, I)
5050
end
5151
# Fix ambiguity error.
5252
function Base.getindex(r::AbstractNamedUnitRange, I::Colon)
53-
return named_getindex(r, I)
53+
return getindex_named(r, I)
5454
end
5555
function Base.getindex(r::AbstractNamedUnitRange, I)
56-
return named_getindex(r, I)
56+
return getindex_named(r, I)
5757
end
5858
Base.isempty(r::AbstractNamedUnitRange) = isempty(dename(r))
5959

0 commit comments

Comments
 (0)