Lower latency associative scan option #10599
Replies: 2 comments
-
|
Looks like a lot of people would be interested (myself included). Are you still interested in creating such a PR? |
Beta Was this translation helpful? Give feedback.
-
|
I've implemented the work inefficient scan into lax.associative_scan and it passes all tests. However, on reflection, I don't think it will be accepted as a PR as I don't have a clear enough set of examples where it is much faster and it leaves more for the JAX team to maintain. It's in The standalone code below which fully supports all lax.associative_scan inputs is below. It is a simpler implementation than previously and covers edge cases and pytree inputs. from collections.abc import Callable
from jax import lax
from jax.tree import map as tree_map, flatten as tree_flatten, unflatten as tree_unflatten
slicing = lax # alias for slicing.slice_in_dim
def work_inefficient_associative_scan(
fn: Callable, elems, reverse: bool = False, axis: int = 0):
"""Performs a scan with an associative binary operation, in parallel.
Uses the work inefficient implementation which
can be faster on small problem sizes.
While both implementations occur in O(NlogN) steps, the smaller constant
prefactor in the work inefficient algorithm can make it faster on highly
parallel hardware
For an introduction to associative scans, see [BLE1990]_.
Args:
fn: A Python callable implementing an associative binary operation with
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
must satisfy the equation
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
The inputs and result are (possibly nested Python tree structures of)
array(s) matching ``elems``. Each array has a dimension in place
of the ``axis`` dimension. `fn` should be applied elementwise over
the ``axis`` dimension (for example, by using :func:`jax.vmap` over the
elementwise function.)
The result ``r`` has the same shape (and structure) as the two inputs
``a`` and ``b``.
elems: A (possibly nested Python tree structure of) array(s), each with
an ``axis`` dimension of size ``num_elems``.
reverse: A boolean stating if the scan should be reversed with respect to
the ``axis`` dimension.
axis: an integer identifying the axis over which the scan should occur.
Returns:
A (possibly nested Python tree structure of) array(s) of the same shape
and structure as ``elems``, in which the ``k``'th element of ``axis`` is the
result of recursively applying ``fn`` to combine the first ``k`` elements
of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``,
the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
If ``elems = [..., x, y, z]`` and ``reverse`` is true, the result is
``[..., f(f(z, y), x), f(z, y), z]``.
Example 1: partial sums of an array of numbers:
>>> work_inefficient_associative_scan(jnp.add, jnp.arange(0, 4))
Array([0, 1, 3, 6], dtype=int32)
Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
>>> partial_prods = work_inefficient_associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)
Example 3: reversed partial sums of an array of numbers
>>> work_inefficient_associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
Array([6, 6, 5, 3], dtype=int32)
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
University.
"""
if not callable(fn):
raise TypeError("lax.associative_scan: fn argument should be callable.")
elems_flat, tree = tree_flatten(elems)
if reverse:
elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat]
def combine(a_flat, b_flat):
# Lower `fn` to operate on flattened sequences of elems.
a = tree_unflatten(tree, a_flat)
b = tree_unflatten(tree, b_flat)
c = fn(a, b)
c_flat, _ = tree_flatten(c)
return c_flat
# Check that all inputs have a consistent leading dimension `num_elems`.
if axis < 0:
axis += elems_flat[0].ndim
#if not core.is_constant_dim(elems_flat[0].shape[axis]):
# raise NotImplementedError("associative scan over axis "
# f"of non-constant size: {elems_flat[0].shape[axis]}. You may be "
# "able to avoid this on TPU. See b/274176030.")
num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
raise ValueError('Array inputs to associative_scan must have the same '
'first dimension. (saw: {})'
.format([elem.shape for elem in elems_flat]))
def _naive_scan(elems):
_slice_to = lambda elems, i: [slicing.slice_in_dim(elem, 0, i, axis=axis) for elem in elems]
_slice_from = lambda elems, i: [slicing.slice_in_dim(elem, i, None, axis=axis) for elem in elems]
_length = lambda elems: elems[0].shape[axis]
_concat = lambda *elems: tree_map(lambda *xs: lax.concatenate(xs, dimension=axis), *elems)
if _length(elems) < 2:
return elems
w = 1
l = _slice_to(elems, _length(elems) - w)
r = _slice_from(elems, w)
while (2 * w) < _length(elems):
# Hillis, W. D. and Steele, G. L. (1986). Data parallel algorithms. Communications of the ACM, 29(12), 1170–1183
# log_2{n} steps
# at the end of each loop l[:2*w] is fully computed
updated_r = combine(l, r)
r = _slice_from(updated_r, w)
l = _concat(
_slice_to(l, w),
_slice_to(updated_r, max(_length(r)-w, w)))
w *= 2
updated_r = combine(
_slice_to(l, _length(r)),
r)
return _concat(l, updated_r)
scans = _naive_scan(elems_flat)
if reverse:
scans = [lax.rev(scanned, [axis]) for scanned in scans]
return tree_unflatten(tree, scans) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
In one of my problems the implementation was bottlenecked by a cumulative matmul. JAX has a handy implemention of a work-efficient associative scan for this
lax.associative_scan. This reduces the procedure from N steps to 2 log_2{N}-2 steps. There is a work-inefficient implementation that reduces this to log_2{N} steps, shown below which is faster for small problem sizes where the GPU is not saturated.Would anyone else be interested in having a work-inefficient option in the lax.associative scan? If so I can put together a pull request.
see https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda, http://www.cs.cmu.edu/~guyb/papers/Ble93.pdf, https://en.wikipedia.org/wiki/Prefix_sum
Beta Was this translation helpful? Give feedback.
All reactions