Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions benchmarks/benchmark_db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def write_run(

from benchmark_db_writer import bq_writer_utils
from benchmark_db_writer import dataclass_bigquery_writer
from benchmark_db_writer.run_summary_writer import sample_run_summary_writer
from benchmark_db_writer.run_summary_writer import run_summary_writer
from benchmark_db_writer.schema.workload_benchmark_v2 import workload_benchmark_v2_schema

def get_db_client(
Expand All @@ -168,9 +168,9 @@ def get_db_client(
print(options.model_id)

if (
sample_run_summary_writer.validate_model_id(options.model_id, options.is_test)
and sample_run_summary_writer.validate_hardware_id(options.hardware_id, options.is_test)
and sample_run_summary_writer.validate_software_id(options.software_id, options.is_test)
run_summary_writer.validate_model_id(options.model_id, options.is_test)
and run_summary_writer.validate_hardware_id(options.hardware_id, options.is_test)
and run_summary_writer.validate_software_id(options.software_id, options.is_test)
):
summary = workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema(
run_id=f"run-{uuid.uuid4()}",
Expand All @@ -179,6 +179,7 @@ def get_db_client(
hardware_id=options.hardware_id,
hardware_num_chips=number_of_chips,
hardware_num_nodes=number_of_nodes,
hardware_num_slices=options.hardware_num_slices,
result_success=run_success,
configs_framework=framework_config_in_json,
configs_env=env_variables,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os.path

# This is the MaxText root: with "max_utils.py"; &etc. TODO: Replace `os.path.basename` with `os.path.abspath`
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "MaxText")
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/MaxText")

# This is the maxtext repo root: with ".git" folder; "README.md"; "pyproject.toml"; &etc.
MAXTEXT_REPO_ROOT = os.environ.get(
Expand Down
7 changes: 6 additions & 1 deletion benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def __post_init__(self):
else:
self.num_devices_per_slice = int(self.device_type.split("-")[1]) / 2
self.topology = ""
self.hardware_id = self.device_type.split("-")[0]
if self.hardware_id == "v5litepod":
self.hardware_id = "v5e"


def wait_for_xpk_workload_completion(cluster_config: XpkClusterConfig, workload_name, xpk_path) -> int:
Expand Down Expand Up @@ -341,6 +344,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict:
"model_id": wl_config.model.model_type,
"hardware_id": wl_config.hardware_id,
"software_id": "jax_maxtext",
"hardware_num_slices": wl_config.num_slices,
"number_of_chips": wl_config.num_devices_per_slice * wl_config.num_slices,
"container_image_name": wl_config.base_docker_image,
"global_batch_size": per_device_batch_size * wl_config.num_devices_per_slice * wl_config.num_slices,
Expand Down Expand Up @@ -445,7 +449,8 @@ def build_user_command(
f"base_output_directory={wl_config.base_output_directory}",
f"{vertex_tensorboard}",
f"{run_name_command}",
f"{enable_metrics_cmd}" f"{upload_hlo_dump}",
f"{enable_metrics_cmd}",
f"{upload_hlo_dump}",
]
)
return command
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/recipes/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def generate_and_run_workloads(user_config, num_slices_list, num_steps, priority
num_slices_list: A list of the number of slices to be executed.
num_steps: The number of steps for each workload.
"""
if user_config.bq_enable and (not user_config.bq_db_project or not user_config.bq_db_dataset):
logging.error("Validation FAILED: BQ is enabled, but project or dataset is missing.")
return 1
xpk_workload_cmds = []
xpk_workload_names = []

Expand Down Expand Up @@ -65,6 +68,9 @@ def generate_and_run_workloads(user_config, num_slices_list, num_steps, priority
xpk_path=user_config.xpk_path,
num_steps=num_steps,
priority=priority,
generate_metrics_and_upload_to_big_query=user_config.bq_enable,
db_project=user_config.bq_db_project,
db_dataset=user_config.bq_db_dataset,
)

# Generate XPK command
Expand Down
5 changes: 5 additions & 0 deletions benchmarks/recipes/user_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class UserConfig:
selected_model_names: list[str] = dataclasses.field(default_factory=lambda: ["llama3_1_8b_8192"])
num_slices_list: list[int] = dataclasses.field(default_factory=lambda: [2])

# BigQuery configuration
bq_enable: bool = False
bq_db_project: str = ""
bq_db_dataset: str = ""

# other configuration
xpk_path: str = "~/xpk"
max_restarts: int = 0
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/upload_metrics_to_bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ def add_parser_arguments(parser: argparse.ArgumentParser):
default=True,
help="Whether to use the testing project or production project",
)
parser.add_argument(
"--hardware_num_slices",
type=int,
required=False,
help="hardware slice number",
)


def download_metrics_file_locally(metrics_gcs_file: str, local_file: str) -> int:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
absl-py
aqtp
array-record
benchmark_db_writer@git+https://github.com/CIeNET-International/aotc.git@c0bef62eac87c99152ff2e9fd48da1f7d9f3cc04#subdirectory=src/aotc/benchmark_db_writer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we depending on a specific commit from a forked repo? Can we not upstream that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We forked the repo from https://github.com/AI-Hypercomputer/aotc/tree/main/src/aotc/benchmark_db_writer and made some fixes since we got no response from the original repo issue AI-Hypercomputer/aotc#1 and talked with @SujeethJinesh and he is okay with using a forked repo for now.
About using the specific commit, since the forked repo does not have strict rules of merging, we would like to set it to a specific commit just in case the latest has new bugs when implementing new features. @SujeethJinesh WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the change now, I think we should definitely make the fix in the main aotc repo or at least depend on a branch off the aotc repo rather than a fork of it under different ownership. Would it be possible to do that instead?

Please create a bug for this internally and I can follow up with the aotc folks about making appropriate fixes there instead of in a forked repo.

Seems like it should be simple enough to actually do so since I don't think the changes you needed to make were very large.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SujeethJinesh Created b/450288198 for this issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to push a branch to aotc repo but got no permission error.

cloud-accelerator-diagnostics
cloud-tpu-diagnostics
datasets
flax
flax==0.11.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we pinning to this version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only this version works with benchmark_db_writer, otherwise it will encounter the following error:

Traceback (most recent call last):

  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/deps/MaxText/train.py", line 761, in <module>

    app.run(main)

  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run

    _run_main(main, args)

  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main

    sys.exit(main(argv))

             ^^^^^^^^^^

  File "/deps/MaxText/train.py", line 757, in main

    run(config, recorder, diagnostic_config)

  File "/deps/MaxText/train.py", line 752, in run

    train_loop(config, recorder)

  File "/deps/MaxText/train.py", line 618, in train_loop

    ) = setup_train_loop(config, recorder)

        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/train.py", line 554, in setup_train_loop

    state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(

                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/maxtext_utils.py", line 942, in setup_training_state

    return setup_initial_state(

           ^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/maxtext_utils.py", line 981, in setup_initial_state

    unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(

                                                                           ^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/maxtext_utils.py", line 1038, in get_abstract_state

    abstract_state = jax.eval_shape(init_state_partial)

                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/maxtext_utils.py", line 892, in init_initial_state

    model_vars = model.init(

                 ^^^^^^^^^^^

  File "/deps/MaxText/layers/models.py", line 126, in __call__

    logits, hidden_state = self.decoder(

                           ^^^^^^^^^^^^^

  File "/deps/MaxText/layers/decoders.py", line 610, in __call__

    y = self._apply_embedding(

        ^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/layers/decoders.py", line 505, in _apply_embedding

    y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)

        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/layers/nnx_wrappers.py", line 426, in __call__

    self._update_variables(module)

  File "/deps/MaxText/layers/nnx_wrappers.py", line 491, in _update_variables

    collection_state = jax.tree.map(

                       ^^^^^^^^^^^^^

  File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map

    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/layers/nnx_wrappers.py", line 485, in _to_linen_var

    return self.metadata_fn(x) # pylint: disable=too-many-function-args

           ^^^^^^^^^^^^^^^^^^^

  File "/deps/MaxText/layers/initializers.py", line 56, in variable_to_logically_partitioned

    variable.sharding,  # type: ignore[arg-type]

    ^^^^^^^^^^^^^^^^^

  File "/usr/local/lib/python3.12/site-packages/flax/nnx/variablelib.py", line 281, in __getattr__

    return getattr(self.raw_value, name)

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

AttributeError: The 'sharding' attribute is not available on traced array with shape float32[128256,4096].

The error occurred while tracing the function init_initial_state at /deps/MaxText/maxtext_utils.py:882 for jit. This value became a tracer due to JAX operations on these lines:



  operation a:key<urbg>[] = random_wrap[impl=urbg] b

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



  operation a:key<urbg>[] = random_fold_in b 3279144704:u32[]

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



  operation a:key<urbg>[] = random_wrap[impl=urbg] b

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



  operation a:key<urbg>[] = random_fold_in b 3279144704:u32[]

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



  operation a:key<urbg>[] = random_wrap[impl=urbg] b

    from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)



(Additional originating lines are not shown.)

--------------------

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Copy link
Collaborator Author

@ycchenzheng ycchenzheng Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The incompatibility was caused by installing the dependencies of benchmark_db_writer, it will install another version higher than 0.11.1
Please check b/441984274 for context

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error seems to be coming from JAX because it thinks the variable is a tracer: https://github.com/jax-ml/jax/blob/5dbbfc38c99b193f43c5273b02263d91cd04a560/jax/_src/core.py#L1047

This may need to be a separate bug fix in MaxText. Specifically, we may want to add this line here

  if isinstance(variable.value, jax.core.Tracer):
    return variable.value

This should help avoid pinning flax.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SujeethJinesh I tried

  if isinstance(variable.value, jax.core.Tracer):
    return variable.value

and unpinned flax, it used flax 0.12 and got the following issue:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/deps/src/MaxText/train.py", line 510, in <module>
    app.run(main)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/deps/src/MaxText/train.py", line 506, in main
    run(config, recorder, diagnostic_config)
  File "/deps/src/MaxText/train.py", line 501, in run
    train_loop(config, recorder)
  File "/deps/src/MaxText/train.py", line 364, in train_loop
    ) = train_utils.setup_train_loop(config, recorder)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/src/MaxText/train_utils.py", line 204, in setup_train_loop
    maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
  File "/deps/src/MaxText/maxtext_utils.py", line 805, in assert_params_sufficiently_sharded
    _raise_if_unsharded_exceeds_tolerance(
  File "/deps/src/MaxText/maxtext_utils.py", line 773, in _raise_if_unsharded_exceeds_tolerance
    raise AssertionError("\n".join(error_msg_lines))
AssertionError: Unsharded parameter percentage (25.00%)exceeds tolerance (2.00%).
The following large tensors are replicated (unsharded) but could be sharded on at least one of the available axes:
 - Name: ['params']['decoder']['layers']['mlp']['wi_0']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
 - Name: ['params']['decoder']['layers']['mlp']['wi_1']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
 - Name: ['params']['decoder']['layers']['mlp']['wo']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
 - Name: ['params']['decoder']['layers']['self_attention']['out']['kernel'](Size: 536870912, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
 - Name: ['params']['decoder']['layers']['self_attention']['query']['kernel'](Size: 536870912, Shape: PartitionSpec(), Spec: PartitionSpec())  is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check PR #2502 for another solution to avoid pinning flax

gcsfs
google-api-python-client
google-cloud-aiplatform
Expand Down
Loading