Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 10c618c

Browse files
committed
Various changes
1 parent 203e218 commit 10c618c

13 files changed

+624
-625
lines changed

pyop3/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
Subset,
3535
TabulatedMapComponent,
3636
)
37-
from pyop3.itree.tree import ScalarIndex
37+
from pyop3.itree.tree import ScalarIndex, as_index_forest
3838
from pyop3.lang import ( # noqa: F401
3939
INC,
4040
MAX_RW,

pyop3/array/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import abc
22

3-
from pyop3.lang import KernelArgument, ReplaceAssignment
3+
from pyop3.lang import FunctionArgument, ReplaceAssignment
44
from pyop3.utils import UniqueNameGenerator
55

66

7-
class Array(KernelArgument, abc.ABC):
7+
class Array(FunctionArgument, abc.ABC):
88
_prefix = "array"
99
_name_generator = UniqueNameGenerator()
1010

pyop3/array/harray.py

+37-61
Original file line numberDiff line numberDiff line change
@@ -139,19 +139,23 @@ def __init__(
139139

140140
if data is not None:
141141
data = np.asarray(data, dtype=dtype)
142-
shape = data.shape
143-
else:
144-
shape = axes.global_size
142+
143+
# always deal with flattened data
144+
if len(data.shape) > 1:
145+
data = data.flatten()
146+
if data.size != axes.unindexed.global_size:
147+
raise ValueError("Data shape does not match axes")
145148

146149
# IndexedAxisTrees do not currently have SFs, so create a dummy one here
147150
if isinstance(axes, AxisTree):
148151
sf = axes.sf
149152
else:
150153
assert isinstance(axes, IndexedAxisTree)
151-
sf = serial_forest(axes.global_size)
154+
# not sure this is the right thing to do
155+
sf = serial_forest(axes.unindexed.global_size)
152156

153157
data = DistributedBuffer(
154-
shape,
158+
axes.unindexed.global_size, # not a useful property anymore
155159
sf,
156160
dtype,
157161
name=self.name,
@@ -165,6 +169,8 @@ def __init__(
165169
# TODO This attr really belongs to the buffer not the array
166170
self.constant = constant
167171

172+
# self._cache = {}
173+
168174
def __str__(self):
169175
return self.name
170176

@@ -177,14 +183,20 @@ def getitem(self, indices, *, strict=False):
177183
if indices is Ellipsis:
178184
return self
179185

186+
# key = (indices, strict)
187+
# if key in self._cache:
188+
# return self._cache[key]
189+
180190
index_forest = as_index_forest(indices, axes=self.axes, strict=strict)
181191
if index_forest.keys() == {pmap()}:
182192
index_tree = index_forest[pmap()]
183193
indexed_axes = index_axes(index_tree, pmap(), self.axes)
184194
axes = compose_axes(indexed_axes, self.axes)
185-
return HierarchicalArray(
195+
dat = HierarchicalArray(
186196
axes, data=self.buffer, max_value=self.max_value, name=self.name
187197
)
198+
# self._cache[key] = dat
199+
return dat
188200

189201
array_per_context = {}
190202
for loop_context, index_tree in index_forest.items():
@@ -194,7 +206,9 @@ def getitem(self, indices, *, strict=False):
194206
axes, data=self.buffer, name=self.name, max_value=self.max_value
195207
)
196208

197-
return ContextSensitiveMultiArray(array_per_context)
209+
dat = ContextSensitiveMultiArray(array_per_context)
210+
# self._cache[key] = dat
211+
return dat
198212

199213
# Since __getitem__ is implemented, this class is implicitly considered
200214
# to be iterable (which it's not). This avoids some confusing behaviour.
@@ -218,16 +232,16 @@ def data(self):
218232
@property
219233
def data_rw(self):
220234
self._check_no_copy_access()
221-
return self.buffer.data_rw[self._buffer_indices]
235+
return self.buffer.data_rw[self.axes._buffer_indices]
222236

223237
@property
224238
def data_ro(self):
225-
if not isinstance(self._buffer_indices, slice):
239+
if not isinstance(self.axes._buffer_indices, slice):
226240
warning(
227241
"Read-only access to the array is provided with a copy, "
228242
"consider avoiding if possible."
229243
)
230-
return self.buffer.data_ro[self._buffer_indices]
244+
return self.buffer.data_ro[self.axes._buffer_indices]
231245

232246
@property
233247
def data_wo(self):
@@ -239,7 +253,7 @@ def data_wo(self):
239253
can be dropped.
240254
"""
241255
self._check_no_copy_access()
242-
return self.buffer.data_wo[self._buffer_indices]
256+
return self.buffer.data_wo[self.axes._buffer_indices]
243257

244258
@property
245259
@deprecated(".data_rw_with_halos")
@@ -249,16 +263,16 @@ def data_with_halos(self):
249263
@property
250264
def data_rw_with_halos(self):
251265
self._check_no_copy_access()
252-
return self.buffer.data_rw[self._buffer_indices_ghost]
266+
return self.buffer.data_rw[self.axes._buffer_indices_ghost]
253267

254268
@property
255269
def data_ro_with_halos(self):
256-
if not isinstance(self._buffer_indices_ghost, slice):
270+
if not isinstance(self.axes._buffer_indices_ghost, slice):
257271
warning(
258272
"Read-only access to the array is provided with a copy, "
259273
"consider avoiding if possible."
260274
)
261-
return self.buffer.data_ro[self._buffer_indices_ghost]
275+
return self.buffer.data_ro[self.axes._buffer_indices_ghost]
262276

263277
@property
264278
def data_wo_with_halos(self):
@@ -270,54 +284,10 @@ def data_wo_with_halos(self):
270284
can be dropped.
271285
"""
272286
self._check_no_copy_access()
273-
return self.buffer.data_wo[self._buffer_indices_ghost]
274-
275-
@cached_property
276-
def _buffer_indices(self):
277-
return self._collect_buffer_indices(ghost=False)
278-
279-
@cached_property
280-
def _buffer_indices_ghost(self):
281-
return self._collect_buffer_indices(ghost=True)
282-
283-
def _collect_buffer_indices(self, *, ghost: bool):
284-
# TODO: This method is inefficient as for affine things we still tabulate
285-
# everything first. It would be best to inspect index_exprs to determine
286-
# if a slice is sufficient, but this is hard.
287-
# TODO: This should be more widely cached, don't want to tabulate more often
288-
# than required.
289-
290-
size = self.axes.size if ghost else self.axes.owned.size
291-
assert size > 0
292-
293-
indices = np.full(size, -1, dtype=IntType)
294-
# TODO: Handle any outer loops.
295-
# TODO: Generate code for this.
296-
for i, p in enumerate(self.axes.iter()):
297-
indices[i] = self.axes.offset(p.source_exprs, p.source_path)
298-
debug_assert(lambda: (indices >= 0).all())
299-
300-
# The packed indices are collected component-by-component so, for
301-
# numbered multi-component axes, they are not in ascending order.
302-
# We sort them so we can test for "affine-ness".
303-
indices.sort()
304-
305-
# See if we can represent these indices as a slice. This is important
306-
# because slices enable no-copy access to the array.
307-
steps = np.unique(indices[1:] - indices[:-1])
308-
if len(steps) == 0:
309-
start = just_one(indices)
310-
return slice(start, start + 1, 1)
311-
elif len(steps) == 1:
312-
start = indices[0]
313-
stop = indices[-1] + 1
314-
(step,) = steps
315-
return slice(start, stop, step)
316-
else:
317-
return indices
287+
return self.buffer.data_wo[self.axes._buffer_indices_ghost]
318288

319289
def _check_no_copy_access(self):
320-
if not isinstance(self._buffer_indices, slice):
290+
if not isinstance(self.axes._buffer_indices, slice):
321291
raise FancyIndexWriteException(
322292
"Writing to the array directly is not supported for "
323293
"non-trivially indexed (i.e. sliced) arrays."
@@ -541,7 +511,8 @@ def __init__(self, *args, **kwargs):
541511
super().__init__(*args, **kwargs)
542512

543513

544-
# Now ContextSensitiveDat
514+
# NOTE: I think I can probably get rid of this class and wrap the
515+
# context-sensitivity inside the axis tree.
545516
class ContextSensitiveMultiArray(Array, ContextSensitive):
546517
def __init__(self, arrays):
547518
name = single_valued(a.name for a in arrays.values())
@@ -596,6 +567,11 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray:
596567
def buffer(self):
597568
return self._shared_attr("buffer")
598569

570+
# this is really nasty, but need to know if wrapping a Mat
571+
@property
572+
def mat(self):
573+
return self._shared_attr("mat")
574+
599575
@property
600576
def dtype(self):
601577
return self._shared_attr("dtype")

pyop3/array/petsc.py

+37-33
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,7 @@ def __init__(
104104
self.mat_type = mat_type
105105
self.mat = mat
106106

107-
# TODO: delete
108-
# self.rtarget_paths = rtarget_paths
109-
# self.rindex_exprs = rindex_exprs
110-
# self.orig_raxes = orig_raxes
111-
# self.router_loops = router_loops
112-
# self.ctarget_paths = ctarget_paths
113-
# self.cindex_exprs = cindex_exprs
114-
# self.orig_caxes = orig_caxes
115-
# self.couter_loops = couter_loops
107+
# self._cache = {}
116108

117109
def __getitem__(self, indices):
118110
return self.getitem(indices, strict=False)
@@ -122,6 +114,11 @@ def __getitem__(self, indices):
122114
__iter__ = None
123115

124116
def getitem(self, indices, *, strict=False):
117+
# does not work as indices may not be hashable, parse first?
118+
# cache_key = (indices, strict)
119+
# if cache_key in self._cache:
120+
# return self._cache[cache_key]
121+
125122
if len(indices) != 2:
126123
raise ValueError
127124

@@ -174,13 +171,15 @@ def getitem(self, indices, *, strict=False):
174171
indexed_caxes = index_axes(ctree, pmap(), self.caxes)
175172
caxes = compose_axes(indexed_caxes, self.caxes)
176173

177-
return type(self)(
174+
mat = type(self)(
178175
raxes,
179176
caxes,
180177
mat_type=self.mat_type,
181178
mat=self.mat,
182179
name=self.name,
183180
)
181+
# self._cache[cache_key] = mat
182+
return mat
184183

185184
# Otherwise we are context-sensitive
186185
arrays = {}
@@ -202,7 +201,9 @@ def getitem(self, indices, *, strict=False):
202201
name=self.name,
203202
)
204203
# But this is now a PetscMat...
205-
return ContextSensitiveMultiArray(arrays)
204+
mat = ContextSensitiveMultiArray(arrays)
205+
# self._cache[cache_key] = mat
206+
return mat
206207

207208
# like Dat, bad name? handle?
208209
@property
@@ -305,7 +306,9 @@ def _iter_nest_labels(
305306
yield (rlabel_acc_, clabel_acc_)
306307

307308
@cached_property
309+
@PETSc.Log.EventDecorator()
308310
def maps(self):
311+
print("HIT!")
309312
from pyop3.axtree.layout import my_product
310313

311314
# TODO: Don't think these need to be lists here.
@@ -342,11 +345,11 @@ def maps(self):
342345
].items()
343346
if ax == cfield_axis.label
344347
)
345-
orig_caxes = AxisTree(self.orig_caxes[cfield].node_map)
348+
orig_caxes = AxisTree(self.caxes.unindexed[cfield].node_map)
346349
orig_caxess = [orig_caxes]
347350
dropped_ckeys = {cfield_axis.label}
348351
else:
349-
orig_caxess = [self.orig_caxes]
352+
orig_caxess = [self.caxes.unindexed]
350353
dropped_ckeys = set()
351354
else:
352355
orig_raxess = [self.raxes.unindexed]
@@ -426,6 +429,22 @@ def rmap(self):
426429
def cmap(self):
427430
return self.maps[1]
428431

432+
@cached_property
433+
def row_lgmap_dat(self):
434+
if self.nested or self.mat_type == "baij":
435+
raise NotImplementedError("Use a smaller set of axes here")
436+
return HierarchicalArray(self.raxes, data=self.raxes.unindexed.global_numbering)
437+
438+
@cached_property
439+
def column_lgmap_dat(self):
440+
if self.nested or self.mat_type == "baij":
441+
raise NotImplementedError("Use a smaller set of axes here")
442+
return HierarchicalArray(self.caxes, data=self.caxes.unindexed.global_numbering)
443+
444+
@cached_property
445+
def comm(self):
446+
return single_valued([self.raxes.comm, self.caxes.comm])
447+
429448
@property
430449
def shape(self):
431450
return (self.raxes.size, self.caxes.size)
@@ -447,15 +466,6 @@ def axes(self):
447466
@classmethod
448467
def _make_mat(cls, raxes, caxes, mat_type):
449468
if isinstance(mat_type, collections.abc.Mapping):
450-
# if strictly_all(c.unit for c in raxes.root.components):
451-
# riter = tuple((c.label, raxes[c.label]) for c in raxes.root.components)
452-
# else:
453-
# riter = [(None, raxes)]
454-
# if strictly_all(c.unit for c in caxes.root.components):
455-
# citer = tuple((c.label, caxes[c.label]) for c in caxes.root.components)
456-
# else:
457-
# citer = [(None, caxes)]
458-
459469
# TODO: This is very ugly
460470
rsize = max(x or 0 for x, _ in mat_type.keys()) + 1
461471
csize = max(y or 0 for _, y in mat_type.keys()) + 1
@@ -482,10 +492,6 @@ def kernel_dtype(self):
482492

483493

484494
class Sparsity(AbstractMat):
485-
# def __init__(self, *args, **kwargs):
486-
# super().__init__(*args, **kwargs)
487-
# self._lazy_template = None
488-
489495
def materialize(self) -> PETSc.Mat:
490496
if not hasattr(self, "_lazy_template"):
491497
self.assemble()
@@ -495,6 +501,7 @@ def materialize(self) -> PETSc.Mat:
495501
# template.preallocateWithMatPreallocator(self.mat)
496502
# We can safely set these options since by using a sparsity we
497503
# are asserting that we know where the non-zeros are going.
504+
# NOTE: These may already get set by PETSc.
498505
template.setOption(PETSc.Mat.Option.NEW_NONZERO_LOCATION_ERR, True)
499506
template.setOption(PETSc.Mat.Option.IGNORE_ZERO_ENTRIES, True)
500507

@@ -537,15 +544,12 @@ def _make_monolithic_mat(cls, raxes, caxes, mat_type: str):
537544
mat = PETSc.Mat().create(comm)
538545
mat.setType(PETSc.Mat.Type.PREALLOCATOR)
539546

540-
# breakpoint()
541-
# else:
542547
# None is for the global size, PETSc will figure it out for us
543548
sizes = ((raxes.owned.size, None), (caxes.owned.size, None))
544-
545549
mat.setSizes(sizes)
546550

547-
rlgmap = PETSc.LGMap().create(raxes.global_numbering(), comm=comm)
548-
clgmap = PETSc.LGMap().create(caxes.global_numbering(), comm=comm)
551+
rlgmap = PETSc.LGMap().create(raxes.global_numbering, comm=comm)
552+
clgmap = PETSc.LGMap().create(caxes.global_numbering, comm=comm)
549553
mat.setLGMap(rlgmap, clgmap)
550554

551555
mat.setUp()
@@ -613,8 +617,8 @@ def _make_monolithic_mat(cls, raxes, caxes, mat_type: str):
613617
sizes = ((raxes.owned.size, None), (caxes.owned.size, None))
614618
mat.setSizes(sizes)
615619

616-
rlgmap = PETSc.LGMap().create(raxes.global_numbering(), comm=comm)
617-
clgmap = PETSc.LGMap().create(caxes.global_numbering(), comm=comm)
620+
rlgmap = PETSc.LGMap().create(raxes.global_numbering, comm=comm)
621+
clgmap = PETSc.LGMap().create(caxes.global_numbering, comm=comm)
618622
mat.setLGMap(rlgmap, clgmap)
619623

620624
mat.setUp()

0 commit comments

Comments
 (0)