@@ -6437,6 +6437,7 @@ def scaled_dot_product_attention(context, node):
64376437 - attn_mask : (target_seq, source_seq) or (B, target_seq, source_seq) or (B, h, target_seq, source_seq) or
64386438 (B, ..., target_seq, source_seq)
64396439 - is_causal : bool
6440+ - scale : optional float
64406441
64416442 Output shape: (target_seq, d_v) or (B,...,target_seq, d_v)
64426443
@@ -6448,14 +6449,22 @@ def scaled_dot_product_attention(context, node):
64486449 https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
64496450 """
64506451 inputs = _get_inputs (context , node , min_expected = 3 )
6451- q , k , v = inputs [: 3 ]
6452+ q , k , v = inputs [:3 ]
64526453 attn_mask = None if len (inputs ) < 4 else inputs [3 ]
64536454 dropout = 0.0 if len (inputs ) < 5 else inputs [4 ]
64546455 is_causal = False if len (inputs ) < 6 else inputs [5 ].val
6456+
6457+ # When len(inputs) == 7, the inputs are (q, k, v, attn_mask, dropout, is_causal, scale)
6458+ if len (inputs ) == 7 and inputs [6 ] is not None :
6459+ raise NotImplementedError (
6460+ "scaled_dot_product_attention op: scale parameter is not handled."
6461+ )
6462+
64556463 if attn_mask is not None and is_causal :
64566464 raise ValueError (
64576465 "scaled_dot_product_attention op: attn_mask cannot be provided when is_causal is set to True."
64586466 )
6467+
64596468 if dropout is not None and (dropout .val is None or dropout .val != 0.0 ):
64606469 raise ValueError ("scaled_dot_product_attention op: dropout is not supported yet" )
64616470
0 commit comments