Skip to content

Understanding "Interoperate with purely functional JAX code" #1313

Answered by marcvanzee
sh0416 asked this question in Q&A
Discussion options

You must be logged in to vote

I followed up the tutorial in JAX documentation and I made myself not to use stateful function such as class. link

The solution in the tutorial you linked to actually does use a class, but that class doesn't maintain a state. The reason is that stateful classes often don't work well under JAX transformations (such as jit in the tutorial). The tutorial gives a nice example as well.

In Linen, exactly the same principle is applied: Modules are indeed classes, but they are stateless. You explicitly pass the state (params, rngs) when you initialize or apply a Module, and the new state is returned. I think it is useful to take a look at the Linen Introduction Colab, which explains these ideas…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@sh0416
Comment options

@marcvanzee
Comment options

@sh0416
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants