@@ -279,6 +279,7 @@ __global__ void kQuantizeBlockwise(const float * code, const T * __restrict__ A,
279
279
#pragma unroll NUM_PER_TH
280
280
for (int j = 0 ; j < NUM_PER_TH/2 ; j++)
281
281
{
282
+ packed_4bit = 0 ;
282
283
packed_4bit |= dQuantizeNF4 (((float )vals[2 *j])*local_abs_max) << 4 ;
283
284
packed_4bit |= dQuantizeNF4 (((float )vals[2 *j+1 ])*local_abs_max);
284
285
qvals[j] = packed_4bit;
@@ -360,9 +361,39 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, NF4)
360
361
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128 , 2 , NF4)
361
362
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64 , 2 , NF4)
362
363
364
+ template <typename T, int DATA_TYPE>
365
+ __global__ void kQuantizeChannelwise(const float *code,
366
+ const T* A,
367
+ unsigned char * out,
368
+ float *absmax,
369
+ int n,
370
+ int cout) {
371
+ int idx = blockDim .x * blockIdx .x + threadIdx .x ;
372
+
373
+ int num = n / 2 ;
374
+ for (int i = idx; i < num; i += blockDim .x * gridDim .x ) {
375
+ int idx = 2 *(i/cout)* cout + i%cout;
376
+ float local_absmax = absmax[i %cout];
377
+ float inv_local_absmax = 1 .0f /local_absmax;
363
378
379
+ unsigned char packed_4bit = 0 ;
380
+ switch (DATA_TYPE)
381
+ {
382
+ case FP4:
383
+ packed_4bit |= dQuantizeFP4 (((float )A[idx])*inv_local_absmax) << 4 ;
384
+ packed_4bit |= dQuantizeFP4 (((float )A[idx+cout])*inv_local_absmax);
385
+ out[i] = packed_4bit;
386
+ break ;
387
+ case NF4:
388
+ packed_4bit |= dQuantizeNF4 (((float )A[idx])*inv_local_absmax) << 4 ;
389
+ packed_4bit |= dQuantizeNF4 (((float )A[idx+cout])*inv_local_absmax);
390
+ out[i] = packed_4bit;
391
+ break ;
392
+ }
393
+ }
394
+ }
364
395
365
- template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, float * absmax, unsigned char *out, int blocksize, int n)
396
+ template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise (const float *code, const paddle::Tensor& A, paddle::Tensor& absmax, unsigned char *out, int blocksize, int n, int channelwise )
366
397
{
367
398
typedef PDTraits<D> traits_;
368
399
typedef typename traits_::DataType DataType_;
@@ -372,61 +403,88 @@ template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float
372
403
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1 ;
373
404
374
405
const DataType_* A_data = reinterpret_cast <const DataType_*>(A.data <data_t >());
375
- if (blocksize == 4096 )
376
- kQuantizeBlockwise <DataType_, 4096 , 4 , 0 ><<<num_blocks, 1024 >>> (code, A_data, absmax, out, n);
377
- else if (blocksize == 2048 )
378
- kQuantizeBlockwise <DataType_, 2048 , 4 , DATA_TYPE><<<num_blocks, 512 >>> (code, A_data, absmax, out, n);
379
- else if (blocksize == 1024 )
380
- kQuantizeBlockwise <DataType_, 1024 , 4 , DATA_TYPE><<<num_blocks, 256 >>> (code, A_data, absmax, out, n);
381
- else if (blocksize == 512 )
382
- kQuantizeBlockwise <DataType_, 512 , 2 , DATA_TYPE><<<num_blocks, 256 >>> (code, A_data, absmax, out, n);
383
- else if (blocksize == 256 )
384
- kQuantizeBlockwise <DataType_, 256 , 2 , DATA_TYPE><<<num_blocks, 128 >>> (code, A_data, absmax, out, n);
385
- else if (blocksize == 128 )
386
- kQuantizeBlockwise <DataType_, 128 , 2 , DATA_TYPE><<<num_blocks, 64 >>> (code, A_data, absmax, out, n);
387
- else if (blocksize == 64 )
388
- kQuantizeBlockwise <DataType_, 64 , 2 , DATA_TYPE><<<num_blocks, 32 >>> (code, A_data, absmax, out, n);
389
- else
390
- PD_THROW (" only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096]." );
406
+ if (channelwise == 0 ) {
407
+ if (blocksize == 4096 )
408
+ kQuantizeBlockwise <DataType_, 4096 , 4 , 0 ><<<num_blocks, 1024 >>> (code, A_data, absmax.data <float >(), out, n);
409
+ else if (blocksize == 2048 )
410
+ kQuantizeBlockwise <DataType_, 2048 , 4 , DATA_TYPE><<<num_blocks, 512 >>> (code, A_data, absmax.data <float >(), out, n);
411
+ else if (blocksize == 1024 )
412
+ kQuantizeBlockwise <DataType_, 1024 , 4 , DATA_TYPE><<<num_blocks, 256 >>> (code, A_data, absmax.data <float >(), out, n);
413
+ else if (blocksize == 512 )
414
+ kQuantizeBlockwise <DataType_, 512 , 2 , DATA_TYPE><<<num_blocks, 256 >>> (code, A_data, absmax.data <float >(), out, n);
415
+ else if (blocksize == 256 )
416
+ kQuantizeBlockwise <DataType_, 256 , 2 , DATA_TYPE><<<num_blocks, 128 >>> (code, A_data, absmax.data <float >(), out, n);
417
+ else if (blocksize == 128 )
418
+ kQuantizeBlockwise <DataType_, 128 , 2 , DATA_TYPE><<<num_blocks, 64 >>> (code, A_data, absmax.data <float >(), out, n);
419
+ else if (blocksize == 64 )
420
+ kQuantizeBlockwise <DataType_, 64 , 2 , DATA_TYPE><<<num_blocks, 32 >>> (code, A_data, absmax.data <float >(), out, n);
421
+ }
422
+ else {
423
+ if (DATA_TYPE == General8bit)
424
+ PD_THROW (" blocksize is -1 only support NF4 and FP4." );
425
+
426
+ int cout = A.shape ()[1 ];
427
+ int max_threads = 1024 ;
428
+
429
+ absmax = A.abs ().max ({0 });
430
+
431
+ int64_t block_size =
432
+ std::min (static_cast <int64_t >(n),
433
+ static_cast <int64_t >(max_threads/ 4 ));
434
+
435
+ const int64_t max_blocks =
436
+ std::max (((max_threads - 1 ) / block_size + 1 ), static_cast <int64_t >(1 ));
437
+ const int64_t grid_size =
438
+ std::min (max_blocks, (n + block_size - 1 ) / block_size);
439
+
440
+ kQuantizeChannelwise <DataType_, DATA_TYPE><<<grid_size, block_size, 0 >>> (
441
+ code, A_data, out, absmax.data <float >(), n, cout);
442
+ }
391
443
392
444
393
445
CUDA_CHECK_RETURN (cudaPeekAtLastError ());
394
446
}
395
447
396
448
std::vector<paddle::Tensor> QuantizeBlockwise (const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) {
397
449
int n = input.numel ();
450
+ int channelwise = 0 ;
398
451
std::vector<int64_t > out_shape = input.shape ();
399
452
if (quant_type != " 8bit" ) { // 4bit
400
453
out_shape = {(n + 1 ) / 2 , 1 };
401
454
}
455
+ if (blocksize == -1 ){
456
+ blocksize = input.shape ()[0 ];
457
+ out_shape = {input.shape ()[0 ]/2 , input.shape ()[1 ]};
458
+ channelwise = 1 ;
459
+ }
402
460
auto out = paddle::empty (out_shape, paddle::DataType::UINT8, input.place ());
403
461
int64_t absmax_shape = n / blocksize;
404
462
auto absmax = paddle::empty ({absmax_shape}, paddle::DataType::FLOAT32, input.place ());
405
463
switch (input.type ()) {
406
464
case paddle::DataType::FLOAT32:
407
465
if (quant_type == " 8bit" )
408
- quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data <float >(), input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
466
+ quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data <float >(), input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
409
467
else if (quant_type == " nf4" ) {
410
- quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL , input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
468
+ quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL , input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
411
469
}
412
470
else if (quant_type == " fp4" )
413
- quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL , input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
471
+ quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL , input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
414
472
return {out, absmax};
415
473
case paddle::DataType::FLOAT16:
416
474
if (quant_type == " 8bit" )
417
- quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data <float >(), input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
475
+ quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data <float >(), input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
418
476
else if (quant_type == " nf4" )
419
- quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL , input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
477
+ quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL , input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
420
478
else if (quant_type == " fp4" )
421
- quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL , input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
479
+ quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL , input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
422
480
return {out, absmax};
423
481
case paddle::DataType::BFLOAT16:
424
482
if (quant_type == " 8bit" )
425
- quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data <float >(), input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
483
+ quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data <float >(), input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
426
484
else if (quant_type == " nf4" )
427
- quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL , input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
485
+ quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL , input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
428
486
else if (quant_type == " fp4" )
429
- quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL , input, absmax. data < float >() , out.data <unsigned char >(), blocksize, n);
487
+ quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL , input, absmax, out.data <unsigned char >(), blocksize, n, channelwise );
430
488
return {out, absmax};
431
489
432
490
default :
@@ -440,7 +498,10 @@ std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const
440
498
std::vector<std::vector<int64_t >> GetQuantizeBlockwiseInferShape (const std::vector<int64_t >& input_shape, const std::vector<int64_t >& code_shape, int blocksize, std::string quant_type){
441
499
int64_t first_shape = (input_shape[0 ] * input_shape[1 ] + 1 ) / 2 ;
442
500
if (quant_type != " 8bit" )
443
- return {{first_shape, 1 }};
501
+ if (blocksize != -1 )
502
+ return {{first_shape, 1 }};
503
+ else
504
+ return {{input_shape[0 ]/2 , input_shape[1 ]}};
444
505
else
445
506
return {input_shape};
446
507
}
0 commit comments