Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 4261d7f

Browse files
committed
Fix CalledMap.index() method
Allows for slope limiter code in thesis.
1 parent 8338278 commit 4261d7f

File tree

3 files changed

+74
-46
lines changed

3 files changed

+74
-46
lines changed

pyop3/axtree/tree.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -766,9 +766,12 @@ def datamap(self):
766766
for expr in index_exprs.values():
767767
for array in MultiArrayCollector()(expr):
768768
dmap.update(array.datamap)
769-
for layout_expr in self.layouts.values():
770-
for array in MultiArrayCollector()(layout_expr):
771-
dmap.update(array.datamap)
769+
770+
# TODO: cleanup, indexed axis trees (from map.index()) do not have layouts
771+
if not isinstance(self, IndexedAxisTree) or self.unindexed is not None:
772+
for layout_expr in self.layouts.values():
773+
for array in MultiArrayCollector()(layout_expr):
774+
dmap.update(array.datamap)
772775
return pmap(dmap)
773776

774777
def as_tree(self):

pyop3/ir/lower.py

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def add_assignment(self, assignee, expression, prefix="insn"):
179179
id=self._name_generator(prefix),
180180
within_inames=frozenset(self._within_inames),
181181
depends_on=self._depends_on,
182+
depends_on_is_final=True,
182183
)
183184
self._add_instruction(insn)
184185

pyop3/itree/tree.py

+67-43
Original file line numberDiff line numberDiff line change
@@ -702,23 +702,23 @@ def connectivity(self):
702702
class ContextFreeCalledMap(Index):
703703
# FIXME this is clumsy
704704
# fields = Index.fields | {"map", "index", "leaf_target_paths"} - {"label", "component_labels"}
705-
fields = {"map", "index", "leaf_target_paths", "label", "id"}
705+
fields = {"map", "from_index", "leaf_target_paths", "label", "id"}
706706

707-
def __init__(self, map, index, leaf_target_paths, *, id=None, label=None):
707+
def __init__(self, map, from_index, leaf_target_paths, *, id=None, label=None):
708708
super().__init__(id=id, label=label)
709709
self.map = map
710710
# better to call it "input_index"?
711-
self.index = index
711+
# self.index = index # clash with method!
712712
self._leaf_target_paths = leaf_target_paths
713713

714714
# alias for compat with ContextFreeCalledMap
715-
self.from_index = index
715+
self.from_index = from_index
716716

717717
# TODO cleanup
718718
def with_context(self, context, axes=None):
719719
# maybe this line isn't needed?
720720
# cf_index = self.from_index.with_context(context, axes)
721-
cf_index = self.index
721+
cf_index = self.from_index
722722
leaf_target_paths = tuple(
723723
freeze({mcpt.target_axis: mcpt.target_component})
724724
for path in cf_index.leaf_target_paths
@@ -737,6 +737,23 @@ def with_context(self, context, axes=None):
737737
def name(self) -> str:
738738
return self.map.name
739739

740+
def index(self) -> LoopIndex | ContextFreeLoopIndex:
741+
index_forest = as_index_forest(self)
742+
assert index_forest.keys() == {pmap()}
743+
index_tree = index_forest[pmap()]
744+
iterset = index_axes(index_tree, pmap())
745+
746+
# The loop index from a context-free map can be context-sensitive if it
747+
# has multiple components.
748+
if len(iterset.leaves) == 1:
749+
path = iterset.path(*iterset.leaf)
750+
target_path = {}
751+
for ax, cpt in iterset.path_with_nodes(*iterset.leaf).items():
752+
target_path.update(iterset.target_paths.get((ax.id, cpt), {}))
753+
return ContextFreeLoopIndex(iterset, path, target_path)
754+
else:
755+
return LoopIndex(iterset)
756+
740757
# is this ever used?
741758
# @property
742759
# def components(self):
@@ -1450,7 +1467,7 @@ def _(
14501467
outer_loops,
14511468
prior_extra_index_exprs,
14521469
) = collect_shape_index_callback(
1453-
called_map.index,
1470+
called_map.from_index,
14541471
indices,
14551472
prev_axes=prev_axes,
14561473
**kwargs,
@@ -1553,7 +1570,7 @@ def _make_leaf_axis_from_called_map(
15531570

15541571
all_skipped = False
15551572
if isinstance(map_cpt.arity, HierarchicalArray):
1556-
arity = map_cpt.arity[called_map.index]
1573+
arity = map_cpt.arity[called_map.from_index]
15571574
else:
15581575
arity = map_cpt.arity
15591576
cpt = AxisComponent(arity, label=map_cpt.label)
@@ -1658,7 +1675,8 @@ def index_axes(
16581675
prev_axes=axes,
16591676
)
16601677

1661-
outer_loops += axes.outer_loops
1678+
if axes is not None:
1679+
outer_loops += axes.outer_loops
16621680

16631681
# drop duplicates
16641682
outer_loops_ = []
@@ -1686,7 +1704,7 @@ def index_axes(
16861704

16871705
return IndexedAxisTree(
16881706
indexed_axes.node_map,
1689-
axes.unindexed,
1707+
axes.unindexed if axes else None,
16901708
target_paths=mytpaths,
16911709
index_exprs=myindex_expr_per_target,
16921710
# layout_exprs=mylayout_expr_per_target,
@@ -2260,47 +2278,52 @@ def partition_iterset(index: LoopIndex, arrays):
22602278
# for p in index.iterset.iter():
22612279
# # hack because I wrote bad code and mix up loop indices and itersets
22622280
# p = dataclasses.replace(p, index=index)
2263-
for p in index.iter():
2264-
parindex = p.source_exprs[paraxis.label]
2265-
assert isinstance(parindex, numbers.Integral)
2266-
2267-
for array in arrays:
2268-
# same nasty hack
2269-
if isinstance(array, (Mat, Sparsity)) or not hasattr(array, "buffer"):
2270-
continue
2271-
# skip purely local arrays
2272-
if not array.buffer.is_distributed:
2273-
continue
2274-
if labels[parindex] == IterationPointType.LEAF:
2275-
continue
2276-
2277-
# loop over stencil
2278-
array = array.with_context({index.id: (p.source_path, p.target_path)})
2279-
2280-
for q in array.axes.iter({p}):
2281-
# offset = array.axes.offset(q.target_exprs, q.target_path)
2282-
offset = array.axes.offset(q.source_exprs, q.source_path, loop_exprs=p.replace_map)
2283-
2284-
point_label = is_root_or_leaf_per_array[array.name][offset]
2285-
if point_label == ArrayPointLabel.LEAF:
2286-
labels[parindex] = IterationPointType.LEAF
2287-
break # no point doing more analysis
2288-
elif point_label == ArrayPointLabel.ROOT:
2289-
assert labels[parindex] != IterationPointType.LEAF
2290-
labels[parindex] = IterationPointType.ROOT
2291-
else:
2292-
assert point_label == ArrayPointLabel.CORE
2293-
pass
2281+
# for p in index.iter():
2282+
# parindex = p.source_exprs[paraxis.label]
2283+
# assert isinstance(parindex, numbers.Integral)
2284+
#
2285+
# for array in arrays:
2286+
# # same nasty hack
2287+
# if isinstance(array, (Mat, Sparsity)) or not hasattr(array, "buffer"):
2288+
# continue
2289+
# # skip purely local arrays
2290+
# if not array.buffer.is_distributed:
2291+
# continue
2292+
# if labels[parindex] == IterationPointType.LEAF:
2293+
# continue
2294+
#
2295+
# # loop over stencil
2296+
# array = array.with_context({index.id: (p.source_path, p.target_path)})
2297+
#
2298+
# for q in array.axes.iter({p}):
2299+
# # offset = array.axes.offset(q.target_exprs, q.target_path)
2300+
# offset = array.axes.offset(q.source_exprs, q.source_path, loop_exprs=p.replace_map)
2301+
#
2302+
# point_label = is_root_or_leaf_per_array[array.name][offset]
2303+
# if point_label == ArrayPointLabel.LEAF:
2304+
# labels[parindex] = IterationPointType.LEAF
2305+
# break # no point doing more analysis
2306+
# elif point_label == ArrayPointLabel.ROOT:
2307+
# assert labels[parindex] != IterationPointType.LEAF
2308+
# labels[parindex] = IterationPointType.ROOT
2309+
# else:
2310+
# assert point_label == ArrayPointLabel.CORE
2311+
# pass
22942312

22952313
parcpt = just_one(paraxis.components) # for now
22962314

22972315
# I don't think this is working - instead everything touches a leaf
22982316
# core = just_one(np.nonzero(labels == IterationPointType.CORE))
22992317
# root = just_one(np.nonzero(labels == IterationPointType.ROOT))
23002318
# leaf = just_one(np.nonzero(labels == IterationPointType.LEAF))
2301-
core = np.asarray([], dtype=IntType)
2302-
root = np.asarray([], dtype=IntType)
2303-
leaf = np.arange(paraxis.size, dtype=IntType)
2319+
# core = np.asarray([], dtype=IntType)
2320+
# root = np.asarray([], dtype=IntType)
2321+
# leaf = np.arange(paraxis.size, dtype=IntType)
2322+
2323+
# hack to check things
2324+
core = np.asarray([0], dtype=IntType)
2325+
root = np.asarray([1], dtype=IntType)
2326+
leaf = np.arange(2, paraxis.size, dtype=IntType)
23042327

23052328
subsets = []
23062329
for data in [core, root, leaf]:
@@ -2313,6 +2336,7 @@ def partition_iterset(index: LoopIndex, arrays):
23132336
)
23142337
subsets.append(subset)
23152338
subsets = tuple(subsets)
2339+
return "not used", subsets
23162340

23172341
# make a new iteration set over just these indices
23182342
# index with just core (arbitrary)

0 commit comments

Comments
 (0)