Closed
Description
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Prod version https://colab.research.google.com/ at time of submission
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: 0.10.4, 0.4.33, 0.4.33 - Python version: 3.11.11
- GPU/TPU model and memory: Colab CPU and T4
- CUDA version (if applicable):
Problem you have encountered:
Encountering internal error when using flax.linen.scan: AttributeError: module 'jax.api_util' has no attribute 'debug_info'
What you expected to happen:
Execution without error.
Logs, error messages, etc:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-1-54809ea297a2>](https://localhost:8080/#) in <cell line: 0>()
24
25 model = SimpleScan()
---> 26 variables = model.init(key_3, init_carry, xs)
27 out_carry, out_val = model.apply(variables, init_carry, xs)
28
[... skipping hidden 9 frame]
1 frames
[<ipython-input-1-54809ea297a2>](https://localhost:8080/#) in __call__(self, c, xs)
18 in_axes=1,
19 out_axes=1)
---> 20 return LSTM(out_feat)(c, xs)
21
22 xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
[... skipping hidden 3 frame]
[/usr/local/lib/python3.11/dist-packages/flax/core/axes_scan.py](https://localhost:8080/#) in scan_fn(broadcast_in, init, *args)
157
158 in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
--> 159 debug_info = jax.api_util.debug_info("flax scan", broadcast_body,
160 (in_tree,), {})
161 f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
AttributeError: module 'jax.api_util' has no attribute 'debug_info'
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
# Example from
# https://flax.readthedocs.io/en/v0.5.3/_autosummary/flax.linen.scan.html
# with minor fixes
import flax
import flax.linen as nn
from jax import random
seq_len, batch_size, in_feat, out_feat = 20, 16, 3, 5
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)
class SimpleScan(nn.Module):
@nn.compact
def __call__(self, c, xs):
LSTM = nn.scan(nn.LSTMCell,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=1,
out_axes=1)
return LSTM(out_feat)(c, xs)
xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
init_carry = nn.LSTMCell(out_feat).initialize_carry(key_2, (batch_size,))
model = SimpleScan()
variables = model.init(key_3, init_carry, xs)
out_carry, out_val = model.apply(variables, init_carry, xs)
assert out_val.shape == (batch_size, seq_len, out_feat)
Colab prepro:
https://colab.research.google.com/drive/1OUJfQsjoOPxhwt3G0GFL1ex2-MbJt4dS?usp=sharing
Metadata
Metadata
Assignees
Labels
No labels