Skip to content

Commit 4f65b09

Browse files
committed
add ray unittest
1 parent bba1f38 commit 4f65b09

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

data_juicer/config/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def init_setup_from_cfg(cfg: Namespace):
429429
cfg.np = sys_cpu_count
430430
logger.warning(
431431
f'Number of processes `np` is not set, '
432-
f'Set it to cpu count [{sys_cpu_count}] as default value.')
432+
f'set it to cpu count [{sys_cpu_count}] as default value.')
433433
if cfg.np > sys_cpu_count:
434434
logger.warning(f'Number of processes `np` is set as [{cfg.np}], which '
435435
f'is larger than the cpu count [{sys_cpu_count}]. Due '

tests/tools/test_process_data.py

+97
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,102 @@ def test_status_code_1(self):
6666
self.assertFalse(osp.exists(tmp_out_path))
6767

6868

69+
class ProcessDataRayTest(DataJuicerTestCaseBase):
70+
71+
def setUp(self):
72+
super().setUp()
73+
74+
self._auto_create_ray_cluster()
75+
76+
self.tmp_dir = tempfile.TemporaryDirectory().name
77+
if not osp.exists(self.tmp_dir):
78+
os.makedirs(self.tmp_dir)
79+
80+
def _auto_create_ray_cluster(self):
81+
if not subprocess.call('ray status', shell=True):
82+
# ray cluster already exists, return
83+
self.tmp_ray_cluster = False
84+
return
85+
86+
self.tmp_ray_cluster = True
87+
head_port = '6379'
88+
head_addr = '127.0.0.1'
89+
rank = int(os.environ.get('RANK', 0))
90+
91+
if rank == 0:
92+
cmd = f"ray start --head --port={head_port} --node-ip-address={head_addr}"
93+
else:
94+
cmd = f"ray start --address={head_addr}:{head_port}"
95+
96+
print(f"current rank: {rank}; execute cmd: {cmd}")
97+
98+
result = subprocess.call(cmd, shell=True)
99+
if result != 0:
100+
raise subprocess.CalledProcessError(result, cmd)
101+
102+
def _close_ray_cluster(self):
103+
subprocess.call('ray stop', shell=True)
104+
105+
def tearDown(self):
106+
super().tearDown()
107+
108+
if osp.exists(self.tmp_dir):
109+
shutil.rmtree(self.tmp_dir)
110+
111+
import ray
112+
ray.shutdown()
113+
114+
if self.tmp_ray_cluster:
115+
self._close_ray_cluster()
116+
117+
def test_ray_image(self):
118+
tmp_yaml_file = osp.join(self.tmp_dir, 'config_0.yaml')
119+
tmp_out_path = osp.join(self.tmp_dir, 'output_0.json')
120+
text_keys = 'text'
121+
122+
data_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))),
123+
'demos', 'data', 'demo-dataset-images.jsonl')
124+
yaml_config = {
125+
'dataset_path': data_path,
126+
'executor_type': 'ray',
127+
'ray_address': 'auto',
128+
'text_keys': text_keys,
129+
'image_key': 'images',
130+
'export_path': tmp_out_path,
131+
'process': [
132+
{
133+
'image_nsfw_filter': {
134+
'hf_nsfw_model': 'Falconsai/nsfw_image_detection',
135+
'trust_remote_code': True,
136+
'score_threshold': 0.5,
137+
'any_or_all': 'any',
138+
'mem_required': '8GB'
139+
},
140+
'image_aspect_ratio_filter':{
141+
'min_ratio': 0.5,
142+
'max_ratio': 2.0
143+
}
144+
}
145+
]
146+
}
147+
148+
with open(tmp_yaml_file, 'w') as file:
149+
yaml.dump(yaml_config, file)
150+
151+
status_code = subprocess.call(
152+
f'python tools/process_data.py --config {tmp_yaml_file}', shell=True)
153+
154+
self.assertEqual(status_code, 0)
155+
self.assertTrue(osp.exists(tmp_out_path))
156+
157+
import ray
158+
res_ds = ray.data.read_json(tmp_out_path)
159+
res_ds = res_ds.to_pandas().to_dict(orient='records')
160+
161+
self.assertEqual(len(res_ds), 3)
162+
for item in res_ds:
163+
self.assertIn('aspect_ratios', item['__dj__stats__'])
164+
165+
69166
if __name__ == '__main__':
70167
unittest.main()

0 commit comments

Comments
 (0)