Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions ai_edge_quantizer/algorithms/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,24 +295,28 @@ def _materialize_standard_op_with_same_as_input_scale(
op_tensor_params.append(input_tensor_params)
# Use input quantization params for all output tensors but without
# quantized_data in case the input is a constant tensor.
input_quant_params = dataclasses.replace(
input_tensor_params.consumers[0].parameters,
quantized_data=None,
)
if not isinstance(input_quant_params, qtyping.UniformQuantParams):
raise ValueError(
"_materialize_standard_op_with_same_as_input_scale only supports"
f" UniformQuantParams. For tensor {input_tensor_params.tensor_name},"
f" got {type(input_quant_params)}"
input_quant_params = input_tensor_params.consumers[0].parameters
if input_quant_params is not None:
input_quant_params = dataclasses.replace(
input_quant_params,
quantized_data=None,
)
if not isinstance(input_quant_params, qtyping.UniformQuantParams):
raise ValueError(
"_materialize_standard_op_with_same_as_input_scale only supports"
f" UniformQuantParams. For tensor {input_tensor_params.tensor_name},"
f" got {type(input_quant_params)}"
)
# Materialize each of the output tensors separately in case there are
# constants among them, requiring updating `quantized_data` first.
for output_tensor in output_tensors:
output_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
output_tensor, graph_info.buffers
)
# Quantize constant inputs' data with the output quantization params.
if output_tensor_data is None:
if input_quant_params is None:
quant_params = None
elif output_tensor_data is None:
quant_params = input_quant_params
else:
quantized_data = uniform_quantize_tensor.uniform_quantize(
Expand All @@ -335,11 +339,12 @@ def _materialize_standard_op_with_same_as_input_scale(

# Change output qsv to be the same as input qsv. This is safe since TFL
# subgraph is acyclic.
input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
for output_tensor in output_tensors:
tensor_name_to_qsv[tfl_flatbuffer_utils.get_tensor_name(output_tensor)] = (
input_tensor_qsv
)
if input_tensor_params.tensor_name in tensor_name_to_qsv:
input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
for output_tensor in output_tensors:
tensor_name_to_qsv[
tfl_flatbuffer_utils.get_tensor_name(output_tensor)
] = input_tensor_qsv

return op_tensor_params

Expand Down
Loading