Skip to content

Commit 5d75224

Browse files
Review changes & added more unit tests
1 parent 4792573 commit 5d75224

File tree

2 files changed

+87
-4
lines changed

2 files changed

+87
-4
lines changed

src/codeflare_sdk/job/ray_jobs.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,14 @@ def delete_job(self, job_id: str) -> bool:
7272
"""
7373
Method for deleting jobs with the job id being a mandatory field.
7474
"""
75-
return self.rayJobClient.delete_job(job_id=job_id)
75+
deletion_status = self.rayJobClient.delete_job(job_id=job_id)
76+
77+
if deletion_status:
78+
print(f"Successfully deleted Job {job_id}")
79+
return deletion_status
80+
else:
81+
print(f"Failed to delete Job {job_id}")
82+
return deletion_status
7683

7784
def get_address(self) -> str:
7885
"""
@@ -108,7 +115,12 @@ def stop_job(self, job_id: str) -> bool:
108115
"""
109116
Method for stopping a job with the job id being a mandatory field.
110117
"""
111-
return self.rayJobClient.stop_job(job_id=job_id)
118+
stop_job_status = self.rayJobClient.stop_job(job_id=job_id)
119+
if stop_job_status:
120+
print(f"Successfully stopped Job {job_id}")
121+
else:
122+
print(f"Failed to stop Job, {job_id} could have already completed.")
123+
return stop_job_status
112124

113125
def tail_job_logs(self, job_id: str) -> Iterator[str]:
114126
"""

tests/unit_test.py

+73-2
Original file line numberDiff line numberDiff line change
@@ -2883,14 +2883,44 @@ def test_rjc_submit_job(ray_job_client, mocker):
28832883

28842884

28852885
def test_rjc_delete_job(ray_job_client, mocker):
2886-
mocked_delete_job = mocker.patch.object(
2886+
# Case return True
2887+
mocked_delete_job_True = mocker.patch.object(
28872888
JobSubmissionClient, "delete_job", return_value=True
28882889
)
28892890
result = ray_job_client.delete_job(job_id="mocked_job_id")
28902891

2891-
mocked_delete_job.assert_called_once_with(job_id="mocked_job_id")
2892+
mocked_delete_job_True.assert_called_once_with(job_id="mocked_job_id")
28922893
assert result is True
28932894

2895+
# Case return False
2896+
mocked_delete_job_False = mocker.patch.object(
2897+
JobSubmissionClient, "delete_job", return_value=False
2898+
)
2899+
result = ray_job_client.delete_job(job_id="mocked_job_id")
2900+
2901+
mocked_delete_job_False.assert_called_once_with(job_id="mocked_job_id")
2902+
assert result is False
2903+
2904+
2905+
def test_rjc_stop_job(ray_job_client, mocker):
2906+
# Case return True
2907+
mocked_stop_job_True = mocker.patch.object(
2908+
JobSubmissionClient, "stop_job", return_value=True
2909+
)
2910+
result = ray_job_client.stop_job(job_id="mocked_job_id")
2911+
2912+
mocked_stop_job_True.assert_called_once_with(job_id="mocked_job_id")
2913+
assert result is True
2914+
2915+
# Case return False
2916+
mocked_stop_job_False = mocker.patch.object(
2917+
JobSubmissionClient, "stop_job", return_value=False
2918+
)
2919+
result = ray_job_client.stop_job(job_id="mocked_job_id")
2920+
2921+
mocked_stop_job_False.assert_called_once_with(job_id="mocked_job_id")
2922+
assert result is False
2923+
28942924

28952925
def test_rjc_address(ray_job_client, mocker):
28962926
mocked_rjc_address = mocker.patch.object(
@@ -2928,6 +2958,47 @@ def test_rjc_get_job_info(ray_job_client, mocker):
29282958
assert job_details == job_details_example
29292959

29302960

2961+
def test_rjc_get_job_status(ray_job_client, mocker):
2962+
job_status_example = "<JobStatus.PENDING: 'PENDING'>"
2963+
mocked_rjc_get_job_status = mocker.patch.object(
2964+
JobSubmissionClient, "get_job_status", return_value=job_status_example
2965+
)
2966+
job_status = ray_job_client.get_job_status(job_id="mocked_job_id")
2967+
2968+
mocked_rjc_get_job_status.assert_called_once_with(job_id="mocked_job_id")
2969+
assert job_status == job_status_example
2970+
2971+
2972+
def test_rjc_tail_job_logs(ray_job_client, mocker):
2973+
logs_example = [
2974+
"Job started...",
2975+
"Processing input data...",
2976+
"Finalizing results...",
2977+
"Job completed successfully.",
2978+
]
2979+
mocked_rjc_tail_job_logs = mocker.patch.object(
2980+
JobSubmissionClient, "tail_job_logs", return_value=logs_example
2981+
)
2982+
job_tail_job_logs = ray_job_client.tail_job_logs(job_id="mocked_job_id")
2983+
2984+
mocked_rjc_tail_job_logs.assert_called_once_with(job_id="mocked_job_id")
2985+
assert job_tail_job_logs == logs_example
2986+
2987+
2988+
def test_rjc_list_jobs(ray_job_client, mocker):
2989+
jobs_list = [
2990+
"JobDetails(type=<JobType.SUBMISSION: 'SUBMISSION'>, job_id=None, submission_id='raysubmit_4k2NYS1YbRXYPZCM', driver_info=None, status=<JobStatus.SUCCEEDED: 'SUCCEEDED'>, entrypoint='python mnist.py', message='Job finished successfully.', error_type=None, start_time=1701352132585, end_time=1701352192002, metadata={}, runtime_env={'working_dir': 'gcs://_ray_pkg_6200b93a110e8033.zip', 'pip': {'packages': ['pytorch_lightning==1.5.10', 'ray_lightning', 'torchmetrics==0.9.1', 'torchvision==0.12.0'], 'pip_check': False}, '_ray_commit': 'b4bba4717f5ba04ee25580fe8f88eed63ef0c5dc'}, driver_agent_http_address='http://10.131.0.18:52365', driver_node_id='9fb515995f5fb13ad4db239ceea378333bebf0a2d45b6aa09d02e691')",
2991+
"JobDetails(type=<JobType.SUBMISSION: 'SUBMISSION'>, job_id=None, submission_id='raysubmit_iRuwU8vdkbUZZGvT', driver_info=None, status=<JobStatus.STOPPED: 'STOPPED'>, entrypoint='python mnist.py', message='Job was intentionally stopped.', error_type=None, start_time=1701353096163, end_time=1701353097733, metadata={}, runtime_env={'working_dir': 'gcs://_ray_pkg_6200b93a110e8033.zip', 'pip': {'packages': ['pytorch_lightning==1.5.10', 'ray_lightning', 'torchmetrics==0.9.1', 'torchvision==0.12.0'], 'pip_check': False}, '_ray_commit': 'b4bba4717f5ba04ee25580fe8f88eed63ef0c5dc'}, driver_agent_http_address='http://10.131.0.18:52365', driver_node_id='9fb515995f5fb13ad4db239ceea378333bebf0a2d45b6aa09d02e691')",
2992+
]
2993+
mocked_rjc_list_jobs = mocker.patch.object(
2994+
JobSubmissionClient, "list_jobs", return_value=jobs_list
2995+
)
2996+
job_list_jobs = ray_job_client.list_jobs()
2997+
2998+
mocked_rjc_list_jobs.assert_called_once()
2999+
assert job_list_jobs == jobs_list
3000+
3001+
29313002
# Make sure to always keep this function last
29323003
def test_cleanup():
29333004
os.remove(f"{aw_dir}unit-test-cluster.yaml")

0 commit comments

Comments
 (0)