The supported models detection improvement (#1235)
* The supported models detection improvement * Unit test for supported models detection
This commit is contained in:
@@ -11,6 +11,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
|
||||
#include "description_buffer.hpp"
|
||||
#include "ie_ir_parser.hpp"
|
||||
@@ -27,20 +29,21 @@ bool IRReader::supportModel(std::istream& model) const {
|
||||
const int header_size = 128;
|
||||
std::string header(header_size, ' ');
|
||||
model.read(&header[0], header_size);
|
||||
model.seekg(0, model.beg);
|
||||
|
||||
// find '<net ' substring in the .xml file
|
||||
bool supports = (header.find("<net ") != std::string::npos) ||
|
||||
(header.find("<Net ") != std::string::npos);
|
||||
pugi::xml_document doc;
|
||||
auto res = doc.load_string(header.c_str(), pugi::parse_default | pugi::parse_fragment);
|
||||
|
||||
if (supports) {
|
||||
pugi::xml_document xmlDoc;
|
||||
model.seekg(0, model.beg);
|
||||
pugi::xml_parse_result res = xmlDoc.load(model);
|
||||
if (res.status != pugi::status_ok) {
|
||||
supports = false;
|
||||
} else {
|
||||
pugi::xml_node root = xmlDoc.document_element();
|
||||
auto version = GetIRVersion(root);
|
||||
bool supports = false;
|
||||
|
||||
if (res == pugi::status_ok) {
|
||||
pugi::xml_node root = doc.document_element();
|
||||
|
||||
std::string node_name = root.name();
|
||||
std::transform(node_name.begin(), node_name.end(), node_name.begin(), ::tolower);
|
||||
|
||||
if (node_name == "net") {
|
||||
size_t const version = GetIRVersion(root);
|
||||
#ifdef IR_READER_V10
|
||||
supports = version == 10;
|
||||
#else
|
||||
@@ -49,7 +52,6 @@ bool IRReader::supportModel(std::istream& model) const {
|
||||
}
|
||||
}
|
||||
|
||||
model.seekg(0, model.beg);
|
||||
return supports;
|
||||
}
|
||||
|
||||
|
||||
@@ -146,6 +146,71 @@ TEST_P(NetReaderTest, ReadCorrectModelWithWeightsUnicodePath) {
|
||||
|
||||
#endif
|
||||
|
||||
TEST(NetReaderTest, IRSupportModelDetection) {
|
||||
InferenceEngine::Core ie;
|
||||
|
||||
static char const *model = R"V0G0N(<net name="Network" version="10" some_attribute="Test Attribute">
|
||||
<layers>
|
||||
<layer name="in1" type="Parameter" id="0" version="opset1">
|
||||
<data element_type="f32" shape="1,3,22,22"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>22</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="Abs" id="1" type="Abs" version="experimental">
|
||||
<input>
|
||||
<port id="1" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>22</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>22</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="output" type="Result" id="2" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>22</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
|
||||
<edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
|
||||
std::string headers[] = {
|
||||
R"()",
|
||||
R"(<!-- <net name="Network" version="100500"> -->)",
|
||||
R"(<!-- <net name="Network" version="10" some_attribute="Test Attribute"> -->)"
|
||||
};
|
||||
|
||||
InferenceEngine::Blob::CPtr weights;
|
||||
|
||||
for (auto header : headers) {
|
||||
ASSERT_NO_THROW(ie.ReadNetwork(header + model, weights));
|
||||
}
|
||||
}
|
||||
|
||||
std::string getTestCaseName(testing::TestParamInfo<NetReaderTestParams> testParams) {
|
||||
InferenceEngine::SizeVector dims;
|
||||
InferenceEngine::Precision prc;
|
||||
|
||||
Reference in New Issue
Block a user