diff --git a/mason.py b/mason.py index 6a0b54a1f..9a886160d 100644 --- a/mason.py +++ b/mason.py @@ -15,7 +15,7 @@ from rich.console import Console from rich.text import Text -from open_instruct.utils import GCP_CLUSTERS, INTERCONNECT_CLUSTERS, WEKA_CLUSTERS +from open_instruct.utils import GCP_CLUSTERS, INTERCONNECT_CLUSTERS, WEKA_CLUSTERS, download_from_gs_bucket console = Console() @@ -464,6 +464,9 @@ def make_internal_command(command: list[str], args: argparse.Namespace, whoami: continue filtered_command = build_command_without_args(command[idx:], CACHE_EXCLUDED_ARGS) + filtered_command = maybe_download_tokenizer_from_gs_bucket( + filtered_command, args.auto_output_dir, whoami + ) caching_command = "python " + " ".join(filtered_command) + " --cache_dataset_only" console.log("📦📦📦 Running the caching command with `--cache_dataset_only`") import subprocess @@ -780,6 +783,35 @@ def make_task_spec(args, full_command: str, i: int, beaker_secrets: str, whoami: return spec +def maybe_download_tokenizer_from_gs_bucket(filtered_command: str, auto_output_dir_path: str, whoami: str): + """if model is only on gs, download tokenizer from gs to local cache folder for dataset preprocessing""" + + if "--model_name_or_path" not in filtered_command: + return filtered_command + + model_arg_idx = filtered_command.index("--model_name_or_path") + model_name_idx = model_arg_idx + 1 + model_name_or_path = filtered_command[model_name_idx].rstrip("/") + + if model_name_or_path.startswith("gs://"): + model_name_hash = hashlib.md5(model_name_or_path.encode("utf-8")).hexdigest()[:8] + local_cache_folder = f"{auto_output_dir_path}/{whoami}/tokenizer_{model_name_hash}/" + + if not os.path.exists(local_cache_folder): + download_from_gs_bucket( + [ + f"{model_name_or_path}/tokenizer.json", + f"{model_name_or_path}/tokenizer_config.json", + f"{model_name_or_path}/config.json", + ], + local_cache_folder, + ) + + filtered_command[model_name_idx] = local_cache_folder + + return filtered_command + + def main(): args, commands = get_args() # If the user is not in Ai2, we run the command as is diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index c46742e42..dafbe8719 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -14,6 +14,7 @@ # Copied from https://github.com/huggingface/alignment-handbook/blob/main/tests/test_data.py import json import pathlib +import tempfile import time import unittest from unittest import mock @@ -262,6 +263,37 @@ def test_send_slack_alert_with_beaker_url(self, mock_environ_get, mock_get_beake self.assertIn("Test error message", request_body["text"]) +class TestDownloadFromGsBucket(unittest.TestCase): + def test_download_from_gs_bucket(self): + src_paths = ["gs://bucket/data1", "gs://bucket/data2"] + + with tempfile.TemporaryDirectory() as tmp_dir: + dest_path = pathlib.Path(tmp_dir) / "downloads" + captured_cmd: dict[str, list[str]] = {} + + def mock_live_subprocess_output(cmd): + captured_cmd["cmd"] = cmd + + with mock.patch.object(utils, "live_subprocess_output", side_effect=mock_live_subprocess_output): + utils.download_from_gs_bucket(src_paths=src_paths, dest_path=str(dest_path)) + + expected_cmd = [ + "gsutil", + "-o", + "GSUtil:parallel_thread_count=1", + "-o", + "GSUtil:sliced_object_download_threshold=150", + "-m", + "cp", + "-r", + *src_paths, + str(dest_path), + ] + + self.assertEqual(captured_cmd["cmd"], expected_cmd) + self.assertTrue(dest_path.exists()) + + class TestUtilityFunctions(unittest.TestCase): """Test utility functions in utils module.""" diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 1c76d22b4..119b05281 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -1103,7 +1103,8 @@ def download_from_hf(model_name_or_path: str, revision: str) -> None: return output -def download_from_gs_bucket(src_path: str, dest_path: str) -> None: +def download_from_gs_bucket(src_paths: list[str], dest_path: str) -> None: + os.makedirs(dest_path, exist_ok=True) cmd = [ "gsutil", "-o", @@ -1113,9 +1114,9 @@ def download_from_gs_bucket(src_path: str, dest_path: str) -> None: "-m", "cp", "-r", - src_path, - dest_path, ] + cmd.extend(src_paths) + cmd.append(dest_path) print(f"Downloading from GS bucket with command: {cmd}") live_subprocess_output(cmd)