external_port_id is calculated based on number of op inputs. (#6037)
* External_port_id is calcultaed based on number of op inputs. * Add test for external_port_id serialization. * Restore data section appearance in xml file.
This commit is contained in:
parent
0981a15846
commit
613bb981ce
@ -235,7 +235,7 @@ class XmlSerializer : public ngraph::AttributeVisitor {
|
||||
|
||||
void output_descriptions_on_adapter(const std::vector<std::shared_ptr<
|
||||
ngraph::op::util::SubGraphOp::OutputDescription>>& output_descriptions,
|
||||
const std::vector<std::string>& parameter_mapping,
|
||||
const uint32_t& input_count,
|
||||
const std::vector<std::string>& result_mapping,
|
||||
pugi::xml_node& port_map) {
|
||||
NGRAPH_CHECK(!result_mapping.empty(), "No results found in body Function.");
|
||||
@ -246,7 +246,7 @@ class XmlSerializer : public ngraph::AttributeVisitor {
|
||||
|
||||
for (const auto& output_description : output_descriptions) {
|
||||
pugi::xml_node output = port_map.append_child("output");
|
||||
output.append_attribute("external_port_id").set_value(parameter_mapping.size() + output_description->m_output_index);
|
||||
output.append_attribute("external_port_id").set_value(input_count + output_description->m_output_index);
|
||||
output.append_attribute("internal_layer_id").set_value(result_mapping[output_description->m_body_value_index].c_str());
|
||||
|
||||
if (auto concat_output = as_type_ptr<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(output_description)) {
|
||||
@ -306,7 +306,11 @@ public:
|
||||
input_descriptions_on_adapter(a->get(), parameter_mapping, result_mapping, port_map);
|
||||
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<std::shared_ptr
|
||||
<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter)) {
|
||||
output_descriptions_on_adapter(a->get(), parameter_mapping, result_mapping, port_map);
|
||||
uint32_t op_input_count = 0;
|
||||
for (auto c = m_xml_node.parent().child("input").first_child(); !c.empty(); c = c.next_sibling()) {
|
||||
op_input_count++;
|
||||
}
|
||||
output_descriptions_on_adapter(a->get(), op_input_count, result_mapping, port_map);
|
||||
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::v5::Loop::SpecialBodyPorts>>(&adapter)) {
|
||||
special_body_ports_on_adapter(a->get(), parameter_mapping, result_mapping, port_map);
|
||||
}
|
||||
@ -700,19 +704,6 @@ void ngfunction_2_irv10(pugi::xml_node& netXml,
|
||||
|
||||
// <layers/data> general attributes
|
||||
pugi::xml_node data = layer.append_child("data");
|
||||
XmlSerializer visitor(data, node_type_name, custom_opsets, constant_node_write_handler);
|
||||
NGRAPH_CHECK(node->visit_attributes(visitor), "Visitor API is not supported in ", node);
|
||||
rt_info::XmlSerializer{data}.serialize(node->get_rt_info());
|
||||
|
||||
if (exec_graph) {
|
||||
visit_exec_graph_node(layer, node);
|
||||
}
|
||||
|
||||
const bool data_attr_size =
|
||||
data.attributes().begin() == data.attributes().end();
|
||||
if (data_attr_size) {
|
||||
layer.remove_child(data);
|
||||
}
|
||||
|
||||
int port_id = 0;
|
||||
// <layers/input>
|
||||
@ -780,6 +771,21 @@ void ngfunction_2_irv10(pugi::xml_node& netXml,
|
||||
layer.insert_move_after(output, layer.first_child());
|
||||
}
|
||||
}
|
||||
|
||||
// fill <data> general attributes
|
||||
XmlSerializer visitor(data, node_type_name, custom_opsets, constant_node_write_handler);
|
||||
NGRAPH_CHECK(node->visit_attributes(visitor), "Visitor API is not supported in ", node);
|
||||
rt_info::XmlSerializer{data}.serialize(node->get_rt_info());
|
||||
|
||||
if (exec_graph) {
|
||||
visit_exec_graph_node(layer, node);
|
||||
}
|
||||
|
||||
const bool data_attr_size =
|
||||
data.attributes().begin() == data.attributes().end();
|
||||
if (data_attr_size) {
|
||||
layer.remove_child(data);
|
||||
}
|
||||
}
|
||||
// <edges>
|
||||
const std::vector<Edge> edge_mapping = create_edge_mapping(layer_ids, f);
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "ie_core.hpp"
|
||||
#include "ie_blob.h"
|
||||
#include "common_test_utils/data_utils.hpp"
|
||||
#include "pugixml.hpp"
|
||||
|
||||
#ifndef IR_SERIALIZATION_MODELS_PATH // should be already defined by cmake
|
||||
#define IR_SERIALIZATION_MODELS_PATH ""
|
||||
@ -84,3 +85,40 @@ TEST_F(SerializationTensorIteratorTest, TiNegativeStride) {
|
||||
|
||||
serialize_and_compare(model_path, weights);
|
||||
}
|
||||
|
||||
TEST_F(SerializationTensorIteratorTest, SerializationExternalPortIdInXmlFile) {
|
||||
const std::string model_path = IR_SERIALIZATION_MODELS_PATH "loop_2d_add.xml";
|
||||
const std::string binary_path = IR_SERIALIZATION_MODELS_PATH "loop_2d_add.bin";
|
||||
|
||||
InferenceEngine::Core ie;
|
||||
InferenceEngine::CNNNetwork expected;
|
||||
pugi::xml_document loop_orig;
|
||||
pugi::xml_document loop_serialized;
|
||||
|
||||
expected = ie.ReadNetwork(model_path, binary_path);
|
||||
expected.serialize(m_out_xml_path, m_out_bin_path);
|
||||
|
||||
pugi::xml_parse_result result = loop_orig.load_file(model_path.c_str());
|
||||
ASSERT_FALSE(result.status) << result.description();
|
||||
result = loop_serialized.load_file(m_out_xml_path.c_str());
|
||||
ASSERT_FALSE(result.status) << result.description();
|
||||
|
||||
auto node1 = loop_orig.child("net").child("layers").find_child_by_attribute("type", "Loop");
|
||||
auto node2 = loop_serialized.child("net").child("layers").find_child_by_attribute("type", "Loop");
|
||||
auto node2_port_map = node2.child("port_map").first_child();
|
||||
|
||||
for (auto ch = node1.child("port_map").first_child(); !ch.empty(); ch = ch.next_sibling()) {
|
||||
auto node1_external_port_id = std::stoi(ch.attribute("external_port_id").value());
|
||||
auto node2_external_port_id = std::stoi(node2_port_map.attribute("external_port_id").value());
|
||||
|
||||
if (node1_external_port_id == -1) {
|
||||
continue;
|
||||
}
|
||||
if (node2_external_port_id == -1) {
|
||||
node2_external_port_id = std::stoi(node2_port_map.next_sibling().attribute("external_port_id").value());
|
||||
}
|
||||
node2_port_map = node2_port_map.next_sibling();
|
||||
|
||||
EXPECT_EQ(node1_external_port_id, node2_external_port_id);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user