Skip to content

Commit dbef61e

Browse files
committed
Add integration for train only scenario
1 parent 7505c68 commit dbef61e

File tree

3 files changed

+116
-10
lines changed

3 files changed

+116
-10
lines changed

include/svs/index/ivf/common.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,9 @@ void centroid_assignment(
404404
data.dimensions()
405405
);
406406
}
407-
if constexpr (std::is_same_v<Distance, distance::DistanceIP>) {
407+
if constexpr (std::is_same_v<
408+
std::remove_cvref_t<Distance>,
409+
distance::DistanceIP>) {
408410
for (auto i : indices) {
409411
auto nearest =
410412
type_traits::sentinel_v<Neighbor<size_t>, std::greater<>>;
@@ -414,7 +416,9 @@ void centroid_assignment(
414416
}
415417
assignments[batch_range.start() + i] = nearest.id();
416418
}
417-
} else if constexpr (std::is_same_v<Distance, distance::DistanceL2>) {
419+
} else if constexpr (std::is_same_v<
420+
std::remove_cvref_t<Distance>,
421+
distance::DistanceL2>) {
418422
for (auto i : indices) {
419423
auto nearest = type_traits::sentinel_v<Neighbor<size_t>, std::less<>>;
420424
auto dists = matmul_results.get_datum(i);
@@ -563,13 +567,13 @@ auto kmeans_training(
563567
auto training_timer = timer.push_back("Kmeans training");
564568
data::SimpleData<float> centroids_fp32 = convert_data<float>(centroids, threadpool);
565569

566-
if constexpr (std::is_same_v<Distance, distance::DistanceIP>) {
570+
if constexpr (std::is_same_v<std::remove_cvref_t<Distance>, distance::DistanceIP>) {
567571
normalize_centroids(centroids_fp32, threadpool, timer);
568572
}
569573

570574
auto assignments = std::vector<size_t>(data.size());
571575
std::vector<float> data_norm;
572-
if constexpr (std::is_same_v<Distance, distance::DistanceL2>) {
576+
if constexpr (std::is_same_v<std::remove_cvref_t<Distance>, distance::DistanceL2>) {
573577
generate_norms(data, data_norm, threadpool);
574578
}
575579
std::vector<float> centroids_norm;
@@ -578,7 +582,7 @@ auto kmeans_training(
578582
auto iter_timer = timer.push_back("iteration");
579583
auto batchsize = parameters.minibatch_size_;
580584
auto num_batches = lib::div_round_up(data.size(), batchsize);
581-
if constexpr (std::is_same_v<Distance, distance::DistanceL2>) {
585+
if constexpr (std::is_same_v<std::remove_cvref_t<Distance>, distance::DistanceL2>) {
582586
generate_norms(centroids_fp32, centroids_norm, threadpool);
583587
}
584588

@@ -611,7 +615,7 @@ auto kmeans_training(
611615

612616
centroid_split(data, centroids_fp32, counts, rng, threadpool, timer);
613617

614-
if constexpr (std::is_same_v<Distance, distance::DistanceIP>) {
618+
if constexpr (std::is_same_v<std::remove_cvref_t<Distance>, distance::DistanceIP>) {
615619
normalize_centroids(centroids_fp32, threadpool, timer);
616620
}
617621
}
@@ -723,7 +727,7 @@ data::SimpleData<BuildType> init_centroids(
723727
template <typename Distance, typename Data, threads::ThreadPool Pool>
724728
std::vector<float> maybe_compute_norms(const Data& data, Pool& threadpool) {
725729
std::vector<float> norms;
726-
if constexpr (std::is_same_v<Distance, distance::DistanceL2>) {
730+
if constexpr (std::is_same_v<std::remove_cvref_t<Distance>, distance::DistanceL2>) {
727731
generate_norms(data, norms, threadpool);
728732
}
729733
return norms;
@@ -849,15 +853,15 @@ void search_centroids(
849853
) {
850854
unsigned int count = 0;
851855
buffer.clear();
852-
if constexpr (std::is_same_v<Dist, distance::DistanceIP>) {
856+
if constexpr (std::is_same_v<std::remove_cvref_t<Dist>, distance::DistanceIP>) {
853857
for (size_t j = 0; j < num_threads; j++) {
854858
auto distance = matmul_results[j].get_datum(query_id);
855859
for (size_t k = 0; k < distance.size(); k++) {
856860
buffer.insert({count, distance[k]});
857861
count++;
858862
}
859863
}
860-
} else if constexpr (std::is_same_v<Dist, distance::DistanceL2>) {
864+
} else if constexpr (std::is_same_v<std::remove_cvref_t<Dist>, distance::DistanceL2>) {
861865
float query_norm = distance::norm_square(query);
862866
for (size_t j = 0; j < num_threads; j++) {
863867
auto distance = matmul_results[j].get_datum(query_id);

include/svs/index/ivf/index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ class IVFIndex {
320320

321321
void initialize_distance_metadata() {
322322
// Precalculate centroid norms for L2 distance
323-
if constexpr (std::is_same_v<Dist, distance::DistanceL2>) {
323+
if constexpr (std::is_same_v<std::remove_cvref_t<Dist>, distance::DistanceL2>) {
324324
centroids_norm_.reserve(centroids_.size());
325325
for (size_t i = 0; i < centroids_.size(); i++) {
326326
centroids_norm_.push_back(distance::norm_square(centroids_.get_datum(i)));

tests/integration/ivf/index_build.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
// svs
1818
#include "svs/core/data/simple.h"
1919
#include "svs/core/recall.h"
20+
#include "svs/index/ivf/clustering.h"
21+
#include "svs/index/ivf/common.h"
22+
#include "svs/index/ivf/hierarchical_kmeans.h"
23+
#include "svs/lib/float16.h"
2024
#include "svs/lib/timing.h"
2125
#include "svs/orchestrators/ivf.h"
2226

@@ -99,6 +103,96 @@ void test_build(const Distance& distance, size_t num_inner_threads = 1) {
99103
}
100104
}
101105

106+
template <typename T, typename Distance>
107+
void test_build_train_only(const Distance& distance, size_t num_inner_threads = 1) {
108+
const double epsilon = 0.06; // Wider tolerance for train_only workflow
109+
const auto queries = svs::data::SimpleData<float>::load(test_dataset::query_file());
110+
CATCH_REQUIRE(svs_test::prepare_temp_directory());
111+
size_t num_threads = 2;
112+
113+
auto expected_result = test_dataset::ivf::expected_build_results(
114+
svs::distance_type_v<Distance>, svsbenchmark::Uncompressed(svs::datatype_v<T>)
115+
);
116+
117+
// Load data
118+
auto data = svs::data::SimpleData<float>::load(test_dataset::data_svs_file());
119+
auto threadpool = svs::threads::as_threadpool(num_threads);
120+
auto parameters = expected_result.build_parameters_.value();
121+
122+
// Step 1: Use train_only mode to get centroids
123+
svs::data::SimpleData<T> centroids_train;
124+
std::vector<std::vector<uint32_t>> clusters_train;
125+
fmt::print(
126+
"Starting Train-Only Mode Clustering with {} centroids\n", parameters.num_centroids_
127+
);
128+
129+
if (parameters.is_hierarchical_) {
130+
fmt::print("Using Hierarchical KMeans Clustering\n");
131+
std::tie(centroids_train, clusters_train) =
132+
svs::index::ivf::hierarchical_kmeans_clustering<T>(
133+
parameters,
134+
data,
135+
distance,
136+
threadpool,
137+
svs::lib::Type<uint32_t>(),
138+
svs::logging::get(),
139+
true // train_only = true
140+
);
141+
} else {
142+
std::tie(centroids_train, clusters_train) = svs::index::ivf::kmeans_clustering<T>(
143+
parameters,
144+
data,
145+
distance,
146+
threadpool,
147+
svs::lib::Type<uint32_t>(),
148+
svs::logging::get(),
149+
true // train_only = true
150+
);
151+
}
152+
153+
fmt::print("Train-Only Mode - Obtained {} centroids\n", centroids_train.size());
154+
155+
// Step 2: Assign data to clusters using cluster_assignment
156+
auto clusters = svs::index::ivf::cluster_assignment<T>(
157+
data,
158+
centroids_train,
159+
distance,
160+
threadpool,
161+
10'000, // minibatch_size
162+
svs::lib::Type<uint32_t>()
163+
);
164+
165+
// Step 3: Create clustering and assemble index
166+
svs::index::ivf::Clustering clustering(std::move(centroids_train), std::move(clusters));
167+
168+
auto index = svs::IVF::assemble_from_clustering<float>(
169+
std::move(clustering), std::move(data), distance, num_threads, num_inner_threads
170+
);
171+
172+
// Test the index with the same expected results
173+
auto groundtruth = test_dataset::load_groundtruth(svs::distance_type_v<Distance>);
174+
for (const auto& expected : expected_result.config_and_recall_) {
175+
auto these_queries = test_dataset::get_test_set(queries, expected.num_queries_);
176+
auto these_groundtruth =
177+
test_dataset::get_test_set(groundtruth, expected.num_queries_);
178+
index.set_search_parameters(expected.search_parameters_);
179+
auto results = index.search(these_queries, expected.num_neighbors_);
180+
double recall = svs::k_recall_at_n(
181+
these_groundtruth, results, expected.num_neighbors_, expected.recall_k_
182+
);
183+
184+
fmt::print(
185+
"Train-Only Mode - n_probes: {}, Expected Recall: {}, Actual Recall: {}\n",
186+
index.get_search_parameters().n_probes_,
187+
expected.recall_,
188+
recall
189+
);
190+
// Just check that recall is reasonable (within wider tolerance)
191+
CATCH_REQUIRE(recall > expected.recall_ - epsilon);
192+
CATCH_REQUIRE(recall < expected.recall_ + epsilon);
193+
}
194+
}
195+
102196
} // namespace
103197

104198
CATCH_TEST_CASE("IVF Build/Clustering", "[integration][build][ivf]") {
@@ -113,3 +207,11 @@ CATCH_TEST_CASE("IVF Build/Clustering", "[integration][build][ivf]") {
113207
// test_build<svs::BFloat16>(svs::DistanceL2(), 4);
114208
// test_build<svs::BFloat16>(svs::DistanceIP(), 4);
115209
}
210+
211+
CATCH_TEST_CASE("IVF Build/Clustering", "[integration][build][ivf][train_only]") {
212+
test_build_train_only<float>(svs::DistanceL2());
213+
test_build_train_only<svs::Float16>(svs::DistanceIP());
214+
215+
test_build_train_only<svs::BFloat16>(svs::DistanceL2());
216+
test_build_train_only<svs::BFloat16>(svs::DistanceIP());
217+
}

0 commit comments

Comments
 (0)