Skip to content

Commit ddd7847

Browse files
authored
Add flax.nnx.eval_shape docstring
1 parent 6b5b300 commit ddd7847

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

flax/nnx/transforms/transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ def _eval_shape_fn(*args, **kwargs):
139139
out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
140140
return extract.from_tree(out)
141141
"""A "lifted" version of `jax.eval_shape <https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html#jax.eval_shape>`_
142-
that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
143-
/ graph nodes as arguments.
142+
that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
143+
/ graph nodes as arguments.
144144
145145
Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without
146-
performing any floating point operations (FLOPs) which can be expensive. This can be
147-
useful for performing shape inference, for example.
146+
performing any floating point operations (FLOPs) which can be expensive. This can be
147+
useful for performing shape inference, for example.
148148
"""
149149

150150
# -------------------------------

0 commit comments

Comments
 (0)