Skip to content

Internal error when using flax.linen.scan: AttributeError: module 'jax.api_util' has no attribute 'debug_info' #4603

Closed
@mitscha

Description

@mitscha

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Prod version https://colab.research.google.com/ at time of submission
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: 0.10.4, 0.4.33, 0.4.33
  • Python version: 3.11.11
  • GPU/TPU model and memory: Colab CPU and T4
  • CUDA version (if applicable):

Problem you have encountered:

Encountering internal error when using flax.linen.scan: AttributeError: module 'jax.api_util' has no attribute 'debug_info'

What you expected to happen:

Execution without error.

Logs, error messages, etc:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-1-54809ea297a2>](https://localhost:8080/#) in <cell line: 0>()
     24 
     25 model = SimpleScan()
---> 26 variables = model.init(key_3, init_carry, xs)
     27 out_carry, out_val = model.apply(variables, init_carry, xs)
     28 

    [... skipping hidden 9 frame]

1 frames
[<ipython-input-1-54809ea297a2>](https://localhost:8080/#) in __call__(self, c, xs)
     18                    in_axes=1,
     19                    out_axes=1)
---> 20     return LSTM(out_feat)(c, xs)
     21 
     22 xs = random.uniform(key_1, (batch_size, seq_len, in_feat))

    [... skipping hidden 3 frame]

[/usr/local/lib/python3.11/dist-packages/flax/core/axes_scan.py](https://localhost:8080/#) in scan_fn(broadcast_in, init, *args)
    157 
    158     in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
--> 159     debug_info = jax.api_util.debug_info("flax scan", broadcast_body,
    160                                          (in_tree,), {})
    161     f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(

AttributeError: module 'jax.api_util' has no attribute 'debug_info'

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

# Example from
# https://flax.readthedocs.io/en/v0.5.3/_autosummary/flax.linen.scan.html
# with minor fixes

import flax
import flax.linen as nn
from jax import random

seq_len, batch_size, in_feat, out_feat = 20, 16, 3, 5
key_1, key_2, key_3 = random.split(random.PRNGKey(0), 3)

class SimpleScan(nn.Module):
  @nn.compact
  def __call__(self, c, xs):
    LSTM = nn.scan(nn.LSTMCell,
                   variable_broadcast="params",
                   split_rngs={"params": False},
                   in_axes=1,
                   out_axes=1)
    return LSTM(out_feat)(c, xs)

xs = random.uniform(key_1, (batch_size, seq_len, in_feat))
init_carry = nn.LSTMCell(out_feat).initialize_carry(key_2, (batch_size,))

model = SimpleScan()
variables = model.init(key_3, init_carry, xs)
out_carry, out_val = model.apply(variables, init_carry, xs)

assert out_val.shape == (batch_size, seq_len, out_feat)

Colab prepro:
https://colab.research.google.com/drive/1OUJfQsjoOPxhwt3G0GFL1ex2-MbJt4dS?usp=sharing

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions