Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit af8821f

Browse files
committed
Use jax.lax.dynamic_slice in the Subtensor JAX dispatcher
1 parent 0ac65bb commit af8821f

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

aesara/link/jax/dispatch/subtensor.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import jax
2+
import jax.numpy as jnp
13

24
from aesara.link.jax.dispatch.basic import jax_funcify
5+
from aesara.scalar.basic import ScalarConstant
36
from aesara.tensor.subtensor import (
47
AdvancedIncSubtensor,
58
AdvancedIncSubtensor1,
@@ -12,21 +15,55 @@
1215
from aesara.tensor.type_other import MakeSlice
1316

1417

18+
def jax_dynamic_slice(
19+
array_to_slice: jnp.DeviceArray, start: jnp.DeviceArray, stop: jnp.DeviceArray
20+
) -> jnp.DeviceArray:
21+
"""Slice a JAX array with traced indices with jax.lax.dynamic_slice.
22+
23+
TODO: This currently assumes that arrays are 1-dimensional.
24+
"""
25+
if stop is None:
26+
slice_size = array_to_slice.shape[0] - start
27+
if start is None:
28+
start = 0
29+
slice_size = stop
30+
return jax.lax.dynamic_slice(array_to_slice, (start,), (slice_size,))
31+
32+
1533
@jax_funcify.register(Subtensor)
1634
@jax_funcify.register(AdvancedSubtensor)
1735
@jax_funcify.register(AdvancedSubtensor1)
18-
def jax_funcify_Subtensor(op, **kwargs):
36+
def jax_funcify_Subtensor(op, node, **kwargs):
1937

2038
idx_list = getattr(op, "idx_list", None)
2139

22-
def subtensor(x, *ilists):
40+
# JAX does not provide an easy way to slice an array dynamically with
41+
# a step value, so we raise an exception.
42+
#
43+
# However, Aesara rewrites sometimes introduce a constant step with value
44+
# `1` in which case it is safe to convert the Op, ignoring the step value.
45+
if len(node.inputs) == 4:
46+
if not isinstance(node.inputs[-1], ScalarConstant):
47+
raise NotImplementedError(
48+
"Dynamic slicing operations with a step value are not supported in JAX"
49+
)
50+
else:
51+
if node.inputs[-1].value != 1:
52+
raise NotImplementedError(
53+
"Dynamic slicing operations with a step value are not supported in JAX"
54+
)
2355

56+
def subtensor(x, *ilists):
2457
indices = indices_from_subtensor(ilists, idx_list)
2558

2659
if len(indices) == 1:
2760
indices = indices[0]
2861

29-
return x.__getitem__(indices)
62+
if isinstance(indices, slice):
63+
return jax_dynamic_slice(x, indices.start, indices.stop)
64+
else:
65+
index = indices
66+
return x.__getitem__(index)
3067

3168
return subtensor
3269

0 commit comments

Comments
 (0)