@@ -362,42 +362,37 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4)
362
362
363
363
364
364
365
- template <typename T , int DATA_TYPE> void quantize_blockwise(const float *code, const T * A, float *absmax, unsigned char *out, int blocksize, int n)
365
+ template <paddle::DataType D , int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n)
366
366
{
367
+ typedef PDTraits<D> traits_;
368
+ typedef typename traits_::DataType DataType_;
369
+ typedef typename traits_::data_t data_t ;
370
+
367
371
int num_blocks = n/blocksize;
368
372
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1 ;
369
373
374
+ const DataType_* A_data = reinterpret_cast <const DataType_*>(A.data <data_t >());
370
375
if (blocksize == 4096 )
371
- kQuantizeBlockwise <T , 4096 , 4 , 0 ><<<num_blocks, 1024 >>> (code, A , absmax, out, n);
376
+ kQuantizeBlockwise <DataType_ , 4096 , 4 , 0 ><<<num_blocks, 1024 >>> (code, A_data , absmax, out, n);
372
377
else if (blocksize == 2048 )
373
- kQuantizeBlockwise <T , 2048 , 4 , DATA_TYPE><<<num_blocks, 512 >>> (code, A , absmax, out, n);
378
+ kQuantizeBlockwise <DataType_ , 2048 , 4 , DATA_TYPE><<<num_blocks, 512 >>> (code, A_data , absmax, out, n);
374
379
else if (blocksize == 1024 )
375
- kQuantizeBlockwise <T , 1024 , 4 , DATA_TYPE><<<num_blocks, 256 >>> (code, A , absmax, out, n);
380
+ kQuantizeBlockwise <DataType_ , 1024 , 4 , DATA_TYPE><<<num_blocks, 256 >>> (code, A_data , absmax, out, n);
376
381
else if (blocksize == 512 )
377
- kQuantizeBlockwise <T , 512 , 2 , DATA_TYPE><<<num_blocks, 256 >>> (code, A , absmax, out, n);
382
+ kQuantizeBlockwise <DataType_ , 512 , 2 , DATA_TYPE><<<num_blocks, 256 >>> (code, A_data , absmax, out, n);
378
383
else if (blocksize == 256 )
379
- kQuantizeBlockwise <T , 256 , 2 , DATA_TYPE><<<num_blocks, 128 >>> (code, A , absmax, out, n);
384
+ kQuantizeBlockwise <DataType_ , 256 , 2 , DATA_TYPE><<<num_blocks, 128 >>> (code, A_data , absmax, out, n);
380
385
else if (blocksize == 128 )
381
- kQuantizeBlockwise <T , 128 , 2 , DATA_TYPE><<<num_blocks, 64 >>> (code, A , absmax, out, n);
386
+ kQuantizeBlockwise <DataType_ , 128 , 2 , DATA_TYPE><<<num_blocks, 64 >>> (code, A_data , absmax, out, n);
382
387
else if (blocksize == 64 )
383
- kQuantizeBlockwise <T , 64 , 2 , DATA_TYPE><<<num_blocks, 32 >>> (code, A , absmax, out, n);
388
+ kQuantizeBlockwise <DataType_ , 64 , 2 , DATA_TYPE><<<num_blocks, 32 >>> (code, A_data , absmax, out, n);
384
389
else
385
390
PD_THROW (" only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096]." );
386
391
387
392
388
393
CUDA_CHECK_RETURN (cudaPeekAtLastError ());
389
394
}
390
395
391
- template void quantize_blockwise<half, General8bit>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
392
- template void quantize_blockwise<half, FP4>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
393
- template void quantize_blockwise<half, NF4>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
394
- template void quantize_blockwise<float , General8bit>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
395
- template void quantize_blockwise<float , FP4>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
396
- template void quantize_blockwise<float , NF4>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
397
- template void quantize_blockwise<__nv_bfloat16, General8bit>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);
398
- template void quantize_blockwise<__nv_bfloat16, FP4>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);
399
- template void quantize_blockwise<__nv_bfloat16, NF4>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);
400
-
401
396
std::vector<paddle::Tensor> QuantizeBlockwise (const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) {
402
397
int n = input.numel ();
403
398
std::vector<int64_t > out_shape = input.shape ();
@@ -410,28 +405,28 @@ std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const
410
405
switch (input.type ()) {
411
406
case paddle::DataType::FLOAT32:
412
407
if (quant_type == " 8bit" )
413
- quantize_blockwise<float , General8bit>(code.data <float >(), input. data < float >() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
408
+ quantize_blockwise<paddle::DataType::FLOAT32 , General8bit>(code.data <float >(), input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
414
409
else if (quant_type == " nf4" ) {
415
- quantize_blockwise<float , NF4>(NULL , input. data < float >() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
410
+ quantize_blockwise<paddle::DataType::FLOAT32 , NF4>(NULL , input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
416
411
}
417
412
else if (quant_type == " fp4" )
418
- quantize_blockwise<float , FP4>(NULL , input. data < float >() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
413
+ quantize_blockwise<paddle::DataType::FLOAT32 , FP4>(NULL , input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
419
414
return {out, absmax};
420
415
case paddle::DataType::FLOAT16:
421
416
if (quant_type == " 8bit" )
422
- quantize_blockwise<half , General8bit>(code.data <float >(), input. data <half>() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
417
+ quantize_blockwise<paddle::DataType::FLOAT16 , General8bit>(code.data <float >(), input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
423
418
else if (quant_type == " nf4" )
424
- quantize_blockwise<half , NF4>(NULL , input. data <half>() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
419
+ quantize_blockwise<paddle::DataType::FLOAT16 , NF4>(NULL , input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
425
420
else if (quant_type == " fp4" )
426
- quantize_blockwise<half , FP4>(NULL , input. data <half>() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
421
+ quantize_blockwise<paddle::DataType::FLOAT16 , FP4>(NULL , input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
427
422
return {out, absmax};
428
423
case paddle::DataType::BFLOAT16:
429
424
if (quant_type == " 8bit" )
430
- quantize_blockwise<__nv_bfloat16 , General8bit>(code.data <float >(), input. data <__nv_bfloat16>() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
425
+ quantize_blockwise<paddle::DataType::BFLOAT16 , General8bit>(code.data <float >(), input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
431
426
else if (quant_type == " nf4" )
432
- quantize_blockwise<__nv_bfloat16 , NF4>(NULL , input. data <__nv_bfloat16>() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
427
+ quantize_blockwise<paddle::DataType::BFLOAT16 , NF4>(NULL , input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
433
428
else if (quant_type == " fp4" )
434
- quantize_blockwise<__nv_bfloat16 , FP4>(NULL , input. data <__nv_bfloat16>() , absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
429
+ quantize_blockwise<paddle::DataType::BFLOAT16 , FP4>(NULL , input, absmax.data <float >(), out.data <unsigned char >(), blocksize, n);
435
430
return {out, absmax};
436
431
437
432
default :
0 commit comments