Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit 86cbfc4

Browse files
committed
Codegen fixups
1 parent 26c301c commit 86cbfc4

File tree

4 files changed

+120
-40
lines changed

4 files changed

+120
-40
lines changed

pyop3/array/petsc.py

+2
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,9 @@ def maps(self):
451451
dropped_ckeys = set()
452452

453453
# TODO: are dropped_rkeys and dropped_ckeys still needed?
454+
# FIXME: this whole thing falls apart if we have multiple loop contexts
454455
loop_index = just_one(self.block_raxes.outer_loops)
456+
455457
iterset = AxisTree(loop_index.iterset.node_map)
456458

457459
rmap_axes = iterset.add_subtree(self.block_raxes, *iterset.leaf)

pyop3/ir/lower.py

+50-26
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ContextAwareLoop,
4444
DummyKernelArgument,
4545
Loop,
46+
LoopList,
4647
PetscMatAdd,
4748
PetscMatInstruction,
4849
PetscMatLoad,
@@ -307,6 +308,7 @@ def set_temporary_shapes(self, shapes):
307308

308309
class CodegenResult:
309310
def __init__(self, expr, ir, arg_replace_map, *, compiler_parameters):
311+
# NOTE: should this be iterable?
310312
self.expr = as_tuple(expr)
311313
self.ir = ir
312314
self.arg_replace_map = arg_replace_map
@@ -315,7 +317,7 @@ def __init__(self, expr, ir, arg_replace_map, *, compiler_parameters):
315317

316318
@cached_property
317319
def datamap(self):
318-
return merge_dicts(e.datamap for e in self.expr)
320+
return merge_dicts(e.preprocessed.datamap for e in self.expr)
319321

320322
def __call__(self, **kwargs):
321323
data_args = []
@@ -399,35 +401,41 @@ def parse_compiler_parameters(compiler_parameters) -> CompilerParameters:
399401
def compile(expr: Instruction, compiler_parameters=None):
400402
compiler_parameters = parse_compiler_parameters(compiler_parameters)
401403

402-
# preprocess expr before lowering
403-
from pyop3.transform import expand_implicit_pack_unpack, expand_loop_contexts
404-
405404
function_name = expr.name
406405

407-
cs_expr = expand_loop_contexts(expr)
406+
if isinstance(expr, LoopList):
407+
cs_expr = expr.loops
408+
else:
409+
assert isinstance(expr, Loop), "other types not handled yet"
410+
cs_expr = (expr,)
411+
408412
ctx = LoopyCodegenContext()
409-
for context, ex in cs_expr:
410-
ex = expand_implicit_pack_unpack(ex)
413+
# NOTE: so I think LoopCollection is a better abstraction here - don't want to be
414+
# explicitly dealing with contexts at this point. Can always sniff them out again.
415+
# for context, ex in cs_expr:
416+
for ex in cs_expr:
417+
# ex = expand_implicit_pack_unpack(ex)
411418

412419
# add external loop indices as kernel arguments
420+
# FIXME: removed because cs_expr needs to sniff the context now
413421
loop_indices = {}
414-
for index, (path, _) in context.items():
415-
if len(path) > 1:
416-
raise NotImplementedError("needs to be sorted")
417-
418-
# dummy = HierarchicalArray(index.iterset, data=NullBuffer(IntType))
419-
dummy = HierarchicalArray(Axis(1), dtype=IntType)
420-
# this is dreadful, pass an integer array instead
421-
ctx.add_argument(dummy)
422-
myname = ctx.actual_to_kernel_rename_map[dummy.name]
423-
replace_map = {
424-
axis: pym.subscript(pym.var(myname), (i,))
425-
for i, axis in enumerate(path.keys())
426-
}
427-
# FIXME currently assume that source and target exprs are the same, they are not!
428-
loop_indices[index] = (replace_map, replace_map)
429-
430-
for e in as_tuple(ex):
422+
# for index, (path, _) in context.items():
423+
# if len(path) > 1:
424+
# raise NotImplementedError("needs to be sorted")
425+
#
426+
# # dummy = HierarchicalArray(index.iterset, data=NullBuffer(IntType))
427+
# dummy = HierarchicalArray(Axis(1), dtype=IntType)
428+
# # this is dreadful, pass an integer array instead
429+
# ctx.add_argument(dummy)
430+
# myname = ctx.actual_to_kernel_rename_map[dummy.name]
431+
# replace_map = {
432+
# axis: pym.subscript(pym.var(myname), (i,))
433+
# for i, axis in enumerate(path.keys())
434+
# }
435+
# # FIXME currently assume that source and target exprs are the same, they are not!
436+
# loop_indices[index] = (replace_map, replace_map)
437+
438+
for e in as_tuple(ex): # TODO: get rid of this loop
431439
# context manager?
432440
ctx.set_temporary_shapes(_collect_temporary_shapes(e))
433441
_compile(e, loop_indices, ctx)
@@ -497,6 +505,7 @@ def _collect_temporary_shapes(expr):
497505
raise TypeError(f"No handler defined for {type(expr).__name__}")
498506

499507

508+
# TODO: get rid of this type
500509
@_collect_temporary_shapes.register
501510
def _(expr: ContextAwareLoop):
502511
shapes = {}
@@ -512,6 +521,20 @@ def _(expr: ContextAwareLoop):
512521
return shapes
513522

514523

524+
@_collect_temporary_shapes.register
525+
def _(expr: Loop):
526+
shapes = {}
527+
for stmt in expr.statements:
528+
for temp, shape in _collect_temporary_shapes(stmt).items():
529+
if shape is None:
530+
continue
531+
if temp in shapes:
532+
assert shapes[temp] == shape
533+
else:
534+
shapes[temp] = shape
535+
return shapes
536+
537+
515538
@_collect_temporary_shapes.register
516539
def _(expr: Assignment):
517540
return pmap()
@@ -539,9 +562,10 @@ def _compile(expr: Any, loop_indices, ctx: LoopyCodegenContext) -> None:
539562
raise TypeError(f"No handler defined for {type(expr).__name__}")
540563

541564

542-
@_compile.register
565+
@_compile.register(ContextAwareLoop) # remove
566+
@_compile.register(Loop)
543567
def _(
544-
loop: ContextAwareLoop,
568+
loop,
545569
loop_indices,
546570
codegen_context: LoopyCodegenContext,
547571
) -> None:

pyop3/lang.py

+55-10
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ class Intent(enum.Enum):
6262
NA = Intent.NA
6363

6464

65+
class ExpressionState(enum.Enum):
66+
"""Enum indicating the state of an expression (preprocessed or not)."""
67+
INITIAL = "initial"
68+
PREPROCESSED = "preprocessed"
69+
70+
6571
# TODO: This exception is not actually ever raised. We should check the
6672
# intents of the kernel arguments and complain if something illegal is
6773
# happening.
@@ -101,7 +107,32 @@ def kernel_dtype(self):
101107

102108

103109
class Instruction(UniqueRecord, abc.ABC):
104-
pass
110+
fields = UniqueRecord.fields | {"state"}
111+
112+
@cached_property
113+
def preprocessed(self):
114+
from pyop3.transform import expand_implicit_pack_unpack, expand_loop_contexts
115+
116+
if self.state == ExpressionState.PREPROCESSED:
117+
return self
118+
else:
119+
insn = self
120+
insn = expand_loop_contexts(insn)
121+
insn = expand_implicit_pack_unpack(insn)
122+
# TODO: should make marking things as preprocessed be an extra stage, currently do in expand_loop_contexts which is a bug
123+
return insn
124+
125+
@property
126+
def is_preprocessed(self):
127+
return self.state == ExpressionState.PREPROCESSED
128+
129+
@cached_property
130+
def loopy_code(self):
131+
from pyop3.ir.lower import compile
132+
133+
return compile(self.preprocessed)
134+
135+
105136

106137

107138
class ContextAwareInstruction(Instruction):
@@ -132,22 +163,25 @@ def __init__(
132163
*,
133164
name: str = _DEFAULT_LOOP_NAME,
134165
compiler_parameters=None,
166+
state: ExpressionState = ExpressionState.INITIAL,
135167
**kwargs,
136168
):
137169
super().__init__(**kwargs)
138170
self.index = index
139171
self.statements = as_tuple(statements)
140172
self.name = name
141173
self.compiler_parameters = compiler_parameters
174+
self.state = state
142175

143176
def __call__(self, **kwargs):
144177
# TODO just parse into ContextAwareLoop and call that
145178
from pyop3.ir.lower import compile
146179
from pyop3.itree.tree import partition_iterset
147180

148-
code = compile(self, compiler_parameters=self.compiler_parameters)
181+
code = compile(self.preprocessed, compiler_parameters=self.compiler_parameters)
149182

150-
if self.is_parallel:
183+
if False:
184+
# if self.is_parallel:
151185
# FIXME: The partitioning code does not seem to always run properly
152186
# so for now do all the transfers in advance.
153187
# interleave computation and communication
@@ -205,12 +239,6 @@ def __call__(self, **kwargs):
205239
with PETSc.Log.Event(f"compute_{self.name}_serial"):
206240
code(**kwargs)
207241

208-
@cached_property
209-
def loopy_code(self):
210-
from pyop3.ir.lower import compile
211-
212-
return compile(self)
213-
214242
@cached_property
215243
def is_parallel(self):
216244
from pyop3.buffer import DistributedBuffer
@@ -355,7 +383,10 @@ def _init_nil():
355383

356384
@cached_property
357385
def datamap(self):
358-
return self.index.datamap | merge_dicts(stmt.datamap for stmt in self.statements)
386+
if self.is_preprocessed:
387+
return self.index.datamap | merge_dicts(stmt.datamap for stmt in self.statements)
388+
else:
389+
return self.preprocessed.datamap
359390

360391

361392
class ContextAwareLoop(ContextAwareInstruction):
@@ -379,6 +410,20 @@ def loopy_code(self):
379410
return compile(self)
380411

381412

413+
class LoopList(Instruction):
414+
fields = Instruction.fields | {"loops"}
415+
416+
def __init__(self, loops, *, name=_DEFAULT_LOOP_NAME, state=ExpressionState.INITIAL, **kwargs):
417+
super().__init__(**kwargs)
418+
self.loops = loops
419+
self.name = name
420+
self.state = ExpressionState(state)
421+
422+
@cached_property
423+
def datamap(self):
424+
return merge_dicts(l.datamap for l in self.loops)
425+
426+
382427
# TODO singledispatch
383428
# TODO perhaps this is simply "has non unit stride"?
384429
def _has_nontrivial_stencil(array):

pyop3/transform.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DummyKernelArgument,
2626
Instruction,
2727
Loop,
28+
LoopList,
2829
Pack,
2930
PetscMatAdd,
3031
PetscMatInstruction,
@@ -33,7 +34,7 @@
3334
ReplaceAssignment,
3435
Terminal,
3536
)
36-
from pyop3.utils import UniqueNameGenerator, checked_zip, just_one
37+
from pyop3.utils import UniqueNameGenerator, checked_zip, just_one, single_valued
3738

3839

3940
# TODO Is this generic for other parsers/transformers? Esp. lower.py
@@ -127,12 +128,16 @@ def _(self, loop: Loop, *, context):
127128
statements[source_path].append(mystmt)
128129

129130
# FIXME this does not propagate inner outer contexts
130-
loop = ContextAwareLoop(
131+
# NOTE: also I think this is redundant, just use a Loop!!!
132+
csloop = ContextAwareLoop(
131133
loop.index.copy(iterset=cf_iterset),
132134
statements,
133135
)
134-
loops.append((octx, loop))
135-
return tuple(loops)
136+
# NOTE: outer context now needs sniffing out, makes the objects nicer
137+
# loops.append((octx, loop))
138+
loops.append(csloop)
139+
140+
return LoopList(loops, name=loop.name, state="preprocessed")
136141

137142
@_apply.register
138143
def _(self, terminal: CalledFunction, *, context):
@@ -247,6 +252,10 @@ def _(self, loop: ContextAwareLoop):
247252
),
248253
)
249254

255+
@_apply.register
256+
def _(self, loop_list: LoopList):
257+
return loop_list.copy(loops=[loop_ for loop in loop_list.loops for loop_ in self._apply(loop)])
258+
250259
# TODO: Should be the same as Assignment
251260
@_apply.register
252261
def _(self, assignment: PetscMatInstruction):

0 commit comments

Comments
 (0)