Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 1ab4c69

Browse files
committed
Pass updates to the JAX and Py functions
1 parent af8821f commit 1ab4c69

File tree

1 file changed

+16
-22
lines changed

1 file changed

+16
-22
lines changed

tests/link/jax/test_basic.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def set_aesara_flags():
3434

3535
def compare_jax_and_py(
3636
fgraph: FunctionGraph,
37-
test_inputs: Iterable,
37+
inputs: Iterable,
3838
assert_fn: Optional[Callable] = None,
39-
must_be_device_array: bool = True,
39+
updates=None,
4040
):
4141
"""Function to compare python graph output and jax compiled output for testing equality
4242
@@ -53,34 +53,28 @@ def compare_jax_and_py(
5353
assert_fn: func, opt
5454
Assert function used to check for equality between python and jax. If not
5555
provided uses np.testing.assert_allclose
56-
must_be_device_array: Bool
57-
Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
58-
if this device array is found it indicates if the result was computed by jax
59-
60-
Returns
61-
-------
62-
jax_res
56+
updates
57+
Updates to be passed to `aesara.function`.
6358
6459
"""
6560
if assert_fn is None:
6661
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
6762

68-
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
69-
aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
70-
jax_res = aesara_jax_fn(*test_inputs)
63+
if isinstance(fgraph, tuple):
64+
fn_inputs, fn_outputs = fgraph
65+
else:
66+
fn_inputs = fgraph.inputs
67+
fn_outputs = fgraph.outputs
68+
69+
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
7170

72-
if must_be_device_array:
73-
if isinstance(jax_res, list):
74-
assert all(
75-
isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
76-
)
77-
else:
78-
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
71+
aesara_py_fn = function(fn_inputs, fn_outputs, mode=py_mode, updates=updates)
72+
py_res = aesara_py_fn(*inputs)
7973

80-
aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
81-
py_res = aesara_py_fn(*test_inputs)
74+
aesara_jax_fn = function(fn_inputs, fn_outputs, mode=jax_mode, updates=updates)
75+
jax_res = aesara_jax_fn(*inputs)
8276

83-
if len(fgraph.outputs) > 1:
77+
if len(fn_outputs) > 1:
8478
for j, p in zip(jax_res, py_res):
8579
assert_fn(j, p)
8680
else:

0 commit comments

Comments
 (0)