Split IR readers (#1167)

* Split IR readers

* Fixed tests

* CMAKE: Removed add_clang_format_target usage from readers
This commit is contained in:
Ilya Lavrenov 2020-07-02 13:31:44 +03:00 committed by GitHub
parent 0e904405f7
commit ef6280ab99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 174 additions and 345 deletions

View File

@ -88,6 +88,9 @@ function(ie_add_plugin)
if(TARGET inference_engine_ir_reader)
add_dependencies(${IE_PLUGIN_NAME} inference_engine_ir_reader)
endif()
if(TARGET inference_engine_ir_reader_v7)
add_dependencies(${IE_PLUGIN_NAME} inference_engine_ir_reader_v7)
endif()
if(TARGET inference_engine_onnx_reader)
add_dependencies(${IE_PLUGIN_NAME} inference_engine_onnx_reader)
endif()

View File

@ -34,26 +34,6 @@ namespace InferenceEngine {
namespace {
std::once_flag flag;
InferenceEngine::details::SharedObjectLoader::Ptr cnnReaderLoader;
InferenceEngine::details::SharedObjectLoader::Ptr createCnnReaderLoader() {
std::call_once(flag, [&] () {
FileUtils::FilePath libraryName = FileUtils::toFilePath(std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
FileUtils::FilePath irReadersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
if (!FileUtils::fileExist(irReadersLibraryPath)) {
THROW_IE_EXCEPTION << "Please, make sure that Inference Engine IR readers library "
<< FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
<< getIELibraryPath();
}
cnnReaderLoader = std::shared_ptr<InferenceEngine::details::SharedObjectLoader>(
new InferenceEngine::details::SharedObjectLoader(irReadersLibraryPath.c_str()));
});
return cnnReaderLoader;
}
IInferencePluginAPI* getInferencePluginAPIInterface(IInferencePlugin* iplugin) {
return dynamic_cast<IInferencePluginAPI*>(iplugin);
}

View File

@ -37,7 +37,6 @@ public:
* @brief This class is a wrapper for reader interfaces
*/
class Reader: public IReader {
private:
InferenceEngine::details::SOPointer<IReader> ptr;
std::once_flag readFlag;
std::string name;
@ -120,10 +119,16 @@ void registerReaders() {
readers.emplace("prototxt", onnxReader);
}
// try to load IR reader if library exists
auto irReader = create_if_exists("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
if (irReader)
readers.emplace("xml", irReader);
// try to load IR reader v10 if library exists
auto irReaderv10 = create_if_exists("IRv10", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
if (irReaderv10)
readers.emplace("xml", irReaderv10);
// try to load IR reader v7 if library exists
auto irReaderv7 = create_if_exists("IRv7", std::string("inference_engine_ir_reader_v7") + std::string(IE_BUILD_POSTFIX));
if (irReaderv7)
readers.emplace("xml", irReaderv7);
initialized = true;
}

View File

@ -3,7 +3,7 @@
#
set(TARGET_NAME inference_engine_reader_api)
# Reader API library
# Reader API interface library
add_library(${TARGET_NAME} INTERFACE)
target_include_directories(${TARGET_NAME} INTERFACE
@ -16,6 +16,8 @@ file(GLOB_RECURSE reader_api_src "${CMAKE_CURRENT_SOURCE_DIR}/reader_api/*.hpp"
add_cpplint_target(${TARGET_NAME}_cpplint FOR_SOURCES ${reader_api_src})
add_subdirectory(ir_reader)
add_subdirectory(ir_reader_v7)
if(NGRAPH_ONNX_IMPORT_ENABLE)
add_subdirectory(onnx_reader)
endif()

View File

@ -24,7 +24,8 @@ source_group("include" FILES ${PUBLIC_HEADERS})
add_library(${TARGET_NAME} SHARED ${LIBRARY_SRC} ${PUBLIC_HEADERS})
target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_ENGINE_API
IMPLEMENT_INFERENCE_ENGINE_PLUGIN)
IMPLEMENT_INFERENCE_ENGINE_PLUGIN
IR_READER_V10)
target_include_directories(${TARGET_NAME} PUBLIC ${PUBLIC_HEADERS_DIR})
target_include_directories(${TARGET_NAME} PRIVATE "${IE_MAIN_SOURCE_DIR}/src/inference_engine")
@ -35,11 +36,6 @@ target_link_libraries(${TARGET_NAME} PRIVATE pugixml)
# code style
add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME})
add_clang_format_target(${TARGET_NAME}_clang_format FOR_TARGETS ${TARGET_NAME})
# developer package
ie_developer_export_targets(${TARGET_NAME})
# install

View File

@ -27,25 +27,17 @@
#include <cpp/ie_cnn_network.h>
#include "ie_blob_stream.hpp"
#include "cnn_network_impl.hpp"
#include "details/caseless.hpp"
#include "details/ie_cnn_network_tools.h"
#include "ie_format_parser.h"
#include "ie_ngraph_utils.hpp"
#include "generic_ie.hpp"
#include "precision_utils.h"
#include "blob_factory.hpp"
#include "ie_cnn_net_reader_impl.h"
using namespace InferenceEngine;
using namespace XMLParseUtils;
IRParser::IRParser(size_t version): IRParser(version, {}) {}
IRParser::IRParser(size_t version, const std::vector<InferenceEngine::IExtensionPtr>& exts) {
if (version < 10) {
parser = std::make_shared<CNNParser>();
return;
}
switch (version) {
case 10:
parser = std::make_shared<V10Parser>(exts);
@ -72,45 +64,6 @@ public:
originBlob(weights) { }
};
std::shared_ptr<ICNNNetwork> CNNParser::parse(const pugi::xml_node& root, std::istream& binStream) {
auto getBlobStream = [](std::istream& binStream) {
details::BlobStream* blobStream = dynamic_cast<details::BlobStream*>(&binStream);
if (blobStream == nullptr) {
details::BlobStream helper({});
std::string typeStream = typeid(binStream).name();
std::string typeBlobStream = typeid(helper).name();
if (typeStream == typeBlobStream)
blobStream = static_cast<details::BlobStream*>(&binStream);
}
return blobStream;
};
details::CNNNetReaderImpl reader(std::make_shared<details::V2FormatParserCreator>());
ResponseDesc resp;
StatusCode ret = reader.ReadNetwork(root, &resp);
if (ret != OK)
THROW_IE_EXCEPTION << resp.msg;
TBlob<uint8_t>::Ptr weightsPtr;
// Try to get BlobStream to work with original blob
details::BlobStream* blobStream = getBlobStream(binStream);
if (blobStream != nullptr) {
weightsPtr = std::make_shared<WeightsHolderBlob>(blobStream->getBlob());
} else {
// Allocate a blob for weights
binStream.seekg(0, std::ios::end);
size_t length = binStream.tellg();
weightsPtr = std::make_shared<TBlob<uint8_t>>(TensorDesc(Precision::U8, {length}, Layout::C));
weightsPtr->allocate();
char* data = weightsPtr->buffer().as<char*>();
binStream.seekg(0, std::ios::beg);
binStream.read(data, length);
}
ret = reader.SetWeights(weightsPtr, &resp);
if (ret != OK)
THROW_IE_EXCEPTION << resp.msg;
return reader.getNetwork();
}
V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) {
// Load default opsets
opsets["opset1"] = ngraph::get_opset1();

View File

@ -4,7 +4,9 @@
#pragma once
#include <ngraph/opsets/opset.hpp>
#include <ie_blob.h>
#include <ie_icnn_network.hpp>
#include <ie_iextension.h>
#include <xml_parse_utils.h>
@ -18,7 +20,6 @@
#include <string>
#include <vector>
#include "cnn_network_impl.hpp"
#include "ie_ngraph_utils.hpp"
namespace InferenceEngine {

View File

@ -27,17 +27,38 @@ bool IRReader::supportModel(std::istream& model) const {
const int header_size = 128;
std::string header(header_size, ' ');
model.read(&header[0], header_size);
// find '<net ' substring in the .xml file
return (header.find("<net ") != std::string::npos) || (header.find("<Net ") != std::string::npos);
bool supports = (header.find("<net ") != std::string::npos) ||
(header.find("<Net ") != std::string::npos);
if (supports) {
pugi::xml_document xmlDoc;
model.seekg(0, model.beg);
pugi::xml_parse_result res = xmlDoc.load(model);
if (res.status != pugi::status_ok) {
supports = false;
} else {
pugi::xml_node root = xmlDoc.document_element();
auto version = GetIRVersion(root);
#ifdef IR_READER_V10
supports = version == 10;
#else
supports = version < 10;
#endif
}
}
model.seekg(0, model.beg);
return supports;
}
CNNNetwork IRReader::read(std::istream& model, const std::vector<IExtensionPtr>& exts) const {
std::istringstream emptyStream;
return read(model, emptyStream, exts);
}
CNNNetwork IRReader::read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const {
model.seekg(0, model.beg);
weights.seekg(0, weights.beg);
pugi::xml_document xmlDoc;
pugi::xml_parse_result res = xmlDoc.load(model);
if (res.status != pugi::status_ok) {

View File

@ -32,8 +32,6 @@ namespace InferenceEngine {
*/
class IRReader: public IReader {
public:
IRReader() = default;
void Release() noexcept override {
delete this;
}

View File

@ -0,0 +1,39 @@
# Copyright (C) 2018-2019 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
set(TARGET_NAME "inference_engine_ir_reader_v7")
if(ENABLE_LTO)
ie_enable_lto()
endif()
set(PUBLIC_HEADERS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/")
file(GLOB_RECURSE LIBRARY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
list(APPEND LIBRARY_SRC ${IE_MAIN_SOURCE_DIR}/src/readers/ir_reader/ie_ir_reader.cpp)
file(GLOB_RECURSE PUBLIC_HEADERS ${PUBLIC_HEADERS_DIR}/*.h ${PUBLIC_HEADERS_DIR}/*.hpp)
# Create named folders for the sources within the .vcproj
# Empty name lists them directly under the .vcproj
source_group("src" FILES ${LIBRARY_SRC})
source_group("include" FILES ${PUBLIC_HEADERS})
# Create shared library
add_library(${TARGET_NAME} SHARED ${LIBRARY_SRC} ${PUBLIC_HEADERS})
target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_ENGINE_API
IMPLEMENT_INFERENCE_ENGINE_PLUGIN)
target_include_directories(${TARGET_NAME} PUBLIC ${PUBLIC_HEADERS_DIR})
target_include_directories(${TARGET_NAME} PRIVATE "${IE_MAIN_SOURCE_DIR}/src/inference_engine"
"${IE_MAIN_SOURCE_DIR}/src/readers/ir_reader")
target_link_libraries(${TARGET_NAME} PUBLIC inference_engine_reader_api inference_engine_plugin_api ${NGRAPH_LIBRARIES} inference_engine)
target_link_libraries(${TARGET_NAME} PRIVATE pugixml)
# code style
add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME})

View File

@ -17,7 +17,6 @@
#include "cnn_network_ngraph_impl.hpp"
#include "details/os/os_filesystem.hpp"
#include "ie_format_parser.h"
#include "ie_ir_parser.hpp"
#include "ie_profiling.hpp"
#include "parsers.h"
#include "xml_parse_utils.h"
@ -29,20 +28,13 @@ using namespace InferenceEngine::details;
CNNNetReaderImpl::CNNNetReaderImpl(const FormatParserCreator::Ptr& _creator)
: parseSuccess(false), _version(0), parserCreator(_creator) {}
CNNNetReaderImpl::~CNNNetReaderImpl() { }
StatusCode CNNNetReaderImpl::SetWeights(const TBlob<uint8_t>::Ptr& weights, ResponseDesc* desc) noexcept {
if (!_parser && _version < 10) {
return DescriptionBuffer(desc) << "network must be read first";
}
try {
if (_version == 10) {
// It's time to perform actual reading of V10 network and instantiate CNNNetworkNGraphImpl
IRParser parser(_version, extensions);
pugi::xml_node root = xmlDoc->document_element();
details::BlobStream blobStream(weights);
network = parser.parse(root, blobStream);
} else if (weights) {
if (_version < 10) {
_parser->SetWeights(weights);
}
} catch (const InferenceEngineException& iee) {
@ -54,7 +46,7 @@ StatusCode CNNNetReaderImpl::SetWeights(const TBlob<uint8_t>::Ptr& weights, Resp
return OK;
}
size_t CNNNetReaderImpl::GetFileVersion(pugi::xml_node& root) {
static size_t GetFileVersion(pugi::xml_node& root) {
return XMLParseUtils::GetUIntAttr(root, "version", 0);
}
@ -126,12 +118,8 @@ StatusCode CNNNetReaderImpl::ReadNetwork(const pugi::xml_node& const_root, Respo
pugi::xml_node root = *const_cast<pugi::xml_node*>(&const_root);
_version = GetFileVersion(root);
if (_version < 2) THROW_IE_EXCEPTION << "deprecated IR version: " << _version;
if (_version == 10) {
// Activate an alternative code path for V10 that should be read into ngraph::Function
// We cannot proceed with reading right now, because there is not binary file loaded.
// So we are postponing real read until weights are specified.
parseSuccess = true;
} else if (_version < 10) {
if (_version < 10) {
_parser = parserCreator->create(_version);
InferenceEngine::details::CNNNetworkImplPtr local_network = _parser->Parse(root);
name = local_network->getName();
@ -185,10 +173,6 @@ StatusCode CNNNetReaderImpl::ReadNetwork() {
return OK;
}
void CNNNetReaderImpl::addExtensions(const std::vector<InferenceEngine::IExtensionPtr>& ext) {
extensions = ext;
}
std::shared_ptr<IFormatParser> V2FormatParserCreator::create(size_t version) {
return std::make_shared<FormatParser>(version);
}

View File

@ -58,14 +58,6 @@ public:
return network;
}
bool isParseSuccess(ResponseDesc* resp) noexcept {
return parseSuccess;
}
StatusCode getDescription(ResponseDesc* desc) noexcept {
return DescriptionBuffer(OK, desc) << description;
}
StatusCode getName(char* name, size_t len, ResponseDesc* resp) noexcept {
if (len > 0) {
size_t length = std::min(this->name.size(), len - 1); // cut the name if buffer is too small
@ -75,21 +67,8 @@ public:
return OK;
}
int getVersion(ResponseDesc* resp) noexcept {
return _version;
}
void Release() noexcept {
delete this;
}
void addExtensions(const std::vector<InferenceEngine::IExtensionPtr>& ext);
~CNNNetReaderImpl();
private:
std::shared_ptr<InferenceEngine::details::IFormatParser> _parser;
size_t GetFileVersion(pugi::xml_node& root);
StatusCode ReadNetwork();
std::string description;
@ -101,7 +80,6 @@ private:
// Stashed xmlDoc that is needed to delayed loading of V10 IR version
std::shared_ptr<pugi::xml_document> xmlDoc;
std::vector<InferenceEngine::IExtensionPtr> extensions;
};
} // namespace details

View File

@ -496,6 +496,9 @@ Blob::Ptr FormatParser::GetBlobFromSegment(const TBlob<uint8_t>::Ptr& weights, c
}
void FormatParser::SetWeights(const TBlob<uint8_t>::Ptr& weights) {
if (weights == nullptr)
return;
for (auto& kvp : _network->allLayers()) {
auto fit = layersParseInfo.find(kvp.second->name);
// todo: may check that earlier - while parsing...

View File

@ -0,0 +1,76 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ie_reader.hpp"
#include "ie_ir_parser.hpp"
#include "ie_blob_stream.hpp"
#include "ie_cnn_net_reader_impl.h"
using namespace InferenceEngine;
IRParser::IRParser(size_t version): IRParser(version, {}) {}
IRParser::IRParser(size_t version, const std::vector<InferenceEngine::IExtensionPtr>& exts) {
if (version < 10) {
parser = std::make_shared<CNNParser>();
return;
} else {
THROW_IE_EXCEPTION << "Unsupported IR version: " << version;
}
}
std::shared_ptr<ICNNNetwork> IRParser::parse(const pugi::xml_node& root, std::istream& binStream) {
return parser->parse(root, binStream);
}
/**
* Hold original blob in order to avoid situations when original blob is allocated on stack
*/
class WeightsHolderBlob : public TBlob<uint8_t> {
Blob::CPtr originBlob;
public:
explicit WeightsHolderBlob(const Blob::CPtr& weights) :
TBlob<uint8_t>(weights->getTensorDesc(),
weights->cbuffer().as<uint8_t*>()),
originBlob(weights) { }
};
std::shared_ptr<ICNNNetwork> CNNParser::parse(const pugi::xml_node& root, std::istream& binStream) {
auto getBlobStream = [](std::istream& binStream) {
details::BlobStream* blobStream = dynamic_cast<details::BlobStream*>(&binStream);
if (blobStream == nullptr) {
details::BlobStream helper({});
std::string typeStream = typeid(binStream).name();
std::string typeBlobStream = typeid(helper).name();
if (typeStream == typeBlobStream)
blobStream = static_cast<details::BlobStream*>(&binStream);
}
return blobStream;
};
details::CNNNetReaderImpl reader(std::make_shared<details::V2FormatParserCreator>());
ResponseDesc resp;
StatusCode ret = reader.ReadNetwork(root, &resp);
if (ret != OK)
THROW_IE_EXCEPTION << resp.msg;
TBlob<uint8_t>::Ptr weightsPtr;
// Try to get BlobStream to work with original blob
details::BlobStream* blobStream = getBlobStream(binStream);
if (blobStream != nullptr) {
weightsPtr = std::make_shared<WeightsHolderBlob>(blobStream->getBlob());
} else {
// Allocate a blob for weights
binStream.seekg(0, std::ios::end);
size_t length = binStream.tellg();
weightsPtr = std::make_shared<TBlob<uint8_t>>(TensorDesc(Precision::U8, {length}, Layout::C));
weightsPtr->allocate();
char* data = weightsPtr->buffer().as<char*>();
binStream.seekg(0, std::ios::beg);
binStream.read(data, length);
}
ret = reader.SetWeights(weightsPtr, &resp);
if (ret != OK)
THROW_IE_EXCEPTION << resp.msg;
return reader.getNetwork();
}

View File

@ -13,6 +13,7 @@ bool ONNXReader::supportModel(std::istream& model) const {
const int header_size = 128;
std::string header(header_size, ' ');
model.read(&header[0], header_size);
model.seekg(0, model.beg);
// find 'onnx' substring in the .onnx files
// find 'ir_version' and 'graph' for prototxt
// return (header.find("onnx") != std::string::npos) || (header.find("pytorch") != std::string::npos) ||
@ -21,7 +22,6 @@ bool ONNXReader::supportModel(std::istream& model) const {
}
CNNNetwork ONNXReader::read(std::istream& model, const std::vector<IExtensionPtr>& exts) const {
model.seekg(0, model.beg);
return CNNNetwork(ngraph::onnx_import::import_onnx_model(model));
}

View File

@ -10,7 +10,6 @@ list(APPEND EXPORT_DEPENDENCIES
commonTestUtils_s
inference_engine_s
inference_engine_lp_transformations
inference_engine_ir_reader
gmock)
addIeTarget(

View File

@ -6,7 +6,6 @@
#include "unit_test_utils/mocks/mock_icnn_network.hpp"
#include "unit_test_utils/mocks/mock_ie_imemory_state.hpp"
#include "unit_test_utils/mocks/mock_iexecutable_network.hpp"
#include "unit_test_utils/mocks/mock_iformat_parser.hpp"
#include "unit_test_utils/mocks/mock_iinfer_request.hpp"
#include "unit_test_utils/mocks/mock_iinference_plugin.hpp"
#include "unit_test_utils/mocks/mock_not_empty_icnn_network.hpp"

View File

@ -1,24 +0,0 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
/**
* \brief mock file for header file for IFormatParser
* \file mock_iformat_parser.hpp
*/
#pragma once
#include <gmock/gmock.h>
#include "ie_icnn_network.hpp"
#include <ie_cnn_net_reader_impl.h>
#include <parsers.h>
#include "pugixml.hpp"
struct MockIFormatParser : public InferenceEngine::details::IFormatParser {
public:
MOCK_METHOD1(Parse, InferenceEngine::details::CNNNetworkImplPtr(pugi::xml_node &));
MOCK_METHOD1(SetWeights, void(const InferenceEngine::TBlob<uint8_t>::Ptr &));
};

View File

@ -9,6 +9,7 @@ addIeTargetTest(
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
LINK_LIBRARIES
unitTestUtils
inference_engine_ir_reader_v7
ADD_CPPLINT
DEPENDENCIES
mock_engine

View File

@ -4,7 +4,7 @@
#include "classification_matcher.hpp"
#include <gtest/gtest.h>
#include <xml_helper.hpp>
#include <fstream>
#include "details/ie_cnn_network_iterator.hpp"
using namespace Regression ;

View File

@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <xml_helper.hpp>
#include "object_detection_matcher.hpp"
#include "details/ie_cnn_network_iterator.hpp"

View File

@ -6,7 +6,6 @@
#include <tests_common.hpp>
#include <tests_common_func.hpp>
#include <memory>
#include "xml_helper.hpp"
#include <ie_core.hpp>
#define XBYAK_NO_OP_NAMES

View File

@ -7,7 +7,6 @@
#include <gtest/gtest.h>
#include <tests_common.hpp>
#include <ie_format_parser.h>
#include <ie_layers_internal.hpp>
#include <details/ie_cnn_network_iterator.hpp>
#include <functional_test_utils/plugin_cache.hpp>

View File

@ -5,7 +5,6 @@
#include <memory>
#include <unordered_set>
#include <xml_helper.hpp>
#include <gtest/gtest.h>
#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp"

View File

@ -6,6 +6,7 @@
#include <ngraph_functions/subgraph_builders.hpp>
#include "myriad_layers_tests.hpp"
#include "vpu_tests_config.hpp"
#include <fstream>
using namespace InferenceEngine;
using namespace ::testing;

View File

@ -8,7 +8,6 @@
#include <chrono>
#include <iostream>
#include "ie_ir_reader.hpp"
#include "functional_test_utils/plugin_cache.hpp"
using namespace InferenceEngine;

View File

@ -3,7 +3,7 @@
//
#include "vpu_layer_tests_utils.hpp"
#include <fstream>
#include "common_test_utils/common_layers_params.hpp"
using namespace InferenceEngine;

View File

@ -15,7 +15,6 @@
#include <common/include/vpu/utils/error.hpp>
#include "blob_factory.hpp"
#include "ie_ir_reader.hpp"
#include "debug.h"
#include "vpu_tests_config.hpp"

View File

@ -21,7 +21,6 @@ function(add_helpers target_name)
target_include_directories(${target_name} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}"
"${IE_MAIN_SOURCE_DIR}/src/inference_engine"
$<TARGET_PROPERTY:inference_engine_ir_reader,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:inference_engine_lp_transformations,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:pugixml,INTERFACE_INCLUDE_DIRECTORIES>
"${IE_MAIN_SOURCE_DIR}/src/vpu/"

View File

@ -6,10 +6,10 @@
#include <ie_blob.h>
#include <ie_core.hpp>
#include <cnn_network_impl.hpp>
#include <ie_layers_property.hpp>
#include <precision_utils.h>
#include <common_test_utils/xml_net_builder/xml_net_builder.hpp>
#include <xml_helper.hpp>
#include <tests_common.hpp>
#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 9) && !defined(__clang__)

View File

@ -143,7 +143,7 @@ target_link_libraries(${TARGET_NAME} PRIVATE
# dynamic libraries
inference_engine_lp_transformations
inference_engine_transformations
inference_engine_ir_reader
inference_engine_ir_reader_v7
${CMAKE_DL_LIBS})
if(TARGET libGNAStubs)

View File

@ -10,7 +10,6 @@
#include <thread>
#include "unit_test_utils/mocks/mock_icnn_network.hpp"
#include "unit_test_utils/mocks/mock_iformat_parser.hpp"
#include "common_test_utils/common_utils.hpp"
using namespace testing;
@ -24,184 +23,6 @@ public:
ResponseDesc resp;
};
struct MockFormatParserCreator : public FormatParserCreator {
MockFormatParserCreator() {
_parser = make_shared<MockIFormatParser>();
}
std::shared_ptr<IFormatParser> create(size_t version) override {
return _parser;
}
MockIFormatParser* getParser() {
return _parser.get();
}
private:
std::shared_ptr<MockIFormatParser> _parser;
};
TEST_F(CNNNetReaderImplTest, validateIsCalled) {
std::string model =
"<net name=\"PVANET\" version=\"2\" batch=\"1\">"
" <layers>"
" <layer name=\"data\" type=\"Input\" precision=\"FP32\" id=\"0\">"
" <output>"
" <port id=\"0\">"
" <dim>1</dim>"
" <dim>3</dim>"
" <dim>544</dim>"
" <dim>992</dim>"
" </port>"
" </output>"
" </layer>"
" <layer name=\"conv1_1_conv\" type=\"Convolution\" precision=\"FP32\" id=\"2\">"
" <convolution_data stride-x=\"2\" stride-y=\"2\" pad-x=\"3\" pad-y=\"3\" kernel-x=\"7\" kernel-y=\"7\" output=\"16\" group=\"1\"/>"
" <input>"
" <port id=\"2\">"
" <dim>1</dim>"
" <dim>3</dim>"
" <dim>544</dim>"
" <dim>992</dim>"
" </port>"
" </input>"
" <output>"
" <port id=\"3\">"
" <dim>1</dim>"
" <dim>16</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </output>"
" <weights offset=\"0\" size=\"9408\"/>"
" <biases offset=\"9408\" size=\"64\"/>"
" </layer>"
" <layer name=\"conv1_1_neg\" type=\"Power\" precision=\"FP32\" id=\"3\">"
" <power_data power=\"1\" scale=\"-1\" shift=\"0\"/>"
" <input>"
" <port id=\"4\">"
" <dim>1</dim>"
" <dim>16</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </input>"
" <output>"
" <port id=\"5\">"
" <dim>1</dim>"
" <dim>16</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </output>"
" </layer>"
" <layer name=\"conv1_1_concat\" type=\"Concat\" precision=\"FP32\" id=\"4\">"
" <concat_data axis=\"1\"/>"
" <input>"
" <port id=\"6\">"
" <dim>1</dim>"
" <dim>16</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" <port id=\"7\">"
" <dim>1</dim>"
" <dim>16</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </input>"
" <output>"
" <port id=\"8\">"
" <dim>1</dim>"
" <dim>32</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </output>"
" </layer>"
" <layer name=\"conv1_1_scale\" type=\"ScaleShift\" precision=\"FP32\" id=\"5\">"
" <input>"
" <port id=\"9\">"
" <dim>1</dim>"
" <dim>32</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </input>"
" <output>"
" <port id=\"10\">"
" <dim>1</dim>"
" <dim>32</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </output>"
" <weights offset=\"9472\" size=\"128\"/>"
" <biases offset=\"9600\" size=\"128\"/>"
" </layer>"
" <layer name=\"conv1_1_relu\" type=\"ReLU\" precision=\"FP32\" id=\"6\">"
" <data negative_slope=\"0\" engine=\"caffe.ReLUParameter.DEFAULT\"/>"
" <input>"
" <port id=\"11\">"
" <dim>1</dim>"
" <dim>32</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </input>"
" <output>"
" <port id=\"12\">"
" <dim>1</dim>"
" <dim>32</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </output>"
" </layer>"
" <layer name=\"pool1\" type=\"Pooling\" precision=\"FP32\" id=\"7\">"
" <pooling_data kernel-x=\"3\" kernel-y=\"3\" pad-x=\"0\" pad-y=\"0\" stride-x=\"2\" stride-y=\"2\" rounding-type=\"ceil\" pool-method=\"max\"/>"
" <input>"
" <port id=\"13\">"
" <dim>1</dim>"
" <dim>32</dim>"
" <dim>272</dim>"
" <dim>496</dim>"
" </port>"
" </input>"
" <output>"
" <port id=\"14\">"
" <dim>1</dim>"
" <dim>32</dim>"
" <dim>136</dim>"
" <dim>248</dim>"
" </port>"
" </output>"
" </layer>"
" </layers>"
" <edges>"
" <edge from-layer=\"0\" from-port=\"0\" to-layer=\"2\" to-port=\"2\"/>"
" <edge from-layer=\"2\" from-port=\"3\" to-layer=\"3\" to-port=\"4\"/>"
" <edge from-layer=\"2\" from-port=\"3\" to-layer=\"4\" to-port=\"6\"/>"
" <edge from-layer=\"3\" from-port=\"5\" to-layer=\"4\" to-port=\"7\"/>"
" <edge from-layer=\"4\" from-port=\"8\" to-layer=\"5\" to-port=\"9\"/>"
" <edge from-layer=\"5\" from-port=\"10\" to-layer=\"6\" to-port=\"11\"/>"
" <edge from-layer=\"6\" from-port=\"12\" to-layer=\"7\" to-port=\"13\"/>"
" </edges>"
"</net>";
auto parserCreator = make_shared<MockFormatParserCreator>();
CNNNetReaderImpl reader(parserCreator);
auto network = make_shared<MockCNNNetworkImpl>();
auto name = std::string{"AlexNet"};
EXPECT_CALL(*parserCreator->getParser(), Parse(_)).Times(1).WillOnce(Return(network));
EXPECT_CALL(*network.get(), validate(_)).Times(1);
EXPECT_CALL(*network.get(), getName()).Times(1).WillOnce(ReturnRef(name));
ASSERT_NO_THROW(sts = reader.ReadNetwork(model.data(), model.length(), &resp));
ASSERT_EQ(OK, sts);
}
TEST_F(CNNNetReaderImplTest, cycleIsDetectedInReader) {
std::string model =
"<net batch=\"1\" name=\"model\" version=\"2\">"