2323from nncf .common .graph .transformations .commands import TransformationCommand
2424from nncf .common .hardware .config import HWConfig
2525from nncf .common .quantization .quantizer_propagation .structs import QuantizationTrait
26+ from nncf .common .quantization .structs import QuantizationScheme
2627from nncf .common .quantization .structs import QuantizerConfig
28+ from nncf .common .quantization .structs import TypedQuantizerConfig
2729from nncf .experimental .common .tensor_statistics .collectors import REDUCERS_MAP
2830from nncf .experimental .common .tensor_statistics .collectors import TensorReducerBase
2931from nncf .experimental .torch .fx .commands import FXApplyTransformationCommand
3537from nncf .quantization .fake_quantize import FakeConvertParameters
3638from nncf .quantization .fake_quantize import FakeQuantizeParameters
3739from nncf .quantization .range_estimator import StatisticsType
40+ from nncf .tensor .definitions import TensorDataType
3841from nncf .torch .graph .graph import PTNNCFGraph
3942from nncf .torch .graph .graph import PTTargetPoint
4043from nncf .torch .graph .operator_metatypes import ELEMENTWISE_OPERATIONS
4649from nncf .torch .model_graph_manager import is_matmul_with_constant
4750from nncf .torch .nncf_network import NNCFNetwork
4851from nncf .torch .quantization .default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
49- from nncf .torch .quantization .layers import QUANTIZATION_MODULES
50- from nncf .torch .quantization .layers import AsymmetricQuantizer
51- from nncf .torch .quantization .layers import BaseQuantizer
52- from nncf .torch .quantization .layers import PTQuantizerSpec
53- from nncf .torch .quantization .layers import get_scale_shape
54- from nncf .torch .quantization .strip import convert_to_torch_fakequantizer
52+ from nncf .torch .quantization .quantize_functions import get_scale_zp_from_input_low_input_high
5553
5654
5755class FXMinMaxAlgoBackend (MinMaxAlgoBackend ):
@@ -175,63 +173,83 @@ def get_weight_config(config: QuantizerConfig, model: NNCFNetwork) -> QuantizerC
175173 return config
176174
177175 @staticmethod
178- def _get_input_scale_shape (
179- nncf_graph : NNCFGraph , target_point : PTTargetPoint , per_channel : bool
180- ) -> tuple [tuple [int , ...], tuple [int , ...], int ]:
181- is_weights = target_point .is_weight_target_point ()
182- if is_weights :
176+ def _get_channel_axis (is_weight_quantizer : bool ) -> int :
177+ if is_weight_quantizer :
183178 # TODO(dlyakhov): support transpose conv/ make channel_idx common
184- channel_idx = 0
185- else :
186- channel_idx = 1 # channel dim for activations
187-
188- input_shape = nncf_graph .get_input_shape_for_insertion_point (target_point )
189- scale_shape = tuple (
190- get_scale_shape (input_shape , is_weights = is_weights , per_channel = per_channel , channel_idx = channel_idx )
191- )
192-
193- return input_shape , scale_shape , channel_idx
179+ return 0
180+ return 1
194181
195182 @staticmethod
196183 def _create_quantizer (
197184 quantizer_config : QuantizerConfig ,
198- scale_shape : tuple ,
199185 parameters : FakeQuantizeParameters ,
200- target_type : TargetType ,
186+ is_weight_quantizer : bool ,
201187 ) -> FakeQuantize :
202- mode = quantizer_config .mode
203- quantizer_cls = QUANTIZATION_MODULES .get (mode )
204- quantizer_spec = PTQuantizerSpec .from_config (
205- quantizer_config ,
206- narrow_range = quantizer_config .narrow_range ,
207- scale_shape = scale_shape ,
208- half_range = False ,
209- logarithm_scale = False ,
210- is_quantized_on_export = False ,
211- compression_lr_multiplier = None ,
212- )
213- quantizer = quantizer_cls (quantizer_spec )
188+ per_channel = quantizer_config .per_channel
189+ dtype = None
190+ if isinstance (quantizer_config , TypedQuantizerConfig ):
191+ dtype = quantizer_config .dest_dtype
214192
215- # Fill it with minmax
216- # TODO(dlyakhov) Prevent creation of intermediate objects like nncf quantizer.
217- FXMinMaxAlgoBackend ._fill_quantizer_parameters (quantizer , parameters , quantizer_spec .scale_shape )
218- # Convert to the torch fake quantizer
219- torch_fq = convert_to_torch_fakequantizer (quantizer )
220- return torch_fq
193+ if dtype not in [TensorDataType .int8 , TensorDataType .uint8 ]:
194+ msg = f"Quantization configurations with dest_dtype=={ dtype } are not supported."
195+ raise nncf .ParameterNotSupportedError (msg )
221196
222- @staticmethod
223- def _fill_quantizer_parameters (quantizer : BaseQuantizer , parameters : FakeQuantizeParameters , scale_shape ) -> None :
224- if isinstance (quantizer , AsymmetricQuantizer ):
225- quantizer .input_low = torch .nn .Parameter (parameters .input_low .data .reshape (scale_shape ))
226- input_range = parameters .input_high - parameters .input_low
227- # Subtract eps from the input_range to make quantizer parameters equal to
228- # original parameters on the forward call.
229- quantizer .input_range = torch .nn .Parameter ((input_range .data - quantizer .eps ).reshape (scale_shape ))
197+ elif quantizer_config .mode != QuantizationScheme .SYMMETRIC :
198+ dtype = TensorDataType .uint8
199+ else :
200+ dtype = (
201+ TensorDataType .int8
202+ if quantizer_config .signedness_to_force or torch .any (parameters .input_low .data < 0.0 )
203+ else TensorDataType .uint8
204+ )
205+
206+ if per_channel :
207+ observer = torch .ao .quantization .observer .PerChannelMinMaxObserver
208+ else :
209+ observer = torch .ao .quantization .observer .MinMaxObserver
210+
211+ if dtype is TensorDataType .int8 :
212+ level_high = 127
213+ level_low = - 128
214+ else :
215+ level_high = 255
216+ level_low = 0
217+
218+ if quantizer_config .narrow_range :
219+ if level_low < 0 :
220+ level_low += 1
221+ else :
222+ level_high -= 1
223+
224+ if quantizer_config .mode == QuantizationScheme .SYMMETRIC :
225+ qscheme = torch .per_channel_symmetric if per_channel else torch .per_tensor_symmetric
230226 else :
231- quantizer .signed = bool (torch .any (parameters .input_low .data < 0 ))
232- # Subtract eps from the scale to make quantizer parameters equal to
233- # original parameters on the forward call.
234- quantizer .scale = torch .nn .Parameter ((parameters .input_high .data - quantizer .eps ).reshape (scale_shape ))
227+ qscheme = torch .per_channel_affine if per_channel else torch .per_tensor_affine
228+
229+ scale , zero_point = get_scale_zp_from_input_low_input_high (
230+ level_low , level_high , parameters .input_low .data , parameters .input_high .data
231+ )
232+
233+ scale = scale .view (- 1 )
234+ zero_point = zero_point .view (- 1 )
235+
236+ fakequantizer = FakeQuantize (
237+ observer = observer ,
238+ quant_max = level_high ,
239+ quant_min = level_low ,
240+ dtype = torch .qint8 if dtype is TensorDataType .int8 else torch .quint8 ,
241+ qscheme = qscheme ,
242+ eps = 1e-16 ,
243+ )
244+
245+ fakequantizer .scale = scale
246+ fakequantizer .zero_point = zero_point
247+ if per_channel :
248+ fakequantizer .ch_axis = FXMinMaxAlgoBackend ._get_channel_axis (is_weight_quantizer )
249+
250+ # Disable observer to save parameters
251+ fakequantizer .disable_observer ()
252+ return fakequantizer
235253
236254 @staticmethod
237255 def create_quantizer_insertion_command (
@@ -240,12 +258,8 @@ def create_quantizer_insertion_command(
240258 quantizer_config : QuantizerConfig ,
241259 parameters : FakeQuantizeParameters ,
242260 ) -> FXApplyTransformationCommand :
243- _ , scale_shape , _ = FXMinMaxAlgoBackend ._get_input_scale_shape (
244- nncf_graph , target_point , quantizer_config .per_channel
245- )
246-
247261 quantizer = FXMinMaxAlgoBackend ._create_quantizer (
248- quantizer_config , scale_shape , parameters , target_point .target_type
262+ quantizer_config , parameters , target_point .is_weight_target_point ()
249263 )
250264 transformation = qdq_insertion_transformation_builder (quantizer , [target_point ])
251265 return FXApplyTransformationCommand (transformation )
@@ -257,12 +271,8 @@ def create_unified_scales_quantizers_insertion_commands(
257271 quantizer_config : QuantizerConfig ,
258272 parameters : FakeQuantizeParameters ,
259273 ) -> list [PTSharedFnInsertionCommand ]:
260- _ , scale_shape , _ = FXMinMaxAlgoBackend ._get_input_scale_shape (
261- nncf_graph , target_points [0 ], quantizer_config .per_channel
262- )
263-
264274 quantizer = FXMinMaxAlgoBackend ._create_quantizer (
265- quantizer_config , scale_shape , parameters , target_points [0 ].target_type
275+ quantizer_config , parameters , target_points [0 ].is_weight_target_point ()
266276 )
267277
268278 transformations = []
0 commit comments