@@ -66,5 +66,102 @@ def test_status_code_1(self):
66
66
self .assertFalse (osp .exists (tmp_out_path ))
67
67
68
68
69
+ class ProcessDataRayTest (DataJuicerTestCaseBase ):
70
+
71
+ def setUp (self ):
72
+ super ().setUp ()
73
+
74
+ self ._auto_create_ray_cluster ()
75
+
76
+ self .tmp_dir = tempfile .TemporaryDirectory ().name
77
+ if not osp .exists (self .tmp_dir ):
78
+ os .makedirs (self .tmp_dir )
79
+
80
+ def _auto_create_ray_cluster (self ):
81
+ if not subprocess .call ('ray status' , shell = True ):
82
+ # ray cluster already exists, return
83
+ self .tmp_ray_cluster = False
84
+ return
85
+
86
+ self .tmp_ray_cluster = True
87
+ head_port = '6379'
88
+ head_addr = '127.0.0.1'
89
+ rank = int (os .environ .get ('RANK' , 0 ))
90
+
91
+ if rank == 0 :
92
+ cmd = f"ray start --head --port={ head_port } --node-ip-address={ head_addr } "
93
+ else :
94
+ cmd = f"ray start --address={ head_addr } :{ head_port } "
95
+
96
+ print (f"current rank: { rank } ; execute cmd: { cmd } " )
97
+
98
+ result = subprocess .call (cmd , shell = True )
99
+ if result != 0 :
100
+ raise subprocess .CalledProcessError (result , cmd )
101
+
102
+ def _close_ray_cluster (self ):
103
+ subprocess .call ('ray stop' , shell = True )
104
+
105
+ def tearDown (self ):
106
+ super ().tearDown ()
107
+
108
+ if osp .exists (self .tmp_dir ):
109
+ shutil .rmtree (self .tmp_dir )
110
+
111
+ import ray
112
+ ray .shutdown ()
113
+
114
+ if self .tmp_ray_cluster :
115
+ self ._close_ray_cluster ()
116
+
117
+ def test_ray_image (self ):
118
+ tmp_yaml_file = osp .join (self .tmp_dir , 'config_0.yaml' )
119
+ tmp_out_path = osp .join (self .tmp_dir , 'output_0.json' )
120
+ text_keys = 'text'
121
+
122
+ data_path = osp .join (osp .dirname (osp .dirname (osp .dirname (osp .realpath (__file__ )))),
123
+ 'demos' , 'data' , 'demo-dataset-images.jsonl' )
124
+ yaml_config = {
125
+ 'dataset_path' : data_path ,
126
+ 'executor_type' : 'ray' ,
127
+ 'ray_address' : 'auto' ,
128
+ 'text_keys' : text_keys ,
129
+ 'image_key' : 'images' ,
130
+ 'export_path' : tmp_out_path ,
131
+ 'process' : [
132
+ {
133
+ 'image_nsfw_filter' : {
134
+ 'hf_nsfw_model' : 'Falconsai/nsfw_image_detection' ,
135
+ 'trust_remote_code' : True ,
136
+ 'score_threshold' : 0.5 ,
137
+ 'any_or_all' : 'any' ,
138
+ 'mem_required' : '8GB'
139
+ },
140
+ 'image_aspect_ratio_filter' :{
141
+ 'min_ratio' : 0.5 ,
142
+ 'max_ratio' : 2.0
143
+ }
144
+ }
145
+ ]
146
+ }
147
+
148
+ with open (tmp_yaml_file , 'w' ) as file :
149
+ yaml .dump (yaml_config , file )
150
+
151
+ status_code = subprocess .call (
152
+ f'python tools/process_data.py --config { tmp_yaml_file } ' , shell = True )
153
+
154
+ self .assertEqual (status_code , 0 )
155
+ self .assertTrue (osp .exists (tmp_out_path ))
156
+
157
+ import ray
158
+ res_ds = ray .data .read_json (tmp_out_path )
159
+ res_ds = res_ds .to_pandas ().to_dict (orient = 'records' )
160
+
161
+ self .assertEqual (len (res_ds ), 3 )
162
+ for item in res_ds :
163
+ self .assertIn ('aspect_ratios' , item ['__dj__stats__' ])
164
+
165
+
69
166
if __name__ == '__main__' :
70
167
unittest .main ()
0 commit comments