Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
*.jl.cov
*.jl.*.cov
*.jl.mem
Manifest.toml
Manifest.toml
.vscode
.DS_Store
22 changes: 6 additions & 16 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
name = "TensorOperations"
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
authors = [
"Lukas Devos <[email protected]>",
"Maarten Van Damme <[email protected]>",
"Jutho Haegeman <[email protected]>",
]
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
version = "5.0.0"

[deps]
Expand All @@ -24,11 +20,13 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
TensorOperationsBumperExt = "Bumper"
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
TensorOperationsOMEinsumContractionOrdersExt = "OMEinsumContractionOrders"
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[compat]
Expand All @@ -41,6 +39,7 @@ DynamicPolynomials = "0.5"
LRUCache = "1"
LinearAlgebra = "1.6"
Logging = "1.6"
OMEinsumContractionOrders = "0.9"
PackageExtensionCompat = "1"
PtrArrays = "1.2"
Random = "1"
Expand All @@ -59,19 +58,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = [
"Test",
"Random",
"DynamicPolynomials",
"ChainRulesTestUtils",
"CUDA",
"cuTENSOR",
"Aqua",
"Logging",
"Bumper",
]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "OMEinsumContractionOrders"]
110 changes: 110 additions & 0 deletions ext/TensorOperationsOMEinsumContractionOrdersExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
module TensorOperationsOMEinsumContractionOrdersExt

using TensorOperations
using TensorOperations: TensorOperations as TO
using TensorOperations: TreeOptimizer
using OMEinsumContractionOrders
using OMEinsumContractionOrders: EinCode, NestedEinsum, SlicedEinsum, isleaf, optimize_kahypar_auto

function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:GreedyMethod}, verbose::Bool) where{TDK, TDV}
ome_optimizer = GreedyMethod()
return optimize(network, optdata, ome_optimizer, verbose)
end

function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:KaHyParBipartite}, verbose::Bool) where{TDK, TDV}

return optimize_kahypar(network, optdata, verbose)
end

function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:TreeSA}, verbose::Bool) where{TDK, TDV}
ome_optimizer = TreeSA()
return optimize(network, optdata, ome_optimizer, verbose)
end

function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:SABipartite}, verbose::Bool) where{TDK, TDV}
ome_optimizer = SABipartite()
return optimize(network, optdata, ome_optimizer, verbose)
end

function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:ExactTreewidth}, verbose::Bool) where{TDK, TDV}
ome_optimizer = ExactTreewidth()
return optimize(network, optdata, ome_optimizer, verbose)
end

function optimize(network, optdata::Dict{TDK, TDV}, ome_optimizer::CodeOptimizer, verbose::Bool) where{TDK, TDV}
try
@assert TDV <: Number
catch
throw(ArgumentError("The values of the optdata dictionary must be of type Number"))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a strange construction, try ... catch ... end typically means you would handle the error it throws, not rethrow a new error. You can add messages to the assertion error too:

Suggested change
try
@assert TDV <: Number
catch
throw(ArgumentError("The values of the optdata dictionary must be of type Number"))
end
@assert TDV <: Number "The values of `optdata` must be `<:Number`"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the mistake, now I simply use @asser instead of try...catch...


# transform the network as EinCode
code, size_dict = network2eincode(network, optdata)
# optimize the contraction order using OMEinsumContractionOrders, which gives a NestedEinsum
optcode = optimize_code(code, size_dict, ome_optimizer)

# transform the optimized contraction order back to the network
optimaltree = eincode2contractiontree(optcode)

# calculate the complexity of the contraction
cc = OMEinsumContractionOrders.contraction_complexity(optcode, size_dict)
if verbose
println("Optimal contraction tree: ", optimaltree)
println(cc)
end
return optimaltree, 2.0^(cc.tc)
end

function optimize_kahypar(network, optdata::Dict{TDK, TDV}, verbose::Bool) where{TDK, TDV}
try
@assert TDV <: Number
catch
throw(ArgumentError("The values of the optdata dictionary must be of type Number"))
end

# transform the network as EinCode
code, size_dict = network2eincode(network, optdata)
# optimize the contraction order using OMEinsumContractionOrders, which gives a NestedEinsum
optcode = optimize_kahypar_auto(code, size_dict)

# transform the optimized contraction order back to the network
optimaltree = eincode2contractiontree(optcode)

# calculate the complexity of the contraction
cc = OMEinsumContractionOrders.contraction_complexity(optcode, size_dict)
if verbose
println("Optimal contraction tree: ", optimaltree)
println(cc)
end
return optimaltree, 2.0^(cc.tc)
end

function network2eincode(network, optdata)
indices = unique(vcat(network...))
new_indices = Dict([i => j for (j, i) in enumerate(indices)])
new_network = [Int[new_indices[i] for i in t] for t in network]
open_edges = Int[]
# if a indices appear only once, it is an open index
for i in indices
if sum([i in t for t in network]) == 1
push!(open_edges, new_indices[i])
end
end
size_dict = Dict([new_indices[i] => optdata[i] for i in keys(optdata)])
return EinCode(new_network, open_edges), size_dict
end

function eincode2contractiontree(eincode::NestedEinsum)
if isleaf(eincode)
return eincode.tensorindex
else
return [eincode2contractiontree(arg) for arg in eincode.args]
end
end

# TreeSA returns a SlicedEinsum, with nslice = 0, so directly using the eins
function eincode2contractiontree(eincode::SlicedEinsum)
return eincode2contractiontree(eincode.eins)
end

end
13 changes: 12 additions & 1 deletion src/indexnotation/optimaltree.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
function optimaltree(network, optdata::Dict; verbose::Bool=false)
struct TreeOptimizer{T} end # T is a Symbol for the algorithm

function optimaltree(network, optdata::Dict; optimizer::TreeOptimizer = TreeOptimizer{:NCon}(), verbose::Bool=false)
@debug "Using optimizer $(typeof(optimizer))"
return optimaltree(network, optdata, optimizer, verbose)
end

function optimaltree(network, optdata::Dict, ::TreeOptimizer{T}, verbose::Bool) where{T}
throw(ArgumentError("Unknown optimizer: $T. Hint: may need to load extensions, e.g. `using OMEinsumContractionOrders`"))
end

function optimaltree(network, optdata::Dict, ::TreeOptimizer{:NCon}, verbose::Bool)
numtensors = length(network)
allindices = unique(vcat(network...))
numindices = length(allindices)
Expand Down
9 changes: 8 additions & 1 deletion src/indexnotation/tensormacros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function tensorparser(tensorexpr, kwargs...)
end
end
# now handle the remaining keyword arguments
optimizer = TreeOptimizer{:NCon}() # the default optimizer
for (name, val) in kwargs
if name == :order
isexpr(val, :tuple) ||
Expand All @@ -85,6 +86,12 @@ function tensorparser(tensorexpr, kwargs...)
val in (:warn, :cache) ||
throw(ArgumentError("Invalid use of `costcheck`, should be `costcheck=warn` or `costcheck=cache`"))
parser.contractioncostcheck = val
elseif name == :opt_algorithm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here you will have to be a little careful, in principle there is no order to the keyword arguments.
If I am not mistaken, now if the user first supplies opt=(a = 2, b = 2, ...), and only afterwards opt_algorithm=..., the algorithm will be ignored.

My best guess is that you probably want to attempt to extract an optimizer and optdict, and only after all kwargs have been parsed, you can construct the contractiontreebuilder

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for pointing that out, I did not notice that perviously.
In the revised version, the contractiontreebuilder will be constructed after all other kwargs have been parsed.

if val isa Symbol
optimizer = TreeOptimizer{val}()
else
throw(ArgumentError("Invalid use of `opt_algorithm`, should be `opt_algorithm=NCon` or `opt_algorithm=NameOfAlgorithm`"))
end
elseif name == :opt
if val isa Bool && val
optdict = optdata(tensorexpr)
Expand All @@ -93,7 +100,7 @@ function tensorparser(tensorexpr, kwargs...)
else
throw(ArgumentError("Invalid use of `opt`, should be `opt=true` or `opt=OptExpr`"))
end
parser.contractiontreebuilder = network -> optimaltree(network, optdict)[1]
parser.contractiontreebuilder = network -> optimaltree(network, optdict, optimizer = optimizer)[1]
elseif !(name == :backend || name == :allocator) # these two have been handled
throw(ArgumentError("Unknown keyword argument `name`."))
end
Expand Down
16 changes: 16 additions & 0 deletions test/macro_kwargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,19 @@ end
end
@test D1 ≈ D2 ≈ D3 ≈ D4 ≈ D5
end

@testset "opt_algorithm" begin
A = randn(5, 5, 5, 5)
B = randn(5, 5, 5)
C = randn(5, 5, 5)

@tensor opt = true begin
D1[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@tensor opt = true opt_algorithm = NCon begin
D2[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@test D1 ≈ D2
end
112 changes: 112 additions & 0 deletions test/omeinsumcontractionordres.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
@testset "@tensor dependency check" begin
@test_throws ArgumentError begin
A = rand(2, 2)
B = rand(2, 2)
C = rand(2, 2)
ex = :(@tensor opt=(i=>2, j=>2, k=>2) opt_algorithm=GreedyMethod S[] := A[i, j] * B[j, k] * C[i, k])
macroexpand(Main, ex)
end
end

using OMEinsumContractionOrders

@testset "OMEinsumContractionOrders optimization algorithms" begin
A = randn(5, 5, 5, 5)
B = randn(5, 5, 5)
C = randn(5, 5, 5)

@tensor begin
D1[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = GreedyMethod begin
D2[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = TreeSA begin
D3[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = KaHyParBipartite begin
D4[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = SABipartite begin
D5[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = ExactTreewidth begin
D6[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b]
end

@tensor opt = (1 => 5, 2 => 5, 3 => 5, 4 => 5, 5 => 5, 6 => 5, 7 => 5) opt_algorithm = GreedyMethod begin
D7[1, 2, 3, 4] := A[1, 5, 3, 6] * B[7, 4, 5] * C[7, 6, 2]
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be nice if you could somehow check that the algorithms are indeed being used. Perhaps you can enable debug logging for this section, and check the logs for the debug message?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug message has been added, one can enable debug logging by setting ENV["JULIA_DEBUG"] = "TensorOperationsOMEinsumContractionOrdersExt".
The logging is enabled in tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to add that you can maybe verify that this worked using: @test_logs? I am not sure about this however, there might be something weird with evaluation time because it is in a macro.


@test D1 D2 D3 D4 D5 D6 D7


A = rand(2, 2)
B = rand(2, 2, 2)
C = rand(2, 2)
D = rand(2, 2)
E = rand(2, 2, 2)
F = rand(2, 2)

@tensor opt = true begin
s1[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m]
end

@tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = GreedyMethod begin
s2[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m]
end

@tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = TreeSA begin
s3[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m]
end

@tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = KaHyParBipartite begin
s4[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m]
end

@tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = SABipartite begin
s5[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m]
end

@tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = ExactTreewidth begin
s6[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m]
end

@test s1 s2 s3 s4 s5 s6

A = randn(5, 5, 5)
B = randn(5, 5, 5)
C = randn(5, 5, 5)
α = randn()

@tensor opt = true begin
D1[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m]
end

@tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = GreedyMethod begin
D2[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m]
end

@tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = TreeSA begin
D3[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m]
end

@tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = KaHyParBipartite begin
D4[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m]
end

@tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = SABipartite begin
D5[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m]
end

@tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = ExactTreewidth begin
D6[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m]
end

@test D1 D2 D3 D4 D5 D6
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ end
include("butensor.jl")
end

# note: OMEinsumContractionOrders should not be loaded before this point
# as there is a test which requires it to be loaded after
@testset "OMEinsumOptimizer extension" begin
include("omeinsumcontractionordres.jl")
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the compat with OMEinsum only allows julia >= v1.9, so you should probably also restrict the tests to that version

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the lowest version of julia in ci to 1.9. Is there any reason that 1.8 needed to be supported?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I did not explain properly, what I meant is:

Suggested change
@testset "OMEinsumOptimizer extension" begin
include("omeinsumcontractionordres.jl")
end
if VERSION >= v"1.9"
@testset "OMEinsumOptimizer extension" begin
include("omeinsumcontractionordres.jl")
end
end

I don't think there is any particular reason to support 1.8, but since it works anyways, and until 1.10 becomes the new LTS, I would just keep this supported.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I add this in tests and change the lowest support version back to v1.8

Copy link
Author

@ArrogantGao ArrogantGao Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We managed to make OMEinsumContractionOrders.jl support Julia v1.8 in v0.9.2, I updated the compat so that CI for 1.8 should pass now.


@testset "Polynomials" begin
include("polynomials.jl")
end
Expand Down