Skip to content

Commit ca5de7c

Browse files
authored
[Ready] motion_score_raft (modelscope#478)
* add raft op * add docs * update docs * refine
1 parent 898372d commit ca5de7c

9 files changed

+372
-45
lines changed

configs/config_all.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,16 @@ process:
465465
sampling_fps: 2 # the samplig rate of frames_per_second to compute optical flow
466466
size: null # resize frames along the smaller edge before computing optical flow, or a sequence like (h, w)
467467
max_size: null # maximum allowed for the longer edge of resized frames
468+
divisible: 1 # The number that the dimensions must be divisible by.
469+
relative: false # whether to normalize the optical flow magnitude to [0, 1], relative to the frame's diagonal length
470+
any_or_all: any # keep this sample when any/all videos meet the filter condition
471+
- video_motion_score_raft_filter: # Keep samples with video motion scores (based on RAFT model) within a specific range.
472+
min_score: 1.0 # the minimum motion score to keep samples
473+
max_score: 10000.0 # the maximum motion score to keep samples
474+
sampling_fps: 2 # the samplig rate of frames_per_second to compute optical flow
475+
size: null # resize frames along the smaller edge before computing optical flow, or a sequence like (h, w)
476+
max_size: null # maximum allowed for the longer edge of resized frames
477+
divisible: 8 # The number that the dimensions must be divisible by.
468478
relative: false # whether to normalize the optical flow magnitude to [0, 1], relative to the frame's diagonal length
469479
any_or_all: any # keep this sample when any/all videos meet the filter condition
470480
- video_nsfw_filter: # filter samples according to the nsfw scores of videos in them

data_juicer/ops/filter/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .video_frames_text_similarity_filter import \
3636
VideoFramesTextSimilarityFilter
3737
from .video_motion_score_filter import VideoMotionScoreFilter
38+
from .video_motion_score_raft_filter import VideoMotionScoreRaftFilter
3839
from .video_nsfw_filter import VideoNSFWFilter
3940
from .video_ocr_area_ratio_filter import VideoOcrAreaRatioFilter
4041
from .video_resolution_filter import VideoResolutionFilter
@@ -57,7 +58,8 @@
5758
'TextActionFilter', 'TextEntityDependencyFilter', 'TextLengthFilter',
5859
'TokenNumFilter', 'VideoAestheticsFilter', 'VideoAspectRatioFilter',
5960
'VideoDurationFilter', 'VideoFramesTextSimilarityFilter',
60-
'VideoMotionScoreFilter', 'VideoNSFWFilter', 'VideoOcrAreaRatioFilter',
61-
'VideoResolutionFilter', 'VideoTaggingFromFramesFilter',
62-
'VideoWatermarkFilter', 'WordRepetitionFilter', 'WordsNumFilter'
61+
'VideoMotionScoreFilter', 'VideoMotionScoreRaftFilter', 'VideoNSFWFilter',
62+
'VideoOcrAreaRatioFilter', 'VideoResolutionFilter',
63+
'VideoTaggingFromFramesFilter', 'VideoWatermarkFilter',
64+
'WordRepetitionFilter', 'WordsNumFilter'
6365
]

data_juicer/ops/filter/video_motion_score_filter.py

+31-36
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from data_juicer.utils.constant import Fields, StatsKeys
99
from data_juicer.utils.lazy_loader import LazyLoader
10+
from data_juicer.utils.mm_utils import calculate_resized_dimensions
1011

1112
from ..base_op import OPERATORS, UNFORKABLE, Filter
1213

@@ -48,6 +49,7 @@ def __init__(self,
4849
size: Union[PositiveInt, Tuple[PositiveInt],
4950
Tuple[PositiveInt, PositiveInt], None] = None,
5051
max_size: Optional[PositiveInt] = None,
52+
divisible: PositiveInt = 1,
5153
relative: bool = False,
5254
any_or_all: str = 'any',
5355
*args,
@@ -69,6 +71,7 @@ def __init__(self,
6971
being resized according to size, size will be overruled so that the
7072
longer edge is equal to max_size. As a result, the smaller edge may
7173
be shorter than size. This is only supported if size is an int.
74+
:param divisible: The number that the dimensions must be divisible by.
7275
:param relative: If `True`, the optical flow magnitude is normalized to
7376
a [0, 1] range, relative to the frame's diagonal length.
7477
:param any_or_all: keep this sample with 'any' or 'all' strategy of
@@ -92,6 +95,7 @@ def __init__(self,
9295
size = (size, )
9396
self.size = size
9497
self.max_size = max_size
98+
self.divisible = divisible
9599
self.relative = relative
96100

97101
self.extra_kwargs = self._default_kwargs
@@ -104,7 +108,21 @@ def __init__(self,
104108
f'Can only be one of ["any", "all"].')
105109
self.any = (any_or_all == 'any')
106110

107-
def compute_stats_single(self, sample, context=False):
111+
def setup_model(self, rank=None):
112+
self.model = cv2.calcOpticalFlowFarneback
113+
114+
def compute_flow(self, prev_frame, curr_frame):
115+
curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
116+
if prev_frame is None:
117+
flow = None
118+
else:
119+
flow = self.model(prev_frame, curr_frame, None,
120+
**self.extra_kwargs)
121+
return flow, curr_frame
122+
123+
def compute_stats_single(self, sample, rank=None, context=False):
124+
self.rank = rank
125+
108126
# check if it's computed already
109127
if StatsKeys.video_motion_score in sample[Fields.stats]:
110128
return sample
@@ -115,6 +133,8 @@ def compute_stats_single(self, sample, context=False):
115133
[], dtype=np.float64)
116134
return sample
117135

136+
self.setup_model(rank)
137+
118138
# load videos
119139
loaded_video_keys = sample[self.video_key]
120140
unique_motion_scores = {}
@@ -133,6 +153,11 @@ def compute_stats_single(self, sample, context=False):
133153
# at least two frames for computing optical flow
134154
sampling_step = max(min(sampling_step, total_frames - 1),
135155
1)
156+
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
157+
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
158+
new_size = calculate_resized_dimensions(
159+
(height, width), self.size, self.max_size,
160+
self.divisible)
136161

137162
prev_frame = None
138163
frame_count = 0
@@ -143,27 +168,21 @@ def compute_stats_single(self, sample, context=False):
143168
# a corrupt frame or reaching the end of the video.
144169
break
145170

146-
height, width, _ = frame.shape
147-
new_size = _compute_resized_output_size(
148-
(height, width), self.size, self.max_size)
149171
if new_size != (height, width):
150172
frame = cv2.resize(frame,
151173
new_size,
152174
interpolation=cv2.INTER_AREA)
153175

154-
gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
155-
if prev_frame is None:
156-
prev_frame = gray_frame
176+
# return flow of shape (H, W, 2) and transformed frame
177+
# of shape (H, W, 3) in BGR mode
178+
flow, prev_frame = self.compute_flow(prev_frame, frame)
179+
if flow is None:
157180
continue
158-
159-
flow = cv2.calcOpticalFlowFarneback(
160-
prev_frame, gray_frame, None, **self.extra_kwargs)
161181
mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
162182
frame_motion_score = np.mean(mag)
163183
if self.relative:
164-
frame_motion_score /= np.hypot(*flow.shape[:2])
184+
frame_motion_score /= np.hypot(*frame.shape[:2])
165185
video_motion_scores.append(frame_motion_score)
166-
prev_frame = gray_frame
167186

168187
# quickly skip frames
169188
frame_count += sampling_step
@@ -197,27 +216,3 @@ def process_single(self, sample):
197216
return keep_bools.any()
198217
else:
199218
return keep_bools.all()
200-
201-
202-
def _compute_resized_output_size(
203-
frame_size: Tuple[int, int],
204-
size: Union[Tuple[PositiveInt], Tuple[PositiveInt, PositiveInt]],
205-
max_size: Optional[int] = None,
206-
) -> Tuple[int, int]:
207-
h, w = frame_size
208-
short, long = (w, h) if w <= h else (h, w)
209-
210-
if size is None: # no change
211-
new_short, new_long = short, long
212-
elif len(size) == 1: # specified size only for the smallest edge
213-
new_short = size[0]
214-
new_long = int(new_short * long / short)
215-
else: # specified both h and w
216-
new_short, new_long = min(size), max(size)
217-
218-
if max_size is not None and new_long > max_size:
219-
new_short = int(max_size * new_short / new_long)
220-
new_long = max_size
221-
222-
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
223-
return new_h, new_w
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import sys
2+
from typing import Optional, Tuple, Union
3+
4+
from pydantic import PositiveFloat, PositiveInt
5+
6+
from data_juicer import cuda_device_count
7+
from data_juicer.ops.filter.video_motion_score_filter import \
8+
VideoMotionScoreFilter
9+
from data_juicer.utils.lazy_loader import LazyLoader
10+
11+
from ..base_op import OPERATORS, UNFORKABLE
12+
13+
torch = LazyLoader('torch', 'torch')
14+
tvm = LazyLoader('tvm', 'torchvision.models')
15+
tvt = LazyLoader('tvt', 'torchvision.transforms')
16+
17+
OP_NAME = 'video_motion_score_raft_filter'
18+
19+
20+
@UNFORKABLE.register_module(OP_NAME)
21+
@OPERATORS.register_module(OP_NAME)
22+
class VideoMotionScoreRaftFilter(VideoMotionScoreFilter):
23+
"""Filter to keep samples with video motion scores within a specified range.
24+
This operator utilizes the RAFT (Recurrent All-Pairs Field Transforms)
25+
model from torchvision to predict optical flow between video frames.
26+
27+
For further details, refer to the official torchvision documentation:
28+
https://pytorch.org/vision/main/models/raft.html
29+
30+
The original paper on RAFT is available here:
31+
https://arxiv.org/abs/2003.12039
32+
"""
33+
34+
_accelerator = 'cuda'
35+
_default_kwargs = {}
36+
37+
def __init__(self,
38+
min_score: float = 1.0,
39+
max_score: float = sys.float_info.max,
40+
sampling_fps: PositiveFloat = 2,
41+
size: Union[PositiveInt, Tuple[PositiveInt],
42+
Tuple[PositiveInt, PositiveInt], None] = None,
43+
max_size: Optional[PositiveInt] = None,
44+
divisible: PositiveInt = 8,
45+
relative: bool = False,
46+
any_or_all: str = 'any',
47+
*args,
48+
**kwargs):
49+
super().__init__(min_score, max_score, sampling_fps, size, max_size,
50+
divisible, relative, any_or_all, *args, **kwargs)
51+
52+
def setup_model(self, rank=None):
53+
self.model = tvm.optical_flow.raft_large(
54+
weights=tvm.optical_flow.Raft_Large_Weights.DEFAULT,
55+
progress=False)
56+
if self.use_cuda():
57+
rank = rank if rank is not None else 0
58+
rank = rank % cuda_device_count()
59+
self.device = f'cuda:{rank}'
60+
else:
61+
self.device = 'cpu'
62+
self.model.to(self.device)
63+
self.model.eval()
64+
65+
self.transforms = tvt.Compose([
66+
tvt.ToTensor(),
67+
tvt.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
68+
tvt.Lambda(lambda img: img.flip(-3).unsqueeze(0)), # BGR to RGB
69+
])
70+
71+
def compute_flow(self, prev_frame, curr_frame):
72+
curr_frame = self.transforms(curr_frame).to(self.device)
73+
if prev_frame is None:
74+
flow = None
75+
else:
76+
with torch.inference_mode():
77+
flows = self.model(prev_frame, curr_frame)
78+
flow = flows[-1][0].cpu().numpy().transpose(
79+
(1, 2, 0)) # 2, H, W -> H, W, 2
80+
return flow, curr_frame

data_juicer/utils/mm_utils.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import re
55
import shutil
6-
from typing import List, Optional, Union
6+
from typing import List, Optional, Tuple, Union
77

88
import av
99
import numpy as np
@@ -164,6 +164,58 @@ def iou(box1, box2):
164164
return 1.0 * intersection / union
165165

166166

167+
def calculate_resized_dimensions(
168+
original_size: Tuple[PositiveInt, PositiveInt],
169+
target_size: Union[PositiveInt, Tuple[PositiveInt, PositiveInt]],
170+
max_length: Optional[int] = None,
171+
divisible: PositiveInt = 1) -> Tuple[int, int]:
172+
"""
173+
Resize dimensions based on specified constraints.
174+
175+
:param original_size: The original dimensions as (height, width).
176+
:param target_size: Desired target size; can be a single integer
177+
(short edge) or a tuple (height, width).
178+
:param max_length: Maximum allowed length for the longer edge.
179+
:param divisible: The number that the dimensions must be divisible by.
180+
:return: Resized dimensions as (height, width).
181+
"""
182+
183+
height, width = original_size
184+
short_edge, long_edge = sorted((width, height))
185+
186+
# Normalize target_size to a tuple
187+
if isinstance(target_size, int):
188+
target_size = (target_size, )
189+
190+
# Initialize new dimensions
191+
if target_size:
192+
if len(target_size) == 1: # Only the smaller edge is specified
193+
new_short_edge = target_size[0]
194+
new_long_edge = int(new_short_edge * long_edge / short_edge)
195+
else: # Both dimensions are specified
196+
new_short_edge = min(target_size)
197+
new_long_edge = max(target_size)
198+
else: # No change
199+
new_short_edge, new_long_edge = short_edge, long_edge
200+
201+
# Enforce maximum length constraint
202+
if max_length is not None and new_long_edge > max_length:
203+
scaling_factor = max_length / new_long_edge
204+
new_short_edge = int(new_short_edge * scaling_factor)
205+
new_long_edge = max_length
206+
207+
# Determine final dimensions based on original orientation
208+
resized_dimensions = ((new_short_edge,
209+
new_long_edge) if width <= height else
210+
(new_long_edge, new_short_edge))
211+
212+
# Ensure final dimensions are divisible by the specified value
213+
resized_dimensions = tuple(
214+
int(dim / divisible) * divisible for dim in resized_dimensions)
215+
216+
return resized_dimensions
217+
218+
167219
# Audios
168220
def load_audios(paths):
169221
return [load_audio(path) for path in paths]

docs/Operators.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types.
1212
|-----------------------------------|:------:|-------------------------------------------------|
1313
| [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data |
1414
| [ Mapper ]( #mapper ) | 52 | Edits and transforms samples |
15-
| [ Filter ]( #filter ) | 43 | Filters out low-quality samples |
15+
| [ Filter ]( #filter ) | 44 | Filters out low-quality samples |
1616
| [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples |
1717
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |
1818

@@ -149,6 +149,7 @@ All the specific operators are listed below, each featured with several capabili
149149
| video_duration_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep data samples whose videos' durations are within a specified range | [code](../data_juicer/ops/filter/video_duration_filter.py) | [tests](../tests/ops/filter/test_video_duration_filter.py) |
150150
| video_frames_text_similarity_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keep data samples whose similarities between sampled video frame images and text are within a specific range | [code](../data_juicer/ops/filter/video_frames_text_similarity_filter.py) | [tests](../tests/ops/filter/test_video_frames_text_similarity_filter.py) |
151151
| video_motion_score_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep samples with video motion scores within a specific range | [code](../data_juicer/ops/filter/video_motion_score_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_filter.py) |
152+
| video_motion_score_raft_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep samples with video motion scores (based on RAFT model) within a specific range | [code](../data_juicer/ops/filter/video_motion_score_raft_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_raft_filter.py) |
152153
| video_nsfw_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keeps samples containing videos with NSFW scores below the threshold | [code](../data_juicer/ops/filter/video_nsfw_filter.py) | [tests](../tests/ops/filter/test_video_nsfw_filter.py) |
153154
| video_ocr_area_ratio_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keep data samples whose detected text area ratios for specified frames in the video are within a specified range | [code](../data_juicer/ops/filter/video_ocr_area_ratio_filter.py) | [tests](../tests/ops/filter/test_video_ocr_area_ratio_filter.py) |
154155
| video_resolution_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keeps samples containing videos with horizontal and vertical resolutions within the specified range | [code](../data_juicer/ops/filter/video_resolution_filter.py) | [tests](../tests/ops/filter/test_video_resolution_filter.py) |

0 commit comments

Comments
 (0)