File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -139,12 +139,12 @@ def _eval_shape_fn(*args, **kwargs):
139
139
out = jax .eval_shape (_eval_shape_fn , * args , ** kwargs )
140
140
return extract .from_tree (out )
141
141
"""A "lifted" version of `jax.eval_shape <https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html#jax.eval_shape>`_
142
- that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
143
- / graph nodes as arguments.
142
+ that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
143
+ / graph nodes as arguments.
144
144
145
145
Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without
146
- performing any floating point operations (FLOPs) which can be expensive. This can be
147
- useful for performing shape inference, for example.
146
+ performing any floating point operations (FLOPs) which can be expensive. This can be
147
+ useful for performing shape inference, for example.
148
148
"""
149
149
150
150
# -------------------------------
You can’t perform that action at this time.
0 commit comments