@@ -139,19 +139,23 @@ def __init__(
139
139
140
140
if data is not None :
141
141
data = np .asarray (data , dtype = dtype )
142
- shape = data .shape
143
- else :
144
- shape = axes .global_size
142
+
143
+ # always deal with flattened data
144
+ if len (data .shape ) > 1 :
145
+ data = data .flatten ()
146
+ if data .size != axes .unindexed .global_size :
147
+ raise ValueError ("Data shape does not match axes" )
145
148
146
149
# IndexedAxisTrees do not currently have SFs, so create a dummy one here
147
150
if isinstance (axes , AxisTree ):
148
151
sf = axes .sf
149
152
else :
150
153
assert isinstance (axes , IndexedAxisTree )
151
- sf = serial_forest (axes .global_size )
154
+ # not sure this is the right thing to do
155
+ sf = serial_forest (axes .unindexed .global_size )
152
156
153
157
data = DistributedBuffer (
154
- shape ,
158
+ axes . unindexed . global_size , # not a useful property anymore
155
159
sf ,
156
160
dtype ,
157
161
name = self .name ,
@@ -165,6 +169,8 @@ def __init__(
165
169
# TODO This attr really belongs to the buffer not the array
166
170
self .constant = constant
167
171
172
+ # self._cache = {}
173
+
168
174
def __str__ (self ):
169
175
return self .name
170
176
@@ -177,14 +183,20 @@ def getitem(self, indices, *, strict=False):
177
183
if indices is Ellipsis :
178
184
return self
179
185
186
+ # key = (indices, strict)
187
+ # if key in self._cache:
188
+ # return self._cache[key]
189
+
180
190
index_forest = as_index_forest (indices , axes = self .axes , strict = strict )
181
191
if index_forest .keys () == {pmap ()}:
182
192
index_tree = index_forest [pmap ()]
183
193
indexed_axes = index_axes (index_tree , pmap (), self .axes )
184
194
axes = compose_axes (indexed_axes , self .axes )
185
- return HierarchicalArray (
195
+ dat = HierarchicalArray (
186
196
axes , data = self .buffer , max_value = self .max_value , name = self .name
187
197
)
198
+ # self._cache[key] = dat
199
+ return dat
188
200
189
201
array_per_context = {}
190
202
for loop_context , index_tree in index_forest .items ():
@@ -194,7 +206,9 @@ def getitem(self, indices, *, strict=False):
194
206
axes , data = self .buffer , name = self .name , max_value = self .max_value
195
207
)
196
208
197
- return ContextSensitiveMultiArray (array_per_context )
209
+ dat = ContextSensitiveMultiArray (array_per_context )
210
+ # self._cache[key] = dat
211
+ return dat
198
212
199
213
# Since __getitem__ is implemented, this class is implicitly considered
200
214
# to be iterable (which it's not). This avoids some confusing behaviour.
@@ -218,16 +232,16 @@ def data(self):
218
232
@property
219
233
def data_rw (self ):
220
234
self ._check_no_copy_access ()
221
- return self .buffer .data_rw [self ._buffer_indices ]
235
+ return self .buffer .data_rw [self .axes . _buffer_indices ]
222
236
223
237
@property
224
238
def data_ro (self ):
225
- if not isinstance (self ._buffer_indices , slice ):
239
+ if not isinstance (self .axes . _buffer_indices , slice ):
226
240
warning (
227
241
"Read-only access to the array is provided with a copy, "
228
242
"consider avoiding if possible."
229
243
)
230
- return self .buffer .data_ro [self ._buffer_indices ]
244
+ return self .buffer .data_ro [self .axes . _buffer_indices ]
231
245
232
246
@property
233
247
def data_wo (self ):
@@ -239,7 +253,7 @@ def data_wo(self):
239
253
can be dropped.
240
254
"""
241
255
self ._check_no_copy_access ()
242
- return self .buffer .data_wo [self ._buffer_indices ]
256
+ return self .buffer .data_wo [self .axes . _buffer_indices ]
243
257
244
258
@property
245
259
@deprecated (".data_rw_with_halos" )
@@ -249,16 +263,16 @@ def data_with_halos(self):
249
263
@property
250
264
def data_rw_with_halos (self ):
251
265
self ._check_no_copy_access ()
252
- return self .buffer .data_rw [self ._buffer_indices_ghost ]
266
+ return self .buffer .data_rw [self .axes . _buffer_indices_ghost ]
253
267
254
268
@property
255
269
def data_ro_with_halos (self ):
256
- if not isinstance (self ._buffer_indices_ghost , slice ):
270
+ if not isinstance (self .axes . _buffer_indices_ghost , slice ):
257
271
warning (
258
272
"Read-only access to the array is provided with a copy, "
259
273
"consider avoiding if possible."
260
274
)
261
- return self .buffer .data_ro [self ._buffer_indices_ghost ]
275
+ return self .buffer .data_ro [self .axes . _buffer_indices_ghost ]
262
276
263
277
@property
264
278
def data_wo_with_halos (self ):
@@ -270,54 +284,10 @@ def data_wo_with_halos(self):
270
284
can be dropped.
271
285
"""
272
286
self ._check_no_copy_access ()
273
- return self .buffer .data_wo [self ._buffer_indices_ghost ]
274
-
275
- @cached_property
276
- def _buffer_indices (self ):
277
- return self ._collect_buffer_indices (ghost = False )
278
-
279
- @cached_property
280
- def _buffer_indices_ghost (self ):
281
- return self ._collect_buffer_indices (ghost = True )
282
-
283
- def _collect_buffer_indices (self , * , ghost : bool ):
284
- # TODO: This method is inefficient as for affine things we still tabulate
285
- # everything first. It would be best to inspect index_exprs to determine
286
- # if a slice is sufficient, but this is hard.
287
- # TODO: This should be more widely cached, don't want to tabulate more often
288
- # than required.
289
-
290
- size = self .axes .size if ghost else self .axes .owned .size
291
- assert size > 0
292
-
293
- indices = np .full (size , - 1 , dtype = IntType )
294
- # TODO: Handle any outer loops.
295
- # TODO: Generate code for this.
296
- for i , p in enumerate (self .axes .iter ()):
297
- indices [i ] = self .axes .offset (p .source_exprs , p .source_path )
298
- debug_assert (lambda : (indices >= 0 ).all ())
299
-
300
- # The packed indices are collected component-by-component so, for
301
- # numbered multi-component axes, they are not in ascending order.
302
- # We sort them so we can test for "affine-ness".
303
- indices .sort ()
304
-
305
- # See if we can represent these indices as a slice. This is important
306
- # because slices enable no-copy access to the array.
307
- steps = np .unique (indices [1 :] - indices [:- 1 ])
308
- if len (steps ) == 0 :
309
- start = just_one (indices )
310
- return slice (start , start + 1 , 1 )
311
- elif len (steps ) == 1 :
312
- start = indices [0 ]
313
- stop = indices [- 1 ] + 1
314
- (step ,) = steps
315
- return slice (start , stop , step )
316
- else :
317
- return indices
287
+ return self .buffer .data_wo [self .axes ._buffer_indices_ghost ]
318
288
319
289
def _check_no_copy_access (self ):
320
- if not isinstance (self ._buffer_indices , slice ):
290
+ if not isinstance (self .axes . _buffer_indices , slice ):
321
291
raise FancyIndexWriteException (
322
292
"Writing to the array directly is not supported for "
323
293
"non-trivially indexed (i.e. sliced) arrays."
@@ -541,7 +511,8 @@ def __init__(self, *args, **kwargs):
541
511
super ().__init__ (* args , ** kwargs )
542
512
543
513
544
- # Now ContextSensitiveDat
514
+ # NOTE: I think I can probably get rid of this class and wrap the
515
+ # context-sensitivity inside the axis tree.
545
516
class ContextSensitiveMultiArray (Array , ContextSensitive ):
546
517
def __init__ (self , arrays ):
547
518
name = single_valued (a .name for a in arrays .values ())
@@ -596,6 +567,11 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray:
596
567
def buffer (self ):
597
568
return self ._shared_attr ("buffer" )
598
569
570
+ # this is really nasty, but need to know if wrapping a Mat
571
+ @property
572
+ def mat (self ):
573
+ return self ._shared_attr ("mat" )
574
+
599
575
@property
600
576
def dtype (self ):
601
577
return self ._shared_attr ("dtype" )
0 commit comments