Skip to content

Commit 830cbce

Browse files
authored
Adding direct KubeRay compatibility to the SDK (#358)
* Added component generation * Added multi-resource YAML support * Cluster.up on ray cluster object * Basic status and down for RayCluster * Finished up/down and added unit tests * Remove unused utils import * Applied review feedback * Changed naming of internal funcs * Review feedback applied, auto-select * OAuth conflict resolution
1 parent 2441f4f commit 830cbce

File tree

7 files changed

+450
-84
lines changed

7 files changed

+450
-84
lines changed

Diff for: src/codeflare_sdk/cluster/cluster.py

+161-61
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, config: ClusterConfiguration):
7070
self.config = config
7171
self.app_wrapper_yaml = self.create_app_wrapper()
7272
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
73-
self._client = None
73+
self._job_submission_client = None
7474

7575
@property
7676
def _client_headers(self):
@@ -86,23 +86,25 @@ def _client_verify_tls(self):
8686
return not self.config.openshift_oauth
8787

8888
@property
89-
def client(self):
90-
if self._client:
91-
return self._client
89+
def job_client(self):
90+
if self._job_submission_client:
91+
return self._job_submission_client
9292
if self.config.openshift_oauth:
9393
print(
9494
api_config_handler().configuration.get_api_key_with_prefix(
9595
"authorization"
9696
)
9797
)
98-
self._client = JobSubmissionClient(
98+
self._job_submission_client = JobSubmissionClient(
9999
self.cluster_dashboard_uri(),
100100
headers=self._client_headers,
101101
verify=self._client_verify_tls,
102102
)
103103
else:
104-
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
105-
return self._client
104+
self._job_submission_client = JobSubmissionClient(
105+
self.cluster_dashboard_uri()
106+
)
107+
return self._job_submission_client
106108

107109
def evaluate_dispatch_priority(self):
108110
priority_class = self.config.dispatch_priority
@@ -141,6 +143,10 @@ def create_app_wrapper(self):
141143

142144
# Before attempting to create the cluster AW, let's evaluate the ClusterConfig
143145
if self.config.dispatch_priority:
146+
if not self.config.mcad:
147+
raise ValueError(
148+
"Invalid Cluster Configuration, cannot have dispatch priority without MCAD"
149+
)
144150
priority_val = self.evaluate_dispatch_priority()
145151
if priority_val == None:
146152
raise ValueError(
@@ -163,6 +169,7 @@ def create_app_wrapper(self):
163169
template = self.config.template
164170
image = self.config.image
165171
instascale = self.config.instascale
172+
mcad = self.config.mcad
166173
instance_types = self.config.machine_types
167174
env = self.config.envs
168175
local_interactive = self.config.local_interactive
@@ -183,6 +190,7 @@ def create_app_wrapper(self):
183190
template=template,
184191
image=image,
185192
instascale=instascale,
193+
mcad=mcad,
186194
instance_types=instance_types,
187195
env=env,
188196
local_interactive=local_interactive,
@@ -207,15 +215,18 @@ def up(self):
207215
try:
208216
config_check()
209217
api_instance = client.CustomObjectsApi(api_config_handler())
210-
with open(self.app_wrapper_yaml) as f:
211-
aw = yaml.load(f, Loader=yaml.FullLoader)
212-
api_instance.create_namespaced_custom_object(
213-
group="workload.codeflare.dev",
214-
version="v1beta1",
215-
namespace=namespace,
216-
plural="appwrappers",
217-
body=aw,
218-
)
218+
if self.config.mcad:
219+
with open(self.app_wrapper_yaml) as f:
220+
aw = yaml.load(f, Loader=yaml.FullLoader)
221+
api_instance.create_namespaced_custom_object(
222+
group="workload.codeflare.dev",
223+
version="v1beta1",
224+
namespace=namespace,
225+
plural="appwrappers",
226+
body=aw,
227+
)
228+
else:
229+
self._component_resources_up(namespace, api_instance)
219230
except Exception as e: # pragma: no cover
220231
return _kube_api_error_handling(e)
221232

@@ -228,13 +239,16 @@ def down(self):
228239
try:
229240
config_check()
230241
api_instance = client.CustomObjectsApi(api_config_handler())
231-
api_instance.delete_namespaced_custom_object(
232-
group="workload.codeflare.dev",
233-
version="v1beta1",
234-
namespace=namespace,
235-
plural="appwrappers",
236-
name=self.app_wrapper_name,
237-
)
242+
if self.config.mcad:
243+
api_instance.delete_namespaced_custom_object(
244+
group="workload.codeflare.dev",
245+
version="v1beta1",
246+
namespace=namespace,
247+
plural="appwrappers",
248+
name=self.app_wrapper_name,
249+
)
250+
else:
251+
self._component_resources_down(namespace, api_instance)
238252
except Exception as e: # pragma: no cover
239253
return _kube_api_error_handling(e)
240254

@@ -252,42 +266,46 @@ def status(
252266
"""
253267
ready = False
254268
status = CodeFlareClusterStatus.UNKNOWN
255-
# check the app wrapper status
256-
appwrapper = _app_wrapper_status(self.config.name, self.config.namespace)
257-
if appwrapper:
258-
if appwrapper.status in [
259-
AppWrapperStatus.RUNNING,
260-
AppWrapperStatus.COMPLETED,
261-
AppWrapperStatus.RUNNING_HOLD_COMPLETION,
262-
]:
263-
ready = False
264-
status = CodeFlareClusterStatus.STARTING
265-
elif appwrapper.status in [
266-
AppWrapperStatus.FAILED,
267-
AppWrapperStatus.DELETED,
268-
]:
269-
ready = False
270-
status = CodeFlareClusterStatus.FAILED # should deleted be separate
271-
return status, ready # exit early, no need to check ray status
272-
elif appwrapper.status in [
273-
AppWrapperStatus.PENDING,
274-
AppWrapperStatus.QUEUEING,
275-
]:
276-
ready = False
277-
if appwrapper.status == AppWrapperStatus.PENDING:
278-
status = CodeFlareClusterStatus.QUEUED
279-
else:
280-
status = CodeFlareClusterStatus.QUEUEING
281-
if print_to_console:
282-
pretty_print.print_app_wrappers_status([appwrapper])
283-
return (
284-
status,
285-
ready,
286-
) # no need to check the ray status since still in queue
269+
if self.config.mcad:
270+
# check the app wrapper status
271+
appwrapper = _app_wrapper_status(self.config.name, self.config.namespace)
272+
if appwrapper:
273+
if appwrapper.status in [
274+
AppWrapperStatus.RUNNING,
275+
AppWrapperStatus.COMPLETED,
276+
AppWrapperStatus.RUNNING_HOLD_COMPLETION,
277+
]:
278+
ready = False
279+
status = CodeFlareClusterStatus.STARTING
280+
elif appwrapper.status in [
281+
AppWrapperStatus.FAILED,
282+
AppWrapperStatus.DELETED,
283+
]:
284+
ready = False
285+
status = CodeFlareClusterStatus.FAILED # should deleted be separate
286+
return status, ready # exit early, no need to check ray status
287+
elif appwrapper.status in [
288+
AppWrapperStatus.PENDING,
289+
AppWrapperStatus.QUEUEING,
290+
]:
291+
ready = False
292+
if appwrapper.status == AppWrapperStatus.PENDING:
293+
status = CodeFlareClusterStatus.QUEUED
294+
else:
295+
status = CodeFlareClusterStatus.QUEUEING
296+
if print_to_console:
297+
pretty_print.print_app_wrappers_status([appwrapper])
298+
return (
299+
status,
300+
ready,
301+
) # no need to check the ray status since still in queue
287302

288303
# check the ray cluster status
289304
cluster = _ray_cluster_status(self.config.name, self.config.namespace)
290-
if cluster and not cluster.status == RayClusterStatus.UNKNOWN:
305+
if cluster:
306+
if cluster.status == RayClusterStatus.UNKNOWN:
307+
ready = False
308+
status = CodeFlareClusterStatus.STARTING
291309
if cluster.status == RayClusterStatus.READY:
292310
ready = True
293311
status = CodeFlareClusterStatus.READY
@@ -407,19 +425,19 @@ def list_jobs(self) -> List:
407425
"""
408426
This method accesses the head ray node in your cluster and lists the running jobs.
409427
"""
410-
return self.client.list_jobs()
428+
return self.job_client.list_jobs()
411429

412430
def job_status(self, job_id: str) -> str:
413431
"""
414432
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
415433
"""
416-
return self.client.get_job_status(job_id)
434+
return self.job_client.get_job_status(job_id)
417435

418436
def job_logs(self, job_id: str) -> str:
419437
"""
420438
This method accesses the head ray node in your cluster and returns the logs for the provided job id.
421439
"""
422-
return self.client.get_job_logs(job_id)
440+
return self.job_client.get_job_logs(job_id)
423441

424442
def torchx_config(
425443
self, working_dir: str = None, requirements: str = None
@@ -435,7 +453,7 @@ def torchx_config(
435453
to_return["requirements"] = requirements
436454
return to_return
437455

438-
def from_k8_cluster_object(rc):
456+
def from_k8_cluster_object(rc, mcad=True):
439457
machine_types = (
440458
rc["metadata"]["labels"]["orderedinstance"].split("_")
441459
if "orderedinstance" in rc["metadata"]["labels"]
@@ -474,6 +492,7 @@ def from_k8_cluster_object(rc):
474492
0
475493
]["image"],
476494
local_interactive=local_interactive,
495+
mcad=mcad,
477496
)
478497
return Cluster(cluster_config)
479498

@@ -484,6 +503,66 @@ def local_client_url(self):
484503
else:
485504
return "None"
486505

506+
def _component_resources_up(
507+
self, namespace: str, api_instance: client.CustomObjectsApi
508+
):
509+
with open(self.app_wrapper_yaml) as f:
510+
yamls = yaml.load_all(f, Loader=yaml.FullLoader)
511+
for resource in yamls:
512+
if resource["kind"] == "RayCluster":
513+
api_instance.create_namespaced_custom_object(
514+
group="ray.io",
515+
version="v1alpha1",
516+
namespace=namespace,
517+
plural="rayclusters",
518+
body=resource,
519+
)
520+
elif resource["kind"] == "Route":
521+
api_instance.create_namespaced_custom_object(
522+
group="route.openshift.io",
523+
version="v1",
524+
namespace=namespace,
525+
plural="routes",
526+
body=resource,
527+
)
528+
elif resource["kind"] == "Secret":
529+
secret_instance = client.CoreV1Api(api_config_handler())
530+
secret_instance.create_namespaced_secret(
531+
namespace=namespace,
532+
body=resource,
533+
)
534+
535+
def _component_resources_down(
536+
self, namespace: str, api_instance: client.CustomObjectsApi
537+
):
538+
with open(self.app_wrapper_yaml) as f:
539+
yamls = yaml.load_all(f, Loader=yaml.FullLoader)
540+
for resource in yamls:
541+
if resource["kind"] == "RayCluster":
542+
api_instance.delete_namespaced_custom_object(
543+
group="ray.io",
544+
version="v1alpha1",
545+
namespace=namespace,
546+
plural="rayclusters",
547+
name=self.app_wrapper_name,
548+
)
549+
elif resource["kind"] == "Route":
550+
name = resource["metadata"]["name"]
551+
api_instance.delete_namespaced_custom_object(
552+
group="route.openshift.io",
553+
version="v1",
554+
namespace=namespace,
555+
plural="routes",
556+
name=name,
557+
)
558+
elif resource["kind"] == "Secret":
559+
name = resource["metadata"]["name"]
560+
secret_instance = client.CoreV1Api(api_config_handler())
561+
secret_instance.delete_namespaced_secret(
562+
namespace=namespace,
563+
name=name,
564+
)
565+
487566

488567
def list_all_clusters(namespace: str, print_to_console: bool = True):
489568
"""
@@ -549,13 +628,33 @@ def get_cluster(cluster_name: str, namespace: str = "default"):
549628

550629
for rc in rcs["items"]:
551630
if rc["metadata"]["name"] == cluster_name:
552-
return Cluster.from_k8_cluster_object(rc)
631+
mcad = _check_aw_exists(cluster_name, namespace)
632+
return Cluster.from_k8_cluster_object(rc, mcad=mcad)
553633
raise FileNotFoundError(
554634
f"Cluster {cluster_name} is not found in {namespace} namespace"
555635
)
556636

557637

558638
# private methods
639+
def _check_aw_exists(name: str, namespace: str) -> bool:
640+
try:
641+
config_check()
642+
api_instance = client.CustomObjectsApi(api_config_handler())
643+
aws = api_instance.list_namespaced_custom_object(
644+
group="workload.codeflare.dev",
645+
version="v1beta1",
646+
namespace=namespace,
647+
plural="appwrappers",
648+
)
649+
except Exception as e: # pragma: no cover
650+
return _kube_api_error_handling(e, print_error=False)
651+
652+
for aw in aws["items"]:
653+
if aw["metadata"]["name"] == name:
654+
return True
655+
return False
656+
657+
559658
def _get_ingress_domain():
560659
try:
561660
config_check()
@@ -660,6 +759,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
660759

661760
config_check()
662761
api_instance = client.CustomObjectsApi(api_config_handler())
762+
# UPDATE THIS
663763
routes = api_instance.list_namespaced_custom_object(
664764
group="route.openshift.io",
665765
version="v1",

Diff for: src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ClusterConfiguration:
4646
num_gpus: int = 0
4747
template: str = f"{dir}/templates/base-template.yaml"
4848
instascale: bool = False
49+
mcad: bool = True
4950
envs: dict = field(default_factory=dict)
5051
image: str = "quay.io/project-codeflare/ray:latest-py39-cu118"
5152
local_interactive: bool = False

Diff for: src/codeflare_sdk/job/jobs.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222
from torchx.schedulers.ray_scheduler import RayScheduler
2323
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
2424

25-
from ray.job_submission import JobSubmissionClient
26-
27-
import openshift as oc
2825

2926
if TYPE_CHECKING:
3027
from ..cluster.cluster import Cluster
@@ -96,9 +93,9 @@ def __init__(
9693

9794
def _dry_run(self, cluster: "Cluster"):
9895
j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus
99-
runner = get_runner(ray_client=cluster.client)
96+
runner = get_runner(ray_client=cluster.job_client)
10097
runner._scheduler_instances["ray"] = RayScheduler(
101-
session_name=runner._name, ray_client=cluster.client
98+
session_name=runner._name, ray_client=cluster.job_client
10299
)
103100
return (
104101
runner.dryrun(

0 commit comments

Comments
 (0)