@@ -231,23 +231,22 @@ def transform(self, X):
231
231
"""
232
232
self ._check_method ("transform" )
233
233
X = self ._check_array (X )
234
- meta = self .transform_meta
234
+ output_meta = self .transform_meta
235
235
236
236
if isinstance (X , da .Array ):
237
- if meta is None :
238
- meta = _get_output_dask_ar_meta_for_estimator (
237
+ if output_meta is None :
238
+ output_meta = _get_output_dask_ar_meta_for_estimator (
239
239
_transform , self ._postfit_estimator , X
240
240
)
241
241
return X .map_blocks (
242
- _transform , estimator = self ._postfit_estimator , meta = meta
242
+ _transform , estimator = self ._postfit_estimator , meta = output_meta
243
243
)
244
244
elif isinstance (X , dd ._Frame ):
245
- if meta is None :
246
- # dask-dataframe relies on dd.core.no_default
247
- # for infering meta
248
- meta = dd .core .no_default
249
- return X .map_partitions (
250
- _transform , estimator = self ._postfit_estimator , meta = meta
245
+ return _get_output_df_for_estimator (
246
+ model_fn = _transform ,
247
+ X = X ,
248
+ output_meta = output_meta ,
249
+ estimator = self ._postfit_estimator ,
251
250
)
252
251
else :
253
252
return _transform (X , estimator = self ._postfit_estimator )
@@ -311,25 +310,30 @@ def predict(self, X):
311
310
"""
312
311
self ._check_method ("predict" )
313
312
X = self ._check_array (X )
314
- meta = self .predict_meta
313
+ output_meta = self .predict_meta
315
314
316
315
if isinstance (X , da .Array ):
317
- if meta is None :
318
- meta = _get_output_dask_ar_meta_for_estimator (
316
+ if output_meta is None :
317
+ output_meta = _get_output_dask_ar_meta_for_estimator (
319
318
_predict , self ._postfit_estimator , X
320
319
)
321
320
322
321
result = X .map_blocks (
323
- _predict , estimator = self ._postfit_estimator , drop_axis = 1 , meta = meta
322
+ _predict ,
323
+ estimator = self ._postfit_estimator ,
324
+ drop_axis = 1 ,
325
+ meta = output_meta ,
324
326
)
325
327
return result
326
328
327
329
elif isinstance (X , dd ._Frame ):
328
- if meta is None :
329
- meta = dd .core .no_default
330
- return X .map_partitions (
331
- _predict , estimator = self ._postfit_estimator , meta = meta
330
+ return _get_output_df_for_estimator (
331
+ model_fn = _predict ,
332
+ X = X ,
333
+ output_meta = output_meta ,
334
+ estimator = self ._postfit_estimator ,
332
335
)
336
+
333
337
else :
334
338
return _predict (X , estimator = self ._postfit_estimator )
335
339
@@ -355,25 +359,26 @@ def predict_proba(self, X):
355
359
356
360
self ._check_method ("predict_proba" )
357
361
358
- meta = self .predict_proba_meta
362
+ output_meta = self .predict_proba_meta
359
363
360
364
if isinstance (X , da .Array ):
361
- if meta is None :
362
- meta = _get_output_dask_ar_meta_for_estimator (
365
+ if output_meta is None :
366
+ output_meta = _get_output_dask_ar_meta_for_estimator (
363
367
_predict_proba , self ._postfit_estimator , X
364
368
)
365
369
# XXX: multiclass
366
370
return X .map_blocks (
367
371
_predict_proba ,
368
372
estimator = self ._postfit_estimator ,
369
- meta = meta ,
373
+ meta = output_meta ,
370
374
chunks = (X .chunks [0 ], len (self ._postfit_estimator .classes_ )),
371
375
)
372
376
elif isinstance (X , dd ._Frame ):
373
- if meta is None :
374
- meta = dd .core .no_default
375
- return X .map_partitions (
376
- _predict_proba , estimator = self ._postfit_estimator , meta = meta
377
+ return _get_output_df_for_estimator (
378
+ model_fn = _predict_proba ,
379
+ X = X ,
380
+ output_meta = output_meta ,
381
+ estimator = self ._postfit_estimator ,
377
382
)
378
383
else :
379
384
return _predict_proba (X , estimator = self ._postfit_estimator )
@@ -626,18 +631,63 @@ def _first_block(dask_object):
626
631
return dask_object
627
632
628
633
629
- def _predict (part , estimator ):
634
+ def _predict (part , estimator , output_meta = None ):
635
+ if part .shape [0 ] == 0 and output_meta is not None :
636
+ empty_output = handle_empty_partitions (output_meta )
637
+ if empty_output is not None :
638
+ return empty_output
630
639
return estimator .predict (part )
631
640
632
641
633
- def _predict_proba (part , estimator ):
642
+ def _predict_proba (part , estimator , output_meta = None ):
643
+ if part .shape [0 ] == 0 and output_meta is not None :
644
+ empty_output = handle_empty_partitions (output_meta )
645
+ if empty_output is not None :
646
+ return empty_output
647
+
634
648
return estimator .predict_proba (part )
635
649
636
650
637
- def _transform (part , estimator ):
651
+ def _transform (part , estimator , output_meta = None ):
652
+ if part .shape [0 ] == 0 and output_meta is not None :
653
+ empty_output = handle_empty_partitions (output_meta )
654
+ if empty_output is not None :
655
+ return empty_output
656
+
638
657
return estimator .transform (part )
639
658
640
659
660
+ def handle_empty_partitions (output_meta ):
661
+ if hasattr (output_meta , "__array_function__" ):
662
+ if len (output_meta .shape ) == 1 :
663
+ shape = 0
664
+ else :
665
+ shape = list (output_meta .shape )
666
+ shape [0 ] = 0
667
+ ar = np .zeros (
668
+ shape = shape ,
669
+ dtype = output_meta .dtype ,
670
+ like = output_meta ,
671
+ )
672
+ return ar
673
+ elif "scipy.sparse" in type (output_meta ).__module__ :
674
+ # sparse matrices dont support
675
+ # `like` due to non implimented __array_function__
676
+ # Refer https://github.com/scipy/scipy/issues/10362
677
+ # Note below works for both cupy and scipy sparse matrices
678
+ # TODO: REMOVE code duplication
679
+ if len (ar .shape ) == 1 :
680
+ shape = 0
681
+ else :
682
+ shape = list (ar .shape )
683
+ shape [0 ] = 0
684
+
685
+ ar = type (output_meta )(shape , dtype = output_meta .dtype )
686
+ return ar
687
+ elif hasattr (output_meta , "iloc" ):
688
+ return output_meta .iloc [:0 , :]
689
+
690
+
641
691
def _get_output_dask_ar_meta_for_estimator (model_fn , estimator , input_dask_ar ):
642
692
"""
643
693
Returns the output metadata array
@@ -692,3 +742,12 @@ def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):
692
742
warnings .warn (msg )
693
743
ar = np .zeros (shape = (1 , input_dask_ar .shape [1 ]), dtype = input_dask_ar .dtype )
694
744
return model_fn (ar , estimator )
745
+
746
+
747
+ def _get_output_df_for_estimator (model_fn , X , output_meta , estimator ):
748
+ if output_meta is None :
749
+ # dask-dataframe relies on dd.core.no_default
750
+ # for infering meta
751
+ output_meta = model_fn (X ._meta_nonempty , estimator )
752
+
753
+ return X .map_partitions (model_fn , estimator , output_meta , meta = output_meta )
0 commit comments