Fix RTInfo Deserialization for multiple key occurances (#8455)

This commit is contained in:
Gleb Kazantaev 2021-11-09 21:09:48 +03:00 committed by GitHub
parent 34886b650d
commit eb2b149fca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 0 deletions

View File

@ -553,6 +553,75 @@ TEST_F(RTInfoDeserialization, NodeV11) {
}
}
TEST_F(RTInfoDeserialization, NodeV11MultipleRTKeys) {
std::string model = R"V0G0N(
<net name="Network" version="11">
<layers>
<layer name="in1" type="Parameter" id="0" version="opset8">
<data element_type="f32" shape="1,22,22,3"/>
<rt_info>
<attribute name="old_api_map" version="0" order="0,2,3,1" element_type="f16"/>
<attribute name="old_api_map" version="0" order="0,1,2,3" element_type="f32"/>
<attribute name="fused_names" version="0" value="in1"/>
</rt_info>
<output>
<port id="0" precision="FP32" names="input_tensor">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</output>
</layer>
<layer name="Round" id="1" type="Round" version="opset8">
<data mode="half_to_even"/>
<rt_info>
<attribute name="fused_names" version="0" value="Round1,Round2"/>
</rt_info>
<input>
<rt_info>
<attribute name="fused_names" version="0" value="check"/>
<attribute name="fused_names" version="0" value="multiple_keys"/>
</rt_info>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</input>
<output>
<port id="2" precision="FP32" names="output_tensor">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</output>
</layer>
<layer name="output" type="Result" id="2" version="opset8">
<rt_info>
<attribute name="old_api_map" version="0" order="0,3,1,2" element_type="f16"/>
</rt_info>
<input>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
<edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
</edges>
</net>
)V0G0N";
ASSERT_ANY_THROW(getWithIRFrontend(model));
}
TEST_F(RTInfoDeserialization, InputAndOutputV11) {
std::string model = R"V0G0N(
<net name="Network" version="11">

View File

@ -727,6 +727,9 @@ std::shared_ptr<ngraph::Node> XmlDeserializer::createNode(const std::vector<ngra
IE_THROW() << "rt_info attribute: " << attribute_name << " has no \"version\" field";
}
const auto& type_info = ov::DiscreteTypeInfo(attribute_name.c_str(), 0, attribute_version.c_str());
if (rt_info.count(type_info)) {
IE_THROW() << "multiple rt_info attributes are detected: " << type_info;
}
if (auto attr = attrs_factory.create_by_type_info(type_info)) {
RTInfoDeserializer attribute_visitor(item);
if (attr->visit_attributes(attribute_visitor)) {