2424"""
2525
2626
27+ from typing import TYPE_CHECKING
28+
2729from loopy .kernel import LoopKernel
2830from loopy .kernel .instruction import BarrierInstruction
29- from loopy .match import parse_match
31+ from loopy .match import ToMatchConvertible , parse_match
3032from loopy .transform .instruction import add_dependency
3133from loopy .translation_unit import for_each_kernel
3234
3335
36+ if TYPE_CHECKING :
37+ from pytools .tag import Tag
38+
39+
3440__doc__ = """
3541.. currentmodule:: loopy
3642
4147# {{{ add_barrier
4248
4349@for_each_kernel
44- def add_barrier (kernel , insn_before = "" , insn_after = "" , id_based_on = None ,
45- tags = None , synchronization_kind = "global" , mem_kind = None ,
46- within_inames = None ):
47- """Takes in a kernel that needs to be added a barrier and returns a kernel
48- which has a barrier inserted into it. It takes input of 2 instructions and
49- then adds a barrier in between those 2 instructions. The expressions can
50- be any inputs that are understood by :func:`loopy.match.parse_match`.
51-
52- :arg insn_before: String expression that specifies the instruction(s)
53- before the barrier which is to be added. If None, no dependencies will
54- be added to barrier.
55- :arg insn_after: String expression that specifies the instruction(s) after
56- the barrier which is to be added. If None, no dependencies on the barrier
57- will be added.
58- :arg id: String on which the id of the barrier would be based on.
50+ def add_barrier (
51+ kernel : LoopKernel ,
52+ insn_before : ToMatchConvertible ,
53+ insn_after : ToMatchConvertible ,
54+ id_based_on : str | None = None ,
55+ tags : frozenset [Tag ] | None = None ,
56+ synchronization_kind : str = "global" ,
57+ mem_kind : str | None = None ,
58+ within_inames : frozenset [str ] | None = None ,
59+ ) -> LoopKernel :
60+ """
61+ Returns a transformed version of *kernel* with an additional
62+ :class:`loopy.BarrierInstruction` inserted.
63+
64+ :arg insn_before: Match expression that specifies the instruction(s)
65+ that the barrier instruction depends on.
66+ :arg insn_after: Match expression that specifies the instruction(s)
67+ that depend on the barrier instruction.
68+ :arg id_based_on: Prefix for the barrier instructions' ID.
5969 :arg tags: The tag of the group to which the barrier must be added
6070 :arg synchronization_kind: Kind of barrier to be added. May be "global" or
6171 "local"
62- :arg kind: Type of memory to be synchronized. May be "global" or "local". Ignored
63- for "global" barriers. If not supplied, defaults to *synchronization_kind*
72+ :arg mem_kind: Type of memory to be synchronized. May be "global" or
73+ "local". Ignored for "global" barriers. If not supplied, defaults to
74+ *synchronization_kind*
6475 :arg within_inames: A :class:`frozenset` of inames identifying the loops
6576 within which the barrier will be executed.
66-
6777 """
6878
6979 assert isinstance (kernel , LoopKernel )
@@ -77,14 +87,11 @@ def add_barrier(kernel, insn_before="", insn_after="", id_based_on=None,
7787 else :
7888 id = kernel .make_unique_instruction_id (based_on = id_based_on )
7989
80- if insn_before is not None :
81- match = parse_match (insn_before )
82- insns_before = frozenset (
83- [insn .id for insn in kernel .instructions if match (kernel , insn )])
84- else :
85- insns_before = None
90+ match = parse_match (insn_before )
91+ depends_on = frozenset (
92+ [insn .id for insn in kernel .instructions if match (kernel , insn )])
8693
87- barrier_to_add = BarrierInstruction (depends_on = insns_before ,
94+ barrier_to_add = BarrierInstruction (depends_on = depends_on ,
8895 depends_on_is_final = True ,
8996 id = id ,
9097 within_inames = within_inames ,
@@ -93,10 +100,9 @@ def add_barrier(kernel, insn_before="", insn_after="", id_based_on=None,
93100 mem_kind = mem_kind )
94101
95102 new_kernel = kernel .copy (instructions = [* kernel .instructions , barrier_to_add ])
96- if insn_after is not None :
97- new_kernel = add_dependency (new_kernel ,
98- insn_match = insn_after ,
99- depends_on = "id:" + id )
103+ new_kernel = add_dependency (
104+ new_kernel , insn_match = insn_after , depends_on = "id:" + id
105+ )
100106
101107 return new_kernel
102108
0 commit comments