33import numpy as np
44import scipy .sparse as sps
55from tick .preprocessing .base import LongitudinalPreprocessor
6- from .build .preprocessing import LongitudinalFeaturesLagger \
6+ from tick . preprocessing .build .preprocessing import LongitudinalFeaturesLagger \
77 as _LongitudinalFeaturesLagger
8- from .utils import check_longitudinal_features_consistency ,\
8+ from tick . preprocessing .utils import check_longitudinal_features_consistency ,\
99 check_censoring_consistency
1010from multiprocessing .pool import Pool
11+ from copy import deepcopy
12+ from functools import partial , partialmethod
1113
1214
1315class LongitudinalFeaturesLagger (LongitudinalPreprocessor ):
@@ -76,15 +78,12 @@ class LongitudinalFeaturesLagger(LongitudinalPreprocessor):
7678 "_n_intervals" : {
7779 "writable" : False
7880 },
79- "_cpp_preprocessor" : {
80- "writable" : False
81- },
8281 "_fitted" : {
8382 "writable" : False
8483 }
8584 }
8685
87- def __init__ (self , n_lags , n_jobs = 1 ):
86+ def __init__ (self , n_lags , n_jobs = - 1 ):
8887 LongitudinalPreprocessor .__init__ (self , n_jobs = n_jobs )
8988 if not isinstance (n_lags , np .ndarray ) or n_lags .dtype != 'uint64' :
9089 raise ValueError (
@@ -93,15 +92,13 @@ def __init__(self, n_lags, n_jobs=1):
9392 self ._n_init_features = None
9493 self ._n_output_features = None
9594 self ._n_intervals = None
96- self ._cpp_preprocessor = None
9795 self ._fitted = False
9896
9997 def _reset (self ):
10098 """Resets the object its initial construction state."""
10199 self ._set ("_n_init_features" , None )
102100 self ._set ("_n_output_features" , None )
103101 self ._set ("_n_intervals" , None )
104- self ._set ("_cpp_preprocessor" , None )
105102 self ._set ("_fitted" , False )
106103
107104 def fit (self , features , labels = None , censoring = None ):
@@ -138,10 +135,7 @@ def fit(self, features, labels=None, censoring=None):
138135 self ._set ("_n_init_features" , n_init_features )
139136 self ._set ("_n_intervals" , n_intervals )
140137 self ._set ("_n_output_features" , int ((self .n_lags + 1 ).sum ()))
141- self ._set ("_cpp_preprocessor" ,
142- _LongitudinalFeaturesLagger (features , self .n_lags ))
143138 self ._set ("_fitted" , True )
144-
145139 return self
146140
147141 def transform (self , features , labels = None , censoring = None ):
@@ -175,49 +169,58 @@ def transform(self, features, labels=None, censoring=None):
175169 base_shape = (self ._n_intervals , self ._n_init_features )
176170 features = check_longitudinal_features_consistency (
177171 features , base_shape , "float64" )
178- if sps .issparse (features [0 ]):
179- if self .n_jobs > 1 :
180- with Pool (self .n_jobs ) as pool :
181- X_with_lags = pool .starmap (self ._sparse_lagger , zip (features , censoring ))
182- pool .start ()
183- pool .join ()
184- else :
185- X_with_lags = [
186- self ._sparse_lagger (x , int (censoring [i ]))
187- for i , x in enumerate (features )
188- ]
189- # TODO: Don't get why int() is required here as censoring_i is uint64
190- else :
191- if self .n_jobs > 1 :
192- with Pool (self .n_jobs ) as pool :
193- X_with_lags = pool .starmap (self ._dense_lagger , zip (features , censoring ))
194- pool .start ()
195- pool .join ()
196- else :
197- X_with_lags = [
198- self ._dense_lagger (x , int (censoring [i ]))
199- for i , x in enumerate (features )
200- ]
172+
173+ initializer = partial (self ._inject_cpp_object ,
174+ n_intervals = self ._n_intervals , n_lags = self .n_lags )
175+ callback = self ._sparse_lagger if sps .issparse (features [0 ]) \
176+ else self ._dense_lagger
177+ callback = partial (callback , n_intervals = self ._n_intervals ,
178+ n_output_features = self ._n_output_features ,
179+ n_lags = self .n_lags )
180+
181+ with Pool (self .n_jobs , initializer = initializer ) as pool :
182+ X_with_lags = pool .starmap (callback , zip (features , censoring ))
201183
202184 return X_with_lags , labels , censoring
203185
204- def _dense_lagger (self , feature_matrix , censoring_i ):
205- output = np .zeros ((self ._n_intervals , self ._n_output_features ),
186+ @staticmethod
187+ def _inject_cpp_object (n_intervals , n_lags ):
188+ """Creates a global instance of the CPP preprocessor object.
189+
190+ WARNING: to be used only as a multiprocessing.Pool initializer.
191+ In multiprocessing context, each process has its own namespace, so using
192+ global is not as bad as it seems. Still, it requires to proceed with
193+ caution.
194+ """
195+ global _cpp_preprocessor
196+ _cpp_preprocessor = _LongitudinalFeaturesLagger (n_intervals , n_lags )
197+
198+ @staticmethod
199+ def _dense_lagger (feature_matrix , censoring_i , n_intervals ,
200+ n_output_features , n_lags ):
201+ """Creates a lagged version of a dense matrixrepresenting longitudinal
202+ features."""
203+ global _cpp_preprocessor
204+ output = np .zeros ((n_intervals , n_output_features ),
206205 dtype = "float64" )
207- self . _cpp_preprocessor .dense_lag_preprocessor (feature_matrix , output ,
208- censoring_i )
206+ _cpp_preprocessor .dense_lag_preprocessor (feature_matrix , output ,
207+ int ( censoring_i ) )
209208 return output
210209
211- def _sparse_lagger (self , feature_matrix , censoring_i ):
212- pp = self ._cpp_preprocessor
210+ @staticmethod
211+ def _sparse_lagger (feature_matrix , censoring_i , n_intervals ,
212+ n_output_features , n_lags ):
213+ """Creates a lagged version of a sparse matrix representing longitudinal
214+ features."""
215+ global _cpp_preprocessor
213216 coo = feature_matrix .tocoo ()
214- estimated_nnz = coo .nnz * int ((self . n_lags + 1 ).sum ())
217+ estimated_nnz = coo .nnz * int ((n_lags + 1 ).sum ())
215218 out_row = np .zeros ((estimated_nnz ,), dtype = "uint64" )
216219 out_col = np .zeros ((estimated_nnz ,), dtype = "uint64" )
217220 out_data = np .zeros ((estimated_nnz ,), dtype = "float64" )
218- pp .sparse_lag_preprocessor (
221+ _cpp_preprocessor .sparse_lag_preprocessor (
219222 coo .row .astype ("uint64" ), coo .col .astype ("uint64" ), coo .data ,
220223 out_row , out_col , out_data , int (censoring_i ))
221224 return sps .csr_matrix ((out_data , (out_row , out_col )),
222- shape = (self . _n_intervals ,
223- self . _n_output_features ))
225+ shape = (n_intervals ,
226+ n_output_features ))
0 commit comments