Skip to content

Commit 55dda4f

Browse files
committed
Update
1 parent 0d74462 commit 55dda4f

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

scratchpad/tn_api/tn.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import math
33
from dataclasses import dataclass
44
from typing import TypeVar, Generic, Iterable
5-
from qtree import np_framework
65

76
class Array(np.ndarray):
87
shape: tuple
@@ -59,14 +58,23 @@ def __init__(self, *args, **kwargs):
5958
self._tensors = []
6059
self._edges = tuple()
6160
self.shape = tuple()
62-
self.buckets = []
6361
self.data_dict = {}
6462

6563
# slice not inplace
6664
def slice(self, slice_dict: dict) -> 'TensorNetwork':
6765
tn = self.copy()
68-
sliced_buckets = np_framework.get_sliced_np_buckets(self.buckets, self.data_dict, slice_dict)
69-
tn.buckets = sliced_buckets
66+
sliced_tns = []
67+
for tensor in tn._tensors:
68+
slice_bounds = []
69+
for idx in range(tensor.ndim):
70+
try:
71+
slice_bounds.append(slice_dict[idx])
72+
except KeyError:
73+
slice_bounds.append(slice(None))
74+
75+
sliced_tns.append(tensor[tuple(slice_bounds)])
76+
77+
tn._tensors = sliced_tns
7078
return tn
7179

7280
def copy(self):
@@ -161,4 +169,6 @@ def __repr__(self):
161169

162170
if __name__ == "__main__":
163171
tn = TensorNetwork.new_random_cpu(2, 3, 4)
172+
slice_dict = {0: slice(0, 2), 1: slice(1, 3)}
173+
sliced_tn = tn.slice(slice_dict)
164174
import pdb; pdb.set_trace()

0 commit comments

Comments
 (0)