Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ class ElasticLaunchConfig(LaunchConfig):
training_log_file: str = ""
failure_node_errors: str = ""
numa_affinity: bool = False
connect_master_timeout = 300
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unnecessary to define these two parameters here. You can refer to the timeout mechanism inside the master server/client and leverage existing parameters (by exposing them via config).

Additionally, note that these parameters should not be passed through dlrover-run; instead, they should be directly obtained by the master through job args.

connect_master_max_retry = 5

def set_node_unit(self, node_unit):
"""Set the number unit of nodes."""
Expand Down Expand Up @@ -316,6 +318,7 @@ def __init__(
node_rank,
rdzv_params: RendezvousParameters,
local_world_size,
connect_master_timeout = 300,
):
self._name = name
self._node_rank = node_rank
Expand All @@ -328,7 +331,7 @@ def __init__(
)
self.pend_timeout = float(rdzv_params.get("pend_timeout", "inf"))
self._client = MasterClient.singleton_instance()
self._store = MasterKVStore(self._name, timedelta(seconds=300))
self._store = MasterKVStore(self._name, timedelta(seconds=connect_master_timeout))
lastcall_timeout = int(rdzv_params.get("lastcall_timeout", 60))
node_unit = int(rdzv_params.get("node_unit", "1"))
self._client.report_rdzv_params(
Expand Down Expand Up @@ -521,6 +524,7 @@ def __init__(
training_log_file: str = "",
failure_node_errors: str = "",
with_diagnostician: bool = True,
connect_master_max_retry = 5,
):
if version_less_than_230():
super().__init__(
Expand Down Expand Up @@ -557,6 +561,7 @@ def __init__(
node_rank=node_rank,
local_world_size=config.nproc_per_node,
)
self.connect_master_max_retry = connect_master_max_retry
self._agent_context = get_agent_context()
self._rank_cpu_affinity = {}
if self._config.numa_affinity:
Expand Down Expand Up @@ -722,7 +727,7 @@ def _get_master_addr_port(self, store: Store) -> Tuple[str, int]:
return master_addr, master_port

def _safe_get_master_addr_port(self, store: Store) -> Tuple[str, int]:
for _ in range(5):
for _ in range(self.connect_master_max_retry):
try:
return self._get_master_addr_port(store)
except Exception as e:
Expand Down Expand Up @@ -1414,6 +1419,7 @@ def launch_agent(
training_log_file=config.training_log_file,
failure_node_errors=config.failure_node_errors,
exit_barrier_timeout=900,
connect_master_max_retry=config.connect_master_max_retry
)

shutdown_rdzv = True
Expand Down Expand Up @@ -1512,6 +1518,7 @@ def _create_worker_spec(
node_rank,
rdzv_parameters,
local_world_size=config.nproc_per_node,
connect_master_timeout=config.connect_master_timeout,
)
spec = WorkerSpec(
role=config.role,
Expand Down
6 changes: 6 additions & 0 deletions dlrover/python/master/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def _build_master_args_parser():
type=pos_int,
help="The timeout value of worker task process(For PS type job).",
)
parser.add_argument(
"--dead_node_timeout",
default=600,
type=int,
help="dead node timeout in seconds",
)
return parser


Expand Down
1 change: 1 addition & 0 deletions dlrover/python/master/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def run(args):
else:
from dlrover.python.master.dist_master import DistributedJobMaster

job_args.dead_node_timeout = args.dead_node_timeout
update_context(job_args)
master = DistributedJobMaster(_dlrover_context.master_port, job_args)
master.prepare()
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def _monitor_nodes(self):
def _monitor_node_heart_beat(self):
with self._lock:
try:
events = self._get_dead_node_event()
events = self._get_dead_node_event(window_interval=self._job_args.dead_node_timeout)
except Exception as e:
logger.warning(e)
events = []
Expand Down
1 change: 1 addition & 0 deletions dlrover/python/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, platform, namespace, job_name):
self.cordon_fault_node = False
self.xpu_type: Accelerators = Accelerators.GENERIC_CPU
self.enable_suspended = False
self.dead_node_timeout = 600

@abstractmethod
def initilize(self):
Expand Down
18 changes: 18 additions & 0 deletions dlrover/trainer/torch/elastic_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,22 @@ def parse_args(args):
action=check_env,
help="Whether to test the communication performance.",
)
parser.add_argument(
"--connect-master-timeout",
"--connect_master_timeout",
type=int,
action=env,
default=30,
help="Connect master timeout in seconds.",
)
parser.add_argument(
"--connect-master-max-retry",
"--connect_master_max_retry",
type=int,
action=env,
default=2,
help="Connect master max retry times.",
)
return parser.parse_args(args)


Expand Down Expand Up @@ -402,6 +418,8 @@ def _elastic_config_from_args(
elastic_config.rdzv_endpoint = ""
join_timeout = elastic_config.rdzv_configs.get("join_timeout", 600)
elastic_config.rdzv_configs["timeout"] = join_timeout
elastic_config.connect_master_timeout = args.connect_master_max_retry
elastic_config.connect_master_max_retry = args.connect_master_max_retry
return elastic_config, cmd, cmd_args


Expand Down