Skip to content

Commit

Permalink
Add flax.nnx.eval_shape docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 authored Nov 11, 2024
1 parent 6b5b300 commit ddd7847
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions flax/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ def _eval_shape_fn(*args, **kwargs):
out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
return extract.from_tree(out)
"""A "lifted" version of `jax.eval_shape <https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html#jax.eval_shape>`_
that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
/ graph nodes as arguments.
that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_
/ graph nodes as arguments.
Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without
performing any floating point operations (FLOPs) which can be expensive. This can be
useful for performing shape inference, for example.
performing any floating point operations (FLOPs) which can be expensive. This can be
useful for performing shape inference, for example.
"""

# -------------------------------
Expand Down

0 comments on commit ddd7847

Please sign in to comment.