Skip to content

Commit 603e0b8

Browse files
authored
support nf4 channel wise quant & fix bug when blocksize>512 (#1817)
1 parent 9c9b6a6 commit 603e0b8

File tree

2 files changed

+162
-37
lines changed

2 files changed

+162
-37
lines changed

csrc/lc/dequantize_blockwise.cu

+74-10
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(const floa
201201
//template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n);
202202

203203

204-
205204
template<typename T, int DATA_TYPE> void dequantize_blockwise(const float *code, const unsigned char *A, const float *absmax, T *out, int blocksize, int n)
206205
{
207206
int num_blocks = n/blocksize;
@@ -226,6 +225,50 @@ template void dequantize_blockwise<float, NF4>(const float *code, const unsigned
226225
//template void dequantize_blockwise<__nv_bfloat16, FP4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);
227226
//template void dequantize_blockwise<__nv_bfloat16, NF4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);
228227

228+
template <typename T, int DATA_TYPE>
229+
__global__ void kDequantizeChannelwise(const unsigned char* A,
230+
const float *absmax,
231+
float *out,
232+
int n,
233+
int cout) {
234+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
235+
236+
int num = n / 2;
237+
//int part_n = num / cout;
238+
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
239+
float local_absmax = absmax[i%cout];
240+
int idx = 2*(i/cout)* cout + i%cout;
241+
switch(DATA_TYPE)
242+
{
243+
case FP4:
244+
out[i*2 + i%cout] = dDequantizeFP4Tree(A[i] >> 4, local_absmax);
245+
out[i*2 + cout + i%cout] = dDequantizeFP4Tree(A[i] & 0x0F, local_absmax);
246+
break;
247+
case NF4:
248+
out[idx] = dDequantizeNF4(A[i] >> 4)* local_absmax;
249+
out[idx + cout] = dDequantizeNF4(A[i] & 0x0F)* local_absmax;
250+
break;
251+
}
252+
__syncthreads();
253+
}
254+
}
255+
256+
template<typename T, int DATA_TYPE> void dequantize_channelwise(const unsigned char *A, const float *absmax, T *out, int n, int cout)
257+
{
258+
int max_threads = 1024;
259+
int64_t block_size =
260+
std::min(static_cast<int64_t>(n),
261+
static_cast<int64_t>(max_threads/ 4));
262+
263+
const int64_t max_blocks =
264+
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
265+
const int64_t grid_size =
266+
std::min(max_blocks, (n + block_size - 1) / block_size);
267+
268+
kDequantizeChannelwise<T, DATA_TYPE><<<grid_size, block_size>>>(A, absmax, out, n, cout);
269+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
270+
}
271+
229272
std::vector<paddle::Tensor> DequantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, const paddle::Tensor& absmax, int blocksize, std::string quant_type) {
230273
int64_t input_numel = input.numel();
231274
int n = input_numel;
@@ -234,23 +277,44 @@ std::vector<paddle::Tensor> DequantizeBlockwise(const paddle::Tensor& input, con
234277
out_shape = {input_numel * 2, 1};
235278
n = n * 2;
236279
}
280+
if (blocksize == -1) {
281+
out_shape = {input.shape()[0] * 2, input.shape()[1]};
282+
}
237283
auto out = paddle::empty(out_shape, paddle::DataType::FLOAT32, input.place());
238284

239-
if (quant_type == "8bit")
240-
dequantize_blockwise<float, General8bit>(code.data<float>(), input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
241-
else if (quant_type == "nf4")
242-
dequantize_blockwise<float, NF4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
243-
else if (quant_type == "fp4")
244-
dequantize_blockwise<float, FP4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
245-
else
246-
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
285+
if (blocksize == -1) {
286+
if (quant_type == "8bit")
287+
PD_THROW("blocksize is -1 only support NF4 and FP4.");
288+
else
289+
blocksize = n / absmax.numel() * 2;
290+
291+
int cout = input.shape()[1];
292+
if (quant_type == "nf4")
293+
dequantize_channelwise<float, NF4>(input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), n, cout);
294+
else if (quant_type == "fp4")
295+
dequantize_channelwise<float, FP4>(input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), n, cout);
296+
else
297+
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
298+
} else {
299+
if (quant_type == "8bit")
300+
dequantize_blockwise<float, General8bit>(code.data<float>(), input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
301+
else if (quant_type == "nf4")
302+
dequantize_blockwise<float, NF4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
303+
else if (quant_type == "fp4")
304+
dequantize_blockwise<float, FP4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
305+
else
306+
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
307+
}
247308
return {out};
248309
};
249310

250311
std::vector<std::vector<int64_t>> GetDequantizeBlockwiseInferShape(const std::vector<int64_t>& input_shape, const std::vector<int64_t>& code_shape, const std::vector<int64_t>& abs_max_shape, int blocksize, std::string quant_type){
251312
int64_t first_shape = input_shape[0] * input_shape[1] * 2;
252313
if (quant_type != "8bit")
253-
return {{first_shape, 1}};
314+
if (blocksize != -1)
315+
return {{first_shape, 1}};
316+
else
317+
return {{input_shape[0] * 2, input_shape[1]}};
254318
else
255319
return {input_shape};
256320
}

csrc/lc/quantize_blockwise.cu

+88-27
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ __global__ void kQuantizeBlockwise(const float * code, const T * __restrict__ A,
279279
#pragma unroll NUM_PER_TH
280280
for(int j = 0; j < NUM_PER_TH/2; j++)
281281
{
282+
packed_4bit = 0;
282283
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
283284
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
284285
qvals[j] = packed_4bit;
@@ -360,9 +361,39 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, NF4)
360361
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, NF4)
361362
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4)
362363

364+
template <typename T, int DATA_TYPE>
365+
__global__ void kQuantizeChannelwise(const float *code,
366+
const T* A,
367+
unsigned char* out,
368+
float *absmax,
369+
int n,
370+
int cout) {
371+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
372+
373+
int num = n / 2;
374+
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
375+
int idx = 2*(i/cout)* cout + i%cout;
376+
float local_absmax = absmax[i %cout];
377+
float inv_local_absmax = 1.0f/local_absmax;
363378

379+
unsigned char packed_4bit = 0;
380+
switch(DATA_TYPE)
381+
{
382+
case FP4:
383+
packed_4bit |= dQuantizeFP4(((float)A[idx])*inv_local_absmax) << 4;
384+
packed_4bit |= dQuantizeFP4(((float)A[idx+cout])*inv_local_absmax);
385+
out[i] = packed_4bit;
386+
break;
387+
case NF4:
388+
packed_4bit |= dQuantizeNF4(((float)A[idx])*inv_local_absmax) << 4;
389+
packed_4bit |= dQuantizeNF4(((float)A[idx+cout])*inv_local_absmax);
390+
out[i] = packed_4bit;
391+
break;
392+
}
393+
}
394+
}
364395

365-
template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n)
396+
template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, paddle::Tensor& absmax, unsigned char *out, int blocksize, int n, int channelwise)
366397
{
367398
typedef PDTraits<D> traits_;
368399
typedef typename traits_::DataType DataType_;
@@ -372,61 +403,88 @@ template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float
372403
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
373404

374405
const DataType_* A_data = reinterpret_cast<const DataType_*>(A.data<data_t>());
375-
if(blocksize == 4096)
376-
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax, out, n);
377-
else if(blocksize == 2048)
378-
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax, out, n);
379-
else if(blocksize == 1024)
380-
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
381-
else if(blocksize == 512)
382-
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
383-
else if(blocksize == 256)
384-
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax, out, n);
385-
else if(blocksize == 128)
386-
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax, out, n);
387-
else if(blocksize == 64)
388-
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax, out, n);
389-
else
390-
PD_THROW("only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096].");
406+
if (channelwise == 0) {
407+
if(blocksize == 4096)
408+
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax.data<float>(), out, n);
409+
else if(blocksize == 2048)
410+
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax.data<float>(), out, n);
411+
else if(blocksize == 1024)
412+
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
413+
else if(blocksize == 512)
414+
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax.data<float>(), out, n);
415+
else if(blocksize == 256)
416+
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax.data<float>(), out, n);
417+
else if(blocksize == 128)
418+
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax.data<float>(), out, n);
419+
else if(blocksize == 64)
420+
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax.data<float>(), out, n);
421+
}
422+
else {
423+
if (DATA_TYPE == General8bit)
424+
PD_THROW("blocksize is -1 only support NF4 and FP4.");
425+
426+
int cout = A.shape()[1];
427+
int max_threads = 1024;
428+
429+
absmax = A.abs().max({0});
430+
431+
int64_t block_size =
432+
std::min(static_cast<int64_t>(n),
433+
static_cast<int64_t>(max_threads/ 4));
434+
435+
const int64_t max_blocks =
436+
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
437+
const int64_t grid_size =
438+
std::min(max_blocks, (n + block_size - 1) / block_size);
439+
440+
kQuantizeChannelwise<DataType_, DATA_TYPE><<<grid_size, block_size, 0>>>(
441+
code, A_data, out, absmax.data<float>(), n, cout);
442+
}
391443

392444

393445
CUDA_CHECK_RETURN(cudaPeekAtLastError());
394446
}
395447

396448
std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) {
397449
int n = input.numel();
450+
int channelwise = 0;
398451
std::vector<int64_t> out_shape = input.shape();
399452
if (quant_type != "8bit") { // 4bit
400453
out_shape = {(n + 1) / 2, 1};
401454
}
455+
if (blocksize == -1){
456+
blocksize = input.shape()[0];
457+
out_shape = {input.shape()[0]/2, input.shape()[1]};
458+
channelwise = 1;
459+
}
402460
auto out = paddle::empty(out_shape, paddle::DataType::UINT8, input.place());
403461
int64_t absmax_shape = n / blocksize;
404462
auto absmax = paddle::empty({absmax_shape}, paddle::DataType::FLOAT32, input.place());
405463
switch(input.type()) {
406464
case paddle::DataType::FLOAT32:
407465
if (quant_type == "8bit")
408-
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
466+
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
409467
else if (quant_type == "nf4") {
410-
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
468+
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
411469
}
412470
else if (quant_type == "fp4")
413-
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
471+
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
414472
return {out, absmax};
415473
case paddle::DataType::FLOAT16:
416474
if (quant_type == "8bit")
417-
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
475+
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
418476
else if (quant_type == "nf4")
419-
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
477+
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
420478
else if (quant_type == "fp4")
421-
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
479+
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
422480
return {out, absmax};
423481
case paddle::DataType::BFLOAT16:
424482
if (quant_type == "8bit")
425-
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
483+
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
426484
else if (quant_type == "nf4")
427-
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
485+
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
428486
else if (quant_type == "fp4")
429-
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
487+
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax, out.data<unsigned char>(), blocksize, n, channelwise);
430488
return {out, absmax};
431489

432490
default:
@@ -440,7 +498,10 @@ std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const
440498
std::vector<std::vector<int64_t>> GetQuantizeBlockwiseInferShape(const std::vector<int64_t>& input_shape, const std::vector<int64_t>& code_shape, int blocksize, std::string quant_type){
441499
int64_t first_shape = (input_shape[0] * input_shape[1] + 1) / 2;
442500
if (quant_type != "8bit")
443-
return {{first_shape, 1}};
501+
if (blocksize != -1)
502+
return {{first_shape, 1}};
503+
else
504+
return {{input_shape[0]/2, input_shape[1]}};
444505
else
445506
return {input_shape};
446507
}

0 commit comments

Comments
 (0)