Skip to content

Commit

Permalink
Merge pull request #4374 from 8bitmp3:add-nnx-eval_shap-docstring
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696587888
  • Loading branch information
Flax Authors committed Nov 14, 2024
2 parents ac3e85a + ddd7847 commit f265a5e
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion flax/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,14 @@ def _eval_shape_fn(*args, **kwargs):

out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
return extract.from_tree(out)

"""A "lifted" version of `jax.eval_shape <https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html#jax.eval_shape>`_
that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
/ graph nodes as arguments.
Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without
performing any floating point operations (FLOPs) which can be expensive. This can be
useful for performing shape inference, for example.
"""

# -------------------------------
# cond and switch
Expand Down

0 comments on commit f265a5e

Please sign in to comment.