|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
1 | 3 |
|
2 | 4 | from aesara.link.jax.dispatch.basic import jax_funcify |
| 5 | +from aesara.scalar.basic import ScalarConstant |
3 | 6 | from aesara.tensor.subtensor import ( |
4 | 7 | AdvancedIncSubtensor, |
5 | 8 | AdvancedIncSubtensor1, |
|
12 | 15 | from aesara.tensor.type_other import MakeSlice |
13 | 16 |
|
14 | 17 |
|
| 18 | +def jax_dynamic_slice( |
| 19 | + array_to_slice: jnp.DeviceArray, start: jnp.DeviceArray, stop: jnp.DeviceArray |
| 20 | +) -> jnp.DeviceArray: |
| 21 | + """Slice a JAX array with traced indices with jax.lax.dynamic_slice. |
| 22 | +
|
| 23 | + TODO: This currently assumes that arrays are 1-dimensional. |
| 24 | + """ |
| 25 | + if stop is None: |
| 26 | + slice_size = array_to_slice.shape[0] - start |
| 27 | + if start is None: |
| 28 | + start = 0 |
| 29 | + slice_size = stop |
| 30 | + return jax.lax.dynamic_slice(array_to_slice, (start,), (slice_size,)) |
| 31 | + |
| 32 | + |
15 | 33 | @jax_funcify.register(Subtensor) |
16 | 34 | @jax_funcify.register(AdvancedSubtensor) |
17 | 35 | @jax_funcify.register(AdvancedSubtensor1) |
18 | | -def jax_funcify_Subtensor(op, **kwargs): |
| 36 | +def jax_funcify_Subtensor(op, node, **kwargs): |
19 | 37 |
|
20 | 38 | idx_list = getattr(op, "idx_list", None) |
21 | 39 |
|
22 | | - def subtensor(x, *ilists): |
| 40 | + # JAX does not provide an easy way to slice an array dynamically with |
| 41 | + # a step value, so we raise an exception. |
| 42 | + # |
| 43 | + # However, Aesara rewrites sometimes introduce a constant step with value |
| 44 | + # `1` in which case it is safe to convert the Op, ignoring the step value. |
| 45 | + if len(node.inputs) == 4: |
| 46 | + if not isinstance(node.inputs[-1], ScalarConstant): |
| 47 | + raise NotImplementedError( |
| 48 | + "Dynamic slicing operations with a step value are not supported in JAX" |
| 49 | + ) |
| 50 | + else: |
| 51 | + if node.inputs[-1].value != 1: |
| 52 | + raise NotImplementedError( |
| 53 | + "Dynamic slicing operations with a step value are not supported in JAX" |
| 54 | + ) |
23 | 55 |
|
| 56 | + def subtensor(x, *ilists): |
24 | 57 | indices = indices_from_subtensor(ilists, idx_list) |
25 | 58 |
|
26 | 59 | if len(indices) == 1: |
27 | 60 | indices = indices[0] |
28 | 61 |
|
29 | | - return x.__getitem__(indices) |
| 62 | + if isinstance(indices, slice): |
| 63 | + return jax_dynamic_slice(x, indices.start, indices.stop) |
| 64 | + else: |
| 65 | + index = indices |
| 66 | + return x.__getitem__(index) |
30 | 67 |
|
31 | 68 | return subtensor |
32 | 69 |
|
|
0 commit comments