Skip to content

Commit ea27359

Browse files
Merge pull request #369 from Dekken/binacox_squash
adding features relatively to the binacox model
2 parents eec1ac5 + 3ed7ba1 commit ea27359

File tree

6 files changed

+419
-29
lines changed

6 files changed

+419
-29
lines changed

tick/preprocessing/features_binarizer.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class FeaturesBinarizer(Base, BaseEstimator, TransformerMixin):
2929
method : "quantile" or "linspace", default="quantile"
3030
* If ``"quantile"`` quantile-based cuts are used.
3131
* If ``"linspace"`` linearly spaced cuts are used.
32+
* If ``"given"`` bins_boundaries needs to be provided.
3233
3334
detect_column_type : "auto" or "column_names", default="auto"
3435
* If ``"auto"`` feature type detection done automatically.
@@ -40,6 +41,9 @@ class FeaturesBinarizer(Base, BaseEstimator, TransformerMixin):
4041
If `True`, first column of each binarized continuous feature block is
4142
removed.
4243
44+
bins_boundaries : `list`, default="none"
45+
Bins boundaries for continuous features.
46+
4347
Attributes
4448
----------
4549
one_hot_encoder : `OneHotEncoder`
@@ -119,21 +123,37 @@ class FeaturesBinarizer(Base, BaseEstimator, TransformerMixin):
119123
}
120124

121125
def __init__(self, method="quantile", n_cuts=10, detect_column_type="auto",
122-
remove_first=False):
126+
remove_first=False, bins_boundaries=None):
123127
Base.__init__(self)
124128

125129
self.method = method
126130
self.n_cuts = n_cuts
127131
self.detect_column_type = detect_column_type
128132
self.remove_first = remove_first
133+
self.bins_boundaries = bins_boundaries
129134
self.reset()
130135

131136
def reset(self):
132137
self._set("one_hot_encoder", OneHotEncoder(sparse=True))
133-
self._set("bins_boundaries", {})
134138
self._set("mapper", {})
135139
self._set("feature_type", {})
136140
self._set("_fitted", False)
141+
if self.method != "given":
142+
self._set("bins_boundaries", {})
143+
144+
@property
145+
def boundaries(self):
146+
"""Get bins boundaries for all features.
147+
148+
Returns
149+
-------
150+
output : `dict`
151+
The bins boundaries for each feature.
152+
"""
153+
if not self._fitted:
154+
raise ValueError("cannot get bins_boundaries if object has not "
155+
"been fitted")
156+
return self.bins_boundaries
137157

138158
@property
139159
def blocks_start(self):
@@ -440,13 +460,20 @@ def _get_boundaries(self, feature_name, feature, fit=False):
440460
the actual number of distinct boundaries for this feature.
441461
"""
442462
if fit:
443-
boundaries = FeaturesBinarizer._detect_boundaries(
444-
feature, self.n_cuts, self.method)
445-
self.bins_boundaries[feature_name] = boundaries
446-
463+
if self.method == 'given':
464+
if self.bins_boundaries is None:
465+
raise ValueError("bins_boundaries required when `method` "
466+
"equals 'given'")
467+
468+
if not isinstance(self.bins_boundaries[feature_name], np.ndarray):
469+
raise ValueError("feature %s not found in bins_boundaries" % feature_name)
470+
boundaries = self.bins_boundaries[feature_name]
471+
else:
472+
boundaries = FeaturesBinarizer._detect_boundaries(
473+
feature, self.n_cuts, self.method)
474+
self.bins_boundaries[feature_name] = boundaries
447475
elif self._fitted:
448476
boundaries = self.bins_boundaries[feature_name]
449-
450477
else:
451478
raise ValueError("cannot call method with fit=True as object has "
452479
"not been fit")
@@ -518,7 +545,7 @@ def _assign_interval(self, feature_name, feature, fit=False):
518545
if feature.dtype != float:
519546
feature = feature.astype(float)
520547

521-
# Compute bins boundaries for the feature
548+
# Get bins boundaries for the feature
522549
boundaries = self._get_boundaries(feature_name, feature, fit)
523550

524551
# Discretize feature

tick/simulation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
__all__ = [
99
"features_normal_cov_uniform", "features_normal_cov_toeplitz",
1010
"weights_sparse_exp", "weights_sparse_gauss"
11-
]
11+
]

tick/survival/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .model_coxreg_partial_lik import ModelCoxRegPartialLik
1010
from .model_sccs import ModelSCCS
1111

12-
from .simu_coxreg import SimuCoxReg
12+
from .simu_coxreg import SimuCoxReg, SimuCoxRegWithCutPoints
1313
from .simu_sccs import SimuSCCS
1414
from .convolutional_sccs import ConvSCCS
1515

tick/survival/cox_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _construct_model_obj(self):
109109

110110
def _all_safe(self, features: np.ndarray, times: np.array,
111111
censoring: np.array):
112-
if not np.array_equal(np.unique(censoring), [0, 1]):
112+
if not set(np.unique(censoring)).issubset({0, 1}):
113113
raise ValueError('``censoring`` must only have values in {0, 1}')
114114
# All times must be positive
115115
if not np.all(times >= 0):

0 commit comments

Comments
 (0)