Skip to content

Commit 73b39a3

Browse files
[TorchFX] quantize_pt2e custom quantizers support (#3487)
### Changes * TorchAOAdapter is updated with new entity: the `ExtendedQuantizerSetup`. It contains additional info about dtypes to use in q->dq pairs * TorchFX MinMax algo backend migrated from restrained stip procedure to flexible custom quantization parameters assignment code ### Reason for changes To fully support quantization via `quantize_pt2e` with custom quantizers (like `XNNPACKQuantizer`) ### Related tickets #3231 ### Tests * flexible custom quantization parameters assignment is tested by `tests/torch/fx/test_calculation_quantizer_params.py` * Conformance test is updated with 2 new configurations: `OV_QUANTIZER_NNCF`, `OV_QUANTIZER_AO` * conformance: post_training_quantization/682/
1 parent e876e1e commit 73b39a3

35 files changed

+15174
-95
lines changed

src/nncf/common/quantization/quantizer_setup.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nncf.common.quantization.structs import NonWeightQuantizerId
2222
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
2323
from nncf.common.quantization.structs import QuantizerConfig
24+
from nncf.common.quantization.structs import TypedQuantizerConfig
2425
from nncf.common.quantization.structs import UnifiedScaleType
2526
from nncf.common.quantization.structs import WeightQuantizerId
2627
from nncf.common.stateful_classes_registry import CommonStatefulClassesRegistry
@@ -193,9 +194,19 @@ def from_state(cls, state: dict[str, Any]) -> "SingleConfigQuantizationPoint":
193194
insertion_point_cls_name = state[cls._state_names.INSERTION_POINT_CLASS_NAME]
194195
insertion_point_cls = CommonStatefulClassesRegistry.get_registered_class(insertion_point_cls_name)
195196
insertion_point = insertion_point_cls.from_state(state[cls._state_names.INSERTION_POINT]) # type: ignore
197+
qconfig_state = state[cls._state_names.QCONFIG]
198+
# Need to instantiate TypedQuantizerConfig
199+
# to support additional fields used by ExecuTorch-specific quantizer configs.
200+
# TODO (dlyakhov): Refactor and generalize quantizer config deserialization to cleanly handle both
201+
# standard and extended formats without relying on manual key comparison (ticket 170078).
202+
if QuantizerConfig().__dict__.keys() == qconfig_state.keys():
203+
qconfig = QuantizerConfig.from_state(qconfig_state)
204+
else:
205+
qconfig = TypedQuantizerConfig.from_state(qconfig_state)
206+
196207
kwargs = {
197208
cls._state_names.INSERTION_POINT: insertion_point,
198-
cls._state_names.QCONFIG: QuantizerConfig.from_state(state[cls._state_names.QCONFIG]),
209+
cls._state_names.QCONFIG: qconfig,
199210
cls._state_names.NAMES_OF_QUANTIZED_OPS: state[cls._state_names.NAMES_OF_QUANTIZED_OPS],
200211
}
201212
return cls(**kwargs)

src/nncf/common/quantization/structs.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from copy import deepcopy
1313
from enum import Enum
14-
from typing import Any, Optional
14+
from typing import Any, Literal, Optional
1515

1616
import nncf
1717
from nncf.common.graph import NNCFNode
@@ -22,6 +22,9 @@
2222
from nncf.config.schemata.defaults import QUANTIZATION_PER_CHANNEL
2323
from nncf.parameters import StrEnum
2424
from nncf.parameters import TargetDevice
25+
from nncf.tensor.definitions import TensorDataType
26+
27+
IntDtype = Literal[TensorDataType.int8, TensorDataType.uint8]
2528

2629

2730
@api()
@@ -421,3 +424,41 @@ def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> dict[s
421424
if quant_group == QuantizerGroup.ACTIVATIONS and self == QuantizationPreset.MIXED:
422425
return {"mode": QuantizationScheme.ASYMMETRIC}
423426
return {"mode": QuantizationScheme.SYMMETRIC}
427+
428+
429+
class TypedQuantizerConfig(QuantizerConfig):
430+
"""
431+
Extended configuration class for quantizers, including destination integer dtype.
432+
"""
433+
434+
def __init__(
435+
self,
436+
num_bits: int = QUANTIZATION_BITS,
437+
mode: QuantizationScheme = QuantizationScheme.SYMMETRIC,
438+
signedness_to_force: Optional[bool] = None,
439+
per_channel: bool = QUANTIZATION_PER_CHANNEL,
440+
narrow_range: bool = QUANTIZATION_NARROW_RANGE,
441+
dest_dtype: IntDtype = TensorDataType.int8,
442+
):
443+
"""
444+
:param num_bits: Bitwidth of the quantization.
445+
:param mode: The mode of quantization (symmetric or asymmetric).
446+
:param signedness_to_force: True if the quantizer *must* be signed, False if *must* be unsigned,
447+
None if the signed/unsigned attribute should be determined based on the incoming activation
448+
statistics during range initialization.
449+
:param per_channel: True for per-channel quantization, False for per-tensor.
450+
:param narrow_range: True if the range of quantized values should be narrowed as compared to the
451+
naive case, False if all 2^`num_bits` quantizations should be used.
452+
:param dest_dtype: Target integer data type for quantized values.
453+
"""
454+
super().__init__(num_bits, mode, signedness_to_force, per_channel, narrow_range)
455+
self.dest_dtype = dest_dtype
456+
457+
def __str__(self) -> str:
458+
retval = super().__str__()
459+
return retval + " DestDtype: {self._dest_dtype}"
460+
461+
def get_state(self) -> dict[str, Any]:
462+
state = super().get_state()
463+
state["dest_dtype"] = self.dest_dtype
464+
return state

src/nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
3030
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
3131
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
32-
from nncf.common.quantization.structs import QuantizerConfig
32+
from nncf.common.quantization.structs import TypedQuantizerConfig
3333
from nncf.experimental.quantization.quantizer import Quantizer
3434
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
35+
from nncf.tensor.definitions import TensorDataType
3536

3637
EdgeOrNode = Union[tuple[torch.fx.Node, torch.fx.Node]]
3738

@@ -71,15 +72,15 @@ def _get_quantization_points(
7172
from_node: torch.fx.Node,
7273
to_nodes: list[torch.fx.Node],
7374
annotated_model: torch.fx.GraphModule,
74-
qconfig: QuantizerConfig,
75+
qconfig: TypedQuantizerConfig,
7576
) -> list[QuantizationPointBase]:
7677
"""
7778
Creates quantization points based on the nodes and edges.
7879
7980
:param from_node: The originating node in the computation graph.
8081
:param to_nodes: The list of destination nodes of the from_node.
8182
:param annotated_model: The torch.fx.GraphModule instance.
82-
:param qconfig: The torch.ao quantization configuration.
83+
:param qconfig: The TorchFX quantization configuration.
8384
:return: A list of NNCF quantization points.
8485
"""
8586
to_n = to_nodes[0]
@@ -159,15 +160,19 @@ def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -
159160
msg = f"Unknown qscheme: {qspec.qscheme}"
160161
raise nncf.InternalError(msg)
161162

162-
signed = qspec.dtype is torch.int8
163+
dtype = TensorDataType.int8 if qspec.dtype is torch.int8 else TensorDataType.uint8
163164
mode = (
164165
QuantizationMode.SYMMETRIC
165166
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
166167
else QuantizationMode.ASYMMETRIC
167168
)
168-
narrow_range = qspec.quant_min % 2 != 0
169-
qconfig = QuantizerConfig(
170-
mode=mode, signedness_to_force=signed, per_channel=per_channel, narrow_range=narrow_range
169+
narrow_range = qspec.quant_max - qspec.quant_min == 254
170+
qconfig = TypedQuantizerConfig(
171+
mode=mode,
172+
signedness_to_force=False,
173+
per_channel=per_channel,
174+
narrow_range=narrow_range,
175+
dest_dtype=dtype,
171176
)
172177

173178
joined_edges = defaultdict(list)

src/nncf/parameters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212
from enum import Enum
13+
from typing import Any
1314

1415
from nncf.common.utils.api_marker import api
1516

@@ -18,6 +19,10 @@ class StrEnum(str, Enum):
1819
def __str__(self) -> str:
1920
return str(self.value)
2021

22+
@staticmethod
23+
def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]) -> Any:
24+
return name.lower()
25+
2126

2227
@api(canonical_alias="nncf.TargetDevice")
2328
class TargetDevice(StrEnum):

src/nncf/quantization/algorithms/min_max/torch_fx_backend.py

Lines changed: 73 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from nncf.common.graph.transformations.commands import TransformationCommand
2424
from nncf.common.hardware.config import HWConfig
2525
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
26+
from nncf.common.quantization.structs import QuantizationScheme
2627
from nncf.common.quantization.structs import QuantizerConfig
28+
from nncf.common.quantization.structs import TypedQuantizerConfig
2729
from nncf.experimental.common.tensor_statistics.collectors import REDUCERS_MAP
2830
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
2931
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
@@ -35,6 +37,7 @@
3537
from nncf.quantization.fake_quantize import FakeConvertParameters
3638
from nncf.quantization.fake_quantize import FakeQuantizeParameters
3739
from nncf.quantization.range_estimator import StatisticsType
40+
from nncf.tensor.definitions import TensorDataType
3841
from nncf.torch.graph.graph import PTNNCFGraph
3942
from nncf.torch.graph.graph import PTTargetPoint
4043
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
@@ -46,12 +49,7 @@
4649
from nncf.torch.model_graph_manager import is_matmul_with_constant
4750
from nncf.torch.nncf_network import NNCFNetwork
4851
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
49-
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
50-
from nncf.torch.quantization.layers import AsymmetricQuantizer
51-
from nncf.torch.quantization.layers import BaseQuantizer
52-
from nncf.torch.quantization.layers import PTQuantizerSpec
53-
from nncf.torch.quantization.layers import get_scale_shape
54-
from nncf.torch.quantization.strip import convert_to_torch_fakequantizer
52+
from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high
5553

5654

5755
class FXMinMaxAlgoBackend(MinMaxAlgoBackend):
@@ -175,63 +173,83 @@ def get_weight_config(config: QuantizerConfig, model: NNCFNetwork) -> QuantizerC
175173
return config
176174

177175
@staticmethod
178-
def _get_input_scale_shape(
179-
nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool
180-
) -> tuple[tuple[int, ...], tuple[int, ...], int]:
181-
is_weights = target_point.is_weight_target_point()
182-
if is_weights:
176+
def _get_channel_axis(is_weight_quantizer: bool) -> int:
177+
if is_weight_quantizer:
183178
# TODO(dlyakhov): support transpose conv/ make channel_idx common
184-
channel_idx = 0
185-
else:
186-
channel_idx = 1 # channel dim for activations
187-
188-
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)
189-
scale_shape = tuple(
190-
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx)
191-
)
192-
193-
return input_shape, scale_shape, channel_idx
179+
return 0
180+
return 1
194181

195182
@staticmethod
196183
def _create_quantizer(
197184
quantizer_config: QuantizerConfig,
198-
scale_shape: tuple,
199185
parameters: FakeQuantizeParameters,
200-
target_type: TargetType,
186+
is_weight_quantizer: bool,
201187
) -> FakeQuantize:
202-
mode = quantizer_config.mode
203-
quantizer_cls = QUANTIZATION_MODULES.get(mode)
204-
quantizer_spec = PTQuantizerSpec.from_config(
205-
quantizer_config,
206-
narrow_range=quantizer_config.narrow_range,
207-
scale_shape=scale_shape,
208-
half_range=False,
209-
logarithm_scale=False,
210-
is_quantized_on_export=False,
211-
compression_lr_multiplier=None,
212-
)
213-
quantizer = quantizer_cls(quantizer_spec)
188+
per_channel = quantizer_config.per_channel
189+
dtype = None
190+
if isinstance(quantizer_config, TypedQuantizerConfig):
191+
dtype = quantizer_config.dest_dtype
214192

215-
# Fill it with minmax
216-
# TODO(dlyakhov) Prevent creation of intermediate objects like nncf quantizer.
217-
FXMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape)
218-
# Convert to the torch fake quantizer
219-
torch_fq = convert_to_torch_fakequantizer(quantizer)
220-
return torch_fq
193+
if dtype not in [TensorDataType.int8, TensorDataType.uint8]:
194+
msg = f"Quantization configurations with dest_dtype=={dtype} are not supported."
195+
raise nncf.ParameterNotSupportedError(msg)
221196

222-
@staticmethod
223-
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None:
224-
if isinstance(quantizer, AsymmetricQuantizer):
225-
quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape))
226-
input_range = parameters.input_high - parameters.input_low
227-
# Subtract eps from the input_range to make quantizer parameters equal to
228-
# original parameters on the forward call.
229-
quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape))
197+
elif quantizer_config.mode != QuantizationScheme.SYMMETRIC:
198+
dtype = TensorDataType.uint8
199+
else:
200+
dtype = (
201+
TensorDataType.int8
202+
if quantizer_config.signedness_to_force or torch.any(parameters.input_low.data < 0.0)
203+
else TensorDataType.uint8
204+
)
205+
206+
if per_channel:
207+
observer = torch.ao.quantization.observer.PerChannelMinMaxObserver
208+
else:
209+
observer = torch.ao.quantization.observer.MinMaxObserver
210+
211+
if dtype is TensorDataType.int8:
212+
level_high = 127
213+
level_low = -128
214+
else:
215+
level_high = 255
216+
level_low = 0
217+
218+
if quantizer_config.narrow_range:
219+
if level_low < 0:
220+
level_low += 1
221+
else:
222+
level_high -= 1
223+
224+
if quantizer_config.mode == QuantizationScheme.SYMMETRIC:
225+
qscheme = torch.per_channel_symmetric if per_channel else torch.per_tensor_symmetric
230226
else:
231-
quantizer.signed = bool(torch.any(parameters.input_low.data < 0))
232-
# Subtract eps from the scale to make quantizer parameters equal to
233-
# original parameters on the forward call.
234-
quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape))
227+
qscheme = torch.per_channel_affine if per_channel else torch.per_tensor_affine
228+
229+
scale, zero_point = get_scale_zp_from_input_low_input_high(
230+
level_low, level_high, parameters.input_low.data, parameters.input_high.data
231+
)
232+
233+
scale = scale.view(-1)
234+
zero_point = zero_point.view(-1)
235+
236+
fakequantizer = FakeQuantize(
237+
observer=observer,
238+
quant_max=level_high,
239+
quant_min=level_low,
240+
dtype=torch.qint8 if dtype is TensorDataType.int8 else torch.quint8,
241+
qscheme=qscheme,
242+
eps=1e-16,
243+
)
244+
245+
fakequantizer.scale = scale
246+
fakequantizer.zero_point = zero_point
247+
if per_channel:
248+
fakequantizer.ch_axis = FXMinMaxAlgoBackend._get_channel_axis(is_weight_quantizer)
249+
250+
# Disable observer to save parameters
251+
fakequantizer.disable_observer()
252+
return fakequantizer
235253

236254
@staticmethod
237255
def create_quantizer_insertion_command(
@@ -240,12 +258,8 @@ def create_quantizer_insertion_command(
240258
quantizer_config: QuantizerConfig,
241259
parameters: FakeQuantizeParameters,
242260
) -> FXApplyTransformationCommand:
243-
_, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape(
244-
nncf_graph, target_point, quantizer_config.per_channel
245-
)
246-
247261
quantizer = FXMinMaxAlgoBackend._create_quantizer(
248-
quantizer_config, scale_shape, parameters, target_point.target_type
262+
quantizer_config, parameters, target_point.is_weight_target_point()
249263
)
250264
transformation = qdq_insertion_transformation_builder(quantizer, [target_point])
251265
return FXApplyTransformationCommand(transformation)
@@ -257,12 +271,8 @@ def create_unified_scales_quantizers_insertion_commands(
257271
quantizer_config: QuantizerConfig,
258272
parameters: FakeQuantizeParameters,
259273
) -> list[PTSharedFnInsertionCommand]:
260-
_, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape(
261-
nncf_graph, target_points[0], quantizer_config.per_channel
262-
)
263-
264274
quantizer = FXMinMaxAlgoBackend._create_quantizer(
265-
quantizer_config, scale_shape, parameters, target_points[0].target_type
275+
quantizer_config, parameters, target_points[0].is_weight_target_point()
266276
)
267277

268278
transformations = []

src/nncf/tensor/definitions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from enum import auto
1515
from typing import Optional, Union
1616

17+
from nncf.parameters import StrEnum
18+
1719
T_SHAPE_ARRAY = tuple[int, ...]
1820
T_SHAPE = Union[int, T_SHAPE_ARRAY]
1921
T_AXIS = Optional[T_SHAPE]
@@ -31,7 +33,7 @@ class TensorBackend(Enum):
3133
ov = auto()
3234

3335

34-
class TensorDataType(Enum):
36+
class TensorDataType(StrEnum):
3537
"""
3638
Enum representing the different tensor data types.
3739
"""

0 commit comments

Comments
 (0)