-
Notifications
You must be signed in to change notification settings - Fork 66
Adding OMEinsumContractionOrders.jl as a backend of TensorOperations.jl for finding the optimal contraction order #185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 8 commits
e4b102d
d6bbe6f
6d61103
329952d
c14fa7c
76698c1
81d64af
3846242
5f442e2
70301a7
61d1cce
f55c1b2
92fe983
4adc9cf
a120620
d39c6ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| *.jl.cov | ||
| *.jl.*.cov | ||
| *.jl.mem | ||
| Manifest.toml | ||
| Manifest.toml | ||
| .vscode | ||
| .DS_Store |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,6 @@ | ||
| name = "TensorOperations" | ||
| uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" | ||
| authors = [ | ||
| "Lukas Devos <[email protected]>", | ||
| "Maarten Van Damme <[email protected]>", | ||
| "Jutho Haegeman <[email protected]>", | ||
| ] | ||
| authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"] | ||
| version = "5.0.0" | ||
|
|
||
| [deps] | ||
|
|
@@ -24,11 +20,13 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" | |
| Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" | ||
| CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
| ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
| OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" | ||
| cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" | ||
|
|
||
| [extensions] | ||
| TensorOperationsBumperExt = "Bumper" | ||
| TensorOperationsChainRulesCoreExt = "ChainRulesCore" | ||
| TensorOperationsOMEinsumContractionOrdersExt = "OMEinsumContractionOrders" | ||
| TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] | ||
|
|
||
| [compat] | ||
|
|
@@ -41,6 +39,7 @@ DynamicPolynomials = "0.5" | |
| LRUCache = "1" | ||
| LinearAlgebra = "1.6" | ||
| Logging = "1.6" | ||
| OMEinsumContractionOrders = "0.9" | ||
| PackageExtensionCompat = "1" | ||
| PtrArrays = "1.2" | ||
| Random = "1" | ||
|
|
@@ -59,19 +58,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | |
| ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" | ||
| DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" | ||
| Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" | ||
| OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" | ||
| Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
| Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
| cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" | ||
|
|
||
| [targets] | ||
| test = [ | ||
| "Test", | ||
| "Random", | ||
| "DynamicPolynomials", | ||
| "ChainRulesTestUtils", | ||
| "CUDA", | ||
| "cuTENSOR", | ||
| "Aqua", | ||
| "Logging", | ||
| "Bumper", | ||
| ] | ||
| test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "OMEinsumContractionOrders"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| module TensorOperationsOMEinsumContractionOrdersExt | ||
|
|
||
| using TensorOperations | ||
| using TensorOperations: TensorOperations as TO | ||
| using TensorOperations: TreeOptimizer | ||
| using OMEinsumContractionOrders | ||
| using OMEinsumContractionOrders: EinCode, NestedEinsum, SlicedEinsum, isleaf, optimize_kahypar_auto | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:GreedyMethod}, verbose::Bool) where{TDK, TDV} | ||
| ome_optimizer = GreedyMethod() | ||
| return optimize(network, optdata, ome_optimizer, verbose) | ||
| end | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:KaHyParBipartite}, verbose::Bool) where{TDK, TDV} | ||
|
|
||
| return optimize_kahypar(network, optdata, verbose) | ||
| end | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:TreeSA}, verbose::Bool) where{TDK, TDV} | ||
| ome_optimizer = TreeSA() | ||
| return optimize(network, optdata, ome_optimizer, verbose) | ||
| end | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:SABipartite}, verbose::Bool) where{TDK, TDV} | ||
| ome_optimizer = SABipartite() | ||
| return optimize(network, optdata, ome_optimizer, verbose) | ||
| end | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK, TDV}, ::TreeOptimizer{:ExactTreewidth}, verbose::Bool) where{TDK, TDV} | ||
| ome_optimizer = ExactTreewidth() | ||
| return optimize(network, optdata, ome_optimizer, verbose) | ||
| end | ||
|
|
||
| function optimize(network, optdata::Dict{TDK, TDV}, ome_optimizer::CodeOptimizer, verbose::Bool) where{TDK, TDV} | ||
| try | ||
| @assert TDV <: Number | ||
| catch | ||
| throw(ArgumentError("The values of the optdata dictionary must be of type Number")) | ||
| end | ||
|
|
||
| # transform the network as EinCode | ||
| code, size_dict = network2eincode(network, optdata) | ||
| # optimize the contraction order using OMEinsumContractionOrders, which gives a NestedEinsum | ||
| optcode = optimize_code(code, size_dict, ome_optimizer) | ||
|
|
||
| # transform the optimized contraction order back to the network | ||
| optimaltree = eincode2contractiontree(optcode) | ||
|
|
||
| # calculate the complexity of the contraction | ||
| cc = OMEinsumContractionOrders.contraction_complexity(optcode, size_dict) | ||
| if verbose | ||
| println("Optimal contraction tree: ", optimaltree) | ||
| println(cc) | ||
| end | ||
| return optimaltree, 2.0^(cc.tc) | ||
| end | ||
|
|
||
| function optimize_kahypar(network, optdata::Dict{TDK, TDV}, verbose::Bool) where{TDK, TDV} | ||
| try | ||
| @assert TDV <: Number | ||
| catch | ||
| throw(ArgumentError("The values of the optdata dictionary must be of type Number")) | ||
| end | ||
|
|
||
| # transform the network as EinCode | ||
| code, size_dict = network2eincode(network, optdata) | ||
| # optimize the contraction order using OMEinsumContractionOrders, which gives a NestedEinsum | ||
| optcode = optimize_kahypar_auto(code, size_dict) | ||
|
|
||
| # transform the optimized contraction order back to the network | ||
| optimaltree = eincode2contractiontree(optcode) | ||
|
|
||
| # calculate the complexity of the contraction | ||
| cc = OMEinsumContractionOrders.contraction_complexity(optcode, size_dict) | ||
| if verbose | ||
| println("Optimal contraction tree: ", optimaltree) | ||
| println(cc) | ||
| end | ||
| return optimaltree, 2.0^(cc.tc) | ||
| end | ||
|
|
||
| function network2eincode(network, optdata) | ||
| indices = unique(vcat(network...)) | ||
| new_indices = Dict([i => j for (j, i) in enumerate(indices)]) | ||
| new_network = [Int[new_indices[i] for i in t] for t in network] | ||
| open_edges = Int[] | ||
| # if a indices appear only once, it is an open index | ||
| for i in indices | ||
| if sum([i in t for t in network]) == 1 | ||
| push!(open_edges, new_indices[i]) | ||
| end | ||
| end | ||
| size_dict = Dict([new_indices[i] => optdata[i] for i in keys(optdata)]) | ||
| return EinCode(new_network, open_edges), size_dict | ||
| end | ||
|
|
||
| function eincode2contractiontree(eincode::NestedEinsum) | ||
| if isleaf(eincode) | ||
| return eincode.tensorindex | ||
| else | ||
| return [eincode2contractiontree(arg) for arg in eincode.args] | ||
| end | ||
| end | ||
|
|
||
| # TreeSA returns a SlicedEinsum, with nslice = 0, so directly using the eins | ||
| function eincode2contractiontree(eincode::SlicedEinsum) | ||
| return eincode2contractiontree(eincode.eins) | ||
| end | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,6 +71,7 @@ function tensorparser(tensorexpr, kwargs...) | |
| end | ||
| end | ||
| # now handle the remaining keyword arguments | ||
| optimizer = TreeOptimizer{:NCon}() # the default optimizer | ||
| for (name, val) in kwargs | ||
| if name == :order | ||
| isexpr(val, :tuple) || | ||
|
|
@@ -85,6 +86,12 @@ function tensorparser(tensorexpr, kwargs...) | |
| val in (:warn, :cache) || | ||
| throw(ArgumentError("Invalid use of `costcheck`, should be `costcheck=warn` or `costcheck=cache`")) | ||
| parser.contractioncostcheck = val | ||
| elseif name == :opt_algorithm | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think here you will have to be a little careful, in principle there is no order to the keyword arguments. My best guess is that you probably want to attempt to extract an optimizer and optdict, and only after all kwargs have been parsed, you can construct the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you very much for pointing that out, I did not notice that perviously. |
||
| if val isa Symbol | ||
| optimizer = TreeOptimizer{val}() | ||
| else | ||
| throw(ArgumentError("Invalid use of `opt_algorithm`, should be `opt_algorithm=NCon` or `opt_algorithm=NameOfAlgorithm`")) | ||
| end | ||
| elseif name == :opt | ||
| if val isa Bool && val | ||
| optdict = optdata(tensorexpr) | ||
|
|
@@ -93,7 +100,7 @@ function tensorparser(tensorexpr, kwargs...) | |
| else | ||
| throw(ArgumentError("Invalid use of `opt`, should be `opt=true` or `opt=OptExpr`")) | ||
| end | ||
| parser.contractiontreebuilder = network -> optimaltree(network, optdict)[1] | ||
| parser.contractiontreebuilder = network -> optimaltree(network, optdict, optimizer = optimizer)[1] | ||
| elseif !(name == :backend || name == :allocator) # these two have been handled | ||
| throw(ArgumentError("Unknown keyword argument `name`.")) | ||
| end | ||
|
|
||
lkdvos marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| @testset "@tensor dependency check" begin | ||
| @test_throws ArgumentError begin | ||
| A = rand(2, 2) | ||
| B = rand(2, 2) | ||
| C = rand(2, 2) | ||
| ex = :(@tensor opt=(i=>2, j=>2, k=>2) opt_algorithm=GreedyMethod S[] := A[i, j] * B[j, k] * C[i, k]) | ||
| macroexpand(Main, ex) | ||
| end | ||
| end | ||
|
|
||
| using OMEinsumContractionOrders | ||
|
|
||
| @testset "OMEinsumContractionOrders optimization algorithms" begin | ||
| A = randn(5, 5, 5, 5) | ||
| B = randn(5, 5, 5) | ||
| C = randn(5, 5, 5) | ||
|
|
||
| @tensor begin | ||
| D1[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b] | ||
| end | ||
|
|
||
| @tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = GreedyMethod begin | ||
| D2[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b] | ||
| end | ||
|
|
||
| @tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = TreeSA begin | ||
| D3[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b] | ||
| end | ||
|
|
||
| @tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = KaHyParBipartite begin | ||
| D4[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b] | ||
| end | ||
|
|
||
| @tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = SABipartite begin | ||
| D5[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b] | ||
| end | ||
|
|
||
| @tensor opt = (a => 5, b => 5, c => 5, d => 5, e => 5, f => 5, g => 5) opt_algorithm = ExactTreewidth begin | ||
| D6[a, b, c, d] := A[a, e, c, f] * B[g, d, e] * C[g, f, b] | ||
| end | ||
|
|
||
| @tensor opt = (1 => 5, 2 => 5, 3 => 5, 4 => 5, 5 => 5, 6 => 5, 7 => 5) opt_algorithm = GreedyMethod begin | ||
| D7[1, 2, 3, 4] := A[1, 5, 3, 6] * B[7, 4, 5] * C[7, 6, 2] | ||
| end | ||
|
||
|
|
||
| @test D1 ≈ D2 ≈ D3 ≈ D4 ≈ D5 ≈ D6 ≈ D7 | ||
|
|
||
|
|
||
| A = rand(2, 2) | ||
| B = rand(2, 2, 2) | ||
| C = rand(2, 2) | ||
| D = rand(2, 2) | ||
| E = rand(2, 2, 2) | ||
| F = rand(2, 2) | ||
|
|
||
| @tensor opt = true begin | ||
| s1[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = GreedyMethod begin | ||
| s2[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = TreeSA begin | ||
| s3[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = KaHyParBipartite begin | ||
| s4[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = SABipartite begin | ||
| s5[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 2, j => 2, k => 2, l => 2, m => 2, n => 2, o => 2) opt_algorithm = ExactTreewidth begin | ||
| s6[] := A[i, k] * B[i, j, l] * C[j, m] * D[k, n] * E[n, l, o] * F[o, m] | ||
| end | ||
|
|
||
| @test s1 ≈ s2 ≈ s3 ≈ s4 ≈ s5 ≈ s6 | ||
|
|
||
| A = randn(5, 5, 5) | ||
| B = randn(5, 5, 5) | ||
| C = randn(5, 5, 5) | ||
| α = randn() | ||
|
|
||
| @tensor opt = true begin | ||
| D1[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = GreedyMethod begin | ||
| D2[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = TreeSA begin | ||
| D3[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = KaHyParBipartite begin | ||
| D4[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = SABipartite begin | ||
| D5[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m] | ||
| end | ||
|
|
||
| @tensor opt = (i => 5, j => 5, k => 5, l => 5, m => 5) opt_algorithm = ExactTreewidth begin | ||
| D6[m] := A[i, j, k] * B[j, k, l] * C[i, l, m] + α * A[i, j, k] * B[j, k, l] * C[i, l, m] | ||
| end | ||
|
|
||
| @test D1 ≈ D2 ≈ D3 ≈ D4 ≈ D5 ≈ D6 | ||
| end | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -39,6 +39,12 @@ end | |||||||||||||||||
| include("butensor.jl") | ||||||||||||||||||
| end | ||||||||||||||||||
|
|
||||||||||||||||||
| # note: OMEinsumContractionOrders should not be loaded before this point | ||||||||||||||||||
| # as there is a test which requires it to be loaded after | ||||||||||||||||||
| @testset "OMEinsumOptimizer extension" begin | ||||||||||||||||||
| include("omeinsumcontractionordres.jl") | ||||||||||||||||||
| end | ||||||||||||||||||
|
||||||||||||||||||
| @testset "OMEinsumOptimizer extension" begin | |
| include("omeinsumcontractionordres.jl") | |
| end | |
| if VERSION >= v"1.9" | |
| @testset "OMEinsumOptimizer extension" begin | |
| include("omeinsumcontractionordres.jl") | |
| end | |
| end |
I don't think there is any particular reason to support 1.8, but since it works anyways, and until 1.10 becomes the new LTS, I would just keep this supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I add this in tests and change the lowest support version back to v1.8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We managed to make OMEinsumContractionOrders.jl support Julia v1.8 in v0.9.2, I updated the compat so that CI for 1.8 should pass now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit of a strange construction,
try ... catch ... endtypically means you would handle the error it throws, not rethrow a new error. You can add messages to the assertion error too:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the mistake, now I simply use
@asserinstead oftry...catch...