Skip to content

Commit 6fdfe4c

Browse files
committed
update unittest
1 parent 293993e commit 6fdfe4c

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

tests/ops/mapper/test_video_extract_frames_mapper.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,18 @@ class VideoExtractFramesMapperTest(DataJuicerTestCaseBase):
2424

2525
def tearDown(self):
2626
super().tearDown()
27-
shutil.rmtree(self.tmp_dir)
27+
if osp.exists(self.tmp_dir):
28+
shutil.rmtree(self.tmp_dir)
29+
30+
default_frame_dir_prefix = self._get_default_frame_dir_prefix()
31+
if osp.exists(default_frame_dir_prefix):
32+
shutil.rmtree(osp.dirname(default_frame_dir_prefix))
33+
34+
def _get_default_frame_dir_prefix(self):
35+
from data_juicer.ops.mapper.video_extract_frames_mapper import OP_NAME
36+
default_frame_dir_prefix = osp.abspath(osp.join(self.data_path,
37+
f'{Fields.multimodal_data_output_dir}/{OP_NAME}/'))
38+
return default_frame_dir_prefix
2839

2940
def _get_frames_list(self, filepath, frame_dir, frame_num):
3041
frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0])
@@ -175,6 +186,57 @@ def test_all_keyframes_sampling(self):
175186
self._sort_files(os.listdir(vid3_frame_dir)),
176187
[f'frame_{i}.jpg' for i in range(13)])
177188

189+
def test_default_frame_dir(self):
190+
ds_list = [{
191+
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
192+
'videos': [self.vid1_path]
193+
}, {
194+
'text':
195+
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
196+
'videos': [self.vid2_path]
197+
}, {
198+
'text':
199+
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
200+
'videos': [self.vid3_path]
201+
}]
202+
203+
frame_num = 2
204+
op = VideoExtractFramesMapper(
205+
frame_sampling_method='uniform',
206+
frame_num=frame_num,
207+
duration=5,
208+
)
209+
210+
vid1_frame_dir = op._get_default_frame_dir(self.vid1_path)
211+
vid2_frame_dir = op._get_default_frame_dir(self.vid2_path)
212+
vid3_frame_dir = op._get_default_frame_dir(self.vid3_path)
213+
214+
tgt_list = copy.deepcopy(ds_list)
215+
tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})})
216+
tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})})
217+
tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})})
218+
219+
dataset = Dataset.from_list(ds_list)
220+
dataset = dataset.map(op.process, batch_size=2, num_proc=1)
221+
res_list = dataset.to_list()
222+
223+
frame_dir_prefix = self._get_default_frame_dir_prefix()
224+
self.assertIn(frame_dir_prefix, osp.abspath(vid1_frame_dir))
225+
self.assertIn(frame_dir_prefix, osp.abspath(vid2_frame_dir))
226+
self.assertIn(frame_dir_prefix, osp.abspath(vid3_frame_dir))
227+
228+
self.assertEqual(res_list, tgt_list)
229+
230+
self.assertListEqual(
231+
self._sort_files(os.listdir(vid1_frame_dir)),
232+
[f'frame_{i}.jpg' for i in range(4)])
233+
self.assertListEqual(
234+
self._sort_files(os.listdir(vid2_frame_dir)),
235+
[f'frame_{i}.jpg' for i in range(8)])
236+
self.assertListEqual(
237+
self._sort_files(os.listdir(vid3_frame_dir)),
238+
[f'frame_{i}.jpg' for i in range(18)])
239+
178240

179241
if __name__ == '__main__':
180242
unittest.main()

0 commit comments

Comments
 (0)