Skip to content

Commit d870e83

Browse files
Marco Giordanometa-codesync[bot]
authored andcommitted
Fix Conv1d w8a32 operator (pytorch#16607)
Summary: Pull Request resolved: pytorch#16607 #### Summary This diff fixes the Conv1d w8a32 operator by adding a transformation to the `val` attribute of the `other_inputs[0].meta` dictionary. Specifically, the `permute` operation is applied to the `original_val` tensor with the `fake_mode` context, and the resulting `transposed_val` is assigned to `transposed_inputs.meta["val"]`. Reviewed By: mcremon-meta Differential Revision: D89863750
1 parent afc9989 commit d870e83

2 files changed

Lines changed: 40 additions & 14 deletions

File tree

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -438,26 +438,40 @@ def get_args_and_kwargs_mixed_w8a32_conv(
438438
torch.ops.aten.permute.default,
439439
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
440440
)
441-
assert "val" in other_inputs[0].meta, "Missing val metadata on input node"
442-
original_val = other_inputs[0].meta["val"]
443-
assert original_val.fake_mode is not None, "fake_mode is None on input node"
444-
with original_val.fake_mode:
445-
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
446-
original_val, [0, 2, 1]
447-
)
441+
# Propagate val metadata for transposed_inputs
442+
if "val" in other_inputs[0].meta:
443+
original_val = other_inputs[0].meta["val"]
444+
fake_mode = original_val.fake_mode
445+
if fake_mode is not None:
446+
with fake_mode:
447+
transposed_val = torch.ops.aten.permute.default(
448+
original_val, [0, 2, 1]
449+
)
450+
transposed_inputs.meta["val"] = transposed_val
451+
else:
452+
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
453+
original_val, [0, 2, 1]
454+
)
448455
copy_node_metadata(transposed_inputs, other_inputs[0])
449456

450457
transposed_weights = graph_module.graph.call_function(
451458
torch.ops.aten.permute.default,
452459
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
453460
)
454-
assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node"
455-
original_val = weights_inputs[0].meta["val"]
456-
assert original_val.fake_mode is not None, "fake_mode is None on weight node"
457-
with original_val.fake_mode:
458-
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
459-
original_val, [2, 0, 1]
460-
)
461+
# Propagate val metadata for transposed_weights
462+
if "val" in weights_inputs[0].meta:
463+
original_val = weights_inputs[0].meta["val"]
464+
fake_mode = original_val.fake_mode
465+
if fake_mode is not None:
466+
with fake_mode:
467+
transposed_val = torch.ops.aten.permute.default(
468+
original_val, [2, 0, 1]
469+
)
470+
transposed_weights.meta["val"] = transposed_val
471+
else:
472+
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
473+
original_val, [2, 0, 1]
474+
)
461475
copy_node_metadata(transposed_weights, weights_inputs[0])
462476

463477
args = (

backends/cadence/aot/quantizer/patterns.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,18 @@ def get_anchors(
744744
conv_layer,
745745
)
746746

747+
inputs = conv_layer.args[0]
748+
if "tensor_meta" in inputs.meta:
749+
inputs_shape = inputs.meta["tensor_meta"].shape
750+
# Bail if length != kernel size - Not yet supported
751+
if inputs_shape[-1] != cnn_weights_shape[2]:
752+
return (
753+
PartitionAnchors(
754+
empty=True,
755+
),
756+
conv_layer,
757+
)
758+
747759
return (
748760
PartitionAnchors(
749761
inputs=[],

0 commit comments

Comments
 (0)