diff --git a/paddleslim/quant/layers/custom_attention.py b/paddleslim/quant/layers/custom_attention.py new file mode 100644 index 000000000..7d3d5ee64 --- /dev/null +++ b/paddleslim/quant/layers/custom_attention.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Custome Attention Layer for quantization. +""" +import paddle.tensor as tensor +from paddle.nn import Layer +from paddle.nn.quant.format import ConvertibleQuantedLayer + + +class QuantizedCustomAttentionLayer(ConvertibleQuantedLayer): + """ + Quantized Custom Attention Layer. + """ + + def __init__(self, layer: Layer, q_config=None): + """ + Initialize the QuantizeWrapper class. + + Args: + layer (Layer): The layer to be quantized. + q_config (QuantConfig, optional): The quantization configuration. Defaults to None. + """ + super().__init__() + # hard code: get activation quanter from weight + self.activation_quanter_k = q_config.weight._instance(layer) + self.activation_quanter_v = q_config.activation._instance(layer) + self.layer = layer + self.quant_info = None + layer_name = self.layer.full_name() + self.layer_id = int(layer_name.split("_")[-1]) + self.kv_losses = {} + + def forward(self, q, config, k, v, attention_mask, output_attentions, **kwargs): + """forward""" + perm = [0, 2, 1, 3] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3] + tmp_k = tensor.transpose(x=k, perm=perm) + tmp_v = tensor.transpose(x=v, perm=perm) + if self.activation_quanter_k is not None: + tmp_k = self.activation_quanter_k(tmp_k) + if self.activation_quanter_v is not None: + tmp_v = self.activation_quanter_v(tmp_v) + k = tensor.transpose(x=tmp_k, perm=perm) + v = tensor.transpose(x=tmp_v, perm=perm) + return self.layer( + q, + config, + k, + v, + attention_mask, + output_attentions, + **kwargs, + ) + + def weights_to_quanters(self): + """weights to quanters""" + return [] + + def activation_quanters(self): + """activation to quanters""" + return ["activation_quanter_k", "activation_quanter_v"] diff --git a/paddleslim/quant/observers/abs_max_headwise.py b/paddleslim/quant/observers/abs_max_headwise.py new file mode 100644 index 000000000..8ec286f04 --- /dev/null +++ b/paddleslim/quant/observers/abs_max_headwise.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from .channel_wise import ChannelWiseObserver +from paddle.quantization.factory import ObserverFactory + + +class AbsMaxHeadwiseObserver(ObserverFactory): + r""" + It collects channel-wise maximum absolute values of target weights. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import AbsMaxHeadwiseObserver + quanter = AbsMaxHeadwiseObserver() + q_config = QuantConfig(activation=None, weight=quanter) + """ + + def __init__(self, quant_bits=8, quant_axis=None): + super(AbsMaxHeadwiseObserver, self).__init__(quant_bits=quant_bits, quant_axis=quant_axis) + + def _get_class(self): + return AbsMaxHeadwiseObserverLayer + + +class AbsMaxHeadwiseObserverLayer(ChannelWiseObserver): + def __init__(self, layer, quant_bits=8, quant_axis=None): + super(AbsMaxHeadwiseObserverLayer, self).__init__( + layer, quant_bits=quant_bits, sign=True, symmetric=True, quant_axis=quant_axis + ) + self.quant_bits = quant_bits + self.calibration_loss = float("inf") + self.qmin, self.qmax = self.qmin_qmax + self._layer = layer + self._max = None + self._scale = None + self._zero_point = None + + def forward(self, inputs): + self._max = self._cal_abs_max(inputs) + return inputs + + def _cal_abs_max(self, inputs): + reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != self.quant_axis()]) + abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis).cast("float32") + abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values) + + if self._max is not None: + abs_max_values = paddle.maximum(abs_max_values, self._max) + + return abs_max_values + + def min_value(self) -> float: + return 0.0 + + def max_value(self) -> float: + return self._max + + def cal_thresholds(self): + """Compute thresholds for MAX function.""" + self._scale = self._max + self._zero_point = paddle.zeros_like(self._scale) + + def scales(self): + """Return output scales.""" + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """Return output zero points.""" + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point diff --git a/paddleslim/quant/observers/avg.py b/paddleslim/quant/observers/avg.py index 199a2aa0e..3a4b6b769 100644 --- a/paddleslim/quant/observers/avg.py +++ b/paddleslim/quant/observers/avg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddleslim/quant/observers/avg_headwise.py b/paddleslim/quant/observers/avg_headwise.py new file mode 100644 index 000000000..a040ffbf8 --- /dev/null +++ b/paddleslim/quant/observers/avg_headwise.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle.quantization.factory import ObserverFactory + +from .abs_max_headwise import AbsMaxHeadwiseObserverLayer + + +class AvgHeadwiseObserver(ObserverFactory): + r""" + It collects channel-wise maximum absolute values of target weights. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import AbsMaxHeadwiseObserver + quanter = AbsMaxHeadwiseObserver() + q_config = QuantConfig(activation=None, weight=quanter) + """ + + def __init__(self, quant_bits=8, quant_axis=None, moving_avg=False): + super(AvgHeadwiseObserver, self).__init__(quant_bits=quant_bits, quant_axis=quant_axis, moving_avg=moving_avg) + + def _get_class(self): + return AvgHeadwiseObserverLayer + + +class AvgHeadwiseObserverLayer(AbsMaxHeadwiseObserverLayer): + def __init__(self, layer, quant_bits=8, quant_axis=None, moving_avg=True): + super(AvgHeadwiseObserverLayer, self).__init__(layer, quant_bits=quant_bits, quant_axis=quant_axis) + self.quant_bits = quant_bits + self._qmin, self._qmax = self.qmin_qmax + self._max = None + self._scale = None + self._zero_point = None + if quant_axis is not None: + self._channel_axis = quant_axis + self._current_iters = 0 + self._range_update_factor_min = 0.001 + self._moving_avg = moving_avg + + def forward(self, inputs, quant_axis=None): + if quant_axis is not None: + self._channel_axis = quant_axis + self._max = self._cal_abs_max(inputs) + return inputs + + def _cal_abs_max(self, inputs): + self._current_iters += 1 + reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != self.quant_axis()]) + abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis).cast("float32") + abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values) + if self._max is not None: + if self._moving_avg: + # exponential moving average update + update_factor = 1.0 / self._current_iters + update_factor = max(update_factor, self._range_update_factor_min) + abs_max_values = self._max * (1 - update_factor) + abs_max_values * update_factor + else: + # normal average + abs_max_values = (self._max * (self._current_iters - 1) + abs_max_values) / self._current_iters + return abs_max_values + + def min_value(self) -> float: + return 0.0 + + def max_value(self) -> float: + return self._max + + def cal_thresholds(self): + """Compute thresholds for MAX function.""" + if self._scale is not None: + self._zero_point = paddle.zeros_like(self._scale) + return + self._scale = self._max + self._zero_point = paddle.zeros_like(self._scale) + + def scales(self): + """Return output scales.""" + self.cal_thresholds() + return self._scale + + def zero_points(self): + """Return output zero points.""" + self.cal_thresholds() + return self._zero_point diff --git a/paddleslim/quant/observers/channel_wise.py b/paddleslim/quant/observers/channel_wise.py index 9962af835..202ff1c8d 100644 --- a/paddleslim/quant/observers/channel_wise.py +++ b/paddleslim/quant/observers/channel_wise.py @@ -28,25 +28,23 @@ class ChannelWiseObserver(UniformObserver): - def __init__( - self, - layer, - quant_bits=8, - sign=True, - symmetric=True, ): + def __init__(self, layer, quant_bits=8, sign=True, symmetric=True, quant_axis=None): super(ChannelWiseObserver, self).__init__( quant_bits=quant_bits, sign=sign, - symmetric=symmetric, ) - self._channel_axis = CHANNEL_AXIS[type(layer)] + symmetric=symmetric, + ) + if quant_axis is not None: + self._channel_axis = quant_axis + else: + assert type(layer) in CHANNEL_AXIS, "Unsupported layer type: {}".format(type(layer)) + self._channel_axis = CHANNEL_AXIS[type(layer)] self._quant_bits = quant_bits def quant_axis(self): - """ Return quantization axis. - """ + """Return quantization axis.""" return self._channel_axis def bit_length(self): - """ Return the bit length of quantized data. - """ + """Return the bit length of quantized data.""" return self._quant_bits diff --git a/paddleslim/quant/observers/uniform.py b/paddleslim/quant/observers/uniform.py index d874fa687..e5c6ca579 100644 --- a/paddleslim/quant/observers/uniform.py +++ b/paddleslim/quant/observers/uniform.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ class UniformObserver(BaseObserver): - """ This is the base class for a uniform quantization observer, which provides + """This is the base class for a uniform quantization observer, which provides common functions for calculating the scale and zero-point used in uniform quantization. Uniform quantization maps floating point values to integers, where the scale determines the step size of the quantizer and the floating point zero is mapped to the zero-point, @@ -31,14 +31,15 @@ class UniformObserver(BaseObserver): symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric. In symmetric quantization, the range of floating point values is relaxed to be symmetric around zero and the zero-point is always 0. - + """ def __init__( - self, - quant_bits=8, - sign=True, - symmetric=True, ): + self, + quant_bits=8, + sign=True, + symmetric=True, + ): super(UniformObserver, self).__init__() self._quant_bits = quant_bits self._sign = sign @@ -54,14 +55,26 @@ def __init__( @property def qmin_qmax(self): - """ Calculate the range of the quantized integer based on the specified + """Calculate the range of the quantized integer based on the specified quant_bits, sign, and symmetric properties.""" - if self._sign: - self._qmin = -2**(self.bit_length() - 1) - self._qmax = 2**(self.bit_length() - 1) - 1 + if isinstance(self._quant_bits, tuple): + if self._quant_bits[0] == 4 and self._quant_bits[1] == 3 and len(self._quant_bits) == 2: + self._qmin = -448.0 + self._qmax = 448.0 + elif self._quant_bits[0] == 5 and self._quant_bits[1] == 2 and len(self._quant_bits) == 2: + self._qmin = -57344.0 + self._qmax = 57344.0 + else: + raise NotImplementedError( + "Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format." + ) else: - self._qmin = 0 - self._qmax = 2**self.bit_length() + if self._sign: + self._qmin = -(2 ** (self.bit_length() - 1)) + self._qmax = 2 ** (self.bit_length() - 1) - 1 + else: + self._qmin = 0 + self._qmax = 2 ** self.bit_length() return self._qmin, self._qmax @abc.abstractmethod