Skip to content

Commit 73c7316

Browse files
committed
update unittest
1 parent 0a2ddd6 commit 73c7316

File tree

1 file changed

+5
-35
lines changed

1 file changed

+5
-35
lines changed

tests/tools/test_process_data.py

+5-35
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import subprocess
55
import tempfile
66
import unittest
7+
import uuid
78
import yaml
89

910
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -45,8 +46,7 @@ def setUp(self):
4546
super().setUp()
4647

4748
self.tmp_dir = tempfile.TemporaryDirectory().name
48-
if not osp.exists(self.tmp_dir):
49-
os.makedirs(self.tmp_dir)
49+
os.makedirs(self.tmp_dir, exist_ok=True)
5050

5151
def tearDown(self):
5252
super().tearDown()
@@ -101,36 +101,9 @@ class ProcessDataRayTest(DataJuicerTestCaseBase):
101101
def setUp(self):
102102
super().setUp()
103103

104-
# self._auto_create_ray_cluster()
105-
self.tmp_dir = f'/workspace/tmp/{self.__class__.__name__}'
106-
if not osp.exists(self.tmp_dir):
107-
os.makedirs(self.tmp_dir)
108-
109-
def _auto_create_ray_cluster(self):
110-
try:
111-
# ray cluster already exists, return
112-
run_in_subprocess('ray status')
113-
self.tmp_ray_cluster = False
114-
return
115-
except:
116-
pass
117-
118-
self.tmp_ray_cluster = True
119-
head_port = '6379'
120-
head_addr = '127.0.0.1'
121-
rank = int(os.environ.get('RANK', 0))
122-
123-
if rank == 0:
124-
cmd = f"ray start --head --port={head_port} --node-ip-address={head_addr}"
125-
else:
126-
cmd = f"ray start --address={head_addr}:{head_port}"
127-
128-
print(f"current rank: {rank}; execute cmd: {cmd}")
129-
130-
run_in_subprocess(cmd)
131-
132-
def _close_ray_cluster(self):
133-
run_in_subprocess('ray stop')
104+
cur_dir = osp.dirname(osp.abspath(__file__))
105+
self.tmp_dir = osp.join(cur_dir, f'tmp_{uuid.uuid4().hex}')
106+
os.makedirs(self.tmp_dir, exist_ok=True)
134107

135108
def tearDown(self):
136109
super().tearDown()
@@ -141,9 +114,6 @@ def tearDown(self):
141114
import ray
142115
ray.shutdown()
143116

144-
# if self.tmp_ray_cluster:
145-
# self._close_ray_cluster()
146-
147117
def test_ray_image(self):
148118
tmp_yaml_file = osp.join(self.tmp_dir, 'config_0.yaml')
149119
tmp_out_path = osp.join(self.tmp_dir, 'output_0.json')

0 commit comments

Comments
 (0)