Skip to content

Commit a08750c

Browse files
authored
fix bf16&fp16 quantize to nf4&fp4 (#1805)
1 parent e085e0d commit a08750c

File tree

4 files changed

+97
-29
lines changed

4 files changed

+97
-29
lines changed

csrc/lc/common.h

+24
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,30 @@ typedef enum LC_DataType_t
2828
template <typename T, int DATA_TYPE> void quantize_blockwise(const float * code, const T *A, float *absmax, unsigned char *out, int blocksize, int n);
2929
template<typename T, int DATA_TYPE> void dequantize_blockwise(const float *code, const unsigned char *A, float *absmax, T *out, int block_size, int n);
3030

31+
template <paddle::DataType D>
32+
class PDTraits;
33+
34+
template <>
35+
class PDTraits<paddle::DataType::FLOAT32> {
36+
public:
37+
typedef float DataType;
38+
typedef float data_t;
39+
};
40+
41+
template <>
42+
class PDTraits<paddle::DataType::FLOAT16> {
43+
public:
44+
typedef half DataType;
45+
typedef paddle::float16 data_t;
46+
};
47+
48+
template <>
49+
class PDTraits<paddle::DataType::BFLOAT16> {
50+
public:
51+
typedef __nv_bfloat16 DataType;
52+
typedef paddle::bfloat16 data_t;
53+
};
54+
3155

3256
#define CUDA_CHECK_RETURN(value) { \
3357
cudaError_t _m_cudaStat = value; \

csrc/lc/quantize_blockwise.cu

+22-27
Original file line numberDiff line numberDiff line change
@@ -362,42 +362,37 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4)
362362

363363

364364

365-
template <typename T, int DATA_TYPE> void quantize_blockwise(const float *code, const T *A, float *absmax, unsigned char *out, int blocksize, int n)
365+
template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n)
366366
{
367+
typedef PDTraits<D> traits_;
368+
typedef typename traits_::DataType DataType_;
369+
typedef typename traits_::data_t data_t;
370+
367371
int num_blocks = n/blocksize;
368372
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
369373

374+
const DataType_* A_data = reinterpret_cast<const DataType_*>(A.data<data_t>());
370375
if(blocksize == 4096)
371-
kQuantizeBlockwise<T, 4096, 4, 0><<<num_blocks, 1024>>>(code, A, absmax, out, n);
376+
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax, out, n);
372377
else if(blocksize == 2048)
373-
kQuantizeBlockwise<T, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, n);
378+
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax, out, n);
374379
else if(blocksize == 1024)
375-
kQuantizeBlockwise<T, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, n);
380+
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
376381
else if(blocksize == 512)
377-
kQuantizeBlockwise<T, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, n);
382+
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
378383
else if(blocksize == 256)
379-
kQuantizeBlockwise<T, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, n);
384+
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax, out, n);
380385
else if(blocksize == 128)
381-
kQuantizeBlockwise<T, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, n);
386+
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax, out, n);
382387
else if(blocksize == 64)
383-
kQuantizeBlockwise<T, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, n);
388+
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax, out, n);
384389
else
385390
PD_THROW("only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096].");
386391

387392

388393
CUDA_CHECK_RETURN(cudaPeekAtLastError());
389394
}
390395

391-
template void quantize_blockwise<half, General8bit>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
392-
template void quantize_blockwise<half, FP4>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
393-
template void quantize_blockwise<half, NF4>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
394-
template void quantize_blockwise<float, General8bit>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
395-
template void quantize_blockwise<float, FP4>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
396-
template void quantize_blockwise<float, NF4>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
397-
template void quantize_blockwise<__nv_bfloat16, General8bit>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);
398-
template void quantize_blockwise<__nv_bfloat16, FP4>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);
399-
template void quantize_blockwise<__nv_bfloat16, NF4>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);
400-
401396
std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) {
402397
int n = input.numel();
403398
std::vector<int64_t> out_shape = input.shape();
@@ -410,28 +405,28 @@ std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const
410405
switch(input.type()) {
411406
case paddle::DataType::FLOAT32:
412407
if (quant_type == "8bit")
413-
quantize_blockwise<float, General8bit>(code.data<float>(), input.data<float>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
408+
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
414409
else if (quant_type == "nf4") {
415-
quantize_blockwise<float, NF4>(NULL, input.data<float>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
410+
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
416411
}
417412
else if (quant_type == "fp4")
418-
quantize_blockwise<float, FP4>(NULL, input.data<float>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
413+
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
419414
return {out, absmax};
420415
case paddle::DataType::FLOAT16:
421416
if (quant_type == "8bit")
422-
quantize_blockwise<half, General8bit>(code.data<float>(), input.data<half>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
417+
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
423418
else if (quant_type == "nf4")
424-
quantize_blockwise<half, NF4>(NULL, input.data<half>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
419+
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
425420
else if (quant_type == "fp4")
426-
quantize_blockwise<half, FP4>(NULL, input.data<half>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
421+
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
427422
return {out, absmax};
428423
case paddle::DataType::BFLOAT16:
429424
if (quant_type == "8bit")
430-
quantize_blockwise<__nv_bfloat16, General8bit>(code.data<float>(), input.data<__nv_bfloat16>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
425+
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
431426
else if (quant_type == "nf4")
432-
quantize_blockwise<__nv_bfloat16, NF4>(NULL, input.data<__nv_bfloat16>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
427+
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
433428
else if (quant_type == "fp4")
434-
quantize_blockwise<__nv_bfloat16, FP4>(NULL, input.data<__nv_bfloat16>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
429+
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
435430
return {out, absmax};
436431

437432
default:

paddleslim/lc/quantizers/quant_func.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import paddle
23
from paddleslim_ops import quant_blockwise, dequant_blockwise
34

@@ -90,6 +91,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
9091
for i in range(gap):
9192
values.append(0)
9293
values.sort()
94+
code = paddle.to_tensor(values)
9395
code /= code.max()
9496

9597
return code
@@ -110,15 +112,15 @@ def dequantize_fp4(x, absmax, blocksize):
110112
def quantize_8bit(x, code, blocksize, quant_type="fp8"):
111113
if code is None:
112114
if quant_type=="fp8":
113-
code = paddle.to_tensor(create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4))
115+
code = create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
114116
else:
115117
code = paddle.to_tensor(create_dynamic_map())
116118
return quant_blockwise(x, code, blocksize=blocksize, quant_type="8bit")
117119

118120
def dequantize_8bit(x, code, absmax, blocksize, quant_type="fp8"):
119121
if code is None:
120122
if quant_type=="fp8":
121-
code = paddle.to_tensor(create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4))
123+
code = create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
122124
else:
123125
code = paddle.to_tensor(create_dynamic_map())
124126

tests/lc/test_func.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import sys
2+
sys.path.append("../../")
3+
import numpy as np
4+
import unittest
5+
import paddle
6+
from paddleslim.lc.layers import NF4Linear, FP4Linear
7+
from paddleslim.lc.quantizers.quant_func import quantize_nf4, quantize_fp4, dequantize_nf4, dequantize_fp4, quantize_8bit, dequantize_8bit
8+
9+
class NF4(unittest.TestCase):
10+
def setUp(self):
11+
self.quant_type = "nf4"
12+
self.blocksize = 64
13+
14+
def test_nf4_fp16(self):
15+
a = paddle.uniform([2, 64], dtype="float16")
16+
nf4_a, scale_a = quantize_nf4(a, self.blocksize)
17+
fp16_a = dequantize_nf4(nf4_a, scale_a, self.blocksize).cast("float16")
18+
19+
class FP4(unittest.TestCase):
20+
def setUp(self):
21+
self.quant_type = "fp4"
22+
self.blocksize = 64
23+
24+
def test_fp4_fp16(self):
25+
a = paddle.uniform([2, 64], dtype="float16")
26+
nf4_a, scale_a = quantize_fp4(a, self.blocksize)
27+
fp16_a = dequantize_fp4(nf4_a, scale_a, self.blocksize).cast("float16")
28+
29+
class BIT8(unittest.TestCase):
30+
def setUp(self):
31+
self.quant_type = "fp8"
32+
self.blocksize = 64
33+
34+
def test_fp8_fp16(self):
35+
a = paddle.uniform([2, 64], dtype="float16")
36+
nf4_a, scale_a = quantize_8bit(a, None, self.blocksize, quant_type="fp8")
37+
fp16_a = dequantize_8bit(nf4_a, None, scale_a, self.blocksize, quant_type="fp8").cast("float16")
38+
39+
def test_dynamic_fp8_fp16(self):
40+
a = paddle.uniform([2, 64], dtype="float16")
41+
nf4_a, scale_a = quantize_8bit(a, None, self.blocksize, quant_type="dynamic_fp8")
42+
fp16_a = dequantize_8bit(nf4_a, None, scale_a, self.blocksize, quant_type="dynamic_fp8").cast("float16")
43+
44+
if __name__ == '__main__':
45+
unittest.main()
46+
47+

0 commit comments

Comments
 (0)