Skip to content

Commit 7a36c39

Browse files
committed
support extract frames by seconds
1 parent 33b99f6 commit 7a36c39

File tree

3 files changed

+205
-39
lines changed

3 files changed

+205
-39
lines changed

data_juicer/ops/mapper/video_extract_frames_mapper.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
from data_juicer.utils.constant import Fields
88
from data_juicer.utils.file_utils import create_directory_if_not_exists
9-
from data_juicer.utils.mm_utils import (SpecialTokens, close_video,
10-
extract_key_frames,
11-
extract_video_frames_uniformly,
12-
load_data_with_context, load_video)
9+
from data_juicer.utils.mm_utils import (
10+
SpecialTokens, close_video, extract_key_frames,
11+
extract_key_frames_by_seconds, extract_video_frames_uniformly,
12+
extract_video_frames_uniformly_by_seconds, load_data_with_context,
13+
load_video)
1314

1415
from ..base_op import OPERATORS, Mapper
1516
from ..op_fusion import LOADED_VIDEOS
@@ -44,6 +45,7 @@ def __init__(
4445
self,
4546
frame_sampling_method: str = 'all_keyframes',
4647
frame_num: PositiveInt = 3,
48+
duration: float = 0,
4749
frame_dir: str = None,
4850
frame_key=Fields.video_frames,
4951
*args,
@@ -57,13 +59,19 @@ def __init__(
5759
The former one extracts all key frames (the number
5860
of which depends on the duration of the video) and the latter
5961
one extract specified number of frames uniformly from the video.
62+
If "duration" > 0, frame_sampling_method acts on every segment.
6063
Default: "all_keyframes".
6164
:param frame_num: the number of frames to be extracted uniformly from
6265
the video. Only works when frame_sampling_method is "uniform". If
6366
it's 1, only the middle frame will be extracted. If it's 2, only
6467
the first and the last frames will be extracted. If it's larger
6568
than 2, in addition to the first and the last frames, other frames
6669
will be extracted uniformly within the video duration.
70+
If "duration" > 0, frame_num is the number of frames per segment.
71+
:param duration: The duration of each segment in seconds.
72+
If 0, frames are extracted from the entire video.
73+
If duration > 0, the video is segmented into multiple segments
74+
based on duration, and frames are extracted from each segment.
6775
:param frame_dir: Output directory to save extracted frames.
6876
If None, a default directory based on the video file path is used.
6977
:param frame_key: The name of field to save generated frames info.
@@ -82,6 +90,7 @@ def __init__(
8290
self.frame_dir = frame_dir
8391
self.frame_sampling_method = frame_sampling_method
8492
self.frame_num = frame_num
93+
self.duration = duration
8594
self.frame_key = frame_key
8695
self.frame_fname_template = 'frame_{}.jpg'
8796

@@ -109,7 +118,7 @@ def process_single(self, sample, context=False):
109118
loaded_video_keys = sample[self.video_key]
110119
sample, videos = load_data_with_context(sample, context,
111120
loaded_video_keys, load_video)
112-
video_to_frames = {}
121+
video_to_frame_dir = {}
113122
text = sample[self.text_key]
114123
offset = 0
115124

@@ -124,10 +133,18 @@ def process_single(self, sample, context=False):
124133
video = videos[video_key]
125134
# extract frame videos
126135
if self.frame_sampling_method == 'all_keyframes':
127-
frames = extract_key_frames(video)
136+
if self.duration:
137+
frames = extract_key_frames_by_seconds(
138+
video, self.duration)
139+
else:
140+
frames = extract_key_frames(video)
128141
elif self.frame_sampling_method == 'uniform':
129-
frames = extract_video_frames_uniformly(
130-
video, self.frame_num)
142+
if self.duration:
143+
frames = extract_video_frames_uniformly_by_seconds(
144+
video, self.frame_num, duration=self.duration)
145+
else:
146+
frames = extract_video_frames_uniformly(
147+
video, self.frame_num)
131148
else:
132149
raise ValueError(f'Not support sampling method \
133150
`{self.frame_sampling_method}`.')
@@ -141,23 +158,20 @@ def process_single(self, sample, context=False):
141158
# video path as frames directory
142159
frame_dir = self._get_default_frame_dir(video_key)
143160
os.makedirs(frame_dir, exist_ok=True)
161+
video_to_frame_dir[video_key] = frame_dir
144162

145-
video_to_frames[video_key] = []
146163
for i, frame in enumerate(frames):
147164
frame_path = osp.join(
148165
frame_dir, self.frame_fname_template.format(i))
149166
if not os.path.exists(frame_path):
150167
frame.save(frame_path)
151168

152-
video_to_frames[video_key].append(frame_path)
153-
154169
offset += video_count
155170

156171
if not context:
157172
for vid_key in videos:
158173
close_video(videos[vid_key])
159174

160-
sample[self.frame_key] = json.dumps(video_to_frames)
161-
# sample[self.frame_key] = video_to_frames
175+
sample[self.frame_key] = json.dumps(video_to_frame_dir)
162176

163177
return sample

data_juicer/utils/mm_utils.py

+81-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import datetime
3+
import io
34
import os
45
import re
56
import shutil
@@ -321,7 +322,11 @@ def cut_video_by_seconds(
321322
container = input_video
322323

323324
# create the output video
324-
output_container = load_video(output_video, 'w')
325+
if output_video:
326+
output_container = load_video(output_video, 'w')
327+
else:
328+
output_buffer = io.BytesIO()
329+
output_container = av.open(output_buffer, mode='w', format='mp4')
325330

326331
# add the video stream into the output video according to input video
327332
input_video_stream = container.streams.video[0]
@@ -390,6 +395,11 @@ def cut_video_by_seconds(
390395
if isinstance(input_video, str):
391396
close_video(container)
392397
close_video(output_container)
398+
399+
if not output_video:
400+
output_buffer.seek(0)
401+
return output_buffer
402+
393403
if not os.path.exists(output_video):
394404
logger.warning(f'This video could not be successfully cut in '
395405
f'[{start_seconds}, {end_seconds}] seconds. '
@@ -463,6 +473,39 @@ def process_each_frame(input_video: Union[str, av.container.InputContainer],
463473
if isinstance(input_video, str) else input_video.name)
464474

465475

476+
def extract_key_frames_by_seconds(
477+
input_video: Union[str, av.container.InputContainer],
478+
duration: float = 1):
479+
"""Extract key frames by seconds.
480+
:param input_video: input video path or av.container.InputContainer.
481+
:param duration: duration of each video split in seconds.
482+
"""
483+
# load the input video
484+
if isinstance(input_video, str):
485+
container = load_video(input_video)
486+
elif isinstance(input_video, av.container.InputContainer):
487+
container = input_video
488+
else:
489+
raise ValueError(f'Unsupported type of input_video. Should be one of '
490+
f'[str, av.container.InputContainer], but given '
491+
f'[{type(input_video)}].')
492+
493+
video_duration = get_video_duration(container)
494+
timestamps = np.arange(0, video_duration, duration).tolist()
495+
496+
all_key_frames = []
497+
for i in range(1, len(timestamps)):
498+
output_buffer = cut_video_by_seconds(container, None,
499+
timestamps[i - 1], timestamps[i])
500+
if output_buffer:
501+
cut_inp_container = av.open(output_buffer, format='mp4', mode='r')
502+
key_frames = extract_key_frames(cut_inp_container)
503+
all_key_frames.extend(key_frames)
504+
close_video(cut_inp_container)
505+
506+
return all_key_frames
507+
508+
466509
def extract_key_frames(input_video: Union[str, av.container.InputContainer]):
467510
"""
468511
Extract key frames from the input video. If there is no keyframes in the
@@ -516,6 +559,43 @@ def get_key_frame_seconds(input_video: Union[str,
516559
return ts
517560

518561

562+
def extract_video_frames_uniformly_by_seconds(
563+
input_video: Union[str, av.container.InputContainer],
564+
frame_num: PositiveInt,
565+
duration: float = 1):
566+
"""Extract video frames uniformly by seconds.
567+
:param input_video: input video path or av.container.InputContainer.
568+
:param frame_num: the number of frames to be extracted uniformly from
569+
each video split by duration.
570+
:param duration: duration of each video split in seconds.
571+
"""
572+
# load the input video
573+
if isinstance(input_video, str):
574+
container = load_video(input_video)
575+
elif isinstance(input_video, av.container.InputContainer):
576+
container = input_video
577+
else:
578+
raise ValueError(f'Unsupported type of input_video. Should be one of '
579+
f'[str, av.container.InputContainer], but given '
580+
f'[{type(input_video)}].')
581+
582+
video_duration = get_video_duration(container)
583+
timestamps = np.arange(0, video_duration, duration).tolist()
584+
585+
all_frames = []
586+
for i in range(1, len(timestamps)):
587+
output_buffer = cut_video_by_seconds(container, None,
588+
timestamps[i - 1], timestamps[i])
589+
if output_buffer:
590+
cut_inp_container = av.open(output_buffer, format='mp4', mode='r')
591+
key_frames = extract_video_frames_uniformly(cut_inp_container,
592+
frame_num=frame_num)
593+
all_frames.extend(key_frames)
594+
close_video(cut_inp_container)
595+
596+
return all_frames
597+
598+
519599
def extract_video_frames_uniformly(
520600
input_video: Union[str, av.container.InputContainer],
521601
frame_num: PositiveInt,

tests/ops/mapper/test_video_extract_frames_mapper.py

+97-25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import os.path as osp
3+
import re
34
import copy
45
import unittest
56
import json
@@ -25,21 +26,63 @@ def tearDown(self):
2526
super().tearDown()
2627
shutil.rmtree(self.tmp_dir)
2728

28-
def _run_video_extract_frames_mapper(self,
29-
op,
30-
source_list,
31-
target_list,
32-
num_proc=1):
33-
dataset = Dataset.from_list(source_list)
34-
dataset = dataset.map(op.process, batch_size=2, num_proc=num_proc)
35-
res_list = dataset.to_list()
36-
self.assertEqual(res_list, target_list)
37-
3829
def _get_frames_list(self, filepath, frame_dir, frame_num):
3930
frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0])
4031
frames_list = [osp.join(frames_dir, f'frame_{i}.jpg') for i in range(frame_num)]
4132
return frames_list
4233

34+
def _get_frames_dir(self, filepath, frame_dir):
35+
frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0])
36+
return frames_dir
37+
38+
def _sort_files(self, file_list):
39+
return sorted(file_list, key=lambda x: int(re.search(r'(\d+)', x).group()))
40+
41+
def test_duration(self):
42+
ds_list = [{
43+
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
44+
'videos': [self.vid1_path]
45+
}, {
46+
'text':
47+
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
48+
'videos': [self.vid2_path]
49+
}, {
50+
'text':
51+
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
52+
'videos': [self.vid3_path]
53+
}]
54+
55+
frame_num = 2
56+
frame_dir=os.path.join(self.tmp_dir, 'test1')
57+
vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir)
58+
vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir)
59+
vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir)
60+
61+
tgt_list = copy.deepcopy(ds_list)
62+
tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})})
63+
tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})})
64+
tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})})
65+
66+
op = VideoExtractFramesMapper(
67+
frame_sampling_method='uniform',
68+
frame_num=frame_num,
69+
duration=0,
70+
frame_dir=frame_dir)
71+
72+
dataset = Dataset.from_list(ds_list)
73+
dataset = dataset.map(op.process, batch_size=2, num_proc=1)
74+
res_list = dataset.to_list()
75+
self.assertEqual(res_list, tgt_list)
76+
self.assertListEqual(
77+
self._sort_files(os.listdir(vid1_frame_dir)),
78+
[f'frame_{i}.jpg' for i in range(frame_num)])
79+
self.assertListEqual(
80+
self._sort_files(os.listdir(vid2_frame_dir)),
81+
[f'frame_{i}.jpg' for i in range(frame_num)])
82+
self.assertListEqual(
83+
self._sort_files(os.listdir(vid3_frame_dir)),
84+
[f'frame_{i}.jpg' for i in range(frame_num)])
85+
4386
def test_uniform_sampling(self):
4487
ds_list = [{
4588
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
@@ -55,22 +98,35 @@ def test_uniform_sampling(self):
5598
}]
5699
frame_num = 3
57100
frame_dir=os.path.join(self.tmp_dir, 'test1')
101+
vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir)
102+
vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir)
103+
vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir)
58104

59105
tgt_list = copy.deepcopy(ds_list)
60-
tgt_list[0].update({Fields.video_frames:
61-
json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, frame_num)})})
62-
tgt_list[1].update({Fields.video_frames:
63-
json.dumps({self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, frame_num)})})
64-
tgt_list[2].update({Fields.video_frames:
65-
json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, frame_num)})})
66-
106+
tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})})
107+
tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})})
108+
tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})})
109+
67110
op = VideoExtractFramesMapper(
68111
frame_sampling_method='uniform',
69112
frame_num=frame_num,
113+
duration=10,
70114
frame_dir=frame_dir)
71-
self._run_video_extract_frames_mapper(op, ds_list, tgt_list)
72115

73-
116+
dataset = Dataset.from_list(ds_list)
117+
dataset = dataset.map(op.process, batch_size=2, num_proc=1)
118+
res_list = dataset.to_list()
119+
self.assertEqual(res_list, tgt_list)
120+
self.assertListEqual(
121+
self._sort_files(os.listdir(vid1_frame_dir)),
122+
[f'frame_{i}.jpg' for i in range(3)])
123+
self.assertListEqual(
124+
self._sort_files(os.listdir(vid2_frame_dir)),
125+
[f'frame_{i}.jpg' for i in range(6)])
126+
self.assertListEqual(
127+
self._sort_files(os.listdir(vid3_frame_dir)),
128+
[f'frame_{i}.jpg' for i in range(12)])
129+
74130
def test_all_keyframes_sampling(self):
75131
ds_list = [{
76132
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
@@ -86,22 +142,38 @@ def test_all_keyframes_sampling(self):
86142
'videos': [self.vid3_path]
87143
}]
88144
frame_dir=os.path.join(self.tmp_dir, 'test2')
145+
vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir)
146+
vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir)
147+
vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir)
89148

90149
tgt_list = copy.deepcopy(ds_list)
91150
tgt_list[0].update({Fields.video_frames:
92-
json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, 3)})})
151+
json.dumps({self.vid1_path: vid1_frame_dir})})
93152
tgt_list[1].update({Fields.video_frames: json.dumps({
94-
self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, 3),
95-
self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6)
153+
self.vid2_path: vid2_frame_dir,
154+
self.vid3_path: vid3_frame_dir
96155
})})
97156
tgt_list[2].update({Fields.video_frames:
98-
json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6)})})
157+
json.dumps({self.vid3_path: vid3_frame_dir})})
99158

100159
op = VideoExtractFramesMapper(
101160
frame_sampling_method='all_keyframes',
102-
frame_dir=frame_dir)
103-
self._run_video_extract_frames_mapper(op, ds_list, tgt_list)
161+
frame_dir=frame_dir,
162+
duration=5)
104163

164+
dataset = Dataset.from_list(ds_list)
165+
dataset = dataset.map(op.process, batch_size=2, num_proc=2)
166+
res_list = dataset.to_list()
167+
self.assertEqual(res_list, tgt_list)
168+
self.assertListEqual(
169+
self._sort_files(os.listdir(vid1_frame_dir)),
170+
[f'frame_{i}.jpg' for i in range(4)])
171+
self.assertListEqual(
172+
self._sort_files(os.listdir(vid2_frame_dir)),
173+
[f'frame_{i}.jpg' for i in range(5)])
174+
self.assertListEqual(
175+
self._sort_files(os.listdir(vid3_frame_dir)),
176+
[f'frame_{i}.jpg' for i in range(13)])
105177

106178

107179
if __name__ == '__main__':

0 commit comments

Comments
 (0)