1515
1616import torch
1717import torch .fx
18- from torch .ao .quantization .fx .utils import create_getattr_from_value
19- from torch .ao .quantization .pt2e .utils import _fuse_conv_bn_
2018from torch .fx .node import map_arg
2119from torch .fx .passes .infra .pass_base import PassBase
2220from torch .fx .passes .infra .pass_base import PassResult
23- from torch .quantization .fake_quantize import FakeQuantize
2421
2522import nncf
2623import nncf .torch
2926from nncf .experimental .torch .fx .constant_folding import constant_fold
3027from nncf .experimental .torch .fx .node_utils import get_graph_node_by_name
3128from nncf .experimental .torch .fx .node_utils import get_tensor_constant_from_node
29+ from nncf .experimental .torch .fx .quantization .qdq_parameters import TorchQDQParameters
3230from nncf .torch .graph .transformations .commands import PTTargetPoint
3331
3432TransformationFNType = Callable [[torch .fx .GraphModule ], None ]
@@ -223,16 +221,16 @@ def constant_update_fn(
223221
224222
225223def qdq_insertion_transformation_builder (
226- quantizer : FakeQuantize , target_points : list [PTTargetPoint ]
224+ parameters : TorchQDQParameters , target_points : list [PTTargetPoint ]
227225) -> TransformationFNType :
228226 """
229- Returns transformation which inserts quantize-dequantize operations with parameters
230- inherited from the given quantizer to each given target point.
227+ Returns transformation which inserts quantize-dequantize operations with
228+ the given parameters to each given target point.
231229
232- :param quantizer: Quantizer module to inherit quantization parameters from .
230+ :param quantizer: Quantization parameters.
233231 :param target_points: List of target point used to insert quantize-dequantize pairs.
234- :return: Transformation which inserts quantize-dequantize operations with parameters
235- inherited from the given quantizer to each given target point.
232+ :return: Transformation which inserts quantize-dequantize operations with
233+ the given parameters to each given target point.
236234 """
237235
238236 def qdq_insertion_transformation (model : torch .fx .GraphModule ):
@@ -243,7 +241,7 @@ def qdq_insertion_transformation(model: torch.fx.GraphModule):
243241 )
244242 raise nncf .InternalError (msg )
245243 for target_point in target_points :
246- insert_one_qdq (model , target_point , quantizer )
244+ insert_one_qdq (model , target_point , parameters )
247245
248246 return qdq_insertion_transformation
249247
@@ -311,38 +309,38 @@ def output_insertion_transformation(model: torch.fx.GraphModule):
311309 return output_insertion_transformation
312310
313311
314- def insert_one_qdq (model : torch .fx .GraphModule , target_point : PTTargetPoint , quantizer : FakeQuantize ):
312+ def insert_one_qdq (model : torch .fx .GraphModule , target_point : PTTargetPoint , parameters : TorchQDQParameters ):
315313 """
316314 Inserts quantize-dequantize after the target node to the target model.
317315
318316 :param model: Target model.
319317 :param target_node: Target node, quantizer-dequantizer pair is inserted just after the
320318 target node.
321- :param quantizer: Quantizer module to inherit quantization parameters from .
319+ :param parameters: Quantization parameters.
322320 """
323- # Copied from torch.ao .quantization.quantize_pt2e.convert_pt2e
321+ # Copied from torchao .quantization.quantize_pt2e.convert_pt2e
324322 # 1. extract information for inserting q/dq node from activation_post_process
325323 node_type = "call_function"
326324 quantize_op : Optional [Callable ] = None
327325
328- dtype = torch .int8 if quantizer .quant_min < 0 else torch .uint8
329- if quantizer .is_per_channel :
326+ dtype = torch .int8 if parameters .quant_min < 0 else torch .uint8
327+ if parameters .is_per_channel :
330328 qparams = {
331- "_scale_" : quantizer .scale ,
332- "_zero_point_" : quantizer .zero_point ,
333- "_axis_" : quantizer .ch_axis ,
334- "_quant_min_" : quantizer .quant_min ,
335- "_quant_max_" : quantizer .quant_max ,
329+ "_scale_" : parameters .scale ,
330+ "_zero_point_" : parameters .zero_point ,
331+ "_axis_" : parameters .ch_axis ,
332+ "_quant_min_" : parameters .quant_min ,
333+ "_quant_max_" : parameters .quant_max ,
336334 "_dtype_" : dtype ,
337335 }
338336 quantize_op = torch .ops .quantized_decomposed .quantize_per_channel .default
339337 dequantize_op = torch .ops .quantized_decomposed .dequantize_per_channel .default
340338 else :
341339 qparams = {
342- "_scale_" : float (quantizer .scale ),
343- "_zero_point_" : int (quantizer .zero_point ),
344- "_quant_min_" : quantizer .quant_min ,
345- "_quant_max_" : quantizer .quant_max ,
340+ "_scale_" : float (parameters .scale ),
341+ "_zero_point_" : int (parameters .zero_point ),
342+ "_quant_min_" : parameters .quant_min ,
343+ "_quant_max_" : parameters .quant_max ,
346344 "_dtype_" : dtype ,
347345 }
348346 quantize_op = torch .ops .quantized_decomposed .quantize_per_tensor .default
@@ -721,19 +719,6 @@ def match_filters(match, original_graph, graph):
721719 _set_meta_for_matches (model , matches )
722720
723721
724- def apply_quantization_transformations (model : torch .fx .GraphModule ) -> None :
725- """
726- Applies quantization transformations to the model.
727-
728- :param model: Model to apply transformations to.
729- """
730- # BatchNorm operations have 3 output ports,
731- # to make it easier for algorithms to work
732- # with the target graph BatchNorm operations
733- # are being fused
734- _fuse_conv_bn_ (model )
735-
736-
737722def fold_constant_except_qdq (model : torch .fx .GraphModule ):
738723 """
739724 Performs constant folding avoiding quantize-dequantize pattern.
@@ -826,3 +811,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
826811 graph_module .graph .eliminate_dead_code ()
827812 graph_module .recompile ()
828813 return PassResult (graph_module , True )
814+
815+
816+ def get_device (module : torch .nn .Module ) -> torch .device :
817+ """
818+ Retrieves device of the first parameter of the given module.
819+ If there are no parameters - returns CPU device.
820+
821+ :param module: A torch.nn.Module instance.
822+ :return: A device of the first parameter of the given module.
823+ If there are no parameters - returns CPU device.
824+ """
825+ try :
826+ named_param = next (module .parameters ())
827+ except StopIteration :
828+ named_param = None
829+ if named_param is None :
830+ return torch .device ("cpu" )
831+ return named_param .device
832+
833+
834+ def create_getattr_from_value (module : torch .nn .Module , graph : torch .fx .Graph , prefix : str , value : Any ) -> torch .fx .Node :
835+ """
836+ Given a value of any type, creates a getattr node corresponding to the value and
837+ registers the value as a buffer to the module.
838+
839+ :param module: A torch.nn.Module instance.
840+ :param graph: A torch.fx.Graph instance.
841+ :param prefix: A string to use as a name prefix for the new getattr node.
842+ :param value: A value
843+ :return: A getattr node corresponding to the given value.
844+ """
845+
846+ def get_new_attr_name (module : torch .nn .Module , prefix : str ):
847+ def get_attr_name (i : int ):
848+ return prefix + str (i )
849+
850+ i = 0
851+ attr_name = get_attr_name (i )
852+ while hasattr (module , attr_name ):
853+ i += 1
854+ attr_name = get_attr_name (i )
855+ return attr_name
856+
857+ attr_name = get_new_attr_name (module , prefix .replace ("." , "_" ))
858+ device = get_device (module )
859+ new_value = value .detach ().clone () if isinstance (value , torch .Tensor ) else torch .tensor (value , device = device )
860+ module .register_buffer (attr_name , new_value )
861+ attr_node = graph .create_node ("get_attr" , attr_name )
862+ return attr_node
0 commit comments