Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 785b938

Browse files
committed
Parallel bits
1 parent 10c618c commit 785b938

File tree

3 files changed

+38
-20
lines changed

3 files changed

+38
-20
lines changed

pyop3/array/petsc.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ def _iter_nest_labels(
308308
@cached_property
309309
@PETSc.Log.EventDecorator()
310310
def maps(self):
311-
print("HIT!")
312311
from pyop3.axtree.layout import my_product
313312

314313
# TODO: Don't think these need to be lists here.
@@ -380,7 +379,8 @@ def maps(self):
380379
# target_indices = {idx.index.id: idx.target_exprs for idx in idxs}
381380
target_indices = merge_dicts([idx.replace_map for idx in idxs])
382381

383-
for p in self.raxes.iter(idxs):
382+
# for p in self.raxes.iter(idxs):
383+
for p in self.raxes.iter(idxs, include_ghost_points=True): # seems to fix things
384384
target_path = p.target_path
385385
target_exprs = p.target_exprs
386386
for key in dropped_rkeys:
@@ -390,6 +390,7 @@ def maps(self):
390390
offset = orig_raxes.offset(
391391
target_exprs, target_path, loop_exprs=target_indices
392392
)
393+
393394
rmap.set_value(
394395
p.source_exprs,
395396
offset,
@@ -402,7 +403,8 @@ def maps(self):
402403
# target_indices = {idx.index.id: idx.target_exprs for idx in idxs}
403404
target_indices = merge_dicts([idx.replace_map for idx in idxs])
404405

405-
for p in self.caxes.iter(idxs):
406+
# for p in self.caxes.iter(idxs):
407+
for p in self.caxes.iter(idxs, include_ghost_points=True): # seems to fix things
406408
target_path = p.target_path
407409
target_exprs = p.target_exprs
408410
for key in dropped_ckeys:
@@ -419,6 +421,8 @@ def maps(self):
419421
loop_exprs=target_indices,
420422
)
421423

424+
# breakpoint()
425+
422426
return (rmap, cmap)
423427

424428
@property

pyop3/axtree/tree.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,20 @@ def _collect_datamap(axis, *subdatamaps, axes):
209209

210210

211211
class AxisComponent(LabelledNodeComponent):
212-
fields = LabelledNodeComponent.fields | {"size", "unit"}
212+
fields = LabelledNodeComponent.fields | {"size", "unit", "rank_equal"}
213213

214214
def __init__(
215215
self,
216216
size,
217217
label=None,
218218
*,
219219
unit=False,
220+
rank_equal=False,
220221
):
221222
from pyop3.array import HierarchicalArray
222223

223224
if isinstance(size, collections.abc.Iterable):
225+
assert not rank_equal # nasty
224226
owned_count, count = map(int, size)
225227
distributed = True
226228
assert isinstance(owned_count, numbers.Integral) and isinstance(count, numbers.Integral)
@@ -247,8 +249,8 @@ def __init__(
247249
self.unit = unit
248250
self.distributed = distributed
249251

250-
# remove
251-
self.rank_equal = not distributed
252+
# cleanup
253+
self.rank_equal = rank_equal
252254

253255
# redone because otherwise getting a bizarre error (numpy types have confusing behaviour!)
254256
# def __eq__(self, other):
@@ -370,7 +372,7 @@ def from_serial(cls, serial: Axis, sf):
370372
# renumber the serial axis to store ghost entries at the end of the vector
371373
component_sizes, numbering = partition_ghost_points(serial, sf)
372374
components = [
373-
c.copy(size=(size, c.count)) for c, size in checked_zip(serial.components, component_sizes)
375+
c.copy(size=(size, c.count), rank_equal=False) for c, size in checked_zip(serial.components, component_sizes)
374376
]
375377
return cls(components, serial.label, numbering=numbering, sf=sf)
376378

@@ -811,14 +813,6 @@ def global_numbering(self):
811813

812814
return numbering[self._buffer_indices_ghost]
813815

814-
@property
815-
def comm(self):
816-
paraxes = [axis for axis in self.nodes if axis.sf is not None]
817-
if not paraxes:
818-
return MPI.COMM_SELF
819-
else:
820-
return single_valued(ax.comm for ax in paraxes)
821-
822816
@cached_property
823817
def leaf_target_paths(self):
824818
return tuple(
@@ -907,7 +901,7 @@ def _collect_owned_index_tree(self, axis=None):
907901
slice_component = AffineSliceComponent(component.label)
908902
slice_components.append(slice_component)
909903

910-
slice_ = Slice(axis.label, slice_components)
904+
slice_ = Slice(axis.label, slice_components, label=axis.label)
911905

912906
index_tree = IndexTree(slice_)
913907
for component, slice_component in checked_zip(axis.components, slice_components):
@@ -1025,6 +1019,14 @@ def sf(self) -> StarForest:
10251019
iremote = np.concatenate(iremotes)
10261020
return StarForest.from_graph(self.size, nroots, ilocal, iremote, self.comm)
10271021

1022+
@property
1023+
def comm(self):
1024+
paraxes = [axis for axis in self.nodes if axis.sf is not None]
1025+
if not paraxes:
1026+
return MPI.COMM_SELF
1027+
else:
1028+
return single_valued(ax.comm for ax in paraxes)
1029+
10281030
@cached_property
10291031
def datamap(self):
10301032
if self.is_empty:
@@ -1127,8 +1129,6 @@ def _buffer_indices_ghost(self):
11271129
return slice(None)
11281130

11291131

1130-
# are all of these necessary?
1131-
# class IndexedAxisTree(Indexed, BaseAxisTree):
11321132
class IndexedAxisTree(BaseAxisTree):
11331133
def __init__(
11341134
self,
@@ -1156,6 +1156,10 @@ def __init__(
11561156
def unindexed(self):
11571157
return self._unindexed
11581158

1159+
@property
1160+
def comm(self):
1161+
return self.unindexed.comm
1162+
11591163
@property
11601164
def target_paths(self):
11611165
return self._target_paths

pyop3/itree/tree.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,7 @@ def _(slice_: Slice, indices, *, target_path_acc, prev_axes, **kwargs):
13011301
size = target_cpt.count
13021302
else:
13031303
size = target_cpt.count[indices]
1304+
rank_equal = True
13041305
else:
13051306
if subslice.stop is None:
13061307
stop = target_cpt.count
@@ -1313,22 +1314,31 @@ def _(slice_: Slice, indices, *, target_path_acc, prev_axes, **kwargs):
13131314

13141315
owned_count = min(target_cpt.owned_count, stop)
13151316
count = stop
1316-
size = (owned_count, count)
1317+
rank_equal = False
1318+
1319+
if owned_count == count:
1320+
size = count
1321+
else:
1322+
size = (owned_count, count)
13171323
else:
13181324
size = math.ceil((stop - subslice.start) / subslice.step)
1325+
rank_equal = True
13191326

13201327
else:
13211328
assert isinstance(subslice, Subset)
13221329
size = subslice.array.axes.leaf_component.count
13231330

1331+
# kind of misleading, the values may differ I think
1332+
rank_equal = True
1333+
13241334
if target_cpt.distributed:
13251335
raise NotImplementedError
13261336

13271337
if is_full_slice and subslice.label_was_none:
13281338
mylabel = subslice.component
13291339
else:
13301340
mylabel = subslice.label
1331-
cpt = AxisComponent(size, label=mylabel, unit=target_cpt.unit)
1341+
cpt = AxisComponent(size, label=mylabel, unit=target_cpt.unit, rank_equal=rank_equal)
13321342
components.append(cpt)
13331343

13341344
target_path_per_subslice.append(pmap({slice_.axis: subslice.component}))

0 commit comments

Comments
 (0)