-
(question originally asked by Bogdan Mazoure) Hi, if I have a network, e.g. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
(answer originally provided by @levskaya) setup codeimport jax
from jax import lax, numpy as jnp, random
import flax
from flax import linen as nn defining a layer = nn.Dense(32)
k0, k1, k2, k3 = random.split(random.PRNGKey(0), 4)
x = random.uniform(k0, (2,16))
variables = layer.init(k1, x)
zeros_like_variables = jax.tree_map(jnp.zeros_like, variables)
layer_fn = lambda x, variables: layer.apply(variables, x)
|
Beta Was this translation helpful? Give feedback.
(answer originally provided by @levskaya)
setup code
defining a
layer_fn(x, variables)
jvp
We have tangents we'd like to push forward.