Skip to content

Conversation

@ycchenzheng
Copy link
Collaborator

@ycchenzheng ycchenzheng commented Sep 26, 2025

Description

Fix generate_metrics_and_upload_to_big_query flag and related dependencies.

FIXES: b/441984274, b/446097400

Tests

Project preparation:

  • Create a dataset in BigQuery or use an existed one in the target project

benchmarks/recipes/user_configs.py change:
Adding the following flags to

# Define the required configuration here
USER_CONFIG = UserConfig(
  • Set bq_enable to True
  • Set bq_db_project and bq_db_dataset to the user's target project and the dataset in the target project
    Run the command python3 -m benchmarks.recipes.pw_mcjax_benchmark_recipe and after 20 steps, one big query record will be written to the run_summary table of the target dataset

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@ycchenzheng ycchenzheng changed the title Pw/user/chzheng/bq fix Fix generate_metrics_and_upload_to_big_query Sep 26, 2025
@ycchenzheng ycchenzheng force-pushed the pw/user/chzheng/bq_fix branch from 7f29fd5 to 87e8861 Compare September 29, 2025 17:10
@ycchenzheng ycchenzheng marked this pull request as ready for review September 29, 2025 17:16
@ycchenzheng ycchenzheng self-assigned this Sep 30, 2025
absl-py
aqtp
array-record
benchmark_db_writer@git+https://github.com/CIeNET-International/aotc.git@c0bef62eac87c99152ff2e9fd48da1f7d9f3cc04#subdirectory=src/aotc/benchmark_db_writer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we depending on a specific commit from a forked repo? Can we not upstream that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We forked the repo from https://github.com/AI-Hypercomputer/aotc/tree/main/src/aotc/benchmark_db_writer and made some fixes since we got no response from the original repo issue AI-Hypercomputer/aotc#1 and talked with @SujeethJinesh and he is okay with using a forked repo for now.
About using the specific commit, since the forked repo does not have strict rules of merging, we would like to set it to a specific commit just in case the latest has new bugs when implementing new features. @SujeethJinesh WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the change now, I think we should definitely make the fix in the main aotc repo or at least depend on a branch off the aotc repo rather than a fork of it under different ownership. Would it be possible to do that instead?

Please create a bug for this internally and I can follow up with the aotc folks about making appropriate fixes there instead of in a forked repo.

Seems like it should be simple enough to actually do so since I don't think the changes you needed to make were very large.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SujeethJinesh Created b/450288198 for this issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to push a branch to aotc repo but got no permission error.

cloud-tpu-diagnostics
datasets
flax
flax==0.11.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we pinning to this version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only this version works with benchmark_db_writer, otherwise it will encounter the following error:

Traceback (most recent call last):

  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/deps/MaxText/train.py", line 761, in <module>

    app.run(main)

  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run

    _run_main(main, args)

  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main

    sys.exit(main(argv))

             ^^^^^^^^^^

  File "/deps/MaxText/train.py", line 757, in main

    run(config, recorder, diagnostic_config)

  File "/deps/MaxText/train.py", line 752, in run

    train_loop(config, recorder)

  File "/deps/MaxText/train.py", line 618, in train_loop

    ) = setup_train_loop(config, recorder)

        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/train.py", line 554, in setup_train_loop

    state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(

                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/maxtext_utils.py", line 942, in setup_training_state

    return setup_initial_state(

           ^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/maxtext_utils.py", line 981, in setup_initial_state

    unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(

                                                                           ^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/maxtext_utils.py", line 1038, in get_abstract_state

    abstract_state = jax.eval_shape(init_state_partial)

                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/maxtext_utils.py", line 892, in init_initial_state

    model_vars = model.init(

                 ^^^^^^^^^^^

  File "/deps/MaxText/layers/models.py", line 126, in __call__

    logits, hidden_state = self.decoder(

                           ^^^^^^^^^^^^^

  File "/deps/MaxText/layers/decoders.py", line 610, in __call__

    y = self._apply_embedding(

        ^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/layers/decoders.py", line 505, in _apply_embedding

    y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)

        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/layers/nnx_wrappers.py", line 426, in __call__

    self._update_variables(module)

  File "/deps/MaxText/layers/nnx_wrappers.py", line 491, in _update_variables

    collection_state = jax.tree.map(

                       ^^^^^^^^^^^^^

  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map

    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/layers/nnx_wrappers.py", line 485, in _to_linen_var

    return self.metadata_fn(x) # pylint: disable=too-many-function-args

           ^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/layers/initializers.py", line 56, in variable_to_logically_partitioned

    variable.sharding,  # type: ignore[arg-type]

    ^^^^^^^^^^^^^^^^^

  File "/usr/local/lib/python3.12/site-packages/flax/nnx/variablelib.py", line 281, in __getattr__

    return getattr(self.raw_value, name)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

AttributeError: The 'sharding' attribute is not available on traced array with shape float32[128256,4096].

The error occurred while tracing the function init_initial_state at /deps/MaxText/maxtext_utils.py:882 for jit. This value became a tracer due to JAX operations on these lines:



  operation a:key<urbg>[] = random_wrap[impl=urbg] b

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



  operation a:key<urbg>[] = random_fold_in b 3279144704:u32[]

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



  operation a:key<urbg>[] = random_wrap[impl=urbg] b

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



  operation a:key<urbg>[] = random_fold_in b 3279144704:u32[]

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



  operation a:key<urbg>[] = random_wrap[impl=urbg] b

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



(Additional originating lines are not shown.)

--------------------

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Copy link
Collaborator Author

@ycchenzheng ycchenzheng Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The incompatibility was caused by installing the dependencies of benchmark_db_writer, it will install another version higher than 0.11.1
Please check b/441984274 for context

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error seems to be coming from JAX because it thinks the variable is a tracer: https://github.com/jax-ml/jax/blob/5dbbfc38c99b193f43c5273b02263d91cd04a560/jax/_src/core.py#L1047

This may need to be a separate bug fix in MaxText. Specifically, we may want to add this line here

  if isinstance(variable.value, jax.core.Tracer):
    return variable.value

This should help avoid pinning flax.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SujeethJinesh I tried

  if isinstance(variable.value, jax.core.Tracer):
    return variable.value

and unpinned flax, it used flax 0.12 and got the following issue:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/MaxText/train.py", line 510, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/MaxText/train.py", line 506, in main
    run(config, recorder, diagnostic_config)
  File "/deps/src/MaxText/train.py", line 501, in run
    train_loop(config, recorder)
  File "/deps/src/MaxText/train.py", line 364, in train_loop
    ) = train_utils.setup_train_loop(config, recorder)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/MaxText/train_utils.py", line 204, in setup_train_loop
    maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
  File "/deps/src/MaxText/maxtext_utils.py", line 805, in assert_params_sufficiently_sharded
    _raise_if_unsharded_exceeds_tolerance(
  File "/deps/src/MaxText/maxtext_utils.py", line 773, in _raise_if_unsharded_exceeds_tolerance
    raise AssertionError("\n".join(error_msg_lines))
AssertionError: Unsharded parameter percentage (25.00%)exceeds tolerance (2.00%).
The following large tensors are replicated (unsharded) but could be sharded on at least one of the available axes:
 - Name: ['params']['decoder']['layers']['mlp']['wi_0']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
 - Name: ['params']['decoder']['layers']['mlp']['wi_1']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
 - Name: ['params']['decoder']['layers']['mlp']['wo']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
 - Name: ['params']['decoder']['layers']['self_attention']['out']['kernel'](Size: 536870912, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
 - Name: ['params']['decoder']['layers']['self_attention']['query']['kernel'](Size: 536870912, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check PR #2502 for another solution to avoid pinning flax

@ycchenzheng ycchenzheng force-pushed the pw/user/chzheng/bq_fix branch 4 times, most recently from bdd517d to a298159 Compare October 2, 2025 16:56
Copy link
Collaborator

@SujeethJinesh SujeethJinesh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

absl-py
aqtp
array-record
benchmark_db_writer@git+https://github.com/CIeNET-International/aotc.git@c0bef62eac87c99152ff2e9fd48da1f7d9f3cc04#subdirectory=src/aotc/benchmark_db_writer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the change now, I think we should definitely make the fix in the main aotc repo or at least depend on a branch off the aotc repo rather than a fork of it under different ownership. Would it be possible to do that instead?

Please create a bug for this internally and I can follow up with the aotc folks about making appropriate fixes there instead of in a forked repo.

Seems like it should be simple enough to actually do so since I don't think the changes you needed to make were very large.

cloud-tpu-diagnostics
datasets
flax
flax==0.11.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error seems to be coming from JAX because it thinks the variable is a tracer: https://github.com/jax-ml/jax/blob/5dbbfc38c99b193f43c5273b02263d91cd04a560/jax/_src/core.py#L1047

This may need to be a separate bug fix in MaxText. Specifically, we may want to add this line here

  if isinstance(variable.value, jax.core.Tracer):
    return variable.value

This should help avoid pinning flax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants