Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal error when using flax.linen.scan: AttributeError: module 'jax.api_util' has no attribute 'debug_info' #4603

Open
mitscha opened this issue Mar 5, 2025 · 1 comment

Comments

@mitscha
Copy link

mitscha commented Mar 5, 2025

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Prod version https://colab.research.google.com/ at time of submission
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: 0.10.4, 0.4.33, 0.4.33
  • Python version: 3.11.11
  • GPU/TPU model and memory: Colab CPU and T4
  • CUDA version (if applicable):

Problem you have encountered:

Encountering internal error when using flax.linen.scan: AttributeError: module 'jax.api_util' has no attribute 'debug_info'

What you expected to happen:

Execution without error.

Logs, error messages, etc:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-1-54809ea297a2>](https://localhost:8080/#) in <cell line: 0>()
     24 
     25 model = SimpleScan()
---> 26 variables = model.init(key_3, init_carry, xs)
     27 out_carry, out_val = model.apply(variables, init_carry, xs)
     28 

    [... skipping hidden 9 frame]

1 frames
[<ipython-input-1-54809ea297a2>](https://localhost:8080/#) in __call__(self, c, xs)
     18                    in_axes=1,
     19                    out_axes=1)
---> 20     return LSTM(out_feat)(c, xs)
     21 
     22 xs = random.uniform(key_1, (batch_size, seq_len, in_feat))

    [... skipping hidden 3 frame]

[/usr/local/lib/python3.11/dist-packages/flax/core/axes_scan.py](https://localhost:8080/#) in scan_fn(broadcast_in, init, *args)
    157 
    158     in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
--> 159     debug_info = jax.api_util.debug_info("flax scan", broadcast_body,
    160                                          (in_tree,), {})
    161     f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(

AttributeError: module 'jax.api_util' has no attribute 'debug_info'

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

# Example from
# https://flax.readthedocs.io/en/v0.5.3/_autosummary/flax.linen.scan.html
# with minor fixes

import flax
import flax.linen as nn
from jax import random

seq_len, batch_size, in_feat, out_feat = 20, 16, 3, 5
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

class SimpleScan(nn.Module):
  @nn.compact
  def __call__(self, c, xs):
    LSTM = nn.scan(nn.LSTMCell,
                   variable_broadcast="params",
                   split_rngs={"params": False},
                   in_axes=1,
                   out_axes=1)
    return LSTM(out_feat)(c, xs)

xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
init_carry = nn.LSTMCell(out_feat).initialize_carry(key_2, (batch_size,))

model = SimpleScan()
variables = model.init(key_3, init_carry, xs)
out_carry, out_val = model.apply(variables, init_carry, xs)

assert out_val.shape == (batch_size, seq_len, out_feat)

Colab prepro:
https://colab.research.google.com/drive/1OUJfQsjoOPxhwt3G0GFL1ex2-MbJt4dS?usp=sharing

@IvyZX
Copy link
Collaborator

IvyZX commented Mar 5, 2025

You'd want to update your JAX to the latest version - that should resolve this.

(We should update our minimum JAX requirement too. Gonna do that now.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants