Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit bf97e57

Browse files
committed
Create a dummy initial carry value when outputs_info is empty
1 parent 011f417 commit bf97e57

File tree

3 files changed

+28
-43
lines changed

3 files changed

+28
-43
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def scan(*outer_inputs):
1515
outer_in_mit_mot = list(op.outer_mitmot(outer_inputs))
1616
outer_in_mit_sot = list(op.outer_mitsot(outer_inputs))
1717
outer_in_sit_sot = list(op.outer_sitsot(outer_inputs))
18-
outer_in_nit_sot = list(op.outer_nitsot(outer_inputs))
18+
# outer_in_nit_sot = list(op.outer_nitsot(outer_inputs))
1919
outer_in_shared = list(op.outer_shared(outer_inputs))
2020
outer_in_non_seqs = list(op.outer_non_seqs(outer_inputs))
2121
if len(outer_in_mit_mot):
@@ -29,10 +29,18 @@ def scan(*outer_inputs):
2929
if len(outer_in_non_seqs):
3030
raise NotImplementedError("non sequence are not supported")
3131

32-
init_carry = outer_in_nit_sot
32+
# If `output_infos` is empty we need to create an empty initial carry
33+
# value with the output's shape and dtype
34+
is_outputs_info_empty = len(outer_in_sit_sot) == 0
35+
if is_outputs_info_empty:
36+
dtype = node.outputs[0].type.dtype
37+
shape = tuple(s for s in node.outputs[0].type.shape if s is not None)
38+
outer_in_sit_sot = [jax.numpy.empty(dtype=dtype, shape=shape)]
39+
40+
init_carry = outer_in_sit_sot
3341
sequences = outer_in_seqs
3442

35-
def scan_inner_in_args(carry, x, is_dummy_sit_sot=True):
43+
def scan_inner_in_args(carry, x):
3644
"""Create an inner-input expression.
3745
3846
Inner-inputs are ordered as follows:
@@ -45,10 +53,12 @@ def scan_inner_in_args(carry, x, is_dummy_sit_sot=True):
4553
"""
4654

4755
inner_in_seqs = x
48-
if is_dummy_sit_sot:
56+
57+
if is_outputs_info_empty:
4958
inner_in_sit_sot = []
5059
else:
5160
inner_in_sit_sot = carry
61+
5262
return sum([inner_in_seqs, inner_in_sit_sot], [])
5363

5464
def scan_new_carry(inner_outputs):
@@ -58,7 +68,7 @@ def scan_new_carry(inner_outputs):
5868
- mit-mot-outputs
5969
- mit-sot-outputs
6070
- sit-sot-outputs
61-
- nit-sots
71+
- nit-sots (no carry)
6272
- shared-outputs
6373
[+ while-condition]
6474

aesara/link/jax/dispatch/subtensor.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import jax
2-
import jax.numpy as jnp
31

42
from aesara.link.jax.dispatch.basic import jax_funcify
5-
from aesara.scalar.basic import ScalarConstant
63
from aesara.tensor.subtensor import (
74
AdvancedIncSubtensor,
85
AdvancedIncSubtensor1,
@@ -15,55 +12,22 @@
1512
from aesara.tensor.type_other import MakeSlice
1613

1714

18-
def jax_dynamic_slice(
19-
array_to_slice: jnp.DeviceArray, start: jnp.DeviceArray, stop: jnp.DeviceArray
20-
) -> jnp.DeviceArray:
21-
"""Slice a JAX array with traced indices with jax.lax.dynamic_slice.
22-
23-
TODO: This currently assumes that arrays are 1-dimensional.
24-
"""
25-
if stop is None:
26-
slice_size = array_to_slice.shape[0] - start
27-
if start is None:
28-
start = 0
29-
slice_size = stop
30-
return jax.lax.dynamic_slice(array_to_slice, (start,), (slice_size,))
31-
32-
3315
@jax_funcify.register(Subtensor)
3416
@jax_funcify.register(AdvancedSubtensor)
3517
@jax_funcify.register(AdvancedSubtensor1)
3618
def jax_funcify_Subtensor(op, node, **kwargs):
3719

3820
idx_list = getattr(op, "idx_list", None)
3921

40-
# JAX does not provide an easy way to slice an array dynamically with
41-
# a step value, so we raise an exception.
42-
#
43-
# However, Aesara rewrites sometimes introduce a constant step with value
44-
# `1` in which case it is safe to convert the Op, ignoring the step value.
45-
if len(node.inputs) == 4:
46-
if not isinstance(node.inputs[-1], ScalarConstant):
47-
raise NotImplementedError(
48-
"Dynamic slicing operations with a step value are not supported in JAX"
49-
)
50-
else:
51-
if node.inputs[-1].value != 1:
52-
raise NotImplementedError(
53-
"Dynamic slicing operations with a step value are not supported in JAX"
54-
)
5522

5623
def subtensor(x, *ilists):
24+
5725
indices = indices_from_subtensor(ilists, idx_list)
5826

5927
if len(indices) == 1:
6028
indices = indices[0]
6129

62-
if isinstance(indices, slice):
63-
return jax_dynamic_slice(x, indices.start, indices.stop)
64-
else:
65-
index = indices
66-
return x.__getitem__(index)
30+
return x.__getitem__(indices)
6731

6832
return subtensor
6933

tests/link/jax/test_scan.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@
3737
None,
3838
lambda op: op.info.n_seqs > 0,
3939
),
40+
# nit-sot
41+
(
42+
lambda: at.as_tensor(2.0),
43+
[],
44+
[{}],
45+
[],
46+
3,
47+
[],
48+
None,
49+
lambda op: op.info.n_nit_sot > 0,
50+
),
4051
],
4152
)
4253
def test_xit_xot_types(

0 commit comments

Comments
 (0)