@@ -209,18 +209,20 @@ def _collect_datamap(axis, *subdatamaps, axes):
209
209
210
210
211
211
class AxisComponent (LabelledNodeComponent ):
212
- fields = LabelledNodeComponent .fields | {"size" , "unit" }
212
+ fields = LabelledNodeComponent .fields | {"size" , "unit" , "rank_equal" }
213
213
214
214
def __init__ (
215
215
self ,
216
216
size ,
217
217
label = None ,
218
218
* ,
219
219
unit = False ,
220
+ rank_equal = False ,
220
221
):
221
222
from pyop3 .array import HierarchicalArray
222
223
223
224
if isinstance (size , collections .abc .Iterable ):
225
+ assert not rank_equal # nasty
224
226
owned_count , count = map (int , size )
225
227
distributed = True
226
228
assert isinstance (owned_count , numbers .Integral ) and isinstance (count , numbers .Integral )
@@ -247,8 +249,8 @@ def __init__(
247
249
self .unit = unit
248
250
self .distributed = distributed
249
251
250
- # remove
251
- self .rank_equal = not distributed
252
+ # cleanup
253
+ self .rank_equal = rank_equal
252
254
253
255
# redone because otherwise getting a bizarre error (numpy types have confusing behaviour!)
254
256
# def __eq__(self, other):
@@ -370,7 +372,7 @@ def from_serial(cls, serial: Axis, sf):
370
372
# renumber the serial axis to store ghost entries at the end of the vector
371
373
component_sizes , numbering = partition_ghost_points (serial , sf )
372
374
components = [
373
- c .copy (size = (size , c .count )) for c , size in checked_zip (serial .components , component_sizes )
375
+ c .copy (size = (size , c .count ), rank_equal = False ) for c , size in checked_zip (serial .components , component_sizes )
374
376
]
375
377
return cls (components , serial .label , numbering = numbering , sf = sf )
376
378
@@ -811,14 +813,6 @@ def global_numbering(self):
811
813
812
814
return numbering [self ._buffer_indices_ghost ]
813
815
814
- @property
815
- def comm (self ):
816
- paraxes = [axis for axis in self .nodes if axis .sf is not None ]
817
- if not paraxes :
818
- return MPI .COMM_SELF
819
- else :
820
- return single_valued (ax .comm for ax in paraxes )
821
-
822
816
@cached_property
823
817
def leaf_target_paths (self ):
824
818
return tuple (
@@ -907,7 +901,7 @@ def _collect_owned_index_tree(self, axis=None):
907
901
slice_component = AffineSliceComponent (component .label )
908
902
slice_components .append (slice_component )
909
903
910
- slice_ = Slice (axis .label , slice_components )
904
+ slice_ = Slice (axis .label , slice_components , label = axis . label )
911
905
912
906
index_tree = IndexTree (slice_ )
913
907
for component , slice_component in checked_zip (axis .components , slice_components ):
@@ -1025,6 +1019,14 @@ def sf(self) -> StarForest:
1025
1019
iremote = np .concatenate (iremotes )
1026
1020
return StarForest .from_graph (self .size , nroots , ilocal , iremote , self .comm )
1027
1021
1022
+ @property
1023
+ def comm (self ):
1024
+ paraxes = [axis for axis in self .nodes if axis .sf is not None ]
1025
+ if not paraxes :
1026
+ return MPI .COMM_SELF
1027
+ else :
1028
+ return single_valued (ax .comm for ax in paraxes )
1029
+
1028
1030
@cached_property
1029
1031
def datamap (self ):
1030
1032
if self .is_empty :
@@ -1127,8 +1129,6 @@ def _buffer_indices_ghost(self):
1127
1129
return slice (None )
1128
1130
1129
1131
1130
- # are all of these necessary?
1131
- # class IndexedAxisTree(Indexed, BaseAxisTree):
1132
1132
class IndexedAxisTree (BaseAxisTree ):
1133
1133
def __init__ (
1134
1134
self ,
@@ -1156,6 +1156,10 @@ def __init__(
1156
1156
def unindexed (self ):
1157
1157
return self ._unindexed
1158
1158
1159
+ @property
1160
+ def comm (self ):
1161
+ return self .unindexed .comm
1162
+
1159
1163
@property
1160
1164
def target_paths (self ):
1161
1165
return self ._target_paths
0 commit comments