@@ -250,19 +250,16 @@ void compute_matmul(
250250 m_int, // const int M
251251 n_int, // const int N
252252 k_int, // const int K
253- 1 .0f , // float alpha (explicitly float)
253+ 1 .0f , // float alpha
254254 data, // const float* A
255255 k_int, // const int lda
256256 centroids, // const float* B
257257 k_int, // const int ldb
258- 0 .0f , // const float beta (explicitly float)
258+ 0 .0f , // const float beta
259259 results, // float* c
260260 n_int // const int ldc
261261 );
262262 } else if constexpr (std::is_same_v<T, BFloat16>) {
263- // Intel MKL BFloat16 GEMM requires careful parameter casting to avoid parameter
264- // errors Ensure all integer parameters are properly cast to int (MKL expects int,
265- // not size_t)
266263 int m_int = static_cast <int >(m);
267264 int n_int = static_cast <int >(n);
268265 int k_int = static_cast <int >(k);
@@ -274,19 +271,16 @@ void compute_matmul(
274271 m_int, // const int M
275272 n_int, // const int N
276273 k_int, // const int K
277- 1 .0f , // float alpha (explicitly float)
274+ 1 .0f , // float alpha
278275 (const uint16_t *)data, // const *uint16_t A
279276 k_int, // const int lda
280277 (const uint16_t *)centroids, // const uint16_t* B
281278 k_int, // const int ldb
282- 0 .0f , // const float beta (explicitly float)
279+ 0 .0f , // const float beta
283280 results, // float* c
284281 n_int // const int ldc
285282 );
286283 } else if constexpr (std::is_same_v<T, Float16>) {
287- // Intel MKL Float16 GEMM requires careful parameter casting to avoid parameter
288- // errors Ensure all integer parameters are properly cast to int (MKL expects int,
289- // not size_t)
290284 int m_int = static_cast <int >(m);
291285 int n_int = static_cast <int >(n);
292286 int k_int = static_cast <int >(k);
@@ -298,12 +292,12 @@ void compute_matmul(
298292 m_int, // const int M
299293 n_int, // const int N
300294 k_int, // const int K
301- 1 .0f , // float alpha (explicitly float)
295+ 1 .0f , // float alpha
302296 (const uint16_t *)data, // const *uint16_t A
303297 k_int, // const int lda
304298 (const uint16_t *)centroids, // const uint16_t* B
305299 k_int, // const int ldb
306- 0 .0f , // const float beta (explicitly float)
300+ 0 .0f , // const float beta
307301 results, // float* c
308302 n_int // const int ldc
309303 );
0 commit comments