Skip to content

Commit 5132ba0

Browse files
Marco Giordanofacebook-github-bot
authored andcommitted
Fix Conv1d w8a32 operator (pytorch#16607)
Summary: #### Summary This diff fixes the Conv1d w8a32 operator by adding a transformation to the `val` attribute of the `other_inputs[0].meta` dictionary. Specifically, the `permute` operation is applied to the `original_val` tensor with the `fake_mode` context, and the resulting `transposed_val` is assigned to `transposed_inputs.meta["val"]`. Reviewed By: mcremon-meta Differential Revision: D89863750
1 parent e638059 commit 5132ba0

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -438,26 +438,36 @@ def get_args_and_kwargs_mixed_w8a32_conv(
438438
torch.ops.aten.permute.default,
439439
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
440440
)
441-
assert "val" in other_inputs[0].meta, "Missing val metadata on input node"
442-
original_val = other_inputs[0].meta["val"]
443-
assert original_val.fake_mode is not None, "fake_mode is None on input node"
444-
with original_val.fake_mode:
445-
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
446-
original_val, [0, 2, 1]
447-
)
441+
# Propagate val metadata for transposed_inputs
442+
if "val" in other_inputs[0].meta:
443+
original_val = other_inputs[0].meta["val"]
444+
fake_mode = original_val.fake_mode
445+
if fake_mode is not None:
446+
with fake_mode:
447+
transposed_val = torch.ops.aten.permute.default(original_val, [0, 2, 1])
448+
transposed_inputs.meta["val"] = transposed_val
449+
else:
450+
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
451+
original_val, [0, 2, 1]
452+
)
448453
copy_node_metadata(transposed_inputs, other_inputs[0])
449454

450455
transposed_weights = graph_module.graph.call_function(
451456
torch.ops.aten.permute.default,
452457
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
453458
)
454-
assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node"
455-
original_val = weights_inputs[0].meta["val"]
456-
assert original_val.fake_mode is not None, "fake_mode is None on weight node"
457-
with original_val.fake_mode:
458-
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
459-
original_val, [2, 0, 1]
460-
)
459+
# Propagate val metadata for transposed_weights
460+
if "val" in weights_inputs[0].meta:
461+
original_val = weights_inputs[0].meta["val"]
462+
fake_mode = original_val.fake_mode
463+
if fake_mode is not None:
464+
with fake_mode:
465+
transposed_val = torch.ops.aten.permute.default(original_val, [2, 0, 1])
466+
transposed_weights.meta["val"] = transposed_val
467+
else:
468+
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
469+
original_val, [2, 0, 1]
470+
)
461471
copy_node_metadata(transposed_weights, weights_inputs[0])
462472

463473
args = (

backends/cadence/aot/quantizer/patterns.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,18 @@ def get_anchors(
744744
conv_layer,
745745
)
746746

747+
inputs = conv_layer.args[0]
748+
if "tensor_meta" in inputs.meta:
749+
inputs_shape = inputs.meta["tensor_meta"].shape
750+
# Bail if length != kernel size - Not yet supported
751+
if inputs_shape[-1] != cnn_weights_shape[2]:
752+
return (
753+
PartitionAnchors(
754+
empty=True,
755+
),
756+
conv_layer,
757+
)
758+
747759
return (
748760
PartitionAnchors(
749761
inputs=[],

0 commit comments

Comments
 (0)