diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index b74dd18c3..03eb91c4e 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -138,7 +138,14 @@ def _eval_shape_fn(*args, **kwargs): out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) return extract.from_tree(out) - + """A "lifted" version of `jax.eval_shape `_ + that can handle `flax.nnx.Module `_ + / graph nodes as arguments. + + Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without + performing any floating point operations (FLOPs) which can be expensive. This can be + useful for performing shape inference, for example. + """ # ------------------------------- # cond and switch