Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 124 additions & 40 deletions python/cuml/cuml/common/sparsefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ def extract_knn_graph(knn_graph):
Converts KNN graph from CSR, COO and CSC formats into separate
distance and indice arrays. Input can be a cupy sparse graph (device)
or a numpy sparse graph (host).

Returns
-------
tuple or None
(knn_indices, knn_dists, n_samples) where indices and dists are flattened
arrays and n_samples is the number of rows in the graph, or None if
the format is not supported.
"""
if isinstance(knn_graph, (csc_matrix, cp_csc_matrix)):
knn_graph = cupyx.scipy.sparse.csr_matrix(knn_graph)
Expand All @@ -208,12 +215,14 @@ def extract_knn_graph(knn_graph):
knn_indices = None
if isinstance(knn_graph, (csr_matrix, cp_csr_matrix)):
knn_indices = knn_graph.indices
n_samples = knn_graph.shape[0]
elif isinstance(knn_graph, (coo_matrix, cp_coo_matrix)):
knn_indices = knn_graph.col
n_samples = knn_graph.shape[0]

if knn_indices is not None:
knn_dists = knn_graph.data
return knn_indices, knn_dists
return knn_indices, knn_dists, n_samples
else:
return None

Expand Down Expand Up @@ -243,6 +252,31 @@ def extract_pairwise_dists(pw_dists, n_neighbors):
return knn_indices, knn_dists


def _determine_k_from_arrays(
knn_indices_arr, n_neighbors, n_samples_hint=None
):
"""Determine k (neighbors per sample) from array shape."""
if len(knn_indices_arr.shape) == 2:
return knn_indices_arr.shape[1]

# 1D flattened array - infer n_samples and k
total_elements = knn_indices_arr.shape[0]
n_samples = (
n_samples_hint
if n_samples_hint is not None
else total_elements // n_neighbors
)

if total_elements % n_samples != 0:
raise ValueError(
f"Precomputed KNN data has {total_elements} total elements which is not evenly "
f"divisible by {n_samples} samples. Expected {n_samples * n_neighbors} elements "
f"for n_neighbors={n_neighbors}."
)

return total_elements // n_samples


@with_cupy_rmm
def extract_knn_infos(knn_info, n_neighbors):
"""
Expand All @@ -260,52 +294,102 @@ def extract_knn_infos(knn_info, n_neighbors):
n_neighbors: number of nearest neighbors
"""
if knn_info is None:
# no KNN was provided
return None

# Extract indices, distances, and optional n_samples hint
deepcopy = False
n_samples_hint = None

if isinstance(knn_info, tuple):
# dists and indices provided as a tuple
results = knn_info
knn_indices, knn_dists = knn_info
elif isinstance(
knn_info,
(
csr_matrix,
coo_matrix,
csc_matrix,
cp_csr_matrix,
cp_coo_matrix,
cp_csc_matrix,
),
):
# Sparse matrix
result = extract_knn_graph(knn_info)
if result is None:
return None
knn_indices, knn_dists, n_samples_hint = result
deepcopy = True
else:
isaKNNGraph = isinstance(
knn_info,
(
csr_matrix,
coo_matrix,
csc_matrix,
cp_csr_matrix,
cp_coo_matrix,
cp_csc_matrix,
),
# Dense pairwise distance matrix
result = extract_pairwise_dists(knn_info, n_neighbors)
if result is None:
return None
knn_indices, knn_dists = result

# Validate the extracted data
knn_indices_arr = (
knn_indices
if hasattr(knn_indices, "shape")
else np.asarray(knn_indices)
)
knn_dists_arr = (
knn_dists if hasattr(knn_dists, "shape") else np.asarray(knn_dists)
)

if knn_indices_arr.shape != knn_dists_arr.shape:
raise ValueError(
f"Precomputed KNN indices and distances must have the same shape. "
f"Got indices shape {knn_indices_arr.shape} and distances shape {knn_dists_arr.shape}."
)
if isaKNNGraph:
# extract dists and indices from a KNN graph
deepcopy = True
results = extract_knn_graph(knn_info)
else:
# extract dists and indices from a pairwise distance matrix
results = extract_pairwise_dists(knn_info, n_neighbors)

if results is not None:
knn_indices, knn_dists = results

knn_indices_m, _, _, _ = input_to_cuml_array(
knn_indices.flatten(),
order="C",
deepcopy=deepcopy,
check_dtype=np.int64,
convert_to_dtype=np.int64,

if len(knn_indices_arr.shape) not in (1, 2):
raise ValueError(
f"Precomputed KNN indices must be 1D or 2D array, got shape {knn_indices_arr.shape}"
)

knn_dists_m, _, _, _ = input_to_cuml_array(
knn_dists.flatten(),
order="C",
deepcopy=deepcopy,
check_dtype=np.float32,
convert_to_dtype=np.float32,
# Determine actual k and validate against expected n_neighbors
k_provided = _determine_k_from_arrays(
knn_indices_arr, n_neighbors, n_samples_hint
)

if k_provided < n_neighbors:
raise ValueError(
f"Precomputed KNN data has {k_provided} neighbors per sample, "
f"but n_neighbors={n_neighbors} was specified. "
f"Cannot use fewer neighbors than requested. "
f"Please provide KNN data with at least {n_neighbors} neighbors per sample."
)
elif k_provided > n_neighbors:
# Trim excess neighbors
if len(knn_indices_arr.shape) == 2:
# 2D array case: trim columns
knn_indices = knn_indices[:, :n_neighbors]
knn_dists = knn_dists[:, :n_neighbors]
else:
# 1D flattened array case: reshape, trim, and flatten
n_samples = knn_indices_arr.shape[0] // k_provided
knn_indices = knn_indices.reshape((n_samples, k_provided))[
:, :n_neighbors
]
knn_dists = knn_dists.reshape((n_samples, k_provided))[
:, :n_neighbors
]

# Convert to CumlArray
knn_indices_m, _, _, _ = input_to_cuml_array(
knn_indices.flatten(),
order="C",
deepcopy=deepcopy,
check_dtype=np.int64,
convert_to_dtype=np.int64,
)

return knn_indices_m, knn_dists_m
else:
return None
knn_dists_m, _, _, _ = input_to_cuml_array(
knn_dists.flatten(),
order="C",
deepcopy=deepcopy,
check_dtype=np.float32,
convert_to_dtype=np.float32,
)

return knn_indices_m, knn_dists_m
76 changes: 76 additions & 0 deletions python/cuml/tests/test_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,3 +986,79 @@ def test_umap_outliers(n_neighbors, n_components):
(gpu_umap_embeddings >= lower_bound)
& (gpu_umap_embeddings <= upper_bound)
)


@pytest.mark.parametrize("precomputed_type", ["tuple", "knn_graph"])
@pytest.mark.parametrize("k_provided,k_requested", [(15, 10), (20, 8)])
def test_umap_precomputed_knn_trimming(
precomputed_type, k_provided, k_requested
):
"""
Test that precomputed KNN data with more neighbors than requested
is automatically trimmed instead of raising an error.
"""
data, labels = make_blobs(
n_samples=500, n_features=10, centers=5, random_state=0
)
data = data.astype(np.float32)

# Build KNN graph with more neighbors than we'll request
nn = NearestNeighbors(n_neighbors=k_provided)
nn.fit(data)

if precomputed_type == "tuple":
distances, indices = nn.kneighbors(data, return_distance=True)
precomputed_knn = (indices, distances)
elif precomputed_type == "knn_graph":
precomputed_knn = nn.kneighbors_graph(data, mode="distance")

# This should work now - the excess neighbors should be trimmed
model = cuUMAP(
n_neighbors=k_requested,
precomputed_knn=precomputed_knn,
random_state=42,
init="random",
)
embedding = model.fit_transform(data)

# Verify the embedding is valid
assert embedding.shape == (data.shape[0], 2)
assert not np.isnan(embedding).any()

# Verify trustworthiness with the requested number of neighbors
trust = trustworthiness(data, embedding, n_neighbors=k_requested)
assert trust >= 0.85


@pytest.mark.parametrize("precomputed_type", ["tuple", "knn_graph"])
def test_umap_precomputed_knn_insufficient_neighbors(precomputed_type):
"""
Test that precomputed KNN data with fewer neighbors than requested
raises an appropriate error.
"""
data, labels = make_blobs(
n_samples=500, n_features=10, centers=5, random_state=0
)
data = data.astype(np.float32)

k_provided = 5
k_requested = 10

# Build KNN graph with fewer neighbors than we'll request
nn = NearestNeighbors(n_neighbors=k_provided)
nn.fit(data)

if precomputed_type == "tuple":
distances, indices = nn.kneighbors(data, return_distance=True)
precomputed_knn = (indices, distances)
elif precomputed_type == "knn_graph":
precomputed_knn = nn.kneighbors_graph(data, mode="distance")

# This should raise an error during initialization
with pytest.raises(ValueError, match=".*fewer neighbors.*"):
cuUMAP(
n_neighbors=k_requested,
precomputed_knn=precomputed_knn,
random_state=42,
init="random",
)