Linen modules get detailed trace annotations by default. Can NNX modules also enable them? #4518
Unanswered
NiklasKappel
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
The following MWE uses
jax.profiler.trace
to profile the inference time of the CNN modules (with the same architecture) from the Flax Linen and Flax NNX documentations, respectively.MWE
When viewing the trace for the Linen model with tensorboard, the relevant section displays the names of the called modules under the "Framework Name Scope" section (here "LinenCNN", "Conv_*" and "Dense_*"):
In the trace for the NNX model, this information is missing:
Is there a built-in way to enable the same kind of annotations for NNX modules? The next best thing I can think of would be to decorate the
__call__
methods of all my NNX modules withjax.named_scope
.Beta Was this translation helpful? Give feedback.
All reactions