The supported models detection improvement (#1235)

* The supported models detection improvement

* Unit test for supported models detection
This commit is contained in:
Vladislav Volkov
2020-07-08 12:57:21 +03:00
committed by GitHub
parent c0b28daf9c
commit 499c976194
2 changed files with 80 additions and 13 deletions

View File

@@ -11,6 +11,8 @@
#include <string>
#include <vector>
#include <sstream>
#include <algorithm>
#include <cctype>
#include "description_buffer.hpp"
#include "ie_ir_parser.hpp"
@@ -27,20 +29,21 @@ bool IRReader::supportModel(std::istream& model) const {
const int header_size = 128;
std::string header(header_size, ' ');
model.read(&header[0], header_size);
model.seekg(0, model.beg);
// find '<net ' substring in the .xml file
bool supports = (header.find("<net ") != std::string::npos) ||
(header.find("<Net ") != std::string::npos);
pugi::xml_document doc;
auto res = doc.load_string(header.c_str(), pugi::parse_default | pugi::parse_fragment);
if (supports) {
pugi::xml_document xmlDoc;
model.seekg(0, model.beg);
pugi::xml_parse_result res = xmlDoc.load(model);
if (res.status != pugi::status_ok) {
supports = false;
} else {
pugi::xml_node root = xmlDoc.document_element();
auto version = GetIRVersion(root);
bool supports = false;
if (res == pugi::status_ok) {
pugi::xml_node root = doc.document_element();
std::string node_name = root.name();
std::transform(node_name.begin(), node_name.end(), node_name.begin(), ::tolower);
if (node_name == "net") {
size_t const version = GetIRVersion(root);
#ifdef IR_READER_V10
supports = version == 10;
#else
@@ -49,7 +52,6 @@ bool IRReader::supportModel(std::istream& model) const {
}
}
model.seekg(0, model.beg);
return supports;
}

View File

@@ -146,6 +146,71 @@ TEST_P(NetReaderTest, ReadCorrectModelWithWeightsUnicodePath) {
#endif
TEST(NetReaderTest, IRSupportModelDetection) {
InferenceEngine::Core ie;
static char const *model = R"V0G0N(<net name="Network" version="10" some_attribute="Test Attribute">
<layers>
<layer name="in1" type="Parameter" id="0" version="opset1">
<data element_type="f32" shape="1,3,22,22"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</output>
</layer>
<layer name="Abs" id="1" type="Abs" version="experimental">
<input>
<port id="1" precision="FP32">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</output>
</layer>
<layer name="output" type="Result" id="2" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
<edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
</edges>
</net>
)V0G0N";
std::string headers[] = {
R"()",
R"(<!-- <net name="Network" version="100500"> -->)",
R"(<!-- <net name="Network" version="10" some_attribute="Test Attribute"> -->)"
};
InferenceEngine::Blob::CPtr weights;
for (auto header : headers) {
ASSERT_NO_THROW(ie.ReadNetwork(header + model, weights));
}
}
std::string getTestCaseName(testing::TestParamInfo<NetReaderTestParams> testParams) {
InferenceEngine::SizeVector dims;
InferenceEngine::Precision prc;