fix bug in Serialize (#74447) (#9840)

* fix bug in Serialize (#74447)

add simple serialization test to check pads changes

clang fix

add check and change pads in conv

refactor ov::clone_model

fix

check in test

* fix FrameworkNode and add test

* fix assert in identiry.cpp

* fix clone_nodes

* remove for node and constructor for node_input.cpp

add spaces

add space
This commit is contained in:
Smirnov Grigorii 2022-02-08 22:00:20 +03:00 committed by GitHub
parent a18069926e
commit d951433b12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 34 additions and 4 deletions

View File

@ -21,7 +21,7 @@ void Identity::validate_and_infer_types() {
//! [op:copy] //! [op:copy]
std::shared_ptr<ov::Node> Identity::clone_with_new_inputs(const ov::OutputVector& new_args) const { std::shared_ptr<ov::Node> Identity::clone_with_new_inputs(const ov::OutputVector& new_args) const {
OPENVINO_ASSERT(new_args.size() != 1, "Incorrect number of new arguments"); OPENVINO_ASSERT(new_args.size() == 1, "Incorrect number of new arguments");
return std::make_shared<Identity>(new_args.at(0)); return std::make_shared<Identity>(new_args.at(0));
} }

View File

@ -245,6 +245,12 @@ std::vector<std::shared_ptr<ngraph::Node>> ngraph::clone_nodes(const std::vector
new_output.get_rt_info() = output_rt_info; new_output.get_rt_info() = output_rt_info;
} }
for (auto input : node->inputs()) {
const auto& output_rt_info = input.get_rt_info();
auto new_input = cloned_node->input(input.get_index());
new_input.get_rt_info() = output_rt_info;
}
cloned_node->set_op_annotations(node->get_op_annotations()); cloned_node->set_op_annotations(node->get_op_annotations());
node_map[node.get()] = cloned_node; node_map[node.get()] = cloned_node;

View File

@ -20,6 +20,9 @@ std::shared_ptr<ov::Node> ov::op::util::FrameworkNode::clone_with_new_inputs(con
for (size_t i = 0; i < get_output_size(); ++i) { for (size_t i = 0; i < get_output_size(); ++i) {
node->set_output_type(i, get_output_element_type(i), get_output_partial_shape(i)); node->set_output_type(i, get_output_element_type(i), get_output_partial_shape(i));
} }
node->m_inputs_desc = m_inputs_desc;
node->m_output_desc = m_output_desc;
node->m_attrs = m_attrs;
return node; return node;
} }

View File

@ -1012,7 +1012,8 @@ void serializeFunc(std::ostream& xml_file,
} // namespace } // namespace
namespace ov { namespace ov {
bool pass::Serialize::run_on_model(const std::shared_ptr<ngraph::Function>& f) { bool pass::Serialize::run_on_model(const std::shared_ptr<ngraph::Function>& f_orig) {
auto f = ov::clone_model(*f_orig);
if (m_xmlFile && m_binFile) { if (m_xmlFile && m_binFile) {
serializeFunc(*m_xmlFile, *m_binFile, f, m_version, m_custom_opsets); serializeFunc(*m_xmlFile, *m_binFile, f, m_version, m_custom_opsets);
} else { } else {

View File

@ -24,7 +24,7 @@
</output> </output>
</layer> </layer>
<layer id="2" name="Convolution_72" type="Convolution" version="opset1"> <layer id="2" name="Convolution_72" type="Convolution" version="opset1">
<data strides="1, 1" dilations="1, 1" pads_begin="0, 0" pads_end="0, 0" auto_pad="explicit" PrimitivesPriority="_IMPLS_"/> <data strides="1, 1" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="SAME_LOWER" PrimitivesPriority="_IMPLS_"/>
<input> <input>
<port id="0"> <port id="0">
<dim>1</dim> <dim>1</dim>

View File

@ -8,6 +8,8 @@
#include <fstream> #include <fstream>
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/serialize.hpp"
#include "openvino/util/file_util.hpp" #include "openvino/util/file_util.hpp"
#include "read_ir.hpp" #include "read_ir.hpp"
#include "util/graph_comparator.hpp" #include "util/graph_comparator.hpp"
@ -41,14 +43,16 @@ public:
TEST_P(SerializationTest, CompareFunctions) { TEST_P(SerializationTest, CompareFunctions) {
auto expected = ov::test::readModel(m_model_path, m_binary_path); auto expected = ov::test::readModel(m_model_path, m_binary_path);
auto orig = ov::clone_model(*expected);
ov::pass::Serialize(m_out_xml_path, m_out_bin_path).run_on_model(expected); ov::pass::Serialize(m_out_xml_path, m_out_bin_path).run_on_model(expected);
auto result = ov::test::readModel(m_out_xml_path, m_out_bin_path); auto result = ov::test::readModel(m_out_xml_path, m_out_bin_path);
const auto fc = FunctionsComparator::with_default() const auto fc = FunctionsComparator::with_default()
.enable(FunctionsComparator::ATTRIBUTES) .enable(FunctionsComparator::ATTRIBUTES)
.enable(FunctionsComparator::CONST_VALUES); .enable(FunctionsComparator::CONST_VALUES);
const auto res = fc.compare(result, expected); const auto res = fc.compare(result, expected);
const auto res2 = fc.compare(expected, orig);
EXPECT_TRUE(res.valid) << res.message; EXPECT_TRUE(res.valid) << res.message;
EXPECT_TRUE(res2.valid) << res2.message;
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(

View File

@ -170,3 +170,19 @@ TEST_F(CustomOpsSerializationTest, CustomOpOVExtensions) {
ASSERT_TRUE(success) << message; ASSERT_TRUE(success) << message;
} }
TEST_F(CustomOpsSerializationTest, CloneFrameworkNode) {
const std::string model = CommonTestUtils::getModelFromTestModelZoo(IR_SERIALIZATION_MODELS_PATH "custom_op.xml");
InferenceEngine::Core ie;
auto extension = std::make_shared<FrameworkNodeExtension>();
ie.AddExtension(extension);
auto expected = ie.ReadNetwork(model);
auto clone = ov::clone_model(*expected.getFunction());
const FunctionsComparator func_comparator = FunctionsComparator::with_default()
.enable(FunctionsComparator::ATTRIBUTES)
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::PRECISIONS);
const FunctionsComparator::Result result = func_comparator.compare(clone, expected.getFunction());
ASSERT_TRUE(result.valid) << result.message;
}