-
Notifications
You must be signed in to change notification settings - Fork 418
Fix generate_metrics_and_upload_to_big_query
#2405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix generate_metrics_and_upload_to_big_query
#2405
Conversation
generate_metrics_and_upload_to_big_query
7f29fd5 to
87e8861
Compare
| absl-py | ||
| aqtp | ||
| array-record | ||
| benchmark_db_writer@git+https://github.com/CIeNET-International/aotc.git@c0bef62eac87c99152ff2e9fd48da1f7d9f3cc04#subdirectory=src/aotc/benchmark_db_writer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we depending on a specific commit from a forked repo? Can we not upstream that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We forked the repo from https://github.com/AI-Hypercomputer/aotc/tree/main/src/aotc/benchmark_db_writer and made some fixes since we got no response from the original repo issue AI-Hypercomputer/aotc#1 and talked with @SujeethJinesh and he is okay with using a forked repo for now.
About using the specific commit, since the forked repo does not have strict rules of merging, we would like to set it to a specific commit just in case the latest has new bugs when implementing new features. @SujeethJinesh WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the change now, I think we should definitely make the fix in the main aotc repo or at least depend on a branch off the aotc repo rather than a fork of it under different ownership. Would it be possible to do that instead?
Please create a bug for this internally and I can follow up with the aotc folks about making appropriate fixes there instead of in a forked repo.
Seems like it should be simple enough to actually do so since I don't think the changes you needed to make were very large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SujeethJinesh Created b/450288198 for this issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to push a branch to aotc repo but got no permission error.
| cloud-tpu-diagnostics | ||
| datasets | ||
| flax | ||
| flax==0.11.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we pinning to this version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only this version works with benchmark_db_writer, otherwise it will encounter the following error:
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/deps/MaxText/train.py", line 761, in <module>
app.run(main)
File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/deps/MaxText/train.py", line 757, in main
run(config, recorder, diagnostic_config)
File "/deps/MaxText/train.py", line 752, in run
train_loop(config, recorder)
File "/deps/MaxText/train.py", line 618, in train_loop
) = setup_train_loop(config, recorder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/train.py", line 554, in setup_train_loop
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/maxtext_utils.py", line 942, in setup_training_state
return setup_initial_state(
^^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/maxtext_utils.py", line 981, in setup_initial_state
unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(
^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/maxtext_utils.py", line 1038, in get_abstract_state
abstract_state = jax.eval_shape(init_state_partial)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/maxtext_utils.py", line 892, in init_initial_state
model_vars = model.init(
^^^^^^^^^^^
File "/deps/MaxText/layers/models.py", line 126, in __call__
logits, hidden_state = self.decoder(
^^^^^^^^^^^^^
File "/deps/MaxText/layers/decoders.py", line 610, in __call__
y = self._apply_embedding(
^^^^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/layers/decoders.py", line 505, in _apply_embedding
y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/layers/nnx_wrappers.py", line 426, in __call__
self._update_variables(module)
File "/deps/MaxText/layers/nnx_wrappers.py", line 491, in _update_variables
collection_state = jax.tree.map(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/jax/_src/tree.py", line 155, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/layers/nnx_wrappers.py", line 485, in _to_linen_var
return self.metadata_fn(x) # pylint: disable=too-many-function-args
^^^^^^^^^^^^^^^^^^^
File "/deps/MaxText/layers/initializers.py", line 56, in variable_to_logically_partitioned
variable.sharding, # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/flax/nnx/variablelib.py", line 281, in __getattr__
return getattr(self.raw_value, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: The 'sharding' attribute is not available on traced array with shape float32[128256,4096].
The error occurred while tracing the function init_initial_state at /deps/MaxText/maxtext_utils.py:882 for jit. This value became a tracer due to JAX operations on these lines:
operation a:key<urbg>[] = random_wrap[impl=urbg] b
from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)
operation a:key<urbg>[] = random_fold_in b 3279144704:u32[]
from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)
operation a:key<urbg>[] = random_wrap[impl=urbg] b
from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)
operation a:key<urbg>[] = random_fold_in b 3279144704:u32[]
from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)
operation a:key<urbg>[] = random_wrap[impl=urbg] b
from line /deps/MaxText/layers/nnx_wrappers.py:293:10 (linen_rngs_dict)
(Additional originating lines are not shown.)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The incompatibility was caused by installing the dependencies of benchmark_db_writer, it will install another version higher than 0.11.1
Please check b/441984274 for context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error seems to be coming from JAX because it thinks the variable is a tracer: https://github.com/jax-ml/jax/blob/5dbbfc38c99b193f43c5273b02263d91cd04a560/jax/_src/core.py#L1047
This may need to be a separate bug fix in MaxText. Specifically, we may want to add this line here
if isinstance(variable.value, jax.core.Tracer):
return variable.value
This should help avoid pinning flax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SujeethJinesh I tried
if isinstance(variable.value, jax.core.Tracer):
return variable.value
and unpinned flax, it used flax 0.12 and got the following issue:
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/deps/src/MaxText/train.py", line 510, in <module>
app.run(main)
File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "/usr/local/lib/python3.12/site-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/deps/src/MaxText/train.py", line 506, in main
run(config, recorder, diagnostic_config)
File "/deps/src/MaxText/train.py", line 501, in run
train_loop(config, recorder)
File "/deps/src/MaxText/train.py", line 364, in train_loop
) = train_utils.setup_train_loop(config, recorder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/deps/src/MaxText/train_utils.py", line 204, in setup_train_loop
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
File "/deps/src/MaxText/maxtext_utils.py", line 805, in assert_params_sufficiently_sharded
_raise_if_unsharded_exceeds_tolerance(
File "/deps/src/MaxText/maxtext_utils.py", line 773, in _raise_if_unsharded_exceeds_tolerance
raise AssertionError("\n".join(error_msg_lines))
AssertionError: Unsharded parameter percentage (25.00%)exceeds tolerance (2.00%).
The following large tensors are replicated (unsharded) but could be sharded on at least one of the available axes:
- Name: ['params']['decoder']['layers']['mlp']['wi_0']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec()) is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
- Name: ['params']['decoder']['layers']['mlp']['wi_1']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec()) is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
- Name: ['params']['decoder']['layers']['mlp']['wo']['kernel'](Size: 1879048192, Shape: PartitionSpec(), Spec: PartitionSpec()) is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
- Name: ['params']['decoder']['layers']['self_attention']['out']['kernel'](Size: 536870912, Shape: PartitionSpec(), Spec: PartitionSpec()) is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
- Name: ['params']['decoder']['layers']['self_attention']['query']['kernel'](Size: 536870912, Shape: PartitionSpec(), Spec: PartitionSpec()) is unsharded on axis: ['fsdp'] could be sharded on: ['fsdp']
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check PR #2502 for another solution to avoid pinning flax
bdd517d to
a298159
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
| absl-py | ||
| aqtp | ||
| array-record | ||
| benchmark_db_writer@git+https://github.com/CIeNET-International/aotc.git@c0bef62eac87c99152ff2e9fd48da1f7d9f3cc04#subdirectory=src/aotc/benchmark_db_writer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the change now, I think we should definitely make the fix in the main aotc repo or at least depend on a branch off the aotc repo rather than a fork of it under different ownership. Would it be possible to do that instead?
Please create a bug for this internally and I can follow up with the aotc folks about making appropriate fixes there instead of in a forked repo.
Seems like it should be simple enough to actually do so since I don't think the changes you needed to make were very large.
| cloud-tpu-diagnostics | ||
| datasets | ||
| flax | ||
| flax==0.11.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error seems to be coming from JAX because it thinks the variable is a tracer: https://github.com/jax-ml/jax/blob/5dbbfc38c99b193f43c5273b02263d91cd04a560/jax/_src/core.py#L1047
This may need to be a separate bug fix in MaxText. Specifically, we may want to add this line here
if isinstance(variable.value, jax.core.Tracer):
return variable.value
This should help avoid pinning flax.
a298159 to
5fc3364
Compare
Description
Fix
generate_metrics_and_upload_to_big_queryflag and related dependencies.FIXES: b/441984274, b/446097400
Tests
Project preparation:
benchmarks/recipes/user_configs.pychange:Adding the following flags to
bq_enabletoTruebq_db_projectandbq_db_datasetto the user's target project and the dataset in the target projectRun the command
python3 -m benchmarks.recipes.pw_mcjax_benchmark_recipeand after 20 steps, one big query record will be written to therun_summarytable of the target datasetChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.