@@ -34,9 +34,9 @@ def set_aesara_flags():
3434
3535def compare_jax_and_py (
3636 fgraph : FunctionGraph ,
37- test_inputs : Iterable ,
37+ inputs : Iterable ,
3838 assert_fn : Optional [Callable ] = None ,
39- must_be_device_array : bool = True ,
39+ updates = None ,
4040):
4141 """Function to compare python graph output and jax compiled output for testing equality
4242
@@ -53,34 +53,28 @@ def compare_jax_and_py(
5353 assert_fn: func, opt
5454 Assert function used to check for equality between python and jax. If not
5555 provided uses np.testing.assert_allclose
56- must_be_device_array: Bool
57- Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
58- if this device array is found it indicates if the result was computed by jax
59-
60- Returns
61- -------
62- jax_res
56+ updates
57+ Updates to be passed to `aesara.function`.
6358
6459 """
6560 if assert_fn is None :
6661 assert_fn = partial (np .testing .assert_allclose , rtol = 1e-4 )
6762
68- fn_inputs = [i for i in fgraph .inputs if not isinstance (i , SharedVariable )]
69- aesara_jax_fn = function (fn_inputs , fgraph .outputs , mode = jax_mode )
70- jax_res = aesara_jax_fn (* test_inputs )
63+ if isinstance (fgraph , tuple ):
64+ fn_inputs , fn_outputs = fgraph
65+ else :
66+ fn_inputs = fgraph .inputs
67+ fn_outputs = fgraph .outputs
68+
69+ fn_inputs = [i for i in fn_inputs if not isinstance (i , SharedVariable )]
7170
72- if must_be_device_array :
73- if isinstance (jax_res , list ):
74- assert all (
75- isinstance (res , jax .interpreters .xla .DeviceArray ) for res in jax_res
76- )
77- else :
78- assert isinstance (jax_res , jax .interpreters .xla .DeviceArray )
71+ aesara_py_fn = function (fn_inputs , fn_outputs , mode = py_mode , updates = updates )
72+ py_res = aesara_py_fn (* inputs )
7973
80- aesara_py_fn = function (fn_inputs , fgraph . outputs , mode = py_mode )
81- py_res = aesara_py_fn ( * test_inputs )
74+ aesara_jax_fn = function (fn_inputs , fn_outputs , mode = jax_mode , updates = updates )
75+ jax_res = aesara_jax_fn ( * inputs )
8276
83- if len (fgraph . outputs ) > 1 :
77+ if len (fn_outputs ) > 1 :
8478 for j , p in zip (jax_res , py_res ):
8579 assert_fn (j , p )
8680 else :
0 commit comments