@@ -235,6 +235,24 @@ class Comparator<vt::bf16> {
235235 }
236236};
237237
238+ template <>
239+ class Comparator <vt::tf32> {
240+ public:
241+ static uint32_t generate () {
242+ auto fvalue = float (rand ()) / RAND_MAX;
243+ return rv_ftox_s (bit_cast<uint32_t >(fvalue), 8 , 10 , 0 , nullptr );
244+ }
245+ static bool compare (uint32_t a, uint32_t b, int index, int errors) {
246+ if (a != b) {
247+ if (errors < MAX_ERRORS) {
248+ printf (" *** error: [%d] expected=0x%x, actual=0x%x\n " , index, b, a);
249+ }
250+ return false ;
251+ }
252+ return true ;
253+ }
254+ };
255+
238256template <>
239257class Comparator <vt::fp32> {
240258public:
@@ -311,6 +329,15 @@ struct muladd_t<vt::bf16, vt::bf16> {
311329 }
312330};
313331
332+ template <>
333+ struct muladd_t <vt::tf32, vt::fp32> {
334+ static float eval (uint32_t a, uint32_t b, float c) {
335+ auto fa = bit_cast<float >(rv_xtof_s (a, 8 , 10 , 0 , nullptr ));
336+ auto fb = bit_cast<float >(rv_xtof_s (b, 8 , 10 , 0 , nullptr ));
337+ return fa * fb + c;
338+ }
339+ };
340+
314341template <>
315342struct muladd_t <vt::int4, vt::int32> {
316343 static int32_t eval (uint8_t a, uint8_t b, int32_t c) {
@@ -367,61 +394,13 @@ static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t
367394 }
368395}
369396
370- /*
371- static void matmul_cpu_sparseA(
372- otype_t* C, // [M × N] output
373- const SparseMat& A, // sparse-A
374- const itype_t* B, // [K × N] dense-B
375- uint32_t N) // number of columns of B/C
376- {
377- const uint32_t M = A.rows;
378- const uint32_t K = A.cols;
379-
380- const uint32_t subbytes = 8 / vt::ITYPE::bits;
381-
382- // --- helper lambdas to index sparse arrays by row ---
383- auto row_values = [&](uint32_t m) {
384- return A.values.data() + m * (K / 2); // two kept per block
385- };
386- auto row_meta = [&](uint32_t m) {
387- return A.meta .data() + m * (K / 4);
388- };
389-
390- for (uint32_t m = 0; m < M; ++m) {
391-
392- const itype_t* Avals = row_values(m);
393- const uint8_t* Ameta = row_meta (m);
394- size_t v_idx = 0; // cursor inside values[]
395-
396- for (uint32_t n = 0; n < N; ++n) {
397- otype_t sum(0);
398- for (uint32_t blk = 0; blk < K; blk += 4) {
399- uint8_t mask = *(Ameta++);
400- assert(mask);
401- for (uint32_t i = 0; i < 4; ++i) {
402- if (mask & (1u << i)) {
403- auto a_val = Avals[v_idx++];
404- uint32_t k = blk + i; // logical K index
405- uint32_t kk = subbytes ? k * subbytes // packed-layout idx
406- : k;
407- auto b_val = data_accessor_t<vt::ITYPE>::read(
408- B, kk * N + n);
409- sum = muladd_t<vt::ITYPE, vt::OTYPE>::eval(a_val, b_val, sum);
410- }
411- }
412- }
413- data_accessor_t<vt::OTYPE>::write(C, m * N + n, sum);
414- }
415- }
416- }*/
417-
418397// /////////////////////////////////////////////////////////////////////////////
419398
420399const char *kernel_file = " kernel.vxbin" ;
421400
422- uint32_t xm = 4 ;
423- uint32_t xn = 8 ;
424- uint32_t xk = 2 ;
401+ uint32_t xm = 32 ;
402+ uint32_t xn = 32 ;
403+ uint32_t xk = 32 ;
425404
426405vx_device_h device = nullptr ;
427406vx_buffer_h A_buffer = nullptr ;
@@ -568,9 +547,9 @@ int main(int argc, char *argv[]) {
568547 return -1 ;
569548 }
570549
571- uint32_t M = xm * cfg::tileM ;
572- uint32_t N = xn * cfg::tileN ;
573- uint32_t K = xk = cfg::tileK ;
550+ uint32_t M = xm;
551+ uint32_t N = xn;
552+ uint32_t K = xk;
574553
575554 if ((M % cfg::tileM) != 0 ) {
576555 std::cout << " Error: M must be a multiple of tensor tileM!" << std::endl;
0 commit comments