Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 15e870d

Browse files
authored
Fix ghost points (#26)
1 parent e69d946 commit 15e870d

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

pyop3/axtree/tree.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ def ghost_count_per_component(self):
436436
def owned(self):
437437
return self._tree.owned.root
438438

439-
def index(self):
440-
return self._tree.index()
439+
def index(self, *, include_ghost_points=False):
440+
return self._tree.index(include_ghost_points=include_ghost_points)
441441

442442
def iter(self, *, include_ghost_points=False):
443443
return self._tree.iter(include_ghost_points=include_ghost_points)
@@ -709,10 +709,10 @@ def outer_loops(self):
709709
def subst_layouts(self):
710710
pass
711711

712-
def index(self, ghost=False):
712+
def index(self, *, include_ghost_points=False):
713713
from pyop3.itree.tree import ContextFreeLoopIndex, LoopIndex
714714

715-
iterset = self if ghost else self.owned
715+
iterset = self if include_ghost_points else self.owned
716716
# If the iterset is linear (single-component for every axis) then we
717717
# can consider the loop to be "context-free".
718718
if len(iterset.leaves) == 1:
@@ -1294,7 +1294,7 @@ def tabulated_offsets(self):
12941294
rmap_axes = iterset.add_subtree(self, *iterset.leaf)
12951295
rmap = HierarchicalArray(rmap_axes, dtype=IntType)
12961296
rmap = rmap[loop_index.local_index]
1297-
for idx in loop_index.iter():
1297+
for idx in loop_index.iter(include_ghost_points=True):
12981298
target_indices = idx.replace_map
12991299
# for p in self.iter(idxs):
13001300
for p in self.iter([idx], include_ghost_points=True): # seems to fix thing

pyop3/itree/tree.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,11 @@ def layout_exprs(self):
429429
def datamap(self):
430430
return self.iterset.datamap
431431

432-
def iter(self, stuff=pmap()):
432+
def iter(self, stuff=pmap(), *, include_ghost_points=False):
433+
iterset = self.iterset if include_ghost_points else self.iterset.owned
433434
return iter_axis_tree(
434435
self,
435-
self.iterset,
436+
iterset,
436437
self.iterset.target_paths,
437438
self.iterset.index_exprs,
438439
stuff,

0 commit comments

Comments
 (0)