Skip to content

Commit f27e8c7

Browse files
committed
support ray actor
1 parent 4b8b436 commit f27e8c7

File tree

4 files changed

+69
-8
lines changed

4 files changed

+69
-8
lines changed

data_juicer/core/ray_data.py

+48-8
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,55 @@ def _run_single_op(self, op):
110110
batch_size = getattr(op, 'batch_size',
111111
1) if op.is_batched_op() else 1
112112
if isinstance(op, Mapper):
113-
self.data = self.data.map_batches(op.process,
114-
batch_size=batch_size,
115-
batch_format='pyarrow',
116-
num_gpus=num_gpus)
113+
if op.use_ray_actor():
114+
# TODO: auto calculate concurrency
115+
concurrency = getattr(op, 'concurrency', 1)
116+
117+
init_params = op._init_parameters
118+
op_args = init_params.pop('args', ())
119+
op_kwargs = init_params.pop('kwargs', {})
120+
op_kwargs.update(init_params)
121+
122+
self.data = self.data.map_batches(
123+
op.__class__,
124+
fn_args=None,
125+
fn_kwargs=None,
126+
fn_constructor_args=op_args,
127+
fn_constructor_kwargs=op_kwargs,
128+
batch_size=batch_size,
129+
num_gpus=num_gpus,
130+
concurrency=concurrency,
131+
batch_format='pyarrow')
132+
else:
133+
self.data = self.data.map_batches(op.process,
134+
batch_size=batch_size,
135+
batch_format='pyarrow',
136+
num_gpus=num_gpus)
117137
elif isinstance(op, Filter):
118-
self.data = self.data.map_batches(op.compute_stats,
119-
batch_size=batch_size,
120-
batch_format='pyarrow',
121-
num_gpus=num_gpus)
138+
if op.use_ray_actor():
139+
# TODO: auto calculate concurrency
140+
concurrency = getattr(op, 'concurrency', 1)
141+
142+
init_params = op._init_parameters
143+
op_args = init_params.pop('args', ())
144+
op_kwargs = init_params.pop('kwargs', {})
145+
op_kwargs.update(init_params)
146+
147+
self.data = self.data.map_batches(
148+
op.__class__,
149+
fn_args=None,
150+
fn_kwargs=None,
151+
fn_constructor_args=op_args,
152+
fn_constructor_kwargs=op_kwargs,
153+
batch_size=batch_size,
154+
num_gpus=num_gpus,
155+
concurrency=concurrency,
156+
batch_format='pyarrow')
157+
else:
158+
self.data = self.data.map_batches(op.compute_stats,
159+
batch_size=batch_size,
160+
batch_format='pyarrow',
161+
num_gpus=num_gpus)
122162
if op.stats_export_path is not None:
123163
self.data.write_json(op.stats_export_path,
124164
force_ascii=False)

data_juicer/ops/base_op.py

+11
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class OP:
118118

119119
_accelerator = 'cpu'
120120
_batched_op = False
121+
_ray_mode = 'task'
121122

122123
def __init__(self, *args, **kwargs):
123124
"""
@@ -143,6 +144,7 @@ def __init__(self, *args, **kwargs):
143144
self.history_key = kwargs.get('history_key', 'history')
144145

145146
self.batch_size = kwargs.get('batch_size', 1000)
147+
self.concurrency = kwargs.get('concurrency', 1)
146148

147149
# whether the model can be accelerated using cuda
148150
_accelerator = kwargs.get('accelerator', None)
@@ -172,6 +174,9 @@ def __init__(self, *args, **kwargs):
172174
def is_batched_op(self):
173175
return self._batched_op
174176

177+
def use_ray_actor(self):
178+
return self._ray_mode == 'actor'
179+
175180
def process(self, *args, **kwargs):
176181
raise NotImplementedError
177182

@@ -255,6 +260,9 @@ def __init_subclass__(cls, **kwargs):
255260
f'{cls.__name__}. Please implement {method_name}_single '
256261
f'or {method_name}_batched.')
257262

263+
def __call__(self, *args, **kwargs):
264+
return self.process(*args, **kwargs)
265+
258266
def process_batched(self, samples, *args, **kwargs):
259267
keys = samples.keys()
260268
first_key = next(iter(keys))
@@ -330,6 +338,9 @@ def __init_subclass__(cls, **kwargs):
330338
f'{cls.__name__}. Please implement {method_name}_single '
331339
f'or {method_name}_batched.')
332340

341+
def __call__(self, *args, **kwargs):
342+
return self.compute_stats(*args, **kwargs)
343+
333344
def compute_stats_batched(self, samples, *args, **kwargs):
334345
keys = samples.keys()
335346
num_samples = len(samples[Fields.stats])

data_juicer/ops/filter/image_nsfw_filter.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class ImageNSFWFilter(Filter):
1919
"""Filter to keep samples whose images have low nsfw scores."""
2020

2121
_accelerator = 'cuda'
22+
_ray_mode = 'actor'
2223

2324
def __init__(self,
2425
hf_nsfw_model: str = 'Falconsai/nsfw_image_detection',
@@ -42,6 +43,8 @@ def __init__(self,
4243
:param kwargs: extra args
4344
"""
4445
super().__init__(*args, **kwargs)
46+
self._init_parameters = self.remove_extra_parameters(locals())
47+
4548
self.score_threshold = score_threshold
4649
if any_or_all not in ['any', 'all']:
4750
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '

data_juicer/ops/mapper/image_captioning_mapper.py

+7
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ImageCaptioningMapper(Mapper):
3030

3131
_accelerator = 'cuda'
3232
_batched_op = True
33+
_ray_mode = 'actor'
3334

3435
def __init__(self,
3536
hf_img2seq: str = 'Salesforce/blip2-opt-2.7b',
@@ -82,6 +83,7 @@ def __init__(self,
8283
:param kwargs: extra args
8384
"""
8485
super().__init__(*args, **kwargs)
86+
self._init_parameters = self.remove_extra_parameters(locals())
8587

8688
if keep_candidate_mode not in [
8789
'random_any', 'similar_one_simhash', 'all'
@@ -282,6 +284,11 @@ def process_batched(self, samples, rank=None):
282284
:param samples:
283285
:return:
284286
"""
287+
import pyarrow as pa
288+
289+
if isinstance(samples, pa.Table):
290+
samples = samples.to_pydict()
291+
285292
# reconstruct samples from "dict of lists" to "list of dicts"
286293
reconstructed_samples = []
287294
for i in range(len(samples[self.text_key])):

0 commit comments

Comments
 (0)