Skip to content

Commit 0b03df4

Browse files
committed
minor update to sgemm_tcu
1 parent d7e6e12 commit 0b03df4

File tree

1 file changed

+33
-54
lines changed

1 file changed

+33
-54
lines changed

tests/regression/sgemm_tcu/main.cpp

Lines changed: 33 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,24 @@ class Comparator<vt::bf16> {
235235
}
236236
};
237237

238+
template <>
239+
class Comparator<vt::tf32> {
240+
public:
241+
static uint32_t generate() {
242+
auto fvalue = float(rand()) / RAND_MAX;
243+
return rv_ftox_s(bit_cast<uint32_t>(fvalue), 8, 10, 0, nullptr);
244+
}
245+
static bool compare(uint32_t a, uint32_t b, int index, int errors) {
246+
if (a != b) {
247+
if (errors < MAX_ERRORS) {
248+
printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a);
249+
}
250+
return false;
251+
}
252+
return true;
253+
}
254+
};
255+
238256
template <>
239257
class Comparator<vt::fp32> {
240258
public:
@@ -311,6 +329,15 @@ struct muladd_t<vt::bf16, vt::bf16> {
311329
}
312330
};
313331

332+
template <>
333+
struct muladd_t<vt::tf32, vt::fp32> {
334+
static float eval(uint32_t a, uint32_t b, float c) {
335+
auto fa = bit_cast<float>(rv_xtof_s(a, 8, 10, 0, nullptr));
336+
auto fb = bit_cast<float>(rv_xtof_s(b, 8, 10, 0, nullptr));
337+
return fa * fb + c;
338+
}
339+
};
340+
314341
template <>
315342
struct muladd_t<vt::int4, vt::int32> {
316343
static int32_t eval(uint8_t a, uint8_t b, int32_t c) {
@@ -367,61 +394,13 @@ static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t
367394
}
368395
}
369396

370-
/*
371-
static void matmul_cpu_sparseA(
372-
otype_t* C, // [M × N] output
373-
const SparseMat& A, // sparse-A
374-
const itype_t* B, // [K × N] dense-B
375-
uint32_t N) // number of columns of B/C
376-
{
377-
const uint32_t M = A.rows;
378-
const uint32_t K = A.cols;
379-
380-
const uint32_t subbytes = 8 / vt::ITYPE::bits;
381-
382-
// --- helper lambdas to index sparse arrays by row ---
383-
auto row_values = [&](uint32_t m) {
384-
return A.values.data() + m * (K / 2); // two kept per block
385-
};
386-
auto row_meta = [&](uint32_t m) {
387-
return A.meta .data() + m * (K / 4);
388-
};
389-
390-
for (uint32_t m = 0; m < M; ++m) {
391-
392-
const itype_t* Avals = row_values(m);
393-
const uint8_t* Ameta = row_meta (m);
394-
size_t v_idx = 0; // cursor inside values[]
395-
396-
for (uint32_t n = 0; n < N; ++n) {
397-
otype_t sum(0);
398-
for (uint32_t blk = 0; blk < K; blk += 4) {
399-
uint8_t mask = *(Ameta++);
400-
assert(mask);
401-
for (uint32_t i = 0; i < 4; ++i) {
402-
if (mask & (1u << i)) {
403-
auto a_val = Avals[v_idx++];
404-
uint32_t k = blk + i; // logical K index
405-
uint32_t kk = subbytes ? k * subbytes // packed-layout idx
406-
: k;
407-
auto b_val = data_accessor_t<vt::ITYPE>::read(
408-
B, kk * N + n);
409-
sum = muladd_t<vt::ITYPE, vt::OTYPE>::eval(a_val, b_val, sum);
410-
}
411-
}
412-
}
413-
data_accessor_t<vt::OTYPE>::write(C, m * N + n, sum);
414-
}
415-
}
416-
}*/
417-
418397
///////////////////////////////////////////////////////////////////////////////
419398

420399
const char *kernel_file = "kernel.vxbin";
421400

422-
uint32_t xm = 4;
423-
uint32_t xn = 8;
424-
uint32_t xk = 2;
401+
uint32_t xm = 32;
402+
uint32_t xn = 32;
403+
uint32_t xk = 32;
425404

426405
vx_device_h device = nullptr;
427406
vx_buffer_h A_buffer = nullptr;
@@ -568,9 +547,9 @@ int main(int argc, char *argv[]) {
568547
return -1;
569548
}
570549

571-
uint32_t M = xm * cfg::tileM;
572-
uint32_t N = xn * cfg::tileN;
573-
uint32_t K = xk = cfg::tileK;
550+
uint32_t M = xm;
551+
uint32_t N = xn;
552+
uint32_t K = xk;
574553

575554
if ((M % cfg::tileM) != 0) {
576555
std::cout << "Error: M must be a multiple of tensor tileM!" << std::endl;

0 commit comments

Comments
 (0)