Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 8c4790c

Browse files
committed
WIP
* About 8 tests currently failing, but they are non-essential for Firedrake. * Substantially improve PetscMat implementation. Row and column axes must now be distinctly labelled. This avoids confusing failures. * Add `Pack` class because sometimes `__getitem__` isn't enough (for example if you want a specific DoF layout).
1 parent 5cb6a32 commit 8c4790c

File tree

7 files changed

+142
-42
lines changed

7 files changed

+142
-42
lines changed

pyop3/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
Function,
4646
Loop,
4747
OpaqueKernelArgument,
48+
Pack,
4849
ReplaceAssignment,
4950
do_loop,
5051
loop,

pyop3/array/petsc.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def __init__(self, raxes, caxes, *, name=None):
105105
self.raxes = raxes
106106
self.caxes = caxes
107107

108+
axes = PartialAxisTree(raxes.parent_to_children)
109+
for leaf_axis, leaf_cpt in raxes.leaves:
110+
# do *not* uniquify, it makes indexing very complicated. Instead assert
111+
# that external indices and axes must be sufficiently unique.
112+
axes = axes.add_subtree(caxes, leaf_axis, leaf_cpt, uniquify_ids=True)
113+
self.axes = AxisTree(axes.parent_to_children)
114+
108115
def __getitem__(self, indices):
109116
# TODO also support context-free (see MultiArray.__getitem__)
110117
if len(indices) != 2:
@@ -151,6 +158,11 @@ def __getitem__(self, indices):
151158

152159
arrays = {}
153160
for ctx, (rtree, ctree) in rcforest.items():
161+
tree = rtree
162+
for rleaf, clabel in rtree.leaves:
163+
tree = tree.add_subtree(ctree, rleaf, clabel, uniquify_ids=True)
164+
indexed_axes = _index_axes(tree, ctx, self.axes)
165+
154166
indexed_raxes = _index_axes(rtree, ctx, self.raxes)
155167
indexed_caxes = _index_axes(ctree, ctx, self.caxes)
156168

@@ -201,21 +213,41 @@ def __getitem__(self, indices):
201213
# breakpoint()
202214
packed = PackedPetscMat(self, rmap, cmap, shape)
203215

204-
indexed_axes = PartialAxisTree(indexed_raxes.parent_to_children)
205-
for leaf_axis, leaf_cpt in indexed_raxes.leaves:
206-
indexed_axes = indexed_axes.add_subtree(
207-
indexed_caxes, leaf_axis, leaf_cpt, uniquify=True
208-
)
209-
indexed_axes = indexed_axes.set_up()
216+
# indexed_axes = PartialAxisTree(indexed_raxes.parent_to_children)
217+
# for leaf_axis, leaf_cpt in indexed_raxes.leaves:
218+
# indexed_axes = indexed_axes.add_subtree(
219+
# indexed_caxes, leaf_axis, leaf_cpt, uniquify=True
220+
# )
221+
# indexed_axes = indexed_axes.set_up()
222+
# node_map = dict(indexed_raxes.parent_to_children)
223+
# target_paths = dict(indexed_raxes.target_paths)
224+
# index_exprs = dict(indexed_raxes.index_exprs)
225+
# for leaf_axis, leaf_cpt in indexed_raxes.leaves:
226+
# for caxis in indexed_caxes.nodes:
227+
# if caxis.id not in indexed_raxes.parent_to_children:
228+
# cid = caxis.id
229+
# else:
230+
# cid = XXX
231+
#
232+
# for ccpt in caxis.components:
233+
# node_map.update(...)
234+
# indexed_axes = AxisTree(node_map, target_paths=???, index_exprs=???)
235+
# can I make indexed_axes simply???
236+
# breakpoint()
237+
238+
outer_loops = list(router_loops)
239+
all_ids = [l.id for l in router_loops]
240+
for ol in couter_loops:
241+
if ol.id not in all_ids:
242+
outer_loops.append(ol)
210243

211244
arrays[ctx] = HierarchicalArray(
212245
indexed_axes,
213246
data=packed,
214247
target_paths=indexed_axes.target_paths,
215248
index_exprs=indexed_axes.index_exprs,
216249
# TODO ordered set?
217-
outer_loops=router_loops
218-
+ tuple(filter(lambda l: l not in router_loops, couter_loops)),
250+
outer_loops=outer_loops,
219251
name=self.name,
220252
)
221253
return ContextSensitiveMultiArray(arrays)

pyop3/ir/lower.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def add_argument(self, array):
213213
# Temporaries can have variable size, hence we allocate space for the
214214
# largest possible array
215215
# shape = (array.alloc_size,)
216-
shape = self._temporary_shapes[array.name]
216+
shape = self._temporary_shapes.get(array.name, (array.alloc_size,))
217217

218218
# could rename array like the rest
219219
temp = lp.TemporaryVariable(array.name, dtype=array.dtype, shape=shape)
@@ -488,6 +488,11 @@ def _(expr: Assignment):
488488
return pmap()
489489

490490

491+
@_collect_temporary_shapes.register
492+
def _(expr: PetscMatInstruction):
493+
return pmap()
494+
495+
491496
@_collect_temporary_shapes.register
492497
def _(call: CalledFunction):
493498
return freeze(
@@ -496,7 +501,6 @@ def _(call: CalledFunction):
496501
for lp_arg, arg in checked_zip(
497502
call.function.code.default_entrypoint.args, call.arguments
498503
)
499-
if lp_arg.shape is not None
500504
}
501505
)
502506

pyop3/itree/tree.py

+34-11
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,14 @@ def datamap(self):
324324
# FIXME class hierarchy is very confusing
325325
class ContextFreeLoopIndex(ContextFreeIndex):
326326
def __init__(self, iterset: AxisTree, source_path, path, *, id=None):
327-
super().__init__(id=id)
327+
super().__init__(id=id, label=id, component_labels=("XXX",))
328328
self.iterset = iterset
329329
self.source_path = freeze(source_path)
330330
self.path = freeze(path)
331331

332+
# if self.label == "_label_ContextFreeLoopIndex_15":
333+
# breakpoint()
334+
332335
def with_context(self, context, *args):
333336
return self
334337

@@ -437,7 +440,8 @@ class Slice(ContextFreeIndex):
437440
438441
"""
439442

440-
fields = Index.fields | {"axis", "slices", "numbering"} - {"label"}
443+
# fields = Index.fields | {"axis", "slices", "numbering"} - {"label", "component_labels"}
444+
fields = {"axis", "slices", "numbering"}
441445

442446
def __init__(self, axis, slices, *, numbering=None, id=None):
443447
super().__init__(label=axis, id=id)
@@ -496,9 +500,10 @@ def datamap(self):
496500
return pmap(data)
497501

498502

499-
class CalledMap(Identified, LoopIterable):
500-
def __init__(self, map, from_index, *, id=None):
503+
class CalledMap(Identified, Labelled, LoopIterable):
504+
def __init__(self, map, from_index, *, id=None, label=None):
501505
Identified.__init__(self, id=id)
506+
Labelled.__init__(self, label=label)
502507
self.map = map
503508
self.from_index = from_index
504509

@@ -596,7 +601,9 @@ def with_context(self, context, axes=None):
596601
)
597602
if len(leaf_target_paths) == 0:
598603
raise RuntimeError
599-
return ContextFreeCalledMap(self.map, cf_index, leaf_target_paths, id=self.id)
604+
return ContextFreeCalledMap(
605+
self.map, cf_index, leaf_target_paths, id=self.id, label=self.label
606+
)
600607

601608
@property
602609
def name(self):
@@ -609,8 +616,12 @@ def connectivity(self):
609616

610617
# class ContextFreeCalledMap(Index, ContextFree):
611618
class ContextFreeCalledMap(Index):
612-
def __init__(self, map, index, leaf_target_paths, *, id=None):
613-
super().__init__(id=id)
619+
# FIXME this is clumsy
620+
# fields = Index.fields | {"map", "index", "leaf_target_paths"} - {"label", "component_labels"}
621+
fields = {"map", "index", "leaf_target_paths", "label", "id"}
622+
623+
def __init__(self, map, index, leaf_target_paths, *, id=None, label=None):
624+
super().__init__(id=id, label=label)
614625
self.map = map
615626
# better to call it "input_index"?
616627
self.index = index
@@ -1274,7 +1285,7 @@ def _make_leaf_axis_from_called_map(
12741285
{map_cpt.target_axis: map_cpt.target_component}
12751286
)
12761287

1277-
axisvar = AxisVariable(called_map.id)
1288+
axisvar = AxisVariable(called_map.label)
12781289

12791290
if not isinstance(map_cpt, TabulatedMapComponent):
12801291
raise NotImplementedError("Currently we assume only arrays here")
@@ -1297,7 +1308,7 @@ def _make_leaf_axis_from_called_map(
12971308
# a replacement
12981309
map_leaf_axis, map_leaf_component = map_axes.leaf
12991310
old_inner_index_expr = map_array.index_exprs[
1300-
map_leaf_axis.id, map_leaf_component.label
1311+
map_leaf_axis.id, map_leaf_component
13011312
]
13021313

13031314
my_index_exprs = {}
@@ -1322,7 +1333,10 @@ def _make_leaf_axis_from_called_map(
13221333
raise RuntimeError("map does not target any relevant axes")
13231334

13241335
axis = Axis(
1325-
components, label=called_map.id, id=axis_id, numbering=called_map.map.numbering
1336+
components,
1337+
label=called_map.label,
1338+
id=axis_id,
1339+
numbering=called_map.map.numbering,
13261340
)
13271341

13281342
return (
@@ -1358,9 +1372,18 @@ def _index_axes(
13581372
debug=debug,
13591373
)
13601374

1361-
# index trees should track outer loops, I think?
13621375
outer_loops += indices.outer_loops
13631376

1377+
# drop duplicates
1378+
outer_loops_ = []
1379+
allids = set()
1380+
for ol in outer_loops:
1381+
if ol.id in allids:
1382+
continue
1383+
outer_loops_.append(ol)
1384+
allids.add(ol.id)
1385+
outer_loops = tuple(outer_loops_)
1386+
13641387
# check that slices etc have not been missed
13651388
assert not include_loop_index_shape, "old option"
13661389
if axes is not None:

pyop3/lang.py

+16
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
checked_zip,
3232
just_one,
3333
merge_dicts,
34+
single_valued,
3435
unique,
3536
)
3637

@@ -80,6 +81,21 @@ def kernel_dtype(self):
8081
pass
8182

8283

84+
# this is an expression, like passing an array through to a kernel
85+
# but it is transformed first.
86+
class Pack(KernelArgument, ContextFree):
87+
def __init__(self, big, small):
88+
self.big = big
89+
self.small = small
90+
91+
@property
92+
def kernel_dtype(self):
93+
try:
94+
return single_valued([self.big.dtype, self.small.dtype])
95+
except ValueError:
96+
raise ValueError("dtypes must match")
97+
98+
8399
class Instruction(UniqueRecord, abc.ABC):
84100
pass
85101

pyop3/transform.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
DummyKernelArgument,
2525
Instruction,
2626
Loop,
27+
Pack,
2728
PetscMatAdd,
2829
PetscMatLoad,
2930
PetscMatStore,
@@ -276,7 +277,6 @@ def _(self, assignment: Assignment):
276277
axes = AxisTree(arg.axes.parent_to_children)
277278
new_arg = HierarchicalArray(
278279
axes,
279-
layouts=arg.layouts,
280280
data=NullBuffer(arg.dtype), # does this need a size?
281281
name=self._name_generator("t"),
282282
)
@@ -319,42 +319,56 @@ def _(self, terminal: CalledFunction):
319319
# this is a separate stage to the assignment operations because one
320320
# can index a packed mat. E.g. mat[p, q][::2] would decompose into
321321
# two calls, one to pack t0 <- mat[p, q] and another to pack t1 <- t0[::2]
322-
if isinstance(arg.buffer, PackedBuffer):
322+
if (
323+
isinstance(arg, Pack)
324+
and isinstance(arg.big.buffer, PackedBuffer)
325+
or not isinstance(arg, Pack)
326+
and isinstance(arg.buffer, PackedBuffer)
327+
):
328+
if isinstance(arg, Pack):
329+
myarg = arg.big
330+
else:
331+
myarg = arg
332+
323333
# TODO add PackedPetscMat as a subclass of buffer?
324-
if not isinstance(arg.buffer.array, PetscMat):
334+
if not isinstance(myarg.buffer.array, PetscMat):
325335
raise NotImplementedError("Only handle Mat at the moment")
326336

327-
axes = AxisTree(arg.axes.parent_to_children)
337+
axes = AxisTree(myarg.axes.parent_to_children)
328338
new_arg = HierarchicalArray(
329339
axes,
330-
data=NullBuffer(arg.dtype), # does this need a size?
340+
data=NullBuffer(myarg.dtype), # does this need a size?
331341
name=self._name_generator("t"),
332342
)
333343

334344
if intent == READ:
335-
gathers.append(PetscMatLoad(arg, new_arg))
345+
gathers.append(PetscMatLoad(myarg, new_arg))
336346
elif intent == WRITE:
337-
scatters.insert(0, PetscMatStore(arg, new_arg))
347+
scatters.insert(0, PetscMatStore(myarg, new_arg))
338348
elif intent == RW:
339-
gathers.append(PetscMatLoad(arg, new_arg))
340-
scatters.insert(0, PetscMatStore(arg, new_arg))
349+
gathers.append(PetscMatLoad(myarg, new_arg))
350+
scatters.insert(0, PetscMatStore(myarg, new_arg))
341351
else:
342352
assert intent == INC
343353
gathers.append(ReplaceAssignment(new_arg, 0))
344-
scatters.insert(0, PetscMatAdd(arg, new_arg))
354+
scatters.insert(0, PetscMatAdd(myarg, new_arg))
345355

346356
# the rest of the packing code is now dealing with the result of this
347357
# function call
348358
arg = new_arg
349359

350360
# unpick pack/unpack instructions
351361
if intent != NA and _requires_pack_unpack(arg):
352-
axes = AxisTree(arg.axes.parent_to_children)
353-
temporary = HierarchicalArray(
354-
axes,
355-
data=NullBuffer(arg.dtype), # does this need a size?
356-
name=self._name_generator("t"),
357-
)
362+
if isinstance(arg, Pack):
363+
temporary = arg.small
364+
arg = arg.big
365+
else:
366+
axes = AxisTree(arg.axes.parent_to_children)
367+
temporary = HierarchicalArray(
368+
axes,
369+
data=NullBuffer(arg.dtype), # does this need a size?
370+
name=self._name_generator("t"),
371+
)
358372

359373
if intent == READ:
360374
gathers.append(ReplaceAssignment(temporary, arg))
@@ -426,7 +440,7 @@ def _requires_pack_unpack(arg):
426440
# however, it is overly restrictive since we could pass something like dat[i0, :] directly
427441
# to a local kernel
428442
# return isinstance(arg, HierarchicalArray) and arg.subst_layouts != arg.layouts
429-
return isinstance(arg, HierarchicalArray)
443+
return isinstance(arg, HierarchicalArray) or isinstance(arg, Pack)
430444

431445

432446
# *below is old untested code*

0 commit comments

Comments
 (0)