diff --git a/tests/vec_inf/client/test_slurm_script_generator.py b/tests/vec_inf/client/test_slurm_script_generator.py index 3d141af8..a81a962d 100644 --- a/tests/vec_inf/client/test_slurm_script_generator.py +++ b/tests/vec_inf/client/test_slurm_script_generator.py @@ -12,6 +12,14 @@ ) +@pytest.fixture(autouse=True) +def patch_model_weights_exists(monkeypatch): + """Ensure model weights directory existence checks default to True.""" + monkeypatch.setattr( + "vec_inf.client._slurm_script_generator.Path.exists", lambda self: True + ) + + class TestSlurmScriptGenerator: """Tests for SlurmScriptGenerator class.""" @@ -168,6 +176,21 @@ def test_generate_server_setup_singularity(self, singularity_params): "module load " in setup ) # Remove module name since it's inconsistent between clusters + def test_generate_server_setup_singularity_no_weights( + self, singularity_params, monkeypatch + ): + """Test server setup when model weights don't exist.""" + monkeypatch.setattr( + "vec_inf.client._slurm_script_generator.Path.exists", + lambda self: False, + ) + + generator = SlurmScriptGenerator(singularity_params) + setup = generator._generate_server_setup() + + assert "ray stop" in setup + assert "/path/to/model_weights/test-model" not in setup + def test_generate_launch_cmd_venv(self, basic_params): """Test launch command generation with virtual environment.""" generator = SlurmScriptGenerator(basic_params) @@ -187,6 +210,22 @@ def test_generate_launch_cmd_singularity(self, singularity_params): assert "apptainer exec --nv" in launch_cmd assert "source" not in launch_cmd + def test_generate_launch_cmd_singularity_no_local_weights( + self, singularity_params, monkeypatch + ): + """Test container launch when model weights directory is missing.""" + monkeypatch.setattr( + "vec_inf.client._slurm_script_generator.Path.exists", + lambda self: False, + ) + + generator = SlurmScriptGenerator(singularity_params) + launch_cmd = generator._generate_launch_cmd() + + assert "exec --nv" in launch_cmd + assert "--bind /path/to/model_weights/test-model" not in launch_cmd + assert "vllm serve test-model" in launch_cmd + def test_generate_launch_cmd_boolean_args(self, basic_params): """Test launch command with boolean vLLM arguments.""" params = basic_params.copy() @@ -391,6 +430,24 @@ def test_generate_model_launch_script_singularity( mock_touch.assert_called_once() mock_write_text.assert_called_once() + @patch("pathlib.Path.touch") + @patch("pathlib.Path.write_text") + def test_generate_model_launch_script_singularity_no_weights( + self, mock_write_text, mock_touch, batch_singularity_params, monkeypatch + ): + """Test batch model launch script when model weights don't exist.""" + monkeypatch.setattr( + "vec_inf.client._slurm_script_generator.Path.exists", + lambda self: False, + ) + + generator = BatchSlurmScriptGenerator(batch_singularity_params) + script_path = generator._generate_model_launch_script("model1") + + assert script_path.name == "launch_model1.sh" + call_args = mock_write_text.call_args[0][0] + assert "/path/to/model_weights/model1" not in call_args + @patch("vec_inf.client._slurm_script_generator.datetime") @patch("pathlib.Path.touch") @patch("pathlib.Path.write_text") diff --git a/vec_inf/client/_slurm_script_generator.py b/vec_inf/client/_slurm_script_generator.py index 15571715..01b786ea 100644 --- a/vec_inf/client/_slurm_script_generator.py +++ b/vec_inf/client/_slurm_script_generator.py @@ -37,8 +37,18 @@ def __init__(self, params: dict[str, Any]): self.additional_binds = ( f",{self.params['bind']}" if self.params.get("bind") else "" ) - self.model_weights_path = str( - Path(self.params["model_weights_parent_dir"], self.params["model_name"]) + model_weights_path = Path( + self.params["model_weights_parent_dir"], self.params["model_name"] + ) + self.model_weights_exists = model_weights_path.exists() + self.model_weights_path = str(model_weights_path) + self.model_source = ( + self.model_weights_path + if self.model_weights_exists + else self.params["model_name"] + ) + self.model_bind_option = ( + f",{self.model_weights_path}" if self.model_weights_exists else "" ) self.env_str = self._generate_env_str() @@ -111,7 +121,9 @@ def _generate_server_setup(self) -> str: server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"])) server_script.append( SLURM_SCRIPT_TEMPLATE["bind_path"].format( - model_weights_path=self.model_weights_path, + model_weights_path=self.model_weights_path + if self.model_weights_exists + else "", additional_binds=self.additional_binds, ) ) @@ -131,7 +143,6 @@ def _generate_server_setup(self) -> str: server_setup_str = server_setup_str.replace( "CONTAINER_PLACEHOLDER", SLURM_SCRIPT_TEMPLATE["container_command"].format( - model_weights_path=self.model_weights_path, env_str=self.env_str, ), ) @@ -165,22 +176,27 @@ def _generate_launch_cmd(self) -> str: Server launch command. """ launcher_script = ["\n"] + + vllm_args_copy = self.params["vllm_args"].copy() + model_source = self.model_source + if "--model" in vllm_args_copy: + model_source = vllm_args_copy.pop("--model") + if self.use_container: launcher_script.append( SLURM_SCRIPT_TEMPLATE["container_command"].format( - model_weights_path=self.model_weights_path, env_str=self.env_str, ) ) launcher_script.append( "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format( - model_weights_path=self.model_weights_path, + model_source=model_source, model_name=self.params["model_name"], ) ) - for arg, value in self.params["vllm_args"].items(): + for arg, value in vllm_args_copy.items(): if isinstance(value, bool): launcher_script.append(f" {arg} \\") else: @@ -225,11 +241,20 @@ def __init__(self, params: dict[str, Any]): if self.params["models"][model_name].get("bind") else "" ) - self.params["models"][model_name]["model_weights_path"] = str( - Path( - self.params["models"][model_name]["model_weights_parent_dir"], - model_name, - ) + model_weights_path = Path( + self.params["models"][model_name]["model_weights_parent_dir"], + model_name, + ) + model_weights_exists = model_weights_path.exists() + model_weights_path_str = str(model_weights_path) + self.params["models"][model_name]["model_weights_path"] = ( + model_weights_path_str + ) + self.params["models"][model_name]["model_weights_exists"] = ( + model_weights_exists + ) + self.params["models"][model_name]["model_source"] = ( + model_weights_path_str if model_weights_exists else model_name ) def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path: @@ -266,7 +291,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path: script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"]) script_content.append( BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format( - model_weights_path=model_params["model_weights_path"], + model_weights_path=model_params["model_weights_path"] + if model_params.get("model_weights_exists", True) + else "", additional_binds=model_params["additional_binds"], ) ) @@ -283,19 +310,25 @@ def _generate_model_launch_script(self, model_name: str) -> Path: model_name=model_name, ) ) + vllm_args_copy = model_params["vllm_args"].copy() + model_source = model_params.get( + "model_source", model_params["model_weights_path"] + ) + if "--model" in vllm_args_copy: + model_source = vllm_args_copy.pop("--model") + if self.use_container: script_content.append( - BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format( - model_weights_path=model_params["model_weights_path"], - ) + BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format() ) script_content.append( "\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format( - model_weights_path=model_params["model_weights_path"], + model_source=model_source, model_name=model_name, ) ) - for arg, value in model_params["vllm_args"].items(): + + for arg, value in vllm_args_copy.items(): if isinstance(value, bool): script_content.append(f" {arg} \\") else: diff --git a/vec_inf/client/_slurm_templates.py b/vec_inf/client/_slurm_templates.py index 43d91f61..ab607fa1 100644 --- a/vec_inf/client/_slurm_templates.py +++ b/vec_inf/client/_slurm_templates.py @@ -98,7 +98,7 @@ class SlurmScriptTemplate(TypedDict): f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop", ], "imports": "source {src_dir}/find_port.sh", - "bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}", + "bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g'),/dev,/tmp{{model_weights_path}}{{additional_binds}}", "container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\", "activate_venv": "source {venv}/bin/activate", "server_setup": { @@ -164,7 +164,7 @@ class SlurmScriptTemplate(TypedDict): ' && mv temp.json "$json_path"', ], "launch_cmd": [ - "vllm serve {model_weights_path} \\", + "vllm serve {model_source} \\", " --served-model-name {model_name} \\", ' --host "0.0.0.0" \\', " --port $vllm_port_number \\", @@ -255,7 +255,7 @@ class BatchModelLaunchScriptTemplate(TypedDict): ], "container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\", "launch_cmd": [ - "vllm serve {model_weights_path} \\", + "vllm serve {model_source} \\", " --served-model-name {model_name} \\", ' --host "0.0.0.0" \\', " --port $vllm_port_number \\",