1717// svs
1818#include " svs/core/data/simple.h"
1919#include " svs/core/recall.h"
20+ #include " svs/index/ivf/clustering.h"
21+ #include " svs/index/ivf/common.h"
22+ #include " svs/index/ivf/hierarchical_kmeans.h"
23+ #include " svs/lib/float16.h"
2024#include " svs/lib/timing.h"
2125#include " svs/orchestrators/ivf.h"
2226
@@ -99,6 +103,96 @@ void test_build(const Distance& distance, size_t num_inner_threads = 1) {
99103 }
100104}
101105
106+ template <typename T, typename Distance>
107+ void test_build_train_only (const Distance& distance, size_t num_inner_threads = 1 ) {
108+ const double epsilon = 0.06 ; // Wider tolerance for train_only workflow
109+ const auto queries = svs::data::SimpleData<float >::load (test_dataset::query_file ());
110+ CATCH_REQUIRE (svs_test::prepare_temp_directory ());
111+ size_t num_threads = 2 ;
112+
113+ auto expected_result = test_dataset::ivf::expected_build_results (
114+ svs::distance_type_v<Distance>, svsbenchmark::Uncompressed (svs::datatype_v<T>)
115+ );
116+
117+ // Load data
118+ auto data = svs::data::SimpleData<float >::load (test_dataset::data_svs_file ());
119+ auto threadpool = svs::threads::as_threadpool (num_threads);
120+ auto parameters = expected_result.build_parameters_ .value ();
121+
122+ // Step 1: Use train_only mode to get centroids
123+ svs::data::SimpleData<T> centroids_train;
124+ std::vector<std::vector<uint32_t >> clusters_train;
125+ fmt::print (
126+ " Starting Train-Only Mode Clustering with {} centroids\n " , parameters.num_centroids_
127+ );
128+
129+ if (parameters.is_hierarchical_ ) {
130+ fmt::print (" Using Hierarchical KMeans Clustering\n " );
131+ std::tie (centroids_train, clusters_train) =
132+ svs::index::ivf::hierarchical_kmeans_clustering<T>(
133+ parameters,
134+ data,
135+ distance,
136+ threadpool,
137+ svs::lib::Type<uint32_t >(),
138+ svs::logging::get (),
139+ true // train_only = true
140+ );
141+ } else {
142+ std::tie (centroids_train, clusters_train) = svs::index::ivf::kmeans_clustering<T>(
143+ parameters,
144+ data,
145+ distance,
146+ threadpool,
147+ svs::lib::Type<uint32_t >(),
148+ svs::logging::get (),
149+ true // train_only = true
150+ );
151+ }
152+
153+ fmt::print (" Train-Only Mode - Obtained {} centroids\n " , centroids_train.size ());
154+
155+ // Step 2: Assign data to clusters using cluster_assignment
156+ auto clusters = svs::index::ivf::cluster_assignment<T>(
157+ data,
158+ centroids_train,
159+ distance,
160+ threadpool,
161+ 10'000 , // minibatch_size
162+ svs::lib::Type<uint32_t >()
163+ );
164+
165+ // Step 3: Create clustering and assemble index
166+ svs::index::ivf::Clustering clustering (std::move (centroids_train), std::move (clusters));
167+
168+ auto index = svs::IVF::assemble_from_clustering<float >(
169+ std::move (clustering), std::move (data), distance, num_threads, num_inner_threads
170+ );
171+
172+ // Test the index with the same expected results
173+ auto groundtruth = test_dataset::load_groundtruth (svs::distance_type_v<Distance>);
174+ for (const auto & expected : expected_result.config_and_recall_ ) {
175+ auto these_queries = test_dataset::get_test_set (queries, expected.num_queries_ );
176+ auto these_groundtruth =
177+ test_dataset::get_test_set (groundtruth, expected.num_queries_ );
178+ index.set_search_parameters (expected.search_parameters_ );
179+ auto results = index.search (these_queries, expected.num_neighbors_ );
180+ double recall = svs::k_recall_at_n (
181+ these_groundtruth, results, expected.num_neighbors_ , expected.recall_k_
182+ );
183+
184+ fmt::print (
185+ " Train-Only Mode - n_probes: {}, Expected Recall: {}, Actual Recall: {}\n " ,
186+ index.get_search_parameters ().n_probes_ ,
187+ expected.recall_ ,
188+ recall
189+ );
190+ // Just check that recall is reasonable (within wider tolerance)
191+ CATCH_REQUIRE (recall > expected.recall_ - epsilon);
192+ CATCH_REQUIRE (recall < expected.recall_ + epsilon);
193+ }
194+ }
195+
102196} // namespace
103197
104198CATCH_TEST_CASE (" IVF Build/Clustering" , " [integration][build][ivf]" ) {
@@ -113,3 +207,11 @@ CATCH_TEST_CASE("IVF Build/Clustering", "[integration][build][ivf]") {
113207 // test_build<svs::BFloat16>(svs::DistanceL2(), 4);
114208 // test_build<svs::BFloat16>(svs::DistanceIP(), 4);
115209}
210+
211+ CATCH_TEST_CASE (" IVF Build/Clustering" , " [integration][build][ivf][train_only]" ) {
212+ test_build_train_only<float >(svs::DistanceL2 ());
213+ test_build_train_only<svs::Float16>(svs::DistanceIP ());
214+
215+ test_build_train_only<svs::BFloat16>(svs::DistanceL2 ());
216+ test_build_train_only<svs::BFloat16>(svs::DistanceIP ());
217+ }
0 commit comments