From 6780425ecfe85c76ac0bcd0dd44de82f9ed8e859 Mon Sep 17 00:00:00 2001 From: Jun Jiang Date: Wed, 20 Aug 2025 09:59:11 -0700 Subject: [PATCH] Add None check before dataclasses.replace. PiperOrigin-RevId: 797357585 --- .../algorithms/utils/common_utils.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/ai_edge_quantizer/algorithms/utils/common_utils.py b/ai_edge_quantizer/algorithms/utils/common_utils.py index 184dbf3d..730fede5 100644 --- a/ai_edge_quantizer/algorithms/utils/common_utils.py +++ b/ai_edge_quantizer/algorithms/utils/common_utils.py @@ -295,16 +295,18 @@ def _materialize_standard_op_with_same_as_input_scale( op_tensor_params.append(input_tensor_params) # Use input quantization params for all output tensors but without # quantized_data in case the input is a constant tensor. - input_quant_params = dataclasses.replace( - input_tensor_params.consumers[0].parameters, - quantized_data=None, - ) - if not isinstance(input_quant_params, qtyping.UniformQuantParams): - raise ValueError( - "_materialize_standard_op_with_same_as_input_scale only supports" - f" UniformQuantParams. For tensor {input_tensor_params.tensor_name}," - f" got {type(input_quant_params)}" + input_quant_params = input_tensor_params.consumers[0].parameters + if input_quant_params is not None: + input_quant_params = dataclasses.replace( + input_quant_params, + quantized_data=None, ) + if not isinstance(input_quant_params, qtyping.UniformQuantParams): + raise ValueError( + "_materialize_standard_op_with_same_as_input_scale only supports" + f" UniformQuantParams. For tensor {input_tensor_params.tensor_name}," + f" got {type(input_quant_params)}" + ) # Materialize each of the output tensors separately in case there are # constants among them, requiring updating `quantized_data` first. for output_tensor in output_tensors: @@ -312,7 +314,9 @@ def _materialize_standard_op_with_same_as_input_scale( output_tensor, graph_info.buffers ) # Quantize constant inputs' data with the output quantization params. - if output_tensor_data is None: + if input_quant_params is None: + quant_params = None + elif output_tensor_data is None: quant_params = input_quant_params else: quantized_data = uniform_quantize_tensor.uniform_quantize( @@ -335,11 +339,12 @@ def _materialize_standard_op_with_same_as_input_scale( # Change output qsv to be the same as input qsv. This is safe since TFL # subgraph is acyclic. - input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name] - for output_tensor in output_tensors: - tensor_name_to_qsv[tfl_flatbuffer_utils.get_tensor_name(output_tensor)] = ( - input_tensor_qsv - ) + if input_tensor_params.tensor_name in tensor_name_to_qsv: + input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name] + for output_tensor in output_tensors: + tensor_name_to_qsv[ + tfl_flatbuffer_utils.get_tensor_name(output_tensor) + ] = input_tensor_qsv return op_tensor_params