Skip to content

Commit bae4afb

Browse files
authored
Add NearestNeighbors SPMD API (#2557)
* Add NearestNeighbors SPMD API * black format * extend gold data to have multiple rows per rank * formatting * raw inputs support for kneighbors * Reduce rows of synthetic large test * update search size and only use _spmd_assert_allclose * support empty kneighbors() * Update sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py * address comments
1 parent 09755e9 commit bae4afb

File tree

5 files changed

+179
-21
lines changed

5 files changed

+179
-21
lines changed

onedal/neighbors/neighbors.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def _fit(self, X, y):
296296
return result
297297

298298
def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
299+
use_raw_input = _get_config().get("use_raw_input", False) is True
299300
n_features = getattr(self, "n_features_in_", None)
300301
shape = getattr(X, "shape", None)
301302
if n_features and shape and len(shape) > 1 and shape[1] != n_features:
@@ -322,7 +323,8 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
322323

323324
if X is not None:
324325
query_is_train = False
325-
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
326+
if not use_raw_input:
327+
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
326328
else:
327329
query_is_train = True
328330
X = self._fit_X
@@ -730,7 +732,6 @@ def __init__(
730732
self,
731733
n_neighbors=5,
732734
*,
733-
weights="uniform",
734735
algorithm="auto",
735736
p=2,
736737
metric="minkowski",
@@ -745,7 +746,7 @@ def __init__(
745746
metric_params=metric_params,
746747
**kwargs,
747748
)
748-
self.weights = weights
749+
self.requires_y = False
749750

750751
@bind_default_backend("neighbors.search")
751752
def train(self, *args, **kwargs): ...
@@ -792,7 +793,7 @@ def _onedal_predict(self, model, X, params):
792793
return self.infer(params, model, X)
793794

794795
@supports_queue
795-
def fit(self, X, y, queue=None):
796+
def fit(self, X, y=None, queue=None):
796797
return self._fit(X, y)
797798

798799
@supports_queue

onedal/spmd/neighbors/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from .neighbors import KNeighborsClassifier, KNeighborsRegressor
17+
from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors
1818

19-
__all__ = ["KNeighborsClassifier", "KNeighborsRegressor"]
19+
__all__ = ["KNeighborsClassifier", "KNeighborsRegressor", "NearestNeighbors"]

onedal/spmd/neighbors/neighbors.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ...common._backend import bind_spmd_backend
1919
from ...neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch
2020
from ...neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch
21+
from ...neighbors import NearestNeighbors as NearestNeighbors_Batch
2122

2223

2324
class KNeighborsClassifier(KNeighborsClassifier_Batch):
@@ -30,6 +31,8 @@ def infer(self, *args, **kwargs): ...
3031

3132
@support_input_format
3233
def fit(self, X, y, queue=None):
34+
# Store queue to use during inference if not provided (if X is none in kneighbors)
35+
self.spmd_queue_ = queue
3336
return super().fit(X, y, queue=queue)
3437

3538
@support_input_format
@@ -42,6 +45,8 @@ def predict_proba(self, X, queue=None):
4245

4346
@support_input_format
4447
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
48+
if X is None and queue is None:
49+
queue = getattr(self, "spmd_queue_", None)
4550
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)
4651

4752

@@ -62,6 +67,8 @@ def infer(self, *args, **kwargs): ...
6267
@support_input_format
6368
@supports_queue
6469
def fit(self, X, y, queue=None):
70+
# Store queue to use during inference if not provided (if X is none in kneighbors)
71+
self.spmd_queue_ = queue
6572
if queue is not None and queue.sycl_device.is_gpu:
6673
return self._fit(X, y)
6774
else:
@@ -72,6 +79,8 @@ def fit(self, X, y, queue=None):
7279

7380
@support_input_format
7481
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
82+
if X is None and queue is None:
83+
queue = getattr(self, "spmd_queue_", None)
7584
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)
7685

7786
@support_input_format
@@ -84,3 +93,24 @@ def _get_onedal_params(self, X, y=None):
8493
if "responses" not in params["result_option"]:
8594
params["result_option"] += "|responses"
8695
return params
96+
97+
98+
class NearestNeighbors(NearestNeighbors_Batch):
99+
100+
@bind_spmd_backend("neighbors.search")
101+
def train(self, *args, **kwargs): ...
102+
103+
@bind_spmd_backend("neighbors.search")
104+
def infer(self, *args, **kwargs): ...
105+
106+
@support_input_format
107+
def fit(self, X, y=None, queue=None):
108+
# Store queue to use during inference if not provided (if X is none in kneighbors)
109+
self.spmd_queue_ = queue
110+
return super().fit(X, y, queue=queue)
111+
112+
@support_input_format
113+
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
114+
if X is None and queue is None:
115+
queue = getattr(self, "spmd_queue_", None)
116+
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)

sklearnex/spmd/neighbors/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from onedal.spmd.neighbors import KNeighborsClassifier, KNeighborsRegressor
17+
from onedal.spmd.neighbors import (
18+
KNeighborsClassifier,
19+
KNeighborsRegressor,
20+
NearestNeighbors,
21+
)
1822

19-
__all__ = ["KNeighborsClassifier", "KNeighborsRegressor"]
23+
__all__ = ["KNeighborsClassifier", "KNeighborsRegressor", "NearestNeighbors"]

sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py

Lines changed: 136 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
)
2525
from sklearnex import config_context
2626
from sklearnex.tests.utils.spmd import (
27-
_assert_unordered_allclose,
2827
_generate_classification_data,
2928
_generate_regression_data,
29+
_generate_statistic_data,
3030
_get_local_tensor,
3131
_mpi_libs_and_gpu_available,
3232
_spmd_assert_allclose,
@@ -94,8 +94,8 @@ def test_knncls_spmd_gold(dataframe, queue):
9494
spmd_result = spmd_model.predict(local_dpt_X_test)
9595
batch_result = batch_model.predict(X_test)
9696

97-
_assert_unordered_allclose(spmd_indcs, batch_indcs, localize=True)
98-
_assert_unordered_allclose(spmd_dists, batch_dists, localize=True)
97+
_spmd_assert_allclose(spmd_indcs, batch_indcs)
98+
_spmd_assert_allclose(spmd_dists, batch_dists)
9999
_spmd_assert_allclose(spmd_result, batch_result)
100100

101101

@@ -164,10 +164,8 @@ def test_knncls_spmd_synthetic(
164164

165165
tol = 1e-4
166166
if dtype == np.float64:
167-
_assert_unordered_allclose(spmd_indcs, batch_indcs, localize=True)
168-
_assert_unordered_allclose(
169-
spmd_dists, batch_dists, localize=True, rtol=tol, atol=tol
170-
)
167+
_spmd_assert_allclose(spmd_indcs, batch_indcs)
168+
_spmd_assert_allclose(spmd_dists, batch_dists, rtol=tol, atol=tol)
171169
_spmd_assert_allclose(spmd_result, batch_result)
172170

173171

@@ -231,8 +229,8 @@ def test_knnreg_spmd_gold(dataframe, queue):
231229
spmd_result = spmd_model.predict(local_dpt_X_test)
232230
batch_result = batch_model.predict(X_test)
233231

234-
_assert_unordered_allclose(spmd_indcs, batch_indcs, localize=True)
235-
_assert_unordered_allclose(spmd_dists, batch_dists, localize=True)
232+
_spmd_assert_allclose(spmd_indcs, batch_indcs)
233+
_spmd_assert_allclose(spmd_dists, batch_dists)
236234
_spmd_assert_allclose(spmd_result, batch_result)
237235

238236

@@ -303,8 +301,133 @@ def test_knnreg_spmd_synthetic(
303301

304302
tol = 0.005 if dtype == np.float32 else 1e-4
305303
if dtype == np.float64:
306-
_assert_unordered_allclose(spmd_indcs, batch_indcs, localize=True)
307-
_assert_unordered_allclose(
308-
spmd_dists, batch_dists, localize=True, rtol=tol, atol=tol
309-
)
304+
_spmd_assert_allclose(spmd_indcs, batch_indcs)
305+
_spmd_assert_allclose(spmd_dists, batch_dists, rtol=tol, atol=tol)
310306
_spmd_assert_allclose(spmd_result, batch_result, rtol=tol, atol=tol)
307+
308+
309+
@pytest.mark.skipif(
310+
not _mpi_libs_and_gpu_available,
311+
reason="GPU device and MPI libs required for test",
312+
)
313+
@pytest.mark.parametrize(
314+
"dataframe,queue",
315+
get_dataframes_and_queues(dataframe_filter_="dpnp,dpctl", device_filter_="gpu"),
316+
)
317+
@pytest.mark.mpi
318+
def test_knnsearch_spmd_gold(dataframe, queue):
319+
# Import spmd and batch algo
320+
from sklearnex.neighbors import NearestNeighbors as NearestNeighbors_Batch
321+
from sklearnex.spmd.neighbors import NearestNeighbors as NearestNeighbors_SPMD
322+
323+
# Create gold data and convert to dataframe
324+
X_train = np.array(
325+
[[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2], [10, 10], [9, 9]]
326+
)
327+
local_dpt_X_train = _convert_to_dataframe(
328+
_get_local_tensor(X_train), sycl_queue=queue, target_df=dataframe
329+
)
330+
331+
# Ensure predictions of batch algo match spmd
332+
spmd_model = NearestNeighbors_SPMD(n_neighbors=2, algorithm="brute").fit(
333+
local_dpt_X_train
334+
)
335+
batch_model = NearestNeighbors_Batch(n_neighbors=2, algorithm="brute").fit(X_train)
336+
spmd_dists, spmd_indcs = spmd_model.kneighbors(local_dpt_X_train)
337+
batch_dists, batch_indcs = batch_model.kneighbors(X_train)
338+
339+
_spmd_assert_allclose(spmd_indcs, batch_indcs)
340+
_spmd_assert_allclose(spmd_dists, batch_dists)
341+
342+
343+
@pytest.mark.skipif(
344+
not _mpi_libs_and_gpu_available,
345+
reason="GPU device and MPI libs required for test",
346+
)
347+
@pytest.mark.parametrize(
348+
"dimensions", [{"n": 100, "m": 10, "k": 2}, {"n": 100000, "m": 100, "k": 100}]
349+
)
350+
@pytest.mark.parametrize(
351+
"dataframe,queue",
352+
get_dataframes_and_queues(dataframe_filter_="dpnp,dpctl", device_filter_="gpu"),
353+
)
354+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
355+
@pytest.mark.mpi
356+
def test_knnsearch_spmd_synthetic(
357+
dimensions,
358+
dataframe,
359+
queue,
360+
dtype,
361+
):
362+
if dimensions["n"] > 10000 and dtype == np.float32:
363+
pytest.skip("Skipping large float32 test due to expected precision issues")
364+
365+
# Import spmd and batch algo
366+
from sklearnex.neighbors import NearestNeighbors as NearestNeighbors_Batch
367+
from sklearnex.spmd.neighbors import NearestNeighbors as NearestNeighbors_SPMD
368+
369+
# Generate data and convert to dataframe
370+
X_train = _generate_statistic_data(dimensions["n"], dimensions["m"], dtype=dtype)
371+
372+
local_dpt_X_train = _convert_to_dataframe(
373+
_get_local_tensor(X_train), sycl_queue=queue, target_df=dataframe
374+
)
375+
376+
# Ensure search results of batch algo match spmd
377+
spmd_model = NearestNeighbors_SPMD(
378+
n_neighbors=dimensions["k"], algorithm="brute"
379+
).fit(local_dpt_X_train)
380+
batch_model = NearestNeighbors_Batch(
381+
n_neighbors=dimensions["k"], algorithm="brute"
382+
).fit(X_train)
383+
spmd_dists, spmd_indcs = spmd_model.kneighbors(local_dpt_X_train)
384+
batch_dists, batch_indcs = batch_model.kneighbors(X_train)
385+
386+
tol = 0.005 if dtype == np.float32 else 1e-6
387+
_spmd_assert_allclose(spmd_indcs, batch_indcs)
388+
_spmd_assert_allclose(spmd_dists, batch_dists, rtol=tol, atol=tol)
389+
390+
391+
@pytest.mark.skipif(
392+
not _mpi_libs_and_gpu_available,
393+
reason="GPU device and MPI libs required for test",
394+
)
395+
@pytest.mark.parametrize(
396+
"dataframe,queue",
397+
get_dataframes_and_queues(dataframe_filter_="dpnp,dpctl", device_filter_="gpu"),
398+
)
399+
@pytest.mark.mpi
400+
def test_knn_spmd_empty_kneighbors(dataframe, queue):
401+
# Import spmd and batch algo
402+
from sklearnex.neighbors import NearestNeighbors as NearestNeighbors_Batch
403+
from sklearnex.spmd.neighbors import (
404+
KNeighborsClassifier,
405+
KNeighborsRegressor,
406+
NearestNeighbors,
407+
)
408+
409+
# Create gold data and convert to dataframe
410+
X_train = np.array(
411+
[[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2], [10, 10], [9, 9]]
412+
)
413+
y_train = np.array([0, 1, 0, 1, 0, 1, 0, 1])
414+
local_dpt_X_train = _convert_to_dataframe(
415+
_get_local_tensor(X_train), sycl_queue=queue, target_df=dataframe
416+
)
417+
local_dpt_y_train = _convert_to_dataframe(
418+
_get_local_tensor(y_train), sycl_queue=queue, target_df=dataframe
419+
)
420+
421+
# Run each estimator without an input to kneighbors() and ensure functionality and equivalence
422+
for CurrentEstimator in [KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors]:
423+
spmd_model = CurrentEstimator(n_neighbors=1, algorithm="brute").fit(
424+
local_dpt_X_train, local_dpt_y_train
425+
)
426+
batch_model = NearestNeighbors_Batch(n_neighbors=1, algorithm="brute").fit(
427+
X_train, y_train
428+
)
429+
spmd_dists, spmd_indcs = spmd_model.kneighbors()
430+
batch_dists, batch_indcs = batch_model.kneighbors()
431+
432+
_spmd_assert_allclose(spmd_indcs, batch_indcs)
433+
_spmd_assert_allclose(spmd_dists, batch_dists)

0 commit comments

Comments
 (0)