Skip to content

Commit 2eb2ab0

Browse files
committed
support py_volcano platform
1 parent b8b0b8a commit 2eb2ab0

File tree

8 files changed

+224
-2
lines changed

8 files changed

+224
-2
lines changed

dlrover/python/common/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class PlatformType(object):
2626
RAY = "ray"
2727
PY_KUBERNETES = "pyk8s"
2828
LOCAL = "local"
29+
PY_VOLCANO = "py_volcano"
2930

3031

3132
class CommunicationType(object):
@@ -293,6 +294,8 @@ class NodeEnv(object):
293294
NODE_ID = "NODE_ID"
294295
NODE_NUM = "NODE_NUM"
295296
NODE_RANK = "NODE_RANK"
297+
MAX_NODE = "MAX_NODE"
298+
MIN_NODE = "MIN_NODE"
296299

297300
# Deprecated env vars.
298301
WORKER_TYPE = "WORKER_TYPE"

dlrover/python/master/scaler/factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,7 @@ def new_job_scaler(platform, job_name, namespace):
2929
return ActorScaler(job_name, namespace)
3030
elif platform == PlatformType.LOCAL:
3131
return None
32+
elif platform == PlatformType.PY_VOLCANO:
33+
from dlrover.python.master.scaler.volcano_scaler import VolcanoScaler
34+
35+
return VolcanoScaler(job_name, namespace)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2022 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from dlrover.python.master.scaler.base_scaler import ScalePlan, Scaler
15+
16+
17+
# Do nothing for VolcanoScaler, use volcano to scaling.
18+
class VolcanoScaler(Scaler):
19+
def __init__(self, job_name, namespace):
20+
super(VolcanoScaler, self).__init__(job_name)
21+
self._namespace = namespace
22+
23+
def start(self):
24+
pass
25+
26+
def scale(self, plan: ScalePlan):
27+
pass

dlrover/python/master/watcher/factory.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818
def new_node_watcher(platform, job_name, namespace):
1919
logger.info("New %s Node Watcher", platform)
20-
if platform in (PlatformType.KUBERNETES, PlatformType.PY_KUBERNETES):
20+
if platform in (
21+
PlatformType.KUBERNETES,
22+
PlatformType.PY_KUBERNETES,
23+
PlatformType.PY_VOLCANO,
24+
):
2125
from dlrover.python.master.watcher.k8s_watcher import PodWatcher
2226

2327
return PodWatcher(job_name, namespace)
@@ -43,5 +47,11 @@ def new_scale_plan_watcher(platform, job_name, namespace, job_uuid):
4347
)
4448

4549
return RayScalePlanWatcher(job_name, namespace, job_uuid)
50+
elif platform in (PlatformType.PY_VOLCANO):
51+
from dlrover.python.master.watcher.volcano_watcher import (
52+
VolcanoScalePlanWatcher,
53+
)
54+
55+
return VolcanoScalePlanWatcher(job_name, namespace, job_uuid)
4656
else:
4757
raise ValueError("Not support engine %s", platform)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2022 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
import time
14+
15+
16+
# Do nothing for VolcanoScalePlanWatcher, use volcano to scaling.
17+
class VolcanoScalePlanWatcher:
18+
def __init__(self, job_name, namespace, job_uuid):
19+
self.job_name = job_name
20+
self.namespace = namespace
21+
self.job_uuid = job_uuid
22+
23+
def watch(self):
24+
while True:
25+
time.sleep(1000)
26+
yield None

dlrover/python/scheduler/factory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def new_elastic_job(platform, job_name, namespace):
2626
from dlrover.python.scheduler.ray import RayElasticJob
2727

2828
return RayElasticJob(job_name, namespace)
29+
elif platform in (PlatformType.PY_VOLCANO):
30+
from dlrover.python.scheduler.volcano import VolcanoElasticJob
31+
32+
return VolcanoElasticJob(job_name, namespace)
2933
else:
3034
raise ValueError("Not support engine %s", platform)
3135

@@ -42,5 +46,9 @@ def new_job_args(platform, job_name, namespace):
4246
return RayJobArgs(platform, namespace, job_name)
4347
elif platform == PlatformType.LOCAL:
4448
return LocalJobArgs(platform, namespace, job_name)
49+
elif platform == PlatformType.PY_VOLCANO:
50+
from dlrover.python.scheduler.volcano import VolcanoJobArgs
51+
52+
return VolcanoJobArgs(platform, namespace, job_name)
4553
else:
4654
raise ValueError("Not support platform %s", platform)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2022 The DLRover Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
import time
16+
17+
from dlrover.python.common.constants import (
18+
DistributionStrategy,
19+
NodeEnv,
20+
NodeType,
21+
OptimizeMode,
22+
)
23+
from dlrover.python.common.log import default_logger as logger
24+
from dlrover.python.common.node import NodeGroupResource, NodeResource
25+
from dlrover.python.scheduler.job import ElasticJob, JobArgs, NodeArgs
26+
from dlrover.python.scheduler.kubernetes import (
27+
_dlrover_context,
28+
convert_cpu_to_decimal,
29+
convert_memory_to_mb,
30+
k8sClient,
31+
)
32+
33+
CONFIGMAP_SUFFIX = "-dlrover-conf"
34+
35+
36+
class VolcanoElasticJob(ElasticJob):
37+
def __init__(self, job_name, namespace):
38+
self._namespace = namespace
39+
self._job_name = job_name
40+
41+
def get_node_name(self, type, id):
42+
return "pod-name"
43+
44+
def get_node_service_addr(self, type, id):
45+
return ""
46+
47+
48+
class VolcanoJobArgs(JobArgs):
49+
def __init__(self, platform, namespace, job_name):
50+
super(VolcanoJobArgs, self).__init__(platform, namespace, job_name)
51+
52+
def initilize(self):
53+
self.user = os.getenv("USER", "")
54+
k8s_client = k8sClient.singleton_instance(self.namespace)
55+
56+
# Get parameters from configmap
57+
configmap = self._retry_to_get_configmap(k8s_client)
58+
self.job_uuid = os.getenv(NodeEnv.JOB_UID, "")
59+
self.distribution_strategy = configmap.data.get(
60+
"distributionStrategy", DistributionStrategy.ALLREDUCE
61+
)
62+
self.optimizeMode = configmap.data.get(
63+
"optimizeMode", OptimizeMode.SINGLE_JOB
64+
)
65+
66+
# Get parameters from volcano job
67+
vcjob = self._retry_to_get_vcjob(k8s_client)
68+
for task in vcjob["spec"]["tasks"]:
69+
if task["name"] == NodeType.WORKER:
70+
restart_policy = task["template"]["spec"].get(
71+
"restartPolicy", ""
72+
)
73+
self.relaunch_always = restart_policy == "Always"
74+
75+
num = int(task.get("replicas", 0))
76+
assert len(task["template"]["spec"]["containers"]) == 1
77+
container = task["template"]["spec"]["containers"][0]
78+
resources = container.get("resources", {})
79+
requests = resources.get("requests", {})
80+
cpu = convert_cpu_to_decimal(requests.get("cpu", 0))
81+
if "memory" in requests:
82+
memory = convert_memory_to_mb(requests["memory"])
83+
else:
84+
memory = 0
85+
gpu_type = None
86+
gpu_num = 0
87+
for k, v in requests.items():
88+
if "nvidia.com" in k:
89+
gpu_type = k
90+
gpu_num = int(v)
91+
group_resource = NodeGroupResource(
92+
num,
93+
NodeResource(
94+
cpu=cpu,
95+
memory=memory,
96+
gpu_type=gpu_type,
97+
gpu_num=gpu_num,
98+
),
99+
)
100+
self.node_args[task["name"]] = NodeArgs(
101+
group_resource,
102+
process_timeout=_dlrover_context.seconds_to_timeout_task_process,
103+
)
104+
logger.info("Job args = %s", self.__dict__)
105+
106+
def _retry_to_get_configmap(self, k8s_client: k8sClient):
107+
for _ in range(3):
108+
configmap = k8s_client.get_configmap(
109+
name=self.job_name + CONFIGMAP_SUFFIX,
110+
)
111+
if configmap:
112+
return configmap
113+
else:
114+
time.sleep(5)
115+
raise ValueError("Cannot get the conifgmap %s" % self.job_name)
116+
117+
def _retry_to_get_vcjob(self, k8s_client: k8sClient):
118+
for _ in range(3):
119+
vcjob = k8s_client.get_custom_resource(
120+
name=self.job_name,
121+
group="batch.volcano.sh",
122+
version="v1alpha1",
123+
plural="jobs",
124+
)
125+
if vcjob:
126+
return vcjob
127+
else:
128+
time.sleep(5)
129+
raise ValueError(
130+
"Cannot get the training volcano job %s" % self.job_name
131+
)

dlrover/trainer/torch/elastic_run.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,20 @@ def parse_args(args):
211211
action=check_env,
212212
help="Whether to test the communication performance.",
213213
)
214-
return parser.parse_args(args)
214+
args = parser.parse_args(args)
215+
216+
# reconfigure nnodes
217+
max_node = os.getenv(NodeEnv.MAX_NODE, None)
218+
min_node = os.getenv(NodeEnv.MAX_NODE, None)
219+
if max_node is not None and min_node is not None:
220+
new_nnodes = "{}:{}".format(
221+
int(os.getenv("MIN_NODE")), int(os.getenv("MAX_NODE"))
222+
)
223+
logger.info(
224+
f"The nnodes will be reconfigured from {args.nnodes} to {new_nnodes}"
225+
)
226+
args.nnodes = new_nnodes
227+
return args
215228

216229

217230
class ElasticLaunch:

0 commit comments

Comments
 (0)