Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 203e218

Browse files
committed
Remove dead code, disable halo logic and remove deprecated bits
Halo code is still wrong as DG advection is losing mass in parallel, but it is at least running.
1 parent 73505c8 commit 203e218

File tree

9 files changed

+172
-463
lines changed

9 files changed

+172
-463
lines changed

pyop3/array/harray.py

+12-81
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
as_axis_tree,
2929
)
3030
from pyop3.axtree.layout import eval_offset
31-
from pyop3.axtree.tree import Indexed, IndexedAxisTree, MultiArrayCollector
31+
from pyop3.axtree.tree import IndexedAxisTree, MultiArrayCollector
3232
from pyop3.buffer import Buffer, DistributedBuffer
3333
from pyop3.dtypes import IntType, ScalarType
3434
from pyop3.lang import KernelArgument, ReplaceAssignment
@@ -74,12 +74,6 @@ def __init__(self, array, indices, path=None):
7474
def __getinitargs__(self):
7575
return (self.array, self.indices, self.path)
7676

77-
# def __str__(self) -> str:
78-
# return f"{self.array.name}[{{{', '.join(f'{i[0]}: {i[1]}' for i in self.indices.items())}}}]"
79-
#
80-
# def __repr__(self) -> str:
81-
# return f"MultiArrayVariable({self.array!r}, {self.indices!r})"
82-
8377

8478
from pymbolic.mapper.stringifier import PREC_CALL, PREC_NONE, StringifyMapper
8579

@@ -101,29 +95,11 @@ def stringify_array(self, array, enclosing_prec, *args, **kwargs):
10195
CalledMapVariable = ArrayVar
10296

10397

104-
# does not belong here!
105-
# class CalledMapVariable(ArrayVar):
106-
# mapper_method = sys.intern("map_called_map_variable")
107-
#
108-
# def __init__(self, array, path, input_index_exprs, shape_index_exprs):
109-
# super().__init__(array, {**input_index_exprs, **shape_index_exprs}, path)
110-
# self.input_index_exprs = freeze(input_index_exprs)
111-
# self.shape_index_exprs = freeze(shape_index_exprs)
112-
#
113-
# def __getinitargs__(self):
114-
# return (
115-
# self.array,
116-
# self.target_path,
117-
# self.input_index_exprs,
118-
# self.shape_index_exprs,
119-
# )
120-
121-
12298
class FancyIndexWriteException(Exception):
12399
pass
124100

125101

126-
class HierarchicalArray(Array, Indexed, ContextFree, KernelArgument):
102+
class HierarchicalArray(Array, ContextFree, KernelArgument):
127103
"""Multi-dimensional, hierarchical array.
128104
129105
Parameters
@@ -189,35 +165,6 @@ def __init__(
189165
# TODO This attr really belongs to the buffer not the array
190166
self.constant = constant
191167

192-
# if some_but_not_all(x is None for x in [target_paths, index_exprs]):
193-
# raise ValueError
194-
195-
# if target_paths is None:
196-
# target_paths = axes._default_target_paths()
197-
# if index_exprs is None:
198-
# index_exprs = axes._default_index_exprs()
199-
#
200-
# self._target_paths = freeze(target_paths)
201-
# self._index_exprs = freeze(index_exprs)
202-
# self._outer_loops = outer_loops or ()
203-
#
204-
# self._layouts = layouts if layouts is not None else axes.layouts
205-
206-
@property
207-
@deprecated()
208-
def target_paths(self):
209-
return self.axes.target_paths
210-
211-
@property
212-
@deprecated()
213-
def index_exprs(self):
214-
return self.axes.index_exprs
215-
216-
@property
217-
@deprecated()
218-
def layouts(self):
219-
return self.axes.layouts
220-
221168
def __str__(self):
222169
return self.name
223170

@@ -253,14 +200,9 @@ def getitem(self, indices, *, strict=False):
253200
# to be iterable (which it's not). This avoids some confusing behaviour.
254201
__iter__ = None
255202

256-
@property
257-
@deprecated("buffer")
258-
def array(self):
259-
return self.buffer
260-
261203
@property
262204
def dtype(self):
263-
return self.array.dtype
205+
return self.buffer.dtype
264206

265207
@property
266208
def kernel_dtype(self):
@@ -425,7 +367,7 @@ def outer_loops(self):
425367

426368
@property
427369
def sf(self):
428-
return self.array.sf
370+
return self.buffer.sf
429371

430372
@property
431373
def comm(self):
@@ -436,11 +378,13 @@ def datamap(self):
436378
datamap_ = {}
437379
datamap_.update(self.buffer.datamap)
438380
datamap_.update(self.axes.datamap)
439-
for index_exprs in self.index_exprs.values():
381+
382+
# FIXME, deleting this breaks stuff...
383+
for index_exprs in self.axes.index_exprs.values():
440384
for expr in index_exprs.values():
441385
for array in MultiArrayCollector()(expr):
442386
datamap_.update(array.datamap)
443-
for layout_expr in self.layouts.values():
387+
for layout_expr in self.axes.layouts.values():
444388
for array in MultiArrayCollector()(layout_expr):
445389
datamap_.update(array.datamap)
446390
return freeze(datamap_)
@@ -457,9 +401,9 @@ def assemble(self, update_leaves=False):
457401
458402
"""
459403
if update_leaves:
460-
self.array._reduce_then_broadcast()
404+
self.buffer._reduce_then_broadcast()
461405
else:
462-
self.array._reduce_leaves_to_roots()
406+
self.buffer._reduce_leaves_to_roots()
463407

464408
def materialize(self) -> HierarchicalArray:
465409
"""Return a new "unindexed" array with the same shape."""
@@ -491,7 +435,7 @@ def _with_axes(self, axes):
491435
assert False, "do not use, it's wrong"
492436
return type(self)(
493437
axes,
494-
data=self.array,
438+
data=self.buffer,
495439
max_value=self.max_value,
496440
name=self.name,
497441
)
@@ -518,7 +462,7 @@ def from_list(cls, data, axis_labels, name=None, dtype=ScalarType, inc=0):
518462
if isinstance(count, Sequence):
519463
count = cls.from_list(count, axis_labels[:-1], name, dtype, inc + 1)
520464
subaxis = Axis(count, axis_labels[-1])
521-
axes = count.axes.add_subaxis(subaxis, count.axes.leaf)
465+
axes = count.axes.add_axis(subaxis, count.axes.leaf)
522466
else:
523467
axes = AxisTree(Axis(count, axis_labels[-1]))
524468

@@ -547,15 +491,6 @@ def set_value(self, indices, value, path=None, *, loop_exprs=pmap()):
547491
offset = self.axes.offset(indices, path, loop_exprs=loop_exprs)
548492
self.buffer.data_wo[offset] = value
549493

550-
# def offset(self, indices, path=None, *, loop_exprs=pmap()):
551-
# return eval_offset(
552-
# self.axes,
553-
# self.subst_layouts,
554-
# indices,
555-
# path,
556-
# loop_exprs=loop_exprs,
557-
# )
558-
559494
def select_axes(self, indices):
560495
selected = []
561496
current_axis = self.axes
@@ -657,10 +592,6 @@ def __getitem__(self, indices) -> ContextSensitiveMultiArray:
657592
#
658593
# return ContextSensitiveMultiArray(array_per_context)
659594

660-
@property
661-
def array(self):
662-
return self._shared_attr("array")
663-
664595
@property
665596
def buffer(self):
666597
return self._shared_attr("buffer")

pyop3/array/petsc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def maps(self):
342342
].items()
343343
if ax == cfield_axis.label
344344
)
345-
orig_caxes = AxisTree(self.orig_caxes[cfield].parent_to_children)
345+
orig_caxes = AxisTree(self.orig_caxes[cfield].node_map)
346346
orig_caxess = [orig_caxes]
347347
dropped_ckeys = {cfield_axis.label}
348348
else:

0 commit comments

Comments
 (0)