diff --git a/src/ggml-bitnet-lut.cpp b/src/ggml-bitnet-lut.cpp index 59422d54..1cbddbee 100644 --- a/src/ggml-bitnet-lut.cpp +++ b/src/ggml-bitnet-lut.cpp @@ -73,7 +73,7 @@ size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const stru const size_t ne11 = src1->ne[1]; const int bits = ggml_bitnet_get_type_bits(src0->type); - size_t wsize = ne10 * ne11 * 15 * sizeof(int8_t) + 1 * ne11 * 2 * sizeof(bitnet_float_type); + size_t wsize = ne10 * ne01 * 15 * sizeof(int8_t) + 1 * ne11 * 2 * sizeof(bitnet_float_type); if (sizeof(bitnet_float_type) == 2) { // Need fp32 to fp16 conversion wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type); @@ -145,7 +145,7 @@ size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const stru const size_t ne10 = src1->ne[0]; const size_t ne11 = src1->ne[1]; - size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type); + size_t wsize = ne10 * ne01 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type); if (sizeof(bitnet_float_type) == 2) { // Need fp32 to fp16 conversion wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);