Skip to content

Commit 5191cf9

Browse files
committed
add op video_extract_frames_mapper
1 parent 4ab426e commit 5191cf9

File tree

5 files changed

+292
-8
lines changed

5 files changed

+292
-8
lines changed

data_juicer/ops/base_op.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,22 @@ def process_batched(self, samples, *args, **kwargs):
257257
keys = samples.keys()
258258
first_key = next(iter(keys))
259259
num_samples = len(samples[first_key])
260+
261+
new_keys = {}
260262
for i in range(num_samples):
261263
this_sample = {key: samples[key][i] for key in keys}
262264
res_sample = self.process_single(this_sample, *args, **kwargs)
263-
for key in keys:
264-
samples[key][i] = res_sample[key]
265+
res_keys = res_sample.keys()
266+
for key in res_keys:
267+
if key not in keys:
268+
if key not in new_keys:
269+
new_keys.update({key: []})
270+
new_keys[key].append(res_sample[key])
271+
else:
272+
samples[key][i] = res_sample[key]
273+
274+
for k, v in new_keys.items():
275+
samples[k] = v
265276

266277
return samples
267278

data_juicer/ops/mapper/__init__.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from .video_captioning_from_summarizer_mapper import \
5151
VideoCaptioningFromSummarizerMapper
5252
from .video_captioning_from_video_mapper import VideoCaptioningFromVideoMapper
53+
from .video_extract_frames_mapper import VideoExtractFramesMapper
5354
from .video_face_blur_mapper import VideoFaceBlurMapper
5455
from .video_ffmpeg_wrapped_mapper import VideoFFmpegWrappedMapper
5556
from .video_remove_watermark_mapper import VideoRemoveWatermarkMapper
@@ -82,10 +83,10 @@
8283
'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper',
8384
'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper',
8485
'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper',
85-
'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper',
86-
'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper',
87-
'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper',
88-
'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper',
89-
'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper',
90-
'WhitespaceNormalizationMapper'
86+
'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper',
87+
'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
88+
'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper',
89+
'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper',
90+
'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper',
91+
'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper'
9192
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import json
2+
import os
3+
import os.path as osp
4+
5+
from pydantic import PositiveInt
6+
7+
from data_juicer.utils.constant import Fields
8+
from data_juicer.utils.file_utils import create_directory_if_not_exists
9+
from data_juicer.utils.mm_utils import (SpecialTokens, close_video,
10+
extract_key_frames,
11+
extract_video_frames_uniformly,
12+
load_data_with_context, load_video)
13+
14+
from ..base_op import OPERATORS, Mapper
15+
from ..op_fusion import LOADED_VIDEOS
16+
17+
OP_NAME = 'video_extract_frames_mapper'
18+
19+
20+
@OPERATORS.register_module(OP_NAME)
21+
@LOADED_VIDEOS.register_module(OP_NAME)
22+
class VideoExtractFramesMapper(Mapper):
23+
"""Mapper to extract frames from video files according to specified methods.
24+
Extracted Frames Data Format:
25+
The data format for the extracted frames is a dictionary mapping
26+
video keys to lists of file paths where the extracted frames are saved.
27+
The dictionary follows the structure:
28+
{
29+
"video_key_1": [
30+
"/${frame_dir}/video_key_1_filename/frame_1.jpg",
31+
"/${frame_dir}/video_key_1_filename/frame_2.jpg",
32+
...],
33+
"video_key_2": [
34+
"/${frame_dir}/video_key_2_filename/frame_1.jpg",
35+
"/${frame_dir}/video_key_2_filename/frame_2.jpg",
36+
...],
37+
...
38+
}
39+
"""
40+
41+
_batched_op = True
42+
43+
def __init__(
44+
self,
45+
frame_sampling_method: str = 'all_keyframes',
46+
frame_num: PositiveInt = 3,
47+
frame_dir: str = None,
48+
frame_key=Fields.video_frames,
49+
*args,
50+
**kwargs,
51+
):
52+
"""
53+
Initialization method.
54+
:param frame_sampling_method: sampling method of extracting frame
55+
videos from the videos. Should be one of
56+
["all_keyframes", "uniform"].
57+
The former one extracts all key frames (the number
58+
of which depends on the duration of the video) and the latter
59+
one extract specified number of frames uniformly from the video.
60+
Default: "all_keyframes".
61+
:param frame_num: the number of frames to be extracted uniformly from
62+
the video. Only works when frame_sampling_method is "uniform". If
63+
it's 1, only the middle frame will be extracted. If it's 2, only
64+
the first and the last frames will be extracted. If it's larger
65+
than 2, in addition to the first and the last frames, other frames
66+
will be extracted uniformly within the video duration.
67+
:param frame_dir: Output directory to save extracted frames.
68+
If None, a default directory based on the video file path is used.
69+
:param frame_key: The name of field to save generated frames info.
70+
:param args: extra args
71+
:param kwargs: extra args
72+
"""
73+
super().__init__(*args, **kwargs)
74+
self._init_parameters = self.remove_extra_parameters(locals())
75+
76+
if frame_sampling_method not in ['all_keyframes', 'uniform']:
77+
raise ValueError(
78+
f'Frame sampling method '
79+
f'[{frame_sampling_method}] is not supported. '
80+
f'Can only be one of ["all_keyframes", "uniform"].')
81+
82+
self.frame_dir = frame_dir
83+
self.frame_sampling_method = frame_sampling_method
84+
self.frame_num = frame_num
85+
self.frame_key = frame_key
86+
self.frame_fname_template = 'frame_{}.jpg'
87+
88+
def _get_default_frame_dir(self, original_filepath):
89+
original_dir = os.path.dirname(original_filepath)
90+
dir_token = f'/{Fields.multimodal_data_output_dir}/'
91+
if dir_token in original_dir:
92+
original_dir = original_dir.split(dir_token)[0]
93+
new_dir = os.path.join(
94+
original_dir, f'{Fields.multimodal_data_output_dir}/{OP_NAME}')
95+
create_directory_if_not_exists(new_dir)
96+
return osp.join(new_dir,
97+
osp.splitext(osp.basename(original_filepath))[0])
98+
99+
def process_single(self, sample, context=False):
100+
# check if it's generated already
101+
if self.frame_key in sample:
102+
return sample
103+
104+
# there is no videos in this sample
105+
if self.video_key not in sample or not sample[self.video_key]:
106+
return []
107+
108+
# load videos
109+
loaded_video_keys = sample[self.video_key]
110+
sample, videos = load_data_with_context(sample, context,
111+
loaded_video_keys, load_video)
112+
video_to_frames = {}
113+
text = sample[self.text_key]
114+
offset = 0
115+
116+
for chunk in text.split(SpecialTokens.eoc):
117+
video_count = chunk.count(SpecialTokens.video)
118+
# no video or no text
119+
if video_count == 0 or len(chunk) == 0:
120+
continue
121+
else:
122+
for video_key in loaded_video_keys[offset:offset +
123+
video_count]:
124+
video = videos[video_key]
125+
# extract frame videos
126+
if self.frame_sampling_method == 'all_keyframes':
127+
frames = extract_key_frames(video)
128+
elif self.frame_sampling_method == 'uniform':
129+
frames = extract_video_frames_uniformly(
130+
video, self.frame_num)
131+
else:
132+
raise ValueError(f'Not support sampling method \
133+
`{self.frame_sampling_method}`.')
134+
frames = [frame.to_image() for frame in frames]
135+
136+
if self.frame_dir:
137+
frame_dir = osp.join(
138+
self.frame_dir,
139+
osp.splitext(osp.basename(video_key))[0])
140+
else:
141+
# video path as frames directory
142+
frame_dir = self._get_default_frame_dir(video_key)
143+
os.makedirs(frame_dir, exist_ok=True)
144+
145+
video_to_frames[video_key] = []
146+
for i, frame in enumerate(frames):
147+
frame_path = osp.join(
148+
frame_dir, self.frame_fname_template.format(i))
149+
if not os.path.exists(frame_path):
150+
frame.save(frame_path)
151+
152+
video_to_frames[video_key].append(frame_path)
153+
154+
offset += video_count
155+
156+
if not context:
157+
for vid_key in videos:
158+
close_video(videos[vid_key])
159+
160+
sample[self.frame_key] = json.dumps(video_to_frames)
161+
# sample[self.frame_key] = video_to_frames
162+
163+
return sample

data_juicer/utils/constant.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class Fields(object):
1616
context = DEFAULT_PREFIX + 'context__'
1717
suffix = DEFAULT_PREFIX + 'suffix__'
1818

19+
video_frames = DEFAULT_PREFIX + 'video_frames__'
1920
# video_frame_tags
2021
video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__'
2122
video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
import os.path as osp
3+
import copy
4+
import unittest
5+
import json
6+
import tempfile
7+
import shutil
8+
from data_juicer.core.data import NestedDataset as Dataset
9+
from data_juicer.ops.mapper.video_extract_frames_mapper import \
10+
VideoExtractFramesMapper
11+
from data_juicer.utils.constant import Fields
12+
from data_juicer.utils.mm_utils import SpecialTokens
13+
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
14+
15+
16+
class VideoExtractFramesMapperTest(DataJuicerTestCaseBase):
17+
18+
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data')
19+
vid1_path = os.path.join(data_path, 'video1.mp4')
20+
vid2_path = os.path.join(data_path, 'video2.mp4')
21+
vid3_path = os.path.join(data_path, 'video3.mp4')
22+
tmp_dir = tempfile.TemporaryDirectory().name
23+
24+
def tearDown(self):
25+
super().tearDown()
26+
shutil.rmtree(self.tmp_dir)
27+
28+
def _run_video_extract_frames_mapper(self,
29+
op,
30+
source_list,
31+
target_list,
32+
num_proc=1):
33+
dataset = Dataset.from_list(source_list)
34+
dataset = dataset.map(op.process, batch_size=2, num_proc=num_proc)
35+
res_list = dataset.to_list()
36+
self.assertEqual(res_list, target_list)
37+
38+
def _get_frames_list(self, filepath, frame_dir, frame_num):
39+
frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0])
40+
frames_list = [osp.join(frames_dir, f'frame_{i}.jpg') for i in range(frame_num)]
41+
return frames_list
42+
43+
def test_uniform_sampling(self):
44+
ds_list = [{
45+
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
46+
'videos': [self.vid1_path]
47+
}, {
48+
'text':
49+
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
50+
'videos': [self.vid2_path]
51+
}, {
52+
'text':
53+
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
54+
'videos': [self.vid3_path]
55+
}]
56+
frame_num = 3
57+
frame_dir=os.path.join(self.tmp_dir, 'test1')
58+
59+
tgt_list = copy.deepcopy(ds_list)
60+
tgt_list[0].update({Fields.video_frames:
61+
json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, frame_num)})})
62+
tgt_list[1].update({Fields.video_frames:
63+
json.dumps({self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, frame_num)})})
64+
tgt_list[2].update({Fields.video_frames:
65+
json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, frame_num)})})
66+
67+
op = VideoExtractFramesMapper(
68+
frame_sampling_method='uniform',
69+
frame_num=frame_num,
70+
frame_dir=frame_dir)
71+
self._run_video_extract_frames_mapper(op, ds_list, tgt_list)
72+
73+
74+
def test_all_keyframes_sampling(self):
75+
ds_list = [{
76+
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
77+
'videos': [self.vid1_path]
78+
}, {
79+
'text':
80+
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}' + \
81+
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
82+
'videos': [self.vid2_path, self.vid3_path]
83+
}, {
84+
'text':
85+
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
86+
'videos': [self.vid3_path]
87+
}]
88+
frame_dir=os.path.join(self.tmp_dir, 'test2')
89+
90+
tgt_list = copy.deepcopy(ds_list)
91+
tgt_list[0].update({Fields.video_frames:
92+
json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, 3)})})
93+
tgt_list[1].update({Fields.video_frames: json.dumps({
94+
self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, 3),
95+
self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6)
96+
})})
97+
tgt_list[2].update({Fields.video_frames:
98+
json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6)})})
99+
100+
op = VideoExtractFramesMapper(
101+
frame_sampling_method='all_keyframes',
102+
frame_dir=frame_dir)
103+
self._run_video_extract_frames_mapper(op, ds_list, tgt_list)
104+
105+
106+
107+
if __name__ == '__main__':
108+
unittest.main()

0 commit comments

Comments
 (0)