@@ -702,23 +702,23 @@ def connectivity(self):
702
702
class ContextFreeCalledMap (Index ):
703
703
# FIXME this is clumsy
704
704
# fields = Index.fields | {"map", "index", "leaf_target_paths"} - {"label", "component_labels"}
705
- fields = {"map" , "index " , "leaf_target_paths" , "label" , "id" }
705
+ fields = {"map" , "from_index " , "leaf_target_paths" , "label" , "id" }
706
706
707
- def __init__ (self , map , index , leaf_target_paths , * , id = None , label = None ):
707
+ def __init__ (self , map , from_index , leaf_target_paths , * , id = None , label = None ):
708
708
super ().__init__ (id = id , label = label )
709
709
self .map = map
710
710
# better to call it "input_index"?
711
- self .index = index
711
+ # self.index = index # clash with method!
712
712
self ._leaf_target_paths = leaf_target_paths
713
713
714
714
# alias for compat with ContextFreeCalledMap
715
- self .from_index = index
715
+ self .from_index = from_index
716
716
717
717
# TODO cleanup
718
718
def with_context (self , context , axes = None ):
719
719
# maybe this line isn't needed?
720
720
# cf_index = self.from_index.with_context(context, axes)
721
- cf_index = self .index
721
+ cf_index = self .from_index
722
722
leaf_target_paths = tuple (
723
723
freeze ({mcpt .target_axis : mcpt .target_component })
724
724
for path in cf_index .leaf_target_paths
@@ -737,6 +737,23 @@ def with_context(self, context, axes=None):
737
737
def name (self ) -> str :
738
738
return self .map .name
739
739
740
+ def index (self ) -> LoopIndex | ContextFreeLoopIndex :
741
+ index_forest = as_index_forest (self )
742
+ assert index_forest .keys () == {pmap ()}
743
+ index_tree = index_forest [pmap ()]
744
+ iterset = index_axes (index_tree , pmap ())
745
+
746
+ # The loop index from a context-free map can be context-sensitive if it
747
+ # has multiple components.
748
+ if len (iterset .leaves ) == 1 :
749
+ path = iterset .path (* iterset .leaf )
750
+ target_path = {}
751
+ for ax , cpt in iterset .path_with_nodes (* iterset .leaf ).items ():
752
+ target_path .update (iterset .target_paths .get ((ax .id , cpt ), {}))
753
+ return ContextFreeLoopIndex (iterset , path , target_path )
754
+ else :
755
+ return LoopIndex (iterset )
756
+
740
757
# is this ever used?
741
758
# @property
742
759
# def components(self):
@@ -1450,7 +1467,7 @@ def _(
1450
1467
outer_loops ,
1451
1468
prior_extra_index_exprs ,
1452
1469
) = collect_shape_index_callback (
1453
- called_map .index ,
1470
+ called_map .from_index ,
1454
1471
indices ,
1455
1472
prev_axes = prev_axes ,
1456
1473
** kwargs ,
@@ -1553,7 +1570,7 @@ def _make_leaf_axis_from_called_map(
1553
1570
1554
1571
all_skipped = False
1555
1572
if isinstance (map_cpt .arity , HierarchicalArray ):
1556
- arity = map_cpt .arity [called_map .index ]
1573
+ arity = map_cpt .arity [called_map .from_index ]
1557
1574
else :
1558
1575
arity = map_cpt .arity
1559
1576
cpt = AxisComponent (arity , label = map_cpt .label )
@@ -1658,7 +1675,8 @@ def index_axes(
1658
1675
prev_axes = axes ,
1659
1676
)
1660
1677
1661
- outer_loops += axes .outer_loops
1678
+ if axes is not None :
1679
+ outer_loops += axes .outer_loops
1662
1680
1663
1681
# drop duplicates
1664
1682
outer_loops_ = []
@@ -1686,7 +1704,7 @@ def index_axes(
1686
1704
1687
1705
return IndexedAxisTree (
1688
1706
indexed_axes .node_map ,
1689
- axes .unindexed ,
1707
+ axes .unindexed if axes else None ,
1690
1708
target_paths = mytpaths ,
1691
1709
index_exprs = myindex_expr_per_target ,
1692
1710
# layout_exprs=mylayout_expr_per_target,
@@ -2260,47 +2278,52 @@ def partition_iterset(index: LoopIndex, arrays):
2260
2278
# for p in index.iterset.iter():
2261
2279
# # hack because I wrote bad code and mix up loop indices and itersets
2262
2280
# p = dataclasses.replace(p, index=index)
2263
- for p in index .iter ():
2264
- parindex = p .source_exprs [paraxis .label ]
2265
- assert isinstance (parindex , numbers .Integral )
2266
-
2267
- for array in arrays :
2268
- # same nasty hack
2269
- if isinstance (array , (Mat , Sparsity )) or not hasattr (array , "buffer" ):
2270
- continue
2271
- # skip purely local arrays
2272
- if not array .buffer .is_distributed :
2273
- continue
2274
- if labels [parindex ] == IterationPointType .LEAF :
2275
- continue
2276
-
2277
- # loop over stencil
2278
- array = array .with_context ({index .id : (p .source_path , p .target_path )})
2279
-
2280
- for q in array .axes .iter ({p }):
2281
- # offset = array.axes.offset(q.target_exprs, q.target_path)
2282
- offset = array .axes .offset (q .source_exprs , q .source_path , loop_exprs = p .replace_map )
2283
-
2284
- point_label = is_root_or_leaf_per_array [array .name ][offset ]
2285
- if point_label == ArrayPointLabel .LEAF :
2286
- labels [parindex ] = IterationPointType .LEAF
2287
- break # no point doing more analysis
2288
- elif point_label == ArrayPointLabel .ROOT :
2289
- assert labels [parindex ] != IterationPointType .LEAF
2290
- labels [parindex ] = IterationPointType .ROOT
2291
- else :
2292
- assert point_label == ArrayPointLabel .CORE
2293
- pass
2281
+ # for p in index.iter():
2282
+ # parindex = p.source_exprs[paraxis.label]
2283
+ # assert isinstance(parindex, numbers.Integral)
2284
+ #
2285
+ # for array in arrays:
2286
+ # # same nasty hack
2287
+ # if isinstance(array, (Mat, Sparsity)) or not hasattr(array, "buffer"):
2288
+ # continue
2289
+ # # skip purely local arrays
2290
+ # if not array.buffer.is_distributed:
2291
+ # continue
2292
+ # if labels[parindex] == IterationPointType.LEAF:
2293
+ # continue
2294
+ #
2295
+ # # loop over stencil
2296
+ # array = array.with_context({index.id: (p.source_path, p.target_path)})
2297
+ #
2298
+ # for q in array.axes.iter({p}):
2299
+ # # offset = array.axes.offset(q.target_exprs, q.target_path)
2300
+ # offset = array.axes.offset(q.source_exprs, q.source_path, loop_exprs=p.replace_map)
2301
+ #
2302
+ # point_label = is_root_or_leaf_per_array[array.name][offset]
2303
+ # if point_label == ArrayPointLabel.LEAF:
2304
+ # labels[parindex] = IterationPointType.LEAF
2305
+ # break # no point doing more analysis
2306
+ # elif point_label == ArrayPointLabel.ROOT:
2307
+ # assert labels[parindex] != IterationPointType.LEAF
2308
+ # labels[parindex] = IterationPointType.ROOT
2309
+ # else:
2310
+ # assert point_label == ArrayPointLabel.CORE
2311
+ # pass
2294
2312
2295
2313
parcpt = just_one (paraxis .components ) # for now
2296
2314
2297
2315
# I don't think this is working - instead everything touches a leaf
2298
2316
# core = just_one(np.nonzero(labels == IterationPointType.CORE))
2299
2317
# root = just_one(np.nonzero(labels == IterationPointType.ROOT))
2300
2318
# leaf = just_one(np.nonzero(labels == IterationPointType.LEAF))
2301
- core = np .asarray ([], dtype = IntType )
2302
- root = np .asarray ([], dtype = IntType )
2303
- leaf = np .arange (paraxis .size , dtype = IntType )
2319
+ # core = np.asarray([], dtype=IntType)
2320
+ # root = np.asarray([], dtype=IntType)
2321
+ # leaf = np.arange(paraxis.size, dtype=IntType)
2322
+
2323
+ # hack to check things
2324
+ core = np .asarray ([0 ], dtype = IntType )
2325
+ root = np .asarray ([1 ], dtype = IntType )
2326
+ leaf = np .arange (2 , paraxis .size , dtype = IntType )
2304
2327
2305
2328
subsets = []
2306
2329
for data in [core , root , leaf ]:
@@ -2313,6 +2336,7 @@ def partition_iterset(index: LoopIndex, arrays):
2313
2336
)
2314
2337
subsets .append (subset )
2315
2338
subsets = tuple (subsets )
2339
+ return "not used" , subsets
2316
2340
2317
2341
# make a new iteration set over just these indices
2318
2342
# index with just core (arbitrary)
0 commit comments