Skip to content

Commit e974c53

Browse files
committed
Add transformations specific to sum-reduction
1 parent 7a72ba9 commit e974c53

File tree

3 files changed

+283
-0
lines changed

3 files changed

+283
-0
lines changed

doc/ref_transform.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ Manipulating Instructions
8080

8181
.. autofunction:: add_barrier
8282

83+
Manipulating Reductions
84+
-----------------------
85+
86+
.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction
87+
88+
.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
89+
8390
Registering Library Routines
8491
----------------------------
8592

loopy/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@
120120
from loopy.transform.parameter import assume, fix_parameters
121121
from loopy.transform.save import save_and_reload_temporaries
122122
from loopy.transform.add_barrier import add_barrier
123+
from loopy.transform.reduction import (
124+
hoist_invariant_multiplicative_terms_in_sum_reduction,
125+
extract_multiplicative_terms_in_sum_reduction_as_subst)
123126
from loopy.transform.callable import (register_callable,
124127
merge, inline_callable_kernel, rename_callable)
125128
from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call
@@ -247,6 +250,9 @@
247250

248251
"add_barrier",
249252

253+
"hoist_invariant_multiplicative_terms_in_sum_reduction",
254+
"extract_multiplicative_terms_in_sum_reduction_as_subst",
255+
250256
"register_callable",
251257
"merge",
252258

loopy/transform/reduction.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""
2+
.. currentmodule:: loopy
3+
4+
.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction
5+
6+
.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
7+
"""
8+
9+
__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
10+
11+
__license__ = """
12+
Permission is hereby granted, free of charge, to any person obtaining a copy
13+
of this software and associated documentation files (the "Software"), to deal
14+
in the Software without restriction, including without limitation the rights
15+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16+
copies of the Software, and to permit persons to whom the Software is
17+
furnished to do so, subject to the following conditions:
18+
19+
The above copyright notice and this permission notice shall be included in
20+
all copies or substantial portions of the Software.
21+
22+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
28+
THE SOFTWARE.
29+
"""
30+
31+
import pymbolic.primitives as p
32+
33+
from typing import (FrozenSet, TypeVar, Callable, List, Tuple, Iterable, Union, Any,
34+
Optional, Sequence)
35+
from loopy.symbolic import IdentityMapper, Reduction, CombineMapper
36+
from loopy.kernel import LoopKernel
37+
from loopy.kernel.data import SubstitutionRule
38+
from loopy.diagnostic import LoopyError
39+
40+
41+
# {{{ partition (copied from more-itertools)
42+
43+
Tpart = TypeVar("Tpart")
44+
45+
46+
def partition(pred: Callable[[Tpart], bool],
47+
iterable: Iterable[Tpart]) -> Tuple[List[Tpart],
48+
List[Tpart]]:
49+
"""
50+
Use a predicate to partition entries into false entries and true
51+
entries
52+
"""
53+
# Inspired from https://docs.python.org/3/library/itertools.html
54+
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
55+
from itertools import tee, filterfalse
56+
t1, t2 = tee(iterable)
57+
return list(filterfalse(pred, t1)), list(filter(pred, t2))
58+
59+
# }}}
60+
61+
62+
# {{{ hoist_reduction_invariant_terms
63+
64+
# type-ignore-reason: cannot subclass from IdentityMapper (inferred as type Any)
65+
class EinsumTermsHoister(IdentityMapper): # type: ignore[misc]
66+
"""
67+
Mapper to hoist products out of a sum-reduction.
68+
69+
.. attribute:: reduction_inames
70+
71+
Inames of the reduction expressions to perform the hoisting.
72+
"""
73+
def __init__(self, reduction_inames: FrozenSet[str]):
74+
super().__init__()
75+
self.reduction_inames = reduction_inames
76+
77+
def map_reduction(self, expr: Reduction) -> p.Expression:
78+
if frozenset(expr.inames) != self.reduction_inames:
79+
return super().map_reduction(expr)
80+
81+
from loopy.library.reduction import SumReductionOperation
82+
from loopy.symbolic import get_dependencies
83+
if isinstance(expr.expr, p.Product) and isinstance(expr.operation,
84+
SumReductionOperation):
85+
from pymbolic.primitives import flattened_product
86+
multiplicative_terms = (flattened_product(self.rec(expr.expr).children)
87+
.children)
88+
invariants, variants = partition(lambda x: (get_dependencies(x)
89+
& self.reduction_inames),
90+
multiplicative_terms)
91+
92+
return p.Product(tuple(invariants)) * Reduction(
93+
expr.operation,
94+
inames=expr.inames,
95+
expr=p.Product(tuple(variants)),
96+
allow_simultaneous=expr.allow_simultaneous)
97+
else:
98+
raise NotImplementedError(expr.expr)
99+
100+
101+
def hoist_invariant_multiplicative_terms_in_sum_reduction(
102+
kernel: LoopKernel,
103+
reduction_inames: Union[str, FrozenSet[str]],
104+
within: Any = None
105+
) -> LoopKernel:
106+
"""
107+
Hoists loop-invariant multiplicative terms in a sum-reduction expression.
108+
109+
:arg reduction_inames: The inames over which reduction is performed that defines
110+
the reduction expression that are to be transformed.
111+
:arg within: A match expression understood by :func:`loopy.match.parse_match`
112+
to specify the instructions over which the transformation is to be
113+
performed.
114+
"""
115+
from loopy.transform.instruction import map_instructions
116+
if isinstance(reduction_inames, str):
117+
reduction_inames = frozenset([reduction_inames])
118+
119+
if not (reduction_inames <= kernel.all_inames()):
120+
raise ValueError(f"Some inames in '{reduction_inames}' not a part of"
121+
" the kernel.")
122+
123+
term_hoister = EinsumTermsHoister(reduction_inames)
124+
125+
return map_instructions(kernel,
126+
insn_match=within,
127+
f=lambda x: x.with_transformed_expressions(term_hoister)
128+
)
129+
130+
# }}}
131+
132+
133+
# {{{ extract_multiplicative_terms_in_sum_reduction_as_subst
134+
135+
class ContainsSumReduction(CombineMapper):
136+
"""
137+
Returns *True* only if the mapper maps over an expression containing a
138+
SumReduction operation.
139+
"""
140+
def combine(self, values: Iterable[bool]) -> bool:
141+
return any(values)
142+
143+
def map_reduction(self, expr: Reduction) -> bool:
144+
from loopy.library.reduction import SumReductionOperation
145+
return (isinstance(expr.operation, SumReductionOperation)
146+
or self.rec(expr.expr))
147+
148+
def map_variable(self, expr: p.Variable) -> bool:
149+
return False
150+
151+
def map_algebraic_leaf(self, expr: Any) -> bool:
152+
return False
153+
154+
155+
class MultiplicativeTermReplacer(IdentityMapper):
156+
"""
157+
Primary mapper of
158+
:func:`extract_multiplicative_terms_in_sum_reduction_as_subst`.
159+
"""
160+
def __init__(self,
161+
*,
162+
terms_filter: Callable[[p.Expression], bool],
163+
subst_name: str,
164+
subst_arguments: Tuple[str, ...]) -> None:
165+
self.subst_name = subst_name
166+
self.subst_arguments = subst_arguments
167+
self.terms_filter = terms_filter
168+
super().__init__()
169+
170+
# mutable state to record the expression collected by the terms_filter
171+
self.collected_subst_rule: Optional[SubstitutionRule] = None
172+
173+
def map_reduction(self, expr: Reduction) -> Reduction:
174+
from loopy.library.reduction import SumReductionOperation
175+
from loopy.symbolic import SubstitutionMapper
176+
if isinstance(expr.operation, SumReductionOperation):
177+
if self.collected_subst_rule is not None:
178+
# => there was already a sum-reduction operation -> raise
179+
raise ValueError("Multiple sum reduction expressions found -> not"
180+
" allowed.")
181+
182+
if isinstance(expr.expr, p.Product):
183+
from pymbolic.primitives import flattened_product
184+
terms = flattened_product(expr.expr.children).children
185+
else:
186+
terms = expr.expression
187+
188+
unfiltered_terms, filtered_terms = partition(self.terms_filter, terms)
189+
submap = SubstitutionMapper({
190+
argument_expr: p.Variable(f"arg{i}")
191+
for i, argument_expr in enumerate(self.subst_arguments)}.get)
192+
self.collected_subst_rule = SubstitutionRule(
193+
name=self.subst_name,
194+
arguments=tuple(f"arg{i}" for i in range(len(self.subst_arguments))),
195+
expression=submap(p.Product(tuple(filtered_terms))
196+
if filtered_terms
197+
else 1)
198+
)
199+
return Reduction(
200+
expr.operation,
201+
expr.inames,
202+
p.Product((p.Variable(self.subst_name)(*self.subst_arguments),
203+
*unfiltered_terms)),
204+
expr.allow_simultaneous)
205+
else:
206+
return super().map_reduction(expr)
207+
208+
209+
def extract_multiplicative_terms_in_sum_reduction_as_subst(
210+
kernel: LoopKernel,
211+
within: Any,
212+
subst_name: str,
213+
arguments: Sequence[p.Expression],
214+
terms_filter: Callable[[p.Expression], bool],
215+
) -> LoopKernel:
216+
"""
217+
Returns a copy of *kernel* with a new substitution named *subst_name* and
218+
*arguments* as arguments for the aggregated multiplicative terms in a
219+
sum-reduction expression.
220+
221+
:arg within: A match expression understood by :func:`loopy.match.parse_match`
222+
to specify the instructions over which the transformation is to be
223+
performed.
224+
:arg terms_filter: A callable to filter which terms of the sum-reduction
225+
comprise the body of substitution rule.
226+
:arg arguments: The sub-expressions of the product of the filtered terms that
227+
form the arguments of the extract substitution rule in the same order.
228+
229+
.. note::
230+
231+
A :class:`~loopy.diagnostic.LoopyError` is raised if more than 1
232+
sum-reduction expression appear in *within*.
233+
"""
234+
from loopy.match import parse_match
235+
within = parse_match(within)
236+
237+
matched_insns = [
238+
insn
239+
for insn in kernel.instructions
240+
if within(kernel, insn) and ContainsSumReduction()((insn.expression,
241+
tuple(insn.predicates)))
242+
]
243+
244+
if len(matched_insns) == 0:
245+
raise LoopyError(f"No instructions found matching '{within}'"
246+
" with sum-reductions found.")
247+
if len(matched_insns) > 1:
248+
raise LoopyError(f"More than one instruction found matching '{within}'"
249+
" with sum-reductions found -> not allowed.")
250+
251+
insn, = matched_insns
252+
replacer = MultiplicativeTermReplacer(subst_name=subst_name,
253+
subst_arguments=tuple(arguments),
254+
terms_filter=terms_filter)
255+
new_insn = insn.with_transformed_expressions(replacer)
256+
new_rule = replacer.collected_subst_rule
257+
new_substitutions = kernel.substitutions.copy()
258+
if subst_name in new_substitutions:
259+
raise LoopyError(f"Kernel '{kernel.name}' already contains a substitution"
260+
" rule named '{subst_name}'.")
261+
new_substitutions[subst_name] = new_rule
262+
263+
return kernel.copy(instructions=[new_insn if insn.id == new_insn.id else insn
264+
for insn in kernel.instructions],
265+
substitutions=new_substitutions)
266+
267+
# }}}
268+
269+
270+
# vim: foldmethod=marker

0 commit comments

Comments
 (0)