Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
82fc26d
refactor: move/delete some methods in neighbors.py
yuejiaointel Oct 6, 2025
325753c
fix: try it again
yuejiaointel Oct 6, 2025
d17bb34
fix: try it again
yuejiaointel Oct 6, 2025
0e8b4c6
fix: try it again
yuejiaointel Oct 6, 2025
9dda937
fix: first round of refactor move preprocssing function to sklearnex
yuejiaointel Oct 6, 2025
8bd86c2
fix: fix shape
yuejiaointel Oct 7, 2025
debfcdf
rebase: rebase to main
yuejiaointel Oct 7, 2025
e9e7306
fix: add fit emthod logic in onedla
yuejiaointel Oct 7, 2025
02da9e9
fix: fix test
yuejiaointel Oct 7, 2025
62c8ddd
fix: fix tupleerror
yuejiaointel Oct 8, 2025
fc296b5
fix: fix tuple issue
yuejiaointel Oct 9, 2025
fe0abbb
print: print fit_x
yuejiaointel Oct 10, 2025
e202e65
fix: fixed tuple
yuejiaointel Oct 10, 2025
649fc5d
fix: fix tuple
yuejiaointel Oct 10, 2025
a1f95f1
print: print in save attributes
yuejiaointel Oct 10, 2025
939a4f6
fix: tuple handling
yuejiaointel Oct 10, 2025
a4b1351
print: add print
yuejiaointel Oct 10, 2025
39ae6c5
print: test print
yuejiaointel Oct 10, 2025
aa98829
test: test fix for typle
yuejiaointel Oct 13, 2025
2f834d0
fix: more print
yuejiaointel Oct 13, 2025
dcf5b43
fix: test fix for tuyple issue
yuejiaointel Oct 13, 2025
9c65647
fix: test fix for tuyple issue
yuejiaointel Oct 13, 2025
b33834d
fix: try add validation
yuejiaointel Oct 13, 2025
96762db
fix: try restore neighbors funcitons
yuejiaointel Oct 14, 2025
cc2293c
fix: test restore
yuejiaointel Oct 14, 2025
19fe8ce
fix: restore again
yuejiaointel Oct 14, 2025
0f37c1b
fix: restpore
yuejiaointel Oct 14, 2025
f984c42
fix: restore ad and add print
yuejiaointel Oct 14, 2025
f372bcb
fix: restore ad and add print
yuejiaointel Oct 14, 2025
169df26
fix: fix test as well
yuejiaointel Oct 14, 2025
2a2a800
fix: fix test
yuejiaointel Oct 14, 2025
4377198
fix: comment out validate data
yuejiaointel Oct 14, 2025
50f9b9d
fix: refactoredclassifier prepressing to sklearnex
yuejiaointel Oct 14, 2025
833f7ab
fix: add vlaidate data and see if it fix attributeerror
yuejiaointel Oct 14, 2025
a2af2ef
fix: fix onedal test
yuejiaointel Oct 14, 2025
0b601f9
fix: dpm
yuejiaointel Oct 14, 2025
97f9bd1
fix: refacto validate n classes
yuejiaointel Oct 14, 2025
e5300ca
fix: refacor kneighbors validation
yuejiaointel Oct 15, 2025
ae590e9
fix: add vlaidation data to rest of the functions
yuejiaointel Oct 15, 2025
0a2850e
fix: fix check n neighbors validation before check is fitted
yuejiaointel Oct 15, 2025
24bd02d
fix: fix when predict(none) is called by adding x is not none check
yuejiaointel Oct 15, 2025
2702322
fix: fix lof
yuejiaointel Oct 15, 2025
965389e
fix: add validation in kneihbors for lof
yuejiaointel Oct 15, 2025
5b8b091
fix: remove count valitation in onedal
yuejiaointel Oct 15, 2025
5e54b86
fix: refactor shape
yuejiaointel Oct 15, 2025
b16ecc8
refactor: neighbors processing logic to skleranex
yuejiaointel Oct 15, 2025
8c89422
fix: validationeighbors < samples after +1
yuejiaointel Oct 15, 2025
273a084
fix: fix assertion error
yuejiaointel Oct 16, 2025
35afada
fix: fix asswertion error by dispatch gpu/skl in sklearnex
yuejiaointel Oct 16, 2025
8cccb1d
refacor: onedal prediciton entirely to sklearnex
yuejiaointel Oct 16, 2025
5e01257
feature: array api in common.py
yuejiaointel Oct 16, 2025
8bec3dc
fix: assertion error
yuejiaointel Oct 17, 2025
bbab97a
feature: add array api support to knn skleranex files
yuejiaointel Oct 18, 2025
aab0100
fix: compatiibilty for array api
yuejiaointel Oct 20, 2025
7574ef5
fix: remove validate data tests from deseleted tests
yuejiaointel Oct 20, 2025
dd74a72
Merge branch 'main' into refactor_neighbor_array_api
yuejiaointel Oct 20, 2025
591eb56
fix: format
yuejiaointel Oct 20, 2025
342b838
fix: remove ensure finite and reformat
yuejiaointel Oct 20, 2025
a46cc59
fix: format
yuejiaointel Oct 20, 2025
43283cd
fix: fix patching type error
yuejiaointel Oct 20, 2025
d734e1f
fix: update doc
yuejiaointel Oct 20, 2025
8c9246d
fix: fix patching error
yuejiaointel Oct 20, 2025
4cb7ed3
fix: attribute error
yuejiaointel Oct 21, 2025
95fff21
fix: patchnig AttributeError
yuejiaointel Oct 21, 2025
b250c46
fix: remove print and commented code
yuejiaointel Oct 21, 2025
a05d284
fix: format
yuejiaointel Oct 21, 2025
cf1d44d
fix: fix conformance test
yuejiaointel Oct 21, 2025
c2104ac
fix: format
yuejiaointel Oct 21, 2025
503bf49
fix: clean up unneeded var
yuejiaointel Oct 21, 2025
b4e6423
fix: attributeerror
yuejiaointel Oct 21, 2025
f3c949b
fix: spmd also use skelarnex neighbors
yuejiaointel Oct 22, 2025
db8070d
test: test without classes_check in onedal neighbor
yuejiaointel Oct 22, 2025
65b160b
fix: spmd issue
yuejiaointel Oct 23, 2025
231eb32
fix: format
yuejiaointel Oct 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/sources/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ The following patched classes have support for array API inputs:
- :obj:`sklearn.linear_model.Ridge`
- :obj:`sklearnex.linear_model.IncrementalLinearRegression`
- :obj:`sklearnex.linear_model.IncrementalRidge`
- :obj:`sklearn.neighbors.KNeighborsClassifier`
- :obj:`sklearn.neighbors.KNeighborsRegressor`
- :obj:`sklearn.neighbors.NearestNeighbors`
- :obj:`sklearn.neighbors.LocalOutlierFactor`

.. note::
While full array API support is currently not implemented for all classes, :external+dpnp:doc:`dpnp.ndarray <reference/ndarray>`
Expand Down
398 changes: 51 additions & 347 deletions onedal/neighbors/neighbors.py

Large diffs are not rendered by default.

17 changes: 10 additions & 7 deletions onedal/neighbors/tests/test_knn_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@
from numpy.testing import assert_array_equal
from sklearn import datasets

from onedal.neighbors import KNeighborsClassifier
from onedal.tests.utils._device_selection import get_queues

# Classification processing now happens in sklearnex layer
from sklearnex.neighbors import KNeighborsClassifier


@pytest.mark.parametrize("queue", get_queues())
def test_iris(queue):
# queue parameter not used with sklearnex, but kept for test parametrization
iris = datasets.load_iris()
clf = KNeighborsClassifier(2).fit(iris.data, iris.target, queue=queue)
assert clf.score(iris.data, iris.target, queue=queue) > 0.9
clf = KNeighborsClassifier(2).fit(iris.data, iris.target)
score = clf.score(iris.data, iris.target)
assert score > 0.9
assert_array_equal(clf.classes_, np.sort(clf.classes_))


Expand All @@ -36,14 +40,13 @@ def test_pickle(queue):
if queue and queue.sycl_device.is_gpu:
pytest.skip("KNN classifier pickling for the GPU sycl_queue is buggy.")
iris = datasets.load_iris()
clf = KNeighborsClassifier(2).fit(iris.data, iris.target, queue=queue)
expected = clf.predict(iris.data, queue=queue)

clf = KNeighborsClassifier(2).fit(iris.data, iris.target)
expected = clf.predict(iris.data)
import pickle

dump = pickle.dumps(clf)
clf2 = pickle.loads(dump)

assert type(clf2) == clf.__class__
result = clf2.predict(iris.data, queue=queue)
result = clf2.predict(iris.data)
assert_array_equal(expected, result)
29 changes: 21 additions & 8 deletions sklearnex/neighbors/_lof.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sklearnex.neighbors.knn_unsupervised import NearestNeighbors

from ..utils._array_api import get_namespace
from ..utils.validation import check_feature_names
from ..utils.validation import check_feature_names, validate_data


@control_n_jobs(decorated_methods=["fit", "kneighbors", "_kneighbors"])
Expand All @@ -56,6 +56,7 @@ def _onedal_fit(self, X, y, queue=None):
if sklearn_check_version("1.2"):
self._validate_params()

# Let _onedal_knn_fit (NearestNeighbors._onedal_fit) handle validation
self._onedal_knn_fit(X, y, queue=queue)

if self.contamination != "auto":
Expand All @@ -74,7 +75,6 @@ def _onedal_fit(self, X, y, queue=None):
% (self.n_neighbors, n_samples)
)
self.n_neighbors_ = max(1, min(self.n_neighbors, n_samples - 1))

(
self._distances_fit_X_,
_neighbors_indices_fit_X_,
Expand Down Expand Up @@ -108,11 +108,10 @@ def _onedal_fit(self, X, y, queue=None):
"Duplicate values are leading to incorrect results. "
"Increase the number of neighbors for more accurate results."
)

return self

def fit(self, X, y=None):
result = dispatch(
return dispatch(
self,
"fit",
{
Expand All @@ -122,7 +121,6 @@ def fit(self, X, y=None):
X,
None,
)
return result

def _predict(self, X=None):
check_is_fitted(self)
Expand All @@ -135,7 +133,6 @@ def _predict(self, X=None):
else:
is_inlier = np.ones(self.n_samples_fit_, dtype=int)
is_inlier[self.negative_outlier_factor_ < self.offset_] = -1

return is_inlier

# This had to be done because predict loses the queue when no
Expand All @@ -149,9 +146,15 @@ def fit_predict(self, X, y=None):
return self.fit(X)._predict()

def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
# Validate n_neighbors parameter first
if n_neighbors is not None:
self._validate_n_neighbors(n_neighbors)

check_is_fitted(self)
if X is not None:
check_feature_names(self, X, reset=False)

# Validate kneighbors parameters (inherited from KNeighborsDispatchingBase)
self._kneighbors_validation(X, n_neighbors)

return dispatch(
self,
"kneighbors",
Expand All @@ -172,6 +175,16 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
def score_samples(self, X):
check_is_fitted(self)

# Validate and convert X
xp, _ = get_namespace(X)
X = validate_data(
self,
X,
dtype=[xp.float64, xp.float32],
accept_sparse="csr",
reset=False,
)

distances_X, neighbors_indices_X = self._kneighbors(
X, n_neighbors=self.n_neighbors_
)
Expand Down
Loading
Loading