Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions sklearnex/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,112 +77,113 @@
class BaseForest(oneDALEstimator, ABC):
_onedal_factory = None

def _onedal_fit(self, X, y, sample_weight=None, queue=None):
use_raw_input = get_config().get("use_raw_input", False) is True
xp, _ = get_namespace(X)
if not use_raw_input:
X, y = validate_data(
self,
X,
y,
multi_output=True,
accept_sparse=False,
dtype=[np.float64, np.float32],
ensure_all_finite=False,
ensure_2d=True,
)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

if y.ndim == 2 and y.shape[1] == 1:
warnings.warn(
"A column-vector y was passed when a 1d array was"
" expected. Please change the shape of y to "
"(n_samples,), for example using ravel().",
DataConversionWarning,
stacklevel=2,
)

if y.ndim == 1:
# reshape is necessary to preserve the data contiguity against vs
# [:, np.newaxis] that does not.
y = xp.reshape(y, (-1, 1))

self._n_samples, self.n_outputs_ = y.shape

if not use_raw_input:
y, expanded_class_weight = self._validate_y_class_weight(y)

if expanded_class_weight is not None:
if sample_weight is not None:
sample_weight = sample_weight * expanded_class_weight
else:
sample_weight = expanded_class_weight
if sample_weight is not None:
sample_weight = [sample_weight]
else:
# try catch needed for raw_inputs + array_api data where unlike
# numpy the way to yield unique values is via `unique_values`
# This should be removed when refactored for gpu zero-copy
try:
self.classes_ = xp.unique(y)
classes = xp.unique(y)
except AttributeError:
self.classes_ = xp.unique_values(y)
classes = xp.unique_values(y)
# Convert to numpy for compatibility with later operations
if hasattr(classes, 'asnumpy'):
self.classes_ = classes.asnumpy()
elif hasattr(xp, 'to_numpy'):
self.classes_ = xp.to_numpy(classes)
else:
self.classes_ = np.asarray(classes)
self.n_classes_ = len(self.classes_)
self.n_features_in_ = X.shape[1]

onedal_params = {
"n_estimators": self.n_estimators,
"criterion": self.criterion,
"max_depth": self.max_depth,
"min_samples_split": self.min_samples_split,
"min_samples_leaf": self.min_samples_leaf,
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
"max_features": self._to_absolute_max_features(
self.max_features, self.n_features_in_
),
"max_leaf_nodes": self.max_leaf_nodes,
"min_impurity_decrease": self.min_impurity_decrease,
"bootstrap": self.bootstrap,
"oob_score": self.oob_score,
"n_jobs": self.n_jobs,
"random_state": self.random_state,
"verbose": self.verbose,
"warm_start": self.warm_start,
"error_metric_mode": self._err if self.oob_score else "none",
"variable_importance_mode": "mdi",
"class_weight": self.class_weight,
"max_bins": self.max_bins,
"min_bin_size": self.min_bin_size,
"max_samples": self.max_samples,
}

onedal_params["min_impurity_split"] = None

# Lazy evaluation of estimators_
self._cached_estimators_ = None

# Compute
self._onedal_estimator = self._onedal_factory(**onedal_params)
self._onedal_estimator.fit(X, xp.reshape(y, (-1,)), sample_weight, queue=queue)

self._save_attributes()

# Decapsulate classes_ attributes
if hasattr(self, "classes_") and self.n_outputs_ == 1:
self.n_classes_ = (
self.n_classes_[0]
if isinstance(self.n_classes_, Iterable)
else self.n_classes_
)
self.classes_ = (
self.classes_[0]
if isinstance(self.classes_[0], Iterable)
else self.classes_
)

return self

Check notice on line 186 in sklearnex/ensemble/_forest.py

View check run for this annotation

codefactor.io / CodeFactor

sklearnex/ensemble/_forest.py#L80-L186

Complex Method

def _save_attributes(self):
if self.oob_score:
Expand Down
Loading