Skip to content

Commit 0cc1840

Browse files
kaushikcfdinducer
authored andcommitted
Type lp.add_barrier, enhace API.
1 parent cb46068 commit 0cc1840

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

loopy/kernel/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,8 @@ def make_unique_instruction_id(self, insns=None, based_on="insn",
282282
if id_str not in used_ids:
283283
return intern(id_str)
284284

285+
raise RuntimeError("Unreachable.")
286+
285287
def all_group_names(self):
286288
result = set()
287289
for insn in self.instructions:

loopy/transform/add_barrier.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,19 @@
2424
"""
2525

2626

27+
from typing import TYPE_CHECKING
28+
2729
from loopy.kernel import LoopKernel
2830
from loopy.kernel.instruction import BarrierInstruction
29-
from loopy.match import parse_match
31+
from loopy.match import ToMatchConvertible, parse_match
3032
from loopy.transform.instruction import add_dependency
3133
from loopy.translation_unit import for_each_kernel
3234

3335

36+
if TYPE_CHECKING:
37+
from pytools.tag import Tag
38+
39+
3440
__doc__ = """
3541
.. currentmodule:: loopy
3642
@@ -41,29 +47,33 @@
4147
# {{{ add_barrier
4248

4349
@for_each_kernel
44-
def add_barrier(kernel, insn_before="", insn_after="", id_based_on=None,
45-
tags=None, synchronization_kind="global", mem_kind=None,
46-
within_inames=None):
47-
"""Takes in a kernel that needs to be added a barrier and returns a kernel
48-
which has a barrier inserted into it. It takes input of 2 instructions and
49-
then adds a barrier in between those 2 instructions. The expressions can
50-
be any inputs that are understood by :func:`loopy.match.parse_match`.
51-
52-
:arg insn_before: String expression that specifies the instruction(s)
53-
before the barrier which is to be added. If None, no dependencies will
54-
be added to barrier.
55-
:arg insn_after: String expression that specifies the instruction(s) after
56-
the barrier which is to be added. If None, no dependencies on the barrier
57-
will be added.
58-
:arg id: String on which the id of the barrier would be based on.
50+
def add_barrier(
51+
kernel: LoopKernel,
52+
insn_before: ToMatchConvertible,
53+
insn_after: ToMatchConvertible,
54+
id_based_on: str | None = None,
55+
tags: frozenset[Tag] | None = None,
56+
synchronization_kind: str = "global",
57+
mem_kind: str | None = None,
58+
within_inames: frozenset[str] | None = None,
59+
) -> LoopKernel:
60+
"""
61+
Returns a transformed version of *kernel* with an additional
62+
:class:`loopy.BarrierInstruction` inserted.
63+
64+
:arg insn_before: Match expression that specifies the instruction(s)
65+
that the barrier instruction depends on.
66+
:arg insn_after: Match expression that specifies the instruction(s)
67+
that depend on the barrier instruction.
68+
:arg id_based_on: Prefix for the barrier instructions' ID.
5969
:arg tags: The tag of the group to which the barrier must be added
6070
:arg synchronization_kind: Kind of barrier to be added. May be "global" or
6171
"local"
62-
:arg kind: Type of memory to be synchronized. May be "global" or "local". Ignored
63-
for "global" barriers. If not supplied, defaults to *synchronization_kind*
72+
:arg mem_kind: Type of memory to be synchronized. May be "global" or
73+
"local". Ignored for "global" barriers. If not supplied, defaults to
74+
*synchronization_kind*
6475
:arg within_inames: A :class:`frozenset` of inames identifying the loops
6576
within which the barrier will be executed.
66-
6777
"""
6878

6979
assert isinstance(kernel, LoopKernel)
@@ -77,14 +87,11 @@ def add_barrier(kernel, insn_before="", insn_after="", id_based_on=None,
7787
else:
7888
id = kernel.make_unique_instruction_id(based_on=id_based_on)
7989

80-
if insn_before is not None:
81-
match = parse_match(insn_before)
82-
insns_before = frozenset(
83-
[insn.id for insn in kernel.instructions if match(kernel, insn)])
84-
else:
85-
insns_before = None
90+
match = parse_match(insn_before)
91+
depends_on = frozenset(
92+
[insn.id for insn in kernel.instructions if match(kernel, insn)])
8693

87-
barrier_to_add = BarrierInstruction(depends_on=insns_before,
94+
barrier_to_add = BarrierInstruction(depends_on=depends_on,
8895
depends_on_is_final=True,
8996
id=id,
9097
within_inames=within_inames,
@@ -93,10 +100,9 @@ def add_barrier(kernel, insn_before="", insn_after="", id_based_on=None,
93100
mem_kind=mem_kind)
94101

95102
new_kernel = kernel.copy(instructions=[*kernel.instructions, barrier_to_add])
96-
if insn_after is not None:
97-
new_kernel = add_dependency(new_kernel,
98-
insn_match=insn_after,
99-
depends_on="id:"+id)
103+
new_kernel = add_dependency(
104+
new_kernel, insn_match=insn_after, depends_on="id:" + id
105+
)
100106

101107
return new_kernel
102108

0 commit comments

Comments
 (0)