Skip to content

Commit 9233385

Browse files
add WithTag
1 parent 5954415 commit 9233385

File tree

4 files changed

+123
-7
lines changed

4 files changed

+123
-7
lines changed

doc/tutorial.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,7 +1566,7 @@ information provided. Now we will count the operations:
15661566

15671567
>>> op_map = lp.get_op_map(knl, subgroup_size=32)
15681568
>>> print(op_map)
1569-
Op(np:dtype('float32'), add, subgroup, "stats_knl"): ...
1569+
Op(np:dtype('float32'), add, subgroup, "stats_knl", None): ...
15701570

15711571
Each line of output will look roughly like::
15721572

@@ -1628,7 +1628,7 @@ together into keys containing only the specified fields:
16281628

16291629
>>> op_map_dtype = op_map.group_by('dtype')
16301630
>>> print(op_map_dtype)
1631-
Op(np:dtype('float32'), None, None): ...
1631+
Op(np:dtype('float32'), None, None, None): ...
16321632
<BLANKLINE>
16331633
>>> f32op_count = op_map_dtype[lp.Op(dtype=np.float32)
16341634
... ].eval_with_dict(param_dict)

loopy/statistics.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,14 @@ class Op(ImmutableRecord):
636636
637637
A :class:`str` representing the kernel name where the operation occurred.
638638
639+
.. attribute:: tags
640+
641+
A :class:`frozenset` of tags to the operation.
642+
639643
"""
640644

641645
def __init__(self, dtype=None, name=None, count_granularity=None,
642-
kernel_name=None):
646+
kernel_name=None, tags=None):
643647
if count_granularity not in CountGranularity.ALL+[None]:
644648
raise ValueError("Op.__init__: count_granularity '%s' is "
645649
"not allowed. count_granularity options: %s"
@@ -651,15 +655,17 @@ def __init__(self, dtype=None, name=None, count_granularity=None,
651655

652656
super().__init__(dtype=dtype, name=name,
653657
count_granularity=count_granularity,
654-
kernel_name=kernel_name)
658+
kernel_name=kernel_name,
659+
tags=tags)
655660

656661
def __repr__(self):
657662
# Record.__repr__ overridden for consistent ordering and conciseness
658663
if self.kernel_name is not None:
659664
return (f"Op({self.dtype}, {self.name}, {self.count_granularity},"
660-
f' "{self.kernel_name}")')
665+
f' "{self.kernel_name}", {self.tags})')
661666
else:
662-
return f"Op({self.dtype}, {self.name}, {self.count_granularity})"
667+
return f"Op({self.dtype}, {self.name}, " + \
668+
f"{self.count_granularity}, {self.tags})"
663669

664670
# }}}
665671

@@ -724,7 +730,7 @@ class MemAccess(ImmutableRecord):
724730
work-group executes on a single compute unit with all work-items within
725731
the work-group sharing local memory. A sub-group is an
726732
implementation-dependent grouping of work-items within a work-group,
727-
analagous to an NVIDIA CUDA warp.
733+
analogous to an NVIDIA CUDA warp.
728734
729735
.. attribute:: kernel_name
730736
@@ -922,6 +928,12 @@ def map_constant(self, expr):
922928
map_tagged_variable = map_constant
923929
map_variable = map_constant
924930

931+
def map_with_tag(self, expr):
932+
opmap = self.rec(expr.expr)
933+
for op in opmap.count_map:
934+
op.tags = expr.tags
935+
return opmap
936+
925937
def map_call(self, expr):
926938
from loopy.symbolic import ResolvedFunction
927939
assert isinstance(expr.function, ResolvedFunction)

loopy/symbolic.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@
114114
# {{{ mappers with support for loopy-specific primitives
115115

116116
class IdentityMapperMixin:
117+
def map_with_tag(self, expr, *args, **kwargs):
118+
new_expr = self.rec(expr.expr, *args, **kwargs)
119+
return WithTag(expr.tags, new_expr)
120+
117121
def map_literal(self, expr, *args, **kwargs):
118122
return expr
119123

@@ -207,6 +211,12 @@ def map_common_subexpression_uncached(self, expr):
207211

208212

209213
class WalkMapperMixin:
214+
def map_with_tag(self, expr, *args, **kwargs):
215+
if not self.visit(expr, *args, **kwargs):
216+
return
217+
218+
self.rec(expr.expr, *args, **kwargs)
219+
210220
def map_literal(self, expr, *args, **kwargs):
211221
self.visit(expr, *args, **kwargs)
212222

@@ -273,6 +283,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase):
273283

274284

275285
class CombineMapper(CombineMapperBase):
286+
def map_with_tag(self, expr, *args, **kwargs):
287+
return self.rec(expr.expr, *args, **kwargs)
288+
276289
def map_reduction(self, expr, *args, **kwargs):
277290
return self.rec(expr.expr, *args, **kwargs)
278291

@@ -298,6 +311,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase,
298311

299312

300313
class StringifyMapper(StringifyMapperBase):
314+
def map_with_tag(self, expr, *args):
315+
from pymbolic.mapper.stringifier import PREC_NONE
316+
return f"WithTag({expr.tags}, {self.rec(expr.expr, PREC_NONE)}"
317+
301318
def map_literal(self, expr, *args):
302319
return expr.s
303320

@@ -440,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs):
440457
def map_loopy_function_identifier(self, expr, *args, **kwargs):
441458
return set()
442459

460+
def map_with_tag(self, expr, *args, **kwargs):
461+
deps = self.rec(expr.expr, *args, **kwargs)
462+
return deps
463+
443464
def map_sub_array_ref(self, expr, *args, **kwargs):
444465
deps = self.rec(expr.subscript, *args, **kwargs)
445466
return deps - set(expr.swept_inames)
@@ -712,6 +733,31 @@ def copy(self, *, name=None, tags=None):
712733
mapper_method = intern("map_tagged_variable")
713734

714735

736+
class WithTag(LoopyExpressionBase):
737+
"""
738+
Represents a frozenset of tags attached to an :attr:`expr`.
739+
"""
740+
741+
init_arg_names = ("tags", "expr")
742+
743+
def __init__(self, tags, expr):
744+
self.tags = tags
745+
self.expr = expr
746+
747+
def __getinitargs__(self):
748+
return (self.tags, self.expr)
749+
750+
def get_hash(self):
751+
return hash((self.__class__, self.tags, self.expr))
752+
753+
def is_equal(self, other):
754+
return (other.__class__ == self.__class__
755+
and other.tags == self.tags
756+
and other.expr == self.expr)
757+
758+
mapper_method = intern("map_with_tag")
759+
760+
715761
class Reduction(LoopyExpressionBase):
716762
"""
717763
Represents a reduction operation on :attr:`expr` across :attr:`inames`.

test/test_statistics.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,64 @@ def test_no_loop_ops():
15311531
assert f64_mul == 1
15321532

15331533

1534+
from pytools.tag import Tag
1535+
1536+
1537+
class MyCostTag1(Tag):
1538+
pass
1539+
1540+
1541+
class MyCostTag2(Tag):
1542+
pass
1543+
1544+
1545+
class MyCostTagSum(Tag):
1546+
pass
1547+
1548+
1549+
def test_op_with_tag():
1550+
from loopy.symbolic import WithTag
1551+
from pymbolic.primitives import Subscript, Variable, Sum
1552+
1553+
n = 500
1554+
1555+
knl = lp.make_kernel(
1556+
"{[i]: 0<=i<n}",
1557+
[
1558+
lp.Assignment("c[i]", WithTag(frozenset((MyCostTagSum(),)),
1559+
Sum(
1560+
(WithTag(frozenset((MyCostTag1(),)),
1561+
Subscript(Variable("a"), Variable("i"))),
1562+
WithTag(frozenset((MyCostTag2(),)),
1563+
Subscript(Variable("b"), Variable("i")))))))
1564+
])
1565+
1566+
knl = lp.add_dtypes(knl, {"a": np.float64, "b": np.float64})
1567+
1568+
params = {"n": n}
1569+
1570+
op_map = lp.get_op_map(knl, subgroup_size=32)
1571+
1572+
f64_add = op_map.filter_by(dtype=[np.float64]).eval_and_sum(params)
1573+
assert f64_add == n
1574+
1575+
f64_add = op_map.filter_by(
1576+
tags=[frozenset((MyCostTagSum(),))]).eval_and_sum(params)
1577+
assert f64_add == n
1578+
1579+
f64_add = op_map.filter_by(
1580+
tags=[frozenset((MyCostTag1(),))]).eval_and_sum(params)
1581+
assert f64_add == 0
1582+
1583+
f64_add = op_map.filter_by(
1584+
tags=[frozenset((MyCostTag2(),))]).eval_and_sum(params)
1585+
assert f64_add == 0
1586+
1587+
f64_add = op_map.filter_by(
1588+
tags=[frozenset((MyCostTag2(), MyCostTagSum()))]).eval_and_sum(params)
1589+
assert f64_add == 0
1590+
1591+
15341592
if __name__ == "__main__":
15351593
if len(sys.argv) > 1:
15361594
exec(sys.argv[1])

0 commit comments

Comments
 (0)