Skip to content

Commit 9e890a8

Browse files
committed
fix
1 parent 19c5ce5 commit 9e890a8

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

mason.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,14 +477,21 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami:
477477
try:
478478
model_arg_idx = filtered_command.index("--model_name_or_path")
479479
model_name_idx = model_arg_idx + 1
480-
model_name_or_path = filtered_command[model_name_idx]
480+
model_name_or_path = filtered_command[model_name_idx].rstrip("/")
481481

482482
if model_name_or_path.startswith("gs://"):
483483
model_name_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8]
484-
local_cache_folder = f"{args.auto_output_dir_path}/{whoami}/tokenizer_{model_name_hash}"
484+
local_cache_folder = f"{args.auto_output_dir_path}/{whoami}/tokenizer_{model_name_hash}/"
485485

486486
if not os.path.exists(local_cache_folder):
487-
download_from_gs_bucket(f"{model_name_or_path}/tokenizer*", local_cache_folder)
487+
download_from_gs_bucket(
488+
[
489+
f"{model_name_or_path}/tokenizer.json",
490+
f"{model_name_or_path}/tokenizer_config.json",
491+
f"{model_name_or_path}/config.json",
492+
],
493+
local_cache_folder,
494+
)
488495

489496
filtered_command[model_name_idx] = local_cache_folder
490497
except ValueError:

open_instruct/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,8 @@ def download_from_hf(model_name_or_path: str, revision: str) -> None:
10641064
return output
10651065

10661066

1067-
def download_from_gs_bucket(src_path: str, dest_path: str) -> None:
1067+
def download_from_gs_bucket(src_paths: str | list[str], dest_path: str) -> None:
1068+
os.makedirs(dest_path, exist_ok=True)
10681069
cmd = [
10691070
"gsutil",
10701071
"-o",
@@ -1074,9 +1075,11 @@ def download_from_gs_bucket(src_path: str, dest_path: str) -> None:
10741075
"-m",
10751076
"cp",
10761077
"-r",
1077-
src_path,
1078-
dest_path,
10791078
]
1079+
if not isinstance(src_paths, list):
1080+
src_paths = [src_paths]
1081+
cmd.extend(src_paths)
1082+
cmd.append(dest_path)
10801083
print(f"Downloading from GS bucket with command: {cmd}")
10811084
live_subprocess_output(cmd)
10821085

0 commit comments

Comments
 (0)