Skip to content

Commit 8bc909d

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Fix plgpu.inline_mgpu support for nested blocks.
The current implementation of `_inline_block` only performs a shallow copy of the operations in a block. PiperOrigin-RevId: 832321954
1 parent bbec947 commit 8bc909d

File tree

2 files changed

+36
-62
lines changed

2 files changed

+36
-62
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 31 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from collections.abc import Callable, Hashable, Sequence
19+
from collections.abc import Callable, Hashable, Iterator, Sequence
2020
import contextlib
2121
import dataclasses
2222
import functools
@@ -2678,50 +2678,23 @@ def _shape_dtype_struct_to_type_and_layout(
26782678
return vector_type, layout
26792679

26802680

2681-
# TODO(allanrenucci): This function is most likely broken. We need to review the
2682-
# `inline_mgpu` lowering logic and clean it up.
2683-
# It was moved from MGPU dialect lowering where it is not used anymore. The
2684-
# rewrite in the dialect lowering addressed bugs in this code.
2685-
def _inline_block(
2686-
block: ir.Block,
2687-
args: Sequence[ir.Value],
2688-
mapper: dict[ir.Value, ir.Value],
2689-
) -> list[ir.Value]:
2690-
"""Inlines the given block at the current insertion point.
2681+
def _replace_uses_in_block(old: ir.Value, new: ir.Value, block: ir.Block):
2682+
"""Replaces all uses of the `old` value with the `new` value in `block`."""
26912683

2692-
The block args are replaced with the provided `args`. If the input mapper is
2693-
not empty, it could further be used to replace captured values with an
2694-
alternative.
2684+
def is_contained_within_block(op: ir.OpView, block: ir.Block) -> bool:
2685+
def parent_blocks(op: ir.OpView) -> Iterator[ir.Block]:
2686+
current_op = op
2687+
while current_op.parent is not None:
2688+
yield current_op.operation.block
2689+
current_op = current_op.parent
26952690

2696-
The operands of the terminator are returned as results.
2697-
"""
2698-
for arg, val in zip(block.arguments, args, strict=True):
2699-
mapper[arg] = val
2700-
return_op = None
2701-
for op in block.operations:
2702-
if isinstance(op.opview, mgpu.dialect.ReturnOp):
2703-
assert return_op is None
2704-
return_op = op.opview
2705-
2706-
# Operands not in the mapper are captured from the context.
2707-
new_operands = [mapper[o] if o in mapper else o for o in op.operands]
2708-
new_attributes = {
2709-
named_attr: op.attributes[named_attr] for named_attr in op.attributes
2710-
}
2711-
new_op = ir.Operation.create(
2712-
name=op.name,
2713-
results=[res.type for res in op.results],
2714-
operands=new_operands,
2715-
attributes=new_attributes,
2716-
)
2717-
for old_result, new_result in zip(op.results, new_op.results):
2718-
mapper[old_result] = new_result
2719-
2720-
if return_op is None:
2721-
raise ValueError("A custom return op must terminate the block.")
2722-
2723-
inlined_return_values = [mapper[o] for o in return_op.operands]
2724-
return inlined_return_values
2691+
return block in parent_blocks(op)
2692+
2693+
exceptions = []
2694+
for use in old.uses:
2695+
if not is_contained_within_block(use.owner, block):
2696+
exceptions.append(use.owner.operation)
2697+
old.replace_all_uses_except(new, exceptions)
27252698

27262699

27272700
def _clone_custom_op_with_extra_args(
@@ -2759,21 +2732,19 @@ def _clone_custom_op_with_extra_args(
27592732
out_layouts=custom_op.out_layouts,
27602733
)
27612734
new_block = new_op.body.blocks.append(*new_in_types)
2762-
2763-
# Clone the old block, by inlining it into the new one.
2735+
ip = ir.InsertionPoint(new_block)
2736+
for op in old_block.operations:
2737+
op.detach_from_parent()
2738+
ip.insert(op)
2739+
for old_arg, new_arg in zip(old_block.arguments, new_block.arguments):
2740+
old_arg.replace_all_uses_with(new_arg)
27642741
num_old_args = len(old_block.arguments)
2765-
with ir.InsertionPoint.at_block_begin(new_block):
2766-
_inline_block(
2767-
old_block,
2768-
list(new_block.arguments)[:num_old_args],
2769-
mapper=dict(
2770-
zip(
2771-
extra_args,
2772-
list(new_block.arguments)[num_old_args:],
2773-
strict=True,
2774-
)
2775-
),
2776-
)
2742+
for extra_arg, new_arg in zip(
2743+
extra_args,
2744+
list(new_block.arguments)[num_old_args:],
2745+
strict=True,
2746+
):
2747+
_replace_uses_in_block(extra_arg, new_arg, new_block)
27772748

27782749
return new_op
27792750

@@ -3000,10 +2971,8 @@ def _inline_mgpu_lowering_rule_wg_semantics(
30002971
)
30012972

30022973
# We need to ensure that the block doesn't capture any values from the context
3003-
# and uses args for everything instead. At least one thing the block is likely
3004-
# to capture is the SMEM scratch buffer which could have been created outside
3005-
# of the block during the execution of the provided mgpu_fn, if it calls
3006-
# `async_copy`.
2974+
# and uses args for everything instead. E.g. `LaunchContext.tma_descriptors`
2975+
# will be capture when calling `ctx.async_copy`.
30072976
captured = _closed_over_values(block)
30082977
if captured:
30092978
old_custom_op = custom_op
@@ -3021,7 +2990,7 @@ def _inline_mgpu_lowering_rule_wg_semantics(
30212990
x.type
30222991
)
30232992
return _inline_mgpu_flat_results(
3024-
ctx, ret, pytree_ret_ty, flat_ret_ty, is_leaf=is_leaf
2993+
ctx, ret, pytree_ret_ty, flat_ret_ty, is_leaf
30252994
)
30262995

30272996

tests/pallas/mosaic_gpu_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,11 @@ def store(ctx, arr, smem_ref, o_ref):
547547
)
548548
ctx.await_async_copy(0)
549549

550+
# A dummy if statement to make sure we inline nested blocks correctly.
551+
is_leader_thread = mgpu.utils.single_thread_predicate()
552+
with mgpu.utils.when(is_leader_thread):
553+
pass
554+
550555
# This time we slice inside the inline_mgpu body.
551556
store(arr, smem_ref, o_ref)
552557

0 commit comments

Comments
 (0)