Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 26c301c

Browse files
committed
Add (broken) optimising code
Currently disabled. Needs a better language for this to work. I am still merging it because some of the infrastructure that I have added is sound.
1 parent 016095c commit 26c301c

File tree

5 files changed

+161
-22
lines changed

5 files changed

+161
-22
lines changed

pyop3/array/harray.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class IncompatibleShapeError(Exception):
5151
"""TODO, also bad name"""
5252

5353

54+
# TODO: not sure this is needed, can a Dat just be one of these?
5455
class ArrayVar(pym.primitives.AlgebraicLeaf):
5556
mapper_method = sys.intern("map_array")
5657

@@ -77,11 +78,11 @@ def __getinitargs__(self):
7778
# This was adapted from pymbolic's map_subscript
7879
def stringify_array(self, array, enclosing_prec, *args, **kwargs):
7980
index_str = self.join_rec(
80-
", ", array.index_exprs.values(), PREC_NONE, *args, **kwargs
81+
", ", array.indices.values(), PREC_NONE, *args, **kwargs
8182
)
8283

8384
return self.parenthesize_if_needed(
84-
self.format("%s[%s]", array.name, index_str), enclosing_prec, PREC_CALL
85+
self.format("%s[%s]", array.array.name, index_str), enclosing_prec, PREC_CALL
8586
)
8687

8788

@@ -398,17 +399,7 @@ def assemble(self, update_leaves=False):
398399

399400
def materialize(self) -> HierarchicalArray:
400401
"""Return a new "unindexed" array with the same shape."""
401-
# "unindexed" axis tree
402-
# strip parallel semantics (in a bad way)
403-
parent_to_children = collections.defaultdict(list)
404-
for p, cs in self.axes.parent_to_children.items():
405-
for c in cs:
406-
if c is not None and c.sf is not None:
407-
c = c.copy(sf=None)
408-
parent_to_children[p].append(c)
409-
410-
axes = AxisTree(parent_to_children)
411-
return type(self)(axes, dtype=self.dtype)
402+
return type(self)(self.axes.materialize(), dtype=self.dtype)
412403

413404
def iter_indices(self, outer_map):
414405
from pyop3.itree.tree import iter_axis_tree

pyop3/axtree/layout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def size_requires_external_index(axes, axis, component, inner_loop_vars, path=pm
120120
leafpath = pmap()
121121
else:
122122
leafpath = just_one(count.axes.leaf_paths)
123-
layout = count.axes.subst_layouts[leafpath]
123+
layout = count.axes._subst_layouts_default[leafpath]
124124
required_loop_vars = LoopIndexCollector(linear=False)(layout)
125125
if not required_loop_vars.issubset(inner_loop_vars):
126126
return True

pyop3/axtree/tree.py

+148-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pymbolic as pym
2121
import pyrsistent
2222
import pytools
23+
from cachetools import cachedmethod
2324
from mpi4py import MPI
2425
from petsc4py import PETSc
2526
from pyrsistent import freeze, pmap, thaw
@@ -191,7 +192,7 @@ def map_array(self, array_var):
191192
# layout_subst = array.axes.subst_layouts[array_var.path]
192193

193194
path, = array.axes.leaf_paths
194-
layout_subst = array.axes.subst_layouts[path]
195+
layout_subst = array.axes.subst_layouts()[path]
195196

196197
# offset = ExpressionEvaluator(indices, self._loop_exprs)(layout_subst)
197198
# offset = ExpressionEvaluator(self.context | indices, self._loop_exprs)(layout_subst)
@@ -210,6 +211,95 @@ def map_loop_index(self, expr):
210211
return self._loop_exprs[expr.id][expr.axis]
211212

212213

214+
class ExpressionFlatteningCollector(pym.mapper.Mapper):
215+
def map_array(self, expr):
216+
needs_flattening = False
217+
for index_expr in expr.indices.values():
218+
subexpr, _ = self.rec(index_expr)
219+
needs_flattening = needs_flattening or subexpr is not None
220+
return (expr, needs_flattening)
221+
222+
def map_axis_variable(self, var):
223+
return (None, False)
224+
225+
map_constant = map_axis_variable
226+
map_loop_index = map_axis_variable
227+
228+
def map_sum(self, expr):
229+
replace_expr = None
230+
needs_flattening = False
231+
for child in expr.children:
232+
subexpr, needs_flattening_ = self.rec(child)
233+
if subexpr is not None:
234+
if replace_expr is None:
235+
replace_expr = subexpr
236+
needs_flattening = needs_flattening_
237+
else:
238+
replace_expr = expr
239+
needs_flattening = needs_flattening or needs_flattening_
240+
241+
return (replace_expr, needs_flattening)
242+
243+
map_product = map_sum
244+
245+
246+
# TODO: This is not the right way to do this - pymbolic is not an adequate
247+
# symbolic language for pyop3.
248+
def eval_expr(expr):
249+
"""Convert an array expression into an array."""
250+
from pyop3 import HierarchicalArray
251+
252+
axes_iter, loop_index = axes_from_expr(expr)
253+
axes = AxisTree.from_iterable(axes_iter)
254+
255+
result = HierarchicalArray(axes, dtype=IntType)
256+
for ploop in loop_index.iter():
257+
for p in axes.iter({ploop}):
258+
evaluator = ExpressionEvaluator(p.source_exprs, loop_exprs=ploop.replace_map)
259+
num = evaluator(expr)
260+
breakpoint()
261+
result.set_value(p.source_exprs, num)
262+
breakpoint()
263+
return result
264+
265+
266+
# NOTE: This is a horrendous hack to rebuild structure from expressions. The
267+
# right way to do this is to have a pyop3 symbolic language where constructs
268+
# like Sum carries information about things like shape and dtype.
269+
class AxisBuilder(pym.mapper.Mapper):
270+
def map_constant(self, expr):
271+
return None, None
272+
273+
def map_array(self, expr):
274+
if len(expr.indices) == 1:
275+
return self.rec(just_one(expr.indices.values()))
276+
else:
277+
# For now limit ourselves to these cases - ultimately this should
278+
# all go.
279+
assert len(expr.indices) == 2
280+
281+
shape = expr.array.axes.leaf_component.count
282+
subresult = self.rec(just_one([i for i in expr.indices.values() if not isinstance(i, AxisVariable)]))
283+
return (subresult[0] + (shape,), subresult[1])
284+
285+
def map_loop_index(self, expr):
286+
assert expr.index.iterset.depth == 1 # for now
287+
return ((expr.index.iterset.materialize().root,), expr.index)
288+
289+
290+
def axes_from_expr(expr):
291+
return AxisBuilder()(expr)
292+
293+
294+
# NOTE: I have identical classes all over the place for this
295+
class ExpressionReplacer(pym.mapper.IdentityMapper):
296+
def __init__(self, replace_map):
297+
self._replace_map = replace_map
298+
299+
def map_variable(self, var):
300+
return self._replace_map.get(var, var)
301+
302+
213303
# This can just be replaced by component.datamap
214304
def _collect_datamap(axis, *subdatamaps, axes):
215305
datamap = {}
@@ -718,9 +808,31 @@ def outer_loops(self):
718808

719809
@property
720810
@abc.abstractmethod
721-
def subst_layouts(self):
811+
def _subst_layouts_default(self):
722812
pass
723813

814+
# NOTE: Shouldn't be a boolean here as there are different optimisation options.
815+
# In particular we can choose to compress multiple maps either only with non-increasing
816+
# arity (arity * 1), or not (which leads to a larger array: arity * arity).
817+
@cachedmethod(cache=lambda self: self._cache)
818+
def subst_layouts(self, optimize=False):
819+
if optimize:
820+
layouts_opt = {}
821+
collector = ExpressionFlatteningCollector()
822+
for key, layout in self._subst_layouts_default.items():
823+
replace_expr, needs_flattening = collector(layout)
824+
if needs_flattening:
825+
target_expr = eval_expr(replace_expr)
826+
replace_map = {replace_expr: target_expr}
827+
breakpoint()
828+
layout_opt = ExpressionReplacer(replace_map)(layout)
829+
else:
830+
layout_opt = layout
831+
layouts_opt[key] = layout_opt
832+
return freeze(layouts_opt)
833+
else:
834+
return self._subst_layouts_default
835+
724836
def index(self, *, include_ghost_points=False):
725837
from pyop3.itree.tree import ContextFreeLoopIndex, LoopIndex
726838

@@ -777,11 +889,25 @@ def datamap(self):
777889
def as_tree(self):
778890
return self
779891

892+
@abc.abstractmethod
893+
def materialize(self):
894+
"""Return a new "unindexed" axis tree with the same shape."""
895+
# "unindexed" axis tree
896+
# strip parallel semantics (in a bad way)
897+
parent_to_children = collections.defaultdict(list)
898+
for p, cs in self.axes.parent_to_children.items():
899+
for c in cs:
900+
if c is not None and c.sf is not None:
901+
c = c.copy(sf=None)
902+
parent_to_children[p].append(c)
903+
904+
axes = AxisTree(parent_to_children)
905+
780906
def offset(self, indices, path=None, *, loop_exprs=pmap()):
781907
from pyop3.axtree.layout import eval_offset
782908
return eval_offset(
783909
self,
784-
self.subst_layouts,
910+
self.subst_layouts(),
785911
indices,
786912
path,
787913
loop_exprs=loop_exprs,
@@ -1052,6 +1178,9 @@ def datamap(self):
10521178
dmap = postvisit(self, _collect_datamap, axes=self)
10531179
return freeze(dmap)
10541180

1181+
def materialize(self):
1182+
return self
1183+
10551184
def add_axis(self, axis, parent_axis, parent_component=None, *, uniquify=False):
10561185
parent_axis = self._as_node(parent_axis)
10571186
if parent_component is not None:
@@ -1134,7 +1263,7 @@ def layouts(self):
11341263
return freeze(layouts_)
11351264

11361265
@property
1137-
def subst_layouts(self):
1266+
def _subst_layouts_default(self):
11381267
return self.layouts
11391268

11401269
@cached_property
@@ -1261,8 +1390,22 @@ def outer_loop_bits(self):
12611390

12621391
return loop_axes, freeze(loop_vars)
12631392

1393+
def materialize(self):
1394+
"""Return a new "unindexed" axis tree with the same shape."""
1395+
# "unindexed" axis tree
1396+
# strip parallel semantics (in a bad way)
1397+
parent_to_children = collections.defaultdict(list)
1398+
for p, cs in self.node_map.items():
1399+
for c in cs:
1400+
if c is not None and c.sf is not None:
1401+
c = c.copy(sf=None)
1402+
parent_to_children[p].append(c)
1403+
1404+
return AxisTree(parent_to_children)
1405+
1406+
12641407
@cached_property
1265-
def subst_layouts(self):
1408+
def _subst_layouts_default(self):
12661409
return subst_layouts(self, self.target_paths, self.index_exprs, self.layouts)
12671410

12681411
@property

pyop3/ir/lower.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -821,10 +821,10 @@ def _(assignment, loop_indices, codegen_context):
821821
else:
822822
csize_var = csize
823823

824-
rlayouts = rmap.axes.subst_layouts[pmap()]
824+
rlayouts = rmap.axes.subst_layouts()[pmap()]
825825
roffset = JnameSubstitutor(loop_indices, codegen_context)(rlayouts)
826826

827-
clayouts = cmap.axes.subst_layouts[pmap()]
827+
clayouts = cmap.axes.subst_layouts()[pmap()]
828828
coffset = JnameSubstitutor(loop_indices, codegen_context)(clayouts)
829829

830830
irow = f"{rmap_name}[{roffset}]"
@@ -970,8 +970,12 @@ def add_leaf_assignment(
970970

971971

972972
def make_array_expr(array, path, inames, ctx):
973+
# TODO: This should be propagated as an option - we don't always want to optimise
974+
# TODO: Disabled optimising for now since I can't get it to work without a
975+
# symbolic language. That has to be future work.
973976
array_offset = make_offset_expr(
974-
array.axes.subst_layouts[path],
977+
# array.axes.subst_layouts(optimize=True)[path],
978+
array.axes.subst_layouts(optimize=False)[path],
975979
inames,
976980
ctx,
977981
)

pyop3/tree.py

+1
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def component_label(self):
248248
class LabelledTree(AbstractTree):
249249
def __init__(self, node_map=None):
250250
super().__init__(node_map=node_map)
251+
self._cache = {}
251252

252253
# post-init checks
253254
self._check_node_labels_unique_in_paths(self.node_map)

0 commit comments

Comments
 (0)