Skip to content

Commit bba1f38

Browse files
committed
add _init_parameters for cuda op
1 parent 20fc9de commit bba1f38

25 files changed

+46
-3
lines changed

data_juicer/config/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,11 @@ def init_setup_from_cfg(cfg: Namespace):
425425

426426
# check number of processes np
427427
sys_cpu_count = os.cpu_count()
428+
if not cfg.np:
429+
cfg.np = sys_cpu_count
430+
logger.warning(
431+
f'Number of processes `np` is not set, '
432+
f'Set it to cpu count [{sys_cpu_count}] as default value.')
428433
if cfg.np > sys_cpu_count:
429434
logger.warning(f'Number of processes `np` is set as [{cfg.np}], which '
430435
f'is larger than the cpu count [{sys_cpu_count}]. Due '

data_juicer/core/ray_data.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,14 @@ def _run_single_op(self, op):
119119
1) if op.is_batched_op() else 1
120120
if isinstance(op, Mapper):
121121
if op.use_cuda():
122-
init_params = op._init_parameters
122+
try:
123+
init_params = op._init_parameters
124+
except AttributeError:
125+
raise ValueError(
126+
f'This Op[{op._name}] enables CUDA, you should add'
127+
' `_init_parameters` attribute to the Op class by '
128+
'add `self._init_parameters = self.remove_extra_parameters(locals())`' # noqa: E501
129+
' after super().__init__().')
123130
op_args = init_params.pop('args', ())
124131
op_kwargs = init_params.pop('kwargs', {})
125132
op_kwargs.update(init_params)
@@ -141,7 +148,14 @@ def _run_single_op(self, op):
141148
num_gpus=num_gpus)
142149
elif isinstance(op, Filter):
143150
if op.use_cuda():
144-
init_params = op._init_parameters
151+
try:
152+
init_params = op._init_parameters
153+
except AttributeError:
154+
raise ValueError(
155+
f'This Op[{op._name}] enables CUDA, you should add'
156+
' `_init_parameters` attribute to the Op class by '
157+
'add `self._init_parameters = self.remove_extra_parameters(locals())`' # noqa: E501
158+
' after super().__init__().')
145159
op_args = init_params.pop('args', ())
146160
op_kwargs = init_params.pop('kwargs', {})
147161
op_kwargs.update(init_params)

data_juicer/ops/filter/image_aesthetics_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self,
4848
"""
4949

5050
super().__init__(*args, **kwargs)
51+
self._init_parameters = self.remove_extra_parameters(locals())
5152
if hf_scorer_model == '':
5253
hf_scorer_model = \
5354
'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE'

data_juicer/ops/filter/image_pair_similarity_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self,
4444
:param kwargs: extra args
4545
"""
4646
super().__init__(*args, **kwargs)
47+
self._init_parameters = self.remove_extra_parameters(locals())
4748
self.min_score = min_score
4849
self.max_score = max_score
4950
if any_or_all not in ['any', 'all']:

data_juicer/ops/filter/image_text_matching_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self,
5353
:param kwargs: extra args
5454
"""
5555
super().__init__(*args, **kwargs)
56+
self._init_parameters = self.remove_extra_parameters(locals())
5657
self.min_score = min_score
5758
self.max_score = max_score
5859
if reduce_mode not in ['avg', 'max', 'min']:

data_juicer/ops/filter/image_text_similarity_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self,
5454
:param kwargs: extra args
5555
"""
5656
super().__init__(*args, **kwargs)
57+
self._init_parameters = self.remove_extra_parameters(locals())
5758
self.min_score = min_score
5859
self.max_score = max_score
5960
if reduce_mode not in ['avg', 'max', 'min']:

data_juicer/ops/filter/image_watermark_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self,
4646
:param kwargs: extra args
4747
"""
4848
super().__init__(*args, **kwargs)
49+
self._init_parameters = self.remove_extra_parameters(locals())
4950
self.prob_threshold = prob_threshold
5051
if any_or_all not in ['any', 'all']:
5152
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '

data_juicer/ops/filter/phrase_grounding_recall_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(self,
115115
:param kwargs: extra args
116116
"""
117117
super().__init__(*args, **kwargs)
118+
self._init_parameters = self.remove_extra_parameters(locals())
118119
self.min_recall = min_recall
119120
self.max_recall = max_recall
120121
if reduce_mode not in ['avg', 'max', 'min']:

data_juicer/ops/filter/video_aesthetics_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self,
7575
"""
7676

7777
super().__init__(*args, **kwargs)
78+
self._init_parameters = self.remove_extra_parameters(locals())
7879
if hf_scorer_model == '':
7980
hf_scorer_model = \
8081
'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE'

data_juicer/ops/filter/video_frames_text_similarity_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self,
7575
:param kwargs: extra args
7676
"""
7777
super().__init__(*args, **kwargs)
78+
self._init_parameters = self.remove_extra_parameters(locals())
7879
self.min_score = min_score
7980
self.max_score = max_score
8081
if frame_sampling_method not in ['all_keyframes', 'uniform']:

data_juicer/ops/filter/video_motion_score_filter.py

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(self,
8282
:param kwargs: extra args
8383
"""
8484
super().__init__(*args, **kwargs)
85+
self._init_parameters = self.remove_extra_parameters(locals())
86+
8587
self.min_score = min_score
8688
self.max_score = max_score
8789
self.sampling_fps = sampling_fps

data_juicer/ops/filter/video_nsfw_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self,
6666
:param kwargs: extra args
6767
"""
6868
super().__init__(*args, **kwargs)
69+
self._init_parameters = self.remove_extra_parameters(locals())
6970
self.score_threshold = score_threshold
7071
if frame_sampling_method not in ['all_keyframes', 'uniform']:
7172
raise ValueError(

data_juicer/ops/filter/video_ocr_area_ratio_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(self,
7272
:param kwargs: extra args
7373
"""
7474
super().__init__(*args, **kwargs)
75+
self._init_parameters = self.remove_extra_parameters(locals())
7576
self.min_area_ratio = min_area_ratio
7677
self.max_area_ratio = max_area_ratio
7778
self.frame_sample_num = frame_sample_num

data_juicer/ops/filter/video_tagging_from_frames_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self,
6262
:param kwargs: extra args
6363
"""
6464
super().__init__(*args, **kwargs)
65+
self._init_parameters = self.remove_extra_parameters(locals())
6566
if contain not in ['any', 'all']:
6667
raise ValueError(f'the containing type [{contain}] is not '
6768
f'supported. Can only be one of ["any", "all"].')

data_juicer/ops/filter/video_watermark_filter.py

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(self,
7070
:param kwargs: extra args
7171
"""
7272
super().__init__(*args, **kwargs)
73+
self._init_parameters = self.remove_extra_parameters(locals())
7374
self.prob_threshold = prob_threshold
7475
if frame_sampling_method not in ['all_keyframes', 'uniform']:
7576
raise ValueError(

data_juicer/ops/mapper/generate_qa_from_examples_mapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(self,
9696
:param kwargs: Extra keyword arguments.
9797
"""
9898
super().__init__(**kwargs)
99+
self._init_parameters = self.remove_extra_parameters(locals())
99100

100101
if not seed_file:
101102
raise ValueError(

data_juicer/ops/mapper/generate_qa_from_text_mapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(self,
6969
"""
7070

7171
super().__init__(**kwargs)
72+
self._init_parameters = self.remove_extra_parameters(locals())
7273

7374
if output_pattern is None:
7475
self.output_pattern = r'Human:(.*?)Assistant:(.*?)(?=Human|$)' # noqa: E501

data_juicer/ops/mapper/image_tagging_mapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self,
3737
:param kwargs: extra args
3838
"""
3939
super().__init__(*args, **kwargs)
40+
self._init_parameters = self.remove_extra_parameters(locals())
4041
self.model_key = prepare_model(
4142
model_type='recognizeAnything',
4243
pretrained_model_name_or_path='ram_plus_swin_large_14m.pth',

data_juicer/ops/mapper/optimize_qa_mapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self,
6666
:param kwargs: Extra keyword arguments.
6767
"""
6868
super().__init__(**kwargs)
69-
69+
self._init_parameters = self.remove_extra_parameters(locals())
7070
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
7171
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
7272
self.qa_pair_template = qa_pair_template or \

data_juicer/ops/mapper/video_captioning_from_audio_mapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(self, keep_original_sample: bool = True, *args, **kwargs):
3333
:param kwargs: extra args
3434
"""
3535
super().__init__(*args, **kwargs)
36+
self._init_parameters = self.remove_extra_parameters(locals())
3637
AUTOINSTALL.check([
3738
'transformers', 'transformers_stream_generator', 'einops',
3839
'accelerate', 'tiktoken'

data_juicer/ops/mapper/video_captioning_from_frames_mapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
:param kwargs: extra args
110110
"""
111111
super().__init__(*args, **kwargs)
112+
self._init_parameters = self.remove_extra_parameters(locals())
112113

113114
if keep_candidate_mode not in [
114115
'random_any', 'similar_one_simhash', 'all'

data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(self,
8282
:param kwargs: extra args
8383
"""
8484
super().__init__(*args, **kwargs)
85+
self._init_parameters = self.remove_extra_parameters(locals())
8586
AUTOINSTALL.check([
8687
'torch',
8788
'transformers',

data_juicer/ops/mapper/video_captioning_from_video_mapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
:param kwargs: extra args
110110
"""
111111
super().__init__(*args, **kwargs)
112+
self._init_parameters = self.remove_extra_parameters(locals())
112113

113114
if keep_candidate_mode not in [
114115
'random_any', 'similar_one_simhash', 'all'

data_juicer/ops/mapper/video_tagging_from_audio_mapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self,
3838
:param kwargs: extra args
3939
"""
4040
super().__init__(*args, **kwargs)
41+
self._init_parameters = self.remove_extra_parameters(locals())
4142
AUTOINSTALL.check(['torchaudio'])
4243
self.model_key = prepare_model(model_type='huggingface',
4344
pretrained_model_name_or_path=hf_ast,

data_juicer/ops/mapper/video_tagging_from_frames_mapper.py

+2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(self,
5656
:param kwargs: extra args
5757
"""
5858
super().__init__(*args, **kwargs)
59+
self._init_parameters = self.remove_extra_parameters(locals())
60+
5961
if frame_sampling_method not in ['all_keyframes', 'uniform']:
6062
raise ValueError(
6163
f'Frame sampling method [{frame_sampling_method}] is not '

0 commit comments

Comments
 (0)