1
1
import os
2
2
import os .path as osp
3
+ import re
3
4
import copy
4
5
import unittest
5
6
import json
@@ -25,21 +26,63 @@ def tearDown(self):
25
26
super ().tearDown ()
26
27
shutil .rmtree (self .tmp_dir )
27
28
28
- def _run_video_extract_frames_mapper (self ,
29
- op ,
30
- source_list ,
31
- target_list ,
32
- num_proc = 1 ):
33
- dataset = Dataset .from_list (source_list )
34
- dataset = dataset .map (op .process , batch_size = 2 , num_proc = num_proc )
35
- res_list = dataset .to_list ()
36
- self .assertEqual (res_list , target_list )
37
-
38
29
def _get_frames_list (self , filepath , frame_dir , frame_num ):
39
30
frames_dir = osp .join (frame_dir , osp .splitext (osp .basename (filepath ))[0 ])
40
31
frames_list = [osp .join (frames_dir , f'frame_{ i } .jpg' ) for i in range (frame_num )]
41
32
return frames_list
42
33
34
+ def _get_frames_dir (self , filepath , frame_dir ):
35
+ frames_dir = osp .join (frame_dir , osp .splitext (osp .basename (filepath ))[0 ])
36
+ return frames_dir
37
+
38
+ def _sort_files (self , file_list ):
39
+ return sorted (file_list , key = lambda x : int (re .search (r'(\d+)' , x ).group ()))
40
+
41
+ def test_duration (self ):
42
+ ds_list = [{
43
+ 'text' : f'{ SpecialTokens .video } 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' ,
44
+ 'videos' : [self .vid1_path ]
45
+ }, {
46
+ 'text' :
47
+ f'{ SpecialTokens .video } 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{ SpecialTokens .eoc } ' ,
48
+ 'videos' : [self .vid2_path ]
49
+ }, {
50
+ 'text' :
51
+ f'{ SpecialTokens .video } 两个长头发的女子正坐在一张圆桌前讲话互动。 { SpecialTokens .eoc } ' ,
52
+ 'videos' : [self .vid3_path ]
53
+ }]
54
+
55
+ frame_num = 2
56
+ frame_dir = os .path .join (self .tmp_dir , 'test1' )
57
+ vid1_frame_dir = self ._get_frames_dir (self .vid1_path , frame_dir )
58
+ vid2_frame_dir = self ._get_frames_dir (self .vid2_path , frame_dir )
59
+ vid3_frame_dir = self ._get_frames_dir (self .vid3_path , frame_dir )
60
+
61
+ tgt_list = copy .deepcopy (ds_list )
62
+ tgt_list [0 ].update ({Fields .video_frames : json .dumps ({self .vid1_path : vid1_frame_dir })})
63
+ tgt_list [1 ].update ({Fields .video_frames : json .dumps ({self .vid2_path : vid2_frame_dir })})
64
+ tgt_list [2 ].update ({Fields .video_frames : json .dumps ({self .vid3_path : vid3_frame_dir })})
65
+
66
+ op = VideoExtractFramesMapper (
67
+ frame_sampling_method = 'uniform' ,
68
+ frame_num = frame_num ,
69
+ duration = 0 ,
70
+ frame_dir = frame_dir )
71
+
72
+ dataset = Dataset .from_list (ds_list )
73
+ dataset = dataset .map (op .process , batch_size = 2 , num_proc = 1 )
74
+ res_list = dataset .to_list ()
75
+ self .assertEqual (res_list , tgt_list )
76
+ self .assertListEqual (
77
+ self ._sort_files (os .listdir (vid1_frame_dir )),
78
+ [f'frame_{ i } .jpg' for i in range (frame_num )])
79
+ self .assertListEqual (
80
+ self ._sort_files (os .listdir (vid2_frame_dir )),
81
+ [f'frame_{ i } .jpg' for i in range (frame_num )])
82
+ self .assertListEqual (
83
+ self ._sort_files (os .listdir (vid3_frame_dir )),
84
+ [f'frame_{ i } .jpg' for i in range (frame_num )])
85
+
43
86
def test_uniform_sampling (self ):
44
87
ds_list = [{
45
88
'text' : f'{ SpecialTokens .video } 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' ,
@@ -55,22 +98,35 @@ def test_uniform_sampling(self):
55
98
}]
56
99
frame_num = 3
57
100
frame_dir = os .path .join (self .tmp_dir , 'test1' )
101
+ vid1_frame_dir = self ._get_frames_dir (self .vid1_path , frame_dir )
102
+ vid2_frame_dir = self ._get_frames_dir (self .vid2_path , frame_dir )
103
+ vid3_frame_dir = self ._get_frames_dir (self .vid3_path , frame_dir )
58
104
59
105
tgt_list = copy .deepcopy (ds_list )
60
- tgt_list [0 ].update ({Fields .video_frames :
61
- json .dumps ({self .vid1_path : self ._get_frames_list (self .vid1_path , frame_dir , frame_num )})})
62
- tgt_list [1 ].update ({Fields .video_frames :
63
- json .dumps ({self .vid2_path : self ._get_frames_list (self .vid2_path , frame_dir , frame_num )})})
64
- tgt_list [2 ].update ({Fields .video_frames :
65
- json .dumps ({self .vid3_path : self ._get_frames_list (self .vid3_path , frame_dir , frame_num )})})
66
-
106
+ tgt_list [0 ].update ({Fields .video_frames : json .dumps ({self .vid1_path : vid1_frame_dir })})
107
+ tgt_list [1 ].update ({Fields .video_frames : json .dumps ({self .vid2_path : vid2_frame_dir })})
108
+ tgt_list [2 ].update ({Fields .video_frames : json .dumps ({self .vid3_path : vid3_frame_dir })})
109
+
67
110
op = VideoExtractFramesMapper (
68
111
frame_sampling_method = 'uniform' ,
69
112
frame_num = frame_num ,
113
+ duration = 10 ,
70
114
frame_dir = frame_dir )
71
- self ._run_video_extract_frames_mapper (op , ds_list , tgt_list )
72
115
73
-
116
+ dataset = Dataset .from_list (ds_list )
117
+ dataset = dataset .map (op .process , batch_size = 2 , num_proc = 1 )
118
+ res_list = dataset .to_list ()
119
+ self .assertEqual (res_list , tgt_list )
120
+ self .assertListEqual (
121
+ self ._sort_files (os .listdir (vid1_frame_dir )),
122
+ [f'frame_{ i } .jpg' for i in range (3 )])
123
+ self .assertListEqual (
124
+ self ._sort_files (os .listdir (vid2_frame_dir )),
125
+ [f'frame_{ i } .jpg' for i in range (6 )])
126
+ self .assertListEqual (
127
+ self ._sort_files (os .listdir (vid3_frame_dir )),
128
+ [f'frame_{ i } .jpg' for i in range (12 )])
129
+
74
130
def test_all_keyframes_sampling (self ):
75
131
ds_list = [{
76
132
'text' : f'{ SpecialTokens .video } 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' ,
@@ -86,22 +142,38 @@ def test_all_keyframes_sampling(self):
86
142
'videos' : [self .vid3_path ]
87
143
}]
88
144
frame_dir = os .path .join (self .tmp_dir , 'test2' )
145
+ vid1_frame_dir = self ._get_frames_dir (self .vid1_path , frame_dir )
146
+ vid2_frame_dir = self ._get_frames_dir (self .vid2_path , frame_dir )
147
+ vid3_frame_dir = self ._get_frames_dir (self .vid3_path , frame_dir )
89
148
90
149
tgt_list = copy .deepcopy (ds_list )
91
150
tgt_list [0 ].update ({Fields .video_frames :
92
- json .dumps ({self .vid1_path : self . _get_frames_list ( self . vid1_path , frame_dir , 3 ) })})
151
+ json .dumps ({self .vid1_path : vid1_frame_dir })})
93
152
tgt_list [1 ].update ({Fields .video_frames : json .dumps ({
94
- self .vid2_path : self . _get_frames_list ( self . vid2_path , frame_dir , 3 ) ,
95
- self .vid3_path : self . _get_frames_list ( self . vid3_path , frame_dir , 6 )
153
+ self .vid2_path : vid2_frame_dir ,
154
+ self .vid3_path : vid3_frame_dir
96
155
})})
97
156
tgt_list [2 ].update ({Fields .video_frames :
98
- json .dumps ({self .vid3_path : self . _get_frames_list ( self . vid3_path , frame_dir , 6 ) })})
157
+ json .dumps ({self .vid3_path : vid3_frame_dir })})
99
158
100
159
op = VideoExtractFramesMapper (
101
160
frame_sampling_method = 'all_keyframes' ,
102
- frame_dir = frame_dir )
103
- self . _run_video_extract_frames_mapper ( op , ds_list , tgt_list )
161
+ frame_dir = frame_dir ,
162
+ duration = 5 )
104
163
164
+ dataset = Dataset .from_list (ds_list )
165
+ dataset = dataset .map (op .process , batch_size = 2 , num_proc = 2 )
166
+ res_list = dataset .to_list ()
167
+ self .assertEqual (res_list , tgt_list )
168
+ self .assertListEqual (
169
+ self ._sort_files (os .listdir (vid1_frame_dir )),
170
+ [f'frame_{ i } .jpg' for i in range (4 )])
171
+ self .assertListEqual (
172
+ self ._sort_files (os .listdir (vid2_frame_dir )),
173
+ [f'frame_{ i } .jpg' for i in range (5 )])
174
+ self .assertListEqual (
175
+ self ._sort_files (os .listdir (vid3_frame_dir )),
176
+ [f'frame_{ i } .jpg' for i in range (13 )])
105
177
106
178
107
179
if __name__ == '__main__' :
0 commit comments