Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 9436215

Browse files
committed
Tests passing, about to try something new on another branch
1 parent ae0d3b6 commit 9436215

File tree

5 files changed

+30
-6
lines changed

5 files changed

+30
-6
lines changed

pyop3/array/harray.py

+4
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def __init__(
123123
):
124124
super().__init__(name=name, prefix=prefix)
125125

126+
# debug
127+
# if self.name == "t_0":
128+
# breakpoint()
129+
126130
axes = as_axis_tree(axes)
127131

128132
if isinstance(data, Buffer):

pyop3/axtree/layout.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def collect_externally_indexed_axes(axes, axis=None, component=None, path=pmap()
160160
}
161161
)
162162
else:
163+
path_ = path | {axis.label: component.label}
163164
csize = component.count
164165
if isinstance(csize, HierarchicalArray):
165166
if csize.axes.is_empty:
@@ -168,14 +169,14 @@ def collect_externally_indexed_axes(axes, axis=None, component=None, path=pmap()
168169
# is the path sufficient? i.e. do we have enough externally provided indices
169170
# to correctly index the axis?
170171
for caxis, ccpt in csize.axes.path_with_nodes(*csize.axes.leaf).items():
171-
if caxis.label in path:
172-
assert path[caxis.label] == ccpt, "Paths do not match"
172+
if caxis.label in path_:
173+
assert path_[caxis.label] == ccpt, "Paths do not match"
173174
else:
175+
# also return an expr?
174176
external_axes[caxis.label] = caxis
175177
else:
176178
assert isinstance(csize, numbers.Integral)
177179
if subaxis := axes.child(axis, component):
178-
path_ = path | {axis.label: component.label}
179180
for subcpt in subaxis.components:
180181
external_axes.update(
181182
{

pyop3/ir/lower.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -828,10 +828,23 @@ def parse_assignment_properly_this_time(
828828
(axis.id, component.label), pmap()
829829
)
830830

831+
# TODO move to register_extent
832+
if isinstance(component.count, HierarchicalArray):
833+
count_axes = component.count.axes
834+
count_exprs = {}
835+
for count_axis, count_cpt in count_axes.path_with_nodes(
836+
*count_axes.leaf
837+
).items():
838+
count_exprs.update(
839+
component.count.index_exprs.get((count_axis.id, count_cpt), {})
840+
)
841+
else:
842+
count_exprs = {}
843+
831844
extent_var = register_extent(
832845
component.count,
833-
index_exprs[assignment.assignee] | loop_indices | domain_index_exprs,
834-
iname_replace_map,
846+
index_exprs[assignment.assignee] | count_exprs | domain_index_exprs,
847+
iname_replace_map | loop_indices,
835848
codegen_context,
836849
)
837850
codegen_context.add_domain(iname, extent_var)

pyop3/lang.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,19 @@ def kernel_arguments(self):
200200
@cached_property
201201
def _distarray_args(self):
202202
from pyop3.array import ContextSensitiveMultiArray, HierarchicalArray
203+
from pyop3.buffer import DistributedBuffer
203204

204205
arrays = {}
205206
for arg, intent in self.kernel_arguments:
206207
if isinstance(arg, ContextSensitiveMultiArray):
207208
# take first
208209
arg, *_ = arg.context_map.values()
209210

210-
if not isinstance(arg, HierarchicalArray) or not arg.buffer.is_distributed:
211+
if (
212+
not isinstance(arg, HierarchicalArray)
213+
or not isinstance(arg.buffer, DistributedBuffer)
214+
or arg.buffer.sf is None
215+
):
211216
continue
212217

213218
if arg.array not in arrays:

pyop3/transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def _(self, assignment: Assignment):
175175
data=NullBuffer(arg.dtype), # does this need a size?
176176
name=self._name_generator("t"),
177177
)
178+
breakpoint()
178179

179180
if intent == READ:
180181
gathers.append(PetscMatLoad(arg, new_arg))

0 commit comments

Comments
 (0)