Skip to content

Commit 354c56d

Browse files
authored
Add missing functionality like tagging and priming (#4)
1 parent 0fd1ea2 commit 354c56d

File tree

5 files changed

+166
-15
lines changed

5 files changed

+166
-15
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "ITensorBase"
22
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
88
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1112
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
1213
UnallocatedArrays = "43c9e47c-e622-40fb-bf18-a09fc8c466b6"
@@ -16,6 +17,7 @@ UnspecifiedTypes = "42b3faec-625b-4613-8ddc-352bf9672b8d"
1617
Accessors = "0.1.39"
1718
DerivableInterfaces = "0.3.7"
1819
FillArrays = "1.13.0"
20+
LinearAlgebra = "1.10"
1921
MapBroadcast = "0.1.5"
2022
NamedDimsArrays = "0.3.0"
2123
UnallocatedArrays = "0.1.1"

README.md

+1-6
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,8 @@ julia> Pkg.add("ITensorBase")
1919

2020
````julia
2121
using ITensorBase: ITensorBase, ITensor, Index
22-
````
23-
24-
TODO: This should be `TensorAlgebra.qr`.
25-
26-
````julia
2722
using LinearAlgebra: qr
28-
using NamedDimsArrays: NamedDimsArray, aligndims, dimnames, name, unname
23+
using NamedDimsArrays: aligndims, unname
2924
using Test: @test
3025
i = Index(2)
3126
j = Index(2)

examples/README.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ julia> Pkg.add("ITensorBase")
2020
# ## Examples
2121

2222
using ITensorBase: ITensorBase, ITensor, Index
23-
# TODO: This should be `TensorAlgebra.qr`.
2423
using LinearAlgebra: qr
25-
using NamedDimsArrays: NamedDimsArray, aligndims, dimnames, name, unname
24+
using NamedDimsArrays: aligndims, unname
2625
using Test: @test
2726
i = Index(2)
2827
j = Index(2)

src/ITensorBase.jl

+130-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module ITensorBase
22

3+
export ITensor, Index
4+
5+
using Accessors: @set
36
using MapBroadcast: Mapped
47
using NamedDimsArrays:
58
NamedDimsArrays,
@@ -14,16 +17,53 @@ using NamedDimsArrays:
1417
name,
1518
named,
1619
nameddimsindices,
20+
setname,
21+
setnameddimsindices,
1722
unname
1823

24+
const Tag = String
25+
const TagSet = Set{Tag}
26+
27+
tagset(tags::String) = Set(filter(!isempty, String.(strip.(split(tags, ",")))))
28+
tagset(tags::TagSet) = tags
29+
30+
function tagsstring(tags::TagSet)
31+
str = ""
32+
length(tags) == 0 && return str
33+
tags_vec = collect(tags)
34+
for n in 1:(length(tags_vec) - 1)
35+
str *= "$(tags_vec[n]),"
36+
end
37+
str *= "$(tags_vec[end])"
38+
return str
39+
end
40+
1941
@kwdef struct IndexName <: AbstractName
2042
id::UInt64 = rand(UInt64)
43+
tags::TagSet = TagSet()
2144
plev::Int = 0
22-
tags::Set{String} = Set{String}()
23-
namedtags::Dict{Symbol,String} = Dict{Symbol,String}()
2445
end
2546
NamedDimsArrays.randname(n::IndexName) = IndexName()
2647

48+
id(n::IndexName) = n.id
49+
tags(n::IndexName) = n.tags
50+
plev(n::IndexName) = n.plev
51+
52+
settags(n::IndexName, tags) = @set n.tags = tags
53+
addtags(n::IndexName, ts) = settags(n, tags(n) tagset(ts))
54+
55+
setprime(n::IndexName, plev) = @set n.plev = plev
56+
prime(n::IndexName) = setprime(n, plev(n) + 1)
57+
58+
function Base.show(io::IO, i::IndexName)
59+
idstr = "id=$(id(i) % 1000)"
60+
tagsstr = !isempty(tags(i)) ? "|\"$(tagsstring(tags(i)))\"" : ""
61+
primestr = primestring(plev(i))
62+
str = "IndexName($(idstr)$(tagsstr))$(primestr)"
63+
print(io, str)
64+
return nothing
65+
end
66+
2767
struct IndexVal{Value<:Integer} <: AbstractNamedInteger{Value,IndexName}
2868
value::Value
2969
name::IndexName
@@ -41,7 +81,22 @@ struct Index{T,Value<:AbstractUnitRange{T}} <: AbstractNamedUnitRange{T,Value,In
4181
name::IndexName
4282
end
4383

44-
Index(length::Int) = Index(Base.OneTo(length), IndexName())
84+
function Index(length::Int; tags=TagSet(), kwargs...)
85+
return Index(Base.OneTo(length), IndexName(; tags=tagset(tags), kwargs...))
86+
end
87+
function Index(length::Int, tags::String; kwargs...)
88+
return Index(Base.OneTo(length), IndexName(; kwargs..., tags=tagset(tags)))
89+
end
90+
91+
# TODO: Define for `NamedDimsArrays.NamedViewIndex`.
92+
id(i::Index) = id(name(i))
93+
tags(i::Index) = tags(name(i))
94+
plev(i::Index) = plev(name(i))
95+
96+
# TODO: Define for `NamedDimsArrays.NamedViewIndex`.
97+
addtags(i::Index, tags) = setname(i, addtags(name(i), tags))
98+
prime(i::Index) = setname(i, prime(name(i)))
99+
Base.adjoint(i::Index) = prime(i)
45100

46101
# Interface
47102
# TODO: Overload `Base.parent` instead.
@@ -51,6 +106,29 @@ NamedDimsArrays.name(i::Index) = i.name
51106
# Constructor
52107
NamedDimsArrays.named(i::AbstractUnitRange, name::IndexName) = Index(i, name)
53108

109+
function primestring(plev)
110+
if plev < 0
111+
return " (warning: prime level $plev is less than 0)"
112+
end
113+
if plev == 0
114+
return ""
115+
elseif plev > 3
116+
return "'$plev"
117+
else
118+
return "'"^plev
119+
end
120+
end
121+
122+
function Base.show(io::IO, i::Index)
123+
lenstr = "length=$(dename(length(i)))"
124+
idstr = "|id=$(id(i) % 1000)"
125+
tagsstr = !isempty(tags(i)) ? "|\"$(tagsstring(tags(i)))\"" : ""
126+
primestr = primestring(plev(i))
127+
str = "Index($(lenstr)$(idstr)$(tagsstr))$(primestr)"
128+
print(io, str)
129+
return nothing
130+
end
131+
54132
struct NoncontiguousIndex{T,Value<:AbstractVector{T}} <:
55133
AbstractNamedVector{T,Value,IndexName}
56134
value::Value
@@ -103,17 +181,30 @@ struct AllocatableArrayInterface <: AbstractAllocatableArrayInterface end
103181

104182
unallocatable(a::AbstractITensor) = NamedDimsArray(a)
105183

106-
@interface ::AbstractAllocatableArrayInterface function Base.setindex!(
107-
a::AbstractArray, value, I::Int...
108-
)
184+
function setindex_allocatable!(a::AbstractArray, value, I...)
109185
allocate!(specify_eltype!(a, typeof(value)))
110186
# TODO: Maybe use `@interface interface(a) a[I...] = value`?
111187
unallocatable(a)[I...] = value
112188
return a
113189
end
114190

191+
# TODO: Combine these by using `Base.to_indices`.
192+
@interface ::AbstractAllocatableArrayInterface function Base.setindex!(
193+
a::AbstractArray, value, I::Int...
194+
)
195+
setindex_allocatable!(a, value, I...)
196+
return a
197+
end
198+
@interface ::AbstractAllocatableArrayInterface function Base.setindex!(
199+
a::AbstractArray, value, I::AbstractNamedInteger...
200+
)
201+
setindex_allocatable!(a, value, I...)
202+
return a
203+
end
204+
115205
@derive AllocatableArrayInterface() (T=AbstractITensor,) begin
116206
Base.setindex!(::T, ::Any, ::Int...)
207+
Base.setindex!(::T, ::Any, ::AbstractNamedInteger...)
117208
end
118209

119210
mutable struct ITensor <: AbstractITensor
@@ -127,9 +218,42 @@ using Accessors: @set
127218
setdenamed(a::ITensor, denamed) = (@set a.parent = denamed)
128219
setdenamed!(a::ITensor, denamed) = (a.parent = denamed)
129220

221+
function ITensor(elt::Type, I1::Index, I_rest::Index...)
222+
I = (I1, I_rest...)
223+
# TODO: Use `FillArrays.Zeros`.
224+
return ITensor(zeros(elt, length.(dename.(I))...), I)
225+
end
226+
130227
function ITensor(I1::Index, I_rest::Index...)
131228
I = (I1, I_rest...)
132229
return ITensor(Zeros{UnspecifiedZero}(length.(dename.(I))...), I)
133230
end
134231

232+
function ITensor()
233+
return ITensor(Zeros{UnspecifiedZero}(), ())
234+
end
235+
236+
inds(a::AbstractITensor) = nameddimsindices(a)
237+
setinds(a::AbstractITensor, inds) = setnameddimsindices(a, inds)
238+
239+
function uniqueinds(a1::AbstractITensor, a_rest::AbstractITensor...)
240+
return setdiff(inds(a1), inds.(a_rest)...)
241+
end
242+
function uniqueind(a1::AbstractITensor, a_rest::AbstractITensor...)
243+
return only(uniqueinds(a1, a_rest...))
244+
end
245+
246+
function commoninds(a1::AbstractITensor, a_rest::AbstractITensor...)
247+
return intersect(inds(a1), inds.(a_rest)...)
248+
end
249+
function commonind(a1::AbstractITensor, a_rest::AbstractITensor...)
250+
return only(commoninds(a1, a_rest...))
251+
end
252+
253+
# TODO: Use `replaceinds`/`mapinds`, based on
254+
# `replacenameddimsindices`/`mapnameddimsindices`.
255+
prime(a::AbstractITensor) = setinds(a, prime.(inds(a)))
256+
257+
include("quirks.jl")
258+
135259
end

src/quirks.jl

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# TODO: Define this properly.
2+
dag(i::Index) = i
3+
# TODO: Deprecate.
4+
dim(i::Index) = dename(length(i))
5+
# TODO: Define this properly.
6+
hasqns(i::Index) = false
7+
# TODO: Deprecate.
8+
itensor(parent::AbstractArray, nameddimsindices) = ITensor(parent, nameddimsindices)
9+
function itensor(parent::AbstractArray, i1::Index, i_rest::Index...)
10+
return ITensor(parent, (i1, i_rest...))
11+
end
12+
13+
# This seems to be needed to get broadcasting working.
14+
# TODO: Investigate this and see if we can get rid of it.
15+
Base.Broadcast.extrude(a::AbstractITensor) = a
16+
17+
# TODO: Generalize this.
18+
# Maybe define it as `oneelement`, and base it on
19+
# `FillArrays.OneElement` (https://juliaarrays.github.io/FillArrays.jl/stable/#FillArrays.OneElement).
20+
function onehot(iv::Pair{<:Index,<:Int})
21+
a = ITensor(first(iv))
22+
a[last(iv)] = one(Bool)
23+
return a
24+
end
25+
26+
using LinearAlgebra: svd
27+
# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
28+
function factorize(a::AbstractITensor, args...; kwargs...)
29+
U, S, V = svd(a, args...; kwargs...)
30+
return U, S * V
31+
end

0 commit comments

Comments
 (0)