Skip to content

Commit cb46068

Browse files
kaushikcfdinducer
authored andcommitted
Type lp.map_instructions.
1 parent d03e69a commit cb46068

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

loopy/transform/instruction.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040

4141
if TYPE_CHECKING:
42-
from collections.abc import Mapping, Sequence
42+
from collections.abc import Callable, Mapping, Sequence
4343

4444
from pymbolic import ArithmeticExpression
4545
from pymbolic.primitives import Subscript
@@ -84,7 +84,8 @@ def find_instructions(
8484

8585
# {{{ map_instructions
8686

87-
def map_instructions(kernel, insn_match, f):
87+
def map_instructions(kernel: LoopKernel, insn_match: ToMatchConvertible,
88+
f: Callable[[InstructionBase], InstructionBase]) -> LoopKernel:
8889
from loopy.match import parse_match
8990
match = parse_match(insn_match)
9091

@@ -104,14 +105,16 @@ def map_instructions(kernel, insn_match, f):
104105
# {{{ set_instruction_priority
105106

106107
@for_each_kernel
107-
def set_instruction_priority(kernel, insn_match, priority):
108+
def set_instruction_priority(
109+
kernel: LoopKernel, insn_match: ToMatchConvertible, priority: int
110+
) -> LoopKernel:
108111
"""Set the priority of instructions matching *insn_match* to *priority*.
109112
110113
*insn_match* may be any instruction id match understood by
111114
:func:`loopy.match.parse_match`.
112115
"""
113116

114-
def set_prio(insn):
117+
def set_prio(insn: InstructionBase):
115118
return insn.copy(priority=priority)
116119

117120
return map_instructions(kernel, insn_match, set_prio)
@@ -122,7 +125,9 @@ def set_prio(insn):
122125
# {{{ add_dependency
123126

124127
@for_each_kernel
125-
def add_dependency(kernel, insn_match, depends_on):
128+
def add_dependency(
129+
kernel: LoopKernel, insn_match: ToMatchConvertible, depends_on: ToMatchConvertible
130+
) -> LoopKernel:
126131
"""Add the instruction dependency *dependency* to the instructions matched
127132
by *insn_match*.
128133
@@ -148,7 +153,7 @@ def add_dependency(kernel, insn_match, depends_on):
148153

149154
matched = [False]
150155

151-
def add_dep(insn):
156+
def add_dep(insn: InstructionBase):
152157
new_deps = insn.depends_on
153158
matched[0] = True
154159
new_deps = added_deps if new_deps is None else new_deps | added_deps

0 commit comments

Comments
 (0)