diff --git a/ads/jobs/builders/infrastructure/dsc_job.py b/ads/jobs/builders/infrastructure/dsc_job.py index 674d97ec1..cad652bbd 100644 --- a/ads/jobs/builders/infrastructure/dsc_job.py +++ b/ads/jobs/builders/infrastructure/dsc_job.py @@ -1,7 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2021, 2024 Oracle and/or its affiliates. +# Copyright (c) 2021, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from __future__ import annotations @@ -21,30 +20,33 @@ import oci import oci.data_science import oci.util as oci_util +import yaml +from oci.data_science import models from oci.data_science.models import JobInfrastructureConfigurationDetails from oci.exceptions import ServiceError -import yaml + from ads.common import utils +from ads.common.decorator.utils import class_or_instance_method +from ads.common.dsc_file_system import ( + DSCFileSystemManager, + OCIFileStorage, + OCIObjectStorage, +) from ads.common.oci_datascience import DSCNotebookSession, OCIDataScienceMixin from ads.common.oci_logging import OCILog from ads.common.oci_resource import ResourceNotFoundError from ads.jobs.builders.infrastructure.base import Infrastructure, RunInstance from ads.jobs.builders.infrastructure.dsc_job_runtime import ( + MULTI_NODE_JOB_SUPPORT, ContainerRuntimeHandler, DataScienceJobRuntimeManager, ) from ads.jobs.builders.infrastructure.utils import get_value from ads.jobs.builders.runtimes.artifact import Artifact +from ads.jobs.builders.runtimes.base import MultiNodeRuntime from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime -from ads.common.dsc_file_system import ( - OCIFileStorage, - DSCFileSystemManager, - OCIObjectStorage, -) -from ads.common.decorator.utils import class_or_instance_method - logger = logging.getLogger(__name__) SLEEP_INTERVAL = 3 @@ -52,6 +54,7 @@ MAXIMUM_MOUNT_COUNT = 5 FILE_STORAGE_TYPE = "FILE_STORAGE" OBJECT_STORAGE_TYPE = "OBJECT_STORAGE" +DEFAULT_NODE_GROUP_NAME = "node-group" class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job): @@ -284,11 +287,15 @@ def load_properties_from_env(self) -> None: def load_defaults(self) -> DSCJob: self.load_properties_from_env() + if getattr(self, "job_node_configuration_details", None): + return self + # Following are for single node job run only if not self.job_infrastructure_configuration_details: self.job_infrastructure_configuration_details = {} + # Convert the dict to JobInfrastructureConfigurationDetails object if isinstance(self.job_infrastructure_configuration_details, dict): - # Default networking + if not self.job_infrastructure_configuration_details.get( "jobInfrastructureType" ): @@ -352,6 +359,7 @@ def create(self) -> DSCJob: raise ValueError("Specify compartment ID for data science job.") if not self.project_id: raise ValueError("Specify project ID for data science job.") + self._create_with_oci_api() return self @@ -498,7 +506,9 @@ def run(self, **kwargs) -> DataScienceJobRun: keys = list(kwargs.keys()) for key in keys: if key in config_swagger_types: - config_kwargs[key] = kwargs.pop(key) + val = kwargs.pop(key) + if val is not None: + config_kwargs[key] = val elif key in env_config_swagger_types: value = kwargs.pop(key) if key in [ @@ -545,6 +555,25 @@ def run(self, **kwargs) -> DataScienceJobRun: env_config_override ) + if getattr(self, "job_node_configuration_details", None): + job_config_override = kwargs.pop("job_configuration_override_details", None) + env_config_override = kwargs.pop( + "job_environment_configuration_override_details", None + ) + if job_config_override or env_config_override: + node_config = { + "jobNodeType": "MULTI_NODE", + "jobNodeGroupConfigurationDetailsList": [ + { + # Node group name must match the node group name in the job. + "name": DEFAULT_NODE_GROUP_NAME, + "JobConfigurationDetails": job_config_override, + "JobEnvironmentConfigurationDetails": env_config_override, + } + ], + } + kwargs["job_node_configuration_override_details"] = node_config + wait = kwargs.pop("wait", False) run = DataScienceJobRun(**kwargs, **self.auth).create() if wait: @@ -756,13 +785,11 @@ def stop_condition(): return True # Stop only if time_finished is over 2 minute ago. # This is for the time delay between job run stopped and the logs appear in oci logging. - if ( + return ( datetime.datetime.now(self.time_finished.tzinfo) - datetime.timedelta(seconds=wait) > self.time_finished - ): - return True - return False + ) if not self.log_id and not self.log_group_id: print( @@ -1471,6 +1498,23 @@ def _update_from_dsc_model( } self.dsc_job = dsc_job + # Process multi-node infrastructure config + node_groups = get_value( + dsc_job, + "job_node_configuration_details.job_node_group_configuration_details_list", + ) + if node_groups and len(node_groups) == 1: + node_group = node_groups[0] + dsc_job.job_infrastructure_configuration_details = ( + node_group.job_infrastructure_configuration_details + ) + subnet_id = get_value( + dsc_job, + "job_node_configuration_details.job_network_configuration.subnet_id", + ) + if subnet_id: + self.set_spec(self.CONST_SUBNET_ID, subnet_id) + for infra_attr, dsc_attr in self.payload_attribute_map.items(): value = get_value(dsc_job, dsc_attr) if not value: @@ -1557,10 +1601,13 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob: if value: dsc_job.job_infrastructure_configuration_details[camel_attr] = value - if not dsc_job.job_infrastructure_configuration_details.get( - "shapeName", "" - ).endswith("Flex") and dsc_job.job_infrastructure_configuration_details.get( - "jobShapeConfigDetails" + shape = dsc_job.job_infrastructure_configuration_details.get("shapeName", "") + if ( + shape + and not str(shape).endswith("Flex") + and dsc_job.job_infrastructure_configuration_details.get( + "jobShapeConfigDetails" + ) ): raise ValueError( "Shape config is not required for non flex shape from user end." @@ -1583,7 +1630,6 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob: return self def build(self) -> DataScienceJob: - self.dsc_job.load_defaults() try: self.dsc_job.load_defaults() @@ -1611,6 +1657,48 @@ def init(self, **kwargs) -> DataScienceJob: ) ) + def _config_multi_node(self, runtime: MultiNodeRuntime): + """Configure the payload for multi-node job run.""" + infra_config: dict = self.dsc_job.job_infrastructure_configuration_details + job_config: models.DefaultJobConfigurationDetails = ( + self.dsc_job.job_configuration_details + ) + env_config = self.dsc_job.job_environment_configuration_details + # For multi-node jobs, + # the job_infrastructure_configuration_details and job_configuration_details + # should be the special EMPTY class. + # The job_environment_configuration_details should be None. + # The configs will be specified in each node group. + self.dsc_job.job_infrastructure_configuration_details = None + self.dsc_job.job_configuration_details = None + self.dsc_job.job_environment_configuration_details = None + + subnet_id = infra_config.pop("subnetId", None) + infra_config["jobInfrastructureType"] = ( + models.MultiNodeJobInfrastructureConfigurationDetails.JOB_INFRASTRUCTURE_TYPE_MULTI_NODE + ) + + if subnet_id: + network_config = models.JobCustomNetworkConfiguration(subnet_id=subnet_id) + else: + network_config = models.JobDefaultNetworkConfiguration() + + node_group_config: dict = { + "name": DEFAULT_NODE_GROUP_NAME, + "replicas": runtime.replica, + "minimumSuccessReplicas": runtime.replica, + "jobInfrastructureConfigurationDetails": infra_config, + "jobConfigurationDetails": job_config, + "jobEnvironmentConfigurationDetails": env_config, + } + + self.dsc_job.job_node_configuration_details = { + "jobNodeType": "MULTI_NODE", + "startupOrder": "IN_PARALLEL", + "jobNetworkConfiguration": network_config, + "jobNodeGroupConfigurationDetailsList": [node_group_config], + } + def create(self, runtime, **kwargs) -> DataScienceJob: """Creates a job with runtime. @@ -1635,9 +1723,7 @@ def create(self, runtime, **kwargs) -> DataScienceJob: if self.name: display_name = Template(self.name).safe_substitute(runtime.envs) - elif isinstance(runtime, GitPythonRuntime) or isinstance( - runtime, ContainerRuntime - ): + elif isinstance(runtime, (GitPythonRuntime, ContainerRuntime)): display_name = utils.get_random_name_for_resource() else: display_name = None @@ -1652,11 +1738,22 @@ def create(self, runtime, **kwargs) -> DataScienceJob: self.dsc_job = DSCJob(**payload, **self.auth) # Set Job infra to user values after DSCJob initialized the defaults self._update_job_infra(self.dsc_job) + if self.is_multi_node_job(runtime): + self._config_multi_node(runtime=runtime) self.dsc_job.create() # Update the model from infra after job creation. self._update_from_dsc_model(self.dsc_job) return self + @staticmethod + def is_multi_node_job(runtime): + """Check if the job is multi-node job.""" + return ( + MULTI_NODE_JOB_SUPPORT + and isinstance(runtime, MultiNodeRuntime) + and runtime.replica > 1 + ) + def run( self, name=None, diff --git a/ads/jobs/builders/infrastructure/dsc_job_runtime.py b/ads/jobs/builders/infrastructure/dsc_job_runtime.py index 8e13a3680..a5099f32e 100644 --- a/ads/jobs/builders/infrastructure/dsc_job_runtime.py +++ b/ads/jobs/builders/infrastructure/dsc_job_runtime.py @@ -1,7 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2021, 2024 Oracle and/or its affiliates. +# Copyright (c) 2021, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ """Contains classes for conversion between ADS runtime and OCI Data Science Job implementation. This module is for ADS developers only. @@ -19,29 +18,37 @@ import shlex from typing import Optional from urllib import parse + +import oci + from ads.common.utils import extract_region +from ads.jobs.builders.infrastructure.utils import get_value +from ads.jobs.builders.runtimes.artifact import ( + GitPythonArtifact, + NotebookArtifact, + PythonArtifact, + ScriptArtifact, +) from ads.jobs.builders.runtimes.base import Runtime +from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime from ads.jobs.builders.runtimes.python_runtime import ( CondaRuntime, - ScriptRuntime, - PythonRuntime, - NotebookRuntime, GitPythonRuntime, + NotebookRuntime, + PythonRuntime, + ScriptRuntime, ) -from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime from ads.jobs.builders.runtimes.pytorch_runtime import ( - PyTorchDistributedRuntime, PyTorchDistributedArtifact, + PyTorchDistributedRuntime, ) -from ads.jobs.builders.runtimes.artifact import ( - ScriptArtifact, - NotebookArtifact, - PythonArtifact, - GitPythonArtifact, -) -from ads.opctl.distributed.common import cluster_config_helper -from ads.jobs.builders.infrastructure.utils import get_value from ads.jobs.templates import driver_utils +from ads.opctl.distributed.common import cluster_config_helper + +if hasattr(oci.data_science.models, "MultiNodeJobInfrastructureConfigurationDetails"): + MULTI_NODE_JOB_SUPPORT = True +else: + MULTI_NODE_JOB_SUPPORT = False class IncompatibleRuntime(Exception): # pragma: no cover @@ -77,6 +84,9 @@ class RuntimeHandler: # Defines the class of the runtime to be handled. RUNTIME_CLASS = Runtime + CONST_WORKER_COUNT = "OCI__WORKER_COUNT" + CONST_NODE_COUNT = "NODE_COUNT" + def __init__(self, data_science_job) -> None: """Initialize the runtime handler. @@ -285,7 +295,7 @@ def extract(self, dsc_job): * _extract_artifact() * _extract_runtime_minutes() Each of these method returns a dict for specifying the runtime. - The dictionaries are combined before initalizing the runtime. + The dictionaries are combined before initializing the runtime. A sub-class can modify one of more of these methods. Parameters @@ -349,6 +359,30 @@ def _extract_args(self, dsc_job) -> dict: return {Runtime.CONST_ARGS: shlex.split(args_string)} return {} + def _get_node_group(self, dsc_job): + """Gets the node group for multi-node job with single node group.""" + node_groups = get_value( + dsc_job, + "job_node_configuration_details.job_node_group_configuration_details_list", + ) + if node_groups and len(node_groups) == 1: + return node_groups[0] + return None + + def _get_replica(self, dsc_job, envs): + node_group = self._get_node_group(dsc_job) + if node_group: + replica = get_value(node_group, "replicas") + elif not envs: + replica = None + elif self.CONST_WORKER_COUNT in envs: + replica = int(envs.pop(self.CONST_WORKER_COUNT)) + 1 + elif self.CONST_NODE_COUNT in envs: + replica = int(envs.pop(self.CONST_NODE_COUNT)) + else: + replica = None + return replica + def _extract_envs(self, dsc_job): """Extract the environment variables from data science job. @@ -362,7 +396,12 @@ def _extract_envs(self, dsc_job): dict A runtime specification dictionary for initializing a runtime. """ - envs = get_value(dsc_job, "job_configuration_details.environment_variables") + env_attr = "job_configuration_details.environment_variables" + node_group = self._get_node_group(dsc_job) + if node_group: + envs = get_value(node_group, env_attr) + else: + envs = get_value(dsc_job, env_attr) if envs: return {Runtime.CONST_ENV_VAR: envs} return {} @@ -968,6 +1007,12 @@ def translate(self, runtime: Runtime) -> dict: payload["job_environment_configuration_details"] = job_env_config return payload + def _translate_env(self, runtime): + envs = super()._translate_env(runtime) + if runtime.replica: + envs[self.CONST_NODE_COUNT] = str(runtime.replica) + return envs + def _translate_artifact(self, runtime: ContainerRuntime): """Additional artifact for the container""" if runtime.artifact_uri: @@ -1049,6 +1094,10 @@ def _extract_envs(self, dsc_job): if envs: spec[ContainerRuntime.CONST_ENV_VAR] = envs + replica = self._get_replica(dsc_job=dsc_job, envs=envs) + if replica: + spec[ContainerRuntime.CONST_REPLICA] = replica + return spec def _extract_properties(self, dsc_job) -> dict: @@ -1081,7 +1130,6 @@ def _extract_properties(self, dsc_job) -> dict: class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler): RUNTIME_CLASS = PyTorchDistributedRuntime - CONST_WORKER_COUNT = "OCI__WORKER_COUNT" CONST_COMMAND = "OCI__LAUNCH_CMD" CONST_DEEPSPEED = "OCI__DEEPSPEED" @@ -1105,8 +1153,7 @@ def _translate_artifact(self, runtime: PyTorchDistributedRuntime): def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict: envs = super()._translate_env(runtime) replica = runtime.replica if runtime.replica else 1 - # WORKER_COUNT = REPLICA - 1 so that it will be same as distributed training - envs[self.CONST_WORKER_COUNT] = str(replica - 1) + envs[self.CONST_NODE_COUNT] = str(replica) envs[self.CONST_JOB_ENTRYPOINT] = PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT if runtime.inputs: envs[driver_utils.CONST_ENV_INPUT_MAPPINGS] = json.dumps(runtime.inputs) @@ -1131,12 +1178,12 @@ def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict: def _extract_envs(self, dsc_job) -> dict: spec = super()._extract_envs(dsc_job) envs = spec.pop(PythonRuntime.CONST_ENV_VAR, {}) - if self.CONST_WORKER_COUNT not in envs: + replica = self._get_replica(dsc_job, envs=envs) + + if not replica: raise IncompatibleRuntime() # Replicas - spec[PyTorchDistributedRuntime.CONST_REPLICA] = ( - int(envs.pop(self.CONST_WORKER_COUNT)) + 1 - ) + spec[PyTorchDistributedRuntime.CONST_REPLICA] = replica # Git if cluster_config_helper.OCI__RUNTIME_URI in envs: git_spec = {} diff --git a/ads/jobs/builders/runtimes/base.py b/ads/jobs/builders/runtimes/base.py index 5cfa9461c..e929f9150 100644 --- a/ads/jobs/builders/runtimes/base.py +++ b/ads/jobs/builders/runtimes/base.py @@ -1,17 +1,16 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2022, 2024 Oracle and/or its affiliates. +# Copyright (c) 2022, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from __future__ import annotations + import re import time import traceback - from typing import Dict, TypeVar -from ads.jobs.builders.base import Builder -from ads.jobs import env_var_parser +from ads.jobs import env_var_parser +from ads.jobs.builders.base import Builder Self = TypeVar("Self", bound="Runtime") @@ -285,6 +284,9 @@ def replica(self) -> int: def run(self, dsc_job, **kwargs): """Starts the job runs""" + # For multi-node job, there is no need to create multiple job run. + if getattr(dsc_job, "job_node_configuration_details", None): + return dsc_job.run(**kwargs) replicas = self.replica if self.replica else 1 main_run = None job_runs = [] diff --git a/ads/jobs/builders/runtimes/pytorch_runtime.py b/ads/jobs/builders/runtimes/pytorch_runtime.py index 02037d052..679d45247 100644 --- a/ads/jobs/builders/runtimes/pytorch_runtime.py +++ b/ads/jobs/builders/runtimes/pytorch_runtime.py @@ -1,19 +1,19 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2023, 2024 Oracle and/or its affiliates. +# Copyright (c) 2023, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from ads.jobs.builders.runtimes.artifact import PythonArtifact, GitPythonArtifact +from ads.jobs.builders.runtimes.artifact import GitPythonArtifact, PythonArtifact from ads.jobs.builders.runtimes.base import MultiNodeRuntime from ads.jobs.builders.runtimes.python_runtime import ( - PythonRuntime, GitPythonRuntime, + PythonRuntime, ) class PyTorchDistributedRuntime(PythonRuntime, MultiNodeRuntime): """Represents runtime supporting PyTorch Distributed training.""" + CONST_GIT = "git" CONST_INPUT = "inputs" CONST_DEP = "dependencies" @@ -169,13 +169,11 @@ def with_command(self, command: str, use_deepspeed=False): def command(self): """The command for launching the workload.""" return self.get_spec(self.CONST_COMMAND) - + @property def use_deepspeed(self): """Indicate whether whether to configure deepspeed for multi-node workload""" - if self.get_spec(self.CONST_DEEPSPEED): - return True - return False + return bool(self.get_spec(self.CONST_DEEPSPEED)) class PyTorchDistributedArtifact(PythonArtifact): diff --git a/ads/jobs/templates/driver_pytorch.py b/ads/jobs/templates/driver_pytorch.py index f1bd53043..2ec5e820b 100644 --- a/ads/jobs/templates/driver_pytorch.py +++ b/ads/jobs/templates/driver_pytorch.py @@ -1,26 +1,27 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2023, 2024 Oracle and/or its affiliates. +# Copyright (c) 2023, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -"""This module requires oracle-ads>=2.6.8 -""" +"""This module requires oracle-ads>=2.6.8 and python>=3.8""" import getpass import ipaddress +import json import logging import multiprocessing import os -import time import shlex import socket import sys +import time import traceback +import fsspec import oci import psutil import torch + from ads import set_auth -from ads.jobs import DataScienceJobRun +from ads.jobs import DataScienceJob, DataScienceJobRun from ads.jobs.builders.infrastructure.dsc_job_runtime import ( PythonRuntimeHandler, ) @@ -29,13 +30,13 @@ try: # This is used by ADS and testing from . import driver_utils - from .driver_oci import GitSSHKey, GitManager - from .oci_metrics import collect_metrics, METRIC_NAMESPACE + from .driver_oci import GitManager, GitSSHKey + from .oci_metrics import METRIC_NAMESPACE, collect_metrics except ImportError: # This is used when the script is in a job run. import driver_utils - from driver_oci import GitSSHKey, GitManager - from oci_metrics import collect_metrics, METRIC_NAMESPACE + from driver_oci import GitManager, GitSSHKey + from oci_metrics import METRIC_NAMESPACE, collect_metrics logger = logging.getLogger(__name__) logger = driver_utils.set_log_level(logger) @@ -50,21 +51,36 @@ CONST_ENV_NODE_COUNT = "NODE_COUNT" CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD" CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED" +CONST_ENV_LOG_OUTPUT = "OCI__LOG_OUTPUT" # Envs set by this module CONST_ENV_WORLD_SIZE = "WORLD_SIZE" CONST_ENV_LD_PRELOAD = "LD_PRELOAD" # Envs for debugging only +CONST_ENV_SET_SOCKET_IFNAME = "SET_SOCKET_IFNAME" # OCI_ODSC_SERVICE_ENDPOINT is used for all processes in the job run CONST_ENV_ODSC_SERVICE_ENDPOINT = "OCI_ODSC_SERVICE_ENDPOINT" # OCI_DS_SERVICE_ENDPOINT is used only by the training process CONST_ENV_DS_SERVICE_ENDPOINT = "OCI_DS_SERVICE_ENDPOINT" +# DTv2 environment variables +CONST_ENV_INITIAL_CLUSTER_SIZE = "INITIAL_CLUSTER_SIZE" +CONST_ENV_META_FILE = "CLUSTER_NODES_METADATA_FILE" +# DTv2 metadata variables +CONST_IP_ADDRESS = "IPAddress" +CONST_RANK = "Rank" + + +CONST_ENCODING = "utf-8" + # Constants used in logs LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: " LOG_PREFIX_NODE_IP = "Node IP: " LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: " +LOG_PREFIX_HOST_KEY_RSA = "NODE HOST KEY RSA: " +LOG_PREFIX_HOST_KEY_ECDSA = "NODE HOST KEY ECDSA: " # Other constants used within this script -# Other constants used within this script +HOST_KEY_PATH_RSA = "/etc/ssh/ssh_host_rsa_key.pub" +HOST_KEY_PATH_ECDSA = "/etc/ssh/ssh_host_ecdsa_key.pub" USER_HOME = os.environ.get("HOME", f"/home/{getpass.getuser()}") SSH_DIR = os.environ.get("OCI__SSH_DIR", os.path.join(USER_HOME, ".ssh")) DEFAULT_LAUNCHER = "torchrun" @@ -122,42 +138,78 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None: super().__init__(code_dir) self.launch_cmd = os.environ.get(CONST_ENV_LAUNCH_CMD, "") - self.ds_client = driver_utils.OCIHelper.init_oci_client( - oci.data_science.DataScienceClient - ) - self.ip = self.find_self_ip() - # IP address of other nodes as a list - self.node_ip_list = [] - # DataScienceJobRun objects of other nodes as a list - self.node_runs = [] - - if CONST_ENV_HOST_JOB_RUN_OCID in os.environ: - # Print the node IP address to logs so that it can be obtained by the host. - print(f"{LOG_PREFIX_NODE_IP}{self.ip}") - self.host_ocid = os.environ[CONST_ENV_HOST_JOB_RUN_OCID] - logger.debug("Host job run OCID: %s", self.host_ocid) - self.host_ip = None - self.is_host = False - else: - # Print the host IP address to logs so that it can be obtained by the nodes. - print(f"{LOG_PREFIX_HOST_IP}{self.ip}") - self.host_ocid = os.environ.get(CONST_ENV_JOB_RUN_OCID) - self.host_ip = self.ip - self.is_host = True + logger.debug(os.environ) - self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid) - self.entrypoint_env = PythonRuntimeHandler.CONST_CODE_ENTRYPOINT - # The total number of nodes is OCI__WORKER_COUNT + 1 - if CONST_ENV_NODE_COUNT in os.environ: + # Node count + if CONST_ENV_INITIAL_CLUSTER_SIZE in os.environ: + self.node_count = int(os.environ[CONST_ENV_INITIAL_CLUSTER_SIZE]) + elif CONST_ENV_NODE_COUNT in os.environ: self.node_count = int(os.environ[CONST_ENV_NODE_COUNT]) else: + # The total number of nodes is OCI__WORKER_COUNT + 1 self.node_count = int(os.environ.get(OCI__WORKER_COUNT, 0)) + 1 logger.debug("Node count: %s", self.node_count) + self.gpu_count = torch.cuda.device_count() logger.debug("GPU count on this node: %s", self.gpu_count) + if self.gpu_count > 0: + logger.debug("GPU name: %s", torch.cuda.get_device_name()) + + # IP address of other nodes as a list + self.node_ip_list = [] + # For DTv2, node_runs should not be used. + self.node_runs = None + self.host_ocid = None + self.host_job_run = None + + self.node_rank = int(os.environ.get(CONST_ENV_NODE_RANK, 0)) + + hostname = socket.gethostname() + logger.debug("Hostname: %s", hostname) + logger.debug( + "Get Host by Addr: %s", LazyEvaluate(socket.gethostbyaddr, hostname) + ) + logger.debug("FQDN: %s", LazyEvaluate(socket.getfqdn, hostname)) + + # Read metadata file for DTv2 + self.rank_to_ip = self.read_metadata() + if self.rank_to_ip: + logger.debug(self.rank_to_ip) + # DTv2 + self.ip = self.rank_to_ip[self.node_rank] + self.host_ip = self.rank_to_ip[0] + self.is_host = self.node_rank == 0 + self.node_ip_list = list(self.rank_to_ip.values()) + self._set_socket_interface(self._get_interface_name()) + # DeepSpeed worker will check job logs to determine the public SSH key. + self.host_ocid = os.environ.get(CONST_ENV_JOB_RUN_OCID) + else: + # DTv1 + self.ip = self.find_self_ip() + if CONST_ENV_HOST_JOB_RUN_OCID in os.environ: + # Print the node IP address to logs so that it can be obtained by the host. + print(f"{LOG_PREFIX_NODE_IP}{self.ip}", flush=True) + self.host_ocid = os.environ[CONST_ENV_HOST_JOB_RUN_OCID] + logger.debug("Host job run OCID: %s", self.host_ocid) + self.host_ip = None + self.is_host = False + else: + # Print the host IP address to logs so that it can be obtained by the nodes. + print(f"{LOG_PREFIX_HOST_IP}{self.ip}", flush=True) + self.host_ocid = os.environ.get(CONST_ENV_JOB_RUN_OCID) + self.host_ip = self.ip + self.is_host = True + + # host_job_run is needed for DTv1 to fetch the IP addresses from logs. + if self.host_ocid and self.node_count > 1: + self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid) + self.entrypoint_env = PythonRuntimeHandler.CONST_CODE_ENTRYPOINT logger.debug("Runner initialized.") + def is_dtv2(self): + return CONST_ENV_META_FILE in os.environ + def launch_cmd_contains(self, arg) -> bool: """Checks if the cmd for launching the training contains specific keyword argument.""" return f"--{arg}" in self.launch_cmd @@ -204,7 +256,7 @@ def wait_for_ip_address(self, job_run, timeout=15 * 60) -> str: logger.info("IP of %s: %s", job_run.id[-6:], ip_address) return ip_address - def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str: + def wait_for_log(self, job_run, log_prefix, timeout=15 * 60, limit=1) -> str: """Waits until a log message with specific prefix is found in the logs of a job run. Parameters @@ -223,27 +275,33 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str: Raises ------ - TimeoutError + LoggingError Failed to obtain the log message within the specific timeout. """ logger.debug( "Waiting for logs with prefix '%s' from %s.", log_prefix, job_run.id ) second_started = time.time() - log = None - while not log: - log = self.check_job_run_logs(job_run=job_run, log_prefix=log_prefix) - if log: + logs = None + while True: + logs = self.check_job_run_logs(job_run=job_run, log_prefix=log_prefix) + if logs and len(logs) >= limit: + logs = logs[:limit] break if time.time() - second_started > timeout: - raise TimeoutError( - f"Failed to obtain log with prefix {log_prefix} for {job_run.id} in {timeout} seconds." + logs = job_run.logs() + last_log = logs[-1]["message"] if len(logs) > 0 else "" + raise Exception( + f"Failed to obtain log with prefix {log_prefix} for {job_run.id} in {timeout} seconds.\n" + f"Last log obtained: {last_log}" ) time.sleep(60) - return log + if limit == 1: + return logs[0] + return logs @staticmethod - def check_job_run_logs(job_run, log_prefix: str) -> str: + def check_job_run_logs(job_run, log_prefix: str) -> list: """Checks the logs of a specific job run and find the log message with specific prefix. Parameters @@ -260,45 +318,111 @@ def check_job_run_logs(job_run, log_prefix: str) -> str: """ logger.debug("Checking logs for job run %s", job_run.id) logs = job_run.logs() - for log in logs: - if log["message"].startswith(log_prefix): - return log["message"][len(log_prefix) :] - return None + logs = [ + log["message"][len(log_prefix) :] + for log in logs + if log["message"].startswith(log_prefix) + ] + return logs def find_self_ip(self): """ Identify IP address by finding which of the host IP intersects with the CIDR block of the subnet associated with the JOB_OCID """ - hostname = socket.gethostname() - logger.debug("Hostname: %s", hostname) - logger.debug( - "Get Host by Addr: %s", LazyEvaluate(socket.gethostbyaddr, hostname) - ) - logger.debug("FQDN: %s", LazyEvaluate(socket.getfqdn, hostname)) - if os.environ.get("JOB_OCID"): - subnet_id = self.ds_client.get_job( - os.environ["JOB_OCID"] - ).data.job_infrastructure_configuration_details.subnet_id + if os.environ.get("JOB_OCID") and self.node_count > 1: + subnet_id = DataScienceJob.from_id(os.environ["JOB_OCID"]).subnet_id core_client = driver_utils.OCIHelper.init_oci_client( oci.core.VirtualNetworkClient ) cidr = core_client.get_subnet(subnet_id).data.cidr_block + self_ip = None for interface, snics in psutil.net_if_addrs().items(): ip = snics[0].address + logger.debug("IFNAME: %s, IP: %s", interface, ip) if ipaddress.ip_address(ip) in ipaddress.ip_network(cidr): + self_ip = ip logger.info("Node IP address: %s", ip) - # Specify the network interface for NCCL/GLOO - os.environ["GLOO_SOCKET_IFNAME"] = interface - os.environ["NCCL_SOCKET_IFNAME"] = interface - return ip - raise EnvironmentError("Unable to determine node IP address.") + + self._set_socket_interface(interface) + if self_ip: + return self_ip + raise OSError("Unable to determine node IP address.") else: - ip = socket.gethostbyname(hostname) + ip = socket.gethostbyname(socket.gethostname()) logger.info("Node IP address: %s", ip) return ip + def _set_socket_interface(self, interface: str): + """Sets the socket interface environment variables, + NCCL_SOCKET_IFNAME and GLOO_SOCKET_IFNAME. + + When `SET_SOCKET_IFNAME` is found in env var and the value is not empty, + the value will be used and the `interface` argument will be ignored. + + NCCL/GLOO will match the interface using prefix. + https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-socket-ifname + + """ + # Specify the network interface for NCCL/GLOO + if os.environ.get(CONST_ENV_SET_SOCKET_IFNAME): + interface = os.environ[CONST_ENV_SET_SOCKET_IFNAME] + + # Set the env vars only if user has not set it already + if not os.environ.get("GLOO_SOCKET_IFNAME"): + logger.debug("Setting GLOO_SOCKET_IFNAME to %s", interface) + os.environ["GLOO_SOCKET_IFNAME"] = interface + if not os.environ.get("NCCL_SOCKET_IFNAME"): + logger.debug("Setting NCCL_SOCKET_IFNAME to %s", interface) + os.environ["NCCL_SOCKET_IFNAME"] = interface + + def _get_interface_name(self): + node_interface = None + for interface, snics in psutil.net_if_addrs().items(): + ip = snics[0].address + logger.debug("IFNAME: %s, IP: %s", interface, ip) + if ip == self.ip: + node_interface = interface + return node_interface + + def read_metadata(self): + """Reads the metadata for DTv2 to get the rank and IP address mapping.""" + if CONST_ENV_META_FILE not in os.environ: + return None + metadata_file = os.environ.get(CONST_ENV_META_FILE) + error_count = 0 + while True: + if not os.path.exists(metadata_file): + logger.debug("Waiting for file %s to be available...", metadata_file) + time.sleep(20) + continue + logger.debug("Reading %s...", metadata_file) + with open(metadata_file, encoding=CONST_ENCODING) as f: + try: + node_list = json.load(f) + except Exception as ex: + # log the content of the file for debugging purpose. + logger.debug("Error occurred when reading metadata file:") + f.seek(0) + logger.debug(f.read()) + error_count += 1 + node_list = [] + if error_count > 3: + raise ex + + if len(node_list) < self.node_count: + logger.debug( + "Waiting for nodes... found %s of %s", + len(node_list), + self.node_count, + ) + time.sleep(20) + continue + logger.debug("All nodes are found in metadata file.") + logger.debug(node_list) + return {int(meta[CONST_RANK]): meta[CONST_IP_ADDRESS] for meta in node_list} + def fetch_code(self): """Fetches source code from Git if repo uri is specified.""" if cluster_config_helper.OCI__RUNTIME_URI in os.environ: @@ -370,10 +494,7 @@ def prepare_cmd(self, launch_args: list = None, prefix=""): else: launch_args.append(self.get_cmd_with_entrypoint_and_args()) - if prefix: - launcher = f"{prefix} {self.LAUNCHER}" - else: - launcher = self.LAUNCHER + launcher = f"{prefix} {self.LAUNCHER}" if prefix else self.LAUNCHER return f"{launcher} {' '.join(launch_args)}" @@ -383,8 +504,16 @@ def time_cmd(self, cmd): self.run_command("pwd", level=logging.DEBUG) # Show all environment variables self.run_command("printenv", level=logging.DEBUG) + if CONST_ENV_DS_SERVICE_ENDPOINT in os.environ: + envs = { + CONST_ENV_ODSC_SERVICE_ENDPOINT: os.environ[ + CONST_ENV_DS_SERVICE_ENDPOINT + ] + } + else: + envs = None training_start_time = time.time() - self.run_command(cmd, conda_prefix=self.conda_prefix, check=True) + self.run_command(cmd, conda_prefix=self.conda_prefix, check=True, envs=envs) logger.info("Time: %s seconds.", time.time() - training_start_time) def run(self): @@ -397,6 +526,7 @@ class TorchRunner(Runner): def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None: super().__init__(code_dir) + logger.debug("Initializing Torch Runner...") self.build_c_library() def build_c_library(self): @@ -442,10 +572,7 @@ def get_rdzv_conf(self) -> str: return rdzv_conf def run(self): - if self.gpu_count > 0: - nproc_per_node = self.gpu_count - else: - nproc_per_node = 1 + nproc_per_node = self.gpu_count if self.gpu_count > 0 else 1 launch_args = [] # Add nnode, nproc_per_node and rdzv args only if they are not specified by the user. @@ -471,24 +598,119 @@ class DeepSpeedRunner(Runner): HOST_FILE = "/home/datascience/hostfile" ENV_FILE = os.path.expanduser("~/.deepspeed_env") LAUNCHER = "deepspeed" + TMPDIR = os.environ.get("TMPDIR") def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None: super().__init__(code_dir) - self.update_os() - - def update_os(self): - # Generate SSH host keys for SSH server - self.run_command("sudo ssh-keygen -A", level=logging.DEBUG, check=True) - # Install SSH server to accept SSH connections - # DeepSpeed uses "hostname -I" to determine the IP address - # pdsh is required for default multi node training - # torch cpp extension uses which command to find compiler - # DeepSpeed async_io requires libaio-devel + logger.debug("Initializing DeepSpeed Runner...") + # Setup DeepSpeed if it used. + if self.use_deepspeed(): + self.host_key = None + self.deepspeed_setup() + + def use_deepspeed(self): + """Indicate if DeepSpeed is used.""" + # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument. + return bool( + os.environ.get(CONST_ENV_DEEPSPEED) + or self.launch_cmd_contains("use_deepspeed") + or self.launch_cmd_contains("deepspeed") + ) + + def deepspeed_setup(self): + """Setup for DeepSpeed.""" + self.host_key = HOST_KEY_PATH_RSA if os.path.exists(HOST_KEY_PATH_RSA) else None + # Create the temp dir if one does not exist. + # This is needed for JIT + if self.TMPDIR and not os.path.isdir(self.TMPDIR): + logger.info("Creating temp directory: %s", self.TMPDIR) + os.makedirs(self.TMPDIR, exist_ok=True) + self.install_deepspeed_dependencies() + # host_job_run is needed for DeepSpeed to fetch the public SSH key from the logs. + if self.host_ocid and self.node_count > 1: + self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid) + + def install_epel(self): + """Installs oracle-epel-release.""" + for ol_version in ["8", "9"]: + if ( + self.run_command( + f'cat /etc/oracle-release | grep "release {ol_version}"', + level=logging.DEBUG, + ) + == 0 + ): + self.run_command( + f"sudo --preserve-env microdnf install -y oracle-epel-release-el{ol_version}" + ) + break + + def _print_host_key(self, host_key_path, prefix): + with open(host_key_path, encoding=CONST_ENCODING) as f: + public_key = f.read() + print(f"{prefix}{self.ip}-{public_key}") + + def _add_known_hosts_from_file(self, ip_addr, key_file): + if not os.path.exists(key_file): + logger.warning( + "Unable to add host key %s to known_hosts: key file not found.", + key_file, + ) + return self.run_command( - "sudo --preserve-env yum install -y openssh-server hostname pdsh which libaio-devel", + f"echo -n '{ip_addr} ' | " f"cat - {key_file} >> {SSH_DIR}/known_hosts", level=logging.DEBUG, check=True, ) + + def _add_known_hosts_from_log(self, job_run, prefix, ip_address=None): + ip_key = self.wait_for_log(job_run, f"{prefix}") + ip_addr, public_key = ip_key.split("-", 1) + if ip_address: + ip_addr = ip_address + with open(f"{SSH_DIR}/known_hosts", "a+", encoding=CONST_ENCODING) as f: + line = f"{ip_addr} {public_key}" + f.write(f"{line}\n") + logger.debug("Added host key: %s", line) + + def install_deepspeed_dependencies(self): + """Installs extra dependencies and start SSH service.""" + if self.node_count == 1: + logger.debug( + "Skipped installing extra dependencies for single node training." + ) + return + + # Check if host keys exist + if self.host_key: + logger.debug( + "Skipped SSH host key generation.\nHost keys found: %s", self.host_key + ) + else: + # Generate SSH host keys for SSH server + self.run_command("sudo ssh-keygen -A", level=logging.DEBUG, check=True) + self._print_host_key(HOST_KEY_PATH_RSA, LOG_PREFIX_HOST_KEY_RSA) + self._print_host_key(HOST_KEY_PATH_ECDSA, LOG_PREFIX_HOST_KEY_ECDSA) + + if self.run_command("which pdsh", level=logging.DEBUG) != 0: + # Install "openssh-server" to accept SSH connections + # DeepSpeed uses "hostname -I" to determine the IP address + # "pdsh" is required for default multi node training + # torch cpp extension uses "which" command to find compiler + # DeepSpeed async_io requires "libaio-devel" + if self.run_command("which microdnf", level=logging.DEBUG) == 0: + self.install_epel() + self.run_command( + "sudo --preserve-env microdnf install -y openssh-server hostname pdsh pdsh-rcmd-ssh libaio-devel", + level=logging.DEBUG, + check=True, + ) + elif self.run_command("which yum", level=logging.DEBUG) == 0: + self.run_command( + "sudo --preserve-env yum install -y openssh-server hostname pdsh which libaio-devel", + level=logging.DEBUG, + check=True, + ) # Start SSH service self.run_command("sudo /usr/sbin/sshd", level=logging.DEBUG, check=True) @@ -496,15 +718,13 @@ def generate_key_pair(self): self.run_command( "ssh-keygen -q -t rsa -N '' <<< $'\ny'", level=logging.DEBUG, check=True ) - with open(os.path.join(SSH_DIR, "id_rsa.pub"), "r", encoding="utf-8") as f: + with open(os.path.join(SSH_DIR, "id_rsa.pub"), encoding=CONST_ENCODING) as f: public_key = f.read() - print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}") - self.add_authoried_key(public_key) - self.run_command( - f"ssh-keyscan -H {self.host_ip} >> {SSH_DIR}/known_hosts", - level=logging.DEBUG, - check=True, - ) + print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}", flush=True) + self._add_authoried_key(public_key) + # Add host key to known hosts + self._add_known_hosts_from_file(self.host_ip, HOST_KEY_PATH_RSA) + self._add_known_hosts_from_file(self.host_ip, HOST_KEY_PATH_ECDSA) self.test_ssh_connection(self.host_ip) # Check DeepSpeed compatibility self.run_command( @@ -512,64 +732,70 @@ def generate_key_pair(self): ) return self - @staticmethod - def add_authoried_key(public_key): + def _add_authoried_key(self, public_key): auth_keys_file = os.path.join(SSH_DIR, "authorized_keys") os.makedirs(SSH_DIR, exist_ok=True) - with open(auth_keys_file, "a+", encoding="utf-8") as f: + with open(auth_keys_file, "a+", encoding=CONST_ENCODING) as f: f.write(public_key) f.write("\n") - logger.debug("Public key saved to %s", auth_keys_file) + logger.debug("Public key saved to %s:%s", self.ip, auth_keys_file) def fetch_host_public_key(self): public_key = self.wait_for_log(self.host_job_run, LOG_PREFIX_PUBLIC_KEY) - print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}") - # logger.debug("%s", LOG_PREFIX_PUBLIC_KEY + public_key) - self.add_authoried_key(public_key) + print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}", flush=True) + self._add_authoried_key(public_key) def generate_hostfile(self): - runs = self.host_job_run.job.run_list() - self.node_runs = [ - run - for run in runs - if run.status in ["ACCEPTED", "IN_PROGRESS"] and run.id != self.host_ocid - ] - self.node_ip_list = [self.wait_for_ip_address(run) for run in self.node_runs] + if not self.node_ip_list: + runs = self.host_job_run.job.run_list() + self.node_runs = [ + run + for run in runs + if run.status in ["ACCEPTED", "IN_PROGRESS"] + and run.id != self.host_ocid + ] + self.node_ip_list = [ + self.wait_for_ip_address(run) for run in self.node_runs + ] logger.info("Node IPs: %s", self.node_ip_list) # Hostfile logger.debug("Writing hostfile to %s", self.HOST_FILE) os.makedirs(os.path.dirname(self.HOST_FILE), exist_ok=True) - host_file_content = [f"{ip} slots={self.gpu_count}" for ip in self.node_ip_list] - with open(self.HOST_FILE, "w", encoding="utf-8") as f: - f.write(f"{self.host_ip} slots={self.gpu_count}\n") + host_file_content = [ + f"{ip} slots={self.gpu_count}\n" for ip in self.node_ip_list + ] + with open(self.HOST_FILE, "w", encoding=CONST_ENCODING) as f: + if self.host_ip not in self.node_ip_list: + f.write(f"{self.host_ip} slots={self.gpu_count}\n") f.writelines(host_file_content) self.run_command(f"cat {self.HOST_FILE}", level=logging.DEBUG) # SSH config ssh_config_path = os.path.join(SSH_DIR, "config") logger.debug("Writing SSH config to %s", ssh_config_path) - with open(ssh_config_path, "w", encoding="utf-8") as f: + with open(ssh_config_path, "w", encoding=CONST_ENCODING) as f: f.writelines( [ - "", - f"Host {self.host_ip}", - "IdentityFile /home/datascience/.ssh/id_rsa", - "User datascience", + "\n", + f"Host {self.host_ip}\n", + "KexAlgorithms diffie-hellman-group-exchange-sha256\n", ] ) for node_ip in self.node_ip_list: + if node_ip == self.host_ip: + continue f.writelines( [ - "", - f"Host {node_ip}", - "IdentityFile /home/datascience/.ssh/id_rsa", - "User datascience", + "\n", + f"Host {node_ip}\n", + "KexAlgorithms diffie-hellman-group-exchange-sha256\n", ] ) + self.run_command(f"cat {ssh_config_path}", level=logging.DEBUG) return self def test_ssh_connection(self, host): ret = self.run_command( - f"ssh -v -o PasswordAuthentication=no {host} hostname -I", + f"ssh -vvv -o PasswordAuthentication=no {host} hostname -I", level=logging.DEBUG, ) if ret == 0: @@ -582,9 +808,8 @@ def touch_file(self, filename): for node_ip in self.node_ip_list: logger.debug("Sending stop file to %s", node_ip) self.run_command( - f"ssh -v {node_ip} 'touch {filename}'", + f"ssh -v -o PasswordAuthentication=no {node_ip} 'touch {filename}'", level=logging.DEBUG, - check=True, ) def save_deepspeed_env(self): @@ -593,31 +818,55 @@ def save_deepspeed_env(self): the environment variables configured by the job runs are not propagated to the SSH session. DeepSpeed will load the environment variables from file for the SSH sessions. """ - with open(self.ENV_FILE, mode="w", encoding="utf-8") as f: + import deepspeed + + try: + version = deepspeed.__version__ + minor_version = int(version.split(".")[1]) + except Exception: + version = 0 + minor_version = 0 + + with open(self.ENV_FILE, mode="w", encoding=CONST_ENCODING) as f: for k, v in os.environ.items(): - # As of deepspeed==0.9.2, empty value or line break will cause parsing error, + # Empty value or line break may cause parsing error, # as the .deepspeed_env file is parsed line by line. if not v or "\n" in v: + logger.debug("Skipped saving %s as deepspeed env.", k) continue # Ignore variables that are node specific - # The network interface name for each job run is a unique string, e.g. ens300f0v1604 - if k in ["NCCL_SOCKET_IFNAME", "GLOO_SOCKET_IFNAME", "JOB_RUN_OCID"]: + # The network interface name for each job run could be a unique string, e.g. ens300f0v1604 + # Deepspeed will copy the SOCKET_IFNAME values to all nodes if they are set. + if k in [ + "NCCL_SOCKET_IFNAME", + "GLOO_SOCKET_IFNAME", + "JOB_RUN_OCID", + "NODE_RANK", + ]: + logger.debug("Skipped saving %s as deepspeed env.", k) continue - # Quote the value if it contains space - # Environment variable containing space may not be exported correctly when using pdsh - # https://github.com/microsoft/DeepSpeed/blob/v0.9.2/deepspeed/launcher/multinode_runner.py#L79 - if " " in v: + # For DeepSpeed < 0.15.2, no extra quotes are added by DeepSpeed + # shelex.quote() will make sure the variable is exported correctly. + if minor_version < 15 or version in ["0.15.1", "0.15.0"]: v = shlex.quote(v) + # As v0.16.4, DeepSpeed is wrapping the value with double quotes. + # Escape the double quotes so that they can be exported correctly. + # This logic may need to be updated with the future version of DeepSpeed. + # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/launcher/multinode_runner.py#L37 + # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/launcher/multinode_runner.py#L90 + # https://github.com/deepspeedai/DeepSpeed/pull/5878 + # https://github.com/deepspeedai/DeepSpeed/pull/7071 + elif '"' in v: + v = v.replace('"', '\\"') + f.write(f"{k}={v}\n") - # The following are required for specifying the network interface to be used by NCCL/GLOO - # The value should be the prefix of the expected network interface name - # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-socket-ifname - f.write("NCCL_SOCKET_IFNAME=ens\n") - f.write("GLOO_SOCKET_IFNAME=ens\n") logger.debug("Environment variables saved to %s", self.ENV_FILE) self.run_command(f"cat {self.ENV_FILE}") + def wait_for_nodes(self): + pass + def run_deepspeed_host(self, launch_args=None): """Prepares the host and launch the deepspeed training. @@ -633,15 +882,41 @@ def run_deepspeed_host(self, launch_args=None): self.generate_key_pair().generate_hostfile() self.save_deepspeed_env() # Wait for nodes to be ready - for run in self.node_runs: - self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY) - - for node_ip in self.node_ip_list: - self.run_command( - f"ssh-keyscan -H {node_ip} >> {SSH_DIR}/known_hosts", - level=logging.DEBUG, - check=True, + # For DTv2, self.node_runs will be None + if self.is_dtv2(): + self.wait_for_log( + self.host_job_run, LOG_PREFIX_PUBLIC_KEY, limit=self.node_count ) + else: + for run in self.node_runs: + self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY) + + if self.host_key: + # If host key exists, it should be the same for all nodes. + for node_ip in self.node_ip_list: + self._add_known_hosts_from_file(node_ip, HOST_KEY_PATH_RSA) + self._add_known_hosts_from_file(node_ip, HOST_KEY_PATH_ECDSA) + elif self.is_dtv2(): + # If host key did not exist, it it generated on the fly, + # Each node will have a different key. + # We will need to check the logs for the public key. + logger.debug("Adding node host keys to known_hosts...") + for node_ip in self.node_ip_list: + self._add_known_hosts_from_log( + self.host_job_run, + LOG_PREFIX_HOST_KEY_RSA + node_ip, + ip_address=node_ip, + ) + self._add_known_hosts_from_log( + self.host_job_run, + LOG_PREFIX_HOST_KEY_ECDSA + node_ip, + ip_address=node_ip, + ) + else: + logger.debug("Adding job run host keys to known_hosts...") + for run in self.node_runs: + self._add_known_hosts_from_log(run, LOG_PREFIX_HOST_KEY_RSA) + self._add_known_hosts_from_log(run, LOG_PREFIX_HOST_KEY_ECDSA) cmd = self.prepare_cmd(launch_args) # For DeepSpeed, we only need to run the cmd on the host @@ -663,6 +938,9 @@ def run_deepspeed_worker(self): if os.path.exists(self.ERROR_FILE): logger.error("There is an error in the host job run.") sys.exit(1) + # Check host job run only if it is not None + if self.host_job_run is None: + continue # Stop the node if the host job run is CANCELLED or in unexpected state. try: self.host_job_run.sync() @@ -693,23 +971,23 @@ def run(self): class GenericRunner(TorchRunner, DeepSpeedRunner): - """Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``.""" + """Runner for running command that may use ``torchrun`` or ``deepspeed``.""" LAUNCHER = "" - def use_deepspeed(self) -> bool: - """Indicate if DeepSpeed is used.""" - if os.environ.get(CONST_ENV_DEEPSPEED): - return True - return False + def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None: + super().__init__(code_dir) + logger.debug("Initializing Generic Runner...") def set_env_var(self): """Set default environment variables.""" defaults = { - "WORLD_SIZE": self.node_count * self.gpu_count, + CONST_ENV_WORLD_SIZE: self.node_count * self.gpu_count, "MASTER_ADDR": self.host_ip, "MASTER_PORT": self.RDZV_PORT, } + if self.node_count == 1: + defaults["RANK"] = 0 for k, v in defaults.items(): if k not in os.environ: os.environ[k] = str(v) @@ -734,7 +1012,7 @@ def run(self): self.time_cmd(cmd=self.prepare_cmd(prefix=self.env_ld_preload())) -class AccelerateRunner(TorchRunner, DeepSpeedRunner): +class AccelerateRunner(GenericRunner): """Runner for HuggingFace Accelerate.""" # accelerate launch will add main_process_port for deepspeed cmd even if it is not needed. @@ -750,14 +1028,18 @@ class AccelerateRunner(TorchRunner, DeepSpeedRunner): LAUNCHER = "accelerate launch" def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None: + # Here we need to call GenericRunner.__init__() explicitly + # to avoid calling the DeepSpeedRunner.__init__(). super().__init__(code_dir) + logger.debug("Initializing Accelerate Runner...") # For "accelerate launch", only one of the following options can be used at one time # `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`. # When a config file is not provided, # --multi_gpu will be set automatically if there is more than 1 GPU # self.multi_gpu = bool(self.node_count > 1 or self.gpu_count > 1) self.num_machines = self.node_count - self.machine_rank = os.environ["NODE_RANK"] + # Machine rank is needed for accelerate launch to work correctly + self.machine_rank = self.node_rank # Total number of processes across all nodes # Here we assume all nodes are having the same shape self.num_processes = (self.gpu_count if self.gpu_count else 1) * self.node_count @@ -766,15 +1048,6 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None: # Host IP is not ready at initialization self.main_process_ip = None - def use_deepspeed(self): - """Indicate if DeepSpeed is used.""" - # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument. - if os.environ.get(CONST_ENV_DEEPSPEED) or self.launch_cmd_contains( - "use_deepspeed" - ): - return True - return False - def accelerate_args(self): """Gets the default arguments for the accelerate command. The value of the default arguments are assigned in ``__init__()``. @@ -785,8 +1058,11 @@ def accelerate_args(self): logger.debug("%s=%s", arg, arg_val) if arg_val is True: args.append(f"--{arg}") - elif arg_val: + elif arg_val is not None: args.extend([f"--{arg}", str(arg_val)]) + # --use_deepspeed is needed for deepspeed to work on single GPU + if self.use_deepspeed() and not self.launch_cmd_contains("use_deepspeed"): + args.append("--use_deepspeed") return args def run_with_torchrun(self): @@ -822,6 +1098,23 @@ def run(self): def main(): + # Collect GPU metrics only if GPU is available and user defined METRIC_NAMESPACE + if METRIC_NAMESPACE and torch.cuda.device_count(): + p = multiprocessing.Process(target=collect_metrics) + p.daemon = True + p.start() + + # Merge the CLI Arguments with CMD specified in env var + if len(sys.argv) > 1: + # Expand the environment variables before shlex.join + # as it will quote the arg with single quotes. + argv = [os.path.expandvars(arg) for arg in sys.argv[1:]] + if os.environ.get(CONST_ENV_LAUNCH_CMD): + os.environ[CONST_ENV_LAUNCH_CMD] = ( + shlex.join(argv) + " " + os.environ.get(CONST_ENV_LAUNCH_CMD) + ) + else: + os.environ[CONST_ENV_LAUNCH_CMD] = shlex.join(argv) launch_cmd = os.environ.get(CONST_ENV_LAUNCH_CMD) if not launch_cmd or launch_cmd.startswith("torchrun "): # Use torchrun as default if launch cmd is not provided @@ -832,21 +1125,42 @@ def main(): runner_class = AccelerateRunner else: runner_class = GenericRunner - + logger.debug("Using %s", str(runner_class)) runner = runner_class() + runner: Runner runner.fetch_code().set_working_dir().setup_python_path().install_dependencies() driver_utils.OCIHelper.copy_inputs() - - runner.wait_for_host_ip_address().run() + if not runner.host_ip: + runner.wait_for_host_ip_address() + runner.run() driver_utils.OCIHelper.copy_outputs() + logger.info("Job finished with exit code 0") + sys.exit(0) + + +def save_job_run_logs(output_uri=os.environ.get(CONST_ENV_LOG_OUTPUT)): + """Saves the job run logs to a file in output_uri.""" + if not output_uri: + return + if CONST_ENV_HOST_JOB_RUN_OCID not in os.environ: + return + + job_run_ocid = os.environ[CONST_ENV_HOST_JOB_RUN_OCID] + log_uri = os.path.join(output_uri, job_run_ocid + ".log") + # Wait for the job logs to be available in logging service + logger.debug("Saving job run logs to %s", log_uri) + time.sleep(60) + try: + job_run = DataScienceJobRun.from_ocid(job_run_ocid) + with fsspec.open(log_uri, "w") as f: + for log in job_run.logs(): + f.write(f"{log.get('message', '')}\n") + except Exception: + logger.error("Failed to save the job run logs to %s", log_uri) + logger.debug(traceback.format_exc()) if __name__ == "__main__": - # Collect GPU metrics only if GPU is available and user defined METRIC_NAMESPACE - if METRIC_NAMESPACE and torch.cuda.device_count(): - p = multiprocessing.Process(target=collect_metrics) - p.daemon = True - p.start() main() diff --git a/ads/jobs/templates/driver_utils.py b/ads/jobs/templates/driver_utils.py index 785230cff..22fda97d4 100644 --- a/ads/jobs/templates/driver_utils.py +++ b/ads/jobs/templates/driver_utils.py @@ -1,13 +1,14 @@ #!/usr/bin/env python # -*- coding: utf-8; -*- -# Copyright (c) 2023, 2024 Oracle and/or its affiliates. +# Copyright (c) 2023, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import contextlib import importlib import json import logging import os +import re import runpy import shlex import stat @@ -20,10 +21,8 @@ from urllib import request from urllib.parse import urlparse - import oci - CONST_ENV_LOG_LEVEL = "OCI_LOG_LEVEL" CONST_ENV_WORKING_DIR = "WORKING_DIR" CONST_ENV_CODE_DIR = "CODE_DIR" @@ -185,8 +184,8 @@ def substitute_output_uri(output_uri): @staticmethod def copy_outputs( - output_dir: str = os.environ.get("OUTPUT_DIR"), - output_uri: str = os.environ.get("OUTPUT_URI"), + output_dir: str = os.environ.get(CONST_ENV_OUTPUT_DIR), + output_uri: str = os.environ.get(CONST_ENV_OUTPUT_URI), ) -> List[str]: """Copies the output files to remote URI. @@ -348,12 +347,18 @@ def __init__(self, code_dir: str = DEFAULT_CODE_DIR) -> None: The path to the directory containing the user code. """ logger.info("Job Run ID is: %s", os.environ.get(CONST_ENV_JOB_RUN_OCID)) + if "VM_ID" in os.environ: + logger.debug("VM_ID: %s", os.environ["VM_ID"]) self.code_dir = code_dir self.conda_prefix = sys.executable.split("/bin/python", 1)[0] @staticmethod def run_command( - command: str, conda_prefix: str = None, level: Optional[int] = None, check=False + command: str, + conda_prefix: str = None, + level: Optional[int] = None, + check: bool = False, + envs: Optional[dict] = None, ) -> int: """Runs a shell command and logs the outputs with specific log level. @@ -369,6 +374,7 @@ def run_command( If this is set to a log level from logging, e.g. logging.DEBUG, the command outputs will be logged with the level. If this is None, the command outputs will be printed. + check : bool Returns ------- @@ -389,11 +395,14 @@ def run_command( ) else: cmd = command + process_envs = os.environ.copy() + if envs: + process_envs.update(envs) process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - env=os.environ.copy(), + env=process_envs, shell=True, ) # Stream the outputs @@ -403,7 +412,7 @@ def run_command( if process.poll() is not None and output == b"": break if output: - msg = output.decode() + msg = output.decode(errors="replace") if level is None: # output already contains the line break print(msg, flush=True, end="") @@ -412,9 +421,15 @@ def run_command( # logging will add line break msg = msg.rstrip("\n") logger.log(level=level, msg=msg) - if "pdsh@" in msg and "ssh exited with exit code 1" in msg: - print("DeepSpeed Failed.") - sys.exit(1) + if "pdsh@" in msg and "ssh exited with exit code" in msg: + codes = re.findall(r"\d+", msg) + if codes and len(codes) > 0: + code = codes[-1] + logger.info("DeepSpeed Failed with exit code %s", code) + else: + code = 1 + logger.error("Deepspeed Failed.") + sys.exit(int(code)) # Add a small delay so that # outputs from the subsequent code will have different timestamp for oci logging time.sleep(0.02) @@ -422,6 +437,7 @@ def run_command( "subprocess %s returned exit code %s", process.pid, process.returncode ) if check and process.returncode != 0: + logger.error("Command %s exited with code %s.", cmd, process.returncode) # If there is an error, exit the main process with the same return code. sys.exit(process.returncode) return process.returncode diff --git a/tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py b/tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py index f33110136..838278a79 100644 --- a/tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py +++ b/tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright (c) 2023 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import json @@ -8,16 +8,17 @@ import unittest import zipfile from unittest import mock -from ads.jobs import PyTorchDistributedRuntime, DataScienceJob, DataScienceJobRun + +from ads.jobs import DataScienceJob, DataScienceJobRun, PyTorchDistributedRuntime from ads.jobs.builders.infrastructure.dsc_job_runtime import ( PyTorchDistributedRuntimeHandler as Handler, ) from ads.jobs.builders.runtimes.pytorch_runtime import ( - PyTorchDistributedArtifact, GitPythonArtifact, + PyTorchDistributedArtifact, ) -from ads.opctl.distributed.common import cluster_config_helper as cluster from ads.jobs.templates import driver_utils as utils +from ads.opctl.distributed.common import cluster_config_helper as cluster class PyTorchRuntimeHandlerTest(unittest.TestCase): @@ -77,7 +78,7 @@ def test_translate_env(self): """Tests setting up environment variables""" envs = Handler(DataScienceJob())._translate_env(self.init_runtime()) self.assertIsInstance(envs, dict) - self.assertEqual(envs[Handler.CONST_WORKER_COUNT], str(self.REPLICAS - 1)) + self.assertEqual(envs[Handler.CONST_NODE_COUNT], str(self.REPLICAS)) self.assertEqual( envs[Handler.CONST_JOB_ENTRYPOINT], PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT, diff --git a/tests/unitary/default_setup/jobs/test_multi_node_job.py b/tests/unitary/default_setup/jobs/test_multi_node_job.py new file mode 100644 index 000000000..f98a430d6 --- /dev/null +++ b/tests/unitary/default_setup/jobs/test_multi_node_job.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from unittest import TestCase, main, mock, skipUnless + +from oci import Response + +from ads.jobs import ContainerRuntime, DataScienceJob, Job, PyTorchDistributedRuntime +from ads.jobs.builders.infrastructure.dsc_job_runtime import MULTI_NODE_JOB_SUPPORT + +test_cases = {"torchrun": "torchrun test_torch_distributed.py"} + + +LOG_GROUP_ID = "ocid1.loggroup.oc1.iad.aaa" +LOG_ID = "ocid1.log.oc1.iad.aaa" +SUBNET_ID = "ocid1.subnet.oc1.iad.aaa" +SHAPE_NAME = "VM.GPU.A10.2" +CONDA_NAME = "pytorch24_p310_gpu_x86_64_v1" + +CONDA_ENV_VARS = { + "CONDA_ENV_SLUG": CONDA_NAME, + "CONDA_ENV_TYPE": "service", + "JOB_RUN_ENTRYPOINT": "driver_pytorch.py", + "NODE_COUNT": "2", + "OCI_LOG_LEVEL": "DEBUG", + "OCI__LAUNCH_CMD": "torchrun artifact.py", +} + +CONTAINER_ENV_VARS = { + "NODE_COUNT": "2", + "OCI_LOG_LEVEL": "DEBUG", +} + + +@skipUnless( + MULTI_NODE_JOB_SUPPORT, + "Multi-Node Job is not supported by the OCI Python SDK installed.", +) +class MultiNodeJobTest(TestCase): + + def init_job_infra(self): + return ( + DataScienceJob() + .with_compartment_id("ocid1.compartment.oc1..aaa") + .with_project_id("ocid1.datascienceproject.oc1.iad.aaa") + .with_log_group_id(LOG_GROUP_ID) + .with_log_id(LOG_ID) + .with_shape_name(SHAPE_NAME) + .with_block_storage_size(256) + ) + + def assert_create_job_details(self, create_job_details, envs): + # Check log config + log_config = create_job_details.job_log_configuration_details + self.assertEqual(log_config.log_id, LOG_ID) + self.assertEqual(log_config.log_group_id, LOG_GROUP_ID) + + # Check top level configs + self.assertIsNone(create_job_details.job_configuration_details) + self.assertIsNone(create_job_details.job_environment_configuration_details) + self.assertIsNone(create_job_details.job_infrastructure_configuration_details) + + job_node_configuration_details = ( + create_job_details.job_node_configuration_details + ) + self.assertIsNotNone(job_node_configuration_details) + # Check network config + self.assertEqual( + job_node_configuration_details.job_network_configuration.job_network_type, + "DEFAULT_NETWORK", + ) + # Check node group config + self.assertEqual( + len( + job_node_configuration_details.job_node_group_configuration_details_list + ), + 1, + ) + node_group_config = ( + job_node_configuration_details.job_node_group_configuration_details_list[0] + ) + self.assertEqual( + node_group_config.job_configuration_details.environment_variables, + envs, + ) + self.assertEqual(node_group_config.replicas, 2) + # Check infra config + infra_config = node_group_config.job_infrastructure_configuration_details + self.assertEqual(infra_config.shape_name, "VM.GPU.A10.2") + self.assertEqual(infra_config.block_storage_size_in_gbs, 256) + self.assertEqual(infra_config.job_infrastructure_type, "MULTI_NODE") + + def assert_create_job_run_details(self, create_job_run_details): + self.assertIsNone(create_job_run_details.job_configuration_override_details) + self.assertIsNone( + create_job_run_details.job_infrastructure_configuration_override_details + ) + self.assertIsNone(create_job_run_details.job_log_configuration_override_details) + self.assertIsNone( + create_job_run_details.job_node_configuration_override_details + ) + + @mock.patch( + "ads.jobs.builders.runtimes.pytorch_runtime.PyTorchDistributedArtifact.build" + ) + @mock.patch("ads.jobs.builders.infrastructure.dsc_job.DSCJob.upload_artifact") + @mock.patch("oci.data_science.DataScienceClient.create_job_run") + @mock.patch("oci.data_science.DataScienceClient.create_job") + def test_create_multi_node_job_with_conda(self, patched_create, patched_run, *args): + patched_create.return_value = Response( + status=200, headers=None, request=None, data=None + ) + + infra = self.init_job_infra() + runtime = ( + PyTorchDistributedRuntime() + # Specify the service conda environment by slug name. + .with_service_conda(CONDA_NAME) + .with_command("torchrun artifact.py") + .with_environment_variable(OCI_LOG_LEVEL="DEBUG") + .with_replica(2) + ) + job = Job(name="DT Test").with_infrastructure(infra).with_runtime(runtime) + job.create() + create_job_details = patched_create.call_args.args[0] + + self.assert_create_job_details( + create_job_details=create_job_details, + envs=CONDA_ENV_VARS, + ) + node_group_config = create_job_details.job_node_configuration_details.job_node_group_configuration_details_list[ + 0 + ] + self.assertIsNone(node_group_config.job_environment_configuration_details) + + # Create Job with subnet_id + patched_create.reset_mock() + infra.with_subnet_id(SUBNET_ID) + job = Job(name="DT Test").with_infrastructure(infra).with_runtime(runtime) + job.create() + create_job_details = patched_create.call_args.args[0] + job_node_configuration_details = ( + create_job_details.job_node_configuration_details + ) + self.assertEqual( + job_node_configuration_details.job_network_configuration.subnet_id, + SUBNET_ID, + ) + patched_run.return_value = Response( + status=200, headers=None, request=None, data=None + ) + + # Check the payload for creating a job run + job.run() + create_job_run_details = patched_run.call_args.args[0] + self.assert_create_job_run_details(create_job_run_details) + + @mock.patch("oci.data_science.DataScienceClient.create_job_run") + @mock.patch("oci.data_science.DataScienceClient.create_job") + def test_create_multi_node_job_with_container( + self, patched_create, patched_run, *args + ): + patched_create.return_value = Response( + status=200, headers=None, request=None, data=None + ) + + infra = self.init_job_infra() + runtime = ( + ContainerRuntime() + # Specify the service conda environment by slug name. + .with_image("container_image") + .with_environment_variable(OCI_LOG_LEVEL="DEBUG") + .with_replica(2) + ) + job = Job(name="DT Test").with_infrastructure(infra).with_runtime(runtime) + job.create() + create_job_details = patched_create.call_args.args[0] + self.assert_create_job_details( + create_job_details=create_job_details, + envs=CONTAINER_ENV_VARS, + ) + node_group_config = create_job_details.job_node_configuration_details.job_node_group_configuration_details_list[ + 0 + ] + container_config = node_group_config.job_environment_configuration_details + self.assertEqual(container_config.job_environment_type, "OCIR_CONTAINER") + self.assertEqual(container_config.image, "container_image") + + patched_run.return_value = Response( + status=200, headers=None, request=None, data=None + ) + + # Check the payload for creating a job run + job.run() + create_job_run_details = patched_run.call_args.args[0] + self.assert_create_job_run_details(create_job_run_details) + + +if __name__ == "__main__": + main() diff --git a/tests/unitary/with_extras/jobs/test_pytorch_ddp.py b/tests/unitary/with_extras/jobs/test_pytorch_ddp.py index e8bf568e5..b32d39eeb 100644 --- a/tests/unitary/with_extras/jobs/test_pytorch_ddp.py +++ b/tests/unitary/with_extras/jobs/test_pytorch_ddp.py @@ -1,17 +1,18 @@ #!/usr/bin/env python -# Copyright (c) 2023 Oracle and/or its affiliates. +# Copyright (c) 2023, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os import sys import unittest -from unittest import mock +from unittest import SkipTest, mock + from ads.jobs import DataScienceJobRun from ads.jobs.builders.infrastructure.dsc_job_runtime import ( PyTorchDistributedRuntimeHandler as Handler, ) -from ads.jobs.templates import driver_utils as utils from ads.jobs.templates import driver_pytorch as driver +from ads.jobs.templates import driver_utils as utils class PyTorchRunnerTest(unittest.TestCase): @@ -49,6 +50,8 @@ def test_wait_for_host_ip(self): {"message": f"{driver.LOG_PREFIX_HOST_IP} {self.TEST_HOST_IP}"} ] runner = self.init_torch_runner() + if not runner.host_job_run: + raise SkipTest("Test is skipped for DTv2.") self.assertEqual(runner.host_ip, None) runner.wait_for_host_ip_address() self.assertEqual(runner.host_ip, self.TEST_HOST_IP) @@ -147,7 +150,11 @@ def test_touch_file(self, run_command): runner.touch_file("stop") commasnds = [call_args.args[0] for call_args in run_command.call_args_list] self.assertEqual( - commasnds, ["ssh -v 10.0.0.2 'touch stop'", "ssh -v 10.0.0.3 'touch stop'"] + commasnds, + [ + "ssh -v -o PasswordAuthentication=no 10.0.0.2 'touch stop'", + "ssh -v -o PasswordAuthentication=no 10.0.0.3 'touch stop'", + ], ) @@ -161,6 +168,8 @@ def init_runner(self): "ads.jobs.DataScienceJobRun.from_ocid" ) as GetJobRun, mock.patch( "ads.jobs.templates.driver_utils.JobRunner.run_command" + ), mock.patch( + "ads.jobs.templates.driver_pytorch.DeepSpeedRunner._print_host_key" ): GetHostIP.return_value = self.TEST_IP GetJobRun.return_value = DataScienceJobRun(id="ocid.abcdefghijk") @@ -186,7 +195,7 @@ def test_run(self, time_cmd, run_command, run_deepspeed): self.assertTrue( time_cmd.call_args.kwargs["cmd"].endswith( "libhostname.so.1 OCI__HOSTNAME=10.0.0.1 " - "accelerate launch --num_processes 2 --num_machines 2 --machine_rank 0 --main_process_port 29400 " + "accelerate launch --num_processes 2 --num_machines 2 --machine_rank 0 --main_process_port 29400 --use_deepspeed " "train.py --data abc" ), time_cmd.call_args.kwargs["cmd"], @@ -206,6 +215,7 @@ def test_run(self, time_cmd, run_command, run_deepspeed): "10.0.0.1", "--main_process_port", "29400", + "--use_deepspeed", "--deepspeed_hostfile=/home/datascience/hostfile", ], )