Skip to content

Commit fe827e4

Browse files
Create WithTag
1 parent 8a2b06b commit fe827e4

File tree

4 files changed

+87
-2
lines changed

4 files changed

+87
-2
lines changed

loopy/kernel/creation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,6 +1348,9 @@ def __init__(self, add_assignment):
13481348
self.expr_to_var = {}
13491349
super().__init__()
13501350

1351+
def map_with_tag(self, expr, additional_inames):
1352+
return super().map_with_tag(expr)
1353+
13511354
def map_reduction(self, expr, additional_inames):
13521355
additional_inames = additional_inames | frozenset(expr.inames)
13531356

loopy/statistics.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,14 @@ class Op(ImmutableRecord):
636636
637637
A :class:`str` representing the kernel name where the operation occurred.
638638
639+
.. attribute:: tags
640+
641+
A :class:`frozenset` of tags to the operation.
642+
639643
"""
640644

641645
def __init__(self, dtype=None, name=None, count_granularity=None,
642-
kernel_name=None):
646+
kernel_name=None, tags=None):
643647
if count_granularity not in CountGranularity.ALL+[None]:
644648
raise ValueError("Op.__init__: count_granularity '%s' is "
645649
"not allowed. count_granularity options: %s"
@@ -651,7 +655,8 @@ def __init__(self, dtype=None, name=None, count_granularity=None,
651655

652656
super().__init__(dtype=dtype, name=name,
653657
count_granularity=count_granularity,
654-
kernel_name=kernel_name)
658+
kernel_name=kernel_name,
659+
tags=tags)
655660

656661
def __repr__(self):
657662
# Record.__repr__ overridden for consistent ordering and conciseness

loopy/symbolic.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@
114114
# {{{ mappers with support for loopy-specific primitives
115115

116116
class IdentityMapperMixin:
117+
def map_with_tag(self, expr, *args, **kwargs):
118+
new_expr = self.rec(expr.expr, *args, **kwargs)
119+
return WithTag(expr.tags, new_expr)
120+
117121
def map_literal(self, expr, *args, **kwargs):
118122
return expr
119123

@@ -207,6 +211,12 @@ def map_common_subexpression_uncached(self, expr):
207211

208212

209213
class WalkMapperMixin:
214+
def map_with_tag(self, expr, *args, **kwargs):
215+
if not self.visit(expr, *args, **kwargs):
216+
return
217+
218+
self.rec(expr.expr, *args, **kwargs)
219+
210220
def map_literal(self, expr, *args, **kwargs):
211221
self.visit(expr, *args, **kwargs)
212222

@@ -273,6 +283,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase):
273283

274284

275285
class CombineMapper(CombineMapperBase):
286+
def map_with_tag(self, expr, *args, **kwargs):
287+
return self.rec(expr.expr, *args, **kwargs)
288+
276289
def map_reduction(self, expr, *args, **kwargs):
277290
return self.rec(expr.expr, *args, **kwargs)
278291

@@ -298,6 +311,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase,
298311

299312

300313
class StringifyMapper(StringifyMapperBase):
314+
def map_with_tag(self, expr, *args):
315+
from pymbolic.mapper.stringifier import PREC_NONE
316+
return f"WithTag({expr.tags}, {self.rec(expr.expr, PREC_NONE)}"
317+
301318
def map_literal(self, expr, *args):
302319
return expr.s
303320

@@ -440,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs):
440457
def map_loopy_function_identifier(self, expr, *args, **kwargs):
441458
return set()
442459

460+
def map_with_tag(self, expr, *args, **kwargs):
461+
deps = self.rec(expr.expr, *args, **kwargs)
462+
return deps
463+
443464
def map_sub_array_ref(self, expr, *args, **kwargs):
444465
deps = self.rec(expr.subscript, *args, **kwargs)
445466
return deps - set(expr.swept_inames)
@@ -712,6 +733,31 @@ def copy(self, *, name=None, tags=None):
712733
mapper_method = intern("map_tagged_variable")
713734

714735

736+
class WithTag(LoopyExpressionBase):
737+
"""
738+
Represents a frozenset of tags attached to an :attr:`expr`.
739+
"""
740+
741+
init_arg_names = ("tags", "expr")
742+
743+
def __init__(self, tags, expr):
744+
self.tags = tags
745+
self.expr = expr
746+
747+
def __getinitargs__(self):
748+
return (self.tags, self.expr)
749+
750+
def get_hash(self):
751+
return hash((self.__class__, self.tags, self.expr))
752+
753+
def is_equal(self, other):
754+
return (other.__class__ == self.__class__
755+
and other.tags == self.tags
756+
and other.expr == self.expr)
757+
758+
mapper_method = intern("map_with_tag")
759+
760+
715761
class Reduction(LoopyExpressionBase):
716762
"""
717763
Represents a reduction operation on :attr:`expr` across :attr:`inames`.

test/test_statistics.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,37 @@ def test_no_loop_ops():
15311531
assert f64_mul == 1
15321532

15331533

1534+
from pytools.tag import Tag
1535+
1536+
1537+
class MyCostTag(Tag):
1538+
pass
1539+
1540+
1541+
class MyCostTag2(Tag):
1542+
pass
1543+
1544+
1545+
def test_op_with_tag():
1546+
from loopy.symbolic import WithTag
1547+
from pymbolic.primitives import Subscript, Variable, Sum
1548+
1549+
knl = lp.make_kernel(
1550+
"{[i]: 0<=i<n}",
1551+
[
1552+
lp.Assignment("c[i]",
1553+
Sum(
1554+
(WithTag(frozenset((MyCostTag(),)),
1555+
Subscript(Variable("a"), Variable("i"))),
1556+
WithTag(frozenset((MyCostTag2(),)),
1557+
Subscript(Variable("b"), Variable("i"))))))
1558+
])
1559+
1560+
knl = lp.add_dtypes(knl, {"a": np.float64, "b": np.float64})
1561+
1562+
_op_map = lp.get_op_map(knl, subgroup_size=32)
1563+
1564+
15341565
if __name__ == "__main__":
15351566
if len(sys.argv) > 1:
15361567
exec(sys.argv[1])

0 commit comments

Comments
 (0)