Skip to content

Add support to use multi-node job run (DTv2) API for distributed training #1165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
916f70f
Use multi-node job run APIs to create job when available.
qiuosier Jan 27, 2025
c88a7c4
Update utility functions for ADS jobs.
qiuosier Jan 27, 2025
1621bae
Update PyTorch driver.
qiuosier Feb 3, 2025
ddf786d
Update imports in dsc_job.py.
qiuosier Feb 3, 2025
ddf0e93
Update PyTorch driver.
qiuosier Feb 3, 2025
67db3ab
Update PyTorch driver.
qiuosier Feb 3, 2025
87a75b3
Fix bug for DeepSpeed.
qiuosier Feb 4, 2025
634ce96
Update PyTorch driver.
qiuosier Feb 4, 2025
b7218a1
Update logic to start job run for multi-node job.
qiuosier Feb 4, 2025
a310ae5
Fix bug for configuring subnet.
qiuosier Feb 5, 2025
6861892
Set network interface name.
qiuosier Feb 5, 2025
8aea639
Update YAML parsing.
qiuosier Feb 5, 2025
4c6350b
Update replica processing.
qiuosier Feb 5, 2025
ae816b6
Allow host_job_run to be None for torchrun in DTv2.
qiuosier Feb 6, 2025
c675dc0
Skip installing deepspeed dependencies for HF accelerate when deepspe…
qiuosier Feb 6, 2025
23be0e0
Fix errors.
qiuosier Feb 6, 2025
6d55882
Set NODE_COUNT in job run and show replica in YAML.
qiuosier Feb 6, 2025
dced541
Add tests.
qiuosier Feb 7, 2025
5b20089
Remove the use of exceptions module.
qiuosier Feb 7, 2025
d3e453c
Merge branch 'main' into feature/dtv2
qiuosier Apr 1, 2025
921c582
Update PyTorch driver.
qiuosier Apr 1, 2025
33eea54
Update PyTorch driver to fix SSH error.
qiuosier Apr 10, 2025
9b44b85
Update PyTorch driver for IFNAME
qiuosier Apr 10, 2025
7446c4b
Merge remote-tracking branch 'origin/main' into feature/dtv2
qiuosier Apr 21, 2025
70e3ce1
Update copyright info and sort imports.
qiuosier Apr 23, 2025
ba5659f
Update test.
qiuosier Apr 23, 2025
be174f8
Update test.
qiuosier Apr 23, 2025
6b92889
Merge branch 'main' into feature/dtv2
qiuosier Apr 23, 2025
a829ace
Update pytorch runner tests.
qiuosier Apr 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 121 additions & 24 deletions ads/jobs/builders/infrastructure/dsc_job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from __future__ import annotations

Expand All @@ -21,37 +20,41 @@
import oci
import oci.data_science
import oci.util as oci_util
import yaml
from oci.data_science import models
from oci.data_science.models import JobInfrastructureConfigurationDetails
from oci.exceptions import ServiceError
import yaml

from ads.common import utils
from ads.common.decorator.utils import class_or_instance_method
from ads.common.dsc_file_system import (
DSCFileSystemManager,
OCIFileStorage,
OCIObjectStorage,
)
from ads.common.oci_datascience import DSCNotebookSession, OCIDataScienceMixin
from ads.common.oci_logging import OCILog
from ads.common.oci_resource import ResourceNotFoundError
from ads.jobs.builders.infrastructure.base import Infrastructure, RunInstance
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
MULTI_NODE_JOB_SUPPORT,
ContainerRuntimeHandler,
DataScienceJobRuntimeManager,
)
from ads.jobs.builders.infrastructure.utils import get_value
from ads.jobs.builders.runtimes.artifact import Artifact
from ads.jobs.builders.runtimes.base import MultiNodeRuntime
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime

from ads.common.dsc_file_system import (
OCIFileStorage,
DSCFileSystemManager,
OCIObjectStorage,
)
from ads.common.decorator.utils import class_or_instance_method

logger = logging.getLogger(__name__)

SLEEP_INTERVAL = 3
WAIT_SECONDS_AFTER_FINISHED = 90
MAXIMUM_MOUNT_COUNT = 5
FILE_STORAGE_TYPE = "FILE_STORAGE"
OBJECT_STORAGE_TYPE = "OBJECT_STORAGE"
DEFAULT_NODE_GROUP_NAME = "node-group"


class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
Expand Down Expand Up @@ -284,11 +287,15 @@ def load_properties_from_env(self) -> None:

def load_defaults(self) -> DSCJob:
self.load_properties_from_env()
if getattr(self, "job_node_configuration_details", None):
return self
# Following are for single node job run only
if not self.job_infrastructure_configuration_details:
self.job_infrastructure_configuration_details = {}

# Convert the dict to JobInfrastructureConfigurationDetails object
if isinstance(self.job_infrastructure_configuration_details, dict):
# Default networking

if not self.job_infrastructure_configuration_details.get(
"jobInfrastructureType"
):
Expand Down Expand Up @@ -352,6 +359,7 @@ def create(self) -> DSCJob:
raise ValueError("Specify compartment ID for data science job.")
if not self.project_id:
raise ValueError("Specify project ID for data science job.")

self._create_with_oci_api()
return self

Expand Down Expand Up @@ -498,7 +506,9 @@ def run(self, **kwargs) -> DataScienceJobRun:
keys = list(kwargs.keys())
for key in keys:
if key in config_swagger_types:
config_kwargs[key] = kwargs.pop(key)
val = kwargs.pop(key)
if val is not None:
config_kwargs[key] = val
elif key in env_config_swagger_types:
value = kwargs.pop(key)
if key in [
Expand Down Expand Up @@ -545,6 +555,25 @@ def run(self, **kwargs) -> DataScienceJobRun:
env_config_override
)

if getattr(self, "job_node_configuration_details", None):
job_config_override = kwargs.pop("job_configuration_override_details", None)
env_config_override = kwargs.pop(
"job_environment_configuration_override_details", None
)
if job_config_override or env_config_override:
node_config = {
"jobNodeType": "MULTI_NODE",
"jobNodeGroupConfigurationDetailsList": [
{
# Node group name must match the node group name in the job.
"name": DEFAULT_NODE_GROUP_NAME,
"JobConfigurationDetails": job_config_override,
"JobEnvironmentConfigurationDetails": env_config_override,
}
],
}
kwargs["job_node_configuration_override_details"] = node_config

wait = kwargs.pop("wait", False)
run = DataScienceJobRun(**kwargs, **self.auth).create()
if wait:
Expand Down Expand Up @@ -756,13 +785,11 @@ def stop_condition():
return True
# Stop only if time_finished is over 2 minute ago.
# This is for the time delay between job run stopped and the logs appear in oci logging.
if (
return (
datetime.datetime.now(self.time_finished.tzinfo)
- datetime.timedelta(seconds=wait)
> self.time_finished
):
return True
return False
)

if not self.log_id and not self.log_group_id:
print(
Expand Down Expand Up @@ -1471,6 +1498,23 @@ def _update_from_dsc_model(
}
self.dsc_job = dsc_job

# Process multi-node infrastructure config
node_groups = get_value(
dsc_job,
"job_node_configuration_details.job_node_group_configuration_details_list",
)
if node_groups and len(node_groups) == 1:
node_group = node_groups[0]
dsc_job.job_infrastructure_configuration_details = (
node_group.job_infrastructure_configuration_details
)
subnet_id = get_value(
dsc_job,
"job_node_configuration_details.job_network_configuration.subnet_id",
)
if subnet_id:
self.set_spec(self.CONST_SUBNET_ID, subnet_id)

for infra_attr, dsc_attr in self.payload_attribute_map.items():
value = get_value(dsc_job, dsc_attr)
if not value:
Expand Down Expand Up @@ -1557,10 +1601,13 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
if value:
dsc_job.job_infrastructure_configuration_details[camel_attr] = value

if not dsc_job.job_infrastructure_configuration_details.get(
"shapeName", ""
).endswith("Flex") and dsc_job.job_infrastructure_configuration_details.get(
"jobShapeConfigDetails"
shape = dsc_job.job_infrastructure_configuration_details.get("shapeName", "")
if (
shape
and not str(shape).endswith("Flex")
and dsc_job.job_infrastructure_configuration_details.get(
"jobShapeConfigDetails"
)
):
raise ValueError(
"Shape config is not required for non flex shape from user end."
Expand All @@ -1583,7 +1630,6 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
return self

def build(self) -> DataScienceJob:
self.dsc_job.load_defaults()

try:
self.dsc_job.load_defaults()
Expand Down Expand Up @@ -1611,6 +1657,48 @@ def init(self, **kwargs) -> DataScienceJob:
)
)

def _config_multi_node(self, runtime: MultiNodeRuntime):
"""Configure the payload for multi-node job run."""
infra_config: dict = self.dsc_job.job_infrastructure_configuration_details
job_config: models.DefaultJobConfigurationDetails = (
self.dsc_job.job_configuration_details
)
env_config = self.dsc_job.job_environment_configuration_details
# For multi-node jobs,
# the job_infrastructure_configuration_details and job_configuration_details
# should be the special EMPTY class.
# The job_environment_configuration_details should be None.
# The configs will be specified in each node group.
self.dsc_job.job_infrastructure_configuration_details = None
self.dsc_job.job_configuration_details = None
self.dsc_job.job_environment_configuration_details = None

subnet_id = infra_config.pop("subnetId", None)
infra_config["jobInfrastructureType"] = (
models.MultiNodeJobInfrastructureConfigurationDetails.JOB_INFRASTRUCTURE_TYPE_MULTI_NODE
)

if subnet_id:
network_config = models.JobCustomNetworkConfiguration(subnet_id=subnet_id)
else:
network_config = models.JobDefaultNetworkConfiguration()

node_group_config: dict = {
"name": DEFAULT_NODE_GROUP_NAME,
"replicas": runtime.replica,
"minimumSuccessReplicas": runtime.replica,
"jobInfrastructureConfigurationDetails": infra_config,
"jobConfigurationDetails": job_config,
"jobEnvironmentConfigurationDetails": env_config,
}

self.dsc_job.job_node_configuration_details = {
"jobNodeType": "MULTI_NODE",
"startupOrder": "IN_PARALLEL",
"jobNetworkConfiguration": network_config,
"jobNodeGroupConfigurationDetailsList": [node_group_config],
}

def create(self, runtime, **kwargs) -> DataScienceJob:
"""Creates a job with runtime.

Expand All @@ -1635,9 +1723,7 @@ def create(self, runtime, **kwargs) -> DataScienceJob:

if self.name:
display_name = Template(self.name).safe_substitute(runtime.envs)
elif isinstance(runtime, GitPythonRuntime) or isinstance(
runtime, ContainerRuntime
):
elif isinstance(runtime, (GitPythonRuntime, ContainerRuntime)):
display_name = utils.get_random_name_for_resource()
else:
display_name = None
Expand All @@ -1652,11 +1738,22 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
self.dsc_job = DSCJob(**payload, **self.auth)
# Set Job infra to user values after DSCJob initialized the defaults
self._update_job_infra(self.dsc_job)
if self.is_multi_node_job(runtime):
self._config_multi_node(runtime=runtime)
self.dsc_job.create()
# Update the model from infra after job creation.
self._update_from_dsc_model(self.dsc_job)
return self

@staticmethod
def is_multi_node_job(runtime):
"""Check if the job is multi-node job."""
return (
MULTI_NODE_JOB_SUPPORT
and isinstance(runtime, MultiNodeRuntime)
and runtime.replica > 1
)

def run(
self,
name=None,
Expand Down
Loading