Skip to content

Commit 37a89e5

Browse files
Louquinzeeddiebergman
authored andcommitted
implement a new attribute allow_string_features (#1420)
1 parent 1a28632 commit 37a89e5

File tree

7 files changed

+88
-22
lines changed

7 files changed

+88
-22
lines changed

autosklearn/automl.py

+3
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(
210210
scoring_functions=None,
211211
get_trials_callback=None,
212212
dataset_compression: Union[bool, Mapping[str, Any]] = True,
213+
allow_string_features: bool = True,
213214
):
214215
super(AutoML, self).__init__()
215216
self.configuration_space = None
@@ -281,6 +282,7 @@ def __init__(
281282
self._dataset_compression = validate_dataset_compression_arg(
282283
dataset_compression, memory_limit=self._memory_limit
283284
)
285+
self.allow_string_features = allow_string_features
284286

285287
self._datamanager = None
286288
self._dataset_name = None
@@ -687,6 +689,7 @@ def fit(
687689
is_classification=is_classification,
688690
feat_type=feat_type,
689691
logger_port=self._logger_port,
692+
allow_string_features=self.allow_string_features,
690693
)
691694
self.InputValidator.fit(X_train=X, y_train=y, X_test=X_test, y_test=y_test)
692695
X, y = self.InputValidator.transform(X, y)

autosklearn/data/feature_validator.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,19 @@ def __init__(
4242
self,
4343
feat_type: Optional[List[str]] = None,
4444
logger: Optional[PickableLoggerAdapter] = None,
45+
allow_string_features: bool = True,
4546
) -> None:
4647
# If a dataframe was provided, we populate
47-
# this attribute with a mapping from column to {numerical | categorical}
48+
# this attribute with a mapping from column to
49+
# {numerical | categorical | string}
4850
self.feat_type: Optional[Dict[Union[str, int], str]] = None
4951
if feat_type is not None:
5052
if isinstance(feat_type, dict):
5153
self.feat_type = feat_type
5254
elif not isinstance(feat_type, List):
5355
raise ValueError(
5456
"Auto-Sklearn expects a list of categorical/"
55-
"numerical feature types, yet a"
57+
"numerical/string feature types, yet a"
5658
" {} was provided".format(type(feat_type))
5759
)
5860
else:
@@ -68,6 +70,7 @@ def __init__(
6870
self.logger = logger if logger is not None else logging.getLogger(__name__)
6971

7072
self._is_fitted = False
73+
self.allow_string_features = allow_string_features
7174

7275
def fit(
7376
self,
@@ -300,7 +303,14 @@ def get_feat_type_from_columns(
300303
elif X[column].dtype.name in ["category", "bool"]:
301304
feat_type[column] = "categorical"
302305
elif X[column].dtype.name == "string":
303-
feat_type[column] = "string"
306+
if self.allow_string_features:
307+
feat_type[column] = "string"
308+
else:
309+
feat_type[column] = "categorical"
310+
warnings.warn(
311+
f"you disabled text encoding column {column} will be "
312+
f"encoded as category"
313+
)
304314
# Move away from np.issubdtype as it causes
305315
# TypeError: data type not understood in certain pandas types
306316
elif not is_numeric_dtype(X[column]):
@@ -311,7 +321,14 @@ def get_feat_type_from_columns(
311321
f"Please ensure that this setting is suitable for your task.",
312322
UserWarning,
313323
)
314-
feat_type[column] = "string"
324+
if self.allow_string_features:
325+
feat_type[column] = "string"
326+
else:
327+
feat_type[column] = "categorical"
328+
warnings.warn(
329+
f"you disabled text encoding column {column} will be"
330+
f"encoded as category"
331+
)
315332
elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype(
316333
X[column].dtype
317334
):

autosklearn/data/validation.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
feat_type: Optional[List[str]] = None,
8181
is_classification: bool = False,
8282
logger_port: Optional[int] = None,
83+
allow_string_features: bool = True,
8384
) -> None:
8485
self.feat_type = feat_type
8586
self.is_classification = is_classification
@@ -92,8 +93,11 @@ def __init__(
9293
else:
9394
self.logger = logging.getLogger("Validation")
9495

96+
self.allow_string_features = allow_string_features
9597
self.feature_validator = FeatureValidator(
96-
feat_type=self.feat_type, logger=self.logger
98+
feat_type=self.feat_type,
99+
logger=self.logger,
100+
allow_string_features=self.allow_string_features,
97101
)
98102
self.target_validator = TargetValidator(
99103
is_classification=self.is_classification, logger=self.logger

autosklearn/estimators.py

+7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
load_models: bool = True,
5252
get_trials_callback=None,
5353
dataset_compression: Union[bool, Mapping[str, Any]] = True,
54+
allow_string_features: bool = True,
5455
):
5556
"""
5657
Parameters
@@ -322,6 +323,10 @@ def __init__(
322323
accordingly. We guarantee that at least one occurrence of each
323324
label is included in the sampled set.
324325
326+
allow_string_features: bool = True
327+
Whether autosklearn should process string features. By default the
328+
textpreprocessing is enabled.
329+
325330
Attributes
326331
----------
327332
cv_results_ : dict of numpy (masked) ndarrays
@@ -367,6 +372,7 @@ def __init__(
367372
self.load_models = load_models
368373
self.get_trials_callback = get_trials_callback
369374
self.dataset_compression = dataset_compression
375+
self.allow_string_features = allow_string_features
370376

371377
self.automl_ = None # type: Optional[AutoML]
372378

@@ -415,6 +421,7 @@ def build_automl(self):
415421
scoring_functions=self.scoring_functions,
416422
get_trials_callback=self.get_trials_callback,
417423
dataset_compression=self.dataset_compression,
424+
allow_string_features=self.allow_string_features,
418425
)
419426

420427
return automl

autosklearn/experimental/askl2.py

+2
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def __init__(
206206
scoring_functions: Optional[List[Scorer]] = None,
207207
load_models: bool = True,
208208
dataset_compression: Union[bool, Mapping[str, Any]] = True,
209+
allow_string_features: bool = True,
209210
):
210211

211212
"""
@@ -363,6 +364,7 @@ def __init__(
363364
metric=metric,
364365
scoring_functions=scoring_functions,
365366
load_models=load_models,
367+
allow_string_features=allow_string_features,
366368
)
367369

368370
def fit(

doc/manual.rst

+21-14
Original file line numberDiff line numberDiff line change
@@ -301,23 +301,29 @@ Other
301301
Supported formats for these training and testing pairs are: np.ndarray,
302302
pd.DataFrame, scipy.sparse.csr_matrix and python lists.
303303

304-
If your data contains categorical values (in the features or targets), autosklearn will automatically encode your
305-
data using a `sklearn.preprocessing.LabelEncoder <https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html>`_
306-
for unidimensional data and a `sklearn.preprocessing.OrdinalEncoder <https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OrdinalEncoder.html>`_
307-
for multidimensional data.
308-
309-
Regarding the features, there are two methods to guide *auto-sklearn* to properly encode categorical columns:
304+
Regarding the features, there are multiple things to consider:
310305

311306
* Providing a X_train/X_test numpy array with the optional flag feat_type. For further details, you
312307
can check the Example :ref:`sphx_glr_examples_40_advanced_example_feature_types.py`.
313308
* You can provide a pandas DataFrame, with properly formatted columns. If a column has numerical
314-
dtype, *auto-sklearn* will not encode it and it will be passed directly to scikit-learn. If the
315-
column has a categorical/boolean class, it will be encoded. If the column is of any other type
316-
(Object or Timeseries), an error will be raised. For further details on how to properly encode
317-
your data, you can check the Pandas Example
318-
`Working with categorical data <https://pandas.pydata.org/pandas-docs/stable/user_guide/categorical.html>`_).
319-
If you are working with time series, it is recommended that you follow this approach
309+
dtype, *auto-sklearn* will not encode it and it will be passed directly to scikit-learn. *auto-sklearn*
310+
supports both categorical or string as column type. Please ensure that you are using the correct
311+
dtype for your task. By default *auto-sklearn* treats object and string columns as strings and
312+
encodes the data using `sklearn.feature_extraction.text.CountVectorizer <https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html>`_
313+
* If your data contains categorical values (in the features or targets), ensure that you explicitly label them as categorical.
314+
data labeled as categorical is encoded by using a `sklearn.preprocessing.LabelEncoder <https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html>`_
315+
for unidimensional data and a `sklearn.preprodcessing.OrdinalEncoder <https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OrdinalEncoder.html>`_ for multidimensional data.
316+
* For further details on how to properly encode your data, you can check the Pandas Example
317+
`Working with categorical data <https://pandas.pydata.org/pandas-docs/stable/user_guide/categorical.html>`_). If you are working with time series, it is recommended that you follow this approach
320318
`Working with time data <https://stats.stackexchange.com/questions/311494/>`_.
319+
* If you prefer not using the string option at all you can disable this option. In this case
320+
objects, strings and categorical columns are encoded as categorical.
321+
322+
.. code:: python
323+
324+
import autosklearn.classification
325+
automl = autosklearn.classification.AutoSklearnClassifier(allow_string_features=False)
326+
automl.fit(X_train, y_train)
321327
322328
Regarding the targets (y_train/y_test), if the task involves a classification problem, such features will be
323329
automatically encoded. It is recommended to provide both y_train and y_test during fit, so that a common encoding
@@ -336,14 +342,15 @@ Other
336342

337343
In order to obtain *vanilla auto-sklearn* as used in `Efficient and Robust Automated Machine Learning
338344
<https://papers.nips.cc/paper/5872-efficient-and-robust-automated-machine -learning>`_
339-
set ``ensemble_size=1`` and ``initial_configurations_via_metalearning=0``:
345+
set ``ensemble_size=1``, ``initial_configurations_via_metalearning=0`` and ``allow_string_features=False``:
340346

341347
.. code:: python
342348
343349
import autosklearn.classification
344350
automl = autosklearn.classification.AutoSklearnClassifier(
345351
ensemble_size=1,
346-
initial_configurations_via_metalearning=0
352+
initial_configurations_via_metalearning=0,
353+
allow_string_features=False,
347354
)
348355
349356
An ensemble of size one will result in always choosing the current best model

test/test_data/test_feature_validator.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -524,15 +524,15 @@ def dummy_func(self):
524524
dummy_object = Dummy(1)
525525
lst = [1, 2, 3]
526526
array = np.array([1, 2, 3])
527-
dummy_stirng = "dummy string"
527+
dummy_string = "dummy string"
528528

529529
df = pd.DataFrame(
530530
{
531531
"dummy_object": [dummy_object] * 4,
532532
"dummy_lst": [lst] * 4,
533533
"dummy_array": [array] * 4,
534-
"dummy_string": [dummy_stirng] * 4,
535-
"type_mix_column": [dummy_stirng, dummy_object, array, lst],
534+
"dummy_string": [dummy_string] * 4,
535+
"type_mix_column": [dummy_string, dummy_object, array, lst],
536536
"cat_column": ["a", "b", "a", "b"],
537537
}
538538
)
@@ -560,3 +560,29 @@ def dummy_func(self):
560560
}
561561

562562
assert feat_type == column_types
563+
564+
565+
def test_allow_string_feature():
566+
df = pd.DataFrame({"Text": ["Hello", "how are you?"]})
567+
with pytest.warns(
568+
UserWarning,
569+
match=r"Input Column Text has generic type object. "
570+
r"Autosklearn will treat this column as string. "
571+
r"Please ensure that this setting is suitable for your task.",
572+
):
573+
validator = FeatureValidator(allow_string_features=False)
574+
feat_type = validator.get_feat_type_from_columns(df)
575+
576+
column_types = {"Text": "categorical"}
577+
assert feat_type == column_types
578+
579+
df["Text"] = df["Text"].astype("string")
580+
with pytest.warns(
581+
UserWarning,
582+
match=r"you disabled text encoding column Text will be " r"encoded as category",
583+
):
584+
validator = FeatureValidator(allow_string_features=False)
585+
feat_type = validator.get_feat_type_from_columns(df)
586+
587+
column_types = {"Text": "categorical"}
588+
assert feat_type == column_types

0 commit comments

Comments
 (0)