Skip to content

Commit e72db64

Browse files
Update scaled_dot_product_attention to work with >6 inputs in latest torch version (#2021)
1 parent 0c17d42 commit e72db64

File tree

1 file changed

+10
-1
lines changed
  • coremltools/converters/mil/frontend/torch

1 file changed

+10
-1
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6437,6 +6437,7 @@ def scaled_dot_product_attention(context, node):
64376437
- attn_mask : (target_seq, source_seq) or (B, target_seq, source_seq) or (B, h, target_seq, source_seq) or
64386438
(B, ..., target_seq, source_seq)
64396439
- is_causal : bool
6440+
- scale : optional float
64406441
64416442
Output shape: (target_seq, d_v) or (B,...,target_seq, d_v)
64426443
@@ -6448,14 +6449,22 @@ def scaled_dot_product_attention(context, node):
64486449
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
64496450
"""
64506451
inputs = _get_inputs(context, node, min_expected=3)
6451-
q, k, v = inputs[: 3]
6452+
q, k, v = inputs[:3]
64526453
attn_mask = None if len(inputs) < 4 else inputs[3]
64536454
dropout = 0.0 if len(inputs) < 5 else inputs[4]
64546455
is_causal = False if len(inputs) < 6 else inputs[5].val
6456+
6457+
# When len(inputs) == 7, the inputs are (q, k, v, attn_mask, dropout, is_causal, scale)
6458+
if len(inputs) == 7 and inputs[6] is not None:
6459+
raise NotImplementedError(
6460+
"scaled_dot_product_attention op: scale parameter is not handled."
6461+
)
6462+
64556463
if attn_mask is not None and is_causal:
64566464
raise ValueError(
64576465
"scaled_dot_product_attention op: attn_mask cannot be provided when is_causal is set to True."
64586466
)
6467+
64596468
if dropout is not None and (dropout.val is None or dropout.val != 0.0):
64606469
raise ValueError("scaled_dot_product_attention op: dropout is not supported yet")
64616470

0 commit comments

Comments
 (0)