-
Notifications
You must be signed in to change notification settings - Fork 129
Implement @as_jax_op
to wrap a JAX function for use in PyTensor
#1120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I have a question, where should I put the @as_jax_op. Currently, it is in a new file |
We can put in init as long as imports work in a way that jax is still optional for Pytensor users (obviously calling the decorator can raise if it's not installed, hopefully with an informative message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks sweet. I'll do a more careful review later, just skimmed through and annotated some thoughts
pytensor/link/jax/ops.py
Outdated
self.num_inputs = len(inputs) | ||
|
||
# Define our output variables | ||
outputs = [pt.as_tensor_variable(type()) for type in self.output_types] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use jax machinery to infer the output types from the input types? Can we created TraceDArrays (or whatever they're called) and pass them through the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scrap that, JAX doesn't let you trace arrays without unknown shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I trace the shape through the JAX function in line 119 of the file. It won't work for unknown shape. But if one specifies the shape at the beginning of a graph, i.e. x = pm.Normal("x", shape=(3,))
, and it loses static shape information afterwards, for instance because of a pt.cumsum, line 99 (pytensor.compile.builders.infer_shape
) will be able to infer the shape. But that is a good comment, I will raise an error if pytensor.compile.builders.infer_shape
isn't able to infer the shape. I think it makes sense to only use this wrapper if the shape information is known.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I see a point where it will lead to problems: If there is an input x = pm.Data("x", shape=(None,), value= np.array([0., 0]))
: in the first run, it will work, as pytensor.compile.builders.infer_shape
will infer the shape as (2,), but if one changes with x.set_value(np.array([0., 0, 0]))
the shape of x
, it will lead to an error, as the Pytensor Op has been created with an explicit shape. I could simply add a parameter to as_jax_op
to force all output shapes to None
, then it should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will write more tests, then it will be clearer what I mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we can't use shape unless it's static. Ideally it shouldn't fail for unknown shapes, but then the user has to tell user the output types.
We can allow the user to specify a make_node callable? That way it can be made to work with different dtypes/ndims if the jax function handles those fine
pytensor/link/jax/ops.py
Outdated
return (result,) # Pytensor requires a tuple here | ||
|
||
# vector-jacobian product Op | ||
class VJPSolOp(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a nice follow up would be to also create a "ValueAndGrad" version of the Op that gets introduced in rewrites when both the Op and the VJP of Op (or JVP) are in the final graph.
This need not be a blocker for this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see exactly what you mean. Is ValueAndGrad used by Pytensor? I searched the codebase but didn't find a mention of it. Does it have to do with implementing L_op
? I haven't really understood the difference between it and grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX has the value and grad concept to more optimally compute both together. PyTensor doesn't have that concept because everything is lazy but we can exploit it during the rewrite phase.
If a user compiles a function that includes both forward and gradient of the same wrapped JAX Op, we could replace it by a third Op whose perform implementation requests jax to compute both.
This is not relevant when the autodiff is done in JAX, but it's relevant when it's done in PyTensor
pytensor/link/jax/ops.py
Outdated
jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, | ||
) | ||
|
||
@jax_funcify.register(SolOp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can dispatch on the base class just once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? This jax.funcify is once registering SolOp, once VJPSolOp. You mean, one could include the gradient calculation in SolOp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean you can define SolOp class outside the decorator and dispatch on that.
Then the decorator can return a subclass of that and you don't need to bother dispatching because the base class dispatch will cover it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, I didn't think of that. Have a look at whether I implemented it like you had envisioned
pytensor/link/jax/ops.py
Outdated
jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, | ||
) | ||
|
||
@jax_funcify.register(SolOp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean you can define SolOp class outside the decorator and dispatch on that.
Then the decorator can return a subclass of that and you don't need to bother dispatching because the base class dispatch will cover it
Big level picture. What's going on with the flattening of inputs and why is it needed? |
To be able to wrap JAX function that accept pytrees as input. |
And if I have a matrix input function will this work or expect it to be a vector instead? |
It will work, it doesn't change |
I would begin in parallel to write an example notebook. I opened an issue here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments, mostly related to tests and a couple of questions regarding PR scope.
Some of the advanced behaviors are a bit opaque from the outside, and I don't get if this is related to functionality that is actually needed (but it's perhaps easier to test like this) or we could do without for the sake of a simpler implementation.
I also have to try this locally, I'm curious how it behaves without static shapes on the inputs.
Overall, this is still looking great and very promising.
tests/link/jax/test_as_jax_op.py
Outdated
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) | ||
] | ||
|
||
x = pt.cumsum(x) # Now x has an unknown shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently an implementation detail, better to have x = tensor(..., shape=(None,))
.
How does this work btw, what is out.type.shape
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your comment above, so I guess you are using PyTensor infer_shape stuff to figure out the output shape even if at write time cumsum did not.
However it will still not work if a root input has no static shape. I would suggest allowing users to define make_node
of a JAX Op which exists exactly for this purpose. JAX doesn't have a concept of f(vector)->vector of unknown shape (because shapes are always concrete during tracing), but PyTensor is perfectly happy about this.
tests/link/jax/test_as_jax_op.py
Outdated
|
||
@as_jax_op | ||
def f(x, y, message): | ||
return x * jnp.ones(3), "Success: " + message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens here, this output is just ignored? Do we need to support this sort of functionality?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I extended this test, the output can be used, but not by pytensor. We don't need to support this functionality, but it doesn't hurt much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is replaced by the test "test_pytree_input_with_non_graph_args"
tests/link/jax/test_as_jax_op.py
Outdated
fn, _ = compare_jax_and_py(fg, test_values) | ||
|
||
|
||
def test_as_jax_op13(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this functionality?
Regarding the functionality and whether one should remove some of it for sake of simplicity: My goal was that
One additional remark, removing the functionality of point 2 and 3 would also remove the additionally dependency on equinox, I don't know how relevant it is for the decision |
There's already something like that: I suggested allowing the user to specify
My issue with non-numerical outputs is that, from reading the tests, are arbitrarily truncated? In that test where a JAX function has a string output. PyTensor is rather flexible in what types of variables it can accommodate, for instance we have string types implemented here: https://github.com/pymc-devs/pymc/blob/e0e751199319e68f376656e2477c1543606c49c7/pymc/pytensorf.py#L1101-L1116 PyTensor itself has sparse matrices, homogenous lists, slices, None, scalars ... As such it seems odd to me to support some extra types only on this JAX wrapper Op helper. If those types are deemed useful enough for this wrapper to handle them, then the case would be made we should add them as regular PyTensor types, and not-special case JAX. I guess I'm just not clear as to what the wrapper is doing with these special inputs (I'm assuming outputs are just being ignored as I wrote above). For the inputs, it's creating a partial function on the perform method? Then it sounds like they should also be implemented as
Also seems somewhat similar to PyTensor Ops with inner functions (ScalarLoop, OpFromGraph, Scan), that compile inner PyTensor functions (or dispatched equivalents on backends like JAX). I guess the common theme is that this PR may be reinventing several things that PyTensor already does (I could be wrong), and there may be room to reuse existing functionality, or expanding it so that it's not restricted to the JAX backend, and more specifically this wrapper. Let me know if any of this makes sense. |
I added a ToDo list in the first post, so you can check the progress. I refactored the code with the help of Cursor AI, and now JAXOp can also be called directly, which can be used to specify undetermined output shapes. I also think it would be useful to have a Zoom meeting to have a better idea of which direction to go. You could for example write me via the Pymc discourse. Mondays and Tuesdays are quite full for me, but otherwise, I am generally available. |
…he previous approach for testing purposes
…be used without the decorator @as_jax_op
Thank you for the great work! I would love to see this feature implemented!
I strongly support to keep time-dependent variables, since this is a feature that sunode does not support to my knowledge. I'm currently exploring the use of pymc in favor of numpyro for inference of ODEs and having a module that translates existing ode models directly to pytensor would be fantastic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much cleaner! Left some comments. I'm still hesitating about some of the symbolic/constant automatic split. I'll try to reread the discussion but if you don't mind making the point again of why this is needed (and needed at the PyTensor level) I would appreciate it.
My naive intuition is this could all be handled by the user doing extra stuff themselves on the jax function they provide whithout PyTensor having to know anything about it?
@@ -24,4 +24,4 @@ dependencies: | |||
- pip | |||
- pip: | |||
- sphinx_sitemap | |||
- -e .. | |||
- -e ..[jax] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be reverted?
@@ -63,6 +63,13 @@ Convert to Variable | |||
|
|||
.. autofunction:: pytensor.as_symbolic(...) | |||
|
|||
Wrap JAX functions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminds me, may want to add something on defining jax/numba/pytorch
ops page
try: | ||
import pytensor.link.jax.ops | ||
from pytensor.link.jax.ops import as_jax_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not force eager import of JAX. You can make it inside the as_jax_op
?
We did some effort to reduce import times of the library
self.__class__.__qualname__.split(".")[:-1] + [name] | ||
) | ||
|
||
def make_node(self, *inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have the input types you can also convert them so it works if say users pass numpy arrays and so on. I don't remember the syntax exactly, something like type.filter_variable(inp)
?
zero_shape = ( | ||
self.output_types[i].shape | ||
if None not in self.output_types[i].shape | ||
else () | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we requiring all outputs have shape? Otherwise you may want to do something like in 84c7802
grad_out = grad(pt.sum(out), [x, y]) | ||
|
||
fg = FunctionGraph([x, y], [out, *grad_out]) | ||
fn, _ = compare_jax_and_py(fg, test_values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We merged a change in main, compare_jax_and_py
no longer expects a FunctionGRaph as input, just inputs, outputs, test_values
fn, _ = compare_jax_and_py(fg, test_values) | ||
|
||
|
||
def test_two_inputs_list_output_one_unused_output(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps merge this with the previous test? Here is how we did it recently to test the grad of multi-output Ops:
pytensor/tests/tensor/test_nlinalg.py
Lines 260 to 278 in 19dafe4
def svd_fn(A, case=0): | |
U, s, V = svd(A, compute_uv=compute_uv, full_matrices=full_matrices) | |
if case == 0: | |
return U.sum() | |
elif case == 1: | |
return s.sum() | |
elif case == 2: | |
return V.sum() | |
elif case == 3: | |
return U.sum() + s.sum() | |
elif case == 4: | |
return s.sum() + V.sum() | |
elif case == 5: | |
return U.sum() + V.sum() | |
elif case == 6: | |
return U.sum() + s.sum() + V.sum() | |
elif case == 7: | |
# All inputs disconnected | |
return as_tensor_variable(3.0) |
out = jax_op(x, y) | ||
grad_out = grad(pt.sum(out), [x, y]) | ||
fg = FunctionGraph([x, y], [out, *grad_out]) | ||
fn, _ = compare_jax_and_py(fg, test_values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the py implementation is also new to this PR (and therefore not a reference), I would suggest also evaluating fn explicitly and asserting it is the expected value (for all tests)
fn, _ = compare_jax_and_py(fg, test_values) | ||
|
||
|
||
def test_unknown_static_shape(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a test with an input whose static shape simply cannot be resolved (as in tensor(shape=(None,))
?
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) | ||
] | ||
|
||
x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be robust create a new test Op that explicitly does not have static output shape. cumsum
could be updated to output static shape in the future.
Description
Add a decorator that transforms a JAX function such that it can be used in PyTensor. Shape and dtype inference works automatically and input and output can be any nested python structure (e.g. Pytrees). Furthermore, using a transformed function as an argument for another transformed function should also work.
Related Issue
@as_jax_op
to wrap JAX functions in PyTensor #537Checklist
Type of change
ToDos
Op.__props__
make_node
be specified by the user, to support non-inferrable shapes- JAXOp is now directly usable by the user
📚 Documentation preview 📚: https://pytensor--1120.org.readthedocs.build/en/1120/