Fix issue with GNA Import/Export. (#4563)
* Fix issue with GNA Import/Export. Application may create ostream and write some app-specific header to beginning of stream. GNA plug-in shall not 'seekg' to beginning of stream where app-specific content is located. Instead it shall seek to the original position of stream Functional test is also updated to cover this case * Updated according to review comments - Check return code right after 'istream::tellg' - Added 'applicationHeader' as parameter to base ImportExport tests * Error check for startPos + unit test * Suppress deprecated warnings like in other GNA tests
This commit is contained in:
@@ -77,12 +77,17 @@ const int gna_header_magic = is_little_endian() ? 0x4d414e47 : 0x474e414d;
|
||||
|
||||
GNAPluginNS::HeaderLatest::ModelHeader GNAModelSerial::ReadHeader(std::istream &is) {
|
||||
is.exceptions(std::istream::failbit);
|
||||
auto startPos = is.tellg();
|
||||
if (startPos == -1) {
|
||||
THROW_GNA_EXCEPTION << "Can't open stream to import";
|
||||
}
|
||||
is.seekg(0, is.end);
|
||||
auto stream_len = is.tellg();
|
||||
if (stream_len == -1) {
|
||||
THROW_GNA_EXCEPTION << "Can't open file to import";
|
||||
}
|
||||
is.seekg(0, is.beg);
|
||||
stream_len -= startPos;
|
||||
is.seekg(startPos, is.beg);
|
||||
|
||||
HeaderLatest::ModelHeader header;
|
||||
header.version.major = 0u;
|
||||
@@ -103,7 +108,7 @@ GNAPluginNS::HeaderLatest::ModelHeader GNAModelSerial::ReadHeader(std::istream &
|
||||
std::hex << std::setw(2) << static_cast<short>(header.gnam[3]);
|
||||
}
|
||||
|
||||
is.seekg(0, is.beg);
|
||||
is.seekg(startPos, is.beg);
|
||||
Header2dot1::ModelHeader tempHeader2dot1;
|
||||
switch (header.version.major) {
|
||||
case 2:
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include <istream>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <ie_input_info.hpp>
|
||||
#include <ie_icnn_network.hpp>
|
||||
|
||||
#include "descriptions/gna_input_desc.hpp"
|
||||
#include "descriptions/gna_output_desc.hpp"
|
||||
|
||||
@@ -14,13 +14,23 @@ namespace {
|
||||
class ImportReshapePermuteConvGNA : public ImportReshapePermuteConv {
|
||||
private:
|
||||
void exportImportNetwork() override {
|
||||
executableNetwork.Export(fileName);
|
||||
std::fstream inputStream(fileName, std::ios_base::in | std::ios_base::binary);
|
||||
if (inputStream.fail()) {
|
||||
FAIL() << "Cannot open file to import model: " << fileName;
|
||||
{
|
||||
std::ofstream out(fileName);
|
||||
out.write(applicationHeader.c_str(), applicationHeader.size());
|
||||
executableNetwork.Export(out);
|
||||
}
|
||||
{
|
||||
std::string appHeader(applicationHeader.size(), ' ');
|
||||
std::fstream inputStream(fileName, std::ios_base::in | std::ios_base::binary);
|
||||
if (inputStream.fail()) {
|
||||
FAIL() << "Cannot open file to import model: " << fileName;
|
||||
}
|
||||
inputStream.read(&appHeader[0], applicationHeader.size());
|
||||
ASSERT_EQ(appHeader, applicationHeader);
|
||||
executableNetwork = core->ImportNetwork(inputStream, targetDevice, configuration);
|
||||
}
|
||||
executableNetwork = core->ImportNetwork(inputStream, targetDevice, configuration);
|
||||
}
|
||||
|
||||
protected:
|
||||
void TearDown() override {
|
||||
if (remove(fileName.c_str()) != 0) {
|
||||
@@ -59,12 +69,18 @@ const std::vector<std::map<std::string, std::string>> importConfigs = {
|
||||
},
|
||||
};
|
||||
|
||||
const std::vector<std::string> appHeaders = {
|
||||
"",
|
||||
"APPLICATION_HEADER"
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_ImportNetworkCase, ImportReshapePermuteConvGNA,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(exportConfigs),
|
||||
::testing::ValuesIn(importConfigs)),
|
||||
::testing::ValuesIn(importConfigs),
|
||||
::testing::ValuesIn(appHeaders)),
|
||||
ImportReshapePermuteConvGNA::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -21,12 +21,18 @@ const std::vector<std::map<std::string, std::string>> importConfigs = {
|
||||
{}
|
||||
};
|
||||
|
||||
const std::vector<std::string> appHeaders = {
|
||||
"",
|
||||
"APPLICATION_HEADER"
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_ImportNetworkCase, ImportNonZero,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_MYRIAD),
|
||||
::testing::ValuesIn(exportConfigs),
|
||||
::testing::ValuesIn(importConfigs)),
|
||||
::testing::ValuesIn(importConfigs),
|
||||
::testing::ValuesIn(appHeaders)),
|
||||
ImportNonZero::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -12,7 +12,8 @@ typedef std::tuple<
|
||||
InferenceEngine::Precision, // Network Precision
|
||||
std::string, // Target Device
|
||||
std::map<std::string, std::string>, // Export Configuration
|
||||
std::map<std::string, std::string> // Import Configuration
|
||||
std::map<std::string, std::string>, // Import Configuration
|
||||
std::string // Application Header
|
||||
> exportImportNetworkParams;
|
||||
|
||||
namespace FuncTestUtils {
|
||||
@@ -26,6 +27,7 @@ public:
|
||||
protected:
|
||||
std::map<std::string, std::string> exportConfiguration;
|
||||
std::map<std::string, std::string> importConfiguration;
|
||||
std::string applicationHeader;
|
||||
|
||||
private:
|
||||
virtual void exportImportNetwork();
|
||||
|
||||
@@ -13,7 +13,8 @@ std::string ImportNetworkTestBase::getTestCaseName(testing::TestParamInfo<export
|
||||
std::string targetDevice;
|
||||
std::map<std::string, std::string> exportConfiguration;
|
||||
std::map<std::string, std::string> importConfiguration;
|
||||
std::tie(netPrecision, targetDevice, exportConfiguration, importConfiguration) = obj.param;
|
||||
std::string appHeader;
|
||||
std::tie(netPrecision, targetDevice, exportConfiguration, importConfiguration, appHeader) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
@@ -24,12 +25,19 @@ std::string ImportNetworkTestBase::getTestCaseName(testing::TestParamInfo<export
|
||||
for (auto const& configItem : importConfiguration) {
|
||||
result << "_importConfigItem=" << configItem.first << "_" << configItem.second;
|
||||
}
|
||||
result << "_appHeader=" << appHeader;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void ImportNetworkTestBase::exportImportNetwork() {
|
||||
std::stringstream strm;
|
||||
strm.write(applicationHeader.c_str(), applicationHeader.size());
|
||||
executableNetwork.Export(strm);
|
||||
|
||||
strm.seekg(0, strm.beg);
|
||||
std::string appHeader(applicationHeader.size(), ' ');
|
||||
strm.read(&appHeader[0], applicationHeader.size());
|
||||
ASSERT_EQ(appHeader, applicationHeader);
|
||||
executableNetwork = core->ImportNetwork(strm, targetDevice, configuration);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace LayerTestsDefinitions {
|
||||
|
||||
void ImportNonZero::SetUp() {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(netPrecision, targetDevice, exportConfiguration, importConfiguration) = this->GetParam();
|
||||
std::tie(netPrecision, targetDevice, exportConfiguration, importConfiguration, applicationHeader) = this->GetParam();
|
||||
const auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
const auto parameter = std::make_shared<ngraph::opset5::Parameter>(ngPrc, ngraph::Shape{1000});
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace LayerTestsDefinitions {
|
||||
|
||||
void ImportReshapePermuteConv::SetUp() {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(netPrecision, targetDevice, exportConfiguration, importConfiguration) = this->GetParam();
|
||||
std::tie(netPrecision, targetDevice, exportConfiguration, importConfiguration, applicationHeader) = this->GetParam();
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngPrc, { {1, 336} });
|
||||
|
||||
25
inference-engine/tests/unit/gna/gna_model_serial_test.cpp
Normal file
25
inference-engine/tests/unit/gna/gna_model_serial_test.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
// to suppress deprecated definition errors
|
||||
#define IMPLEMENT_INFERENCE_ENGINE_PLUGIN
|
||||
#include "gna_model_serial.hpp"
|
||||
|
||||
using ::testing::Return;
|
||||
using ::testing::_;
|
||||
|
||||
class IstreamMock final: public std::streambuf {
|
||||
public:
|
||||
MOCK_METHOD3(seekoff, std::streampos(std::streamoff, std::ios_base::seekdir,
|
||||
std::ios_base::openmode));
|
||||
};
|
||||
|
||||
TEST(GNAModelSerialTest, TestErrorOnTellg) {
|
||||
IstreamMock mock;
|
||||
EXPECT_CALL(mock, seekoff(_, _, _)).WillRepeatedly(Return(-1));
|
||||
std::istream is(&mock);
|
||||
ASSERT_THROW(GNAModelSerial::ReadHeader(is), InferenceEngine::details::InferenceEngineException);
|
||||
}
|
||||
Reference in New Issue
Block a user