File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
checkpoint/orbax/checkpoint/_src/path Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -528,7 +528,11 @@ def _find_all_with_single_host_load_and_broadcast(
528528
529529 def find_all (self , base_path : epath .PathLike ) -> Iterator [Metadata ]:
530530 """Returns metadata of all steps matching with name_format attributes."""
531- if multihost .process_count () > 1 and self .single_host_load_and_broadcast :
531+ # Note: the order of conjuncts is important here; we should not call
532+ # `multihost.process_count()` when `single_host_load_and_broadcast` is False
533+ # as this has the possible side effect of initializing the jax backend. See
534+ # b/454565916 for details.
535+ if self .single_host_load_and_broadcast and multihost .process_count () > 1 :
532536 return self ._find_all_with_single_host_load_and_broadcast (base_path )
533537
534538 # <step_prefix>_?<0 padding>?*
You can’t perform that action at this time.
0 commit comments