diff --git a/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp b/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp index 9e0249d0d1f7..fbec42961c95 100644 --- a/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp +++ b/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp @@ -516,6 +516,17 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( seqlensK, cstInt64Dtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/cstNone); } + seqLensKType = cast(seqlensK.getType()); + // ORT accepts seqlens_k as [B, 1] or [1, B]; normalize it to [B]. + if (seqLensKType.hasSizes() && seqLensKType.getSizes().size() == 2) { + int64_t squeezeDim = + seqLensKType.getSizes()[0] == 1 && seqLensKType.getSizes()[1] != 1 + ? 0 + : 1; + seqlensK = Torch::squeezeTensor(rewriter, binder.op, loc, squeezeDim, + seqlensK) + .value(); + } // Reshape Q/K/V from [batch, seq, hidden] to [batch, heads, seq, // head_size]. This requires: @@ -669,220 +680,209 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( /*scale=*/cstFloatOne); } - // Build present_key/present_value by padding past with zeros, then - // scattering current K/V into the correct per-batch position. - // - // Why pad instead of cat? With cat(past, current), the current token - // ends up at position past_seq_len. With variable seqlens_k, ORT places - // the current token at seqlens_k[b] and leaves position past_seq_len as - // zero/uninitialized. Using cat+scatter leaves a stale copy of the - // current token at past_seq_len which doesn't match ORT's output. - // Padding with zeros then scattering avoids this. - // - // constant_pad_nd pads innermost dims first: [0, 0, 0, seq_len] - // dim 3 (head_size): [0, 0] -- no padding - // dim 2 (seq): [0, seq_len] -- extend by seq_len on the right - Value cstFloatZero = Torch::ConstantFloatOp::create( - rewriter, loc, rewriter.getType(), - rewriter.getF64FloatAttr(0.0)); - Value padList = Torch::PrimListConstructOp::create( - rewriter, loc, intListType, - SmallVector{cstIntZero, cstIntZero, cstIntZero, - cstSequenceLength}); - Value presentKey = Torch::AtenConstantPadNdOp::create( - rewriter, loc, resultTypes[1], pastKey, padList, cstFloatZero); - Value presentValue = Torch::AtenConstantPadNdOp::create( - rewriter, loc, resultTypes[2], pastValue, padList, cstFloatZero); - - // Scatter current K/V into the padded buffer at position pastLen[b]+q. - // pastLen = seqlens_k + 1 - seq_len - Value totalSeqForScatter = Torch::AtenAddScalarOp::create( + auto pastKeyType = cast(pastKey.getType()); + auto pastValueType = cast(pastValue.getType()); + auto presentKeyType = cast(resultTypes[1]); + auto presentValueType = cast(resultTypes[2]); + // This lowering writes into a fixed-capacity cache buffer. + if (pastKeyType != presentKeyType || pastValueType != presentValueType) + return rewriter.notifyMatchFailure( + binder.op, + "Only buffer-sharing GQA is supported: past_key/past_value " + "must have the same type as present_key/present_value"); + + Type keyElemType = presentKeyType.getOptionalDtype(); + Type valueElemType = presentValueType.getOptionalDtype(); + int64_t pastSeqStatic = Torch::kUnknownSize; + if (pastKeyType.hasSizes() && pastKeyType.getSizes().size() > 2) + pastSeqStatic = pastKeyType.getSizes()[2]; + + // Convert seqlens_k into per-batch cache write offsets. + Value totalSeqForWrite = Torch::AtenAddScalarOp::create( rewriter, loc, seqlensK.getType(), seqlensK, cstIntOne, cstIntOne); Value pastLen = Torch::AtenSubScalarOp::create( - rewriter, loc, totalSeqForScatter.getType(), totalSeqForScatter, + rewriter, loc, totalSeqForWrite.getType(), totalSeqForWrite, cstSequenceLength, cstIntOne); - // qRange: [0, 1, ..., seqLen-1] — shared by scatter and mask. + Value cstDim3 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(3)); + auto intSiType = rewriter.getIntegerType(64, /*isSigned=*/true); + Torch::ValueTensorType qRangeType = Torch::ValueTensorType::get( - context, {sequenceLength}, - rewriter.getIntegerType(64, /*isSigned=*/true)); + context, SmallVector{sequenceLength}, intSiType); Value qRange = Torch::AtenArangeOp::create( rewriter, loc, qRangeType, cstSequenceLength, cstInt64Dtype, /*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone); - // pastLen -> [B, 1, 1, 1] via unsqueeze chain for 4D scatter index - // broadcasting - auto intSiType = rewriter.getIntegerType(64, /*isSigned=*/true); - Value cstDim3 = Torch::ConstantIntOp::create( - rewriter, loc, rewriter.getI64IntegerAttr(3)); - // [batch] -> [batch, 1] - Torch::ValueTensorType pastLenUnsq1Type = Torch::ValueTensorType::get( - context, SmallVector{batchSize, 1}, intSiType); - Value pastLenUnsq1 = Torch::AtenUnsqueezeOp::create( - rewriter, loc, pastLenUnsq1Type, pastLen, cstDim1); - // [batch, 1] -> [batch, 1, 1] - Torch::ValueTensorType pastLenUnsq2Type = Torch::ValueTensorType::get( - context, SmallVector{batchSize, 1, 1}, intSiType); - Value pastLenUnsq2 = Torch::AtenUnsqueezeOp::create( - rewriter, loc, pastLenUnsq2Type, pastLenUnsq1, cstDim2); - // [batch, 1, 1] -> [batch, 1, 1, 1] - Torch::ValueTensorType pastLenView4dType = Torch::ValueTensorType::get( - context, SmallVector{batchSize, 1, 1, 1}, intSiType); - Value pastLenView4d = Torch::AtenUnsqueezeOp::create( - rewriter, loc, pastLenView4dType, pastLenUnsq2, cstDim3); - - // qRange -> [1, 1, seq, 1] via unsqueeze chain for scatter - // [seq] -> [1, seq] - Torch::ValueTensorType qUnsq0Type = Torch::ValueTensorType::get( - context, SmallVector{1, sequenceLength}, intSiType); - Value qUnsq0 = Torch::AtenUnsqueezeOp::create(rewriter, loc, qUnsq0Type, - qRange, cstIntZero); - // [1, seq] -> [1, 1, seq] - Torch::ValueTensorType qUnsq1Type = Torch::ValueTensorType::get( - context, SmallVector{1, 1, sequenceLength}, intSiType); - Value qUnsq1 = Torch::AtenUnsqueezeOp::create(rewriter, loc, qUnsq1Type, - qUnsq0, cstIntZero); - // [1, 1, seq] -> [1, 1, seq, 1] - Torch::ValueTensorType scatterQViewType = Torch::ValueTensorType::get( - context, SmallVector{1, 1, sequenceLength, 1}, intSiType); - Value scatterQRangeView = Torch::AtenUnsqueezeOp::create( - rewriter, loc, scatterQViewType, qUnsq1, cstDim3); - - // scatterIdxBase = pastLen[B,1,1,1] + qRange[1,1,seq,1] - // -> [B, 1, seq, 1] - SmallVector scatterIdxBaseSizes{batchSize, 1, sequenceLength, - 1}; - Torch::ValueTensorType scatterIdxBaseType = Torch::ValueTensorType::get( - context, scatterIdxBaseSizes, - rewriter.getIntegerType(64, /*isSigned=*/true)); - Value scatterIdxBase = Torch::AtenAddTensorOp::create( - rewriter, loc, scatterIdxBaseType, pastLenView4d, scatterQRangeView, - cstIntOne); - - // Expand to [B, kv_heads, seq, head_size] to match current K/V shape - SmallVector scatterExpandSizes{batchSize, kvNumHeads, - sequenceLength, headSize}; - Torch::ValueTensorType scatterIdxType = Torch::ValueTensorType::get( - context, scatterExpandSizes, - rewriter.getIntegerType(64, /*isSigned=*/true)); - Value scatterExpandSizeList = Torch::PrimListConstructOp::create( + Value pastSeq = rewriter.createOrFold( + loc, rewriter.getType(), pastKey, cstDim2); + Value kvPastStride = Torch::AtenMulIntOp::create( + rewriter, loc, rewriter.getType(), cstKVNumHeads, + pastSeq); + + int64_t flatPastRowsStatic = Torch::kUnknownSize; + if (batchSize != Torch::kUnknownSize && + pastSeqStatic != Torch::kUnknownSize) + flatPastRowsStatic = batchSize * kvNumHeads * pastSeqStatic; + int64_t flatSrcRowsStatic = Torch::kUnknownSize; + if (batchSize != Torch::kUnknownSize && + sequenceLength != Torch::kUnknownSize) + flatSrcRowsStatic = batchSize * kvNumHeads * sequenceLength; + + // Compute flat row indices for [batch, kv_head, token] cache writes. + Torch::ValueTensorType bRangeType = Torch::ValueTensorType::get( + context, SmallVector{batchSize}, intSiType); + Value bRange = Torch::AtenArangeOp::create( + rewriter, loc, bRangeType, cstBatchSize, cstInt64Dtype, + /*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone); + Value bRange2d = + Torch::unsqueezeTensor(rewriter, binder.op, bRange, cstDim1) + .value(); + Value bRange3d = + Torch::unsqueezeTensor(rewriter, binder.op, bRange2d, cstDim2) + .value(); + Torch::ValueTensorType bhqPrefixType = + cast(bRange3d.getType()); + + Torch::ValueTensorType hRangeType = Torch::ValueTensorType::get( + context, SmallVector{kvNumHeads}, intSiType); + Value hRange = Torch::AtenArangeOp::create( + rewriter, loc, hRangeType, cstKVNumHeads, cstInt64Dtype, + /*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone); + Value hRange2d = + Torch::unsqueezeTensor(rewriter, binder.op, hRange, cstIntZero) + .value(); + Value hRange3d = + Torch::unsqueezeTensor(rewriter, binder.op, hRange2d, cstDim2) + .value(); + Torch::ValueTensorType h3dType = + cast(hRange3d.getType()); + + Value qRange2d = + Torch::unsqueezeTensor(rewriter, binder.op, qRange, cstIntZero) + .value(); + Value qRange3d = + Torch::unsqueezeTensor(rewriter, binder.op, qRange2d, cstIntZero) + .value(); + + Value pastLen2d = + Torch::unsqueezeTensor(rewriter, binder.op, pastLen, cstDim1) + .value(); + Value pastLen3d = + Torch::unsqueezeTensor(rewriter, binder.op, pastLen2d, cstDim2) + .value(); + + Value bContrib = Torch::AtenMulScalarOp::create( + rewriter, loc, bhqPrefixType, bRange3d, kvPastStride); + Value hContrib = Torch::AtenMulScalarOp::create(rewriter, loc, h3dType, + hRange3d, pastSeq); + + Torch::ValueTensorType bhOneType = Torch::ValueTensorType::get( + context, SmallVector{batchSize, kvNumHeads, 1}, intSiType); + Value targetBh1 = Torch::AtenAddTensorOp::create( + rewriter, loc, bhOneType, bContrib, hContrib, cstIntOne); + targetBh1 = Torch::AtenAddTensorOp::create( + rewriter, loc, bhOneType, targetBh1, pastLen3d, cstIntOne); + + Torch::ValueTensorType target3dType = Torch::ValueTensorType::get( + context, + SmallVector{batchSize, kvNumHeads, sequenceLength}, + intSiType); + Value target3d = Torch::AtenAddTensorOp::create( + rewriter, loc, target3dType, targetBh1, qRange3d, cstIntOne); + + Torch::ValueTensorType target1dType = Torch::ValueTensorType::get( + context, SmallVector{flatSrcRowsStatic}, intSiType); + Value target1d = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, target1dType, target3d, cstIntZero, cstDim2); + + // Flatten K/V caches to rows for index_put updates. + Torch::ValueTensorType past2dType = Torch::ValueTensorType::get( + context, SmallVector{flatPastRowsStatic, headSize}, + keyElemType); + Torch::ValueTensorType pastValue2dType = Torch::ValueTensorType::get( + context, SmallVector{flatPastRowsStatic, headSize}, + valueElemType); + Torch::ValueTensorType src2dType = Torch::ValueTensorType::get( + context, SmallVector{flatSrcRowsStatic, headSize}, + keyElemType); + Torch::ValueTensorType valueSrc2dType = Torch::ValueTensorType::get( + context, SmallVector{flatSrcRowsStatic, headSize}, + valueElemType); + + Value pastKey2d = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, past2dType, pastKey, cstIntZero, cstDim2); + Value pastValue2d = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, pastValue2dType, pastValue, cstIntZero, cstDim2); + Value kRotary2d = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, src2dType, kRotary, cstIntZero, cstDim2); + Value vInput2d = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, valueSrc2dType, vInput, cstIntZero, cstDim2); + + Type idxListType = Torch::ListType::get( + Torch::ValueTensorType::getWithLeastStaticInformation(context)); + Value idxList = Torch::PrimListConstructOp::create( + rewriter, loc, idxListType, SmallVector{target1d}); + + Value newPresentKey2d = Torch::AtenIndexPutHackedTwinOp::create( + rewriter, loc, past2dType, pastKey2d, idxList, kRotary2d, cstFalse); + Value newPresentValue2d = Torch::AtenIndexPutHackedTwinOp::create( + rewriter, loc, pastValue2dType, pastValue2d, idxList, vInput2d, + cstFalse); + + Value cacheUnflattenSizeList = Torch::PrimListConstructOp::create( rewriter, loc, intListType, - SmallVector{cstBatchSize, cstKVNumHeads, cstSequenceLength, - cstHeadSize}); - Value scatterIdx = Torch::AtenExpandOp::create( - rewriter, loc, scatterIdxType, scatterIdxBase, - scatterExpandSizeList, /*implicit=*/cstFalse); - - // Scatter current K/V into buffer at position pastLen[b] + q - presentKey = Torch::AtenScatterSrcOp::create( - rewriter, loc, resultTypes[1], presentKey, cstDim2, scatterIdx, - kRotary); - presentValue = Torch::AtenScatterSrcOp::create( - rewriter, loc, resultTypes[2], presentValue, cstDim2, scatterIdx, - vInput); - - // Generate causal attention mask. - // With scatter, KV layout matches ORT: current at pastLen[b]. - // Simple boolean mask: k <= pastLen[b] + q - // Mask shape: [batch, 1, seqLen, kvSeqLen] (i1). + SmallVector{cstBatchSize, cstKVNumHeads, pastSeq}); + Value presentKey = Torch::AtenUnflattenIntOp::create( + rewriter, loc, presentKeyType, newPresentKey2d, cstIntZero, + cacheUnflattenSizeList); + Value presentValue = Torch::AtenUnflattenIntOp::create( + rewriter, loc, presentValueType, newPresentValue2d, cstIntZero, + cacheUnflattenSizeList); + + // Build the causal mask over the full fixed-capacity KV cache. Value attnMask = cstNone; - - // Get the KV sequence length from presentKey shape - Torch::ValueTensorType presentKeyType = - cast(presentKey.getType()); if (presentKeyType.hasSizes() && presentKeyType.getSizes().size() == 4) { int64_t kvSeqLen = presentKeyType.getSizes()[2]; - - // Only generate mask if KV sequence length is dynamic or > 0 - // For dynamic shapes or non-trivial sequences, we need to mask if (kvSeqLen == Torch::kUnknownSize || kvSeqLen > 0) { - // Get KV sequence dimension size - Value kvSeqLenVal = Torch::AtenSizeIntOp::create( + Value kvSeqLenValue = Torch::AtenSizeIntOp::create( rewriter, loc, rewriter.getType(), presentKey, cstDim2); - - // kRange: [0, 1, 2, ..., kvSeqLen-1] shape [kvSeqLen] Torch::ValueTensorType kRangeType = Torch::ValueTensorType::get( - context, {kvSeqLen}, - rewriter.getIntegerType(64, /*isSigned=*/true)); + context, SmallVector{kvSeqLen}, intSiType); Value kRange = Torch::AtenArangeOp::create( - rewriter, loc, kRangeType, kvSeqLenVal, cstInt64Dtype, + rewriter, loc, kRangeType, kvSeqLenValue, cstInt64Dtype, /*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone); - // Reshape for broadcasting: - // pastLen: [batch] -> [batch, 1, 1] - // qRange: [seqLen] -> [1, seqLen, 1] (reuses qRange from above) - // kRange: [kvSeqLen] -> [1, 1, kvSeqLen] + Value qMask3d = + Torch::unsqueezeTensor(rewriter, binder.op, qRange2d, cstDim2) + .value(); + Value kRange2d = + Torch::unsqueezeTensor(rewriter, binder.op, kRange, cstIntZero) + .value(); + Value kMask3d = Torch::unsqueezeTensor(rewriter, binder.op, + kRange2d, cstIntZero) + .value(); - // pastLen -> [batch, 1, 1] via unsqueeze chain - // [batch] -> [batch, 1] - Torch::ValueTensorType pastLenMaskUnsq1Type = - Torch::ValueTensorType::get( - context, SmallVector{batchSize, 1}, intSiType); - Value pastLenMaskUnsq1 = Torch::AtenUnsqueezeOp::create( - rewriter, loc, pastLenMaskUnsq1Type, pastLen, cstDim1); - // [batch, 1] -> [batch, 1, 1] - Torch::ValueTensorType seqlensViewType = - Torch::ValueTensorType::get( - context, SmallVector{batchSize, 1, 1}, intSiType); - Value pastLenView = Torch::AtenUnsqueezeOp::create( - rewriter, loc, seqlensViewType, pastLenMaskUnsq1, cstDim2); - - // qRange -> [1, seqLen, 1] via unsqueeze chain - // [seqLen] -> [1, seqLen] - Torch::ValueTensorType qMaskUnsq0Type = Torch::ValueTensorType::get( - context, SmallVector{1, sequenceLength}, intSiType); - Value qMaskUnsq0 = Torch::AtenUnsqueezeOp::create( - rewriter, loc, qMaskUnsq0Type, qRange, cstIntZero); - // [1, seqLen] -> [1, seqLen, 1] - Torch::ValueTensorType qViewType = Torch::ValueTensorType::get( - context, SmallVector{1, sequenceLength, 1}, intSiType); - Value qRangeView = Torch::AtenUnsqueezeOp::create( - rewriter, loc, qViewType, qMaskUnsq0, cstDim2); - - // kRange -> [1, 1, kvSeqLen] via unsqueeze chain - // [kvSeqLen] -> [1, kvSeqLen] - Torch::ValueTensorType kUnsq0Type = Torch::ValueTensorType::get( - context, SmallVector{1, kvSeqLen}, intSiType); - Value kUnsq0 = Torch::AtenUnsqueezeOp::create( - rewriter, loc, kUnsq0Type, kRange, cstIntZero); - // [1, kvSeqLen] -> [1, 1, kvSeqLen] - Torch::ValueTensorType kViewType = Torch::ValueTensorType::get( - context, SmallVector{1, 1, kvSeqLen}, intSiType); - Value kRangeView = Torch::AtenUnsqueezeOp::create( - rewriter, loc, kViewType, kUnsq0, cstIntZero); - - // Causal mask: k <= pastLen + q - // pastLenView[batch,1,1] + qRangeView[1,seqLen,1] - // -> [batch, seqLen, 1] - SmallVector pastLenPlusQSizes{batchSize, sequenceLength, - 1}; Torch::ValueTensorType pastLenPlusQType = Torch::ValueTensorType::get( - context, pastLenPlusQSizes, - rewriter.getIntegerType(64, /*isSigned=*/true)); + context, SmallVector{batchSize, sequenceLength, 1}, + intSiType); Value pastLenPlusQ = Torch::AtenAddTensorOp::create( - rewriter, loc, pastLenPlusQType, pastLenView, qRangeView, - cstIntOne); + rewriter, loc, pastLenPlusQType, pastLen3d, qMask3d, cstIntOne); - // kRangeView[1,1,kvSeqLen] <= pastLenPlusQ[batch,seqLen,1] - // -> [batch, seqLen, kvSeqLen] - SmallVector maskBoolSizes{batchSize, sequenceLength, - kvSeqLen}; Torch::ValueTensorType maskBoolType = Torch::ValueTensorType::get( - context, maskBoolSizes, rewriter.getI1Type()); + context, + SmallVector{batchSize, sequenceLength, kvSeqLen}, + rewriter.getI1Type()); Value causalMask = Torch::AtenLeTensorOp::create( - rewriter, loc, maskBoolType, kRangeView, pastLenPlusQ); - - // Unsqueeze to [batch, 1, seqLen, kvSeqLen] for SDPA. - // Pass the boolean mask directly — downstream backends (e.g. - // IREE's iree_linalg_ext.attention) handle bool-to-float - // conversion internally. - SmallVector attnMaskSizes{batchSize, 1, sequenceLength, - kvSeqLen}; + rewriter, loc, maskBoolType, kMask3d, pastLenPlusQ); + Torch::ValueTensorType attnMaskType = Torch::ValueTensorType::get( - context, attnMaskSizes, rewriter.getI1Type()); + context, + SmallVector{batchSize, 1, sequenceLength, kvSeqLen}, + rewriter.getI1Type()); attnMask = Torch::AtenUnsqueezeOp::create( rewriter, loc, attnMaskType, causalMask, cstDim1); } @@ -896,6 +896,10 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( rewriter, loc, rewriter.getType(), rewriter.getF64FloatAttr(scale)); + Value cstFloatZero = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getType(), + rewriter.getF64FloatAttr(0.0)); + // Use presentKey/presentValue (full KV cache) for attention, not just // the current token's K/V. This is essential for proper KV caching. Value attention = Torch::AtenScaledDotProductAttentionOp::create( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 0e9e7b1ebd61..bed7e1ad439b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -2501,367 +2501,91 @@ func.func @test_mwm(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si6 // ----- -// CHECK-LABEL: func.func @test_group_query_attention -func.func @test_group_query_attention(%arg0: !torch.vtensor<[1,1,16],f32>, %arg1: !torch.vtensor<[1,1,16],f32>, %arg2: !torch.vtensor<[1,1,16],f32>) -> (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_RESHAPE:.+]] = torch.aten.unflatten.int %arg0, {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int %[[Q_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_RESHAPE:.+]] = torch.aten.unflatten.int %arg1, {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int %[[K_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_RESHAPE:.+]] = torch.aten.unflatten.int %arg2, {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int %[[V_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[MASK:.+]] = torch.aten.le.Tensor {{.*}} -> !torch.vtensor<[1,1,1],i1> - // CHECK: %[[MASK_RESHAPE:.+]] = torch.aten.unsqueeze %[[MASK]], {{.*}} -> !torch.vtensor<[1,1,1,1],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[K_SCATTER]], %[[V_SCATTER]], %[[MASK_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[OUT_TRANSPOSE:.+]] = torch.aten.transpose.int %[[OUTPUT]], {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[OUT_RESHAPE:.+]] = torch.aten.flatten.using_ints %[[OUT_TRANSPOSE]], {{.*}} -> !torch.vtensor<[1,1,16],f32> - %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf32>} : () -> !torch.vtensor<[1,2,0,8],f32> - %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf32>} : () -> !torch.vtensor<[1,2,0,8],f32> - %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %4:3 = torch.operator "onnx.GroupQueryAttention"(%arg0, %arg1, %arg2, %0, %1, %2, %3) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) - return %4#0, %4#1, %4#2 : !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32> -} - -// ----- - -// Test GQA with dynamic sequence length -// CHECK-LABEL: func.func @test_group_query_attention_dynamic_seq -func.func @test_group_query_attention_dynamic_seq(%arg0: !torch.vtensor<[1,?,16],f32>, %arg1: !torch.vtensor<[1,?,16],f32>, %arg2: !torch.vtensor<[1,?,16],f32>) -> (!torch.vtensor<[1,?,16],f32>, !torch.vtensor<[1,2,?,8],f32>, !torch.vtensor<[1,2,?,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[PAST_KEY:.+]] = torch.vtensor.literal(dense<> : tensor<1x2x0x8xf32>) : !torch.vtensor<[1,2,0,8],f32> - // CHECK: %[[PAST_VALUE:.+]] = torch.vtensor.literal(dense<> : tensor<1x2x0x8xf32>) : !torch.vtensor<[1,2,0,8],f32> - // CHECK: %[[SEQ_LEN:.+]] = torch.aten.size.int %arg0, {{.*}} : !torch.vtensor<[1,?,16],f32>, !torch.int -> !torch.int - // CHECK: %[[Q_RESHAPE:.+]] = torch.aten.unflatten.int %arg0, {{.*}} -> !torch.vtensor<[1,?,2,8],f32> - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int %[[Q_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,?,8],f32> - // CHECK: %[[K_RESHAPE:.+]] = torch.aten.unflatten.int %arg1, {{.*}} -> !torch.vtensor<[1,?,2,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int %[[K_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,?,8],f32> - // CHECK: %[[V_RESHAPE:.+]] = torch.aten.unflatten.int %arg2, {{.*}} -> !torch.vtensor<[1,?,2,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int %[[V_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,?,8],f32> - // CHECK: %[[PAD_KEY:.+]] = torch.aten.constant_pad_nd %[[PAST_KEY]], {{.*}} -> !torch.vtensor<[1,2,?,8],f32> - // CHECK: %[[PAD_VALUE:.+]] = torch.aten.constant_pad_nd %[[PAST_VALUE]], {{.*}} -> !torch.vtensor<[1,2,?,8],f32> - // CHECK: %[[PRESENT_KEY:.+]] = torch.aten.scatter.src %[[PAD_KEY]], {{.*}}, %[[K_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,?,8],f32> - // CHECK: %[[PRESENT_VALUE:.+]] = torch.aten.scatter.src %[[PAD_VALUE]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,?,8],f32> - // CHECK: %[[MASK:.+]] = torch.aten.unsqueeze {{.*}} -> !torch.vtensor<[1,1,?,?],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[PRESENT_KEY]], %[[PRESENT_VALUE]], %[[MASK]], {{.*}} -> !torch.vtensor<[1,2,?,8],f32> - // CHECK: %[[OUT_TRANSPOSE:.+]] = torch.aten.transpose.int %[[OUTPUT]], {{.*}} -> !torch.vtensor<[1,?,2,8],f32> - // CHECK: %[[OUT_RESHAPE:.+]] = torch.aten.flatten.using_ints %[[OUT_TRANSPOSE]], {{.*}} -> !torch.vtensor<[1,?,16],f32> - %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf32>} : () -> !torch.vtensor<[1,2,0,8],f32> - %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf32>} : () -> !torch.vtensor<[1,2,0,8],f32> - %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %4:3 = torch.operator "onnx.GroupQueryAttention"(%arg0, %arg1, %arg2, %0, %1, %2, %3) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[1,?,16],f32>, !torch.vtensor<[1,?,16],f32>, !torch.vtensor<[1,?,16],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,?,16],f32>, !torch.vtensor<[1,2,?,8],f32>, !torch.vtensor<[1,2,?,8],f32>) - return %4#0, %4#1, %4#2 : !torch.vtensor<[1,?,16],f32>, !torch.vtensor<[1,2,?,8],f32>, !torch.vtensor<[1,2,?,8],f32> -} - -// ----- - -// Test GQA with dynamic batch and sequence dimensions -// CHECK-LABEL: func.func @test_group_query_attention_dynamic_batch_seq -func.func @test_group_query_attention_dynamic_batch_seq(%arg0: !torch.vtensor<[?,?,16],f32>, %arg1: !torch.vtensor<[?,?,16],f32>, %arg2: !torch.vtensor<[?,?,16],f32>, %past_key: !torch.vtensor<[?,2,0,8],f32>, %past_value: !torch.vtensor<[?,2,0,8],f32>, %seqlens_k: !torch.vtensor<[?],si32>, %total_seq_length: !torch.vtensor<[?],si32>) -> (!torch.vtensor<[?,?,16],f32>, !torch.vtensor<[?,2,?,8],f32>, !torch.vtensor<[?,2,?,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[BATCH:.+]] = torch.aten.size.int %arg0, {{.*}} : !torch.vtensor<[?,?,16],f32>, !torch.int -> !torch.int - // CHECK: %[[SEQ_LEN:.+]] = torch.aten.size.int %arg0, {{.*}} : !torch.vtensor<[?,?,16],f32>, !torch.int -> !torch.int - // CHECK: %[[Q_RESHAPE:.+]] = torch.aten.unflatten.int %arg0, {{.*}} -> !torch.vtensor<[?,?,2,8],f32> - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int %[[Q_RESHAPE]], {{.*}} -> !torch.vtensor<[?,2,?,8],f32> - // CHECK: %[[K_RESHAPE:.+]] = torch.aten.unflatten.int %arg1, {{.*}} -> !torch.vtensor<[?,?,2,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int %[[K_RESHAPE]], {{.*}} -> !torch.vtensor<[?,2,?,8],f32> - // CHECK: %[[V_RESHAPE:.+]] = torch.aten.unflatten.int %arg2, {{.*}} -> !torch.vtensor<[?,?,2,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int %[[V_RESHAPE]], {{.*}} -> !torch.vtensor<[?,2,?,8],f32> - // CHECK: %[[PAD_KEY:.+]] = torch.aten.constant_pad_nd %arg3, {{.*}} -> !torch.vtensor<[?,2,?,8],f32> - // CHECK: %[[PAD_VALUE:.+]] = torch.aten.constant_pad_nd %arg4, {{.*}} -> !torch.vtensor<[?,2,?,8],f32> - // CHECK: %[[PRESENT_KEY:.+]] = torch.aten.scatter.src %[[PAD_KEY]], {{.*}}, %[[K_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[?,2,?,8],f32> - // CHECK: %[[PRESENT_VALUE:.+]] = torch.aten.scatter.src %[[PAD_VALUE]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[?,2,?,8],f32> - // CHECK: %[[MASK:.+]] = torch.aten.unsqueeze {{.*}} -> !torch.vtensor<[?,1,?,?],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[PRESENT_KEY]], %[[PRESENT_VALUE]], %[[MASK]], {{.*}} -> !torch.vtensor<[?,2,?,8],f32> - // CHECK: %[[OUT_TRANSPOSE:.+]] = torch.aten.transpose.int %[[OUTPUT]], {{.*}} -> !torch.vtensor<[?,?,2,8],f32> - // CHECK: %[[OUT_RESHAPE:.+]] = torch.aten.flatten.using_ints %[[OUT_TRANSPOSE]], {{.*}} -> !torch.vtensor<[?,?,16],f32> - %4:3 = torch.operator "onnx.GroupQueryAttention"(%arg0, %arg1, %arg2, %past_key, %past_value, %seqlens_k, %total_seq_length) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[?,?,16],f32>, !torch.vtensor<[?,?,16],f32>, !torch.vtensor<[?,?,16],f32>, !torch.vtensor<[?,2,0,8],f32>, !torch.vtensor<[?,2,0,8],f32>, !torch.vtensor<[?],si32>, !torch.vtensor<[?],si32>) -> (!torch.vtensor<[?,?,16],f32>, !torch.vtensor<[?,2,?,8],f32>, !torch.vtensor<[?,2,?,8],f32>) - return %4#0, %4#1, %4#2 : !torch.vtensor<[?,?,16],f32>, !torch.vtensor<[?,2,?,8],f32>, !torch.vtensor<[?,2,?,8],f32> -} - -// ----- - -// CHECK-LABEL: func.func @test_group_query_attention_with_rotary_embedding -func.func @test_group_query_attention_with_rotary_embedding(%query: !torch.vtensor<[1,1,16],f32>, %key: !torch.vtensor<[1,1,16],f32>, %value: !torch.vtensor<[1,1,16],f32>, %cos_cache: !torch.vtensor<[2,4],f32>, %sin_cache: !torch.vtensor<[2,4],f32>) -> (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[Q_ROTARY:.+]] = torch.onnx.rotary_embedding %[[Q_TRANSPOSE]], {{.*}} %arg3, %arg4, {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_ROTARY:.+]] = torch.onnx.rotary_embedding %[[K_TRANSPOSE]], {{.*}} %arg3, %arg4, {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_ROTARY]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[MASK:.+]] = torch.aten.le.Tensor {{.*}} -> !torch.vtensor<[1,1,1],i1> - // CHECK: %[[MASK_RESHAPE:.+]] = torch.aten.unsqueeze %[[MASK]], {{.*}} -> !torch.vtensor<[1,1,1,1],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_ROTARY]], %[[K_SCATTER]], %[[V_SCATTER]], %[[MASK_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[OUT_TRANSPOSE:.+]] = torch.aten.transpose.int %[[OUTPUT]], {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[OUT_RESHAPE:.+]] = torch.aten.flatten.using_ints %[[OUT_TRANSPOSE]], {{.*}} -> !torch.vtensor<[1,1,16],f32> - %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf32>} : () -> !torch.vtensor<[1,2,0,8],f32> - %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf32>} : () -> !torch.vtensor<[1,2,0,8],f32> - %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %none = torch.constant.none - %4:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %0, %1, %2, %3, %cos_cache, %sin_cache) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64, torch.onnx.do_rotary = 1 : si64} : (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[2,4],f32>, !torch.vtensor<[2,4],f32>) -> (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) - return %4#0, %4#1, %4#2 : !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32> -} - -// ----- - -// Test GroupQueryAttention with packed QKV input and rotary embedding -// packed_qkv shape: [batch, seq, num_heads*head_size + 2*kv_num_heads*head_size] -// num_heads=4, kv_num_heads=2 so Q slice [1,1,32] differs from K/V slices [1,1,16] -// CHECK-LABEL: func.func @test_group_query_attention_packed_qkv_rotary -func.func @test_group_query_attention_packed_qkv_rotary(%packed_qkv: !torch.vtensor<[1,1,64],f32>, %past_key: !torch.vtensor<[1,2,0,8],f32>, %past_value: !torch.vtensor<[1,2,0,8],f32>, %cos_cache: !torch.vtensor<[2,4],f32>, %sin_cache: !torch.vtensor<[2,4],f32>) -> (!torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // Q slice is [1,1,32] (num_heads*head_size=4*8), K/V are [1,1,16] (kv_num_heads*head_size=2*8) - // CHECK: %[[Q_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[1,1,32],f32> - // CHECK: %[[K_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[1,1,16],f32> - // CHECK: %[[V_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[1,1,16],f32> - // CHECK: %[[Q_RESHAPE:.+]] = torch.aten.unflatten.int %[[Q_SLICE]], {{.*}} -> !torch.vtensor<[1,1,4,8],f32> - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int %[[Q_RESHAPE]], {{.*}} -> !torch.vtensor<[1,4,1,8],f32> - // CHECK: %[[K_RESHAPE:.+]] = torch.aten.unflatten.int %[[K_SLICE]], {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int %[[K_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_RESHAPE:.+]] = torch.aten.unflatten.int %[[V_SLICE]], {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int %[[V_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[Q_ROTARY:.+]] = torch.onnx.rotary_embedding %[[Q_TRANSPOSE]], {{.*}} -> !torch.vtensor<[1,4,1,8],f32> - // CHECK: %[[K_ROTARY:.+]] = torch.onnx.rotary_embedding %[[K_TRANSPOSE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd %arg1, {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd %arg2, {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_ROTARY]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_ROTARY]], %[[K_SCATTER]], %[[V_SCATTER]], {{.*}} -> !torch.vtensor<[1,4,1,8],f32> - %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %2:3 = torch.operator "onnx.GroupQueryAttention"(%packed_qkv, %past_key, %past_value, %0, %1, %cos_cache, %sin_cache) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 4 : si64, torch.onnx.do_rotary = 1 : si64} : (!torch.vtensor<[1,1,64],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[2,4],f32>, !torch.vtensor<[2,4],f32>) -> (!torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) - return %2#0, %2#1, %2#2 : !torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32> -} - -// ----- - -// Test GroupQueryAttention with packed QKV and dynamic dimensions +// CHECK-LABEL: func.func @test_group_query_attention_fixed_capacity_cache +// CHECK-SAME: %[[PAST_KEY:[a-zA-Z0-9_]+]]: !torch.vtensor<[1,2,4,8],f32> +// CHECK-SAME: %[[PAST_VALUE:[a-zA-Z0-9_]+]]: !torch.vtensor<[1,2,4,8],f32> +func.func @test_group_query_attention_fixed_capacity_cache(%query: !torch.vtensor<[1,2,16],f32>, %key: !torch.vtensor<[1,2,16],f32>, %value: !torch.vtensor<[1,2,16],f32>, %past_key: !torch.vtensor<[1,2,4,8],f32>, %past_value: !torch.vtensor<[1,2,4,8],f32>, %seqlens_k: !torch.vtensor<[1],si32>, %total_seq_length: !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,4,8],f32>, !torch.vtensor<[1,2,4,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK-NOT: torch.aten.constant_pad_nd + // CHECK-NOT: torch.aten.view + // CHECK: %[[K_FLAT:.+]] = torch.aten.flatten.using_ints %[[PAST_KEY]], %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,4,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8,8],f32> + // CHECK: %[[V_FLAT:.+]] = torch.aten.flatten.using_ints %[[PAST_VALUE]], %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,4,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8,8],f32> + // CHECK: %[[K_WRITE:.+]] = torch.aten.index_put.hacked_twin %[[K_FLAT]], %{{.*}} -> !torch.vtensor<[8,8],f32> + // CHECK: %[[V_WRITE:.+]] = torch.aten.index_put.hacked_twin %[[V_FLAT]], %{{.*}} -> !torch.vtensor<[8,8],f32> + // CHECK: %[[PRESENT_KEY:.+]] = torch.aten.unflatten.int %[[K_WRITE]], %{{.*}}, %{{.*}} : !torch.vtensor<[8,8],f32>, !torch.int, !torch.list -> !torch.vtensor<[1,2,4,8],f32> + // CHECK: %[[PRESENT_VALUE:.+]] = torch.aten.unflatten.int %[[V_WRITE]], %{{.*}}, %{{.*}} : !torch.vtensor<[8,8],f32>, !torch.int, !torch.list -> !torch.vtensor<[1,2,4,8],f32> + // CHECK: torch.aten.scaled_dot_product_attention %{{.*}}, %[[PRESENT_KEY]], %[[PRESENT_VALUE]], %{{.*}} -> !torch.vtensor<[1,2,2,8],f32> + %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_length) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,4,8],f32>, !torch.vtensor<[1,2,4,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,4,8],f32>, !torch.vtensor<[1,2,4,8],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,4,8],f32>, !torch.vtensor<[1,2,4,8],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_group_query_attention_packed_qkv_dynamic func.func @test_group_query_attention_packed_qkv_dynamic(%packed_qkv: !torch.vtensor<[?,?,6144],f32>, %past_key: !torch.vtensor<[?,8,?,128],f32>, %past_value: !torch.vtensor<[?,8,?,128],f32>, %seqlens_k: !torch.vtensor<[?],si32>, %total_seq_len: !torch.vtensor<[],si32>, %cos_cache: !torch.vtensor<[131072,64],f32>, %sin_cache: !torch.vtensor<[131072,64],f32>) -> (!torch.vtensor<[?,?,4096],f32>, !torch.vtensor<[?,8,?,128],f32>, !torch.vtensor<[?,8,?,128],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[?,?,4096],f32> - // CHECK: %[[K_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[?,?,1024],f32> - // CHECK: %[[V_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[?,?,1024],f32> - // CHECK: %[[Q_RESHAPE:.+]] = torch.aten.unflatten.int %[[Q_SLICE]], {{.*}} -> !torch.vtensor<[?,?,32,128],f32> - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int %[[Q_RESHAPE]], {{.*}} -> !torch.vtensor<[?,32,?,128],f32> - // CHECK: %[[K_RESHAPE:.+]] = torch.aten.unflatten.int %[[K_SLICE]], {{.*}} -> !torch.vtensor<[?,?,8,128],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int %[[K_RESHAPE]], {{.*}} -> !torch.vtensor<[?,8,?,128],f32> - // CHECK: %[[V_RESHAPE:.+]] = torch.aten.unflatten.int %[[V_SLICE]], {{.*}} -> !torch.vtensor<[?,?,8,128],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int %[[V_RESHAPE]], {{.*}} -> !torch.vtensor<[?,8,?,128],f32> - // CHECK: %[[Q_ROTARY:.+]] = torch.onnx.rotary_embedding %[[Q_TRANSPOSE]], {{.*}} -> !torch.vtensor<[?,32,?,128],f32> - // CHECK: %[[K_ROTARY:.+]] = torch.onnx.rotary_embedding %[[K_TRANSPOSE]], {{.*}} -> !torch.vtensor<[?,8,?,128],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd %arg1, {{.*}} -> !torch.vtensor<[?,8,?,128],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd %arg2, {{.*}} -> !torch.vtensor<[?,8,?,128],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_ROTARY]] : {{.*}} -> !torch.vtensor<[?,8,?,128],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[?,8,?,128],f32> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_ROTARY]], %[[K_SCATTER]], %[[V_SCATTER]], {{.*}} -> !torch.vtensor<[?,32,?,128],f32> + // CHECK: torch.aten.slice.Tensor %arg0, %{{.*}} -> !torch.vtensor<[?,?,4096],f32> + // CHECK: torch.aten.slice.Tensor %arg0, %{{.*}} -> !torch.vtensor<[?,?,1024],f32> + // CHECK: torch.aten.slice.Tensor %arg0, %{{.*}} -> !torch.vtensor<[?,?,1024],f32> + // CHECK: torch.aten.transpose.int %{{.*}} -> !torch.vtensor<[?,32,?,128],f32> + // CHECK: torch.aten.transpose.int %{{.*}} -> !torch.vtensor<[?,8,?,128],f32> + // CHECK: torch.aten.transpose.int %{{.*}} -> !torch.vtensor<[?,8,?,128],f32> + // CHECK: torch.onnx.rotary_embedding %{{.*}} -> !torch.vtensor<[?,32,?,128],f32> + // CHECK: torch.onnx.rotary_embedding %{{.*}} -> !torch.vtensor<[?,8,?,128],f32> + // CHECK-NOT: torch.aten.constant_pad_nd + // CHECK-NOT: torch.aten.view + // CHECK: %[[K_FLAT:.+]] = torch.aten.flatten.using_ints %arg1, %{{.*}}, %{{.*}} : !torch.vtensor<[?,8,?,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> + // CHECK: %[[V_FLAT:.+]] = torch.aten.flatten.using_ints %arg2, %{{.*}}, %{{.*}} : !torch.vtensor<[?,8,?,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> + // CHECK: %[[K_WRITE:.+]] = torch.aten.index_put.hacked_twin %[[K_FLAT]], %{{.*}} -> !torch.vtensor<[?,128],f32> + // CHECK: %[[V_WRITE:.+]] = torch.aten.index_put.hacked_twin %[[V_FLAT]], %{{.*}} -> !torch.vtensor<[?,128],f32> + // CHECK: %[[PRESENT_KEY:.+]] = torch.aten.unflatten.int %[[K_WRITE]], %{{.*}}, %{{.*}} : !torch.vtensor<[?,128],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,8,?,128],f32> + // CHECK: %[[PRESENT_VALUE:.+]] = torch.aten.unflatten.int %[[V_WRITE]], %{{.*}}, %{{.*}} : !torch.vtensor<[?,128],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,8,?,128],f32> + // CHECK: torch.aten.scaled_dot_product_attention %{{.*}}, %[[PRESENT_KEY]], %[[PRESENT_VALUE]], %{{.*}} -> !torch.vtensor<[?,32,?,128],f32> %0:3 = torch.operator "onnx.GroupQueryAttention"(%packed_qkv, %past_key, %past_value, %seqlens_k, %total_seq_len, %cos_cache, %sin_cache) {torch.onnx.kv_num_heads = 8 : si64, torch.onnx.num_heads = 32 : si64, torch.onnx.do_rotary = 1 : si64, torch.onnx.smooth_softmax = -1 : si64, torch.onnx.scale = 8.838835e-02 : f32} : (!torch.vtensor<[?,?,6144],f32>, !torch.vtensor<[?,8,?,128],f32>, !torch.vtensor<[?,8,?,128],f32>, !torch.vtensor<[?],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[131072,64],f32>, !torch.vtensor<[131072,64],f32>) -> (!torch.vtensor<[?,?,4096],f32>, !torch.vtensor<[?,8,?,128],f32>, !torch.vtensor<[?,8,?,128],f32>) return %0#0, %0#1, %0#2 : !torch.vtensor<[?,?,4096],f32>, !torch.vtensor<[?,8,?,128],f32>, !torch.vtensor<[?,8,?,128],f32> } // ----- -// Test GroupQueryAttention with packed QKV but without rotary embedding -// num_heads=4, kv_num_heads=2 so Q slice [1,1,32] differs from K/V slices [1,1,16] -// CHECK-LABEL: func.func @test_group_query_attention_packed_qkv_no_rotary -func.func @test_group_query_attention_packed_qkv_no_rotary(%packed_qkv: !torch.vtensor<[1,1,64],f32>, %past_key: !torch.vtensor<[1,2,0,8],f32>, %past_value: !torch.vtensor<[1,2,0,8],f32>) -> (!torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[1,1,32],f32> - // CHECK: %[[K_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[1,1,16],f32> - // CHECK: %[[V_SLICE:.+]] = torch.aten.slice.Tensor %arg0, {{.*}} -> !torch.vtensor<[1,1,16],f32> - // CHECK: %[[Q_RESHAPE:.+]] = torch.aten.unflatten.int %[[Q_SLICE]], {{.*}} -> !torch.vtensor<[1,1,4,8],f32> - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int %[[Q_RESHAPE]], {{.*}} -> !torch.vtensor<[1,4,1,8],f32> - // CHECK: %[[K_RESHAPE:.+]] = torch.aten.unflatten.int %[[K_SLICE]], {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int %[[K_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_RESHAPE:.+]] = torch.aten.unflatten.int %[[V_SLICE]], {{.*}} -> !torch.vtensor<[1,1,2,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int %[[V_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK-NOT: torch.onnx.rotary_embedding - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd %arg1, {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd %arg2, {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[K_SCATTER]], %[[V_SCATTER]], {{.*}} -> !torch.vtensor<[1,4,1,8],f32> - %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %2:3 = torch.operator "onnx.GroupQueryAttention"(%packed_qkv, %past_key, %past_value, %0, %1) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 4 : si64} : (!torch.vtensor<[1,1,64],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) - return %2#0, %2#1, %2#2 : !torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32> -} - -// ----- - -// Test GQA with num_heads != kv_num_heads and rotary embedding -// CHECK-LABEL: func.func @test_group_query_attention_gqa_rotary -func.func @test_group_query_attention_gqa_rotary(%query: !torch.vtensor<[1,1,32],f32>, %key: !torch.vtensor<[1,1,16],f32>, %value: !torch.vtensor<[1,1,16],f32>, %cos_cache: !torch.vtensor<[2,4],f32>, %sin_cache: !torch.vtensor<[2,4],f32>) -> (!torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,4,1,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[Q_ROTARY:.+]] = torch.onnx.rotary_embedding %[[Q_TRANSPOSE]], {{.*}} -> !torch.vtensor<[1,4,1,8],f32> - // CHECK: %[[K_ROTARY:.+]] = torch.onnx.rotary_embedding %[[K_TRANSPOSE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_ROTARY]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_ROTARY]], %[[K_SCATTER]], %[[V_SCATTER]], {{.*}} -> !torch.vtensor<[1,4,1,8],f32> - // CHECK: %[[OUT_TRANSPOSE:.+]] = torch.aten.transpose.int %[[OUTPUT]], {{.*}} -> !torch.vtensor<[1,1,4,8],f32> - // CHECK: %[[OUT_RESHAPE:.+]] = torch.aten.flatten.using_ints %[[OUT_TRANSPOSE]], {{.*}} -> !torch.vtensor<[1,1,32],f32> - %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf32>} : () -> !torch.vtensor<[1,2,0,8],f32> - %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf32>} : () -> !torch.vtensor<[1,2,0,8],f32> - %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %4:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %0, %1, %2, %3, %cos_cache, %sin_cache) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 4 : si64, torch.onnx.do_rotary = 1 : si64} : (!torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1,2,0,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[2,4],f32>, !torch.vtensor<[2,4],f32>) -> (!torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32>) - return %4#0, %4#1, %4#2 : !torch.vtensor<[1,1,32],f32>, !torch.vtensor<[1,2,1,8],f32>, !torch.vtensor<[1,2,1,8],f32> -} - -// ----- - -// Test GQA with non-zero past KV cache -// CHECK-LABEL: func.func @test_group_query_attention_kv_cache -func.func @test_group_query_attention_kv_cache(%query: !torch.vtensor<[1,1,16],f32>, %key: !torch.vtensor<[1,1,16],f32>, %value: !torch.vtensor<[1,1,16],f32>, %past_key: !torch.vtensor<[1,2,4,8],f32>, %past_value: !torch.vtensor<[1,2,4,8],f32>) -> (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,5,8],f32>, !torch.vtensor<[1,2,5,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SEQLENS_K:.+]] = torch.aten.to.dtype {{.*}} -> !torch.vtensor<[1],si64> - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd %arg3, {{.*}} -> !torch.vtensor<[1,2,5,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd %arg4, {{.*}} -> !torch.vtensor<[1,2,5,8],f32> - // CHECK: %[[Q_RANGE:.+]] = torch.aten.arange {{.*}} -> !torch.vtensor<[1],si64> - // CHECK: %[[PAST_VIEW:.+]] = torch.aten.unsqueeze {{.*}} -> !torch.vtensor<[1,1,1,1],si64> - // CHECK: %[[Q_RANGE_4D:.+]] = torch.aten.unsqueeze {{.*}} -> !torch.vtensor<[1,1,1,1],si64> - // CHECK: %[[IDX_BASE:.+]] = torch.aten.add.Tensor %[[PAST_VIEW]], %[[Q_RANGE_4D]], {{.*}} -> !torch.vtensor<[1,1,1,1],si64> - // CHECK: %[[SCATTER_IDX:.+]] = torch.aten.expand %[[IDX_BASE]], {{.*}} -> !torch.vtensor<[1,2,1,8],si64> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], - // CHECK-SAME: %[[SCATTER_IDX]], %[[K_TRANSPOSE]] : - // CHECK-SAME: -> !torch.vtensor<[1,2,5,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], - // CHECK-SAME: %[[SCATTER_IDX]], %[[V_TRANSPOSE]] : - // CHECK-SAME: -> !torch.vtensor<[1,2,5,8],f32> - // CHECK: %[[MASK:.+]] = torch.aten.le.Tensor {{.*}} -> !torch.vtensor<[1,1,5],i1> - // CHECK: %[[MASK_RESHAPE:.+]] = torch.aten.unsqueeze %[[MASK]], {{.*}} -> !torch.vtensor<[1,1,1,5],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[K_SCATTER]], %[[V_SCATTER]], %[[MASK_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f32> - // CHECK: %[[OUT_RESHAPE:.+]] = torch.aten.flatten.using_ints {{.*}} -> !torch.vtensor<[1,1,16],f32> - %seqlens_k = torch.operator "onnx.Constant"() {torch.onnx.value = dense<4> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %total_seq_len = torch.operator "onnx.Constant"() {torch.onnx.value = dense<5> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_len) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,4,8],f32>, !torch.vtensor<[1,2,4,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,5,8],f32>, !torch.vtensor<[1,2,5,8],f32>) - return %0#0, %0#1, %0#2 : !torch.vtensor<[1,1,16],f32>, !torch.vtensor<[1,2,5,8],f32>, !torch.vtensor<[1,2,5,8],f32> -} - -// ----- - -// Test GQA causal mask shape with multi-token sequence -// CHECK-LABEL: func.func @test_group_query_attention_seqlens_k_mask -func.func @test_group_query_attention_seqlens_k_mask(%query: !torch.vtensor<[1,4,16],f32>, %key: !torch.vtensor<[1,4,16],f32>, %value: !torch.vtensor<[1,4,16],f32>, %past_key: !torch.vtensor<[1,2,3,8],f32>, %past_value: !torch.vtensor<[1,2,3,8],f32>) -> (!torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,2,7,8],f32>, !torch.vtensor<[1,2,7,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd %arg3, {{.*}} -> !torch.vtensor<[1,2,7,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd %arg4, {{.*}} -> !torch.vtensor<[1,2,7,8],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,7,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,7,8],f32> - // CHECK: %[[MASK:.+]] = torch.aten.le.Tensor {{.*}} -> !torch.vtensor<[1,4,7],i1> - // CHECK: %[[MASK_RESHAPE:.+]] = torch.aten.unsqueeze %[[MASK]], {{.*}} -> !torch.vtensor<[1,1,4,7],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[K_SCATTER]], %[[V_SCATTER]], %[[MASK_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - %seqlens_k = torch.operator "onnx.Constant"() {torch.onnx.value = dense<6> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %total_seq_len = torch.operator "onnx.Constant"() {torch.onnx.value = dense<7> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_len) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,2,3,8],f32>, !torch.vtensor<[1,2,3,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,2,7,8],f32>, !torch.vtensor<[1,2,7,8],f32>) - return %0#0, %0#1, %0#2 : !torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,2,7,8],f32>, !torch.vtensor<[1,2,7,8],f32> -} - -// ----- - -// Test GQA with multi-token prefill (seqLen=2) -// CHECK-LABEL: func.func @test_group_query_attention_prefill_mask_shape -func.func @test_group_query_attention_prefill_mask_shape(%query: !torch.vtensor<[1,2,16],f32>, %key: !torch.vtensor<[1,2,16],f32>, %value: !torch.vtensor<[1,2,16],f32>, %past_key: !torch.vtensor<[1,2,3,8],f32>, %past_value: !torch.vtensor<[1,2,3,8],f32>) -> (!torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,5,8],f32>, !torch.vtensor<[1,2,5,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,2,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,2,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,2,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd %arg3, {{.*}} -> !torch.vtensor<[1,2,5,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd %arg4, {{.*}} -> !torch.vtensor<[1,2,5,8],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,5,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,5,8],f32> - // CHECK: %[[MASK:.+]] = torch.aten.le.Tensor {{.*}} -> !torch.vtensor<[1,2,5],i1> - // CHECK: %[[MASK_RESHAPE:.+]] = torch.aten.unsqueeze %[[MASK]], {{.*}} -> !torch.vtensor<[1,1,2,5],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[K_SCATTER]], %[[V_SCATTER]], %[[MASK_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,2,8],f32> - // CHECK: %[[OUT_RESHAPE:.+]] = torch.aten.flatten.using_ints {{.*}} -> !torch.vtensor<[1,2,16],f32> - %seqlens_k = torch.operator "onnx.Constant"() {torch.onnx.value = dense<4> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %total_seq_len = torch.operator "onnx.Constant"() {torch.onnx.value = dense<5> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_len) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,3,8],f32>, !torch.vtensor<[1,2,3,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,5,8],f32>, !torch.vtensor<[1,2,5,8],f32>) - return %0#0, %0#1, %0#2 : !torch.vtensor<[1,2,16],f32>, !torch.vtensor<[1,2,5,8],f32>, !torch.vtensor<[1,2,5,8],f32> -} - -// ----- - -// Test GQA position ID calculation with rotary embeddings -// CHECK-LABEL: func.func @test_group_query_attention_position_ids -func.func @test_group_query_attention_position_ids(%query: !torch.vtensor<[1,4,16],f32>, %key: !torch.vtensor<[1,4,16],f32>, %value: !torch.vtensor<[1,4,16],f32>, %past_key: !torch.vtensor<[1,2,3,8],f32>, %past_value: !torch.vtensor<[1,2,3,8],f32>, %cos_cache: !torch.vtensor<[2,4],f32>, %sin_cache: !torch.vtensor<[2,4],f32>) -> (!torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,2,7,8],f32>, !torch.vtensor<[1,2,7,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[POS_IDS:.+]] = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1,4],si64> - // CHECK: %[[Q_ROTARY:.+]] = torch.onnx.rotary_embedding %[[Q_TRANSPOSE]], %[[POS_IDS]], {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[K_ROTARY:.+]] = torch.onnx.rotary_embedding %[[K_TRANSPOSE]], %[[POS_IDS]], {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd %arg3, {{.*}} -> !torch.vtensor<[1,2,7,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd %arg4, {{.*}} -> !torch.vtensor<[1,2,7,8],f32> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_ROTARY]] : {{.*}} -> !torch.vtensor<[1,2,7,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,7,8],f32> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_ROTARY]], %[[K_SCATTER]], %[[V_SCATTER]], {{.*}} -> !torch.vtensor<[1,2,4,8],f32> - // CHECK: %[[OUT_RESHAPE:.+]] = torch.aten.flatten.using_ints {{.*}} -> !torch.vtensor<[1,4,16],f32> - %seqlens_k = torch.operator "onnx.Constant"() {torch.onnx.value = dense<6> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %total_seq_len = torch.operator "onnx.Constant"() {torch.onnx.value = dense<7> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_len, %cos_cache, %sin_cache) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64, torch.onnx.do_rotary = 1 : si64} : (!torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,2,3,8],f32>, !torch.vtensor<[1,2,3,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[2,4],f32>, !torch.vtensor<[2,4],f32>) -> (!torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,2,7,8],f32>, !torch.vtensor<[1,2,7,8],f32>) - return %0#0, %0#1, %0#2 : !torch.vtensor<[1,4,16],f32>, !torch.vtensor<[1,2,7,8],f32>, !torch.vtensor<[1,2,7,8],f32> -} - -// ----- - -// Test GQA with f16 inputs -// CHECK-LABEL: func.func @test_group_query_attention_f16 -func.func @test_group_query_attention_f16(%arg0: !torch.vtensor<[1,1,16],f16>, %arg1: !torch.vtensor<[1,1,16],f16>, %arg2: !torch.vtensor<[1,1,16],f16>) -> (!torch.vtensor<[1,1,16],f16>, !torch.vtensor<[1,2,1,8],f16>, !torch.vtensor<[1,2,1,8],f16>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f16> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f16> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[1,2,1,8],f16> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd {{.*}} -> !torch.vtensor<[1,2,1,8],f16> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd {{.*}} -> !torch.vtensor<[1,2,1,8],f16> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], {{.*}}, %[[K_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f16> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], {{.*}}, %[[V_TRANSPOSE]] : {{.*}} -> !torch.vtensor<[1,2,1,8],f16> - // CHECK: %[[MASK:.+]] = torch.aten.le.Tensor {{.*}} -> !torch.vtensor<[1,1,1],i1> - // CHECK: %[[MASK_RESHAPE:.+]] = torch.aten.unsqueeze %[[MASK]], {{.*}} -> !torch.vtensor<[1,1,1,1],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[K_SCATTER]], %[[V_SCATTER]], %[[MASK_RESHAPE]], {{.*}} -> !torch.vtensor<[1,2,1,8],f16> - %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf16>} : () -> !torch.vtensor<[1,2,0,8],f16> - %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<> : tensor<1x2x0x8xf16>} : () -> !torch.vtensor<[1,2,0,8],f16> - %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %4:3 = torch.operator "onnx.GroupQueryAttention"(%arg0, %arg1, %arg2, %0, %1, %2, %3) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[1,1,16],f16>, !torch.vtensor<[1,1,16],f16>, !torch.vtensor<[1,1,16],f16>, !torch.vtensor<[1,2,0,8],f16>, !torch.vtensor<[1,2,0,8],f16>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[1,1,16],f16>, !torch.vtensor<[1,2,1,8],f16>, !torch.vtensor<[1,2,1,8],f16>) - return %4#0, %4#1, %4#2 : !torch.vtensor<[1,1,16],f16>, !torch.vtensor<[1,2,1,8],f16>, !torch.vtensor<[1,2,1,8],f16> -} - -// ----- - -// Test GQA with variable seqlens_k across batch (batch=2) -// CHECK-LABEL: func.func @test_group_query_attention_variable_seqlens_k -func.func @test_group_query_attention_variable_seqlens_k(%query: !torch.vtensor<[2,1,16],f32>, %key: !torch.vtensor<[2,1,16],f32>, %value: !torch.vtensor<[2,1,16],f32>, %past_key: !torch.vtensor<[2,2,4,8],f32>, %past_value: !torch.vtensor<[2,2,4,8],f32>, %seqlens_k: !torch.vtensor<[2],si32>) -> (!torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,2,5,8],f32>, !torch.vtensor<[2,2,5,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SEQLENS_K:.+]] = torch.aten.to.dtype %arg5, {{.*}} -> !torch.vtensor<[2],si64> - // CHECK: %[[Q_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[2,2,1,8],f32> - // CHECK: %[[K_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[2,2,1,8],f32> - // CHECK: %[[V_TRANSPOSE:.+]] = torch.aten.transpose.int {{.*}} -> !torch.vtensor<[2,2,1,8],f32> - // CHECK: %[[K_PAD:.+]] = torch.aten.constant_pad_nd %arg3, {{.*}} -> !torch.vtensor<[2,2,5,8],f32> - // CHECK: %[[V_PAD:.+]] = torch.aten.constant_pad_nd %arg4, {{.*}} -> !torch.vtensor<[2,2,5,8],f32> - // CHECK: %[[Q_RANGE:.+]] = torch.aten.arange {{.*}} -> !torch.vtensor<[1],si64> - // CHECK: %[[PAST_VIEW:.+]] = torch.aten.unsqueeze {{.*}} -> !torch.vtensor<[2,1,1,1],si64> - // CHECK: %[[Q_RANGE_4D:.+]] = torch.aten.unsqueeze {{.*}} -> !torch.vtensor<[1,1,1,1],si64> - // CHECK: %[[IDX_BASE:.+]] = torch.aten.add.Tensor %[[PAST_VIEW]], %[[Q_RANGE_4D]], {{.*}} -> !torch.vtensor<[2,1,1,1],si64> - // CHECK: %[[SCATTER_IDX:.+]] = torch.aten.expand %[[IDX_BASE]], {{.*}} -> !torch.vtensor<[2,2,1,8],si64> - // CHECK: %[[K_SCATTER:.+]] = torch.aten.scatter.src %[[K_PAD]], - // CHECK-SAME: %[[SCATTER_IDX]], %[[K_TRANSPOSE]] : - // CHECK-SAME: -> !torch.vtensor<[2,2,5,8],f32> - // CHECK: %[[V_SCATTER:.+]] = torch.aten.scatter.src %[[V_PAD]], - // CHECK-SAME: %[[SCATTER_IDX]], %[[V_TRANSPOSE]] : - // CHECK-SAME: -> !torch.vtensor<[2,2,5,8],f32> - // CHECK: %[[MASK:.+]] = torch.aten.le.Tensor {{.*}} -> !torch.vtensor<[2,1,5],i1> - // CHECK: %[[MASK_RESHAPE:.+]] = torch.aten.unsqueeze %[[MASK]], {{.*}} -> !torch.vtensor<[2,1,1,5],i1> - // CHECK: %[[OUTPUT:.+]] = torch.aten.scaled_dot_product_attention %[[Q_TRANSPOSE]], %[[K_SCATTER]], %[[V_SCATTER]], %[[MASK_RESHAPE]], {{.*}} -> !torch.vtensor<[2,2,1,8],f32> - %total_seq_len = torch.operator "onnx.Constant"() {torch.onnx.value = dense<5> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> - %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_len) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,2,4,8],f32>, !torch.vtensor<[2,2,4,8],f32>, !torch.vtensor<[2],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,2,5,8],f32>, !torch.vtensor<[2,2,5,8],f32>) - return %0#0, %0#1, %0#2 : !torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,2,5,8],f32>, !torch.vtensor<[2,2,5,8],f32> +// CHECK-LABEL: func.func @test_group_query_attention_variable_seqlens_k_f16 +// CHECK-SAME: %[[PAST_KEY:[a-zA-Z0-9_]+]]: !torch.vtensor<[2,2,4,8],f16> +// CHECK-SAME: %[[PAST_VALUE:[a-zA-Z0-9_]+]]: !torch.vtensor<[2,2,4,8],f16> +func.func @test_group_query_attention_variable_seqlens_k_f16(%query: !torch.vtensor<[2,1,16],f16>, %key: !torch.vtensor<[2,1,16],f16>, %value: !torch.vtensor<[2,1,16],f16>, %past_key: !torch.vtensor<[2,2,4,8],f16>, %past_value: !torch.vtensor<[2,2,4,8],f16>, %seqlens_k: !torch.vtensor<[2],si32>, %total_seq_length: !torch.vtensor<[1],si32>) -> (!torch.vtensor<[2,1,16],f16>, !torch.vtensor<[2,2,4,8],f16>, !torch.vtensor<[2,2,4,8],f16>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.to.dtype %{{.*}} -> !torch.vtensor<[2],si64> + // CHECK-NOT: torch.aten.constant_pad_nd + // CHECK: torch.aten.index_put.hacked_twin %{{.*}} -> !torch.vtensor<[16,8],f16> + // CHECK: torch.aten.index_put.hacked_twin %{{.*}} -> !torch.vtensor<[16,8],f16> + // CHECK: torch.aten.unsqueeze %{{.*}} -> !torch.vtensor<[2,1,1,4],i1> + // CHECK: torch.aten.scaled_dot_product_attention %{{.*}} -> !torch.vtensor<[2,2,1,8],f16> + %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_length) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[2,1,16],f16>, !torch.vtensor<[2,1,16],f16>, !torch.vtensor<[2,1,16],f16>, !torch.vtensor<[2,2,4,8],f16>, !torch.vtensor<[2,2,4,8],f16>, !torch.vtensor<[2],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[2,1,16],f16>, !torch.vtensor<[2,2,4,8],f16>, !torch.vtensor<[2,2,4,8],f16>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[2,1,16],f16>, !torch.vtensor<[2,2,4,8],f16>, !torch.vtensor<[2,2,4,8],f16> +} + +// ----- + +// CHECK-LABEL: func.func @test_group_query_attention_2d_seqlens_k +// CHECK-SAME: %[[SEQLENS_K:[a-zA-Z0-9_]+]]: !torch.vtensor<[2,1],si32> +func.func @test_group_query_attention_2d_seqlens_k(%query: !torch.vtensor<[2,1,16],f32>, %key: !torch.vtensor<[2,1,16],f32>, %value: !torch.vtensor<[2,1,16],f32>, %past_key: !torch.vtensor<[2,2,4,8],f32>, %past_value: !torch.vtensor<[2,2,4,8],f32>, %seqlens_k: !torch.vtensor<[2,1],si32>, %total_seq_length: !torch.vtensor<[1],si32>) -> (!torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,2,4,8],f32>, !torch.vtensor<[2,2,4,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SEQLENS_I64:.+]] = torch.aten.to.dtype %[[SEQLENS_K]]{{.*}} -> !torch.vtensor<[2,1],si64> + // CHECK: %[[SEQLENS_1D:.+]] = torch.aten.squeeze.dim %[[SEQLENS_I64]], %{{.*}} : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64> + // CHECK: torch.aten.add.Scalar %[[SEQLENS_1D]] + // CHECK-NOT: torch.aten.constant_pad_nd + // CHECK: torch.aten.index_put.hacked_twin %{{.*}} -> !torch.vtensor<[16,8],f32> + // CHECK: torch.aten.index_put.hacked_twin %{{.*}} -> !torch.vtensor<[16,8],f32> + // CHECK: torch.aten.scaled_dot_product_attention %{{.*}} -> !torch.vtensor<[2,2,1,8],f32> + %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_length) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64} : (!torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,2,4,8],f32>, !torch.vtensor<[2,2,4,8],f32>, !torch.vtensor<[2,1],si32>, !torch.vtensor<[1],si32>) -> (!torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,2,4,8],f32>, !torch.vtensor<[2,2,4,8],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[2,1,16],f32>, !torch.vtensor<[2,2,4,8],f32>, !torch.vtensor<[2,2,4,8],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_group_query_attention_separate_qkv_rotary +func.func @test_group_query_attention_separate_qkv_rotary(%query: !torch.vtensor<[1,3,16],f32>, %key: !torch.vtensor<[1,3,16],f32>, %value: !torch.vtensor<[1,3,16],f32>, %past_key: !torch.vtensor<[1,2,6,8],f32>, %past_value: !torch.vtensor<[1,2,6,8],f32>, %seqlens_k: !torch.vtensor<[1],si32>, %total_seq_length: !torch.vtensor<[1],si32>, %cos_cache: !torch.vtensor<[16,4],f32>, %sin_cache: !torch.vtensor<[16,4],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[1,2,6,8],f32>, !torch.vtensor<[1,2,6,8],f32>) attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.onnx.rotary_embedding %{{.*}} -> !torch.vtensor<[1,2,3,8],f32> + // CHECK: torch.onnx.rotary_embedding %{{.*}} -> !torch.vtensor<[1,2,3,8],f32> + // CHECK-NOT: torch.aten.constant_pad_nd + // CHECK: torch.aten.index_put.hacked_twin %{{.*}} -> !torch.vtensor<[12,8],f32> + // CHECK: torch.aten.index_put.hacked_twin %{{.*}} -> !torch.vtensor<[12,8],f32> + // CHECK: torch.aten.unsqueeze %{{.*}} -> !torch.vtensor<[1,1,3,6],i1> + // CHECK: torch.aten.scaled_dot_product_attention %{{.*}} -> !torch.vtensor<[1,2,3,8],f32> + %0:3 = torch.operator "onnx.GroupQueryAttention"(%query, %key, %value, %past_key, %past_value, %seqlens_k, %total_seq_length, %cos_cache, %sin_cache) {torch.onnx.kv_num_heads = 2 : si64, torch.onnx.num_heads = 2 : si64, torch.onnx.do_rotary = 1 : si64} : (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[1,3,16],f32>, !torch.vtensor<[1,3,16],f32>, !torch.vtensor<[1,2,6,8],f32>, !torch.vtensor<[1,2,6,8],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>, !torch.vtensor<[16,4],f32>, !torch.vtensor<[16,4],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[1,2,6,8],f32>, !torch.vtensor<[1,2,6,8],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[1,3,16],f32>, !torch.vtensor<[1,2,6,8],f32>, !torch.vtensor<[1,2,6,8],f32> }