6
6
import pytensor .tensor as pt
7
7
from pytensor import as_jax_op , config , grad
8
8
from pytensor .graph .fg import FunctionGraph
9
+ from pytensor .link .jax .ops import JAXOp
9
10
from pytensor .scalar import all_types
10
- from pytensor .tensor import tensor
11
+ from pytensor .tensor import TensorType , tensor
11
12
from tests .link .jax .test_basic import compare_jax_and_py
12
13
13
14
@@ -19,18 +20,29 @@ def test_two_inputs_single_output():
19
20
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
20
21
]
21
22
22
- @as_jax_op
23
23
def f (x , y ):
24
24
return jax .nn .sigmoid (x + y )
25
25
26
- out = f (x , y )
26
+ # Test with as_jax_op decorator
27
+ out = as_jax_op (f )(x , y )
27
28
grad_out = grad (pt .sum (out ), [x , y ])
28
29
29
30
fg = FunctionGraph ([x , y ], [out , * grad_out ])
30
31
fn , _ = compare_jax_and_py (fg , test_values )
31
32
with jax .disable_jit ():
32
33
fn , _ = compare_jax_and_py (fg , test_values )
33
34
35
+ # Test direct JAXOp usage
36
+ jax_op = JAXOp (
37
+ [x .type , y .type ],
38
+ [TensorType (config .floatX , shape = (2 ,))],
39
+ f ,
40
+ )
41
+ out = jax_op (x , y )
42
+ grad_out = grad (pt .sum (out ), [x , y ])
43
+ fg = FunctionGraph ([x , y ], [out , * grad_out ])
44
+ fn , _ = compare_jax_and_py (fg , test_values )
45
+
34
46
35
47
def test_two_inputs_tuple_output ():
36
48
rng = np .random .default_rng (2 )
@@ -40,11 +52,11 @@ def test_two_inputs_tuple_output():
40
52
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
41
53
]
42
54
43
- @as_jax_op
44
55
def f (x , y ):
45
56
return jax .nn .sigmoid (x + y ), y * 2
46
57
47
- out1 , out2 = f (x , y )
58
+ # Test with as_jax_op decorator
59
+ out1 , out2 = as_jax_op (f )(x , y )
48
60
grad_out = grad (pt .sum (out1 + out2 ), [x , y ])
49
61
50
62
fg = FunctionGraph ([x , y ], [out1 , out2 , * grad_out ])
@@ -54,6 +66,17 @@ def f(x, y):
54
66
# inputs are not automatically transformed to jax.Array anymore
55
67
fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
56
68
69
+ # Test direct JAXOp usage
70
+ jax_op = JAXOp (
71
+ [x .type , y .type ],
72
+ [TensorType (config .floatX , shape = (2 ,)), TensorType (config .floatX , shape = (2 ,))],
73
+ f ,
74
+ )
75
+ out1 , out2 = jax_op (x , y )
76
+ grad_out = grad (pt .sum (out1 + out2 ), [x , y ])
77
+ fg = FunctionGraph ([x , y ], [out1 , out2 , * grad_out ])
78
+ fn , _ = compare_jax_and_py (fg , test_values )
79
+
57
80
58
81
def test_two_inputs_list_output_one_unused_output ():
59
82
# One output is unused, to test whether the wrapper can handle DisconnectedType
@@ -64,72 +87,119 @@ def test_two_inputs_list_output_one_unused_output():
64
87
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
65
88
]
66
89
67
- @as_jax_op
68
90
def f (x , y ):
69
91
return [jax .nn .sigmoid (x + y ), y * 2 ]
70
92
71
- out , _ = f (x , y )
93
+ # Test with as_jax_op decorator
94
+ out , _ = as_jax_op (f )(x , y )
72
95
grad_out = grad (pt .sum (out ), [x , y ])
73
96
74
97
fg = FunctionGraph ([x , y ], [out , * grad_out ])
75
98
fn , _ = compare_jax_and_py (fg , test_values )
76
99
with jax .disable_jit ():
77
100
fn , _ = compare_jax_and_py (fg , test_values )
78
101
102
+ # Test direct JAXOp usage
103
+ jax_op = JAXOp (
104
+ [x .type , y .type ],
105
+ [TensorType (config .floatX , shape = (2 ,)), TensorType (config .floatX , shape = (2 ,))],
106
+ f ,
107
+ )
108
+ out , _ = jax_op (x , y )
109
+ grad_out = grad (pt .sum (out ), [x , y ])
110
+ fg = FunctionGraph ([x , y ], [out , * grad_out ])
111
+ fn , _ = compare_jax_and_py (fg , test_values )
112
+
79
113
80
114
def test_single_input_tuple_output ():
81
115
rng = np .random .default_rng (4 )
82
116
x = tensor ("x" , shape = (2 ,))
83
117
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
84
118
85
- @as_jax_op
86
119
def f (x ):
87
120
return jax .nn .sigmoid (x ), x * 2
88
121
89
- out1 , out2 = f (x )
122
+ # Test with as_jax_op decorator
123
+ out1 , out2 = as_jax_op (f )(x )
90
124
grad_out = grad (pt .sum (out1 ), [x ])
91
125
92
126
fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
93
127
fn , _ = compare_jax_and_py (fg , test_values )
94
128
with jax .disable_jit ():
95
129
fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
96
130
131
+ # Test direct JAXOp usage
132
+ jax_op = JAXOp (
133
+ [x .type ],
134
+ [TensorType (config .floatX , shape = (2 ,)), TensorType (config .floatX , shape = (2 ,))],
135
+ f ,
136
+ )
137
+ out1 , out2 = jax_op (x )
138
+ grad_out = grad (pt .sum (out1 ), [x ])
139
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
140
+ fn , _ = compare_jax_and_py (fg , test_values )
141
+
97
142
98
143
def test_scalar_input_tuple_output ():
99
144
rng = np .random .default_rng (5 )
100
145
x = tensor ("x" , shape = ())
101
146
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
102
147
103
- @as_jax_op
104
148
def f (x ):
105
149
return jax .nn .sigmoid (x ), x
106
150
107
- out1 , out2 = f (x )
151
+ # Test with as_jax_op decorator
152
+ out1 , out2 = as_jax_op (f )(x )
108
153
grad_out = grad (pt .sum (out1 ), [x ])
109
154
110
155
fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
111
156
fn , _ = compare_jax_and_py (fg , test_values )
112
157
with jax .disable_jit ():
113
158
fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
114
159
160
+ # Test direct JAXOp usage
161
+ jax_op = JAXOp (
162
+ [x .type ],
163
+ [TensorType (config .floatX , shape = ()), TensorType (config .floatX , shape = ())],
164
+ f ,
165
+ )
166
+ out1 , out2 = jax_op (x )
167
+ grad_out = grad (pt .sum (out1 ), [x ])
168
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
169
+ fn , _ = compare_jax_and_py (fg , test_values )
170
+
115
171
116
172
def test_single_input_list_output ():
117
173
rng = np .random .default_rng (6 )
118
174
x = tensor ("x" , shape = (2 ,))
119
175
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
120
176
121
- @as_jax_op
122
177
def f (x ):
123
178
return [jax .nn .sigmoid (x ), 2 * x ]
124
179
125
- out1 , out2 = f (x )
180
+ # Test with as_jax_op decorator
181
+ out1 , out2 = as_jax_op (f )(x )
126
182
grad_out = grad (pt .sum (out1 ), [x ])
127
183
128
184
fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
129
185
fn , _ = compare_jax_and_py (fg , test_values )
130
186
with jax .disable_jit ():
131
187
fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
132
188
189
+ # Test direct JAXOp usage, with unspecified output shapes
190
+ jax_op = JAXOp (
191
+ [x .type ],
192
+ [
193
+ TensorType (config .floatX , shape = (None ,)),
194
+ TensorType (config .floatX , shape = (None ,)),
195
+ ],
196
+ f ,
197
+ )
198
+ out1 , out2 = jax_op (x )
199
+ grad_out = grad (pt .sum (out1 ), [x ])
200
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
201
+ fn , _ = compare_jax_and_py (fg , test_values )
202
+
133
203
134
204
def test_pytree_input_tuple_output ():
135
205
rng = np .random .default_rng (7 )
@@ -140,11 +210,11 @@ def test_pytree_input_tuple_output():
140
210
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
141
211
]
142
212
143
- @as_jax_op
144
213
def f (x , y ):
145
214
return jax .nn .sigmoid (x ), 2 * x + y ["y" ] + y ["y2" ][0 ]
146
215
147
- out = f (x , y_tmp )
216
+ # Test with as_jax_op decorator
217
+ out = as_jax_op (f )(x , y_tmp )
148
218
grad_out = grad (pt .sum (out [1 ]), [x , y ])
149
219
150
220
fg = FunctionGraph ([x , y ], [out [0 ], out [1 ], * grad_out ])
@@ -163,11 +233,11 @@ def test_pytree_input_pytree_output():
163
233
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
164
234
]
165
235
166
- @as_jax_op
167
236
def f (x , y ):
168
237
return x , jax .tree_util .tree_map (lambda x : jnp .exp (x ), y )
169
238
170
- out = f (x , y_tmp )
239
+ # Test with as_jax_op decorator
240
+ out = as_jax_op (f )(x , y_tmp )
171
241
grad_out = grad (pt .sum (out [1 ]["b" ][0 ]), [x , y ])
172
242
173
243
fg = FunctionGraph ([x , y ], [out [0 ], out [1 ]["a" ], * grad_out ])
@@ -186,7 +256,6 @@ def test_pytree_input_with_non_graph_args():
186
256
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
187
257
]
188
258
189
- @as_jax_op
190
259
def f (x , y , depth , which_variable ):
191
260
if which_variable == "x" :
192
261
var = x
@@ -198,22 +267,23 @@ def f(x, y, depth, which_variable):
198
267
var = jax .nn .sigmoid (var )
199
268
return var
200
269
270
+ # Test with as_jax_op decorator
201
271
# arguments depth and which_variable are not part of the graph
202
- out = f (x , y_tmp , depth = 3 , which_variable = "x" )
272
+ out = as_jax_op ( f ) (x , y_tmp , depth = 3 , which_variable = "x" )
203
273
grad_out = grad (pt .sum (out ), [x ])
204
274
fg = FunctionGraph ([x , y ], [out [0 ], * grad_out ])
205
275
fn , _ = compare_jax_and_py (fg , test_values )
206
276
with jax .disable_jit ():
207
277
fn , _ = compare_jax_and_py (fg , test_values )
208
278
209
- out = f (x , y_tmp , depth = 7 , which_variable = "y" )
279
+ out = as_jax_op ( f ) (x , y_tmp , depth = 7 , which_variable = "y" )
210
280
grad_out = grad (pt .sum (out ), [x ])
211
281
fg = FunctionGraph ([x , y ], [out [0 ], * grad_out ])
212
282
fn , _ = compare_jax_and_py (fg , test_values )
213
283
with jax .disable_jit ():
214
284
fn , _ = compare_jax_and_py (fg , test_values )
215
285
216
- out = f (x , y_tmp , depth = 10 , which_variable = "z" )
286
+ out = as_jax_op ( f ) (x , y_tmp , depth = 10 , which_variable = "z" )
217
287
assert out == "Unsupported argument"
218
288
219
289
@@ -228,11 +298,11 @@ def test_unused_matrix_product():
228
298
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
229
299
]
230
300
231
- @as_jax_op
232
301
def f (x , y ):
233
302
return x [:, None ] @ y [None ], jnp .exp (x )
234
303
235
- out = f (x , y )
304
+ # Test with as_jax_op decorator
305
+ out = as_jax_op (f )(x , y )
236
306
grad_out = grad (pt .sum (out [1 ]), [x ])
237
307
238
308
fg = FunctionGraph ([x , y ], [out [1 ], * grad_out ])
@@ -241,6 +311,20 @@ def f(x, y):
241
311
with jax .disable_jit ():
242
312
fn , _ = compare_jax_and_py (fg , test_values )
243
313
314
+ # Test direct JAXOp usage
315
+ jax_op = JAXOp (
316
+ [x .type , y .type ],
317
+ [
318
+ TensorType (config .floatX , shape = (3 , 3 )),
319
+ TensorType (config .floatX , shape = (3 ,)),
320
+ ],
321
+ f ,
322
+ )
323
+ out = jax_op (x , y )
324
+ grad_out = grad (pt .sum (out [1 ]), [x ])
325
+ fg = FunctionGraph ([x , y ], [out [1 ], * grad_out ])
326
+ fn , _ = compare_jax_and_py (fg , test_values )
327
+
244
328
245
329
def test_unknown_static_shape ():
246
330
rng = np .random .default_rng (11 )
0 commit comments