Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Fix ghost points #26

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pyop3/axtree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ def ghost_count_per_component(self):
def owned(self):
return self._tree.owned.root

def index(self):
return self._tree.index()
def index(self, *, include_ghost_points=False):
return self._tree.index(include_ghost_points=include_ghost_points)

def iter(self, *, include_ghost_points=False):
return self._tree.iter(include_ghost_points=include_ghost_points)
Expand Down Expand Up @@ -709,10 +709,10 @@ def outer_loops(self):
def subst_layouts(self):
pass

def index(self, ghost=False):
def index(self, *, include_ghost_points=False):
from pyop3.itree.tree import ContextFreeLoopIndex, LoopIndex

iterset = self if ghost else self.owned
iterset = self if include_ghost_points else self.owned
# If the iterset is linear (single-component for every axis) then we
# can consider the loop to be "context-free".
if len(iterset.leaves) == 1:
Expand Down Expand Up @@ -1294,7 +1294,7 @@ def tabulated_offsets(self):
rmap_axes = iterset.add_subtree(self, *iterset.leaf)
rmap = HierarchicalArray(rmap_axes, dtype=IntType)
rmap = rmap[loop_index.local_index]
for idx in loop_index.iter():
for idx in loop_index.iter(include_ghost_points=True):
target_indices = idx.replace_map
# for p in self.iter(idxs):
for p in self.iter([idx], include_ghost_points=True): # seems to fix thing
Expand Down
5 changes: 3 additions & 2 deletions pyop3/itree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,11 @@ def layout_exprs(self):
def datamap(self):
return self.iterset.datamap

def iter(self, stuff=pmap()):
def iter(self, stuff=pmap(), *, include_ghost_points=False):
iterset = self.iterset if include_ghost_points else self.iterset.owned
return iter_axis_tree(
self,
self.iterset,
iterset,
self.iterset.target_paths,
self.iterset.index_exprs,
stuff,
Expand Down
Loading