Skip to content

Commit 28b97e0

Browse files
committed
first pass at fixing empty partition failures
1 parent 1e811ce commit 28b97e0

File tree

2 files changed

+107
-31
lines changed

2 files changed

+107
-31
lines changed

dask_ml/wrappers.py

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -231,23 +231,22 @@ def transform(self, X):
231231
"""
232232
self._check_method("transform")
233233
X = self._check_array(X)
234-
meta = self.transform_meta
234+
output_meta = self.transform_meta
235235

236236
if isinstance(X, da.Array):
237-
if meta is None:
238-
meta = _get_output_dask_ar_meta_for_estimator(
237+
if output_meta is None:
238+
output_meta = _get_output_dask_ar_meta_for_estimator(
239239
_transform, self._postfit_estimator, X
240240
)
241241
return X.map_blocks(
242-
_transform, estimator=self._postfit_estimator, meta=meta
242+
_transform, estimator=self._postfit_estimator, meta=output_meta
243243
)
244244
elif isinstance(X, dd._Frame):
245-
if meta is None:
246-
# dask-dataframe relies on dd.core.no_default
247-
# for infering meta
248-
meta = dd.core.no_default
249-
return X.map_partitions(
250-
_transform, estimator=self._postfit_estimator, meta=meta
245+
return _get_output_df_for_estimator(
246+
model_fn=_transform,
247+
X=X,
248+
output_meta=output_meta,
249+
estimator=self._postfit_estimator,
251250
)
252251
else:
253252
return _transform(X, estimator=self._postfit_estimator)
@@ -311,25 +310,30 @@ def predict(self, X):
311310
"""
312311
self._check_method("predict")
313312
X = self._check_array(X)
314-
meta = self.predict_meta
313+
output_meta = self.predict_meta
315314

316315
if isinstance(X, da.Array):
317-
if meta is None:
318-
meta = _get_output_dask_ar_meta_for_estimator(
316+
if output_meta is None:
317+
output_meta = _get_output_dask_ar_meta_for_estimator(
319318
_predict, self._postfit_estimator, X
320319
)
321320

322321
result = X.map_blocks(
323-
_predict, estimator=self._postfit_estimator, drop_axis=1, meta=meta
322+
_predict,
323+
estimator=self._postfit_estimator,
324+
drop_axis=1,
325+
meta=output_meta,
324326
)
325327
return result
326328

327329
elif isinstance(X, dd._Frame):
328-
if meta is None:
329-
meta = dd.core.no_default
330-
return X.map_partitions(
331-
_predict, estimator=self._postfit_estimator, meta=meta
330+
return _get_output_df_for_estimator(
331+
model_fn=_predict,
332+
X=X,
333+
output_meta=output_meta,
334+
estimator=self._postfit_estimator,
332335
)
336+
333337
else:
334338
return _predict(X, estimator=self._postfit_estimator)
335339

@@ -355,25 +359,26 @@ def predict_proba(self, X):
355359

356360
self._check_method("predict_proba")
357361

358-
meta = self.predict_proba_meta
362+
output_meta = self.predict_proba_meta
359363

360364
if isinstance(X, da.Array):
361-
if meta is None:
362-
meta = _get_output_dask_ar_meta_for_estimator(
365+
if output_meta is None:
366+
output_meta = _get_output_dask_ar_meta_for_estimator(
363367
_predict_proba, self._postfit_estimator, X
364368
)
365369
# XXX: multiclass
366370
return X.map_blocks(
367371
_predict_proba,
368372
estimator=self._postfit_estimator,
369-
meta=meta,
373+
meta=output_meta,
370374
chunks=(X.chunks[0], len(self._postfit_estimator.classes_)),
371375
)
372376
elif isinstance(X, dd._Frame):
373-
if meta is None:
374-
meta = dd.core.no_default
375-
return X.map_partitions(
376-
_predict_proba, estimator=self._postfit_estimator, meta=meta
377+
return _get_output_df_for_estimator(
378+
model_fn=_predict_proba,
379+
X=X,
380+
output_meta=output_meta,
381+
estimator=self._postfit_estimator,
377382
)
378383
else:
379384
return _predict_proba(X, estimator=self._postfit_estimator)
@@ -626,18 +631,63 @@ def _first_block(dask_object):
626631
return dask_object
627632

628633

629-
def _predict(part, estimator):
634+
def _predict(part, estimator, output_meta=None):
635+
if part.shape[0] == 0 and output_meta is not None:
636+
empty_output = handle_empty_partitions(output_meta)
637+
if empty_output is not None:
638+
return empty_output
630639
return estimator.predict(part)
631640

632641

633-
def _predict_proba(part, estimator):
642+
def _predict_proba(part, estimator, output_meta=None):
643+
if part.shape[0] == 0 and output_meta is not None:
644+
empty_output = handle_empty_partitions(output_meta)
645+
if empty_output is not None:
646+
return empty_output
647+
634648
return estimator.predict_proba(part)
635649

636650

637-
def _transform(part, estimator):
651+
def _transform(part, estimator, output_meta=None):
652+
if part.shape[0] == 0 and output_meta is not None:
653+
empty_output = handle_empty_partitions(output_meta)
654+
if empty_output is not None:
655+
return empty_output
656+
638657
return estimator.transform(part)
639658

640659

660+
def handle_empty_partitions(output_meta):
661+
if hasattr(output_meta, "__array_function__"):
662+
if len(output_meta.shape) == 1:
663+
shape = 0
664+
else:
665+
shape = list(output_meta.shape)
666+
shape[0] = 0
667+
ar = np.zeros(
668+
shape=shape,
669+
dtype=output_meta.dtype,
670+
like=output_meta,
671+
)
672+
return ar
673+
elif "scipy.sparse" in type(output_meta).__module__:
674+
# sparse matrices dont support
675+
# `like` due to non implimented __array_function__
676+
# Refer https://github.com/scipy/scipy/issues/10362
677+
# Note below works for both cupy and scipy sparse matrices
678+
# TODO: REMOVE code duplication
679+
if len(ar.shape) == 1:
680+
shape = 0
681+
else:
682+
shape = list(ar.shape)
683+
shape[0] = 0
684+
685+
ar = type(output_meta)(shape, dtype=output_meta.dtype)
686+
return ar
687+
elif hasattr(output_meta, "iloc"):
688+
return output_meta.iloc[:0, :]
689+
690+
641691
def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):
642692
"""
643693
Returns the output metadata array
@@ -692,3 +742,12 @@ def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):
692742
warnings.warn(msg)
693743
ar = np.zeros(shape=(1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype)
694744
return model_fn(ar, estimator)
745+
746+
747+
def _get_output_df_for_estimator(model_fn, X, output_meta, estimator):
748+
if output_meta is None:
749+
# dask-dataframe relies on dd.core.no_default
750+
# for infering meta
751+
output_meta = model_fn(X._meta_nonempty, estimator)
752+
753+
return X.map_partitions(model_fn, estimator, output_meta, meta=output_meta)

tests/test_parallel_post_fit.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def test_predict_meta_override():
6666
# Failure when not proving predict_meta
6767
# because of value dependent model
6868
wrap = ParallelPostFit(base)
69-
with pytest.raises(ValueError):
69+
# TODO: Fix
70+
with pytest.raises(IndexError):
7071
wrap.predict(dd_X)
7172

7273
# Success when providing meta over-ride
@@ -89,7 +90,8 @@ def test_predict_proba_meta_override():
8990
# Failure when not proving predict_proba_meta
9091
# because of value dependent model
9192
wrap = ParallelPostFit(base)
92-
with pytest.raises(ValueError):
93+
# TODO: Fix below
94+
with pytest.raises(IndexError):
9395
wrap.predict_proba(dd_X)
9496

9597
# Success when providing meta over-ride
@@ -289,3 +291,18 @@ def shape(self):
289291
match="provide explicit `predict_proba_meta` to the dask_ml.wrapper",
290292
):
291293
clf.predict_proba(fake_dask_ar)
294+
295+
296+
def test_predict_empty_partitions():
297+
df = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8], "y": [True, False] * 4})
298+
ddf = dd.from_pandas(df, npartitions=4)
299+
300+
clf = ParallelPostFit(LogisticRegression())
301+
clf = clf.fit(df[["x"]], df["y"])
302+
303+
ddf_with_empty_part = ddf[ddf.x < 5][["x"]]
304+
result = clf.predict(ddf_with_empty_part).compute()
305+
306+
expected = clf.estimator.predict(ddf_with_empty_part.compute())
307+
308+
assert_eq_ar(result, expected)

0 commit comments

Comments
 (0)