Skip to content

Commit 949d534

Browse files
committed
optimize unittest
1 parent 4f65b09 commit 949d534

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

tests/tools/test_process_data.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@
99
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
1010

1111

12+
def run_in_subprocess(cmd):
13+
result = subprocess.run(
14+
cmd,
15+
shell=True,
16+
capture_output=True,
17+
text=True
18+
)
19+
20+
if result.returncode != 0:
21+
print(f"Command failed with return code {result.returncode}")
22+
print(f"Standard Output: {result.stdout}")
23+
print(f"Standard Error: {result.stderr}")
24+
raise subprocess.CalledProcessError(result, cmd)
25+
26+
return result
27+
28+
1229
class ProcessDataTest(DataJuicerTestCaseBase):
1330

1431
def setUp(self):
@@ -78,10 +95,13 @@ def setUp(self):
7895
os.makedirs(self.tmp_dir)
7996

8097
def _auto_create_ray_cluster(self):
81-
if not subprocess.call('ray status', shell=True):
98+
try:
8299
# ray cluster already exists, return
100+
run_in_subprocess('ray status')
83101
self.tmp_ray_cluster = False
84102
return
103+
except:
104+
pass
85105

86106
self.tmp_ray_cluster = True
87107
head_port = '6379'
@@ -95,12 +115,10 @@ def _auto_create_ray_cluster(self):
95115

96116
print(f"current rank: {rank}; execute cmd: {cmd}")
97117

98-
result = subprocess.call(cmd, shell=True)
99-
if result != 0:
100-
raise subprocess.CalledProcessError(result, cmd)
118+
run_in_subprocess(cmd)
101119

102120
def _close_ray_cluster(self):
103-
subprocess.call('ray stop', shell=True)
121+
run_in_subprocess('ray stop')
104122

105123
def tearDown(self):
106124
super().tearDown()
@@ -148,10 +166,8 @@ def test_ray_image(self):
148166
with open(tmp_yaml_file, 'w') as file:
149167
yaml.dump(yaml_config, file)
150168

151-
status_code = subprocess.call(
152-
f'python tools/process_data.py --config {tmp_yaml_file}', shell=True)
169+
run_in_subprocess(f'python tools/process_data.py --config {tmp_yaml_file}')
153170

154-
self.assertEqual(status_code, 0)
155171
self.assertTrue(osp.exists(tmp_out_path))
156172

157173
import ray

0 commit comments

Comments
 (0)