* fix bug in Serialize (#74447) add simple serialization test to check pads changes clang fix add check and change pads in conv refactor ov::clone_model fix check in test * fix FrameworkNode and add test * fix assert in identiry.cpp * fix clone_nodes * remove for node and constructor for node_input.cpp add spaces add space
This commit is contained in:
parent
a18069926e
commit
d951433b12
@ -21,7 +21,7 @@ void Identity::validate_and_infer_types() {
|
|||||||
|
|
||||||
//! [op:copy]
|
//! [op:copy]
|
||||||
std::shared_ptr<ov::Node> Identity::clone_with_new_inputs(const ov::OutputVector& new_args) const {
|
std::shared_ptr<ov::Node> Identity::clone_with_new_inputs(const ov::OutputVector& new_args) const {
|
||||||
OPENVINO_ASSERT(new_args.size() != 1, "Incorrect number of new arguments");
|
OPENVINO_ASSERT(new_args.size() == 1, "Incorrect number of new arguments");
|
||||||
|
|
||||||
return std::make_shared<Identity>(new_args.at(0));
|
return std::make_shared<Identity>(new_args.at(0));
|
||||||
}
|
}
|
||||||
|
@ -245,6 +245,12 @@ std::vector<std::shared_ptr<ngraph::Node>> ngraph::clone_nodes(const std::vector
|
|||||||
new_output.get_rt_info() = output_rt_info;
|
new_output.get_rt_info() = output_rt_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto input : node->inputs()) {
|
||||||
|
const auto& output_rt_info = input.get_rt_info();
|
||||||
|
auto new_input = cloned_node->input(input.get_index());
|
||||||
|
new_input.get_rt_info() = output_rt_info;
|
||||||
|
}
|
||||||
|
|
||||||
cloned_node->set_op_annotations(node->get_op_annotations());
|
cloned_node->set_op_annotations(node->get_op_annotations());
|
||||||
|
|
||||||
node_map[node.get()] = cloned_node;
|
node_map[node.get()] = cloned_node;
|
||||||
|
@ -20,6 +20,9 @@ std::shared_ptr<ov::Node> ov::op::util::FrameworkNode::clone_with_new_inputs(con
|
|||||||
for (size_t i = 0; i < get_output_size(); ++i) {
|
for (size_t i = 0; i < get_output_size(); ++i) {
|
||||||
node->set_output_type(i, get_output_element_type(i), get_output_partial_shape(i));
|
node->set_output_type(i, get_output_element_type(i), get_output_partial_shape(i));
|
||||||
}
|
}
|
||||||
|
node->m_inputs_desc = m_inputs_desc;
|
||||||
|
node->m_output_desc = m_output_desc;
|
||||||
|
node->m_attrs = m_attrs;
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1012,7 +1012,8 @@ void serializeFunc(std::ostream& xml_file,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
bool pass::Serialize::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
bool pass::Serialize::run_on_model(const std::shared_ptr<ngraph::Function>& f_orig) {
|
||||||
|
auto f = ov::clone_model(*f_orig);
|
||||||
if (m_xmlFile && m_binFile) {
|
if (m_xmlFile && m_binFile) {
|
||||||
serializeFunc(*m_xmlFile, *m_binFile, f, m_version, m_custom_opsets);
|
serializeFunc(*m_xmlFile, *m_binFile, f, m_version, m_custom_opsets);
|
||||||
} else {
|
} else {
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
</output>
|
</output>
|
||||||
</layer>
|
</layer>
|
||||||
<layer id="2" name="Convolution_72" type="Convolution" version="opset1">
|
<layer id="2" name="Convolution_72" type="Convolution" version="opset1">
|
||||||
<data strides="1, 1" dilations="1, 1" pads_begin="0, 0" pads_end="0, 0" auto_pad="explicit" PrimitivesPriority="_IMPLS_"/>
|
<data strides="1, 1" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="SAME_LOWER" PrimitivesPriority="_IMPLS_"/>
|
||||||
<input>
|
<input>
|
||||||
<port id="0">
|
<port id="0">
|
||||||
<dim>1</dim>
|
<dim>1</dim>
|
||||||
|
@ -8,6 +8,8 @@
|
|||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "ngraph/pass/manager.hpp"
|
||||||
|
#include "ngraph/pass/serialize.hpp"
|
||||||
#include "openvino/util/file_util.hpp"
|
#include "openvino/util/file_util.hpp"
|
||||||
#include "read_ir.hpp"
|
#include "read_ir.hpp"
|
||||||
#include "util/graph_comparator.hpp"
|
#include "util/graph_comparator.hpp"
|
||||||
@ -41,14 +43,16 @@ public:
|
|||||||
|
|
||||||
TEST_P(SerializationTest, CompareFunctions) {
|
TEST_P(SerializationTest, CompareFunctions) {
|
||||||
auto expected = ov::test::readModel(m_model_path, m_binary_path);
|
auto expected = ov::test::readModel(m_model_path, m_binary_path);
|
||||||
|
auto orig = ov::clone_model(*expected);
|
||||||
ov::pass::Serialize(m_out_xml_path, m_out_bin_path).run_on_model(expected);
|
ov::pass::Serialize(m_out_xml_path, m_out_bin_path).run_on_model(expected);
|
||||||
auto result = ov::test::readModel(m_out_xml_path, m_out_bin_path);
|
auto result = ov::test::readModel(m_out_xml_path, m_out_bin_path);
|
||||||
|
|
||||||
const auto fc = FunctionsComparator::with_default()
|
const auto fc = FunctionsComparator::with_default()
|
||||||
.enable(FunctionsComparator::ATTRIBUTES)
|
.enable(FunctionsComparator::ATTRIBUTES)
|
||||||
.enable(FunctionsComparator::CONST_VALUES);
|
.enable(FunctionsComparator::CONST_VALUES);
|
||||||
const auto res = fc.compare(result, expected);
|
const auto res = fc.compare(result, expected);
|
||||||
|
const auto res2 = fc.compare(expected, orig);
|
||||||
EXPECT_TRUE(res.valid) << res.message;
|
EXPECT_TRUE(res.valid) << res.message;
|
||||||
|
EXPECT_TRUE(res2.valid) << res2.message;
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
@ -170,3 +170,19 @@ TEST_F(CustomOpsSerializationTest, CustomOpOVExtensions) {
|
|||||||
|
|
||||||
ASSERT_TRUE(success) << message;
|
ASSERT_TRUE(success) << message;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CustomOpsSerializationTest, CloneFrameworkNode) {
|
||||||
|
const std::string model = CommonTestUtils::getModelFromTestModelZoo(IR_SERIALIZATION_MODELS_PATH "custom_op.xml");
|
||||||
|
InferenceEngine::Core ie;
|
||||||
|
auto extension = std::make_shared<FrameworkNodeExtension>();
|
||||||
|
ie.AddExtension(extension);
|
||||||
|
auto expected = ie.ReadNetwork(model);
|
||||||
|
auto clone = ov::clone_model(*expected.getFunction());
|
||||||
|
|
||||||
|
const FunctionsComparator func_comparator = FunctionsComparator::with_default()
|
||||||
|
.enable(FunctionsComparator::ATTRIBUTES)
|
||||||
|
.enable(FunctionsComparator::CONST_VALUES)
|
||||||
|
.enable(FunctionsComparator::PRECISIONS);
|
||||||
|
const FunctionsComparator::Result result = func_comparator.compare(clone, expected.getFunction());
|
||||||
|
ASSERT_TRUE(result.valid) << result.message;
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user