@@ -24,7 +24,18 @@ class VideoExtractFramesMapperTest(DataJuicerTestCaseBase):
24
24
25
25
def tearDown (self ):
26
26
super ().tearDown ()
27
- shutil .rmtree (self .tmp_dir )
27
+ if osp .exists (self .tmp_dir ):
28
+ shutil .rmtree (self .tmp_dir )
29
+
30
+ default_frame_dir_prefix = self ._get_default_frame_dir_prefix ()
31
+ if osp .exists (default_frame_dir_prefix ):
32
+ shutil .rmtree (osp .dirname (default_frame_dir_prefix ))
33
+
34
+ def _get_default_frame_dir_prefix (self ):
35
+ from data_juicer .ops .mapper .video_extract_frames_mapper import OP_NAME
36
+ default_frame_dir_prefix = osp .abspath (osp .join (self .data_path ,
37
+ f'{ Fields .multimodal_data_output_dir } /{ OP_NAME } /' ))
38
+ return default_frame_dir_prefix
28
39
29
40
def _get_frames_list (self , filepath , frame_dir , frame_num ):
30
41
frames_dir = osp .join (frame_dir , osp .splitext (osp .basename (filepath ))[0 ])
@@ -175,6 +186,57 @@ def test_all_keyframes_sampling(self):
175
186
self ._sort_files (os .listdir (vid3_frame_dir )),
176
187
[f'frame_{ i } .jpg' for i in range (13 )])
177
188
189
+ def test_default_frame_dir (self ):
190
+ ds_list = [{
191
+ 'text' : f'{ SpecialTokens .video } 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' ,
192
+ 'videos' : [self .vid1_path ]
193
+ }, {
194
+ 'text' :
195
+ f'{ SpecialTokens .video } 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{ SpecialTokens .eoc } ' ,
196
+ 'videos' : [self .vid2_path ]
197
+ }, {
198
+ 'text' :
199
+ f'{ SpecialTokens .video } 两个长头发的女子正坐在一张圆桌前讲话互动。 { SpecialTokens .eoc } ' ,
200
+ 'videos' : [self .vid3_path ]
201
+ }]
202
+
203
+ frame_num = 2
204
+ op = VideoExtractFramesMapper (
205
+ frame_sampling_method = 'uniform' ,
206
+ frame_num = frame_num ,
207
+ duration = 5 ,
208
+ )
209
+
210
+ vid1_frame_dir = op ._get_default_frame_dir (self .vid1_path )
211
+ vid2_frame_dir = op ._get_default_frame_dir (self .vid2_path )
212
+ vid3_frame_dir = op ._get_default_frame_dir (self .vid3_path )
213
+
214
+ tgt_list = copy .deepcopy (ds_list )
215
+ tgt_list [0 ].update ({Fields .video_frames : json .dumps ({self .vid1_path : vid1_frame_dir })})
216
+ tgt_list [1 ].update ({Fields .video_frames : json .dumps ({self .vid2_path : vid2_frame_dir })})
217
+ tgt_list [2 ].update ({Fields .video_frames : json .dumps ({self .vid3_path : vid3_frame_dir })})
218
+
219
+ dataset = Dataset .from_list (ds_list )
220
+ dataset = dataset .map (op .process , batch_size = 2 , num_proc = 1 )
221
+ res_list = dataset .to_list ()
222
+
223
+ frame_dir_prefix = self ._get_default_frame_dir_prefix ()
224
+ self .assertIn (frame_dir_prefix , osp .abspath (vid1_frame_dir ))
225
+ self .assertIn (frame_dir_prefix , osp .abspath (vid2_frame_dir ))
226
+ self .assertIn (frame_dir_prefix , osp .abspath (vid3_frame_dir ))
227
+
228
+ self .assertEqual (res_list , tgt_list )
229
+
230
+ self .assertListEqual (
231
+ self ._sort_files (os .listdir (vid1_frame_dir )),
232
+ [f'frame_{ i } .jpg' for i in range (4 )])
233
+ self .assertListEqual (
234
+ self ._sort_files (os .listdir (vid2_frame_dir )),
235
+ [f'frame_{ i } .jpg' for i in range (8 )])
236
+ self .assertListEqual (
237
+ self ._sort_files (os .listdir (vid3_frame_dir )),
238
+ [f'frame_{ i } .jpg' for i in range (18 )])
239
+
178
240
179
241
if __name__ == '__main__' :
180
242
unittest .main ()
0 commit comments