Skip to content

Commit b3558f0

Browse files
authored
Add element_size function for Dataset and SQDataset elements (#191)
Add element_size function for Dataset and SQDataset elements. Add 'Check element_size' test cases with full datatypes coverage. Related-to issue #104
1 parent cc1f139 commit b3558f0

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

include/svs/core/data/simple.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ class SimpleData {
292292
/// Return the number of dimensions for each entry in the dataset.
293293
size_t dimensions() const { return getsize<1>(data_); }
294294

295+
/// Return the size in bytes of one vector: sizeof(element_type) * dimensions()
296+
size_t element_size() const { return sizeof(element_type) * dimensions(); }
297+
295298
///
296299
/// @brief Return a constant handle to vector stored as position ``i``.
297300
///
@@ -720,6 +723,8 @@ class SimpleData<T, Extent, Blocked<Alloc>> {
720723
}
721724
}
722725

726+
size_t element_size() const { return sizeof(element_type) * dimensions(); }
727+
723728
const_value_type get_datum(size_t i) const {
724729
auto [block_id, data_id] = resolve(i);
725730
return getindex(blocks_, block_id).slice(data_id);

include/svs/quantization/scalar/scalar.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ class SQDataset {
390390

391391
size_t size() const { return data_.size(); }
392392
size_t dimensions() const { return data_.dimensions(); }
393+
size_t element_size() const { return sizeof(element_type) * dimensions(); }
393394

394395
float get_scale() const { return scale_; }
395396
float get_bias() const { return bias_; }

tests/svs/core/data.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,57 @@ CATCH_TEST_CASE("Data Loading/Saving", "[core][data]") {
8888
CATCH_REQUIRE(w == z);
8989
}
9090
}
91+
92+
CATCH_TEST_CASE("Element Size", "[core][data]") {
93+
CATCH_SECTION("Check element_size()") {
94+
// Test with float, dynamic dimensions
95+
auto float_data = svs::data::SimpleData<float, svs::Dynamic>(5, 10);
96+
CATCH_REQUIRE(float_data.element_size() == sizeof(float) * 10);
97+
98+
// Test with double, dynamic dimensions
99+
auto double_data = svs::data::SimpleData<double, svs::Dynamic>(3, 16);
100+
CATCH_REQUIRE(double_data.element_size() == sizeof(double) * 16);
101+
102+
// Test with int8_t, fixed dimensions
103+
auto int8_data = svs::data::SimpleData<int8_t, 32>(10, 32);
104+
CATCH_REQUIRE(int8_data.element_size() == sizeof(int8_t) * 32);
105+
106+
// Test with int16_t, dynamic dimensions
107+
auto int16_data = svs::data::SimpleData<int16_t, svs::Dynamic>(8, 64);
108+
CATCH_REQUIRE(int16_data.element_size() == sizeof(int16_t) * 64);
109+
110+
// Test with int32_t, fixed dimensions
111+
auto int32_data = svs::data::SimpleData<int32_t, 128>(5, 128);
112+
CATCH_REQUIRE(int32_data.element_size() == sizeof(int32_t) * 128);
113+
114+
// Test with uint8_t, dynamic dimensions
115+
auto uint8_data = svs::data::SimpleData<uint8_t, svs::Dynamic>(12, 256);
116+
CATCH_REQUIRE(uint8_data.element_size() == sizeof(uint8_t) * 256);
117+
118+
// Test with uint16_t, fixed dimensions
119+
auto uint16_data = svs::data::SimpleData<uint16_t, 48>(7, 48);
120+
CATCH_REQUIRE(uint16_data.element_size() == sizeof(uint16_t) * 48);
121+
122+
// Test with uint32_t, dynamic dimensions
123+
auto uint32_data = svs::data::SimpleData<uint32_t, svs::Dynamic>(6, 96);
124+
CATCH_REQUIRE(uint32_data.element_size() == sizeof(uint32_t) * 96);
125+
126+
// Test fixed dimensions with blocked storage
127+
auto blocked_fixed = svs::data::BlockedData<int32_t, 64>(25, 64);
128+
CATCH_REQUIRE(blocked_fixed.element_size() == sizeof(int32_t) * 64);
129+
130+
// Test element_size consistency across different instances
131+
auto data1 = svs::data::SimpleData<float, svs::Dynamic>(10, 20);
132+
// Different size, same dims
133+
auto data2 = svs::data::SimpleData<float, svs::Dynamic>(50, 20);
134+
CATCH_REQUIRE(data1.element_size() == data2.element_size());
135+
136+
// Test consistency across different data types with same dimensions
137+
auto float_128 = svs::data::SimpleData<float, svs::Dynamic>(5, 128);
138+
auto double_128 = svs::data::SimpleData<double, svs::Dynamic>(5, 128);
139+
CATCH_REQUIRE(float_128.element_size() == sizeof(float) * 128);
140+
CATCH_REQUIRE(double_128.element_size() == sizeof(double) * 128);
141+
// double is 2x float
142+
CATCH_REQUIRE(double_128.element_size() == 2 * float_128.element_size());
143+
}
144+
}

tests/svs/quantization/scalar/scalar.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,64 @@ template <typename T, typename Distance> void test_distance() {
170170
test_distance_single<std::int8_t, Distance, N>(-10, 1);
171171
}
172172

173+
CATCH_TEST_CASE("SQDataset Element Size", "[quantization][scalar]") {
174+
CATCH_SECTION("Check element_size()") {
175+
// Test with int8_t, dynamic dimensions
176+
auto sq_int8_dynamic = scalar::SQDataset<std::int8_t>(10, 128);
177+
CATCH_REQUIRE(sq_int8_dynamic.element_size() == sizeof(std::int8_t) * 128);
178+
179+
// Test with int8_t, fixed dimensions
180+
constexpr size_t dims_64 = 64;
181+
auto sq_int8_fixed = scalar::SQDataset<std::int8_t, dims_64>({}, 1.0F, 0.0F);
182+
CATCH_REQUIRE(sq_int8_fixed.element_size() == sizeof(std::int8_t) * dims_64);
183+
184+
// Test with int16_t, dynamic dimensions
185+
auto sq_int16_dynamic = scalar::SQDataset<std::int16_t>(5, 256);
186+
CATCH_REQUIRE(sq_int16_dynamic.element_size() == sizeof(std::int16_t) * 256);
187+
188+
// Test with int16_t, fixed dimensions
189+
constexpr size_t dims_32 = 32;
190+
auto sq_int16_fixed = scalar::SQDataset<std::int16_t, dims_32>({}, 2.0F, 1.0F);
191+
CATCH_REQUIRE(sq_int16_fixed.element_size() == sizeof(std::int16_t) * dims_32);
192+
193+
// Test with uint8_t, dynamic dimensions
194+
auto sq_uint8_dynamic = scalar::SQDataset<std::uint8_t>(8, 96);
195+
CATCH_REQUIRE(sq_uint8_dynamic.element_size() == sizeof(std::uint8_t) * 96);
196+
197+
// Test with uint8_t, fixed dimensions
198+
constexpr size_t dims_48 = 48;
199+
auto sq_uint8_fixed = scalar::SQDataset<std::uint8_t, dims_48>({}, 1.5F, 0.5F);
200+
CATCH_REQUIRE(sq_uint8_fixed.element_size() == sizeof(std::uint8_t) * dims_48);
201+
202+
// Test with uint16_t, dynamic dimensions
203+
auto sq_uint16_dynamic = scalar::SQDataset<std::uint16_t>(12, 200);
204+
CATCH_REQUIRE(sq_uint16_dynamic.element_size() == sizeof(std::uint16_t) * 200);
205+
206+
// Test element_size consistency across different instances with same type/dims
207+
auto sq1 = scalar::SQDataset<std::int8_t>(10, 100);
208+
// Different size, same dims
209+
auto sq2 = scalar::SQDataset<std::int8_t>(50, 100);
210+
CATCH_REQUIRE(sq1.element_size() == sq2.element_size());
211+
212+
// Test different quantized types with same dimensions
213+
auto sq_int8_128 = scalar::SQDataset<std::int8_t>(5, 128);
214+
auto sq_int16_128 = scalar::SQDataset<std::int16_t>(5, 128);
215+
CATCH_REQUIRE(sq_int8_128.element_size() == sizeof(std::int8_t) * 128);
216+
CATCH_REQUIRE(sq_int16_128.element_size() == sizeof(std::int16_t) * 128);
217+
// int16 is 2x int8
218+
CATCH_REQUIRE(sq_int16_128.element_size() == 2 * sq_int8_128.element_size());
219+
220+
// Test that element_size reflects the quantized type, not original float
221+
auto original = svs::data::SimpleData<float>(5, 128);
222+
auto compressed = scalar::SQDataset<std::int8_t>::compress(original);
223+
CATCH_REQUIRE(compressed.element_size() == sizeof(std::int8_t) * 128);
224+
// Should be smaller
225+
CATCH_REQUIRE(compressed.element_size() != sizeof(float) * 128);
226+
// int8 vs float
227+
CATCH_REQUIRE(compressed.element_size() == original.element_size() / 4);
228+
}
229+
}
230+
173231
CATCH_TEST_CASE("Testing SQDataset", "[quantization][scalar]") {
174232
CATCH_SECTION("Default SQDataset") {}
175233

0 commit comments

Comments
 (0)