Understanding "Interoperate with purely functional JAX code" #1313
-
Hi, I'm new in Flax. However, when I came into the Flax, the linen module starts with class. Thanks, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The solution in the tutorial you linked to actually does use a class, but that class doesn't maintain a state. The reason is that stateful classes often don't work well under JAX transformations (such as In Linen, exactly the same principle is applied: Modules are indeed classes, but they are stateless. You explicitly pass the state (params, rngs) when you initialize or apply a Module, and the new state is returned. I think it is useful to take a look at the Linen Introduction Colab, which explains these ideas in more detail. |
Beta Was this translation helpful? Give feedback.
The solution in the tutorial you linked to actually does use a class, but that class doesn't maintain a state. The reason is that stateful classes often don't work well under JAX transformations (such as
jit
in the tutorial). The tutorial gives a nice example as well.In Linen, exactly the same principle is applied: Modules are indeed classes, but they are stateless. You explicitly pass the state (params, rngs) when you initialize or apply a Module, and the new state is returned. I think it is useful to take a look at the Linen Introduction Colab, which explains these ideas…