diff --git a/docs/template_extension/new/identity.cpp b/docs/template_extension/new/identity.cpp index bf84e3a5cf9..442753c1f5c 100644 --- a/docs/template_extension/new/identity.cpp +++ b/docs/template_extension/new/identity.cpp @@ -21,7 +21,7 @@ void Identity::validate_and_infer_types() { //! [op:copy] std::shared_ptr Identity::clone_with_new_inputs(const ov::OutputVector& new_args) const { - OPENVINO_ASSERT(new_args.size() != 1, "Incorrect number of new arguments"); + OPENVINO_ASSERT(new_args.size() == 1, "Incorrect number of new arguments"); return std::make_shared(new_args.at(0)); } diff --git a/src/core/src/graph_util.cpp b/src/core/src/graph_util.cpp index 25f6ca04241..5f1101b076e 100644 --- a/src/core/src/graph_util.cpp +++ b/src/core/src/graph_util.cpp @@ -245,6 +245,12 @@ std::vector> ngraph::clone_nodes(const std::vector new_output.get_rt_info() = output_rt_info; } + for (auto input : node->inputs()) { + const auto& output_rt_info = input.get_rt_info(); + auto new_input = cloned_node->input(input.get_index()); + new_input.get_rt_info() = output_rt_info; + } + cloned_node->set_op_annotations(node->get_op_annotations()); node_map[node.get()] = cloned_node; diff --git a/src/core/src/op/util/framework_node.cpp b/src/core/src/op/util/framework_node.cpp index cf11bbb2ff2..acfb1ff5be1 100644 --- a/src/core/src/op/util/framework_node.cpp +++ b/src/core/src/op/util/framework_node.cpp @@ -20,6 +20,9 @@ std::shared_ptr ov::op::util::FrameworkNode::clone_with_new_inputs(con for (size_t i = 0; i < get_output_size(); ++i) { node->set_output_type(i, get_output_element_type(i), get_output_partial_shape(i)); } + node->m_inputs_desc = m_inputs_desc; + node->m_output_desc = m_output_desc; + node->m_attrs = m_attrs; return node; } diff --git a/src/core/src/pass/serialize.cpp b/src/core/src/pass/serialize.cpp index 57ccad45f46..64832a1411f 100644 --- a/src/core/src/pass/serialize.cpp +++ b/src/core/src/pass/serialize.cpp @@ -1012,7 +1012,8 @@ void serializeFunc(std::ostream& xml_file, } // namespace namespace ov { -bool pass::Serialize::run_on_model(const std::shared_ptr& f) { +bool pass::Serialize::run_on_model(const std::shared_ptr& f_orig) { + auto f = ov::clone_model(*f_orig); if (m_xmlFile && m_binFile) { serializeFunc(*m_xmlFile, *m_binFile, f, m_version, m_custom_opsets); } else { diff --git a/src/core/tests/models/ir/conv_with_rt_info.xml b/src/core/tests/models/ir/conv_with_rt_info.xml index bd969eead7b..eeda7abb2bf 100644 --- a/src/core/tests/models/ir/conv_with_rt_info.xml +++ b/src/core/tests/models/ir/conv_with_rt_info.xml @@ -24,7 +24,7 @@ - + 1 diff --git a/src/core/tests/pass/serialization/serialize.cpp b/src/core/tests/pass/serialization/serialize.cpp index d3d4329625d..dc632a00284 100644 --- a/src/core/tests/pass/serialization/serialize.cpp +++ b/src/core/tests/pass/serialization/serialize.cpp @@ -8,6 +8,8 @@ #include +#include "ngraph/pass/manager.hpp" +#include "ngraph/pass/serialize.hpp" #include "openvino/util/file_util.hpp" #include "read_ir.hpp" #include "util/graph_comparator.hpp" @@ -41,14 +43,16 @@ public: TEST_P(SerializationTest, CompareFunctions) { auto expected = ov::test::readModel(m_model_path, m_binary_path); + auto orig = ov::clone_model(*expected); ov::pass::Serialize(m_out_xml_path, m_out_bin_path).run_on_model(expected); auto result = ov::test::readModel(m_out_xml_path, m_out_bin_path); - const auto fc = FunctionsComparator::with_default() .enable(FunctionsComparator::ATTRIBUTES) .enable(FunctionsComparator::CONST_VALUES); const auto res = fc.compare(result, expected); + const auto res2 = fc.compare(expected, orig); EXPECT_TRUE(res.valid) << res.message; + EXPECT_TRUE(res2.valid) << res2.message; } INSTANTIATE_TEST_SUITE_P( diff --git a/src/tests/functional/inference_engine/ir_serialization/custom_ops.cpp b/src/tests/functional/inference_engine/ir_serialization/custom_ops.cpp index 4f5978e4022..0e707074ad8 100644 --- a/src/tests/functional/inference_engine/ir_serialization/custom_ops.cpp +++ b/src/tests/functional/inference_engine/ir_serialization/custom_ops.cpp @@ -170,3 +170,19 @@ TEST_F(CustomOpsSerializationTest, CustomOpOVExtensions) { ASSERT_TRUE(success) << message; } + +TEST_F(CustomOpsSerializationTest, CloneFrameworkNode) { + const std::string model = CommonTestUtils::getModelFromTestModelZoo(IR_SERIALIZATION_MODELS_PATH "custom_op.xml"); + InferenceEngine::Core ie; + auto extension = std::make_shared(); + ie.AddExtension(extension); + auto expected = ie.ReadNetwork(model); + auto clone = ov::clone_model(*expected.getFunction()); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default() + .enable(FunctionsComparator::ATTRIBUTES) + .enable(FunctionsComparator::CONST_VALUES) + .enable(FunctionsComparator::PRECISIONS); + const FunctionsComparator::Result result = func_comparator.compare(clone, expected.getFunction()); + ASSERT_TRUE(result.valid) << result.message; +}