scan instead of for + if block #1395
-
Hello, guys! Here is the deal: I have recently read that using scan instead of for loops is generally more efficient. Here is an example on how it can work: from jax import numpy as jnp, random
from flax import linen as nn
class Iteration(nn.Module):
def __call__(self, carry, iterable):
# unpack carry
value, i = carry
# add the iteration number to value at each iteration
value += i
# repack carry
carry = (value, i+1) # go to next i
return carry, iterable
class ScanUsedAsFor(nn.Module):
def setup(self):
self.scan = nn.scan(Iteration,
variable_broadcast=('params'),
split_rngs={'params': False},
in_axes=0,
out_axes=0)()
def __call__(self, start_value, n_iter):
iterable = jnp.arange(n_iter)
carry = start_value, 0
carry, out = self.scan(carry, iterable)
value, i = carry
return value
model = ScanUsedAsFor()
key = random.PRNGKey(0)
params = model.init(key, 0, 0)
result = model.apply(params, 0, 5)
print(result) The module ScanUsedAsFor above sums to start_value each number from 0 to n_iter. To acquire this, I have to write a module as a body function ( Now if I have to check if the value of i is higher than 2 before summing: class Iteration(nn.Module):
def __call__(self, carry, iterable):
# unpack carry
value, i = carry
# doing an important check
if i>2:
value += i
# repack carry
carry = (value, i+1) # go to next i
return carry, iterable this is the error I get:
If I cannot materialize the iteration number like this, how can I do this kind of filtering in my loop? I gave many trials on this today, but I still could not fully understand my problem from the docs. I tried using variables, attributes... But consistently failed. Thanks for helping me out. Tips and blames are welcome. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
If you want to do a conditional on traced values (that is -- values within JAX transformations) then you have to use |
Beta Was this translation helpful? Give feedback.
If you want to do a conditional on traced values (that is -- values within JAX transformations) then you have to use
lax.cond
rather than<
.