Skip to content

Commit 727b730

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Fix Conv1d w8a32 operator
Summary: #### Summary This diff fixes the Conv1d w8a32 operator by adding a transformation to the `val` attribute of the `other_inputs[0].meta` dictionary. Specifically, the `permute` operation is applied to the `original_val` tensor with the `fake_mode` context, and the resulting `transposed_val` is assigned to `transposed_inputs.meta["val"]`. Reviewed By: mcremon-meta Differential Revision: D89863750
1 parent 9cbe754 commit 727b730

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

backends/cadence/aot/compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
print_memory_planning_info,
2323
)
2424
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
25+
from executorch.exir.passes.spec_prop_pass import SpecPropPass
2526
from executorch.backends.cadence.aot.quantizer.quantizer import (
2627
CadenceDefaultQuantizer,
2728
CadenceQuantizer,
@@ -157,7 +158,9 @@ def apply_pre_edge_transform_passes(
157158
# Get patterns and apply fusion of dq -> op -> q to qop
158159
# pyre-ignore[16]: no attribute
159160
patterns = [q.pattern for q in quantizer.quantizers]
160-
fused_program = _transform(converted_program, QuantFusion(patterns))
161+
fused_program = _transform(
162+
converted_program, QuantFusion(patterns), SpecPropPass()
163+
)
161164

162165
return fused_program
163166

backends/cadence/aot/quantizer/patterns.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def get_anchors(
625625
)
626626

627627
cnn_weights = conv_layer.args[1]
628-
if hasattr(cnn_weights.meta, "tensor_meta"):
628+
if "tensor_meta" in cnn_weights.meta:
629629
cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape
630630
# Bail if the channels are not multiple of 4 (SIMD)
631631
if cnn_weights_shape[0] % 4 != 0:
@@ -651,6 +651,18 @@ def get_anchors(
651651
conv_layer,
652652
)
653653

654+
inputs = conv_layer.args[0]
655+
if "tensor_meta" in inputs.meta:
656+
inputs_shape = inputs.meta["tensor_meta"].shape
657+
# Bail if length != kernel size - Not yet supported
658+
if inputs_shape[-1] != cnn_weights_shape[2]:
659+
return (
660+
PartitionAnchors(
661+
empty=True,
662+
),
663+
conv_layer,
664+
)
665+
654666
return (
655667
PartitionAnchors(
656668
inputs=[],

0 commit comments

Comments
 (0)