Skip to content

Commit 187945a

Browse files
committed
add Tensor.swap
1 parent bb4a6f8 commit 187945a

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

discopy/tensor.py

+9
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ def cups(left, right):
163163
def caps(left, right):
164164
return Tensor.cups(left, right).dagger()
165165

166+
@staticmethod
167+
def swap(left, right):
168+
array = Id(left @ right).array
169+
source = range(len(left @ right), 2 * len(left @ right))
170+
target = [i + len(right) if i < len(left @ right @ left)
171+
else i - len(left) for i in source]
172+
return Tensor(left @ right, right @ left,
173+
np.moveaxis(array, source, target))
174+
166175

167176
class Id(Tensor):
168177
""" Implements the identity tensor for a given dimension.

test/test_tensor.py

+11
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ def test_Tensor_tensor():
7474
assert F(f) @ F(g) == F(f @ g)
7575

7676

77+
def test_tensor_swap():
78+
f = Tensor(Dim(2), Dim(2), [1, 0, 0, 1])
79+
g = Tensor(Dim(3), Dim(3), list(range(9)))
80+
swap = Tensor.swap(Dim(2), Dim(3))
81+
assert f @ g >> swap == swap >> g @ f
82+
83+
swaps = Tensor.swap(Dim(2), Dim(3, 3))
84+
assert swaps == swap @ Id(3) >> Id(3) @ swap
85+
assert swaps >> swaps.dagger() == Id(Dim(2, 3, 3))
86+
87+
7788
def test_TensorFunctor():
7889
assert repr(TensorFunctor({Ty('x'): 1}, {})) ==\
7990
"TensorFunctor(ob={Ty('x'): 1}, ar={})"

0 commit comments

Comments
 (0)