Remove ExecutableNetwork::QueryState (#8034)

* removed ExecutableNetwork::QueryState from code

* removed ExecutableNetwork::QueryStates from tests (not checked)

* buildable version

* remove unneeded change and fix cpplint error

* remove extra space

* remove QueryState from GNAExecutableNetwork

* clean up GNA tests for QueryState in tests_deprecated (without replacement because deprecated)

* fix tests after merge

* remove tests again after merge

* fixed tests with _REGULAR_API suffix
This commit is contained in:
Svetlana Dolinina 2021-11-15 13:58:26 +03:00 committed by GitHub
parent 3354275da1
commit e6884c3fd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 52 additions and 417 deletions

View File

@ -67,13 +67,6 @@ class GNAExecutableNetwork : public InferenceEngine::IExecutableNetworkInternal
return std::make_shared<GNAInferRequest>(plg, inputs, outputs); return std::make_shared<GNAInferRequest>(plg, inputs, outputs);
} }
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
IE_SUPPRESS_DEPRECATED_START
return plg->QueryState();
IE_SUPPRESS_DEPRECATED_END
}
void Export(const std::string &modelFileName) override { void Export(const std::string &modelFileName) override {
plg->Export(modelFileName); plg->Export(modelFileName);
} }

View File

@ -190,18 +190,6 @@ public:
*/ */
INFERENCE_ENGINE_DEPRECATED("Use ExecutableNetwork::CreateInferRequest instead") INFERENCE_ENGINE_DEPRECATED("Use ExecutableNetwork::CreateInferRequest instead")
InferRequest::Ptr CreateInferRequestPtr(); InferRequest::Ptr CreateInferRequestPtr();
/**
* @deprecated Use InferRequest::QueryState instead
* @brief Gets state control interface for given executable network.
*
* State control essential for recurrent networks
*
* @return A vector of Memory State objects
*/
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
std::vector<VariableState> QueryState();
IE_SUPPRESS_DEPRECATED_END
}; };
} // namespace InferenceEngine } // namespace InferenceEngine

View File

@ -65,16 +65,6 @@ ExecutableNetwork::operator IExecutableNetwork::Ptr() {
return std::make_shared<ExecutableNetworkBase>(_impl); return std::make_shared<ExecutableNetworkBase>(_impl);
} }
std::vector<VariableState> ExecutableNetwork::QueryState() {
std::vector<VariableState> controller;
EXEC_NET_CALL_STATEMENT({
for (auto&& state : _impl->QueryState()) {
controller.emplace_back(VariableState{_so, state});
}
});
return controller;
}
InferRequest ExecutableNetwork::CreateInferRequest() { InferRequest ExecutableNetwork::CreateInferRequest() {
EXEC_NET_CALL_STATEMENT(return {_so, _impl->CreateInferRequest()}); EXEC_NET_CALL_STATEMENT(return {_so, _impl->CreateInferRequest()});
} }

View File

@ -87,10 +87,6 @@ std::shared_ptr<ngraph::Function> IExecutableNetworkInternal::GetExecGraphInfo()
IE_THROW(NotImplemented); IE_THROW(NotImplemented);
} }
std::vector<std::shared_ptr<IVariableStateInternal>> IExecutableNetworkInternal::QueryState() {
IE_THROW(NotImplemented);
}
void IExecutableNetworkInternal::SetPointerToPlugin(const std::shared_ptr<IInferencePlugin>& plugin) { void IExecutableNetworkInternal::SetPointerToPlugin(const std::shared_ptr<IInferencePlugin>& plugin) {
_plugin = plugin; _plugin = plugin;
} }

View File

@ -283,12 +283,6 @@ bool MKLDNNExecNetwork::CanProcessDynBatch(const InferenceEngine::CNNNetwork &ne
return true; return true;
} }
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
return memoryStates;
}
IE_SUPPRESS_DEPRECATED_END
void MKLDNNExecNetwork::Export(std::ostream& modelStream) { void MKLDNNExecNetwork::Export(std::ostream& modelStream) {
CNNNetworkSerializer serializer(modelStream, extensionManager); CNNNetworkSerializer serializer(modelStream, extensionManager);
serializer <<_network; serializer <<_network;

View File

@ -44,9 +44,6 @@ public:
std::shared_ptr<ngraph::Function> GetExecGraphInfo() override; std::shared_ptr<ngraph::Function> GetExecGraphInfo() override;
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
void Export(std::ostream& modelStream) override; void Export(std::ostream& modelStream) override;
protected: protected:

View File

@ -60,26 +60,20 @@ void MKLDNNPlugin::MKLDNNInferRequest::CreateInferRequest() {
// Save all MemoryLayer data tensors. Will use insight about mechanics // Save all MemoryLayer data tensors. Will use insight about mechanics
// of MemoryLayer implementation. It uses output edge of MemoryLayer // of MemoryLayer implementation. It uses output edge of MemoryLayer
// producer as storage for tensor to keep it between infer calls. // producer as storage for tensor to keep it between infer calls.
IE_SUPPRESS_DEPRECATED_START for (auto& node : graph->GetNodes()) {
if (execNetwork->_numRequests > 1 || execNetwork->QueryState().size() == 0) { if (node->getType() == MemoryInput) {
for (auto &node : graph->GetNodes()) { auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
if (node->getType() == MemoryInput) { auto state_store = memoryNode->getStore();
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get()); auto state_name = memoryNode->getId();
auto state_store = memoryNode->getStore();
auto state_name = memoryNode->getId();
// Remove suffix with pair ID. Internal information. // Remove suffix with pair ID. Internal information.
auto suffix_idx = state_name.find("/id="); auto suffix_idx = state_name.find("/id=");
if (suffix_idx != std::string::npos) if (suffix_idx != std::string::npos)
state_name = state_name.substr(0, suffix_idx); state_name = state_name.substr(0, suffix_idx);
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store)); memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
}
} }
} else {
memoryStates = execNetwork->QueryState();
} }
IE_SUPPRESS_DEPRECATED_END
} }
MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() { MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() {

View File

@ -115,13 +115,6 @@ public:
*/ */
virtual std::shared_ptr<ngraph::Function> GetExecGraphInfo(); virtual std::shared_ptr<ngraph::Function> GetExecGraphInfo();
/**
* @deprecated Need to implement GetVariablesInfo for ExecutableNetwork
* @brief Queries memory states.
* @return Returns memory states
*/
virtual std::vector<std::shared_ptr<IVariableStateInternal>> QueryState();
/** /**
* @brief Sets the pointer to plugin internal. * @brief Sets the pointer to plugin internal.
* @param[in] plugin The plugin * @param[in] plugin The plugin

View File

@ -35,13 +35,6 @@ TEST(ExecutableNetworkTests, throwsOnUninitializedGetExecGraphInfo) {
ASSERT_THROW(exec.GetExecGraphInfo(), InferenceEngine::NotAllocated); ASSERT_THROW(exec.GetExecGraphInfo(), InferenceEngine::NotAllocated);
} }
TEST(ExecutableNetworkTests, throwsOnUninitializedQueryState) {
IE_SUPPRESS_DEPRECATED_START
ExecutableNetwork exec;
ASSERT_THROW(exec.QueryState(), InferenceEngine::NotAllocated);
IE_SUPPRESS_DEPRECATED_END
}
TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) { TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) {
ExecutableNetwork exec; ExecutableNetwork exec;
ASSERT_THROW(exec.SetConfig({{}}), InferenceEngine::NotAllocated); ASSERT_THROW(exec.SetConfig({{}}), InferenceEngine::NotAllocated);

View File

@ -65,16 +65,14 @@ public:
} }
auto importedNetwork = core->ImportNetwork(inputStream, targetDevice, configuration); auto importedNetwork = core->ImportNetwork(inputStream, targetDevice, configuration);
std::vector<std::string> queryToState; std::vector<std::string> queryToState;
IE_SUPPRESS_DEPRECATED_START InferenceEngine::InferRequest importInfer = importedNetwork.CreateInferRequest();
for (const auto &query_state : executableNetwork.QueryState()) { for (const auto &query_state : importInfer.QueryState()) {
queryToState.push_back(query_state.GetName()); queryToState.push_back(query_state.GetName());
} }
for (const auto &next_memory : importedNetwork.QueryState()) { for (const auto &next_memory : importInfer.QueryState()) {
ASSERT_TRUE(std::find(queryToState.begin(), queryToState.end(), next_memory.GetName()) != queryToState.end()) ASSERT_TRUE(std::find(queryToState.begin(), queryToState.end(), next_memory.GetName()) != queryToState.end())
<< "State " << next_memory.GetName() << " expected to be in memory states but it is not!"; << "State " << next_memory.GetName() << " expected to be in memory states but it is not!";
} }
IE_SUPPRESS_DEPRECATED_END
InferenceEngine::InferRequest importInfer = importedNetwork.CreateInferRequest();
importInfer.Infer(); importInfer.Infer();
} }

View File

@ -42,8 +42,7 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
manager.register_pass<ngraph::pass::LowLatency2>(); // LowLatency enables UnrollTI manager.register_pass<ngraph::pass::LowLatency2>(); // LowLatency enables UnrollTI
manager.run_passes(function); manager.run_passes(function);
LoadNetwork(); LoadNetwork();
IE_SUPPRESS_DEPRECATED_START auto states = inferRequest.QueryState();
auto states = executableNetwork.QueryState();
for (auto& state : states) { for (auto& state : states) {
auto name = state.GetName(); auto name = state.GetName();
if (name.find("cell_state_1") != std::string::npos) { if (name.find("cell_state_1") != std::string::npos) {
@ -58,7 +57,6 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
GTEST_FAIL() << "unknown memory state"; GTEST_FAIL() << "unknown memory state";
} }
} }
IE_SUPPRESS_DEPRECATED_END
// Run and compare // Run and compare
Infer(); Infer();
const auto& actualOutputs = GetOutputs(); const auto& actualOutputs = GetOutputs();

View File

@ -31,98 +31,6 @@ InferenceEngine::ExecutableNetwork InferRequestVariableStateTest::PrepareNetwork
return ie->LoadNetwork(net, deviceName); return ie->LoadNetwork(net, deviceName);
} }
TEST_P(InferRequestVariableStateTest, smoke_VariableState_QueryState) {
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
auto states = executableNet.QueryState();
ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
for (auto &&state : states) {
auto name = state.GetName();
ASSERT_TRUE(std::find(statesToQuery.begin(), statesToQuery.end(), name) != statesToQuery.end())
<< "State " << name << "expected to be in memory states but it is not!";
}
IE_SUPPRESS_DEPRECATED_END
}
TEST_P(InferRequestVariableStateTest, smoke_VariableState_SetState) {
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
const float new_state_val = 13.0f;
for (auto &&state : executableNet.QueryState()) {
state.Reset();
auto state_val = state.GetState();
auto element_count = state_val->size();
float *new_state_data = new float[element_count];
for (int i = 0; i < element_count; i++) {
new_state_data[i] = new_state_val;
}
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
stateBlob->allocate();
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
delete[]new_state_data;
state.SetState(stateBlob);
}
for (auto &&state : executableNet.QueryState()) {
auto lastState = state.GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float *>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
for (int i = 0; i < last_state_size; i++) {
EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5);
}
}
IE_SUPPRESS_DEPRECATED_END
}
TEST_P(InferRequestVariableStateTest, smoke_VariableState_Reset) {
IE_SUPPRESS_DEPRECATED_START
auto executableNet = PrepareNetwork();
const float new_state_val = 13.0f;
for (auto &&state : executableNet.QueryState()) {
state.Reset();
auto state_val = state.GetState();
auto element_count = state_val->size();
float *new_state_data = new float[element_count];
for (int i = 0; i < element_count; i++) {
new_state_data[i] = new_state_val;
}
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
stateBlob->allocate();
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
delete[]new_state_data;
state.SetState(stateBlob);
}
executableNet.QueryState().front().Reset();
auto states = executableNet.QueryState();
for (int i = 0; i < states.size(); ++i) {
auto lastState = states[i].GetState();
auto last_state_size = lastState->size();
auto last_state_data = lastState->cbuffer().as<float *>();
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
if (i == 0) {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(0, last_state_data[j], 1e-5);
}
} else {
for (int j = 0; j < last_state_size; ++j) {
EXPECT_NEAR(new_state_val, last_state_data[j], 1e-5);
}
}
}
IE_SUPPRESS_DEPRECATED_END
}
TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_QueryState) { TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_QueryState) {
auto executableNet = PrepareNetwork(); auto executableNet = PrepareNetwork();
auto inferReq = executableNet.CreateInferRequest(); auto inferReq = executableNet.CreateInferRequest();

View File

@ -57,6 +57,7 @@ void DetectNetworkBatch::LoadNetwork() {
functionRefs = ngraph::clone_function(*cnnNetwork.getFunction()); functionRefs = ngraph::clone_function(*cnnNetwork.getFunction());
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
TEST_P(DetectNetworkBatch, InferWithOneInput) { TEST_P(DetectNetworkBatch, InferWithOneInput) {

View File

@ -54,10 +54,10 @@ namespace ConfigurationTestsDefinitions {
ConfigureNetwork(); ConfigureNetwork();
cnnNetwork.setBatchSize(max_batch_size); cnnNetwork.setBatchSize(max_batch_size);
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
void DynamicBatchTest::Infer() { void DynamicBatchTest::Infer() {
inferRequest = executableNetwork.CreateInferRequest();
inputs.clear(); inputs.clear();
for (int i = 0; i < batch_sizes.size(); i++) { for (int i = 0; i < batch_sizes.size(); i++) {

View File

@ -344,6 +344,7 @@ void LayerTestsCommon::LoadNetwork() {
CoreConfiguration(this); CoreConfiguration(this);
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
void LayerTestsCommon::GenerateInputs() { void LayerTestsCommon::GenerateInputs() {
@ -361,8 +362,6 @@ void LayerTestsCommon::GenerateInputs() {
} }
void LayerTestsCommon::Infer() { void LayerTestsCommon::Infer() {
inferRequest = executableNetwork.CreateInferRequest();
const auto& inputsInfo = executableNetwork.GetInputsInfo(); const auto& inputsInfo = executableNetwork.GetInputsInfo();
const auto& functionParams = function->get_parameters(); const auto& functionParams = function->get_parameters();
for (int i = 0; i < functionParams.size(); ++i) { for (int i = 0; i < functionParams.size(); ++i) {

View File

@ -81,6 +81,7 @@ namespace LayerTestsDefinitions {
CoreConfiguration(this); CoreConfiguration(this);
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
GenerateInputs(); GenerateInputs();
for (int64_t i = 0; i < iteration_count; ++i) { for (int64_t i = 0; i < iteration_count; ++i) {

View File

@ -103,8 +103,8 @@ namespace SubgraphTestsDefinitions {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()
LoadNetwork(); LoadNetwork();
IE_SUPPRESS_DEPRECATED_START
auto states = executableNetwork.QueryState(); auto states = inferRequest.QueryState();
for (auto& state : states) { for (auto& state : states) {
auto name = state.GetName(); auto name = state.GetName();
if (name == "memory_1") { if (name == "memory_1") {
@ -119,7 +119,6 @@ namespace SubgraphTestsDefinitions {
GTEST_FAIL() << "unknown memory state"; GTEST_FAIL() << "unknown memory state";
} }
} }
IE_SUPPRESS_DEPRECATED_END
GenerateInputs(); GenerateInputs();
Infer(); Infer();
switchToNgraphFriendlyModel(); switchToNgraphFriendlyModel();

View File

@ -6,8 +6,7 @@
namespace SubgraphTestsDefinitions { namespace SubgraphTestsDefinitions {
void DelayedCopyTestBase::InitMemory() { void DelayedCopyTestBase::InitMemory() {
IE_SUPPRESS_DEPRECATED_START auto states = inferRequest.QueryState();
auto states = executableNetwork.QueryState();
for (auto& state : states) { for (auto& state : states) {
auto name = state.GetName(); auto name = state.GetName();
if (name.find("id") != std::string::npos) { if (name.find("id") != std::string::npos) {
@ -18,7 +17,6 @@ namespace SubgraphTestsDefinitions {
GTEST_FAIL() << "unknown memory state"; GTEST_FAIL() << "unknown memory state";
} }
} }
IE_SUPPRESS_DEPRECATED_END
} }
void DelayedCopyTestBase::Run() { void DelayedCopyTestBase::Run() {

View File

@ -276,8 +276,7 @@ namespace SubgraphTestsDefinitions {
} }
void MemoryLSTMCellTest::InitMemory() { void MemoryLSTMCellTest::InitMemory() {
IE_SUPPRESS_DEPRECATED_START auto states = inferRequest.QueryState();
auto states = executableNetwork.QueryState();
for (auto& state : states) { for (auto& state : states) {
auto name = state.GetName(); auto name = state.GetName();
if (name.find("cell_state_1") != std::string::npos) { if (name.find("cell_state_1") != std::string::npos) {
@ -292,7 +291,6 @@ namespace SubgraphTestsDefinitions {
GTEST_FAIL() << "unknown memory state"; GTEST_FAIL() << "unknown memory state";
} }
} }
IE_SUPPRESS_DEPRECATED_END
} }
void MemoryLSTMCellTest::ApplyLowLatency() { void MemoryLSTMCellTest::ApplyLowLatency() {
@ -330,6 +328,7 @@ namespace SubgraphTestsDefinitions {
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) { } else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) {
cnnNetwork = InferenceEngine::CNNNetwork{function}; cnnNetwork = InferenceEngine::CNNNetwork{function};
InferenceEngine::lowLatency2(cnnNetwork); InferenceEngine::lowLatency2(cnnNetwork);
@ -339,6 +338,7 @@ namespace SubgraphTestsDefinitions {
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
} }
} // namespace SubgraphTestsDefinitions } // namespace SubgraphTestsDefinitions

View File

@ -122,7 +122,7 @@ void MemoryEltwiseReshapeConcatTest::Run() {
InferenceEngine::Layout::NC); InferenceEngine::Layout::NC);
IE_SUPPRESS_DEPRECATED_START IE_SUPPRESS_DEPRECATED_START
auto states = executableNetwork.QueryState(); auto states = inferRequest.QueryState();
auto state_values_blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description, auto state_values_blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description,
memory_init.data(), memory_init.size()); memory_init.data(), memory_init.size());
states[0].SetState(state_values_blob); states[0].SetState(state_values_blob);

View File

@ -397,8 +397,7 @@ void MultipleLSTMCellTest::InitMemory() {
InferenceEngine::TensorDesc state_description(InferenceEngine::Precision::FP32, InferenceEngine::TensorDesc state_description(InferenceEngine::Precision::FP32,
InferenceEngine::SizeVector({1, hiddenSize}), InferenceEngine::SizeVector({1, hiddenSize}),
InferenceEngine::Layout::NC); InferenceEngine::Layout::NC);
IE_SUPPRESS_DEPRECATED_START auto states = inferRequest.QueryState();
auto states = executableNetwork.QueryState();
for (auto& state : states) { for (auto& state : states) {
auto name = state.GetName(); auto name = state.GetName();
if (name.find("cell_state_1") != std::string::npos) { if (name.find("cell_state_1") != std::string::npos) {
@ -421,7 +420,6 @@ void MultipleLSTMCellTest::InitMemory() {
GTEST_FAIL() << "unknown memory state"; GTEST_FAIL() << "unknown memory state";
} }
} }
IE_SUPPRESS_DEPRECATED_END
} }
void MultipleLSTMCellTest::ApplyLowLatency() { void MultipleLSTMCellTest::ApplyLowLatency() {
@ -459,6 +457,7 @@ void MultipleLSTMCellTest::ApplyLowLatency() {
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) { } else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) {
cnnNetwork = InferenceEngine::CNNNetwork{function}; cnnNetwork = InferenceEngine::CNNNetwork{function};
InferenceEngine::lowLatency2(cnnNetwork); InferenceEngine::lowLatency2(cnnNetwork);
@ -468,6 +467,7 @@ void MultipleLSTMCellTest::ApplyLowLatency() {
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
} }

View File

@ -72,8 +72,7 @@ namespace SubgraphTestsDefinitions {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()
LoadNetwork(); LoadNetwork();
IE_SUPPRESS_DEPRECATED_START auto states = inferRequest.QueryState();
auto states = executableNetwork.QueryState();
for (auto& state : states) { for (auto& state : states) {
auto name = state.GetName(); auto name = state.GetName();
if (name == "memory") { if (name == "memory") {
@ -84,7 +83,6 @@ namespace SubgraphTestsDefinitions {
GTEST_FAIL() << "unknown memory state"; GTEST_FAIL() << "unknown memory state";
} }
} }
IE_SUPPRESS_DEPRECATED_END
GenerateInputs(); GenerateInputs();
Infer(); Infer();
switchToNgraphFriendlyModel(); switchToNgraphFriendlyModel();

View File

@ -24,7 +24,6 @@ public:
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void)); MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void));
MOCK_METHOD1(Export, void(const std::string &)); MOCK_METHOD1(Export, void(const std::string &));
void Export(std::ostream &) override {}; void Export(std::ostream &) override {};
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>(void));
MOCK_METHOD0(GetExecGraphInfo, std::shared_ptr<ngraph::Function>(void)); MOCK_METHOD0(GetExecGraphInfo, std::shared_ptr<ngraph::Function>(void));
MOCK_METHOD1(SetConfig, void(const std::map<std::string, Parameter> &config)); MOCK_METHOD1(SetConfig, void(const std::map<std::string, Parameter> &config));

View File

@ -43,162 +43,6 @@ class InferRequestVariableStateTests : public ::testing::Test {
} }
}; };
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn(1);
toReturn[0] = mockVariableStateInternal;
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
auto state = net->QueryState();
ASSERT_EQ(state.size(), 1);
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
auto state = net->QueryState();
ASSERT_EQ(state.size(), 0);
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
auto state = net->QueryState();
ASSERT_EQ(state.size(), 2);
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesReset) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
auto state = net->QueryState();
state.front()->Reset();
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesExceptionsFromReset) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
auto state = net->QueryState();
EXPECT_ANY_THROW(state.front()->Reset());
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetName) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
auto state = net->QueryState();
EXPECT_STREQ(state.front()->GetName().c_str(), "someName");
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
auto pState = net->QueryState().front();
EXPECT_NO_THROW(pState->GetName());
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
auto pState = net->QueryState().front();
std::string name;
EXPECT_NO_THROW(name = pState->GetName());
EXPECT_EQ(name, "someName");
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
auto pState = net->QueryState().front();
std::string name;
EXPECT_NO_THROW(name = pState->GetName());
EXPECT_EQ(name, "someName");
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, VariableStateCanPropagateSetState) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
Blob::Ptr saver;
toReturn.push_back(mockVariableStateInternal);
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
float data[] = {123, 124, 125};
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
EXPECT_NO_THROW(net->QueryState().front()->SetState(stateBlob));
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
IE_SUPPRESS_DEPRECATED_END
}
TEST_F(InferRequestVariableStateTests, VariableStateCanPropagateGetLastState) {
IE_SUPPRESS_DEPRECATED_START
std::vector<IVariableStateInternal::Ptr> toReturn;
toReturn.push_back(mockVariableStateInternal);
float data[] = {123, 124, 125};
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
auto saver = net->QueryState().front()->GetState();
ASSERT_NE(saver, nullptr);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
IE_SUPPRESS_DEPRECATED_END
}
class VariableStateInternalMockImpl : public IVariableStateInternal { class VariableStateInternalMockImpl : public IVariableStateInternal {
public: public:
VariableStateInternalMockImpl(const char* name) : IVariableStateInternal(name) {} VariableStateInternalMockImpl(const char* name) : IVariableStateInternal(name) {}

View File

@ -88,26 +88,6 @@ TEST_F(ExecutableNetworkTests, GetInputsInfo) {
ASSERT_EQ(info, InferenceEngine::ConstInputsDataMap{}); ASSERT_EQ(info, InferenceEngine::ConstInputsDataMap{});
} }
IE_SUPPRESS_DEPRECATED_START
TEST_F(ExecutableNetworkTests, QueryStateThrowsIfReturnErr) {
EXPECT_CALL(*mockIExeNet.get(), QueryState())
.Times(1)
.WillOnce(Throw(InferenceEngine::GeneralError{""}));
EXPECT_THROW(exeNetwork->QueryState(), InferenceEngine::Exception);
}
TEST_F(ExecutableNetworkTests, QueryState) {
auto mockIMemState_p = std::make_shared<MockIVariableStateInternal>();
EXPECT_CALL(*mockIExeNet.get(), QueryState())
.Times(1)
.WillOnce(Return(std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(1, mockIMemState_p)));
std::vector<InferenceEngine::IVariableStateInternal::Ptr> MemState_v;
EXPECT_NO_THROW(MemState_v = exeNetwork->QueryState());
EXPECT_EQ(MemState_v.size(), 1);
}
IE_SUPPRESS_DEPRECATED_END
class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests { class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
protected: protected:
@ -135,6 +115,29 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfReturnNotO
ASSERT_THROW(exeNetwork->CreateInferRequest(), InferenceEngine::Exception); ASSERT_THROW(exeNetwork->CreateInferRequest(), InferenceEngine::Exception);
} }
TEST_F(ExecutableNetworkWithIInferReqTests, QueryStateThrowsIfReturnErr) {
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
IInferRequestInternal::Ptr actualInferReq;
ASSERT_NO_THROW(actualInferReq = exeNetwork->CreateInferRequest());
EXPECT_CALL(*mockIInferReq_p.get(), QueryState())
.Times(1)
.WillOnce(Throw(InferenceEngine::GeneralError{""}));
EXPECT_THROW(actualInferReq->QueryState(), InferenceEngine::Exception);
}
TEST_F(ExecutableNetworkWithIInferReqTests, QueryState) {
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
IInferRequestInternal::Ptr actualInferReq;
ASSERT_NO_THROW(actualInferReq = exeNetwork->CreateInferRequest());
auto mockIMemState_p = std::make_shared<MockIVariableStateInternal>();
EXPECT_CALL(*mockIInferReq_p.get(), QueryState())
.Times(1)
.WillOnce(Return(std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(1, mockIMemState_p)));
std::vector<InferenceEngine::IVariableStateInternal::Ptr> MemState_v;
EXPECT_NO_THROW(MemState_v = actualInferReq->QueryState());
EXPECT_EQ(MemState_v.size(), 1);
}
IE_SUPPRESS_DEPRECATED_START IE_SUPPRESS_DEPRECATED_START
class ExecutableNetworkBaseTests : public ::testing::Test { class ExecutableNetworkBaseTests : public ::testing::Test {

View File

@ -798,30 +798,6 @@ void GNAQueryStateMatcher :: match() {
EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess)); EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
#endif #endif
IE_SUPPRESS_DEPRECATED_START
try {
loadNetwork();
if (GnaPluginTestEnvironment::kAnyNotNull == _env.numberOfStates) {
auto states = executer->QueryState();
ASSERT_NE(states.size(), 0);
// usually states are callable
for (auto & state : states) {
state->Reset();
}
} else if (_env.numberOfStates >= 0) {
ASSERT_EQ(executer->QueryState().size(), _env.numberOfStates);
} else {
FAIL() << "number of memory states expectation not set";
}
}
catch(std::exception &ex) {
FAIL() << ex.what();
}
catch(...) {
FAIL() << "unknown exception thrown";
}
IE_SUPPRESS_DEPRECATED_END
} }

View File

@ -1,25 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <gtest/gtest.h>
#include "gna_matcher.hpp"
class QueryStateTest : public GNATest<> {
protected:
void SetUp() override {
}
};
using namespace GNATestIRs;
// Recursive Algorithm
// Precision Threshold
TEST_F(QueryStateTest, returnEmptyCollectionOfStatesIfNoMemoryInIR) {
assert_that().afterLoadingModel(TanhActivationModel()).withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f).queryState().isEmpty();
}
TEST_F(QueryStateTest, returnNonEmptyCollectionOfStatesForMemoryIR) {
assert_that().afterLoadingModel(affineToMemoryModel()).withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f).queryState().isNotEmpty();
}