* fix bug in Serialize (#74447) add simple serialization test to check pads changes clang fix add check and change pads in conv refactor ov::clone_model fix check in test * fix FrameworkNode and add test * fix assert in identiry.cpp * fix clone_nodes * remove for node and constructor for node_input.cpp add spaces add space
This commit is contained in:
parent
a18069926e
commit
d951433b12
@ -21,7 +21,7 @@ void Identity::validate_and_infer_types() {
|
||||
|
||||
//! [op:copy]
|
||||
std::shared_ptr<ov::Node> Identity::clone_with_new_inputs(const ov::OutputVector& new_args) const {
|
||||
OPENVINO_ASSERT(new_args.size() != 1, "Incorrect number of new arguments");
|
||||
OPENVINO_ASSERT(new_args.size() == 1, "Incorrect number of new arguments");
|
||||
|
||||
return std::make_shared<Identity>(new_args.at(0));
|
||||
}
|
||||
|
@ -245,6 +245,12 @@ std::vector<std::shared_ptr<ngraph::Node>> ngraph::clone_nodes(const std::vector
|
||||
new_output.get_rt_info() = output_rt_info;
|
||||
}
|
||||
|
||||
for (auto input : node->inputs()) {
|
||||
const auto& output_rt_info = input.get_rt_info();
|
||||
auto new_input = cloned_node->input(input.get_index());
|
||||
new_input.get_rt_info() = output_rt_info;
|
||||
}
|
||||
|
||||
cloned_node->set_op_annotations(node->get_op_annotations());
|
||||
|
||||
node_map[node.get()] = cloned_node;
|
||||
|
@ -20,6 +20,9 @@ std::shared_ptr<ov::Node> ov::op::util::FrameworkNode::clone_with_new_inputs(con
|
||||
for (size_t i = 0; i < get_output_size(); ++i) {
|
||||
node->set_output_type(i, get_output_element_type(i), get_output_partial_shape(i));
|
||||
}
|
||||
node->m_inputs_desc = m_inputs_desc;
|
||||
node->m_output_desc = m_output_desc;
|
||||
node->m_attrs = m_attrs;
|
||||
return node;
|
||||
}
|
||||
|
||||
|
@ -1012,7 +1012,8 @@ void serializeFunc(std::ostream& xml_file,
|
||||
} // namespace
|
||||
|
||||
namespace ov {
|
||||
bool pass::Serialize::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
||||
bool pass::Serialize::run_on_model(const std::shared_ptr<ngraph::Function>& f_orig) {
|
||||
auto f = ov::clone_model(*f_orig);
|
||||
if (m_xmlFile && m_binFile) {
|
||||
serializeFunc(*m_xmlFile, *m_binFile, f, m_version, m_custom_opsets);
|
||||
} else {
|
||||
|
@ -24,7 +24,7 @@
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="2" name="Convolution_72" type="Convolution" version="opset1">
|
||||
<data strides="1, 1" dilations="1, 1" pads_begin="0, 0" pads_end="0, 0" auto_pad="explicit" PrimitivesPriority="_IMPLS_"/>
|
||||
<data strides="1, 1" dilations="1, 1" pads_begin="1, 1" pads_end="1, 1" auto_pad="SAME_LOWER" PrimitivesPriority="_IMPLS_"/>
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
|
@ -8,6 +8,8 @@
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/serialize.hpp"
|
||||
#include "openvino/util/file_util.hpp"
|
||||
#include "read_ir.hpp"
|
||||
#include "util/graph_comparator.hpp"
|
||||
@ -41,14 +43,16 @@ public:
|
||||
|
||||
TEST_P(SerializationTest, CompareFunctions) {
|
||||
auto expected = ov::test::readModel(m_model_path, m_binary_path);
|
||||
auto orig = ov::clone_model(*expected);
|
||||
ov::pass::Serialize(m_out_xml_path, m_out_bin_path).run_on_model(expected);
|
||||
auto result = ov::test::readModel(m_out_xml_path, m_out_bin_path);
|
||||
|
||||
const auto fc = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::ATTRIBUTES)
|
||||
.enable(FunctionsComparator::CONST_VALUES);
|
||||
const auto res = fc.compare(result, expected);
|
||||
const auto res2 = fc.compare(expected, orig);
|
||||
EXPECT_TRUE(res.valid) << res.message;
|
||||
EXPECT_TRUE(res2.valid) << res2.message;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
|
@ -170,3 +170,19 @@ TEST_F(CustomOpsSerializationTest, CustomOpOVExtensions) {
|
||||
|
||||
ASSERT_TRUE(success) << message;
|
||||
}
|
||||
|
||||
TEST_F(CustomOpsSerializationTest, CloneFrameworkNode) {
|
||||
const std::string model = CommonTestUtils::getModelFromTestModelZoo(IR_SERIALIZATION_MODELS_PATH "custom_op.xml");
|
||||
InferenceEngine::Core ie;
|
||||
auto extension = std::make_shared<FrameworkNodeExtension>();
|
||||
ie.AddExtension(extension);
|
||||
auto expected = ie.ReadNetwork(model);
|
||||
auto clone = ov::clone_model(*expected.getFunction());
|
||||
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::ATTRIBUTES)
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::PRECISIONS);
|
||||
const FunctionsComparator::Result result = func_comparator.compare(clone, expected.getFunction());
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user