Replies: 2 comments
-
It's definitely an issue with initialization. For example, using setup, this doesnt work:
but this workaround I came up with does:
|
Beta Was this translation helpful? Give feedback.
-
The error is not super informative here but it comes from JAX so we cannot fix that. So in general shapes should be static so you should use normal python or numpy ops on shapes (note the jnp -> np):
now the range will have constant size and the rest should work. |
Beta Was this translation helpful? Give feedback.
-
I'm working on implementing Hadamard Transforms. I use the following function to build a hadamard matrix:
I cannot lower this to a fori_loop or while, because the shape changes every iteration. I then put it in a simple dense layer:
I can initialize this model and my parameters look fine (i.e. its a pytree with a device array):
However, when I want to do the forward pass with
I get a concretizationtype error on the for loop of the hadamard function: 'ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.'
I'm not sure why this is happening - I can init the model, so I would expect forward to work. Is jitting the apply function also jitting the hadamard function?
Beta Was this translation helpful? Give feedback.
All reactions