Skip to content

Commit cc1f139

Browse files
Copilotahuber21
andauthored
Refactor AVX2 distance computations to consistently use generic_simd_op (#196)
- [x] Understand the existing code structure and identify inconsistencies - [x] Create `L2FloatOp<8>` for AVX2 L2 distance computations - [x] Create `ConvertToFloat<8>` base class for AVX2 - [x] Refactor L2 AVX2 implementations to use `simd::generic_simd_op()` - [x] Create `IPFloatOp<8>` for AVX2 Inner Product computations - [x] Refactor Inner Product AVX2 implementations to use `simd::generic_simd_op()` - [x] Create `CosineFloatOp<8>` for AVX2 Cosine Similarity computations - [x] Add AVX2 implementations for Cosine Similarity with all type combinations - [x] Build and test all changes - [x] Fix compilation warnings - [x] Address code review feedback - [x] Optimize masked load implementation ## Recent Changes Reverted AVX512VL conditional specializations based on reviewer feedback. Now using consistent blend mask approach for all AVX2 implementations without runtime conditionals for AVX512VL. The implementation now: - Uses `create_blend_mask_avx2()` helper function to create masks - Uses `_mm256_blendv_ps` for masked loads on AVX2 - Handles masking in load operations for accumulate functions - Maintains clean separation between AVX2 and AVX512 code paths Performance regression resolved - benchmarks confirmed performance parity on both AVX512 and AVX2 systems. <!-- START COPILOT CODING AGENT SUFFIX --> <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>Consistently use compute ops in distance computations</issue_title> > <issue_description>Refactor the compute operations in our [distance primitives](https://github.com/intel/ScalableVectorSearch/tree/main/include/svs/core/distance). > > The goal is that all `compute()` methods will dispatch to `simd::generic_simd_op()`, as it is done for example in > > https://github.com/intel/ScalableVectorSearch/blob/18ba515e595ee65ff39d3abae3f42a56b3f5ab2c/include/svs/core/distance/euclidean.h#L321 > > To achieve this, the actual compute operation must be wrapped into a SIMD struct that provides the necessary protocol. One example is `L2FloatOp` > > https://github.com/intel/ScalableVectorSearch/blob/18ba515e595ee65ff39d3abae3f42a56b3f5ab2c/include/svs/core/distance/euclidean.h#L240-L259 > > That is, it is necessary to provide `init()`, `accumulate()`, and `combine()` operations. > > The chain `compute() -> simd::generic_simd_op() -> <actual compute>` is happening in many places already, but not consistently. One example where avx intrinsics are still used directly in `compute()` is > > https://github.com/intel/ScalableVectorSearch/blob/18ba515e595ee65ff39d3abae3f42a56b3f5ab2c/include/svs/core/distance/euclidean.h#L369 > > But many other examples exist in our [distance primitives](https://github.com/intel/ScalableVectorSearch/tree/main/include/svs/core/distance).</issue_description> > > <agent_instructions>As a first step, it is enough to focus on consistency between L2, Inner product, and cosine similarity distances.</agent_instructions> > > ## Comments on the Issue (you are @copilot in this section) > > <comments> > </comments> > </details> Fixes #183 <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ahuber21 <[email protected]> Co-authored-by: Andreas Huber <[email protected]>
1 parent 070fa79 commit cc1f139

File tree

4 files changed

+264
-204
lines changed

4 files changed

+264
-204
lines changed

include/svs/core/distance/cosine.h

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,120 @@ struct CosineSimilarityImpl<N, Float16, Float16, AVX_AVAILABILITY::AVX512> {
383383

384384
#endif
385385

386+
/////
387+
///// Intel(R) AVX2 Implementations
388+
/////
389+
390+
SVS_VALIDATE_BOOL_ENV(SVS_AVX512_F)
391+
SVS_VALIDATE_BOOL_ENV(SVS_AVX2)
392+
#if !SVS_AVX512_F && SVS_AVX2
393+
394+
template <> struct CosineFloatOp<8> : public svs::simd::ConvertToFloat<8> {
395+
using parent = svs::simd::ConvertToFloat<8>;
396+
using mask_t = typename parent::mask_t;
397+
static constexpr size_t simd_width = 8;
398+
399+
// A lightweight struct to contain both the partial results for the inner product
400+
// of the left-hand and right-hand as well as partial results for computing the norm
401+
// of the right-hand.
402+
struct Pair {
403+
__m256 op;
404+
__m256 norm;
405+
};
406+
407+
static Pair init() { return {_mm256_setzero_ps(), _mm256_setzero_ps()}; };
408+
409+
static Pair accumulate(Pair accumulator, __m256 a, __m256 b) {
410+
return {
411+
_mm256_fmadd_ps(a, b, accumulator.op), _mm256_fmadd_ps(b, b, accumulator.norm)};
412+
}
413+
414+
static Pair accumulate(mask_t /*m*/, Pair accumulator, __m256 a, __m256 b) {
415+
// For AVX2, masking is handled in the load operations
416+
return {
417+
_mm256_fmadd_ps(a, b, accumulator.op), _mm256_fmadd_ps(b, b, accumulator.norm)};
418+
}
419+
420+
static Pair combine(Pair x, Pair y) {
421+
return {_mm256_add_ps(x.op, y.op), _mm256_add_ps(x.norm, y.norm)};
422+
}
423+
424+
static std::pair<float, float> reduce(Pair x) {
425+
return std::make_pair(
426+
simd::_mm256_reduce_add_ps(x.op), simd::_mm256_reduce_add_ps(x.norm)
427+
);
428+
}
429+
};
430+
431+
// Floating and Mixed Types
432+
template <size_t N> struct CosineSimilarityImpl<N, float, float, AVX_AVAILABILITY::AVX2> {
433+
SVS_NOINLINE static float
434+
compute(const float* a, const float* b, float a_norm, lib::MaybeStatic<N> length) {
435+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<8>(), a, b, length);
436+
return sum / (std::sqrt(norm) * a_norm);
437+
}
438+
};
439+
440+
template <size_t N> struct CosineSimilarityImpl<N, float, uint8_t, AVX_AVAILABILITY::AVX2> {
441+
SVS_NOINLINE static float
442+
compute(const float* a, const uint8_t* b, float a_norm, lib::MaybeStatic<N> length) {
443+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<8>(), a, b, length);
444+
return sum / (std::sqrt(norm) * a_norm);
445+
};
446+
};
447+
448+
template <size_t N> struct CosineSimilarityImpl<N, float, int8_t, AVX_AVAILABILITY::AVX2> {
449+
SVS_NOINLINE static float
450+
compute(const float* a, const int8_t* b, float a_norm, lib::MaybeStatic<N> length) {
451+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<8>(), a, b, length);
452+
return sum / (std::sqrt(norm) * a_norm);
453+
};
454+
};
455+
456+
template <size_t N> struct CosineSimilarityImpl<N, float, Float16, AVX_AVAILABILITY::AVX2> {
457+
SVS_NOINLINE static float
458+
compute(const float* a, const Float16* b, float a_norm, lib::MaybeStatic<N> length) {
459+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<8>{}, a, b, length);
460+
return sum / (std::sqrt(norm) * a_norm);
461+
}
462+
};
463+
464+
template <size_t N> struct CosineSimilarityImpl<N, Float16, float, AVX_AVAILABILITY::AVX2> {
465+
SVS_NOINLINE static float
466+
compute(const Float16* a, const float* b, float a_norm, lib::MaybeStatic<N> length) {
467+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<8>{}, a, b, length);
468+
return sum / (std::sqrt(norm) * a_norm);
469+
}
470+
};
471+
472+
template <size_t N>
473+
struct CosineSimilarityImpl<N, Float16, Float16, AVX_AVAILABILITY::AVX2> {
474+
SVS_NOINLINE static float
475+
compute(const Float16* a, const Float16* b, float a_norm, lib::MaybeStatic<N> length) {
476+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<8>{}, a, b, length);
477+
return sum / (std::sqrt(norm) * a_norm);
478+
}
479+
};
480+
481+
template <size_t N> struct CosineSimilarityImpl<N, int8_t, int8_t, AVX_AVAILABILITY::AVX2> {
482+
SVS_NOINLINE static float
483+
compute(const int8_t* a, const int8_t* b, float a_norm, lib::MaybeStatic<N> length) {
484+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<8>{}, a, b, length);
485+
return sum / (std::sqrt(norm) * a_norm);
486+
}
487+
};
488+
489+
template <size_t N>
490+
struct CosineSimilarityImpl<N, uint8_t, uint8_t, AVX_AVAILABILITY::AVX2> {
491+
SVS_NOINLINE static float
492+
compute(const uint8_t* a, const uint8_t* b, float a_norm, lib::MaybeStatic<N> length) {
493+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<8>{}, a, b, length);
494+
return sum / (std::sqrt(norm) * a_norm);
495+
}
496+
};
497+
498+
#endif
499+
386500
#if defined(__x86_64__)
387501

388502
#include "svs/multi-arch/x86/preprocessor.h"

include/svs/core/distance/euclidean.h

Lines changed: 30 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -366,144 +366,69 @@ template <size_t N> struct L2Impl<N, Float16, Float16, AVX_AVAILABILITY::AVX512>
366366
SVS_VALIDATE_BOOL_ENV(SVS_AVX512_F)
367367
SVS_VALIDATE_BOOL_ENV(SVS_AVX2)
368368
#if !SVS_AVX512_F && SVS_AVX2
369+
370+
template <> struct L2FloatOp<8> : public svs::simd::ConvertToFloat<8> {
371+
using parent = svs::simd::ConvertToFloat<8>;
372+
using mask_t = typename parent::mask_t;
373+
static constexpr size_t simd_width = 8;
374+
375+
// Here, we can fill-in the shared init, accumulate, combine, and reduce methods.
376+
static __m256 init() { return _mm256_setzero_ps(); }
377+
378+
static __m256 accumulate(__m256 accumulator, __m256 a, __m256 b) {
379+
auto c = _mm256_sub_ps(a, b);
380+
return _mm256_fmadd_ps(c, c, accumulator);
381+
}
382+
383+
static __m256 accumulate(mask_t /*m*/, __m256 accumulator, __m256 a, __m256 b) {
384+
// For AVX2, masking is handled in the load operations
385+
auto c = _mm256_sub_ps(a, b);
386+
return _mm256_fmadd_ps(c, c, accumulator);
387+
}
388+
389+
static __m256 combine(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
390+
static float reduce(__m256 x) { return simd::_mm256_reduce_add_ps(x); }
391+
};
392+
369393
template <size_t N> struct L2Impl<N, float, float, AVX_AVAILABILITY::AVX2> {
370394
SVS_NOINLINE static float
371395
compute(const float* a, const float* b, lib::MaybeStatic<N> length) {
372-
constexpr size_t vector_size = 8;
373-
374-
// Peel off the last iterations if the SIMD vector width does not evenly the total
375-
// vector width.
376-
size_t upper = lib::upper<vector_size>(length);
377-
auto rest = lib::rest<vector_size>(length);
378-
auto sum = _mm256_setzero_ps();
379-
for (size_t j = 0; j < upper; j += vector_size) {
380-
auto va = _mm256_loadu_ps(a + j);
381-
auto vb = _mm256_loadu_ps(b + j);
382-
auto tmp = _mm256_sub_ps(va, vb);
383-
sum = _mm256_fmadd_ps(tmp, tmp, sum);
384-
}
385-
return simd::_mm256_reduce_add_ps(sum) + generic_l2(a + upper, b + upper, rest);
396+
return simd::generic_simd_op(L2FloatOp<8>{}, a, b, length);
386397
}
387398
};
388399

389400
template <size_t N> struct L2Impl<N, Float16, Float16, AVX_AVAILABILITY::AVX2> {
390401
SVS_NOINLINE static float
391402
compute(const Float16* a, const Float16* b, lib::MaybeStatic<N> length) {
392-
constexpr size_t vector_size = 8;
393-
394-
// Peel off the last iterations if the SIMD vector width does not evenly the total
395-
// vector width.
396-
size_t upper = lib::upper<vector_size>(length);
397-
auto rest = lib::rest<vector_size>(length);
398-
auto sum = _mm256_setzero_ps();
399-
for (size_t j = 0; j < upper; j += vector_size) {
400-
auto va =
401-
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(a + j)));
402-
auto vb =
403-
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(b + j)));
404-
auto tmp = _mm256_sub_ps(va, vb);
405-
sum = _mm256_fmadd_ps(tmp, tmp, sum);
406-
}
407-
return simd::_mm256_reduce_add_ps(sum) + generic_l2(a + upper, b + upper, rest);
403+
return simd::generic_simd_op(L2FloatOp<8>{}, a, b, length);
408404
}
409405
};
410406

411407
template <size_t N> struct L2Impl<N, float, Float16, AVX_AVAILABILITY::AVX2> {
412408
SVS_NOINLINE static float
413409
compute(const float* a, const Float16* b, lib::MaybeStatic<N> length) {
414-
constexpr size_t vector_size = 8;
415-
416-
// Peel off the last iterations if the SIMD vector width does not evenly the total
417-
// vector width.
418-
size_t upper = lib::upper<vector_size>(length);
419-
auto rest = lib::rest<vector_size>(length);
420-
auto sum = _mm256_setzero_ps();
421-
for (size_t j = 0; j < upper; j += vector_size) {
422-
auto va = _mm256_loadu_ps(a + j);
423-
auto vb =
424-
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(b + j)));
425-
auto tmp = _mm256_sub_ps(va, vb);
426-
sum = _mm256_fmadd_ps(tmp, tmp, sum);
427-
}
428-
return simd::_mm256_reduce_add_ps(sum) + generic_l2(a + upper, b + upper, rest);
410+
return simd::generic_simd_op(L2FloatOp<8>{}, a, b, length);
429411
}
430412
};
431413

432414
template <size_t N> struct L2Impl<N, float, int8_t, AVX_AVAILABILITY::AVX2> {
433415
SVS_NOINLINE static float
434416
compute(const float* a, const int8_t* b, lib::MaybeStatic<N> length) {
435-
constexpr size_t vector_size = 8;
436-
437-
// Peel off the last iterations if the SIMD vector width does not evenly the total
438-
// vector width.
439-
size_t upper = lib::upper<vector_size>(length);
440-
auto rest = lib::rest<vector_size>(length);
441-
auto sum = _mm256_setzero_ps();
442-
for (size_t j = 0; j < upper; j += vector_size) {
443-
auto va = _mm256_castsi256_ps(
444-
_mm256_lddqu_si256(reinterpret_cast<const __m256i*>(a + j))
445-
);
446-
auto vb = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
447-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(b + j)))
448-
));
449-
auto tmp = _mm256_sub_ps(va, vb);
450-
sum = _mm256_fmadd_ps(tmp, tmp, sum);
451-
}
452-
return simd::_mm256_reduce_add_ps(sum) + generic_l2(a + upper, b + upper, rest);
417+
return simd::generic_simd_op(L2FloatOp<8>{}, a, b, length);
453418
}
454419
};
455420

456421
template <size_t N> struct L2Impl<N, int8_t, int8_t, AVX_AVAILABILITY::AVX2> {
457422
SVS_NOINLINE static float
458423
compute(const int8_t* a, const int8_t* b, lib::MaybeStatic<N> length) {
459-
constexpr size_t vector_size = 8;
460-
461-
size_t upper = lib::upper<vector_size>(length);
462-
auto rest = lib::rest<vector_size>(length);
463-
auto sum = _mm256_setzero_ps();
464-
for (size_t j = 0; j < upper; j += vector_size) {
465-
// * Strategy: Load 8 bytes as a 64-bit int.
466-
// * Use `_mm_cvtsi64_si128` to convert to a 128-bit vector.
467-
// * Use `mm256_evtepi8_epi32` to convert the 8-bytes to
468-
// 8 32-bit integers.
469-
// * Finally, convert to single precision floating point.
470-
auto va = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
471-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(a + j)))
472-
));
473-
auto vb = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
474-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(b + j)))
475-
));
476-
auto diff = _mm256_sub_ps(va, vb);
477-
sum = _mm256_fmadd_ps(diff, diff, sum);
478-
}
479-
return simd::_mm256_reduce_add_ps(sum) + generic_l2(a + upper, b + upper, rest);
424+
return simd::generic_simd_op(L2FloatOp<8>{}, a, b, length);
480425
}
481426
};
482427

483428
template <size_t N> struct L2Impl<N, uint8_t, uint8_t, AVX_AVAILABILITY::AVX2> {
484429
SVS_NOINLINE static float
485430
compute(const uint8_t* a, const uint8_t* b, lib::MaybeStatic<N> length) {
486-
constexpr size_t vector_size = 8;
487-
488-
size_t upper = lib::upper<vector_size>(length);
489-
auto rest = lib::rest<vector_size>(length);
490-
auto sum = _mm256_setzero_ps();
491-
for (size_t j = 0; j < upper; j += vector_size) {
492-
// * Strategy: Load 8 bytes as a 64-bit int.
493-
// * Use `_mm_cvtsi64_si128` to convert to a 128-bit vector.
494-
// * Use `mm256_evtepi8_epi32` to convert the 8-bytes to
495-
// 8 32-bit integers.
496-
// * Finally, convert to single precision floating point.
497-
auto va = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
498-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(a + j)))
499-
));
500-
auto vb = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
501-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(b + j)))
502-
));
503-
auto diff = _mm256_sub_ps(va, vb);
504-
sum = _mm256_fmadd_ps(diff, diff, sum);
505-
}
506-
return simd::_mm256_reduce_add_ps(sum) + generic_l2(a + upper, b + upper, rest);
431+
return simd::generic_simd_op(L2FloatOp<8>{}, a, b, length);
507432
}
508433
};
509434

0 commit comments

Comments
 (0)