Skip to content

Commit 6780425

Browse files
junjiang-labcopybara-github
authored andcommitted
Add None check before dataclasses.replace.
PiperOrigin-RevId: 797357585
1 parent 451ac5c commit 6780425

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

ai_edge_quantizer/algorithms/utils/common_utils.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -295,24 +295,28 @@ def _materialize_standard_op_with_same_as_input_scale(
295295
op_tensor_params.append(input_tensor_params)
296296
# Use input quantization params for all output tensors but without
297297
# quantized_data in case the input is a constant tensor.
298-
input_quant_params = dataclasses.replace(
299-
input_tensor_params.consumers[0].parameters,
300-
quantized_data=None,
301-
)
302-
if not isinstance(input_quant_params, qtyping.UniformQuantParams):
303-
raise ValueError(
304-
"_materialize_standard_op_with_same_as_input_scale only supports"
305-
f" UniformQuantParams. For tensor {input_tensor_params.tensor_name},"
306-
f" got {type(input_quant_params)}"
298+
input_quant_params = input_tensor_params.consumers[0].parameters
299+
if input_quant_params is not None:
300+
input_quant_params = dataclasses.replace(
301+
input_quant_params,
302+
quantized_data=None,
307303
)
304+
if not isinstance(input_quant_params, qtyping.UniformQuantParams):
305+
raise ValueError(
306+
"_materialize_standard_op_with_same_as_input_scale only supports"
307+
f" UniformQuantParams. For tensor {input_tensor_params.tensor_name},"
308+
f" got {type(input_quant_params)}"
309+
)
308310
# Materialize each of the output tensors separately in case there are
309311
# constants among them, requiring updating `quantized_data` first.
310312
for output_tensor in output_tensors:
311313
output_tensor_data = tfl_flatbuffer_utils.get_tensor_data(
312314
output_tensor, graph_info.buffers
313315
)
314316
# Quantize constant inputs' data with the output quantization params.
315-
if output_tensor_data is None:
317+
if input_quant_params is None:
318+
quant_params = None
319+
elif output_tensor_data is None:
316320
quant_params = input_quant_params
317321
else:
318322
quantized_data = uniform_quantize_tensor.uniform_quantize(
@@ -335,11 +339,12 @@ def _materialize_standard_op_with_same_as_input_scale(
335339

336340
# Change output qsv to be the same as input qsv. This is safe since TFL
337341
# subgraph is acyclic.
338-
input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
339-
for output_tensor in output_tensors:
340-
tensor_name_to_qsv[tfl_flatbuffer_utils.get_tensor_name(output_tensor)] = (
341-
input_tensor_qsv
342-
)
342+
if input_tensor_params.tensor_name in tensor_name_to_qsv:
343+
input_tensor_qsv = tensor_name_to_qsv[input_tensor_params.tensor_name]
344+
for output_tensor in output_tensors:
345+
tensor_name_to_qsv[
346+
tfl_flatbuffer_utils.get_tensor_name(output_tensor)
347+
] = input_tensor_qsv
343348

344349
return op_tensor_params
345350

0 commit comments

Comments
 (0)