Skip to content

Commit eea2ffb

Browse files
committed
Fix: Wire remaining kernels in 5 places each
This brings the current state of the project to 2'000 kernels including both SIMD and serial endpoints.
1 parent 0a152c3 commit eea2ffb

13 files changed

Lines changed: 147 additions & 2 deletions

File tree

CONTRIBUTING.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,21 @@ To add a new operation family, for example `foo`:
357357
7. __Cross-platform tests__: add entries to `test/test_cross.hpp` and the relevant `test_cross_*.cpp` files.
358358
8. __CMakeLists.txt__: wire the new source files into the `nk_test` and `nk_bench` targets.
359359
9. __Language bindings__: update `python/numkong.c`, `javascript/numkong.c`, `rust/numkong.rs`, etc. as needed.
360+
361+
## Adding a Backend Kernel to an Existing Family
362+
363+
For primary kernels, every backend implementation should be wired in five places beyond the backend header itself:
364+
365+
1. __Forward declaration__: add the `NK_PUBLIC` declaration with the matching `@copydoc` in the first half of `include/numkong/<family>.h`.
366+
2. __Compile-time dispatch__: add the `#if !NK_DYNAMIC_DISPATCH` branch in the second half of `include/numkong/<family>.h`.
367+
3. __Run-time dispatch__: add the dtype-specific entry to the relevant `c/dispatch_*.c` table.
368+
4. __Precision tests__: register the kernel in `nk_test`, usually in the existing `test/test_<family>.cpp` suite.
369+
5. __Benchmarks__: register the kernel in `nk_bench`, usually in the existing `bench/bench_<family>.cpp` suite.
370+
371+
Use the existing family suite unless the kernel introduces a genuinely new test shape.
372+
The rule is about coverage and reachability, not about creating a brand new source file for every symbol.
373+
374+
There are two intentional exceptions:
375+
376+
- `cast`: the family-level `nk_cast_*` kernels follow the same header/dispatch/test/bench rule, but scalar conversion helpers are wired through `c/dispatch_other.c` and are covered through `test/test_cast.cpp` and `bench/bench_cast.cpp`.
377+
- `scalar`: scalar helpers are centrally declared in `include/numkong/scalar.h`, wired through `c/dispatch_other.c`, and currently do not follow the per-helper `nk_test` and `nk_bench` registration pattern.

c/dispatch_bf16.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ void nk_dispatch_bf16_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
8686
case nk_kernel_euclidean_k: *m = (m_t)&nk_euclidean_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
8787
case nk_kernel_bilinear_k: *m = (m_t)&nk_bilinear_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
8888
case nk_kernel_mahalanobis_k: *m = (m_t)&nk_mahalanobis_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
89+
case nk_kernel_rmsd_k: *m = (m_t)&nk_rmsd_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
90+
case nk_kernel_kabsch_k: *m = (m_t)&nk_kabsch_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
91+
case nk_kernel_umeyama_k: *m = (m_t)&nk_umeyama_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
8992
case nk_kernel_each_fma_k: *m = (m_t)&nk_each_fma_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
9093
case nk_kernel_each_blend_k: *m = (m_t)&nk_each_blend_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
9194
case nk_kernel_each_scale_k: *m = (m_t)&nk_each_scale_bf16_neonbfdot, *c = nk_cap_neonbfdot_k; return;
@@ -229,6 +232,9 @@ void nk_dispatch_bf16_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
229232
case nk_kernel_euclidean_k: *m = (m_t)&nk_euclidean_bf16_haswell, *c = nk_cap_haswell_k; return;
230233
case nk_kernel_bilinear_k: *m = (m_t)&nk_bilinear_bf16_haswell, *c = nk_cap_haswell_k; return;
231234
case nk_kernel_mahalanobis_k: *m = (m_t)&nk_mahalanobis_bf16_haswell, *c = nk_cap_haswell_k; return;
235+
case nk_kernel_rmsd_k: *m = (m_t)&nk_rmsd_bf16_haswell, *c = nk_cap_haswell_k; return;
236+
case nk_kernel_kabsch_k: *m = (m_t)&nk_kabsch_bf16_haswell, *c = nk_cap_haswell_k; return;
237+
case nk_kernel_umeyama_k: *m = (m_t)&nk_umeyama_bf16_haswell, *c = nk_cap_haswell_k; return;
232238
case nk_kernel_each_fma_k: *m = (m_t)&nk_each_fma_bf16_haswell, *c = nk_cap_haswell_k; return;
233239
case nk_kernel_each_blend_k: *m = (m_t)&nk_each_blend_bf16_haswell, *c = nk_cap_haswell_k; return;
234240
case nk_kernel_each_scale_k: *m = (m_t)&nk_each_scale_bf16_haswell, *c = nk_cap_haswell_k; return;

c/dispatch_e4m3.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void nk_dispatch_e4m3_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
5353
#endif
5454
#if NK_TARGET_NEONFHM
5555
if (v & nk_cap_neonfhm_k) switch (k) {
56+
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e4m3_neonfhm, *c = nk_cap_neonfhm_k; return;
5657
case nk_kernel_reduce_moments_k: *m = (m_t)&nk_reduce_moments_e4m3_neonfhm, *c = nk_cap_neonfhm_k; return;
5758
case nk_kernel_reduce_minmax_k: *m = (m_t)&nk_reduce_minmax_e4m3_neonfhm, *c = nk_cap_neonfhm_k; return;
5859
case nk_kernel_dots_packed_size_k: *m = (m_t)&nk_dots_packed_size_e4m3_neonfhm, *c = nk_cap_neonfhm_k; return;
@@ -70,6 +71,12 @@ void nk_dispatch_e4m3_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
7071
default: break;
7172
}
7273
#endif
74+
#if NK_TARGET_NEONBFDOT
75+
if (v & nk_cap_neonbfdot_k) switch (k) {
76+
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e4m3_neonbfdot, *c = nk_cap_neonbfdot_k; return;
77+
default: break;
78+
}
79+
#endif
7380
#if NK_TARGET_NEON
7481
if (v & nk_cap_neon_k) switch (k) {
7582
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e4m3_neon, *c = nk_cap_neon_k; return;

c/dispatch_e5m2.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void nk_dispatch_e5m2_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
5353
#endif
5454
#if NK_TARGET_NEONFHM
5555
if (v & nk_cap_neonfhm_k) switch (k) {
56+
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e5m2_neonfhm, *c = nk_cap_neonfhm_k; return;
5657
case nk_kernel_reduce_moments_k: *m = (m_t)&nk_reduce_moments_e5m2_neonfhm, *c = nk_cap_neonfhm_k; return;
5758
case nk_kernel_reduce_minmax_k: *m = (m_t)&nk_reduce_minmax_e5m2_neonfhm, *c = nk_cap_neonfhm_k; return;
5859
case nk_kernel_dots_packed_size_k: *m = (m_t)&nk_dots_packed_size_e5m2_neonfhm, *c = nk_cap_neonfhm_k; return;
@@ -70,6 +71,12 @@ void nk_dispatch_e5m2_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
7071
default: break;
7172
}
7273
#endif
74+
#if NK_TARGET_NEONBFDOT
75+
if (v & nk_cap_neonbfdot_k) switch (k) {
76+
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e5m2_neonbfdot, *c = nk_cap_neonbfdot_k; return;
77+
default: break;
78+
}
79+
#endif
7380
#if NK_TARGET_NEON
7481
if (v & nk_cap_neon_k) switch (k) {
7582
case nk_kernel_dot_k: *m = (m_t)&nk_dot_e5m2_neon, *c = nk_cap_neon_k; return;

c/dispatch_f32.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ void nk_dispatch_f32_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_punn
4848
return;
4949
case nk_kernel_maxsim_pack_k: *m = (m_t)&nk_maxsim_pack_f32_v128relaxed, *c = nk_cap_v128relaxed_k; return;
5050
case nk_kernel_maxsim_packed_k: *m = (m_t)&nk_maxsim_packed_f32_v128relaxed, *c = nk_cap_v128relaxed_k; return;
51+
case nk_kernel_rmsd_k: *m = (m_t)&nk_rmsd_f32_v128relaxed, *c = nk_cap_v128relaxed_k; return;
52+
case nk_kernel_kabsch_k: *m = (m_t)&nk_kabsch_f32_v128relaxed, *c = nk_cap_v128relaxed_k; return;
53+
case nk_kernel_umeyama_k: *m = (m_t)&nk_umeyama_f32_v128relaxed, *c = nk_cap_v128relaxed_k; return;
5154
default: break;
5255
}
5356
#endif
@@ -61,6 +64,8 @@ void nk_dispatch_f32_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_punn
6164
#endif
6265
#if NK_TARGET_SMEF64
6366
if (v & nk_cap_smef64_k) switch (k) {
67+
case nk_kernel_bilinear_k: *m = (m_t)&nk_bilinear_f32_smef64, *c = nk_cap_smef64_k; return;
68+
case nk_kernel_mahalanobis_k: *m = (m_t)&nk_mahalanobis_f32_smef64, *c = nk_cap_smef64_k; return;
6469
case nk_kernel_dots_packed_size_k: *m = (m_t)&nk_dots_packed_size_f32_smef64, *c = nk_cap_smef64_k; return;
6570
case nk_kernel_dots_pack_k: *m = (m_t)&nk_dots_pack_f32_smef64, *c = nk_cap_smef64_k; return;
6671
case nk_kernel_dots_packed_k: *m = (m_t)&nk_dots_packed_f32_smef64, *c = nk_cap_smef64_k; return;

c/dispatch_f32c.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ void nk_dispatch_f32c_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
1515
default: break;
1616
}
1717
#endif
18+
#if NK_TARGET_SMEF64
19+
if (v & nk_cap_smef64_k) switch (k) {
20+
case nk_kernel_bilinear_k: *m = (m_t)&nk_bilinear_f32c_smef64, *c = nk_cap_smef64_k; return;
21+
default: break;
22+
}
23+
#endif
1824
#if NK_TARGET_SVE
1925
if (v & nk_cap_sve_k) switch (k) {
2026
case nk_kernel_dot_k: *m = (m_t)&nk_dot_f32c_sve, *c = nk_cap_sve_k; return;

c/dispatch_f64.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,16 @@ void nk_dispatch_f64_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_punn
4343
case nk_kernel_euclideans_symmetric_k:
4444
*m = (m_t)&nk_euclideans_symmetric_f64_v128relaxed, *c = nk_cap_v128relaxed_k;
4545
return;
46+
case nk_kernel_rmsd_k: *m = (m_t)&nk_rmsd_f64_v128relaxed, *c = nk_cap_v128relaxed_k; return;
47+
case nk_kernel_kabsch_k: *m = (m_t)&nk_kabsch_f64_v128relaxed, *c = nk_cap_v128relaxed_k; return;
48+
case nk_kernel_umeyama_k: *m = (m_t)&nk_umeyama_f64_v128relaxed, *c = nk_cap_v128relaxed_k; return;
4649
default: break;
4750
}
4851
#endif
4952
#if NK_TARGET_SMEF64
5053
if (v & nk_cap_smef64_k) switch (k) {
54+
case nk_kernel_bilinear_k: *m = (m_t)&nk_bilinear_f64_smef64, *c = nk_cap_smef64_k; return;
55+
case nk_kernel_mahalanobis_k: *m = (m_t)&nk_mahalanobis_f64_smef64, *c = nk_cap_smef64_k; return;
5156
case nk_kernel_dots_packed_size_k: *m = (m_t)&nk_dots_packed_size_f64_smef64, *c = nk_cap_smef64_k; return;
5257
case nk_kernel_dots_pack_k: *m = (m_t)&nk_dots_pack_f64_smef64, *c = nk_cap_smef64_k; return;
5358
case nk_kernel_dots_packed_k: *m = (m_t)&nk_dots_packed_f64_smef64, *c = nk_cap_smef64_k; return;

c/dispatch_f64c.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ void nk_dispatch_f64c_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
1515
default: break;
1616
}
1717
#endif
18+
#if NK_TARGET_SMEF64
19+
if (v & nk_cap_smef64_k) switch (k) {
20+
case nk_kernel_bilinear_k: *m = (m_t)&nk_bilinear_f64c_smef64, *c = nk_cap_smef64_k; return;
21+
default: break;
22+
}
23+
#endif
1824
#if NK_TARGET_SVE
1925
if (v & nk_cap_sve_k) switch (k) {
2026
case nk_kernel_dot_k: *m = (m_t)&nk_dot_f64c_sve, *c = nk_cap_sve_k; return;
@@ -55,6 +61,8 @@ void nk_dispatch_f64c_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
5561
#endif
5662
#if NK_TARGET_HASWELL
5763
if (v & nk_cap_haswell_k) switch (k) {
64+
case nk_kernel_dot_k: *m = (m_t)&nk_dot_f64c_haswell, *c = nk_cap_haswell_k; return;
65+
case nk_kernel_vdot_k: *m = (m_t)&nk_vdot_f64c_haswell, *c = nk_cap_haswell_k; return;
5866
case nk_kernel_each_scale_k: *m = (m_t)&nk_each_scale_f64c_haswell, *c = nk_cap_haswell_k; return;
5967
case nk_kernel_each_blend_k: *m = (m_t)&nk_each_blend_f64c_haswell, *c = nk_cap_haswell_k; return;
6068
case nk_kernel_each_fma_k: *m = (m_t)&nk_each_fma_f64c_haswell, *c = nk_cap_haswell_k; return;

include/numkong/curved.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ NK_PUBLIC void nk_bilinear_f32c_smef64(nk_f32c_t const *a, nk_f32c_t const *b, n
209209
/** @copydoc nk_mahalanobis_f32 */
210210
NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
211211
nk_f64_t *result);
212+
/** @copydoc nk_bilinear_f64 */
213+
NK_PUBLIC void nk_bilinear_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
214+
nk_f64_t *result);
215+
/** @copydoc nk_bilinear_f64c */
216+
NK_PUBLIC void nk_bilinear_f64c_smef64(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
217+
nk_f64c_t *result);
218+
/** @copydoc nk_mahalanobis_f64 */
219+
NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
220+
nk_f64_t *result);
212221
#endif // NK_TARGET_SMEF64
213222

214223
#if NK_TARGET_HASWELL
@@ -345,6 +354,8 @@ extern "C" {
345354
NK_PUBLIC void nk_bilinear_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n, nk_f64_t *result) {
346355
#if NK_TARGET_SKYLAKE
347356
nk_bilinear_f64_skylake(a, b, c, n, result);
357+
#elif NK_TARGET_SMEF64
358+
nk_bilinear_f64_smef64(a, b, c, n, result);
348359
#elif NK_TARGET_RVV
349360
nk_bilinear_f64_rvv(a, b, c, n, result);
350361
#else
@@ -355,6 +366,8 @@ NK_PUBLIC void nk_bilinear_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t co
355366
NK_PUBLIC void nk_bilinear_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n, nk_f64_t *result) {
356367
#if NK_TARGET_SKYLAKE
357368
nk_bilinear_f32_skylake(a, b, c, n, result);
369+
#elif NK_TARGET_SMEF64
370+
nk_bilinear_f32_smef64(a, b, c, n, result);
358371
#elif NK_TARGET_HASWELL
359372
nk_bilinear_f32_haswell(a, b, c, n, result);
360373
#elif NK_TARGET_NEON
@@ -397,6 +410,8 @@ NK_PUBLIC void nk_bilinear_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_
397410
nk_f64c_t *results) {
398411
#if NK_TARGET_SKYLAKE
399412
nk_bilinear_f64c_skylake(a, b, c, n, results);
413+
#elif NK_TARGET_SMEF64
414+
nk_bilinear_f64c_smef64(a, b, c, n, results);
400415
#else
401416
nk_bilinear_f64c_serial(a, b, c, n, results);
402417
#endif
@@ -406,6 +421,8 @@ NK_PUBLIC void nk_bilinear_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_
406421
nk_f64c_t *results) {
407422
#if NK_TARGET_SKYLAKE
408423
nk_bilinear_f32c_skylake(a, b, c, n, results);
424+
#elif NK_TARGET_SMEF64
425+
nk_bilinear_f32c_smef64(a, b, c, n, results);
409426
#elif NK_TARGET_NEON
410427
nk_bilinear_f32c_neon(a, b, c, n, results);
411428
#else
@@ -437,6 +454,8 @@ NK_PUBLIC void nk_mahalanobis_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t
437454
nk_f64_t *result) {
438455
#if NK_TARGET_SKYLAKE
439456
nk_mahalanobis_f64_skylake(a, b, c, n, result);
457+
#elif NK_TARGET_SMEF64
458+
nk_mahalanobis_f64_smef64(a, b, c, n, result);
440459
#elif NK_TARGET_RVV
441460
nk_mahalanobis_f64_rvv(a, b, c, n, result);
442461
#else
@@ -448,6 +467,8 @@ NK_PUBLIC void nk_mahalanobis_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t
448467
nk_f64_t *result) {
449468
#if NK_TARGET_SKYLAKE
450469
nk_mahalanobis_f32_skylake(a, b, c, n, result);
470+
#elif NK_TARGET_SMEF64
471+
nk_mahalanobis_f32_smef64(a, b, c, n, result);
451472
#elif NK_TARGET_HASWELL
452473
nk_mahalanobis_f32_haswell(a, b, c, n, result);
453474
#elif NK_TARGET_NEON

include/numkong/dot.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,10 @@ NK_PUBLIC void nk_vdot_f16c_neonhalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_
307307
#if NK_TARGET_NEONFHM
308308
/** @copydoc nk_dot_f16 */
309309
NK_PUBLIC void nk_dot_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
310+
/** @copydoc nk_dot_e4m3 */
311+
NK_PUBLIC void nk_dot_e4m3_neonfhm(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
312+
/** @copydoc nk_dot_e5m2 */
313+
NK_PUBLIC void nk_dot_e5m2_neonfhm(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
310314
/** @copydoc nk_dot_f16c */
311315
NK_PUBLIC void nk_dot_f16c_neonfhm(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
312316
/** @copydoc nk_vdot_f16c */
@@ -318,6 +322,10 @@ NK_PUBLIC void nk_vdot_f16c_neonfhm(nk_f16c_t const *a, nk_f16c_t const *b, nk_s
318322
NK_PUBLIC void nk_dot_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
319323
/** @copydoc nk_dot_u8 */
320324
NK_PUBLIC void nk_dot_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
325+
/** @copydoc nk_dot_i4 */
326+
NK_PUBLIC void nk_dot_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
327+
/** @copydoc nk_dot_u4 */
328+
NK_PUBLIC void nk_dot_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
321329
/** @copydoc nk_dot_e2m3 */
322330
NK_PUBLIC void nk_dot_e2m3_neonsdot(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
323331
/** @copydoc nk_dot_e3m2 */
@@ -327,6 +335,10 @@ NK_PUBLIC void nk_dot_e3m2_neonsdot(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_s
327335
#if NK_TARGET_NEONBFDOT
328336
/** @copydoc nk_dot_bf16 */
329337
NK_PUBLIC void nk_dot_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
338+
/** @copydoc nk_dot_e4m3 */
339+
NK_PUBLIC void nk_dot_e4m3_neonbfdot(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
340+
/** @copydoc nk_dot_e5m2 */
341+
NK_PUBLIC void nk_dot_e5m2_neonbfdot(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
330342
/** @copydoc nk_dot_bf16c */
331343
NK_PUBLIC void nk_dot_bf16c_neonbfdot(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
332344
/** @copydoc nk_vdot_bf16c */
@@ -371,6 +383,10 @@ NK_PUBLIC void nk_dot_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_
371383
NK_PUBLIC void nk_dot_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
372384
/** @copydoc nk_vdot_f32c */
373385
NK_PUBLIC void nk_vdot_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
386+
/** @copydoc nk_dot_f64c */
387+
NK_PUBLIC void nk_dot_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
388+
/** @copydoc nk_vdot_f64c */
389+
NK_PUBLIC void nk_vdot_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
374390

375391
/** @copydoc nk_dot_f16 */
376392
NK_PUBLIC void nk_dot_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
@@ -694,6 +710,8 @@ NK_PUBLIC void nk_dot_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32
694710
NK_PUBLIC void nk_dot_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
695711
#if NK_TARGET_ICELAKE
696712
nk_dot_i4_icelake(a, b, n, result);
713+
#elif NK_TARGET_NEONSDOT
714+
nk_dot_i4_neonsdot(a, b, n, result);
697715
#elif NK_TARGET_RVV
698716
nk_dot_i4_rvv(a, b, n, result);
699717
#elif NK_TARGET_HASWELL
@@ -708,6 +726,8 @@ NK_PUBLIC void nk_dot_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk
708726
NK_PUBLIC void nk_dot_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
709727
#if NK_TARGET_ICELAKE
710728
nk_dot_u4_icelake(a, b, n, result);
729+
#elif NK_TARGET_NEONSDOT
730+
nk_dot_u4_neonsdot(a, b, n, result);
711731
#elif NK_TARGET_RVV
712732
nk_dot_u4_rvv(a, b, n, result);
713733
#elif NK_TARGET_HASWELL
@@ -788,6 +808,10 @@ NK_PUBLIC void nk_dot_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
788808
NK_PUBLIC void nk_dot_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
789809
#if NK_TARGET_GENOA
790810
nk_dot_e4m3_genoa(a, b, n, result);
811+
#elif NK_TARGET_NEONBFDOT
812+
nk_dot_e4m3_neonbfdot(a, b, n, result);
813+
#elif NK_TARGET_NEONFHM
814+
nk_dot_e4m3_neonfhm(a, b, n, result);
791815
#elif NK_TARGET_RVVHALF
792816
nk_dot_e4m3_rvvhalf(a, b, n, result);
793817
#elif NK_TARGET_RVVBF16
@@ -810,6 +834,10 @@ NK_PUBLIC void nk_dot_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n,
810834
NK_PUBLIC void nk_dot_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
811835
#if NK_TARGET_GENOA
812836
nk_dot_e5m2_genoa(a, b, n, result);
837+
#elif NK_TARGET_NEONBFDOT
838+
nk_dot_e5m2_neonbfdot(a, b, n, result);
839+
#elif NK_TARGET_NEONFHM
840+
nk_dot_e5m2_neonfhm(a, b, n, result);
813841
#elif NK_TARGET_RVVHALF
814842
nk_dot_e5m2_rvvhalf(a, b, n, result);
815843
#elif NK_TARGET_RVVBF16
@@ -962,6 +990,8 @@ NK_PUBLIC void nk_dot_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n,
962990
nk_dot_f64c_rvv(a, b, n, result);
963991
#elif NK_TARGET_SKYLAKE
964992
nk_dot_f64c_skylake(a, b, n, result);
993+
#elif NK_TARGET_HASWELL
994+
nk_dot_f64c_haswell(a, b, n, result);
965995
#elif NK_TARGET_V128RELAXED
966996
nk_dot_f64c_v128relaxed(a, b, n, result);
967997
#else
@@ -1022,6 +1052,8 @@ NK_PUBLIC void nk_vdot_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n,
10221052
nk_vdot_f64c_rvv(a, b, n, result);
10231053
#elif NK_TARGET_SKYLAKE
10241054
nk_vdot_f64c_skylake(a, b, n, result);
1055+
#elif NK_TARGET_HASWELL
1056+
nk_vdot_f64c_haswell(a, b, n, result);
10251057
#elif NK_TARGET_V128RELAXED
10261058
nk_vdot_f64c_v128relaxed(a, b, n, result);
10271059
#else

0 commit comments

Comments
 (0)