Skip to content

Compute gradients wrt inputs instead of params #2057

Answered by andsteing
andsteing asked this question in General
Discussion options

You must be logged in to vote

(answer originally provided by @levskaya)

setup code

import jax
from jax import lax, numpy as jnp, random
import flax
from flax import linen as nn

defining a layer_fn(x, variables)

layer = nn.Dense(32)

k0, k1, k2, k3 = random.split(random.PRNGKey(0), 4)
x = random.uniform(k0, (2,16))
variables = layer.init(k1, x)
zeros_like_variables = jax.tree_map(jnp.zeros_like, variables)

layer_fn = lambda x, variables: layer.apply(variables, x)

jvp

We have tangents we'd like to push forward.

tangent = random.uniform(k2, (2, 16))   # or "dx" to map forwards

y, tangents_out = jax.jvp(layer_fn,
                          (x, variables),                  # inputs
                          (tangent, zero…

Replies: 1 comment

Comment options

andsteing
Apr 19, 2022
Maintainer Author

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
1 participant