Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 01cfcd0

Browse files
committed
Add context sensitive dat and hashable trees
1 parent f8625bf commit 01cfcd0

File tree

5 files changed

+88
-30
lines changed

5 files changed

+88
-30
lines changed

pyop3/array/harray.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pyop3.array.base import Array
2020
from pyop3.axtree import (
2121
Axis,
22+
ContextSensitive,
2223
AxisTree,
2324
ContextFree,
2425
as_axis_tree,
@@ -138,7 +139,7 @@ def __init__(
138139
# always deal with flattened data
139140
if len(data.shape) > 1:
140141
data = data.flatten()
141-
if data.size != axes.alloc_size:
142+
if data.size != axes.unindexed.global_size:
142143
raise ValueError("Data shape does not match axes")
143144

144145
# IndexedAxisTrees do not currently have SFs, so create a dummy one here
@@ -147,10 +148,10 @@ def __init__(
147148
else:
148149
assert isinstance(axes, (ContextSensitiveAxisTree, IndexedAxisTree))
149150
# not sure this is the right thing to do
150-
sf = serial_forest(axes.alloc_size)
151+
sf = serial_forest(axes.unindexed.global_size)
151152

152153
data = DistributedBuffer(
153-
axes.alloc_size, # not a useful property anymore
154+
axes.unindexed.global_size, # not a useful property anymore
154155
sf,
155156
dtype,
156157
name=self.name,
@@ -528,3 +529,16 @@ class MultiArray(HierarchicalArray):
528529
@deprecated("HierarchicalArray")
529530
def __init__(self, *args, **kwargs):
530531
super().__init__(*args, **kwargs)
532+
533+
534+
class ContextSensitiveDat(ContextSensitive):
535+
"""Class for describing arrays that are different within different loop contexts.
536+
537+
This is useful for the case where one wants to pass a small array through as
538+
part of a context-sensitive assignment.
539+
540+
"""
541+
542+
@property
543+
def dtype(self):
544+
return self._shared_attr("dtype")

pyop3/array/petsc.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pyrsistent import freeze, pmap
1313

1414
from pyop3.array.base import Array
15-
from pyop3.array.harray import HierarchicalArray
15+
from pyop3.array.harray import HierarchicalArray, ContextSensitiveDat
1616
from pyop3.axtree.tree import (
1717
AxisTree,
1818
ContextSensitiveAxisTree,
@@ -261,11 +261,22 @@ def assign(self, other, *, eager=True):
261261
# TODO: Check axes match between self and other
262262
expr = PetscMatStore(self, other)
263263
elif isinstance(other, numbers.Number):
264-
static = HierarchicalArray(
265-
self.axes,
266-
data=np.full(self.axes.alloc_size, other, dtype=self.dtype),
267-
constant=True,
268-
)
264+
if isinstance(self.axes, ContextSensitiveAxisTree):
265+
cs_dats = {}
266+
for context, axes in self.axes.context_map.items():
267+
cs_dat = HierarchicalArray(
268+
axes,
269+
data=np.full(axes.size, other, dtype=self.dtype),
270+
constant=True,
271+
)
272+
cs_dats[context] = cs_dat
273+
static = ContextSensitiveDat(cs_dats)
274+
else:
275+
static = HierarchicalArray(
276+
self.axes,
277+
data=np.full(self.axes.alloc_size, other, dtype=self.dtype),
278+
constant=True,
279+
)
269280
expr = PetscMatStore(self, static)
270281
else:
271282
raise NotImplementedError

pyop3/axtree/tree.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def filter_context(self, context):
114114
key.update({loop_index: freeze(path)})
115115
return freeze(key)
116116

117+
def _shared_attr(self, attr: str):
118+
return single_valued(getattr(a, attr) for a in self.context_map.values())
117119

118120
# this is basically just syntactic sugar, might not be needed
119121
# avoids the need for
@@ -1164,9 +1166,9 @@ def __init__(
11641166
self._layout_exprs = pmap(layout_exprs)
11651167
self._outer_loops = tuple(outer_loops)
11661168

1167-
# @cached_property
1168-
# def _hash_key(self):
1169-
# return super()._hash_key + (self.unindexed, self.target_paths, self.index_exprs, self.layout_exprs, self.outer_loops)
1169+
@cached_property
1170+
def _hash_key(self):
1171+
return super()._hash_key + (self.unindexed, self.target_paths, self.index_exprs, self.layout_exprs, self.outer_loops)
11701172

11711173
@property
11721174
def unindexed(self):
@@ -1354,11 +1356,9 @@ def datamap(self):
13541356
def sf(self):
13551357
return single_valued([ax.sf for ax in self.context_map.values()])
13561358

1357-
# @cached_property
1358-
# def unindexed(self):
1359-
# this does not work because unindexed may have different IDs, so just return
1360-
# the first one.
1361-
# return single_valued([ax.unindexed for ax in self.context_map.values()])
1359+
@cached_property
1360+
def unindexed(self):
1361+
return single_valued([ax.unindexed for ax in self.context_map.values()])
13621362

13631363
@cached_property
13641364
def context_free(self):

pyop3/lang.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def _array_updates(self):
242242
243243
"""
244244
from pyop3 import DistributedBuffer, HierarchicalArray, Mat
245+
from pyop3.array.harray import ContextSensitiveDat
245246
from pyop3.array.petsc import Sparsity
246247

247248
initializers = []
@@ -260,6 +261,9 @@ def _array_updates(self):
260261
initializers.extend(inits)
261262
reductions.extend(reds)
262263
broadcasts.extend(bcasts)
264+
elif isinstance(arg, ContextSensitiveDat):
265+
# assumed to not be distributed
266+
pass
263267
else:
264268
assert isinstance(arg, (Mat, Sparsity))
265269
# just in case
@@ -589,7 +593,12 @@ def __init__(self, mat_arg, array_arg):
589593

590594
@property
591595
def kernel_arguments(self):
592-
return (self.mat_arg.mat, self.array_arg.buffer)
596+
args = (self.mat_arg.mat,)
597+
if isinstance(self.array_arg, ContextSensitive):
598+
args += tuple(dat.buffer for dat in self.array_arg.context_map.values())
599+
else:
600+
args += (self.array_arg.buffer,)
601+
return args
593602

594603
@property
595604
def datamap(self):

pyop3/tree.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import abc
44
import collections
55
import functools
6+
import itertools
67
import operator
78
from collections import defaultdict
89
from collections.abc import Hashable, Sequence
@@ -251,18 +252,41 @@ def __init__(self, node_map=None):
251252
# post-init checks
252253
self._check_node_labels_unique_in_paths(self.node_map)
253254

254-
# This is arguably over-specific. Otherwise equivalent trees will currently
255-
# fail this check if their IDs do not match. One way to resolve this would be
256-
# to re-ID all of the nodes with a pre-order traversal. This is not a priority.
257-
# def __eq__(self, other):
258-
# return type(other) is type(self) and other._hash_key == self._hash_key
259-
#
260-
# def __hash__(self):
261-
# return hash(self._hash_key)
262-
#
263-
# @cached_property
264-
# def _hash_key(self):
265-
# return (self.node_map,)
255+
def __eq__(self, other):
256+
return type(other) is type(self) and other._hash_key == self._hash_key
257+
258+
def __hash__(self):
259+
return hash(self._hash_key)
260+
261+
@cached_property
262+
def _hash_key(self):
263+
return (self._hash_node_map,)
264+
265+
@cached_property
266+
def _hash_node_map(self):
267+
if self.is_empty:
268+
return pmap()
269+
270+
counter = itertools.count()
271+
return self._collect_hash_node_map(None, None, counter)
272+
273+
def _collect_hash_node_map(self, old_parent_id, new_parent_id, counter):
274+
if old_parent_id not in self.node_map:
275+
return pmap()
276+
277+
nodes = []
278+
node_map = {}
279+
for old_node in self.node_map[old_parent_id]:
280+
if old_node is not None:
281+
new_node = old_node.copy(id=f"id_{next(counter)}")
282+
node_map.update(self._collect_hash_node_map(old_node.id, new_node.id, counter))
283+
else:
284+
new_node = None
285+
286+
nodes.append(new_node)
287+
288+
node_map[new_parent_id] = freeze(nodes)
289+
return freeze(node_map)
266290

267291
@classmethod
268292
def _check_node_labels_unique_in_paths(

0 commit comments

Comments
 (0)