Skip to content

Commit 08d3cae

Browse files
authored
feat: Generic data processor (#117)
* data_processor mvp * add more typings * one output missing typing * remove redundant fit_transform method * typecheck simplifications * add cols positional index support to DataProcessor * add base processor
1 parent 97b855a commit 08d3cae

File tree

3 files changed

+132
-0
lines changed

3 files changed

+132
-0
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ tensorflow==2.4.*
77
easydict==1.9
88
pmlb==1.0.*
99
tqdm<5.0
10+
typeguard==2.13.*
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import List, Union
2+
3+
from numpy import concatenate, ndarray, split, zeros
4+
from pandas import concat, DataFrame
5+
from sklearn.base import BaseEstimator, TransformerMixin
6+
from typeguard import typechecked
7+
8+
@typechecked
9+
class BaseProcessor(BaseEstimator, TransformerMixin):
10+
"""
11+
Base class for Data Preprocessing. It is a base version and should not be instantiated directly.
12+
It works like any other transformer in scikit learn with the methods fit, transform and inverse transform.
13+
Args:
14+
num_cols (list of strings/list of ints):
15+
List of names of numerical columns or positional indexes (if pos_idx was set to True).
16+
cat_cols (list of strings/list of ints):
17+
List of names of categorical columns or positional indexes (if pos_idx was set to True).
18+
pos_idx (bool):
19+
Specifies if the passed col IDs are names or positional indexes (column numbers).
20+
"""
21+
def __init__(self, *, num_cols: Union[List[str], List[int]] = None, cat_cols: Union[List[str], List[int]] = None,
22+
pos_idx: bool = False):
23+
self.num_cols = [] if num_cols is None else num_cols
24+
self.cat_cols = [] if cat_cols is None else cat_cols
25+
26+
self.num_col_idx_ = None
27+
self.cat_col_idx_ = None
28+
29+
self.num_pipeline = None # To be overriden by child processors
30+
31+
self.cat_pipeline = None # To be overriden by child processors
32+
33+
self._types = None
34+
self.col_order_ = None
35+
self.pos_idx = pos_idx
36+
37+
def fit(self, X: DataFrame):
38+
"""Fits the DataProcessor to a passed DataFrame.
39+
Args:
40+
X (DataFrame):
41+
DataFrame used to fit the processor parameters.
42+
Should be aligned with the num/cat columns defined in initialization.
43+
"""
44+
if self.pos_idx:
45+
self.num_cols = list(X.columns[self.num_cols])
46+
self.cat_cols = list(X.columns[self.cat_cols])
47+
self.col_order_ = [c for c in X.columns if c in self.num_cols + self.cat_cols]
48+
self._types = X.dtypes
49+
50+
self.num_pipeline.fit(X[self.num_cols]) if self.num_cols else zeros([len(X), 0])
51+
self.cat_pipeline.fit(X[self.cat_cols]) if self.cat_cols else zeros([len(X), 0])
52+
53+
return self
54+
55+
def transform(self, X: DataFrame) -> ndarray:
56+
"""Transforms the passed DataFrame with the fit DataProcessor.
57+
Args:
58+
X (DataFrame):
59+
DataFrame used to fit the processor parameters.
60+
Should be aligned with the num/cat columns defined in initialization.
61+
Returns:
62+
transformed (ndarray):
63+
Processed version of the passed DataFrame.
64+
"""
65+
num_data = self.num_pipeline.transform(X[self.num_cols]) if self.num_cols else zeros([len(X), 0])
66+
cat_data = self.cat_pipeline.transform(X[self.cat_cols]) if self.cat_cols else zeros([len(X), 0])
67+
68+
transformed = concatenate([num_data, cat_data], axis=1)
69+
70+
self.num_col_idx_ = num_data.shape[1]
71+
self.cat_col_idx_ = self.num_col_idx_ + cat_data.shape[1]
72+
73+
return transformed
74+
75+
def inverse_transform(self, X: ndarray) -> DataFrame:
76+
"""Inverts the data transformation pipelines on a passed DataFrame.
77+
Args:
78+
X (ndarray):
79+
Numpy array to be brought back to the original data format.
80+
Should share the schema of data transformed by this DataProcessor.
81+
Can be used to revert transformations of training data or for
82+
Returns:
83+
result (DataFrame):
84+
DataFrame with inverted
85+
"""
86+
num_data, cat_data, _ = split(X, [self.num_col_idx_, self.cat_col_idx_], axis=1)
87+
88+
num_data = self.num_pipeline.inverse_transform(num_data) if self.num_cols else zeros([len(X), 0])
89+
cat_data = self.cat_pipeline.inverse_transform(cat_data) if self.cat_cols else zeros([len(X), 0])
90+
91+
result = concat([DataFrame(num_data, columns=self.num_cols),
92+
DataFrame(cat_data, columns=self.cat_cols),], axis=1)
93+
94+
result = result.loc[:, self.col_order_]
95+
96+
for col in result.columns:
97+
result[col]=result[col].astype(self._types[col])
98+
99+
return result
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import List, Union
2+
3+
from sklearn.pipeline import Pipeline
4+
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
5+
from typeguard import typechecked
6+
7+
from ydata_synthetic.preprocessing.base_processor import BaseProcessor
8+
9+
@typechecked
10+
class RegularDataProcessor(BaseProcessor):
11+
"""
12+
Main class for Regular/Tabular Data Preprocessing.
13+
It works like any other transformer in scikit learn with the methods fit, transform and inverse transform.
14+
Args:
15+
num_cols (list of strings/list of ints):
16+
List of names of numerical columns or positional indexes (if pos_idx was set to True).
17+
cat_cols (list of strings/list of ints):
18+
List of names of categorical columns or positional indexes (if pos_idx was set to True).
19+
pos_idx (bool):
20+
Specifies if the passed col IDs are names or positional indexes (column numbers).
21+
"""
22+
def __init__(self, *, num_cols: Union[List[str], List[int]] = None, cat_cols: Union[List[str], List[int]] = None,
23+
pos_idx: bool = False):
24+
super().__init__(num_cols = num_cols, cat_cols = cat_cols, pos_idx = pos_idx)
25+
26+
self.num_pipeline = Pipeline([
27+
("scaler", MinMaxScaler()),
28+
])
29+
30+
self.cat_pipeline = Pipeline([
31+
("encoder", OneHotEncoder(sparse=False, handle_unknown='ignore'))
32+
])

0 commit comments

Comments
 (0)