diff --git a/nncf/openvino/graph/nncf_graph_builder.py b/nncf/openvino/graph/nncf_graph_builder.py index d14b6cd4946..519911d7410 100644 --- a/nncf/openvino/graph/nncf_graph_builder.py +++ b/nncf/openvino/graph/nncf_graph_builder.py @@ -98,7 +98,7 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None: in_node_id = graph.get_node_by_name(op.get_friendly_name()).node_id for output_port_id, out in enumerate(op.outputs()): node_vs_target_inputs = defaultdict(list) - for inp in out.get_target_inputs(): + for inp in sorted(out.get_target_inputs(), key=lambda inp: inp.get_node().get_friendly_name()): node_vs_target_inputs[inp.get_node()].append(inp) for out_node, inputs in node_vs_target_inputs.items(): diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 79363722934..2d126e60474 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -709,19 +709,22 @@ def apply( ) return transformed_model - def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]: + def _get_activation_node_port_and_channel(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]: """ - This method returns the activation layer and corresponding port id for the node. + This method returns the activation layer, corresponding port id and channel axis for the given node. :param node: NNCFGraph node for which the activation is sought. :param nncf_graph: NNCFGraph instance with the node. - :return: Tuple with the activation node and port id. + :return: Tuple with the activation node, port id and channel axis. """ activation_port = self._backend_entity.get_activation_port_id(node, nncf_graph) activation_edge = nncf_graph.get_input_edge_by_port_id(node, activation_port) activation_node = activation_edge.from_node port_id = activation_edge.output_port_id - return activation_node, port_id + activation_channel_axis = self._backend_entity.get_activation_channel_axis( + node, port_id, activation_edge.tensor_shape + ) + return activation_node, port_id, activation_channel_axis def get_matmul_input_to_output_nodes_map( self, matmul_nodes: list[NNCFNode], graph: NNCFGraph @@ -742,8 +745,8 @@ def get_matmul_input_to_output_nodes_map( """ matmul_input_to_output_nodes_map = defaultdict(list) for node in matmul_nodes: - act_node, output_port_id = self._get_activation_node_and_port(node, graph) - matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node) + act_node, output_port_id, act_channel_axis = self._get_activation_node_port_and_channel(node, graph) + matmul_input_to_output_nodes_map[(act_node, output_port_id, act_channel_axis)].append(node) return matmul_input_to_output_nodes_map def get_compression_nodes_info( @@ -809,15 +812,17 @@ def get_statistic_points( statistic_container = StatisticPointsContainer() # Statistics for data aware algorithms if self._data_aware_compression: - for node, output_port_id in nodes_and_port_ids: + for node, output_port_id, input_channel_axis in nodes_and_port_ids: statistic_point = self._backend_entity.target_point( TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id ) - # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden - # size dimension. + # Reduce activations across all but the hidden dimension. n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) + # negative axis (e.g. -1 for the last axis) is converted into corresponding positive value + input_channel_axis = input_channel_axis % n_dims + reduction_axes = tuple(i for i in range(n_dims) if i != input_channel_axis) stat_collector = self._backend_entity.mean_statistic_collector( - reduction_axes=tuple(range(n_dims - 1)), subset_size=self._subset_size + reduction_axes=reduction_axes, subset_size=self._subset_size ) statistic_container.add_statistic_point( StatisticPoint( @@ -854,7 +859,7 @@ def _get_statistics_for_weights_compression( # Where mean_value is a 1D tensor representing an activation reduced over batch and sequence length dimensions, # shape is an original shape of an activation before reduction, n is the size of the dataset (or subset_size). statistics = {} - for (act_node, output_port_id), matmul_nodes in matmul_input_to_output_nodes_map.items(): + for (act_node, output_port_id, _), matmul_nodes in matmul_input_to_output_nodes_map.items(): tensor_collectors = list( statistic_points.get_algo_statistics_for_node( act_node.node_name, diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py index 62d0745a0f4..bd24e872b26 100644 --- a/nncf/quantization/algorithms/weight_compression/backend.py +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -257,6 +257,18 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> :return: Backend-specific callable to filter statistic containers according to its statistic point. """ + @staticmethod + @abstractmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + """ + Returns axis number of the activation tensor which correspond to it channel. + + :param node: NNCFNode instance. + :param port_id: Port ID for input. + :param input_shape: Shape of the input. + :return: Channel axis number. + """ + class AWQAlgoBackend(WeightCompressionAlgoBackend): @staticmethod diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 76674fd9288..b6f2e27c765 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -123,8 +123,13 @@ def apply( ]: continue _, input_tensors = next(iter(inputs.items())) - hessian = self._calculate_hessian(node, input_tensors) - scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors) + input_channel_axis = self._backend_entity.get_activation_channel_axis( + node, self._backend_entity.get_activation_port_id(node, graph), input_tensors[0].shape + ) + hessian = self._calculate_hessian(node, input_tensors, input_channel_axis) + scale, zero_point = self._quantize_weights( + model, graph, wc_params, hessian, input_tensors, input_channel_axis + ) scales[wc_params.weight_name] = scale zero_points[wc_params.weight_name] = zero_point @@ -157,7 +162,7 @@ def get_statistic_points( return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes) - def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor: + def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor], input_channel_axis: int) -> Tensor: """ Calculates the Hessian matrix for the given node and inputs. @@ -170,19 +175,18 @@ def _calculate_hessian(self, node: NNCFNode, inputs: list[Tensor]) -> Tensor: if node.metatype in self._backend_entity.convolution_metatypes: msg = "Convolution metatypes are not supported" raise nncf.UnsupportedModelError(msg) - if node.layer_attributes.input_attributes["transpose"]: - msg = "Transposed input is not supported" - raise nncf.UnsupportedModelError(msg) hessian = fns.zeros( - (inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32 + (inputs[0].shape[input_channel_axis], inputs[0].shape[input_channel_axis]), + backend=inputs[0].backend, + dtype=TensorDataType.float32, ) for inp in inputs: batch_size = 1 if len(inp.shape) == 2 else inp.shape[0] if node.metatype in self._backend_entity.matmul_metatypes: if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.reshape((-1, inp.shape[input_channel_axis])) inp = fns.transpose(inp) hessian *= nsamples / (nsamples + batch_size) nsamples += batch_size @@ -198,6 +202,7 @@ def _quantize_weights( wc_params: WeightCompressionParameters, hessian: Tensor, inputs: list[Tensor], + input_channel_axis: int, ): """ Quantizes the weights of the model based on the calculated Hessian matrix. @@ -210,10 +215,7 @@ def _quantize_weights( """ if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes: msg = "Convolution metatypes are not supported" - raise RuntimeError(msg) - if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]: - msg = "Transpose is not supported" - raise RuntimeError(msg) + raise nncf.UnsupportedModelError(msg) weight_tensor = self._backend_entity.get_weight( wc_params.node_with_weight, wc_params.weight_port_id, model, graph @@ -267,8 +269,12 @@ def _quantize_weights( scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: - activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] - wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) + slicing_along_axis = [slice(None)] * len(inputs[0].shape) + slicing_along_axis[input_channel_axis] = slice(i1 + i, i1 + i + group_size) + activations = [inp[tuple(slicing_along_axis)] for inp in inputs] + wc_statistics = ScaleEstimation.activations_to_wc_statistics( + activations, input_channel_axis + ) scale, zero_point = ScaleEstimation.calculate_quantization_params( wc_statistics, weight_tensor[:, (i1 + i) : (i1 + i + group_size)], diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index 6bb9391d0cb..fa185473965 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -277,7 +277,7 @@ def get_statistic_points( self._set_backend_entity(model) statistic_container = StatisticPointsContainer() - for act_node, output_port_id in nodes_and_port_ids: + for act_node, output_port_id, _ in nodes_and_port_ids: n_dims = len(graph.get_output_edges_by_port_id(act_node, output_port_id)[0].tensor_shape) if n_dims < 2: msg = ( diff --git a/nncf/quantization/algorithms/weight_compression/onnx_backend.py b/nncf/quantization/algorithms/weight_compression/onnx_backend.py index a962cf163bc..36a7f9e6170 100644 --- a/nncf/quantization/algorithms/weight_compression/onnx_backend.py +++ b/nncf/quantization/algorithms/weight_compression/onnx_backend.py @@ -32,6 +32,7 @@ from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES from nncf.onnx.graph.model_transformer import remove_initializer from nncf.onnx.graph.model_transformer import set_initializer +from nncf.onnx.graph.node_utils import get_act_quantization_axis from nncf.onnx.graph.node_utils import get_weight_quantization_axis from nncf.onnx.graph.onnx_helper import ONNX_DTYPE_TO_NNCF_DTYPE from nncf.onnx.graph.onnx_helper import get_name_to_node_map @@ -240,6 +241,10 @@ def filter_func(point: StatisticPoint) -> bool: return filter_func + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + return get_act_quantization_axis(node, port_id) + def insert_adapters( self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool ) -> None: diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 7c1838eb8d2..f284073b4f9 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -33,6 +33,7 @@ from nncf.openvino.graph.model_transformer import OVModelTransformer from nncf.openvino.graph.node_utils import convert_op from nncf.openvino.graph.node_utils import create_ov_const_from_tensor +from nncf.openvino.graph.node_utils import get_activation_channel_axis from nncf.openvino.graph.node_utils import get_const_value_as_numpy_tensor from nncf.openvino.graph.node_utils import get_const_value_as_ov_tensor from nncf.openvino.graph.node_utils import get_weight_channel_axes @@ -114,9 +115,6 @@ def mean_statistic_collector( @staticmethod def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: - if node.layer_attributes.input_attributes["transpose"]: - msg = "Transposed input is not supported" - raise nncf.UnsupportedModelError(msg) constant_ports = node.layer_attributes.get_const_port_ids() activation_ports = [ e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports @@ -133,6 +131,9 @@ def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> list[tupl return result def get_weight(self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.Model, graph: NNCFGraph) -> Tensor: + if not node_with_weight.layer_attributes.constant_attributes[weight_port_id]["transpose"]: + msg = "Only transposed weights are supported" + raise nncf.UnsupportedModelError(msg) weight_name = node_with_weight.layer_attributes.constant_attributes[weight_port_id]["name"] weight_node = self.name_to_node_mapping[weight_name] weight_tensor = get_const_value_as_numpy_tensor(weight_node) @@ -199,7 +200,12 @@ def insert_adapters( A_W = opset.constant(lora_A.data) B_W = opset.constant(lora_B.data) - A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True) + A_MM = opset.matmul( + input_node, + A_W, + transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes["transpose"], + transpose_b=True, + ) B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True) node_output_port = mm_node.output(0) @@ -367,6 +373,10 @@ def filter_func(point: StatisticPoint) -> bool: return filter_func + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + return get_activation_channel_axis(node, port_id, input_shape) + class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend): """ diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 4aea4633ebb..308828e1f2a 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -371,7 +371,7 @@ def calculate_quantization_params( return result_scale, zp @staticmethod - def activations_to_wc_statistics(activations: list[Tensor]) -> WCTensorStatistic: + def activations_to_wc_statistics(activations: list[Tensor], input_channel_axis: int) -> WCTensorStatistic: """ Mimic the activation reducing logic from WeightCompression.get_statistic_points. @@ -382,7 +382,9 @@ def activations_to_wc_statistics(activations: list[Tensor]) -> WCTensorStatistic shapes = [] for act in activations: shapes.append(act.shape) - reduction_shape = tuple(range(act.ndim - 1)) + # negative axis (e.g. -1 for the last axis) is converted into corresponding positive value + input_channel_axis = input_channel_axis % len(act.shape) + reduction_shape = tuple(i for i in range(len(act.shape)) if i != input_channel_axis) mean_values.append(fns.mean(act, axis=reduction_shape)) wc_statistics = WCTensorStatistic(mean_values, shapes) return wc_statistics diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index 50f765c35c3..c3435d27a3a 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -497,6 +497,10 @@ def transform_model( return transformed_model + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + return node.metatype.output_channel_axis + class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend): @staticmethod diff --git a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py index 2650f16600c..590c60ab970 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -281,6 +281,10 @@ def transform_model( return transformed_model + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: tuple[int]) -> int: + return node.metatype.output_channel_axis + class FXMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, FXWeightCompressionAlgoBackend): @staticmethod diff --git a/tests/openvino/native/quantization/test_gptq.py b/tests/openvino/native/quantization/test_gptq.py index a141d7c99fe..8d12703a720 100644 --- a/tests/openvino/native/quantization/test_gptq.py +++ b/tests/openvino/native/quantization/test_gptq.py @@ -346,7 +346,10 @@ def test_calculate_scale_linear(): nodes = graph.get_all_nodes() wrapped_inputs = [Tensor(inp) for inp in inputs] - H = gptq._calculate_hessian(nodes[1], wrapped_inputs) + input_channel_axis = gptq._backend_entity.get_activation_channel_axis( + nodes[1], gptq._backend_entity.get_activation_port_id(nodes[1], graph), wrapped_inputs[0].shape + ) + H = gptq._calculate_hessian(nodes[1], wrapped_inputs, input_channel_axis) ref_H = ref_gptq.H.numpy() assert np.all(np.isclose(ref_H, H.data)) @@ -356,7 +359,7 @@ def test_calculate_scale_linear(): ) wc_params.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=16) - scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs) + scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs, input_channel_axis) ref_scale = ref_scale.numpy() scale = scale.reshape(ref_scale.shape) assert np.all(np.isclose(ref_scale, scale.data)) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 6122f273114..9ef14cab931 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -11,7 +11,8 @@ import inspect import os -from typing import Callable +from contextlib import nullcontext +from typing import Callable, Optional from unittest.mock import patch import numpy as np @@ -95,7 +96,9 @@ class LMLinearModel(OVReferenceModel): HIDDEN_DIM = 16 INPUT_SHAPE = [1, 24, HIDDEN_DIM] # [B, SeqLen, HiddenDim] - def _create_ov_model(self, transpose_b: bool = True, transpose_a=False, input_shape=None): + def _create_ov_model( + self, transpose_b: bool = True, transpose_a: bool = False, input_shape: Optional[list[int]] = None + ): self._input_shape = self.INPUT_SHAPE if input_shape is None else input_shape hdim_axis = -2 if transpose_a else -1 self._hidden_dim = self._input_shape[hdim_axis] @@ -1454,6 +1457,16 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): ) +@pytest.mark.parametrize( + ("transpose_a", "transpose_b", "raises_error"), + [ + (False, True, False), + (True, True, False), + (False, False, True), + (True, False, True), + ], + ids=["tb_nota", "ta_tb", "nota_notb", "ta_notb"], +) @pytest.mark.parametrize( "kwargs", [ @@ -1466,14 +1479,19 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)), ), ], + ids=["se", "lora", "gptq_se_awq"], ) -def test_compression_with_transposed_activations(kwargs): +def test_compression_with_transpose(transpose_a, transpose_b, raises_error, kwargs): dataset_size = 4 - model = LMLinearModel(transpose_a=True, transpose_b=False).ov_model + model = LMLinearModel(transpose_a=transpose_a, transpose_b=transpose_b).ov_model input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size dataset = Dataset(input_data) - with pytest.raises(nncf.UnsupportedModelError): + with ( + pytest.raises(nncf.UnsupportedModelError) + if raises_error and not kwargs.get("lora_correction", False) + else nullcontext() + ): compress_weights( model, mode=CompressWeightsMode.INT4_SYM,