Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit f8625bf

Browse files
committed
Expunge ContextSensitiveMultiArray
Firedrake tests appear to be passing.
1 parent 40f3f9e commit f8625bf

File tree

7 files changed

+121
-45
lines changed

7 files changed

+121
-45
lines changed

pyop3/array/base.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import abc
22

3+
from pyop3.axtree import ContextAware
34
from pyop3.lang import FunctionArgument, ReplaceAssignment
45
from pyop3.utils import UniqueNameGenerator
56

67

7-
class Array(FunctionArgument, abc.ABC):
8+
class Array(ContextAware, FunctionArgument, abc.ABC):
89
_prefix = "array"
910
_name_generator = UniqueNameGenerator()
1011

@@ -16,3 +17,12 @@ def __init__(self, name=None, *, prefix=None) -> None:
1617
def assign(self, other, eager=True):
1718
expr = ReplaceAssignment(self, other)
1819
return expr() if eager else expr
20+
21+
@abc.abstractmethod
22+
def with_context(self):
23+
pass
24+
25+
@property
26+
@abc.abstractmethod
27+
def context_free(self):
28+
pass

pyop3/array/harray.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class FancyIndexWriteException(Exception):
9494
pass
9595

9696

97-
class HierarchicalArray(Array, ContextFree, KernelArgument):
97+
class HierarchicalArray(Array, KernelArgument):
9898
"""Multi-dimensional, hierarchical array.
9999
100100
Parameters
@@ -138,7 +138,7 @@ def __init__(
138138
# always deal with flattened data
139139
if len(data.shape) > 1:
140140
data = data.flatten()
141-
if data.size != axes.unindexed.global_size:
141+
if data.size != axes.alloc_size:
142142
raise ValueError("Data shape does not match axes")
143143

144144
# IndexedAxisTrees do not currently have SFs, so create a dummy one here
@@ -147,10 +147,10 @@ def __init__(
147147
else:
148148
assert isinstance(axes, (ContextSensitiveAxisTree, IndexedAxisTree))
149149
# not sure this is the right thing to do
150-
sf = serial_forest(axes.unindexed.global_size)
150+
sf = serial_forest(axes.alloc_size)
151151

152152
data = DistributedBuffer(
153-
axes.unindexed.global_size, # not a useful property anymore
153+
axes.alloc_size, # not a useful property anymore
154154
sf,
155155
dtype,
156156
name=self.name,
@@ -210,6 +210,25 @@ def getitem(self, indices, *, strict=False):
210210
# to be iterable (which it's not). This avoids some confusing behaviour.
211211
__iter__ = None
212212

213+
def with_context(self, context):
214+
return type(self)(
215+
self.axes.with_context(context),
216+
name=self.name,
217+
data=self.buffer,
218+
max_value=self.max_value,
219+
constant=self.constant,
220+
)
221+
222+
@property
223+
def context_free(self, context):
224+
return type(self)(
225+
self.axes.context_free,
226+
name=self.name,
227+
data=self.buffer,
228+
max_value=self.max_value,
229+
constant=self.constant,
230+
)
231+
213232
@property
214233
def dtype(self):
215234
return self.buffer.dtype

pyop3/array/petsc.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class PetscVecNest(PetscVec):
6464
...
6565

6666

67-
class AbstractMat(Array, ContextFree):
67+
# BaseMat?
68+
class AbstractMat(Array):
6869
DEFAULT_MAT_TYPE = PETSc.Mat.Type.AIJ
6970

7071
prefix = "mat"
@@ -216,6 +217,29 @@ def getitem(self, indices, *, strict=False):
216217
# self._cache[cache_key] = mat
217218
return mat
218219

220+
def with_context(self, context):
221+
row_axes = self.raxes.with_context(context)
222+
col_axes = self.caxes.with_context(context)
223+
return type(self)(
224+
row_axes,
225+
col_axes,
226+
name=self.name,
227+
mat_type=self.mat_type,
228+
mat=self.mat,
229+
)
230+
231+
@property
232+
def context_free(self):
233+
row_axes = self.raxes.context_free
234+
col_axes = self.caxes.context_free
235+
return type(self)(
236+
row_axes,
237+
col_axes,
238+
name=self.name,
239+
mat_type=self.mat_type,
240+
mat=self.mat,
241+
)
242+
219243
# like Dat, bad name? handle?
220244
@property
221245
def array(self):

pyop3/axtree/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .tree import (
22
Axis,
3+
ContextMismatchException,
34
AxisComponent,
45
AxisTree,
56
AxisVariable,

pyop3/axtree/tree.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pyrsistent import freeze, pmap, thaw
2626

2727
from pyop3.axtree.parallel import partition_ghost_points
28+
from pyop3.exceptions import Pyop3Exception
2829
from pyop3.dtypes import IntType
2930
from pyop3.sf import StarForest, serial_forest
3031
from pyop3.tree import (
@@ -51,10 +52,14 @@
5152
)
5253

5354

54-
class ExpectedLinearAxisTreeException(Exception):
55+
class ExpectedLinearAxisTreeException(Pyop3Exception):
5556
...
5657

5758

59+
class ContextMismatchException(Pyop3Exception):
60+
pass
61+
62+
5863
class ContextAware(abc.ABC):
5964
@abc.abstractmethod
6065
def with_context(self, context):
@@ -97,7 +102,10 @@ def keys(self):
97102
return frozenset(indices)
98103

99104
def with_context(self, context):
100-
return self.context_map[self.filter_context(context)]
105+
try:
106+
return self.context_map[self.filter_context(context)]
107+
except KeyError:
108+
raise ContextMismatchException
101109

102110
def filter_context(self, context):
103111
key = {}
@@ -1144,18 +1152,22 @@ def __init__(
11441152
layout_exprs,
11451153
outer_loops,
11461154
):
1155+
if layout_exprs is None:
1156+
layout_exprs = pmap()
11471157
if outer_loops is None:
11481158
outer_loops = ()
1149-
else:
1150-
assert isinstance(outer_loops, tuple)
11511159

11521160
super().__init__(node_map)
11531161
self._unindexed = unindexed
1154-
self._target_paths = target_paths
1155-
self._index_exprs = index_exprs
1156-
self._layout_exprs = layout_exprs
1162+
self._target_paths = pmap(target_paths)
1163+
self._index_exprs = pmap(index_exprs)
1164+
self._layout_exprs = pmap(layout_exprs)
11571165
self._outer_loops = tuple(outer_loops)
11581166

1167+
# @cached_property
1168+
# def _hash_key(self):
1169+
# return super()._hash_key + (self.unindexed, self.target_paths, self.index_exprs, self.layout_exprs, self.outer_loops)
1170+
11591171
@property
11601172
def unindexed(self):
11611173
return self._unindexed
@@ -1342,9 +1354,15 @@ def datamap(self):
13421354
def sf(self):
13431355
return single_valued([ax.sf for ax in self.context_map.values()])
13441356

1357+
# @cached_property
1358+
# def unindexed(self):
1359+
# this does not work because unindexed may have different IDs, so just return
1360+
# the first one.
1361+
# return single_valued([ax.unindexed for ax in self.context_map.values()])
1362+
13451363
@cached_property
1346-
def unindexed(self):
1347-
return single_valued([ax.unindexed for ax in self.context_map.values()])
1364+
def context_free(self):
1365+
return just_one(self.context_map.values())
13481366

13491367

13501368
@functools.singledispatch

pyop3/transform.py

+21-30
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pyop3.array import HierarchicalArray
1111
from pyop3.array.petsc import AbstractMat
12-
from pyop3.axtree import Axis, AxisTree, ContextFree, ContextSensitive
12+
from pyop3.axtree import Axis, AxisTree, ContextFree, ContextSensitive, ContextMismatchException, ContextAware
1313
from pyop3.buffer import DistributedBuffer, NullBuffer, PackedBuffer
1414
from pyop3.itree import Map, TabulatedMapComponent
1515
from pyop3.lang import (
@@ -162,7 +162,7 @@ def _(self, terminal: CalledFunction, *, context):
162162
outer_context.update(
163163
{
164164
index: paths
165-
for ctx in arg.context_map.keys()
165+
for ctx in arg.axes.context_map.keys()
166166
for index, paths in ctx.items()
167167
if index not in context
168168
}
@@ -179,38 +179,29 @@ def _(self, terminal: Assignment, *, context):
179179
# FIXME for now we assume an outer context of {}. In other words anything
180180
# context sensitive in the assignment is completely handled by the existing
181181
# outer loops.
182+
# This is meaningful if the kernel accepts a loop index as an argument.
182183

183-
valid = True
184184
cf_args = []
185185
for arg in terminal.arguments:
186-
try:
187-
cf_arg = (
188-
arg.with_context(context)
189-
if isinstance(arg.axes, ContextSensitive)
190-
else arg
191-
)
192-
# FIXME We will hit issues here when we are missing outer context I think
193-
except KeyError:
194-
# assignment is not valid in this context, do nothing
195-
valid = False
196-
break
197-
cf_args.append(cf_arg)
198-
199-
if valid:
200-
return ((pmap(), terminal.with_arguments(cf_args)),)
201-
else:
202-
return ((pmap(), None),)
186+
if isinstance(arg, ContextAware):
187+
try:
188+
cf_args.append(arg.with_context(context))
189+
except ContextMismatchException:
190+
# assignment is not valid in this context, do nothing
191+
return ((pmap(), None),)
192+
else:
193+
cf_args.append(arg)
194+
return ((pmap(), terminal.with_arguments(cf_args)),)
203195

204196
# TODO: this is just an assignment, fix inheritance
205197
@_apply.register
206198
def _(self, terminal: PetscMatInstruction, *, context):
207-
if any(
208-
isinstance(a.axes, ContextSensitive)
209-
for a in {terminal.mat_arg, terminal.array_arg}
210-
):
211-
raise NotImplementedError
212-
213-
return ((pmap(), terminal),)
199+
try:
200+
mat = terminal.mat_arg.with_context(context)
201+
array = terminal.array_arg.with_context(context)
202+
return ((pmap(), terminal.copy(mat_arg=mat, array_arg=array)),)
203+
except ContextMismatchException:
204+
return ((pmap(), None),)
214205

215206

216207
def expand_loop_contexts(expr: Instruction):
@@ -320,9 +311,9 @@ def _(self, terminal: CalledFunction):
320311
for (arg, intent), shape in checked_zip(
321312
terminal.function_arguments, terminal.argument_shapes
322313
):
323-
assert isinstance(
324-
arg, ContextFree
325-
), "Loop contexts should already be expanded"
314+
# assert isinstance(
315+
# arg, ContextFree
316+
# ), "Loop contexts should already be expanded"
326317

327318
if isinstance(arg, DummyKernelArgument):
328319
arguments.append(arg)

pyop3/tree.py

+13
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,19 @@ def __init__(self, node_map=None):
251251
# post-init checks
252252
self._check_node_labels_unique_in_paths(self.node_map)
253253

254+
# This is arguably over-specific. Otherwise equivalent trees will currently
255+
# fail this check if their IDs do not match. One way to resolve this would be
256+
# to re-ID all of the nodes with a pre-order traversal. This is not a priority.
257+
# def __eq__(self, other):
258+
# return type(other) is type(self) and other._hash_key == self._hash_key
259+
#
260+
# def __hash__(self):
261+
# return hash(self._hash_key)
262+
#
263+
# @cached_property
264+
# def _hash_key(self):
265+
# return (self.node_map,)
266+
254267
@classmethod
255268
def _check_node_labels_unique_in_paths(
256269
cls, node_map, node=None, seen_labels=frozenset()

0 commit comments

Comments
 (0)