Skip to content

Commit 58b1abc

Browse files
committed
Add transformations specific to sum-reduction
1 parent 7a72ba9 commit 58b1abc

File tree

3 files changed

+305
-0
lines changed

3 files changed

+305
-0
lines changed

doc/ref_transform.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ Manipulating Instructions
8080

8181
.. autofunction:: add_barrier
8282

83+
Manipulating Reductions
84+
-----------------------
85+
86+
.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction
87+
88+
.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
89+
8390
Registering Library Routines
8491
----------------------------
8592

loopy/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@
120120
from loopy.transform.parameter import assume, fix_parameters
121121
from loopy.transform.save import save_and_reload_temporaries
122122
from loopy.transform.add_barrier import add_barrier
123+
from loopy.transform.reduction import (
124+
hoist_invariant_multiplicative_terms_in_sum_reduction,
125+
extract_multiplicative_terms_in_sum_reduction_as_subst)
123126
from loopy.transform.callable import (register_callable,
124127
merge, inline_callable_kernel, rename_callable)
125128
from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call
@@ -247,6 +250,9 @@
247250

248251
"add_barrier",
249252

253+
"hoist_invariant_multiplicative_terms_in_sum_reduction",
254+
"extract_multiplicative_terms_in_sum_reduction_as_subst",
255+
250256
"register_callable",
251257
"merge",
252258

loopy/transform/reduction.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
"""
2+
.. currentmodule:: loopy
3+
4+
.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction
5+
6+
.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst
7+
"""
8+
9+
__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
10+
11+
__license__ = """
12+
Permission is hereby granted, free of charge, to any person obtaining a copy
13+
of this software and associated documentation files (the "Software"), to deal
14+
in the Software without restriction, including without limitation the rights
15+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16+
copies of the Software, and to permit persons to whom the Software is
17+
furnished to do so, subject to the following conditions:
18+
19+
The above copyright notice and this permission notice shall be included in
20+
all copies or substantial portions of the Software.
21+
22+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
28+
THE SOFTWARE.
29+
"""
30+
31+
import pymbolic.primitives as p
32+
33+
from typing import (FrozenSet, TypeVar, Callable, List, Tuple, Iterable, Union, Any,
34+
Optional, Sequence)
35+
from loopy.symbolic import IdentityMapper, Reduction, CombineMapper
36+
from loopy.kernel import LoopKernel
37+
from loopy.kernel.data import SubstitutionRule
38+
from loopy.diagnostic import LoopyError
39+
40+
41+
# {{{ partition (copied from more-itertools)
42+
43+
Tpart = TypeVar("Tpart")
44+
45+
46+
def partition(pred: Callable[[Tpart], bool],
47+
iterable: Iterable[Tpart]) -> Tuple[List[Tpart],
48+
List[Tpart]]:
49+
"""
50+
Use a predicate to partition entries into false entries and true
51+
entries
52+
"""
53+
# Inspired from https://docs.python.org/3/library/itertools.html
54+
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
55+
from itertools import tee, filterfalse
56+
t1, t2 = tee(iterable)
57+
return list(filterfalse(pred, t1)), list(filter(pred, t2))
58+
59+
# }}}
60+
61+
62+
# {{{ hoist_reduction_invariant_terms
63+
64+
class EinsumTermsHoister(IdentityMapper):
65+
"""
66+
Mapper to hoist products out of a sum-reduction.
67+
68+
.. attribute:: reduction_inames
69+
70+
Inames of the reduction expressions to perform the hoisting.
71+
"""
72+
def __init__(self, reduction_inames: FrozenSet[str]):
73+
super().__init__()
74+
self.reduction_inames = reduction_inames
75+
76+
# type-ignore-reason: super-class.map_reduction returns 'Any'
77+
def map_reduction(self, expr: Reduction # type: ignore[override]
78+
) -> p.Expression:
79+
if frozenset(expr.inames) != self.reduction_inames:
80+
return super().map_reduction(expr)
81+
82+
from loopy.library.reduction import SumReductionOperation
83+
from loopy.symbolic import get_dependencies
84+
if isinstance(expr.operation, SumReductionOperation):
85+
if isinstance(expr.expr, p.Product):
86+
from pymbolic.primitives import flattened_product
87+
multiplicative_terms = (flattened_product(self.rec(expr.expr)
88+
.children)
89+
.children)
90+
else:
91+
multiplicative_terms = (expr.expr,)
92+
93+
invariants, variants = partition(lambda x: (get_dependencies(x)
94+
& self.reduction_inames),
95+
multiplicative_terms)
96+
if not variants:
97+
# -> everything is invariant
98+
return self.rec(expr.expr) * Reduction(
99+
expr.operation,
100+
inames=expr.inames,
101+
expr=1, # FIXME: invalid dtype (not sure how?)
102+
allow_simultaneous=expr.allow_simultaneous)
103+
if not invariants:
104+
# -> nothing to hoist
105+
return Reduction(
106+
expr.operation,
107+
inames=expr.inames,
108+
expr=self.rec(expr.expr),
109+
allow_simultaneous=expr.allow_simultaneous)
110+
111+
return p.Product(tuple(invariants)) * Reduction(
112+
expr.operation,
113+
inames=expr.inames,
114+
expr=p.Product(tuple(variants)),
115+
allow_simultaneous=expr.allow_simultaneous)
116+
else:
117+
return super().map_reduction(expr)
118+
119+
120+
def hoist_invariant_multiplicative_terms_in_sum_reduction(
121+
kernel: LoopKernel,
122+
reduction_inames: Union[str, FrozenSet[str]],
123+
within: Any = None
124+
) -> LoopKernel:
125+
"""
126+
Hoists loop-invariant multiplicative terms in a sum-reduction expression.
127+
128+
:arg reduction_inames: The inames over which reduction is performed that defines
129+
the reduction expression that is to be transformed.
130+
:arg within: A match expression understood by :func:`loopy.match.parse_match`
131+
that specifies the instructions over which the transformation is to be
132+
performed.
133+
"""
134+
from loopy.transform.instruction import map_instructions
135+
if isinstance(reduction_inames, str):
136+
reduction_inames = frozenset([reduction_inames])
137+
138+
if not (reduction_inames <= kernel.all_inames()):
139+
raise ValueError(f"Some inames in '{reduction_inames}' not a part of"
140+
" the kernel.")
141+
142+
term_hoister = EinsumTermsHoister(reduction_inames)
143+
144+
return map_instructions(kernel,
145+
insn_match=within,
146+
f=lambda x: x.with_transformed_expressions(term_hoister)
147+
)
148+
149+
# }}}
150+
151+
152+
# {{{ extract_multiplicative_terms_in_sum_reduction_as_subst
153+
154+
class ContainsSumReduction(CombineMapper):
155+
"""
156+
Returns *True* only if the mapper maps over an expression containing a
157+
SumReduction operation.
158+
"""
159+
def combine(self, values: Iterable[bool]) -> bool:
160+
return any(values)
161+
162+
# type-ignore-reason: super-class.map_reduction returns 'Any'
163+
def map_reduction(self, expr: Reduction) -> bool: # type: ignore[override]
164+
from loopy.library.reduction import SumReductionOperation
165+
return (isinstance(expr.operation, SumReductionOperation)
166+
or self.rec(expr.expr))
167+
168+
def map_variable(self, expr: p.Variable) -> bool:
169+
return False
170+
171+
def map_algebraic_leaf(self, expr: Any) -> bool:
172+
return False
173+
174+
175+
class MultiplicativeTermReplacer(IdentityMapper):
176+
"""
177+
Primary mapper of
178+
:func:`extract_multiplicative_terms_in_sum_reduction_as_subst`.
179+
"""
180+
def __init__(self,
181+
*,
182+
terms_filter: Callable[[p.Expression], bool],
183+
subst_name: str,
184+
subst_arguments: Tuple[str, ...]) -> None:
185+
self.subst_name = subst_name
186+
self.subst_arguments = subst_arguments
187+
self.terms_filter = terms_filter
188+
super().__init__()
189+
190+
# mutable state to record the expression collected by the terms_filter
191+
self.collected_subst_rule: Optional[SubstitutionRule] = None
192+
193+
# type-ignore-reason: super-class.map_reduction returns 'Any'
194+
def map_reduction(self, expr: Reduction) -> Reduction: # type: ignore[override]
195+
from loopy.library.reduction import SumReductionOperation
196+
from loopy.symbolic import SubstitutionMapper
197+
if isinstance(expr.operation, SumReductionOperation):
198+
if self.collected_subst_rule is not None:
199+
# => there was already a sum-reduction operation -> raise
200+
raise ValueError("Multiple sum reduction expressions found -> not"
201+
" allowed.")
202+
203+
if isinstance(expr.expr, p.Product):
204+
from pymbolic.primitives import flattened_product
205+
terms = flattened_product(expr.expr.children).children
206+
else:
207+
terms = (expr.expr,)
208+
209+
unfiltered_terms, filtered_terms = partition(self.terms_filter, terms)
210+
submap = SubstitutionMapper({
211+
argument_expr: p.Variable(f"arg{i}")
212+
for i, argument_expr in enumerate(self.subst_arguments)}.get)
213+
self.collected_subst_rule = SubstitutionRule(
214+
name=self.subst_name,
215+
arguments=tuple(f"arg{i}" for i in range(len(self.subst_arguments))),
216+
expression=submap(p.Product(tuple(filtered_terms))
217+
if filtered_terms
218+
else 1)
219+
)
220+
return Reduction(
221+
expr.operation,
222+
expr.inames,
223+
p.Product((p.Variable(self.subst_name)(*self.subst_arguments),
224+
*unfiltered_terms)),
225+
expr.allow_simultaneous)
226+
else:
227+
return super().map_reduction(expr)
228+
229+
230+
def extract_multiplicative_terms_in_sum_reduction_as_subst(
231+
kernel: LoopKernel,
232+
within: Any,
233+
subst_name: str,
234+
arguments: Sequence[p.Expression],
235+
terms_filter: Callable[[p.Expression], bool],
236+
) -> LoopKernel:
237+
"""
238+
Returns a copy of *kernel* with a new substitution named *subst_name* and
239+
*arguments* as arguments for the aggregated multiplicative terms in a
240+
sum-reduction expression.
241+
242+
:arg within: A match expression understood by :func:`loopy.match.parse_match`
243+
to specify the instructions over which the transformation is to be
244+
performed.
245+
:arg terms_filter: A callable to filter which terms of the sum-reduction
246+
comprise the body of substitution rule.
247+
:arg arguments: The sub-expressions of the product of the filtered terms that
248+
form the arguments of the extract substitution rule in the same order.
249+
250+
.. note::
251+
252+
A ``LoopyError`` is raised if none or more than 1 sum-reduction expression
253+
appear in *within*.
254+
"""
255+
from loopy.match import parse_match
256+
within = parse_match(within)
257+
258+
matched_insns = [
259+
insn
260+
for insn in kernel.instructions
261+
if within(kernel, insn) and ContainsSumReduction()((insn.expression,
262+
tuple(insn.predicates)))
263+
]
264+
265+
if len(matched_insns) == 0:
266+
raise LoopyError(f"No instructions found matching '{within}'"
267+
" with sum-reductions found.")
268+
if len(matched_insns) > 1:
269+
raise LoopyError(f"More than one instruction found matching '{within}'"
270+
" with sum-reductions found -> not allowed.")
271+
272+
insn, = matched_insns
273+
replacer = MultiplicativeTermReplacer(subst_name=subst_name,
274+
subst_arguments=tuple(arguments),
275+
terms_filter=terms_filter)
276+
new_insn = insn.with_transformed_expressions(replacer)
277+
new_rule = replacer.collected_subst_rule
278+
new_substitutions = dict(kernel.substitutions).copy()
279+
if subst_name in new_substitutions:
280+
raise LoopyError(f"Kernel '{kernel.name}' already contains a substitution"
281+
" rule named '{subst_name}'.")
282+
assert new_rule is not None
283+
new_substitutions[subst_name] = new_rule
284+
285+
return kernel.copy(instructions=[new_insn if insn.id == new_insn.id else insn
286+
for insn in kernel.instructions],
287+
substitutions=new_substitutions)
288+
289+
# }}}
290+
291+
292+
# vim: foldmethod=marker

0 commit comments

Comments
 (0)