-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Serialize any node as postponed constant #32490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |||||
|
|
||||||
| #include "openvino/xml_util/xml_serialize_util.hpp" | ||||||
|
|
||||||
| #include <functional> | ||||||
| #include <pugixml.hpp> | ||||||
|
|
||||||
| #include "openvino/core/descriptor_tensor.hpp" | ||||||
|
|
@@ -875,6 +876,60 @@ std::unique_ptr<XmlSerializer> XmlSerializer::make_visitor(pugi::xml_node& data, | |||||
| data_is_temporary); | ||||||
| } | ||||||
|
|
||||||
| namespace { | ||||||
| void find_postponed_constants_and_exclude_nodes(const std::vector<std::shared_ptr<ov::Node>>& sorted_ops, | ||||||
| std::unordered_set<ov::Node*>& postponed_constants, | ||||||
| std::unordered_set<ov::Node*>& nodes_to_exclude) { | ||||||
| // Collect all nodes with postponed_constant attribute (not for exclusion, but as starting points) | ||||||
| for (const auto& node : sorted_ops) { | ||||||
| if (node->get_rt_info().count("postponed_constant")) { | ||||||
| postponed_constants.insert(node.get()); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Perform reverse DFS to find nodes that only feed into postponed_constant nodes | ||||||
| std::function<void(ov::Node*)> reverse_dfs = [&](ov::Node* node) { | ||||||
| // Skip if it's a Parameter (model input) | ||||||
| if (ov::op::util::is_parameter(node)) { | ||||||
| return; | ||||||
| } | ||||||
|
|
||||||
| // Check if ALL outputs go to postponed_constant or already excluded nodes | ||||||
| bool all_outputs_excluded = true; | ||||||
| for (const auto& output : node->outputs()) { | ||||||
| for (const auto& target_input : output.get_target_inputs()) { | ||||||
| auto* target_node = target_input.get_node(); | ||||||
| if (!postponed_constants.count(target_node) && !nodes_to_exclude.count(target_node)) { | ||||||
| all_outputs_excluded = false; | ||||||
| break; | ||||||
| } | ||||||
| } | ||||||
| if (!all_outputs_excluded) { | ||||||
| break; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // If all outputs are excluded, mark this node and continue DFS | ||||||
| if (all_outputs_excluded && node->get_output_size() > 0) { | ||||||
|
||||||
| if (all_outputs_excluded && node->get_output_size() > 0) { | |
| if (all_outputs_excluded) { |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,10 @@ | |
| #include "common_test_utils/common_utils.hpp" | ||
| #include "common_test_utils/file_utils.hpp" | ||
| #include "common_test_utils/graph_comparator.hpp" | ||
| #include "openvino/op/add.hpp" | ||
| #include "openvino/op/concat.hpp" | ||
| #include "openvino/op/constant.hpp" | ||
| #include "openvino/op/multiply.hpp" | ||
| #include "openvino/pass/manager.hpp" | ||
| #include "openvino/pass/serialize.hpp" | ||
| #include "openvino/runtime/core.hpp" | ||
|
|
@@ -166,3 +169,146 @@ TEST(PostponedOpSerializationTest, IncorrectRtInfo) { | |
| std::stringstream serialized_model, serialized_weigths; | ||
| ov::pass::Serialize(serialized_model, serialized_weigths).run_on_model(model); | ||
| } | ||
|
|
||
| TEST(PostponedConstantTest, ConcatWithPostponedConstant) { | ||
| std::stringstream serialized_xml, serialized_bin; | ||
| { | ||
| auto const1 = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4}); | ||
| auto const2 = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8}); | ||
| auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{const1, const2}, 0); | ||
| concat->get_rt_info()["postponed_constant"] = true; | ||
|
|
||
| auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2}); | ||
| auto add = std::make_shared<ov::op::v1::Add>(concat, param); | ||
|
|
||
| auto model = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel"); | ||
|
|
||
| ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model); | ||
| } | ||
| ov::Core core; | ||
|
|
||
| auto weights = serialized_bin.str(); | ||
| ov::Tensor weights_tensor(ov::element::u8, ov::Shape{weights.size()}, weights.data()); | ||
|
|
||
| auto deserialized_model = core.read_model(serialized_xml.str(), weights_tensor); | ||
|
|
||
| { | ||
| auto constant = std::make_shared<ov::op::v0::Constant>(ov::element::f32, | ||
| ov::Shape{4, 2}, | ||
| std::vector<float>{1, 2, 3, 4, 5, 6, 7, 8}); | ||
| auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2}); | ||
| auto add = std::make_shared<ov::op::v1::Add>(constant, param); | ||
|
|
||
| auto expected = std::make_shared<ov::Model>(add->outputs(), ov::ParameterVector{param}, "ConcatAddModel"); | ||
|
|
||
| const auto& [success, message] = | ||
| compare_functions(deserialized_model, expected, true, false, false, true, true); | ||
| ASSERT_TRUE(success) << message; | ||
| } | ||
| } | ||
|
|
||
| TEST(PostponedConstantTest, SubgraphExclusion) { | ||
| GTEST_SKIP() << "Subgraph exclusion is not supported in the current implementation"; | ||
| std::stringstream serialized_xml, serialized_bin; | ||
| { | ||
| auto const1 = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4}); | ||
| auto const2 = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8}); | ||
|
|
||
| auto add1 = std::make_shared<ov::op::v1::Add>(const1, const2); | ||
| auto multiply = std::make_shared<ov::op::v1::Multiply>(add1, const2); | ||
| auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{multiply, const2}, 0); | ||
| concat->get_rt_info()["postponed_constant"] = true; | ||
|
|
||
| auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2}); | ||
| auto final_add = std::make_shared<ov::op::v1::Add>(concat, param); | ||
|
|
||
| auto model = | ||
| std::make_shared<ov::Model>(final_add->outputs(), ov::ParameterVector{param}, "SubgraphExclusionModel"); | ||
|
|
||
| ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model); | ||
| } | ||
| ov::Core core; | ||
|
|
||
| auto weights = serialized_bin.str(); | ||
| ov::Tensor weights_tensor(ov::element::u8, ov::Shape{weights.size()}, weights.data()); | ||
|
|
||
| auto deserialized_model = core.read_model(serialized_xml.str(), weights_tensor); | ||
|
|
||
| { | ||
| // Expected: const1, const2 used for Add -> [6,8,10,12] | ||
| // Then multiply by const2 [5,6,7,8] -> [30,48,70,96] | ||
| // Then concat with const2 [5,6,7,8] along axis 0 -> [30,48,70,96,5,6,7,8] reshaped to {4,2} | ||
|
Comment on lines
+242
to
+244
|
||
| auto constant = std::make_shared<ov::op::v0::Constant>(ov::element::f32, | ||
| ov::Shape{4, 2}, | ||
| std::vector<float>{30, 48, 70, 96, 5, 6, 7, 8}); | ||
| auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 2}); | ||
| auto final_add = std::make_shared<ov::op::v1::Add>(constant, param); | ||
|
|
||
| auto expected = | ||
| std::make_shared<ov::Model>(final_add->outputs(), ov::ParameterVector{param}, "SubgraphExclusionModel"); | ||
|
|
||
| const auto& [success, message] = | ||
| compare_functions(deserialized_model, expected, true, false, false, true, true); | ||
| ASSERT_TRUE(success) << message; | ||
| } | ||
| } | ||
|
|
||
| TEST(PostponedConstantTest, NodeWithMultipleConsumers) { | ||
| std::stringstream serialized_xml, serialized_bin; | ||
| { | ||
| auto const1 = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4}); | ||
| auto const2 = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8}); | ||
|
|
||
| auto add = std::make_shared<ov::op::v1::Add>(const1, const2); | ||
| auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{const1, const2}, 0); | ||
|
|
||
| auto model = | ||
| std::make_shared<ov::Model>(ov::OutputVector{concat, add}, ov::ParameterVector{}, "MultipleConsumersModel"); | ||
|
|
||
| concat->get_rt_info()["postponed_constant"] = true; | ||
|
|
||
| ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model); | ||
| } | ||
| ov::Core core; | ||
|
|
||
| auto weights = serialized_bin.str(); | ||
| ov::Tensor weights_tensor(ov::element::u8, ov::Shape{weights.size()}, weights.data()); | ||
|
|
||
| auto deserialized_model = core.read_model(serialized_xml.str(), weights_tensor); | ||
|
|
||
| { | ||
| auto const1 = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{1, 2, 3, 4}); | ||
| auto const2 = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{2, 2}, std::vector<float>{5, 6, 7, 8}); | ||
|
|
||
| auto add = std::make_shared<ov::op::v1::Add>(const1, const2); | ||
| auto concat = std::make_shared<ov::op::v0::Constant>(ov::element::f32, | ||
| ov::Shape{4, 2}, | ||
| std::vector<float>{1, 2, 3, 4, 5, 6, 7, 8}); | ||
|
|
||
| auto expected = | ||
| std::make_shared<ov::Model>(ov::OutputVector{concat, add}, ov::ParameterVector{}, "MultipleConsumersModel"); | ||
|
|
||
| const auto& [success, message] = | ||
| compare_functions(deserialized_model, expected, true, false, false, true, true); | ||
| ASSERT_TRUE(success) << message; | ||
| } | ||
| } | ||
|
|
||
| TEST(PostponedConstantTest, ParameterNotExcluded) { | ||
| std::stringstream serialized_xml, serialized_bin; | ||
| auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 2}); | ||
| auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{param}, 0); | ||
| auto model = std::make_shared<ov::Model>(concat->outputs(), ov::ParameterVector{param}, "ParameterModel"); | ||
|
|
||
| concat->get_rt_info()["postponed_constant"] = true; | ||
|
|
||
| EXPECT_THROW(ov::pass::Serialize(serialized_xml, serialized_bin).run_on_model(model), ov::Exception); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for Parameter nodes should include an exception throw when a Parameter is encountered in a postponed constant's dependency chain. Test
ParameterNotExcludedexpects an exception (line 313), but this implementation silently returns, preventing the exception from being raised.