Skip to content

Commit 7eabb65

Browse files
committed
Add to some tests a direct call to JAXOp
1 parent 48fbf0a commit 7eabb65

File tree

1 file changed

+107
-23
lines changed

1 file changed

+107
-23
lines changed

tests/link/jax/test_as_jax_op.py

+107-23
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import pytensor.tensor as pt
77
from pytensor import as_jax_op, config, grad
88
from pytensor.graph.fg import FunctionGraph
9+
from pytensor.link.jax.ops import JAXOp
910
from pytensor.scalar import all_types
10-
from pytensor.tensor import tensor
11+
from pytensor.tensor import TensorType, tensor
1112
from tests.link.jax.test_basic import compare_jax_and_py
1213

1314

@@ -19,18 +20,29 @@ def test_two_inputs_single_output():
1920
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
2021
]
2122

22-
@as_jax_op
2323
def f(x, y):
2424
return jax.nn.sigmoid(x + y)
2525

26-
out = f(x, y)
26+
# Test with as_jax_op decorator
27+
out = as_jax_op(f)(x, y)
2728
grad_out = grad(pt.sum(out), [x, y])
2829

2930
fg = FunctionGraph([x, y], [out, *grad_out])
3031
fn, _ = compare_jax_and_py(fg, test_values)
3132
with jax.disable_jit():
3233
fn, _ = compare_jax_and_py(fg, test_values)
3334

35+
# Test direct JAXOp usage
36+
jax_op = JAXOp(
37+
[x.type, y.type],
38+
[TensorType(config.floatX, shape=(2,))],
39+
f,
40+
)
41+
out = jax_op(x, y)
42+
grad_out = grad(pt.sum(out), [x, y])
43+
fg = FunctionGraph([x, y], [out, *grad_out])
44+
fn, _ = compare_jax_and_py(fg, test_values)
45+
3446

3547
def test_two_inputs_tuple_output():
3648
rng = np.random.default_rng(2)
@@ -40,11 +52,11 @@ def test_two_inputs_tuple_output():
4052
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
4153
]
4254

43-
@as_jax_op
4455
def f(x, y):
4556
return jax.nn.sigmoid(x + y), y * 2
4657

47-
out1, out2 = f(x, y)
58+
# Test with as_jax_op decorator
59+
out1, out2 = as_jax_op(f)(x, y)
4860
grad_out = grad(pt.sum(out1 + out2), [x, y])
4961

5062
fg = FunctionGraph([x, y], [out1, out2, *grad_out])
@@ -54,6 +66,17 @@ def f(x, y):
5466
# inputs are not automatically transformed to jax.Array anymore
5567
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
5668

69+
# Test direct JAXOp usage
70+
jax_op = JAXOp(
71+
[x.type, y.type],
72+
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
73+
f,
74+
)
75+
out1, out2 = jax_op(x, y)
76+
grad_out = grad(pt.sum(out1 + out2), [x, y])
77+
fg = FunctionGraph([x, y], [out1, out2, *grad_out])
78+
fn, _ = compare_jax_and_py(fg, test_values)
79+
5780

5881
def test_two_inputs_list_output_one_unused_output():
5982
# One output is unused, to test whether the wrapper can handle DisconnectedType
@@ -64,72 +87,119 @@ def test_two_inputs_list_output_one_unused_output():
6487
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
6588
]
6689

67-
@as_jax_op
6890
def f(x, y):
6991
return [jax.nn.sigmoid(x + y), y * 2]
7092

71-
out, _ = f(x, y)
93+
# Test with as_jax_op decorator
94+
out, _ = as_jax_op(f)(x, y)
7295
grad_out = grad(pt.sum(out), [x, y])
7396

7497
fg = FunctionGraph([x, y], [out, *grad_out])
7598
fn, _ = compare_jax_and_py(fg, test_values)
7699
with jax.disable_jit():
77100
fn, _ = compare_jax_and_py(fg, test_values)
78101

102+
# Test direct JAXOp usage
103+
jax_op = JAXOp(
104+
[x.type, y.type],
105+
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
106+
f,
107+
)
108+
out, _ = jax_op(x, y)
109+
grad_out = grad(pt.sum(out), [x, y])
110+
fg = FunctionGraph([x, y], [out, *grad_out])
111+
fn, _ = compare_jax_and_py(fg, test_values)
112+
79113

80114
def test_single_input_tuple_output():
81115
rng = np.random.default_rng(4)
82116
x = tensor("x", shape=(2,))
83117
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
84118

85-
@as_jax_op
86119
def f(x):
87120
return jax.nn.sigmoid(x), x * 2
88121

89-
out1, out2 = f(x)
122+
# Test with as_jax_op decorator
123+
out1, out2 = as_jax_op(f)(x)
90124
grad_out = grad(pt.sum(out1), [x])
91125

92126
fg = FunctionGraph([x], [out1, out2, *grad_out])
93127
fn, _ = compare_jax_and_py(fg, test_values)
94128
with jax.disable_jit():
95129
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
96130

131+
# Test direct JAXOp usage
132+
jax_op = JAXOp(
133+
[x.type],
134+
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
135+
f,
136+
)
137+
out1, out2 = jax_op(x)
138+
grad_out = grad(pt.sum(out1), [x])
139+
fg = FunctionGraph([x], [out1, out2, *grad_out])
140+
fn, _ = compare_jax_and_py(fg, test_values)
141+
97142

98143
def test_scalar_input_tuple_output():
99144
rng = np.random.default_rng(5)
100145
x = tensor("x", shape=())
101146
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
102147

103-
@as_jax_op
104148
def f(x):
105149
return jax.nn.sigmoid(x), x
106150

107-
out1, out2 = f(x)
151+
# Test with as_jax_op decorator
152+
out1, out2 = as_jax_op(f)(x)
108153
grad_out = grad(pt.sum(out1), [x])
109154

110155
fg = FunctionGraph([x], [out1, out2, *grad_out])
111156
fn, _ = compare_jax_and_py(fg, test_values)
112157
with jax.disable_jit():
113158
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
114159

160+
# Test direct JAXOp usage
161+
jax_op = JAXOp(
162+
[x.type],
163+
[TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())],
164+
f,
165+
)
166+
out1, out2 = jax_op(x)
167+
grad_out = grad(pt.sum(out1), [x])
168+
fg = FunctionGraph([x], [out1, out2, *grad_out])
169+
fn, _ = compare_jax_and_py(fg, test_values)
170+
115171

116172
def test_single_input_list_output():
117173
rng = np.random.default_rng(6)
118174
x = tensor("x", shape=(2,))
119175
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
120176

121-
@as_jax_op
122177
def f(x):
123178
return [jax.nn.sigmoid(x), 2 * x]
124179

125-
out1, out2 = f(x)
180+
# Test with as_jax_op decorator
181+
out1, out2 = as_jax_op(f)(x)
126182
grad_out = grad(pt.sum(out1), [x])
127183

128184
fg = FunctionGraph([x], [out1, out2, *grad_out])
129185
fn, _ = compare_jax_and_py(fg, test_values)
130186
with jax.disable_jit():
131187
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
132188

189+
# Test direct JAXOp usage, with unspecified output shapes
190+
jax_op = JAXOp(
191+
[x.type],
192+
[
193+
TensorType(config.floatX, shape=(None,)),
194+
TensorType(config.floatX, shape=(None,)),
195+
],
196+
f,
197+
)
198+
out1, out2 = jax_op(x)
199+
grad_out = grad(pt.sum(out1), [x])
200+
fg = FunctionGraph([x], [out1, out2, *grad_out])
201+
fn, _ = compare_jax_and_py(fg, test_values)
202+
133203

134204
def test_pytree_input_tuple_output():
135205
rng = np.random.default_rng(7)
@@ -140,11 +210,11 @@ def test_pytree_input_tuple_output():
140210
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
141211
]
142212

143-
@as_jax_op
144213
def f(x, y):
145214
return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0]
146215

147-
out = f(x, y_tmp)
216+
# Test with as_jax_op decorator
217+
out = as_jax_op(f)(x, y_tmp)
148218
grad_out = grad(pt.sum(out[1]), [x, y])
149219

150220
fg = FunctionGraph([x, y], [out[0], out[1], *grad_out])
@@ -163,11 +233,11 @@ def test_pytree_input_pytree_output():
163233
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
164234
]
165235

166-
@as_jax_op
167236
def f(x, y):
168237
return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y)
169238

170-
out = f(x, y_tmp)
239+
# Test with as_jax_op decorator
240+
out = as_jax_op(f)(x, y_tmp)
171241
grad_out = grad(pt.sum(out[1]["b"][0]), [x, y])
172242

173243
fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out])
@@ -186,7 +256,6 @@ def test_pytree_input_with_non_graph_args():
186256
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
187257
]
188258

189-
@as_jax_op
190259
def f(x, y, depth, which_variable):
191260
if which_variable == "x":
192261
var = x
@@ -198,22 +267,23 @@ def f(x, y, depth, which_variable):
198267
var = jax.nn.sigmoid(var)
199268
return var
200269

270+
# Test with as_jax_op decorator
201271
# arguments depth and which_variable are not part of the graph
202-
out = f(x, y_tmp, depth=3, which_variable="x")
272+
out = as_jax_op(f)(x, y_tmp, depth=3, which_variable="x")
203273
grad_out = grad(pt.sum(out), [x])
204274
fg = FunctionGraph([x, y], [out[0], *grad_out])
205275
fn, _ = compare_jax_and_py(fg, test_values)
206276
with jax.disable_jit():
207277
fn, _ = compare_jax_and_py(fg, test_values)
208278

209-
out = f(x, y_tmp, depth=7, which_variable="y")
279+
out = as_jax_op(f)(x, y_tmp, depth=7, which_variable="y")
210280
grad_out = grad(pt.sum(out), [x])
211281
fg = FunctionGraph([x, y], [out[0], *grad_out])
212282
fn, _ = compare_jax_and_py(fg, test_values)
213283
with jax.disable_jit():
214284
fn, _ = compare_jax_and_py(fg, test_values)
215285

216-
out = f(x, y_tmp, depth=10, which_variable="z")
286+
out = as_jax_op(f)(x, y_tmp, depth=10, which_variable="z")
217287
assert out == "Unsupported argument"
218288

219289

@@ -228,11 +298,11 @@ def test_unused_matrix_product():
228298
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
229299
]
230300

231-
@as_jax_op
232301
def f(x, y):
233302
return x[:, None] @ y[None], jnp.exp(x)
234303

235-
out = f(x, y)
304+
# Test with as_jax_op decorator
305+
out = as_jax_op(f)(x, y)
236306
grad_out = grad(pt.sum(out[1]), [x])
237307

238308
fg = FunctionGraph([x, y], [out[1], *grad_out])
@@ -241,6 +311,20 @@ def f(x, y):
241311
with jax.disable_jit():
242312
fn, _ = compare_jax_and_py(fg, test_values)
243313

314+
# Test direct JAXOp usage
315+
jax_op = JAXOp(
316+
[x.type, y.type],
317+
[
318+
TensorType(config.floatX, shape=(3, 3)),
319+
TensorType(config.floatX, shape=(3,)),
320+
],
321+
f,
322+
)
323+
out = jax_op(x, y)
324+
grad_out = grad(pt.sum(out[1]), [x])
325+
fg = FunctionGraph([x, y], [out[1], *grad_out])
326+
fn, _ = compare_jax_and_py(fg, test_values)
327+
244328

245329
def test_unknown_static_shape():
246330
rng = np.random.default_rng(11)

0 commit comments

Comments
 (0)