@@ -59,28 +59,30 @@ def _handle_missing_values(
59
59
@beartype
60
60
def _compute_pca (
61
61
feature_matrix : np .ndarray , number_of_components : int , scaler_type : str
62
- ) -> Tuple [np .ndarray , np .ndarray ]:
62
+ ) -> Tuple [np .ndarray , np .ndarray , np . ndarray , np . ndarray ]:
63
63
scaler = SCALERS [scaler_type ]()
64
64
scaled_data = scaler .fit_transform (feature_matrix )
65
65
66
66
pca = PCA (n_components = number_of_components )
67
- principal_components = pca .fit_transform (scaled_data )
68
- explained_variances = pca .explained_variance_ratio_
67
+ transformed_data = pca .fit_transform (scaled_data )
68
+ principal_components = pca .components_
69
+ explained_variances = pca .explained_variance_
70
+ explained_variance_ratios = pca .explained_variance_ratio_
69
71
70
- return principal_components , explained_variances
72
+ return transformed_data , principal_components , explained_variances , explained_variance_ratios
71
73
72
74
73
75
@beartype
74
76
def compute_pca (
75
77
data : Union [np .ndarray , pd .DataFrame , gpd .GeoDataFrame ],
76
- number_of_components : int ,
78
+ number_of_components : Optional [ int ] = None ,
77
79
columns : Optional [Sequence [str ]] = None ,
78
80
scaler_type : Literal ["standard" , "min_max" , "robust" ] = "standard" ,
79
81
nodata_handling : Literal ["remove" , "replace" ] = "remove" ,
80
82
nodata : Optional [Number ] = None ,
81
- ) -> Tuple [Union [np .ndarray , pd .DataFrame , gpd .GeoDataFrame ], np .ndarray ]:
83
+ ) -> Tuple [Union [np .ndarray , pd .DataFrame , gpd .GeoDataFrame ], np .ndarray , np . ndarray , np . ndarray ]:
82
84
"""
83
- Compute defined number of principal components for numeric input data.
85
+ Compute defined number of principal components for numeric input data and transform the data .
84
86
85
87
Before computation, data is scaled according to specified scaler and NaN values removed or replaced.
86
88
Optionally, a nodata value can be given to handle similarly as NaN values.
@@ -93,7 +95,8 @@ def compute_pca(
93
95
Args:
94
96
data: Input data for PCA.
95
97
number_of_components: The number of principal components to compute. Should be >= 1 and at most
96
- the number of numeric columns if input is (Geo)Dataframe.
98
+ the number of features found in input data. If not defined, will be the same as number of
99
+ features in data. Defaults to None.
97
100
columns: Select columns used for the PCA. Other columns are excluded from PCA, but added back
98
101
to the result Dataframe intact. Only relevant if input is (Geo)Dataframe. Defaults to None.
99
102
scaler_type: Transform data according to a specified Sklearn scaler.
@@ -103,8 +106,8 @@ def compute_pca(
103
106
nodata: Define a nodata value to remove. Defaults to None.
104
107
105
108
Returns:
106
- The computed principal components in corresponding format as the input data and the
107
- explained variance ratios for each component.
109
+ The transformed data in same format as input data, computed principal components, explained variances
110
+ and explained variance ratios for each component.
108
111
109
112
Raises:
110
113
EmptyDataException: The input is empty.
@@ -116,7 +119,7 @@ def compute_pca(
116
119
if scaler_type not in SCALERS :
117
120
raise InvalidParameterValueException (f"Invalid scaler. Choose from: { list (SCALERS .keys ())} " )
118
121
119
- if number_of_components < 1 :
122
+ if number_of_components is not None and number_of_components < 1 :
120
123
raise InvalidParameterValueException ("The number of principal components should be >= 1." )
121
124
122
125
# Get feature matrix (Numpy array) from various input types
@@ -158,40 +161,50 @@ def compute_pca(
158
161
feature_matrix = feature_matrix .astype (float )
159
162
feature_matrix , nan_mask = _handle_missing_values (feature_matrix , nodata_handling , nodata )
160
163
164
+ # Default number of components to number of features in data if not defined
165
+ if number_of_components is None :
166
+ number_of_components = feature_matrix .shape [1 ]
167
+
161
168
if number_of_components > feature_matrix .shape [1 ]:
162
- raise InvalidParameterValueException ("The number of principal components is too high for the given input data." )
169
+ raise InvalidParameterValueException (
170
+ "The number of principal components is too high for the given input data "
171
+ + f"({ number_of_components } > { feature_matrix .shape [1 ]} )."
172
+ )
173
+
163
174
# Core PCA computation
164
- principal_components , explained_variances = _compute_pca (feature_matrix , number_of_components , scaler_type )
175
+ transformed_data , principal_components , explained_variances , explained_variance_ratios = _compute_pca (
176
+ feature_matrix , number_of_components , scaler_type
177
+ )
165
178
166
179
if nodata_handling == "remove" and nan_mask is not None :
167
- principal_components_with_nans = np .full ((nan_mask .size , principal_components .shape [1 ]), np .nan )
168
- principal_components_with_nans [~ nan_mask , :] = principal_components
169
- principal_components = principal_components_with_nans
180
+ transformed_data_with_nans = np .full ((nan_mask .size , transformed_data .shape [1 ]), np .nan )
181
+ transformed_data_with_nans [~ nan_mask , :] = transformed_data
182
+ transformed_data = transformed_data_with_nans
170
183
171
184
# Convert PCA output to proper format
172
185
if isinstance (data , np .ndarray ):
173
186
if data .ndim == 3 :
174
- result_data = principal_components .reshape (rows , cols , - 1 ).transpose (2 , 0 , 1 )
187
+ transformed_data_out = transformed_data .reshape (rows , cols , - 1 ).transpose (2 , 0 , 1 )
175
188
else :
176
- result_data = principal_components
189
+ transformed_data_out = transformed_data
177
190
178
191
elif isinstance (data , pd .DataFrame ):
179
192
component_names = [f"principal_component_{ i + 1 } " for i in range (number_of_components )]
180
- result_data = pd .DataFrame (data = principal_components , columns = component_names )
193
+ transformed_data_out = pd .DataFrame (data = transformed_data , columns = component_names )
181
194
if columns is not None :
182
195
old_columns = [column for column in data .columns if column not in columns ]
183
196
for column in old_columns :
184
- result_data [column ] = data [column ]
197
+ transformed_data_out [column ] = data [column ]
185
198
if isinstance (data , gpd .GeoDataFrame ):
186
- result_data = gpd .GeoDataFrame (result_data , geometry = geometries , crs = crs )
199
+ transformed_data_out = gpd .GeoDataFrame (transformed_data_out , geometry = geometries , crs = crs )
187
200
188
- return result_data , explained_variances
201
+ return transformed_data_out , principal_components , explained_variances , explained_variance_ratios
189
202
190
203
191
204
@beartype
192
205
def plot_pca (
193
206
pca_df : pd .DataFrame ,
194
- explained_variances : Optional [np .ndarray ] = None ,
207
+ explained_variance_ratios : Optional [np .ndarray ] = None ,
195
208
color_column_name : Optional [str ] = None ,
196
209
save_path : Optional [str ] = None ,
197
210
) -> sns .PairGrid :
@@ -203,7 +216,7 @@ def plot_pca(
203
216
204
217
Args:
205
218
pca_df: A DataFrame containing computed principal components.
206
- explained_variances : The explained variance ratios for each principal component. Used for labeling
219
+ explained_variance_ratios : The explained variance ratios for each principal component. Used for labeling
207
220
axes in the plot. Optional parameter. Defaults to None.
208
221
color_column_name: Name of the column that will be used for color-coding data points. Typically a
209
222
categorical variable in the original data. Optional parameter, no colors if not provided.
@@ -226,8 +239,8 @@ def plot_pca(
226
239
pair_grid = sns .pairplot (filtered_df , hue = color_column_name )
227
240
228
241
# Add explained variances to axis labels if provided
229
- if explained_variances is not None :
230
- labels = [f"PC { i + 1 } ({ var :.1f} %)" for i , var in enumerate (explained_variances * 100 )]
242
+ if explained_variance_ratios is not None :
243
+ labels = [f"PC { i + 1 } ({ var :.1f} %)" for i , var in enumerate (explained_variance_ratios * 100 )]
231
244
else :
232
245
labels = [f"PC { i + 1 } " for i in range (len (pair_grid .axes ))]
233
246
0 commit comments