28
28
as_axis_tree ,
29
29
)
30
30
from pyop3 .axtree .layout import eval_offset
31
- from pyop3 .axtree .tree import Indexed , IndexedAxisTree , MultiArrayCollector
31
+ from pyop3 .axtree .tree import IndexedAxisTree , MultiArrayCollector
32
32
from pyop3 .buffer import Buffer , DistributedBuffer
33
33
from pyop3 .dtypes import IntType , ScalarType
34
34
from pyop3 .lang import KernelArgument , ReplaceAssignment
@@ -74,12 +74,6 @@ def __init__(self, array, indices, path=None):
74
74
def __getinitargs__ (self ):
75
75
return (self .array , self .indices , self .path )
76
76
77
- # def __str__(self) -> str:
78
- # return f"{self.array.name}[{{{', '.join(f'{i[0]}: {i[1]}' for i in self.indices.items())}}}]"
79
- #
80
- # def __repr__(self) -> str:
81
- # return f"MultiArrayVariable({self.array!r}, {self.indices!r})"
82
-
83
77
84
78
from pymbolic .mapper .stringifier import PREC_CALL , PREC_NONE , StringifyMapper
85
79
@@ -101,29 +95,11 @@ def stringify_array(self, array, enclosing_prec, *args, **kwargs):
101
95
CalledMapVariable = ArrayVar
102
96
103
97
104
- # does not belong here!
105
- # class CalledMapVariable(ArrayVar):
106
- # mapper_method = sys.intern("map_called_map_variable")
107
- #
108
- # def __init__(self, array, path, input_index_exprs, shape_index_exprs):
109
- # super().__init__(array, {**input_index_exprs, **shape_index_exprs}, path)
110
- # self.input_index_exprs = freeze(input_index_exprs)
111
- # self.shape_index_exprs = freeze(shape_index_exprs)
112
- #
113
- # def __getinitargs__(self):
114
- # return (
115
- # self.array,
116
- # self.target_path,
117
- # self.input_index_exprs,
118
- # self.shape_index_exprs,
119
- # )
120
-
121
-
122
98
class FancyIndexWriteException (Exception ):
123
99
pass
124
100
125
101
126
- class HierarchicalArray (Array , Indexed , ContextFree , KernelArgument ):
102
+ class HierarchicalArray (Array , ContextFree , KernelArgument ):
127
103
"""Multi-dimensional, hierarchical array.
128
104
129
105
Parameters
@@ -189,35 +165,6 @@ def __init__(
189
165
# TODO This attr really belongs to the buffer not the array
190
166
self .constant = constant
191
167
192
- # if some_but_not_all(x is None for x in [target_paths, index_exprs]):
193
- # raise ValueError
194
-
195
- # if target_paths is None:
196
- # target_paths = axes._default_target_paths()
197
- # if index_exprs is None:
198
- # index_exprs = axes._default_index_exprs()
199
- #
200
- # self._target_paths = freeze(target_paths)
201
- # self._index_exprs = freeze(index_exprs)
202
- # self._outer_loops = outer_loops or ()
203
- #
204
- # self._layouts = layouts if layouts is not None else axes.layouts
205
-
206
- @property
207
- @deprecated ()
208
- def target_paths (self ):
209
- return self .axes .target_paths
210
-
211
- @property
212
- @deprecated ()
213
- def index_exprs (self ):
214
- return self .axes .index_exprs
215
-
216
- @property
217
- @deprecated ()
218
- def layouts (self ):
219
- return self .axes .layouts
220
-
221
168
def __str__ (self ):
222
169
return self .name
223
170
@@ -253,14 +200,9 @@ def getitem(self, indices, *, strict=False):
253
200
# to be iterable (which it's not). This avoids some confusing behaviour.
254
201
__iter__ = None
255
202
256
- @property
257
- @deprecated ("buffer" )
258
- def array (self ):
259
- return self .buffer
260
-
261
203
@property
262
204
def dtype (self ):
263
- return self .array .dtype
205
+ return self .buffer .dtype
264
206
265
207
@property
266
208
def kernel_dtype (self ):
@@ -425,7 +367,7 @@ def outer_loops(self):
425
367
426
368
@property
427
369
def sf (self ):
428
- return self .array .sf
370
+ return self .buffer .sf
429
371
430
372
@property
431
373
def comm (self ):
@@ -436,11 +378,13 @@ def datamap(self):
436
378
datamap_ = {}
437
379
datamap_ .update (self .buffer .datamap )
438
380
datamap_ .update (self .axes .datamap )
439
- for index_exprs in self .index_exprs .values ():
381
+
382
+ # FIXME, deleting this breaks stuff...
383
+ for index_exprs in self .axes .index_exprs .values ():
440
384
for expr in index_exprs .values ():
441
385
for array in MultiArrayCollector ()(expr ):
442
386
datamap_ .update (array .datamap )
443
- for layout_expr in self .layouts .values ():
387
+ for layout_expr in self .axes . layouts .values ():
444
388
for array in MultiArrayCollector ()(layout_expr ):
445
389
datamap_ .update (array .datamap )
446
390
return freeze (datamap_ )
@@ -457,9 +401,9 @@ def assemble(self, update_leaves=False):
457
401
458
402
"""
459
403
if update_leaves :
460
- self .array ._reduce_then_broadcast ()
404
+ self .buffer ._reduce_then_broadcast ()
461
405
else :
462
- self .array ._reduce_leaves_to_roots ()
406
+ self .buffer ._reduce_leaves_to_roots ()
463
407
464
408
def materialize (self ) -> HierarchicalArray :
465
409
"""Return a new "unindexed" array with the same shape."""
@@ -491,7 +435,7 @@ def _with_axes(self, axes):
491
435
assert False , "do not use, it's wrong"
492
436
return type (self )(
493
437
axes ,
494
- data = self .array ,
438
+ data = self .buffer ,
495
439
max_value = self .max_value ,
496
440
name = self .name ,
497
441
)
@@ -518,7 +462,7 @@ def from_list(cls, data, axis_labels, name=None, dtype=ScalarType, inc=0):
518
462
if isinstance (count , Sequence ):
519
463
count = cls .from_list (count , axis_labels [:- 1 ], name , dtype , inc + 1 )
520
464
subaxis = Axis (count , axis_labels [- 1 ])
521
- axes = count .axes .add_subaxis (subaxis , count .axes .leaf )
465
+ axes = count .axes .add_axis (subaxis , count .axes .leaf )
522
466
else :
523
467
axes = AxisTree (Axis (count , axis_labels [- 1 ]))
524
468
@@ -547,15 +491,6 @@ def set_value(self, indices, value, path=None, *, loop_exprs=pmap()):
547
491
offset = self .axes .offset (indices , path , loop_exprs = loop_exprs )
548
492
self .buffer .data_wo [offset ] = value
549
493
550
- # def offset(self, indices, path=None, *, loop_exprs=pmap()):
551
- # return eval_offset(
552
- # self.axes,
553
- # self.subst_layouts,
554
- # indices,
555
- # path,
556
- # loop_exprs=loop_exprs,
557
- # )
558
-
559
494
def select_axes (self , indices ):
560
495
selected = []
561
496
current_axis = self .axes
@@ -657,10 +592,6 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray:
657
592
#
658
593
# return ContextSensitiveMultiArray(array_per_context)
659
594
660
- @property
661
- def array (self ):
662
- return self ._shared_attr ("array" )
663
-
664
595
@property
665
596
def buffer (self ):
666
597
return self ._shared_attr ("buffer" )
0 commit comments