From ceeafaf64f346c6f14a67c612e131da5c27ef620 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 8 Jan 2024 20:42:05 +0800 Subject: [PATCH] [CPU] Support group beam search (#21983) * support group beam search * support dyn batch without set_state * apply review comments * strides may be incorrect when batch changed --------- Co-authored-by: Yu Xu --- src/plugins/intel_cpu/src/memory_state.cpp | 8 +- .../intel_cpu/src/nodes/scaled_attn.cpp | 148 ++++++++++- src/plugins/intel_cpu/src/nodes/scaled_attn.h | 1 + .../cpu_opset/common/op/sdpa.cpp | 7 +- .../src/sdpa_group_beam_search.cpp | 231 ++++++++++++++++++ 5 files changed, 383 insertions(+), 12 deletions(-) create mode 100644 src/plugins/intel_cpu/tests/functional/subgraph_tests/src/sdpa_group_beam_search.cpp diff --git a/src/plugins/intel_cpu/src/memory_state.cpp b/src/plugins/intel_cpu/src/memory_state.cpp index 371e60c0f0ea4d..758b7530dbd0e1 100644 --- a/src/plugins/intel_cpu/src/memory_state.cpp +++ b/src/plugins/intel_cpu/src/memory_state.cpp @@ -164,8 +164,12 @@ VariableStateKVcache::VariableStateKVcache( } ov::SoPtr VariableStateKVcache::get_state() const { - OPENVINO_ASSERT(m_internal_mem && m_hidden_state, "KVState internal memory is not initialized"); - OPENVINO_ASSERT(!is_reset_state(), "KVState is undefined after reset"); + if (!m_internal_mem || !m_hidden_state || is_reset_state()) { + auto new_desc = to_static(get_external_desc()); + auto external_mem = std::make_shared(get_engine(), new_desc); + return std::make_shared(external_mem); + } + auto actual_internal_desc = m_internal_mem->getDescWithType(); auto&& dims = actual_internal_desc->getShape().getStaticDims(); diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 3dc40ff00c8c8c..c10b7038d637ab 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -824,13 +824,153 @@ void ScaledDotProductAttention::assignState(const std::shared_ptr order = {0, 1, 2, 3}; + if (!m_config.config.permute_axes.empty()) { + order = m_config.config.permute_axes; + } + PlainTensor beam_idx, old_beam_table_k; + auto old_hidden_state_k = m_k_state->hidden_state_mem(); + beam_idx.reset(mem_beam_idx); + + auto inputNumber = getOriginalInputsNumber(); + auto&& v_dims = getParentEdgeAt(inputNumber - 1)->getMemory().getStaticDims(); + size_t L0 = v_dims.at(order[2]); + auto B_state = v_dims.at(order[0]); + old_beam_table_k.reset(old_hidden_state_k); + + PlainTensor cur_k; + PlainTensor cur_v; + cur_k.reset(mem_cur_k); + cur_v.reset(mem_cur_v); + cur_k = cur_k.permute(order); + cur_v = cur_v.permute(order); + auto B = cur_k.size(0); + auto H = cur_k.size(1); + auto L1 = cur_k.size(2); + auto S = cur_k.size(3); + auto reverse = [&order] (const std::vector& cur) { + std::vector result(cur.size()); + for (size_t i = 0; i < cur.size(); i++) { + result[order[i]] = cur[i]; + } + return result; + }; + + // 1. check beam idx if it's valid + auto* table = beam_idx.data(); + for (size_t i = 0; i < B; i++) { + OPENVINO_ASSERT(static_cast(table[i]) < B_state, "beam_idx[", i, "]=", table[i], + " should less than batch of previous pastkv: ", B_state); + } + + // 2. resize pastkv + { + auto shape = {B, H, (L0 + L1) * 2, S}; + auto mem_desc = std::make_shared(m_kvcache_precision, + Shape(reverse(shape)), + shape, + order); + auto new_internal_mem_k = std::make_shared(getEngine(), mem_desc); + auto new_internal_mem_v = std::make_shared(getEngine(), mem_desc); + + PlainTensor new_pastk, new_pastv, old_past_k, old_past_v; + new_pastk.reset(new_internal_mem_k); + new_pastv.reset(new_internal_mem_v); + new_pastk = new_pastk.permute(order); + new_pastv = new_pastv.permute(order); + if (L0 > 0) { + auto old_internal_mem_k = m_k_state->internal_state_mem(); + auto old_internal_mem_v = m_v_state->internal_state_mem(); + old_past_k.reset(old_internal_mem_k); + old_past_v.reset(old_internal_mem_v); + old_past_k = old_past_k.permute(order); + old_past_v = old_past_v.permute(order); + parallel_for3d(B, H, L0, [&](size_t b, size_t h, size_t m) { + auto idx = static_cast(table[b]); + auto b_kv = static_cast(old_beam_table_k.at({idx, m})); + memcpy(&new_pastk.at({b, h, m}), + &old_past_k.at({b_kv, h, m}), + S * old_past_k.m_element_size); + memcpy(&new_pastv.at({b, h, m}), + &old_past_v.at({b_kv, h, m}), + S * old_past_v.m_element_size); + }); + } + + auto new_shape = {B, H, (L0 + L1), S}; + mem_desc = std::make_shared(m_kvcache_precision, + Shape(reverse(new_shape)), + new_shape, + order, + 0, + VectorDims{}, + mem_desc->getStrides()); + new_internal_mem_k->redefineDesc(mem_desc); + new_internal_mem_v->redefineDesc(mem_desc); + attn_memcpy(cur_k, cur_v, new_pastk.slice(2, L0, L0 + L1), new_pastv.slice(2, L0, L0 + L1)); + + m_k_state->assign_internal_state(new_internal_mem_k); + m_v_state->assign_internal_state(new_internal_mem_v); + m_k_state->assign_internal_state_max_size(B * H * (L0 + L1) * 2 * S); + m_v_state->assign_internal_state_max_size(B * H * (L0 + L1) * 2 * S); + } + // 3. create beam table + { + auto mem_desc = std::make_shared(ov::element::i32, Shape{B, (L0 + L1) * 2}); + + auto new_hidden_state_k = std::make_shared(getEngine(), mem_desc); + auto new_hidden_state_v = std::make_shared(getEngine(), mem_desc); + PlainTensor new_beam_table_k, new_beam_table_v; + new_beam_table_k.reset(new_hidden_state_k); + new_beam_table_v.reset(new_hidden_state_v); + + for (size_t b = 0; b < B; b++) { + for (size_t l = 0; l < L0 + L1; l++) { + new_beam_table_k.at({b, l}) = b; + new_beam_table_v.at({b, l}) = b; + } + } + + std::vector new_shape{B, (L0 + L1)}; + mem_desc = std::make_shared(ov::element::i32, + Shape(new_shape), + new_shape, + VectorDims{0, 1}, + 0, + VectorDims{}, + mem_desc->getStrides()); + new_hidden_state_k->redefineDesc(mem_desc); + new_hidden_state_v->redefineDesc(mem_desc); + + m_k_state->assign_hidden_state(new_hidden_state_k); + m_v_state->assign_hidden_state(new_hidden_state_v); + m_k_state->assign_hidden_state_max_size(B * (L0 + L1) * 2); + m_v_state->assign_hidden_state_max_size(B * (L0 + L1) * 2); + } +} + void ScaledDotProductAttention::gatherConcatPastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx) { PlainTensor cur_k; cur_k.reset(mem_cur_k); - if (!m_config.config.permute_axes.empty()) + auto inputNumber = getOriginalInputsNumber(); + auto&& v_dims = getParentEdgeAt(inputNumber - 1)->getMemory().getStaticDims(); + size_t B_state; + if (!m_config.config.permute_axes.empty()) { cur_k = cur_k.permute(m_config.config.permute_axes); + B_state = v_dims.at(m_config.config.permute_axes[0]); + } else { + B_state = v_dims.at(0); + } + + auto B = cur_k.size(0); + auto L1 = cur_k.size(2); + if (B != B_state) { + resetBeamTablePastkv(mem_cur_k, mem_cur_v, mem_beam_idx); + return; + } - updateBeamTable(mem_beam_idx, cur_k.size(2)); + updateBeamTable(mem_beam_idx, L1); updatePastkv(mem_cur_k, mem_cur_v); } @@ -858,7 +998,7 @@ void ScaledDotProductAttention::updateBeamTable(const MemoryPtr& mem_beam_idx, s OPENVINO_ASSERT(B == B_state, "beam idx batch: ", B, " is not equal to batch of state: ", B_state); OPENVINO_ASSERT(B * (L0 + L1) > 0, "B or (L0+L1) is zero, B: ", B, ", L0: ", L0, ", L1: ", L1); // resize buffer - if (B * (L0 + L1) > m_k_state->hidden_state_max_size()) { + if (is_reset || B * (L0 + L1) > m_k_state->hidden_state_max_size()) { auto mem_desc = std::make_shared(ov::element::i32, Shape{B, (L0 + L1) * 2}); auto new_hidden_state_k = std::make_shared(getEngine(), mem_desc); @@ -981,7 +1121,7 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M OPENVINO_ASSERT(B == B_state, "pastkv batch: ", B, " is not equal to batch of state: ", B_state); OPENVINO_ASSERT(B * (L0 + L1) > 0, "B or (L0+L1) is zero, B: ", B, ", L0: ", L0, ", L1: ", L1); // resize buffer - if (B * H * (L0 + L1) * S > m_k_state->internal_state_max_size()) { + if (is_reset || B * H * (L0 + L1) * S > m_k_state->internal_state_max_size()) { auto new_shape = {B, H, (L0 + L1) * 2, S}; auto mem_desc = std::make_shared(m_kvcache_precision, Shape(reverse(new_shape)), diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 4cb09ac32d242f..4b25dc0c0fdeee 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -56,6 +56,7 @@ class ScaledDotProductAttention : public Node { void updateBeamTable(const MemoryPtr& mem_beam_idx, size_t new_q_len); void updatePastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v); ov::element::Type getRuntimePrecision() const override; + void resetBeamTablePastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx); struct Config { ScaledDotProductAttentionWithKVCache::Config config; diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp index f581d981797513..31bce21d3579d3 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp @@ -46,13 +46,8 @@ void ov::intel_cpu::ScaledDotProductAttentionWithKVCache::validate_and_infer_typ "shape not compatiable at index ", i); } - } else if (i == length_index) { - continue; } else { - NODE_VALIDATION_CHECK(this, - q_ps[i].compatible(past_kv_ps[i]), - "shape not compatiable at index ", - i); + continue; } } past_kv_ps[length_index] += q_ps[length_index]; diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/sdpa_group_beam_search.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/sdpa_group_beam_search.cpp new file mode 100644 index 00000000000000..13de6c5dc11e9e --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/sdpa_group_beam_search.cpp @@ -0,0 +1,231 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "openvino/opsets/opset13.hpp" +#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp" + +#include "ov_models/utils/ov_helpers.hpp" +#include "shared_test_classes/base/layer_test_utils.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "test_utils/cpu_test_utils.hpp" +#include "common_test_utils/ov_tensor_utils.hpp" + +using namespace CPUTestUtils; + +namespace ov { +namespace test { + +using SDPAGroupBeamSearchTestParams = std::tuple + >; +// Subgraph: +/* Parameter + * | + * Parameter ReadValue | ReadValue Parameter + * \ / | \ / + * Gather / Gather / + * \ / | \ / + * Concat | Concat + * / \ | / \ + * / \ | / \ + * / \ | / \ + * Assign ScaledDotProductAttention Assign + * | + * Add + * | + * Result + */ + +class SDPAGroupBeamSearchTest : public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest, + public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + ElementType inType; + std::vector inputShapes; + std::tie(inType, inputShapes) = obj.param; + std::ostringstream result; + result << "IS="; + for (const auto& shape : inputShapes) { + result << ov::test::utils::partialShape2str({shape.first}) << "_"; + } + result << "TS="; + for (const auto& shape : inputShapes) { + result << "("; + if (!shape.second.empty()) { + for (const auto& itr : shape.second) { + result << ov::test::utils::vec2str(itr); + } + } + result << ")_"; + } + result << "Prc=" << inType; + return result.str(); + } + + void SetUp() override { + ElementType inType; + std::vector inputShapes; + std::tie(inType, inputShapes) = this->GetParam(); + targetDevice = ov::test::utils::DEVICE_CPU; + rel_threshold = 1e-2f; + if (inType == ElementType::bf16) { + configuration.insert({"ENFORCE_BF16", "YES"}); + rel_threshold = 0.01f; + } + init_input_shapes(inputShapes); + ov::ParameterVector inputParams; + // q,k,v + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams[0]->set_friendly_name("q"); + inputParams[1]->set_friendly_name("k"); + inputParams[2]->set_friendly_name("v"); + // pastkv init_cost + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + auto var_k = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastk"}); + auto pastk = std::make_shared(inputParams[3], var_k); + pastk->set_friendly_name("pastk_r"); + auto var_v = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastv"}); + auto pastv = std::make_shared(inputParams[3], var_v); + pastv->set_friendly_name("pastv_r"); + + auto beam_idx = std::make_shared(ElementType::i32, ov::PartialShape{-1}); + beam_idx->set_friendly_name("beam_idx"); + inputParams.push_back(beam_idx); + auto gatherK = std::make_shared(pastk, beam_idx, op::v0::Constant::create(ElementType::i32, {1}, {0})); + auto gatherV = std::make_shared(pastv, beam_idx, op::v0::Constant::create(ElementType::i32, {1}, {0})); + auto concatK = std::make_shared(OutputVector{gatherK, inputParams[1]}, 2); + auto concatV = std::make_shared(OutputVector{gatherV, inputParams[2]}, 2); + auto sdp = std::make_shared(inputParams[0], concatK, concatV, false); + sdp->set_friendly_name("mha"); + auto add = std::make_shared(sdp, op::v0::Constant::create(inType, {1}, {1.0f})); + auto pastk_assign = std::make_shared(concatK, var_k); + auto pastv_assign = std::make_shared(concatV, var_v); + pastk_assign->set_friendly_name("pastk_w"); + pastv_assign->set_friendly_name("pastv_w"); + + ResultVector results{std::make_shared(add)}; + + SinkVector sinks{pastk_assign, pastv_assign}; + function = std::make_shared(results, sinks, inputParams, "ConcatSDP"); + targetDevice = ov::test::utils::DEVICE_CPU; + + functionRefs = function->clone(); + pass::Manager manager; + // decompose ScaledDotProductAttention + manager.register_pass(); + manager.run_passes(functionRefs); + } + void generate_inputs(const std::vector& targetInputStaticShapes) override { + std::vector shapes(4); + shapes[0] = targetInputStaticShapes[0]; + shapes[1] = targetInputStaticShapes[0]; + shapes[2] = targetInputStaticShapes[0]; + shapes[3] = targetInputStaticShapes[1]; + SubgraphBaseTest::generate_inputs(shapes); + } + template + void strided_iota(IT first, size_t n, T value, T stride) { + for (size_t i = 0; i < n; i++) { + *first++ = value; + value += stride; + } + } + void generate(int idx, const std::vector& targetInputStaticShapes, size_t beam_num) { + inputs.clear(); + auto create_input = [this, beam_num] (std::shared_ptr param, ov::Shape shape, float val) { + if (param->get_element_type() == element::i32) { + ov::Tensor t{ov::element::i32, shape}; + auto size = shape[0]; + auto* p = static_cast(t.data()); + auto start = static_cast(val); + for (size_t i = 0; i < size; i++) { + p[i] = (start + i) % beam_num; + } + inputs.insert({param, t}); + } else if (param->get_element_type() == element::f32) { + ov::Tensor t{ov::element::f32, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } else { + ov::Tensor t{ov::element::bf16, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } + }; + // q, k, v, pastkv + create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f); + create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f); + create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f); + create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f); + create_input(function->get_parameters()[4], ov::Shape{targetInputStaticShapes[0][0]}, idx + 0.0f); + } + void prepare() { + compile_model(); + inferRequest = compiledModel.create_infer_request(); + ASSERT_TRUE(inferRequest); + } + void reset() { + for (auto&& state : inferRequest.query_state()) { + state.reset(); + } + } + std::vector run_test(std::shared_ptr model) { + function = model; + prepare(); + std::vector outputs; + + for (int idx = 0; idx < static_cast(targetStaticShapes.size()); idx++) { + auto& shapes = targetStaticShapes[idx]; + generate(idx, shapes, targetStaticShapes[idx > 0 ? idx - 1 : 0][0][0]); + for (const auto& input : inputs) { + inferRequest.set_tensor(input.first, input.second); + } + inferRequest.infer(); + auto outputTensor = inferRequest.get_output_tensor(0); + ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()}; + outputTensor.copy_to(copy); + outputs.push_back(copy); + } + reset(); + + return outputs; + } +}; + +TEST_P(SDPAGroupBeamSearchTest, CompareWithRefs) { + auto actualOutputs = run_test(function); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1); + CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0); + CheckNumberOfNodesWithType(compiledModel, "Reorder", 0); + CheckNumberOfNodesWithType(compiledModel, "Gather", 0); + auto expectedOutputs = run_test(functionRefs); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0); + for (size_t i = 0; i < actualOutputs.size(); i++) { + ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold); + } +} + +namespace { +const std::vector> inputShapes = { + { + // B, H, L1, S + {{-1, 8, -1, 64}, {{1, 8, 10, 64}, {4, 8, 1, 64}, {2, 8, 1, 64}, {4, 8, 1, 64}, {4, 8, 1, 64}}}, + // B, H, L0, S + {{-1, 8, -1, 64}, {{1, 8, 0, 64}, {4, 8, 10, 64}, {2, 8, 11, 64}, {4, 8, 12, 64}, {4, 8, 13, 64}}}, + }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_SDPAGroupBeamSearchTest, + SDPAGroupBeamSearchTest, + ::testing::Combine(::testing::Values(ElementType::f32), + ::testing::ValuesIn(inputShapes)), + SDPAGroupBeamSearchTest::getTestCaseName); + +} // namespace +} // namespace test +} // namespace ov