diff --git a/csrc/lc/helper.h b/csrc/lc/helper.h new file mode 100644 index 000000000..98cd551c3 --- /dev/null +++ b/csrc/lc/helper.h @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/extension.h" + +template +class PDTraits; + +template <> +class PDTraits { +public: + typedef float DataType; + typedef float data_t; + static const paddle::DataType dtype_s = paddle::DataType::FLOAT32; +}; + +template <> +class PDTraits { +public: + typedef half DataType; + typedef paddle::float16 data_t; + static const paddle::DataType dtype_s = paddle::DataType::FLOAT16; +}; + +template <> +class PDTraits { +public: + typedef __nv_bfloat16 DataType; + typedef paddle::bfloat16 data_t; + static const paddle::DataType dtype_s = paddle::DataType::BFLOAT16; +}; \ No newline at end of file diff --git a/csrc/lc/nf4.cu b/csrc/lc/nf4.cu new file mode 100644 index 000000000..f5688f07d --- /dev/null +++ b/csrc/lc/nf4.cu @@ -0,0 +1,394 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "helper.h" +#include +using namespace std; + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +__device__ unsigned char dQuantizeNF4(float x) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} + +template +//__launch_bounds__(TH, 4) +__global__ void kQuantizeBlockwiseNF4(const T* A, float *absmax, unsigned char *out, const int n) +{ + // 所有的 CUDA blocks 处理的所有元素个数 + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + // 当前 CUDA block 处理元素的起始索引 + const int base_idx = (blockIdx.x * BLOCK_SIZE); + // 当前 CUDA thread 处理的输入元素 + T vals[NUM_PER_TH]; + // 当前 CUDA thread 处理的输出元素个数 + const int output_num_per_thread = NUM_PER_TH/2; + // 当前 CUDA thread 处理的输出元素 + unsigned char qvals[output_num_per_thread]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore StoreChar; + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + // 每个CUDA block (也是每个 quantization block)的absmax + __shared__ float smem_absmax_value[1]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); + + if(threadIdx.x == 0) + smem_absmax_value[0] = local_abs_max; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax[i/BLOCK_SIZE] = local_abs_max; + else + local_abs_max = smem_absmax_value[0]; + + __syncwarp(); + + local_abs_max = 1.0f/local_abs_max; + + unsigned char packed_4bit = 0; + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[i/2]), qvals, (valid_items+1)/2); + } +} + + + + +__device__ float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +template +__global__ void kDequantizeBlockwiseNF4(const unsigned char * A, const float * absmax, T *out, int blocksize, const int n) +{ + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH*2]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + + __syncthreads(); + StoreT(storet).Store(&(out[i*2]), vals, valid_items_store); + } +} + +#define MAKE_kDequantizeBlockwiseNF4(dtype, tile_size, threads, num_per_thread) \ +template __global__ void kDequantizeBlockwiseNF4(const unsigned char * A, const float * absmax, dtype *out, int blocksize, const int n); \ + +MAKE_kDequantizeBlockwiseNF4(half, 512, 64, 8) +MAKE_kDequantizeBlockwiseNF4(__nv_bfloat16, 512, 64, 8) +MAKE_kDequantizeBlockwiseNF4(float, 512, 64, 8) + + +template +std::vector LaunchDeQuantizeNF4(const paddle::Tensor& input, + const paddle::Tensor& absmax, + int block_size) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto output = paddle::full(input.shape(), 1, traits_::dtype_s, input.place()); + + const unsigned char *in_ptr = input.data(); + const float *abs_max_ptr = absmax.data(); + DataType_ *out_ptr = reinterpret_cast(output.mutable_data()); + // 修改上句 + + const int n = input.numel(); + int num_blocks = n/block_size; + num_blocks = n % block_size == 0 ? num_blocks : num_blocks + 1; + int tile_size = 1024; + + kDequantizeBlockwiseNF4<<<(n+tile_size-1)/tile_size, 64>>>(in_ptr, abs_max_ptr, out_ptr, block_size/2, n); + return {output}; +} + + + +#define MAKE_kQuantizeBlockwiseNF4(dtype, blocksize, num_per_thread) \ +template __global__ void kQuantizeBlockwiseNF4(const dtype * A, float *absmax, unsigned char *out, const int n); \ + +MAKE_kQuantizeBlockwiseNF4(half, 4096, 4) +MAKE_kQuantizeBlockwiseNF4(half, 1024, 4) +MAKE_kQuantizeBlockwiseNF4(half, 512, 2) +MAKE_kQuantizeBlockwiseNF4(half, 256, 2) +MAKE_kQuantizeBlockwiseNF4(half, 128, 2) +MAKE_kQuantizeBlockwiseNF4(half, 64, 2) + +MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 4096, 4) +MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 1024, 4) +MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 512, 2) +MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 256, 2) +MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 128, 2) +MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 64, 2) + +MAKE_kQuantizeBlockwiseNF4(float, 4096, 4) +MAKE_kQuantizeBlockwiseNF4(float, 1024, 4) +MAKE_kQuantizeBlockwiseNF4(float, 512, 2) +MAKE_kQuantizeBlockwiseNF4(float, 256, 2) +MAKE_kQuantizeBlockwiseNF4(float, 128, 2) +MAKE_kQuantizeBlockwiseNF4(float, 64, 2) + +template +std::vector LaunchQuantizeNF4(const paddle::Tensor& input, int block_size) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto input_shape = input.shape(); + auto output = paddle::full(input_shape, 1, paddle::DataType::UINT8, input.place()); + const int n = input.numel(); + int num_blocks = n/block_size; + num_blocks = n % block_size == 0 ? num_blocks : num_blocks + 1; + + auto abs_max = paddle::full({num_blocks}, 1, paddle::DataType::FLOAT32, input.place()); + + const DataType_ *in_ptr = reinterpret_cast(input.data()); + unsigned char *out_ptr = output.mutable_data(); + float *abs_max_ptr = abs_max.mutable_data(); + + if(block_size == 2048) { + kQuantizeBlockwiseNF4<<>>(in_ptr, abs_max_ptr, out_ptr, n); + } else if(block_size == 1024) { + kQuantizeBlockwiseNF4<<>>(in_ptr, abs_max_ptr, out_ptr, n); + } else if(block_size == 512) { + kQuantizeBlockwiseNF4<<>>(in_ptr, abs_max_ptr, out_ptr, n); + } else if(block_size == 256) { + kQuantizeBlockwiseNF4<<>>(in_ptr, abs_max_ptr, out_ptr, n); + } else if(block_size == 128) { + kQuantizeBlockwiseNF4<<>>(in_ptr, abs_max_ptr, out_ptr, n); + } else if(block_size == 64) { + kQuantizeBlockwiseNF4<<>>(in_ptr, abs_max_ptr, out_ptr, n); + } + return {output, abs_max}; +} + +std::vector QuantizeNF4(const paddle::Tensor& input, int block_size) { + switch (input.type()) { + case paddle::DataType::BFLOAT16: { + return LaunchQuantizeNF4(input, block_size); + } + case paddle::DataType::FLOAT16: { + return LaunchQuantizeNF4(input, block_size); + } + case paddle::DataType::FLOAT32: { + return LaunchQuantizeNF4(input, block_size); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only bfloat16, float16 and float32 are supported. "); + break; + } + } +} + +std::vector DeQuantizeNF4(const paddle::Tensor& input, + const paddle::Tensor& absmax, + int block_size, + int dtype) { + + switch (dtype) { + case 0: { + return LaunchDeQuantizeNF4(input, absmax, block_size); + } + case 1: { + return LaunchDeQuantizeNF4(input, absmax, block_size); + } + case 2: { + return LaunchDeQuantizeNF4(input, absmax, block_size); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only bfloat16, float16 and float32 are supported. "); + break; + } + } +} + +PD_BUILD_OP(quantize_nf4) + .Inputs({"input"}) + .Outputs({"out", "abs_max"}) + .Attrs({"block_size: int"}) + .SetKernelFn(PD_KERNEL(QuantizeNF4)); + +PD_BUILD_OP(dequantize_nf4) + .Inputs({"input", "absmax"}) + .Outputs({"out"}) + .Attrs({"block_size: int", "dtype: int"}) + .SetKernelFn(PD_KERNEL(DeQuantizeNF4)); \ No newline at end of file diff --git a/csrc/requirements.txt b/csrc/requirements.txt new file mode 100644 index 000000000..0bf062538 --- /dev/null +++ b/csrc/requirements.txt @@ -0,0 +1,2 @@ +cupy-cuda116 +pybind11 \ No newline at end of file diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py new file mode 100644 index 000000000..5facfc062 --- /dev/null +++ b/csrc/setup_cuda.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils.cpp_extension import CUDAExtension, setup + +setup( + name="paddleslim_ops", + ext_modules=CUDAExtension(sources=[ + "./lc/nf4.cu", + ]), ) diff --git a/paddleslim/lc/__init__.py b/paddleslim/lc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddleslim/lc/layers/__init__.py b/paddleslim/lc/layers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddleslim/lc/layers/linear.py b/paddleslim/lc/layers/linear.py new file mode 100644 index 000000000..0b99c4d28 --- /dev/null +++ b/paddleslim/lc/layers/linear.py @@ -0,0 +1,20 @@ +import paddle +import paddle.nn as nn + + +class WeightQuantizationLinear(nn.Layer): + def __init__( + self, + linear: paddle.nn.Linear, ): + super().__init__() + self.in_features = linear.weight.shape[0] + self.out_features = linear.weight.shape[1] + self.dtype = linear.dtype + self.weight_name = linear.weight.name + self.quant_weight_name = ".".join([self.weight_name, "quant_weight"]) + + def forward(self, x): + raise NotImplementedError() + + def quantize(self, weight) -> paddle.Tensor: + raise NotImplementedError() diff --git a/paddleslim/lc/layers/nf4_linear.py b/paddleslim/lc/layers/nf4_linear.py new file mode 100644 index 000000000..0b5500dd2 --- /dev/null +++ b/paddleslim/lc/layers/nf4_linear.py @@ -0,0 +1,57 @@ +import paddle +import paddle.nn as nn +from paddleslim.lc.quantizers import NF4Quantizer +from .linear import WeightQuantizationLinear + + +class NF4Linear(WeightQuantizationLinear): + quant_dtype = "int4" + weight_dtype = "int8" + quant_scale_suffix = "quant_scale" + double_quant_scale_suffix = "double_quant_scale" + + def __init__( + self, + linear: nn.Linear, + block_size=64, + use_double_quant=False, ): + super(NF4Linear, self).__init__(linear) + self.block_size = block_size + self.double_quant = use_double_quant + self.quantizer = NF4Quantizer(block_size, use_double_quant) + # PaddlePaddle dosen't support Int4 data type, one Int8 data represents two Int4 data. + self.quant_weight = self.create_parameter( + shape=[self.out_features // 2, self.in_features], + attr=paddle.ParamAttr(self.quant_weight_name), + dtype=NF4Linear.weight_dtype, + is_bias=False, ) + + self.quant_scale_name = ".".join( + [self.weight_name, NF4Linear.quant_scale_suffix]) + self.quant_scale = self.create_parameter( + shape=[self.out_features], + attr=paddle.ParamAttr(self.quant_scale_name), + dtype="float32", # to be fixed + is_bias=False, ) + if self.double_quant: + self.double_quant_scale_name = ".".join( + [self.weight_name, NF4Linear.double_quant_scale_suffix]) + self.double_quant_scale = self.create_parameter( + shape=[self.out_features], + attr=paddle.ParamAttr(self.double_quant_scale_name), + dtype="float32", + is_bias=False, ) + + def quantize(self, weight): + quantized_weight = self.quantizer.quantize(weight) + return { + self.quant_weight_name: quantized_weight, + self.quant_scale_name: self.quantizer.quant_scale, + self.double_quant_scale_name: self.quantizer.double_quant_scale + } + + def forward(self, x): + self.quantizer.quant_scale = self.state_dict[self.quant_scale_name] + self.quantizer.double_quant_scale = self.state_dict[ + self.double_quant_scale_name] + return self.quantizer.matmul(x, self.quant_weight) diff --git a/paddleslim/lc/quantizers/__init__.py b/paddleslim/lc/quantizers/__init__.py new file mode 100644 index 000000000..f71391044 --- /dev/null +++ b/paddleslim/lc/quantizers/__init__.py @@ -0,0 +1 @@ +from .nf4 import NF4Quantizer diff --git a/paddleslim/lc/quantizers/base_quantizer.py b/paddleslim/lc/quantizers/base_quantizer.py new file mode 100644 index 000000000..5117bdaec --- /dev/null +++ b/paddleslim/lc/quantizers/base_quantizer.py @@ -0,0 +1,12 @@ +import paddle + + +class BaseQuantizer(): + def quantize(self, x: paddle.Tensor): + raise NotImplementedError() + + def dequantize(self, x: paddle.Tensor): + raise NotImplementedError() + + def matmul(self, x: paddle.Tensor, y: paddle.Tensor, bias: paddle.Tensor): + raise NotImplementedError() diff --git a/paddleslim/lc/quantizers/nf4.py b/paddleslim/lc/quantizers/nf4.py new file mode 100644 index 000000000..422c85511 --- /dev/null +++ b/paddleslim/lc/quantizers/nf4.py @@ -0,0 +1,27 @@ +import paddle +from .base_quantizer import BaseQuantizer +import paddleslim_ops + + +class NF4Quantizer(BaseQuantizer): + dtype = "int4" + + def __init__(self, block_size=64, double_quant=False): + super(BaseQuantizer, self).__init__() + self.block_size = block_size + self.double_quant = double_quant + self.quant_scale = None + self.double_quant_scale = None + + def quantize(self, x: paddle.Tensor): + out, abs_max = paddleslim_ops.quantize_nf4( + x, block_size=self.block_size) + self.quant_scale = abs_max + return out + + def dequantize(self, x: paddle.Tensor, dtype: int): + return paddleslim_ops.dequantize_nf4( + x, self.quant_scale, block_size=self.block_size, dtype=dtype) + + def matmul(self, x: paddle.Tensor, y: paddle.Tensor, bias: paddle.Tensor): + return x @ self.dequantize(y) + bias