1616
1717from __future__ import annotations
1818
19- from collections .abc import Callable , Hashable , Sequence
19+ from collections .abc import Callable , Hashable , Iterator , Sequence
2020import contextlib
2121import dataclasses
2222import functools
@@ -2678,50 +2678,23 @@ def _shape_dtype_struct_to_type_and_layout(
26782678 return vector_type , layout
26792679
26802680
2681- # TODO(allanrenucci): This function is most likely broken. We need to review the
2682- # `inline_mgpu` lowering logic and clean it up.
2683- # It was moved from MGPU dialect lowering where it is not used anymore. The
2684- # rewrite in the dialect lowering addressed bugs in this code.
2685- def _inline_block (
2686- block : ir .Block ,
2687- args : Sequence [ir .Value ],
2688- mapper : dict [ir .Value , ir .Value ],
2689- ) -> list [ir .Value ]:
2690- """Inlines the given block at the current insertion point.
2681+ def _replace_uses_in_block (old : ir .Value , new : ir .Value , block : ir .Block ):
2682+ """Replaces all uses of the `old` value with the `new` value in `block`."""
26912683
2692- The block args are replaced with the provided `args`. If the input mapper is
2693- not empty, it could further be used to replace captured values with an
2694- alternative.
2684+ def is_contained_within_block (op : ir .OpView , block : ir .Block ) -> bool :
2685+ def parent_blocks (op : ir .OpView ) -> Iterator [ir .Block ]:
2686+ current_op = op
2687+ while current_op .parent is not None :
2688+ yield current_op .operation .block
2689+ current_op = current_op .parent
26952690
2696- The operands of the terminator are returned as results.
2697- """
2698- for arg , val in zip (block .arguments , args , strict = True ):
2699- mapper [arg ] = val
2700- return_op = None
2701- for op in block .operations :
2702- if isinstance (op .opview , mgpu .dialect .ReturnOp ):
2703- assert return_op is None
2704- return_op = op .opview
2705-
2706- # Operands not in the mapper are captured from the context.
2707- new_operands = [mapper [o ] if o in mapper else o for o in op .operands ]
2708- new_attributes = {
2709- named_attr : op .attributes [named_attr ] for named_attr in op .attributes
2710- }
2711- new_op = ir .Operation .create (
2712- name = op .name ,
2713- results = [res .type for res in op .results ],
2714- operands = new_operands ,
2715- attributes = new_attributes ,
2716- )
2717- for old_result , new_result in zip (op .results , new_op .results ):
2718- mapper [old_result ] = new_result
2719-
2720- if return_op is None :
2721- raise ValueError ("A custom return op must terminate the block." )
2722-
2723- inlined_return_values = [mapper [o ] for o in return_op .operands ]
2724- return inlined_return_values
2691+ return block in parent_blocks (op )
2692+
2693+ exceptions = []
2694+ for use in old .uses :
2695+ if not is_contained_within_block (use .owner , block ):
2696+ exceptions .append (use .owner .operation )
2697+ old .replace_all_uses_except (new , exceptions )
27252698
27262699
27272700def _clone_custom_op_with_extra_args (
@@ -2759,21 +2732,19 @@ def _clone_custom_op_with_extra_args(
27592732 out_layouts = custom_op .out_layouts ,
27602733 )
27612734 new_block = new_op .body .blocks .append (* new_in_types )
2762-
2763- # Clone the old block, by inlining it into the new one.
2735+ ip = ir .InsertionPoint (new_block )
2736+ for op in old_block .operations :
2737+ op .detach_from_parent ()
2738+ ip .insert (op )
2739+ for old_arg , new_arg in zip (old_block .arguments , new_block .arguments ):
2740+ old_arg .replace_all_uses_with (new_arg )
27642741 num_old_args = len (old_block .arguments )
2765- with ir .InsertionPoint .at_block_begin (new_block ):
2766- _inline_block (
2767- old_block ,
2768- list (new_block .arguments )[:num_old_args ],
2769- mapper = dict (
2770- zip (
2771- extra_args ,
2772- list (new_block .arguments )[num_old_args :],
2773- strict = True ,
2774- )
2775- ),
2776- )
2742+ for extra_arg , new_arg in zip (
2743+ extra_args ,
2744+ list (new_block .arguments )[num_old_args :],
2745+ strict = True ,
2746+ ):
2747+ _replace_uses_in_block (extra_arg , new_arg , new_block )
27772748
27782749 return new_op
27792750
@@ -3000,10 +2971,8 @@ def _inline_mgpu_lowering_rule_wg_semantics(
30002971 )
30012972
30022973 # We need to ensure that the block doesn't capture any values from the context
3003- # and uses args for everything instead. At least one thing the block is likely
3004- # to capture is the SMEM scratch buffer which could have been created outside
3005- # of the block during the execution of the provided mgpu_fn, if it calls
3006- # `async_copy`.
2974+ # and uses args for everything instead. E.g. `LaunchContext.tma_descriptors`
2975+ # will be capture when calling `ctx.async_copy`.
30072976 captured = _closed_over_values (block )
30082977 if captured :
30092978 old_custom_op = custom_op
@@ -3021,7 +2990,7 @@ def _inline_mgpu_lowering_rule_wg_semantics(
30212990 x .type
30222991 )
30232992 return _inline_mgpu_flat_results (
3024- ctx , ret , pytree_ret_ty , flat_ret_ty , is_leaf = is_leaf
2993+ ctx , ret , pytree_ret_ty , flat_ret_ty , is_leaf
30252994 )
30262995
30272996
0 commit comments