[Arm] Enable Gather MatMul with KleidiAI Microkernels#34303
[Arm] Enable Gather MatMul with KleidiAI Microkernels#34303abhijain1204fujitsu wants to merge 9 commits intoopenvinotoolkit:masterfrom
Conversation
| #else | ||
|
|
||
| ov::element::Type getRuntimePrecision() const override; | ||
| Algorithm algorithm = Algorithm::GatherMatmulDefault; | ||
| size_t numExperts = 0; | ||
|
|
||
| std::vector<ExecutorPtr> executor; | ||
| std::vector<MemoryArgs> memArgsFC; | ||
|
|
||
| MemoryPtr m_weightsMemory = nullptr; | ||
| MemoryPtr m_tmpInpBuffer = nullptr; | ||
| MemoryDescPtr m_tmpInputDesc = nullptr; | ||
| MemoryDescPtr m_tmpOutputDesc = nullptr; | ||
|
|
||
| #endif |
There was a problem hiding this comment.
Some fields are clearly duplicated between if and else branches. Should we narrow the scope?
There was a problem hiding this comment.
I assume this file is a temporal solution, and ARM specific implementation will be moved to corresponding executor.
| continue; | ||
| } | ||
|
|
||
| parallel_for(num_valid_rows, [&](size_t m) { |
There was a problem hiding this comment.
It's better to use CpuParallel class in such contexts to align the implementation with the x64 approach.
There was a problem hiding this comment.
Cannot we reuse the exiting x64 test via moving it to the common scope and enabling corresponding instances for arm?
|
Hi @maxnick. Thanks for the comment I have modified the implementation, moving some of the Gathermatmul logic to KleidiAIExecutor, also keeping the executor interface light as discussed. Also I have integrated the logic in the same file "gathermatmul.cpp" and reuse existing x86 code. I will update this PR with the new refactored logic in the coming week once its approved internally. Will move the relevant tests to common scope as well. |
c4f4a52 to
71109bc
Compare
|
Hi @maxnick, I have made the requested changes and fixed the test cases, please review the PR. Thanks! |
|
|
||
| TEST_P(MoECompressedWeightsSubgraphTest, CompareWithRefs) { | ||
| SKIP_IF_CURRENT_TEST_IS_DISABLED() | ||
| #ifndef OPENVINO_ARCH_X86 |
There was a problem hiding this comment.
OPENVINO_ARCH_X86_64 should be checked here as well.
| reduce_node->clone_with_new_inputs({new_final_mul->output(0), reduce_node->input_value(1)}); | ||
| ov::copy_runtime_info(reduce_node, new_reduce_node); | ||
| new_reduce_sum = | ||
| squeeze_node->clone_with_new_inputs({new_reduce_node->output(0), new_reduce_node->input_value(1)}); |
There was a problem hiding this comment.
Are we assuming that squeeze axes equal reduce axes, so we can use new_reduce_node->input_value(1) in new_reduce_sum initialization? Shouldn't we use squeeze_node->input_value(1) here?
There was a problem hiding this comment.
Hi @alvoron, I have made the changes as per your suggestion. Thanks!
yes, It was based on the assumption that sqeeze axis and reduce axis are the same. But since we do not intent to make any modification to the original pattern here, so I think it makes sense here, to not have this assumption.
4a59819 to
6aa7b53
Compare
6aa7b53 to
c48e78b
Compare
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Enables GatherMatmul / GatherMatmul-Compressed on ARM by wiring the MoE-to-GatherMatmul transformation through the common CPU pass pipeline and adding a KleidiAI-backed execution path for ARM64.
Changes:
- Register MoE→GatherMatmul and GatherMatmul→Compressed conversions for non-x86 CPU builds.
- Add ARM64 KleidiAI execution path for GatherMatmul (including compressed weights constraints) and fix handling of transposed weights before packing.
- Extend MoE transformation + tests to also match ReduceSum implemented as
keep_dims=truefollowed bySqueeze.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/moe.cpp | Adjusts ARM-specific thresholds, skips, and config/parameter generation for compressed MoE tests. |
| src/plugins/intel_cpu/src/transformations/cpu_opset/convert_to_cpu_specific_opset.hpp | Registers MoE conversion passes for all CPU arches (not only x64). |
| src/plugins/intel_cpu/src/nodes/gathermatmul.h | Adds kernel-type selection (OneDNN vs KleidiAI) and ARM-only executor state. |
| src/plugins/intel_cpu/src/nodes/gathermatmul.cpp | Implements ARM64 KleidiAI primitive creation/prepare/execute paths and ARM compressed-type gating. |
| src/plugins/intel_cpu/src/nodes/executors/kleidiai/kleidiai_mm.hpp | Adds GatherMatmul-specific mode + API to provide gather/scatter index map. |
| src/plugins/intel_cpu/src/nodes/executors/kleidiai/kleidiai_mm.cpp | Implements GatherMatmul gather/scatter flow + fixes repack path for transposed weights. |
| src/common/transformations/tests/common_optimizations/convert_tiled_moe_block_to_gather_matmuls_test.cpp | Expands test matrix to cover ReduceSum keep-dims + Squeeze form. |
| src/common/transformations/src/transformations/common_optimizations/convert_tiled_moe_block_to_gather_matmuls.cpp | Updates pattern to match both ReduceSum forms and clones reduce+squeeze appropriately. |
| if (pm.find(p.reduceSum_squeeze) != pm.end()) { | ||
| const auto reduce_node = pm.at(p.reduceSum_keepDims).get_node_shared_ptr(); | ||
| const auto squeeze_node = pm.at(p.reduceSum_squeeze).get_node_shared_ptr(); | ||
| const auto new_reduce_node = | ||
| reduce_node->clone_with_new_inputs({new_final_mul->output(0), reduce_node->input_value(1)}); | ||
| ov::copy_runtime_info(reduce_node, new_reduce_node); | ||
| new_reduce_sum = | ||
| squeeze_node->clone_with_new_inputs({new_reduce_node->output(0), squeeze_node->input_value(1)}); | ||
| ov::copy_runtime_info(squeeze_node, new_reduce_sum); | ||
| } else { | ||
| const auto reduce_node = pm.at(p.reduceSum_noKeepDims).get_node_shared_ptr(); | ||
| new_reduce_sum = | ||
| reduce_node->clone_with_new_inputs({new_final_mul->output(0), reduce_node->input_value(1)}); | ||
| ov::copy_runtime_info(reduce_node, new_reduce_sum); | ||
| } | ||
|
|
||
| new_reduce_sum->set_friendly_name(p.reduce_sum->get_friendly_name()); |
There was a problem hiding this comment.
p.reduce_sum is a pattern node (an OR of two alternatives), so p.reduce_sum->get_friendly_name() is not guaranteed to match the friendly name of the actually matched runtime node. Prefer taking the friendly name from the matched node (pm.at(p.reduceSum_squeeze) or pm.at(p.reduceSum_noKeepDims)), and (in the keep-dims path) also consider preserving the ReduceSum node name separately if name-stability is important for debugging/telemetry.
| if (pm.find(p.reduceSum_squeeze) != pm.end()) { | |
| const auto reduce_node = pm.at(p.reduceSum_keepDims).get_node_shared_ptr(); | |
| const auto squeeze_node = pm.at(p.reduceSum_squeeze).get_node_shared_ptr(); | |
| const auto new_reduce_node = | |
| reduce_node->clone_with_new_inputs({new_final_mul->output(0), reduce_node->input_value(1)}); | |
| ov::copy_runtime_info(reduce_node, new_reduce_node); | |
| new_reduce_sum = | |
| squeeze_node->clone_with_new_inputs({new_reduce_node->output(0), squeeze_node->input_value(1)}); | |
| ov::copy_runtime_info(squeeze_node, new_reduce_sum); | |
| } else { | |
| const auto reduce_node = pm.at(p.reduceSum_noKeepDims).get_node_shared_ptr(); | |
| new_reduce_sum = | |
| reduce_node->clone_with_new_inputs({new_final_mul->output(0), reduce_node->input_value(1)}); | |
| ov::copy_runtime_info(reduce_node, new_reduce_sum); | |
| } | |
| new_reduce_sum->set_friendly_name(p.reduce_sum->get_friendly_name()); | |
| std::string new_reduce_sum_friendly_name; | |
| if (pm.find(p.reduceSum_squeeze) != pm.end()) { | |
| const auto reduce_node = pm.at(p.reduceSum_keepDims).get_node_shared_ptr(); | |
| const auto squeeze_node = pm.at(p.reduceSum_squeeze).get_node_shared_ptr(); | |
| const auto new_reduce_node = | |
| reduce_node->clone_with_new_inputs({new_final_mul->output(0), reduce_node->input_value(1)}); | |
| ov::copy_runtime_info(reduce_node, new_reduce_node); | |
| new_reduce_node->set_friendly_name(reduce_node->get_friendly_name()); | |
| new_reduce_sum = | |
| squeeze_node->clone_with_new_inputs({new_reduce_node->output(0), squeeze_node->input_value(1)}); | |
| ov::copy_runtime_info(squeeze_node, new_reduce_sum); | |
| new_reduce_sum_friendly_name = squeeze_node->get_friendly_name(); | |
| } else { | |
| const auto reduce_node = pm.at(p.reduceSum_noKeepDims).get_node_shared_ptr(); | |
| new_reduce_sum = | |
| reduce_node->clone_with_new_inputs({new_final_mul->output(0), reduce_node->input_value(1)}); | |
| ov::copy_runtime_info(reduce_node, new_reduce_sum); | |
| new_reduce_sum_friendly_name = reduce_node->get_friendly_name(); | |
| } | |
| new_reduce_sum->set_friendly_name(new_reduce_sum_friendly_name); |
There was a problem hiding this comment.
Incorporated these changes in the latest commit.
| p.reduceSum_keepDims = pattern::wrap_type<ov::op::v1::ReduceSum>({p.mul3, pattern::any_input()}, | ||
| pattern::consumers_count(1), | ||
| {{"keep_dims", true}}); | ||
| p.reduceSum_squeeze = pattern::wrap_type<ov::op::v0::Squeeze>({p.reduceSum_keepDims, pattern::any_input()}); | ||
| p.reduceSum_noKeepDims = pattern::wrap_type<ov::op::v1::ReduceSum>({p.mul3, pattern::any_input()}, | ||
| pattern::consumers_count(1), | ||
| {{"keep_dims", false}}); |
There was a problem hiding this comment.
The added pattern::consumers_count(1) constraint can make the transformation stop matching if ReduceSum has >1 consumer in real graphs (even if the subgraph is otherwise valid). If the single-consumer requirement is not strictly necessary for correctness, consider removing/relaxing it to avoid unintentionally disabling the MoE→GatherMatmul optimization.
| p.reduceSum_keepDims = pattern::wrap_type<ov::op::v1::ReduceSum>({p.mul3, pattern::any_input()}, | |
| pattern::consumers_count(1), | |
| {{"keep_dims", true}}); | |
| p.reduceSum_squeeze = pattern::wrap_type<ov::op::v0::Squeeze>({p.reduceSum_keepDims, pattern::any_input()}); | |
| p.reduceSum_noKeepDims = pattern::wrap_type<ov::op::v1::ReduceSum>({p.mul3, pattern::any_input()}, | |
| pattern::consumers_count(1), | |
| {{"keep_dims", false}}); | |
| p.reduceSum_keepDims = | |
| pattern::wrap_type<ov::op::v1::ReduceSum>({p.mul3, pattern::any_input()}, {{"keep_dims", true}}); | |
| p.reduceSum_squeeze = pattern::wrap_type<ov::op::v0::Squeeze>({p.reduceSum_keepDims, pattern::any_input()}); | |
| p.reduceSum_noKeepDims = | |
| pattern::wrap_type<ov::op::v1::ReduceSum>({p.mul3, pattern::any_input()}, {{"keep_dims", false}}); |
There was a problem hiding this comment.
Not made this change as we are matching only for consumers_count(1)
| void moveMemToNumaNode(int numaNodeID) override; | ||
|
|
||
| void setKaiExecutorImplAsGatherMatmul(); | ||
| void set_gather_idx(std::vector<std::pair<int32_t, int32_t>> idxMap); |
There was a problem hiding this comment.
set_gather_idx takes the vector by value, which forces an extra copy when the caller passes an lvalue. Prefer taking std::vector<...>&& (and std::move at the callsite) or const std::vector<...>& (and copying only if needed) to reduce per-iteration overhead in GatherMatmul execution.
| void set_gather_idx(std::vector<std::pair<int32_t, int32_t>> idxMap); | |
| void set_gather_idx(const std::vector<std::pair<int32_t, int32_t>>& idxMap); |
There was a problem hiding this comment.
Incorporated these changes in the latest commit.
| auto gather_idx_expertOffset = gather_idx_map.begin() + gather_axis_index * M; | ||
| std::vector<std::pair<int32_t, int32_t>> kai_gather_idx(gather_idx_expertOffset, | ||
| gather_idx_expertOffset + num_valid_rows); | ||
| executor[gather_axis_index]->set_gather_idx(kai_gather_idx); | ||
| executor[gather_axis_index]->execute(memArgs[gather_axis_index]); |
There was a problem hiding this comment.
This allocates and copies kai_gather_idx on every execute() call for every expert. To reduce overhead, consider reusing a per-expert buffer (e.g., store a vector in the GatherMatmul node, clear() + reserve() + assign()), and then pass/move it into the executor (especially if set_gather_idx is changed to accept an rvalue-reference).
| const auto srcPrc = dstDesc->getPrecision(); | ||
| m_tmpInputDesc = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(srcPrc, Shape({M, K})); | ||
| m_tmpOutputDesc = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(srcPrc, Shape({M, N})); |
There was a problem hiding this comment.
For GatherMatmul, the temporary input buffer represents gathered activations, so its precision should be derived from the source tensor (ARG_SRC), not from the destination descriptor. Using dstDesc->getPrecision() could silently create mismatched temp buffers if output precision diverges from input; prefer memory.at(ARG_SRC)->getDescPtr()->getPrecision() here.
| const auto srcPrc = dstDesc->getPrecision(); | |
| m_tmpInputDesc = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(srcPrc, Shape({M, K})); | |
| m_tmpOutputDesc = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(srcPrc, Shape({M, N})); | |
| const auto srcPrc = memory.at(ARG_SRC)->getDescPtr()->getPrecision(); | |
| const auto dstPrc = dstDesc->getPrecision(); | |
| m_tmpInputDesc = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(srcPrc, Shape({M, K})); | |
| m_tmpOutputDesc = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(dstPrc, Shape({M, N})); |
There was a problem hiding this comment.
Incorporated these changes in the latest commit.
| namespace { | ||
| // TODO: OffsetHelper is common util function. Move it to some common location | ||
| class OffsetHelper { |
There was a problem hiding this comment.
OffsetHelper logic is duplicated (there is a very similar helper in GatherMatmul). Since correct offset computation is easy to get subtly wrong over time (broadcasting, bitwidth, stride semantics), it would be safer to centralize this helper in a shared utility and reuse it in both places.
There was a problem hiding this comment.
Currently kept as such. Please suggest a common path if needs to be moved.
|
Hi @maxnick, I have fixed the clang formatting, which was failing earlier. Thanks. |
ca6191b to
80491e5
Compare
80491e5 to
d16cc01
Compare
[ About ]
[ Design ]
[Benchmark Results]

**Results are measured on single socket Graviton4 machine [ 96 cores ]
Kleidiai support is enabled and tested for F32, INT8 and INT4 precisions. For F32 OneDNN is made the default.
This work is contributed by @ashwins990 and @abhijain1204fujitsu