20
20
import pymbolic as pym
21
21
import pyrsistent
22
22
import pytools
23
+ from cachetools import cachedmethod
23
24
from mpi4py import MPI
24
25
from petsc4py import PETSc
25
26
from pyrsistent import freeze , pmap , thaw
@@ -191,7 +192,7 @@ def map_array(self, array_var):
191
192
# layout_subst = array.axes.subst_layouts[array_var.path]
192
193
193
194
path , = array .axes .leaf_paths
194
- layout_subst = array .axes .subst_layouts [path ]
195
+ layout_subst = array .axes .subst_layouts () [path ]
195
196
196
197
# offset = ExpressionEvaluator(indices, self._loop_exprs)(layout_subst)
197
198
# offset = ExpressionEvaluator(self.context | indices, self._loop_exprs)(layout_subst)
@@ -210,6 +211,95 @@ def map_loop_index(self, expr):
210
211
return self ._loop_exprs [expr .id ][expr .axis ]
211
212
212
213
214
+ class ExpressionFlatteningCollector (pym .mapper .Mapper ):
215
+ def map_array (self , expr ):
216
+ needs_flattening = False
217
+ for index_expr in expr .indices .values ():
218
+ subexpr , _ = self .rec (index_expr )
219
+ needs_flattening = needs_flattening or subexpr is not None
220
+ return (expr , needs_flattening )
221
+
222
+ def map_axis_variable (self , var ):
223
+ return (None , False )
224
+
225
+ map_constant = map_axis_variable
226
+ map_loop_index = map_axis_variable
227
+
228
+ def map_sum (self , expr ):
229
+ replace_expr = None
230
+ needs_flattening = False
231
+ for child in expr .children :
232
+ subexpr , needs_flattening_ = self .rec (child )
233
+ if subexpr is not None :
234
+ if replace_expr is None :
235
+ replace_expr = subexpr
236
+ needs_flattening = needs_flattening_
237
+ else :
238
+ replace_expr = expr
239
+ needs_flattening = needs_flattening or needs_flattening_
240
+
241
+ return (replace_expr , needs_flattening )
242
+
243
+ map_product = map_sum
244
+
245
+
246
+ # TODO: This is not the right way to do this - pymbolic is not an adequate
247
+ # symbolic language for pyop3.
248
+ def eval_expr (expr ):
249
+ """Convert an array expression into an array."""
250
+ from pyop3 import HierarchicalArray
251
+
252
+ axes_iter , loop_index = axes_from_expr (expr )
253
+ axes = AxisTree .from_iterable (axes_iter )
254
+
255
+ result = HierarchicalArray (axes , dtype = IntType )
256
+ for ploop in loop_index .iter ():
257
+ for p in axes .iter ({ploop }):
258
+ evaluator = ExpressionEvaluator (p .source_exprs , loop_exprs = ploop .replace_map )
259
+ num = evaluator (expr )
260
+ breakpoint ()
261
+ result .set_value (p .source_exprs , num )
262
+ breakpoint ()
263
+ return result
264
+
265
+
266
+ # NOTE: This is a horrendous hack to rebuild structure from expressions. The
267
+ # right way to do this is to have a pyop3 symbolic language where constructs
268
+ # like Sum carries information about things like shape and dtype.
269
+ class AxisBuilder (pym .mapper .Mapper ):
270
+ def map_constant (self , expr ):
271
+ return None , None
272
+
273
+ def map_array (self , expr ):
274
+ if len (expr .indices ) == 1 :
275
+ return self .rec (just_one (expr .indices .values ()))
276
+ else :
277
+ # For now limit ourselves to these cases - ultimately this should
278
+ # all go.
279
+ assert len (expr .indices ) == 2
280
+
281
+ shape = expr .array .axes .leaf_component .count
282
+ subresult = self .rec (just_one ([i for i in expr .indices .values () if not isinstance (i , AxisVariable )]))
283
+ return (subresult [0 ] + (shape ,), subresult [1 ])
284
+
285
+ def map_loop_index (self , expr ):
286
+ assert expr .index .iterset .depth == 1 # for now
287
+ return ((expr .index .iterset .materialize ().root ,), expr .index )
288
+
289
+
290
+ def axes_from_expr (expr ):
291
+ return AxisBuilder ()(expr )
292
+
293
+
294
+ # NOTE: I have identical classes all over the place for this
295
+ class ExpressionReplacer (pym .mapper .IdentityMapper ):
296
+ def __init__ (self , replace_map ):
297
+ self ._replace_map = replace_map
298
+
299
+ def map_variable (self , var ):
300
+ return self ._replace_map .get (var , var )
301
+
302
+
213
303
# This can just be replaced by component.datamap
214
304
def _collect_datamap (axis , * subdatamaps , axes ):
215
305
datamap = {}
@@ -718,9 +808,31 @@ def outer_loops(self):
718
808
719
809
@property
720
810
@abc .abstractmethod
721
- def subst_layouts (self ):
811
+ def _subst_layouts_default (self ):
722
812
pass
723
813
814
+ # NOTE: Shouldn't be a boolean here as there are different optimisation options.
815
+ # In particular we can choose to compress multiple maps either only with non-increasing
816
+ # arity (arity * 1), or not (which leads to a larger array: arity * arity).
817
+ @cachedmethod (cache = lambda self : self ._cache )
818
+ def subst_layouts (self , optimize = False ):
819
+ if optimize :
820
+ layouts_opt = {}
821
+ collector = ExpressionFlatteningCollector ()
822
+ for key , layout in self ._subst_layouts_default .items ():
823
+ replace_expr , needs_flattening = collector (layout )
824
+ if needs_flattening :
825
+ target_expr = eval_expr (replace_expr )
826
+ replace_map = {replace_expr : target_expr }
827
+ breakpoint ()
828
+ layout_opt = ExpressionReplacer (replace_map )(layout )
829
+ else :
830
+ layout_opt = layout
831
+ layouts_opt [key ] = layout_opt
832
+ return freeze (layouts_opt )
833
+ else :
834
+ return self ._subst_layouts_default
835
+
724
836
def index (self , * , include_ghost_points = False ):
725
837
from pyop3 .itree .tree import ContextFreeLoopIndex , LoopIndex
726
838
@@ -777,11 +889,25 @@ def datamap(self):
777
889
def as_tree (self ):
778
890
return self
779
891
892
+ @abc .abstractmethod
893
+ def materialize (self ):
894
+ """Return a new "unindexed" axis tree with the same shape."""
895
+ # "unindexed" axis tree
896
+ # strip parallel semantics (in a bad way)
897
+ parent_to_children = collections .defaultdict (list )
898
+ for p , cs in self .axes .parent_to_children .items ():
899
+ for c in cs :
900
+ if c is not None and c .sf is not None :
901
+ c = c .copy (sf = None )
902
+ parent_to_children [p ].append (c )
903
+
904
+ axes = AxisTree (parent_to_children )
905
+
780
906
def offset (self , indices , path = None , * , loop_exprs = pmap ()):
781
907
from pyop3 .axtree .layout import eval_offset
782
908
return eval_offset (
783
909
self ,
784
- self .subst_layouts ,
910
+ self .subst_layouts () ,
785
911
indices ,
786
912
path ,
787
913
loop_exprs = loop_exprs ,
@@ -1052,6 +1178,9 @@ def datamap(self):
1052
1178
dmap = postvisit (self , _collect_datamap , axes = self )
1053
1179
return freeze (dmap )
1054
1180
1181
+ def materialize (self ):
1182
+ return self
1183
+
1055
1184
def add_axis (self , axis , parent_axis , parent_component = None , * , uniquify = False ):
1056
1185
parent_axis = self ._as_node (parent_axis )
1057
1186
if parent_component is not None :
@@ -1134,7 +1263,7 @@ def layouts(self):
1134
1263
return freeze (layouts_ )
1135
1264
1136
1265
@property
1137
- def subst_layouts (self ):
1266
+ def _subst_layouts_default (self ):
1138
1267
return self .layouts
1139
1268
1140
1269
@cached_property
@@ -1261,8 +1390,22 @@ def outer_loop_bits(self):
1261
1390
1262
1391
return loop_axes , freeze (loop_vars )
1263
1392
1393
+ def materialize (self ):
1394
+ """Return a new "unindexed" axis tree with the same shape."""
1395
+ # "unindexed" axis tree
1396
+ # strip parallel semantics (in a bad way)
1397
+ parent_to_children = collections .defaultdict (list )
1398
+ for p , cs in self .node_map .items ():
1399
+ for c in cs :
1400
+ if c is not None and c .sf is not None :
1401
+ c = c .copy (sf = None )
1402
+ parent_to_children [p ].append (c )
1403
+
1404
+ return AxisTree (parent_to_children )
1405
+
1406
+
1264
1407
@cached_property
1265
- def subst_layouts (self ):
1408
+ def _subst_layouts_default (self ):
1266
1409
return subst_layouts (self , self .target_paths , self .index_exprs , self .layouts )
1267
1410
1268
1411
@property
0 commit comments