Skip to content

Commit a002f4c

Browse files
committed
Merge branch 'main' into ray_actor
2 parents f27e8c7 + 46062f8 commit a002f4c

20 files changed

+852
-226
lines changed

README.md

+16
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,22 @@ The dependency options are listed below:
197197
| `.[tools]` | Install dependencies for dedicated tools, such as quality classifiers. |
198198
| `.[sandbox]` | Install all dependencies for sandbox. |
199199

200+
- Install dependencies for specific OPs
201+
202+
With the growth of the number of OPs, the dependencies of all OPs becomes very heavy. Instead of using the command `pip install -v -e .[sci]` to install all dependencies,
203+
we provide two alternative, lighter options:
204+
205+
- Automatic Minimal Dependency Installation: During the execution of Data-Juicer, minimal dependencies will be automatically installed. This allows for immediate execution, but may potentially lead to dependency conflicts.
206+
207+
- Manual Minimal Dependency Installation: To manually install minimal dependencies tailored to a specific execution configuration, run the following command:
208+
```shell
209+
# only for installation from source
210+
python tools/dj_install.py --config path_to_your_data-juicer_config_file
211+
212+
# use command line tool
213+
dj-install --config path_to_your_data-juicer_config_file
214+
```
215+
200216
### Using pip
201217

202218
- Run the following command to install the latest released `data_juicer` using `pip`:

README_ZH.md

+15
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,21 @@ pip install -v -e .[tools] # 安装部分工具库的依赖
178178
| `.[tools]` | 安装专用工具库(如质量分类器)所需的依赖项 |
179179
| `.[sandbox]` | 安装沙盒实验室的基础依赖 |
180180

181+
* 只安装部分算子依赖
182+
183+
随着OP数量的增长,所有OP的依赖变得很重。为此,我们提供了两个替代的、更轻量的选项,作为使用命令`pip install -v -e .[sci]`安装所有依赖的替代:
184+
185+
* 自动最小依赖安装:在执行Data-Juicer的过程中,将自动安装最小依赖。也就是说你可以直接执行,但这种方式可能会导致一些依赖冲突。
186+
187+
* 手动最小依赖安装:可以通过如下指令手动安装适合特定执行配置的最小依赖:
188+
```shell
189+
# 适用于从源码安装
190+
python tools/dj_install.py --config path_to_your_data-juicer_config_file
191+
192+
# 使用命令行工具
193+
dj-install --config path_to_your_data-juicer_config_file
194+
```
195+
181196
### 使用 pip 安装
182197

183198
* 运行以下命令用 `pip` 安装 `data_juicer` 的最新发布版本:

configs/config_all.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,11 @@ process:
341341
horizontal_flip: false # flip frame image horizontally (left to right).
342342
vertical_flip: false # flip frame image vertically (top to bottom).
343343
mem_required: '20GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
344+
- video_extract_frames_mapper: # extract frames from video files according to specified methods
345+
frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes".
346+
frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration.
347+
duration: 0 # The duration of each segment in seconds. If 0, frames are extracted from the entire video. If duration > 0, the video is segmented into multiple segments based on duration, and frames are extracted from each segment.
348+
frame_dir: None # Output directory to save extracted frames. If None, a default directory based on the video file path is used.
344349
- video_face_blur_mapper: # blur faces detected in videos
345350
cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'.
346351
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']

data_juicer/core/ray_data.py

+45-29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from functools import partial
23

34
import pyarrow as pa
45
from loguru import logger
@@ -13,28 +14,26 @@
1314
rd = LazyLoader('rd', 'ray.data')
1415

1516

16-
def is_valid_path(item, dataset_dir):
17-
full_path = os.path.abspath(os.path.join(dataset_dir, item))
18-
return os.path.exists(full_path)
17+
def get_abs_path(path, dataset_dir):
18+
full_path = os.path.abspath(os.path.join(dataset_dir, path))
19+
if os.path.exists(full_path):
20+
return full_path
21+
else:
22+
return path
1923

2024

21-
def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys):
25+
def convert_to_absolute_paths(samples, dataset_dir, path_keys):
26+
samples = samples.to_pydict()
2227
for key in path_keys:
23-
if key not in dict_with_paths:
24-
continue
25-
if isinstance(dict_with_paths[key], list):
26-
dict_with_paths[key] = [
27-
os.path.abspath(os.path.join(dataset_dir, item))
28-
if isinstance(item, str) and is_valid_path(dataset_dir, item)
29-
else item for item in dict_with_paths[key]
30-
]
31-
elif isinstance(dict_with_paths[key], str):
32-
dict_with_paths[key] = os.path.abspath(
33-
os.path.join(dataset_dir,
34-
dict_with_paths[key])) if is_valid_path(
35-
dict_with_paths[key],
36-
dataset_dir) else dict_with_paths[key]
37-
return dict_with_paths
28+
for idx in range(len(samples[key])):
29+
paths = samples[key][idx]
30+
if isinstance(paths, str):
31+
samples[key][idx] = get_abs_path(paths, dataset_dir)
32+
elif isinstance(paths, list):
33+
samples[key][idx] = [
34+
get_abs_path(item, dataset_dir) for item in paths
35+
]
36+
return pa.Table.from_pydict(samples)
3837

3938

4039
# TODO: check path for nestdataset
@@ -43,22 +42,26 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg):
4342
Set all the path in input data to absolute path.
4443
Checks dataset_dir and project_dir for valid paths.
4544
"""
46-
if not (cfg.video_key in dataset.columns() or cfg.image_key
47-
in dataset.columns() or cfg.audio_key in dataset.columns()):
48-
return dataset
49-
dataset_dir = os.path.dirname(dataset_path)
50-
dataset = dataset.map(lambda item: convert_to_absolute_paths(
51-
item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key]))
52-
logger.info(f"transfer {dataset.count()} sample's paths")
45+
path_keys = []
46+
columns = dataset.columns()
47+
for key in [cfg.video_key, cfg.image_key, cfg.audio_key]:
48+
if key in columns:
49+
path_keys.append(key)
50+
if len(path_keys) > 0:
51+
dataset_dir = os.path.dirname(dataset_path)
52+
dataset = dataset.map_batches(partial(convert_to_absolute_paths,
53+
dataset_dir=dataset_dir,
54+
path_keys=path_keys),
55+
batch_format='pyarrow',
56+
zero_copy_batch=True)
5357
return dataset
5458

5559

5660
def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset:
61+
columns = dataset.columns()
5762
if dataset_path:
5863
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
59-
columns = dataset.columns()
6064
if Fields.stats not in columns:
61-
logger.info(f'columns {columns}')
6265

6366
def process_batch_arrow(table: pa.Table) -> pa.Table:
6467
new_column_data = [{} for _ in range(len(table))]
@@ -77,6 +80,11 @@ def get_num_gpus(op, op_proc):
7780
return 1.0 / proc_per_gpu
7881

7982

83+
def filter_batch(batch, filter_func):
84+
mask = pa.array(filter_func(batch.to_pydict()))
85+
return batch.filter(mask)
86+
87+
8088
class RayDataset(DJDataset):
8189

8290
def __init__(self,
@@ -162,7 +170,15 @@ def _run_single_op(self, op):
162170
if op.stats_export_path is not None:
163171
self.data.write_json(op.stats_export_path,
164172
force_ascii=False)
165-
self.data = self.data.filter(op.process)
173+
if op.is_batched_op():
174+
self.data = self.data.map_batches(partial(
175+
filter_batch, filter_func=op.process),
176+
batch_format='pyarrow',
177+
batch_size=batch_size,
178+
num_gpus=num_gpus,
179+
zero_copy_batch=True)
180+
else:
181+
self.data = self.data.filter(op.process)
166182
else:
167183
logger.error(
168184
'Ray executor only support Filter and Mapper OPs for now')

data_juicer/ops/base_op.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,22 @@ def process_batched(self, samples, *args, **kwargs):
267267
keys = samples.keys()
268268
first_key = next(iter(keys))
269269
num_samples = len(samples[first_key])
270+
271+
new_keys = {}
270272
for i in range(num_samples):
271273
this_sample = {key: samples[key][i] for key in keys}
272274
res_sample = self.process_single(this_sample, *args, **kwargs)
273-
for key in keys:
274-
samples[key][i] = res_sample[key]
275+
res_keys = res_sample.keys()
276+
for key in res_keys:
277+
if key not in keys:
278+
if key not in new_keys:
279+
new_keys.update({key: []})
280+
new_keys[key].append(res_sample[key])
281+
else:
282+
samples[key][i] = res_sample[key]
283+
284+
for k, v in new_keys.items():
285+
samples[k] = v
275286

276287
return samples
277288

data_juicer/ops/filter/flagged_words_filter.py

+56-48
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class FlaggedWordFilter(Filter):
2424
"""Filter to keep samples with flagged-word ratio less than a specific max
2525
value."""
2626

27+
_batched_op = True
28+
2729
def __init__(self,
2830
lang: str = 'en',
2931
tokenization: bool = False,
@@ -72,53 +74,59 @@ def __init__(self,
7274
self.model_key = prepare_model(model_type='sentencepiece',
7375
lang=lang)
7476

75-
def compute_stats_single(self, sample, context=False):
77+
def compute_stats_batched(self, samples, context=False):
7678
# check if it's computed already
77-
if StatsKeys.flagged_words_ratio in sample[Fields.stats]:
78-
return sample
79-
80-
# try to get words from context
79+
samples_list = samples[self.text_key]
80+
samples_stats = samples[Fields.stats]
8181
words_key = f'{InterVars.words}-{self.model_key}'
82-
if context and words_key in sample[Fields.context]:
83-
words = sample[Fields.context][words_key]
84-
else:
85-
tokenizer = get_model(self.model_key)
86-
words = get_words_from_document(
87-
sample[self.text_key],
88-
token_func=tokenizer.encode_as_pieces if tokenizer else None)
89-
if context:
90-
sample[Fields.context][words_key] = words
91-
92-
# try to get refined words from context
93-
refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \
94-
f'{self.use_words_aug}-' \
95-
f'{self.words_aug_group_sizes}-' \
96-
f'{self.words_aug_join_char}'
97-
if context and refined_words_key in sample[Fields.context]:
98-
words = sample[Fields.context][refined_words_key]
99-
else:
100-
words = words_refinement(
101-
words,
102-
lower_case=True,
103-
strip_chars=SPECIAL_CHARACTERS,
104-
use_words_aug=self.use_words_aug,
105-
words_aug_group_sizes=self.words_aug_group_sizes,
106-
words_aug_join_char=self.words_aug_join_char)
107-
if context:
108-
sample[Fields.context][refined_words_key] = words
109-
110-
flagged_words_ratio = (len(
111-
[word
112-
for word in words if word in self.FLAGGED_WORDS[self.lang]]) /
113-
len(words)) if len(words) != 0 else 0.0
114-
115-
if flagged_words_ratio > 1.0:
116-
flagged_words_ratio = 1.0
117-
118-
sample[Fields.stats][
119-
StatsKeys.flagged_words_ratio] = flagged_words_ratio
120-
return sample
121-
122-
def process_single(self, sample):
123-
return sample[Fields.stats][
124-
StatsKeys.flagged_words_ratio] <= self.max_ratio
82+
tokenizer = get_model(self.model_key)
83+
for idx, stat in enumerate(samples_stats):
84+
if StatsKeys.flagged_words_ratio in stat:
85+
continue
86+
if context and words_key in samples[Fields.context][idx]:
87+
words = samples[Fields.context][idx][words_key]
88+
else:
89+
words = get_words_from_document(
90+
samples_list[idx],
91+
token_func=tokenizer.encode_as_pieces
92+
if tokenizer else None)
93+
if context:
94+
samples[Fields.context][idx][words_key] = words
95+
# try to get refined words from context
96+
refined_words_key = f'{InterVars.refined_words}' \
97+
'-True-SPECIAL_CHARS-' \
98+
f'{self.use_words_aug}-' \
99+
f'{self.words_aug_group_sizes}-' \
100+
f'{self.words_aug_join_char}'
101+
if context and refined_words_key in samples[Fields.context][idx]:
102+
words = samples[Fields.context][idx][refined_words_key]
103+
else:
104+
words = words_refinement(
105+
words,
106+
lower_case=True,
107+
strip_chars=SPECIAL_CHARACTERS,
108+
use_words_aug=self.use_words_aug,
109+
words_aug_group_sizes=self.words_aug_group_sizes,
110+
words_aug_join_char=self.words_aug_join_char)
111+
if context:
112+
samples[Fields.context][idx][refined_words_key] = words
113+
114+
flagged_words_ratio = (len([
115+
word for word in words if word in self.FLAGGED_WORDS[self.lang]
116+
]) / len(words)) if len(words) != 0 else 0.0
117+
118+
if flagged_words_ratio > 1.0:
119+
flagged_words_ratio = 1.0
120+
121+
samples_stats[idx][
122+
StatsKeys.flagged_words_ratio] = flagged_words_ratio
123+
124+
return samples
125+
126+
def process_batched(self, samples):
127+
return list(
128+
map(
129+
lambda stat: stat[StatsKeys.flagged_words_ratio] <= self.
130+
max_ratio,
131+
samples[Fields.stats],
132+
))

0 commit comments

Comments
 (0)