Remove ExecutableNetwork::QueryState (#8034)
* removed ExecutableNetwork::QueryState from code * removed ExecutableNetwork::QueryStates from tests (not checked) * buildable version * remove unneeded change and fix cpplint error * remove extra space * remove QueryState from GNAExecutableNetwork * clean up GNA tests for QueryState in tests_deprecated (without replacement because deprecated) * fix tests after merge * remove tests again after merge * fixed tests with _REGULAR_API suffix
This commit is contained in:
parent
3354275da1
commit
e6884c3fd7
@ -67,13 +67,6 @@ class GNAExecutableNetwork : public InferenceEngine::IExecutableNetworkInternal
|
||||
return std::make_shared<GNAInferRequest>(plg, inputs, outputs);
|
||||
}
|
||||
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
return plg->QueryState();
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void Export(const std::string &modelFileName) override {
|
||||
plg->Export(modelFileName);
|
||||
}
|
||||
|
@ -190,18 +190,6 @@ public:
|
||||
*/
|
||||
INFERENCE_ENGINE_DEPRECATED("Use ExecutableNetwork::CreateInferRequest instead")
|
||||
InferRequest::Ptr CreateInferRequestPtr();
|
||||
|
||||
/**
|
||||
* @deprecated Use InferRequest::QueryState instead
|
||||
* @brief Gets state control interface for given executable network.
|
||||
*
|
||||
* State control essential for recurrent networks
|
||||
*
|
||||
* @return A vector of Memory State objects
|
||||
*/
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
||||
std::vector<VariableState> QueryState();
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
@ -65,16 +65,6 @@ ExecutableNetwork::operator IExecutableNetwork::Ptr() {
|
||||
return std::make_shared<ExecutableNetworkBase>(_impl);
|
||||
}
|
||||
|
||||
std::vector<VariableState> ExecutableNetwork::QueryState() {
|
||||
std::vector<VariableState> controller;
|
||||
EXEC_NET_CALL_STATEMENT({
|
||||
for (auto&& state : _impl->QueryState()) {
|
||||
controller.emplace_back(VariableState{_so, state});
|
||||
}
|
||||
});
|
||||
return controller;
|
||||
}
|
||||
|
||||
InferRequest ExecutableNetwork::CreateInferRequest() {
|
||||
EXEC_NET_CALL_STATEMENT(return {_so, _impl->CreateInferRequest()});
|
||||
}
|
||||
|
@ -87,10 +87,6 @@ std::shared_ptr<ngraph::Function> IExecutableNetworkInternal::GetExecGraphInfo()
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<IVariableStateInternal>> IExecutableNetworkInternal::QueryState() {
|
||||
IE_THROW(NotImplemented);
|
||||
}
|
||||
|
||||
void IExecutableNetworkInternal::SetPointerToPlugin(const std::shared_ptr<IInferencePlugin>& plugin) {
|
||||
_plugin = plugin;
|
||||
}
|
||||
|
@ -283,12 +283,6 @@ bool MKLDNNExecNetwork::CanProcessDynBatch(const InferenceEngine::CNNNetwork &ne
|
||||
return true;
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
|
||||
return memoryStates;
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
void MKLDNNExecNetwork::Export(std::ostream& modelStream) {
|
||||
CNNNetworkSerializer serializer(modelStream, extensionManager);
|
||||
serializer <<_network;
|
||||
|
@ -44,9 +44,6 @@ public:
|
||||
|
||||
std::shared_ptr<ngraph::Function> GetExecGraphInfo() override;
|
||||
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
|
||||
|
||||
void Export(std::ostream& modelStream) override;
|
||||
|
||||
protected:
|
||||
|
@ -60,26 +60,20 @@ void MKLDNNPlugin::MKLDNNInferRequest::CreateInferRequest() {
|
||||
// Save all MemoryLayer data tensors. Will use insight about mechanics
|
||||
// of MemoryLayer implementation. It uses output edge of MemoryLayer
|
||||
// producer as storage for tensor to keep it between infer calls.
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
if (execNetwork->_numRequests > 1 || execNetwork->QueryState().size() == 0) {
|
||||
for (auto &node : graph->GetNodes()) {
|
||||
if (node->getType() == MemoryInput) {
|
||||
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
|
||||
auto state_store = memoryNode->getStore();
|
||||
auto state_name = memoryNode->getId();
|
||||
for (auto& node : graph->GetNodes()) {
|
||||
if (node->getType() == MemoryInput) {
|
||||
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
|
||||
auto state_store = memoryNode->getStore();
|
||||
auto state_name = memoryNode->getId();
|
||||
|
||||
// Remove suffix with pair ID. Internal information.
|
||||
auto suffix_idx = state_name.find("/id=");
|
||||
if (suffix_idx != std::string::npos)
|
||||
state_name = state_name.substr(0, suffix_idx);
|
||||
// Remove suffix with pair ID. Internal information.
|
||||
auto suffix_idx = state_name.find("/id=");
|
||||
if (suffix_idx != std::string::npos)
|
||||
state_name = state_name.substr(0, suffix_idx);
|
||||
|
||||
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
|
||||
}
|
||||
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
|
||||
}
|
||||
} else {
|
||||
memoryStates = execNetwork->QueryState();
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() {
|
||||
|
@ -115,13 +115,6 @@ public:
|
||||
*/
|
||||
virtual std::shared_ptr<ngraph::Function> GetExecGraphInfo();
|
||||
|
||||
/**
|
||||
* @deprecated Need to implement GetVariablesInfo for ExecutableNetwork
|
||||
* @brief Queries memory states.
|
||||
* @return Returns memory states
|
||||
*/
|
||||
virtual std::vector<std::shared_ptr<IVariableStateInternal>> QueryState();
|
||||
|
||||
/**
|
||||
* @brief Sets the pointer to plugin internal.
|
||||
* @param[in] plugin The plugin
|
||||
|
@ -35,13 +35,6 @@ TEST(ExecutableNetworkTests, throwsOnUninitializedGetExecGraphInfo) {
|
||||
ASSERT_THROW(exec.GetExecGraphInfo(), InferenceEngine::NotAllocated);
|
||||
}
|
||||
|
||||
TEST(ExecutableNetworkTests, throwsOnUninitializedQueryState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
ExecutableNetwork exec;
|
||||
ASSERT_THROW(exec.QueryState(), InferenceEngine::NotAllocated);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) {
|
||||
ExecutableNetwork exec;
|
||||
ASSERT_THROW(exec.SetConfig({{}}), InferenceEngine::NotAllocated);
|
||||
|
@ -65,16 +65,14 @@ public:
|
||||
}
|
||||
auto importedNetwork = core->ImportNetwork(inputStream, targetDevice, configuration);
|
||||
std::vector<std::string> queryToState;
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
for (const auto &query_state : executableNetwork.QueryState()) {
|
||||
InferenceEngine::InferRequest importInfer = importedNetwork.CreateInferRequest();
|
||||
for (const auto &query_state : importInfer.QueryState()) {
|
||||
queryToState.push_back(query_state.GetName());
|
||||
}
|
||||
for (const auto &next_memory : importedNetwork.QueryState()) {
|
||||
for (const auto &next_memory : importInfer.QueryState()) {
|
||||
ASSERT_TRUE(std::find(queryToState.begin(), queryToState.end(), next_memory.GetName()) != queryToState.end())
|
||||
<< "State " << next_memory.GetName() << " expected to be in memory states but it is not!";
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
InferenceEngine::InferRequest importInfer = importedNetwork.CreateInferRequest();
|
||||
importInfer.Infer();
|
||||
}
|
||||
|
||||
|
@ -42,8 +42,7 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
|
||||
manager.register_pass<ngraph::pass::LowLatency2>(); // LowLatency enables UnrollTI
|
||||
manager.run_passes(function);
|
||||
LoadNetwork();
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto states = executableNetwork.QueryState();
|
||||
auto states = inferRequest.QueryState();
|
||||
for (auto& state : states) {
|
||||
auto name = state.GetName();
|
||||
if (name.find("cell_state_1") != std::string::npos) {
|
||||
@ -58,7 +57,6 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
|
||||
GTEST_FAIL() << "unknown memory state";
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
// Run and compare
|
||||
Infer();
|
||||
const auto& actualOutputs = GetOutputs();
|
||||
|
@ -31,98 +31,6 @@ InferenceEngine::ExecutableNetwork InferRequestVariableStateTest::PrepareNetwork
|
||||
return ie->LoadNetwork(net, deviceName);
|
||||
}
|
||||
|
||||
TEST_P(InferRequestVariableStateTest, smoke_VariableState_QueryState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto executableNet = PrepareNetwork();
|
||||
|
||||
auto states = executableNet.QueryState();
|
||||
ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
|
||||
|
||||
for (auto &&state : states) {
|
||||
auto name = state.GetName();
|
||||
ASSERT_TRUE(std::find(statesToQuery.begin(), statesToQuery.end(), name) != statesToQuery.end())
|
||||
<< "State " << name << "expected to be in memory states but it is not!";
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_P(InferRequestVariableStateTest, smoke_VariableState_SetState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto executableNet = PrepareNetwork();
|
||||
const float new_state_val = 13.0f;
|
||||
for (auto &&state : executableNet.QueryState()) {
|
||||
state.Reset();
|
||||
auto state_val = state.GetState();
|
||||
auto element_count = state_val->size();
|
||||
|
||||
float *new_state_data = new float[element_count];
|
||||
for (int i = 0; i < element_count; i++) {
|
||||
new_state_data[i] = new_state_val;
|
||||
}
|
||||
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
|
||||
stateBlob->allocate();
|
||||
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
|
||||
delete[]new_state_data;
|
||||
state.SetState(stateBlob);
|
||||
}
|
||||
|
||||
for (auto &&state : executableNet.QueryState()) {
|
||||
auto lastState = state.GetState();
|
||||
auto last_state_size = lastState->size();
|
||||
auto last_state_data = lastState->cbuffer().as<float *>();
|
||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
||||
|
||||
for (int i = 0; i < last_state_size; i++) {
|
||||
EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5);
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_P(InferRequestVariableStateTest, smoke_VariableState_Reset) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto executableNet = PrepareNetwork();
|
||||
const float new_state_val = 13.0f;
|
||||
for (auto &&state : executableNet.QueryState()) {
|
||||
state.Reset();
|
||||
auto state_val = state.GetState();
|
||||
auto element_count = state_val->size();
|
||||
|
||||
float *new_state_data = new float[element_count];
|
||||
for (int i = 0; i < element_count; i++) {
|
||||
new_state_data[i] = new_state_val;
|
||||
}
|
||||
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
|
||||
stateBlob->allocate();
|
||||
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
|
||||
delete[]new_state_data;
|
||||
|
||||
state.SetState(stateBlob);
|
||||
}
|
||||
|
||||
executableNet.QueryState().front().Reset();
|
||||
|
||||
auto states = executableNet.QueryState();
|
||||
for (int i = 0; i < states.size(); ++i) {
|
||||
auto lastState = states[i].GetState();
|
||||
auto last_state_size = lastState->size();
|
||||
auto last_state_data = lastState->cbuffer().as<float *>();
|
||||
|
||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
||||
|
||||
if (i == 0) {
|
||||
for (int j = 0; j < last_state_size; ++j) {
|
||||
EXPECT_NEAR(0, last_state_data[j], 1e-5);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < last_state_size; ++j) {
|
||||
EXPECT_NEAR(new_state_val, last_state_data[j], 1e-5);
|
||||
}
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_QueryState) {
|
||||
auto executableNet = PrepareNetwork();
|
||||
auto inferReq = executableNet.CreateInferRequest();
|
||||
|
@ -57,6 +57,7 @@ void DetectNetworkBatch::LoadNetwork() {
|
||||
functionRefs = ngraph::clone_function(*cnnNetwork.getFunction());
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
TEST_P(DetectNetworkBatch, InferWithOneInput) {
|
||||
|
@ -54,10 +54,10 @@ namespace ConfigurationTestsDefinitions {
|
||||
ConfigureNetwork();
|
||||
cnnNetwork.setBatchSize(max_batch_size);
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void DynamicBatchTest::Infer() {
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
inputs.clear();
|
||||
|
||||
for (int i = 0; i < batch_sizes.size(); i++) {
|
||||
|
@ -344,6 +344,7 @@ void LayerTestsCommon::LoadNetwork() {
|
||||
CoreConfiguration(this);
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void LayerTestsCommon::GenerateInputs() {
|
||||
@ -361,8 +362,6 @@ void LayerTestsCommon::GenerateInputs() {
|
||||
}
|
||||
|
||||
void LayerTestsCommon::Infer() {
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
|
||||
const auto& inputsInfo = executableNetwork.GetInputsInfo();
|
||||
const auto& functionParams = function->get_parameters();
|
||||
for (int i = 0; i < functionParams.size(); ++i) {
|
||||
|
@ -81,6 +81,7 @@ namespace LayerTestsDefinitions {
|
||||
CoreConfiguration(this);
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
GenerateInputs();
|
||||
for (int64_t i = 0; i < iteration_count; ++i) {
|
||||
|
@ -103,8 +103,8 @@ namespace SubgraphTestsDefinitions {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
LoadNetwork();
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto states = executableNetwork.QueryState();
|
||||
|
||||
auto states = inferRequest.QueryState();
|
||||
for (auto& state : states) {
|
||||
auto name = state.GetName();
|
||||
if (name == "memory_1") {
|
||||
@ -119,7 +119,6 @@ namespace SubgraphTestsDefinitions {
|
||||
GTEST_FAIL() << "unknown memory state";
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
GenerateInputs();
|
||||
Infer();
|
||||
switchToNgraphFriendlyModel();
|
||||
|
@ -6,8 +6,7 @@
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
void DelayedCopyTestBase::InitMemory() {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto states = executableNetwork.QueryState();
|
||||
auto states = inferRequest.QueryState();
|
||||
for (auto& state : states) {
|
||||
auto name = state.GetName();
|
||||
if (name.find("id") != std::string::npos) {
|
||||
@ -18,7 +17,6 @@ namespace SubgraphTestsDefinitions {
|
||||
GTEST_FAIL() << "unknown memory state";
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void DelayedCopyTestBase::Run() {
|
||||
|
@ -276,8 +276,7 @@ namespace SubgraphTestsDefinitions {
|
||||
}
|
||||
|
||||
void MemoryLSTMCellTest::InitMemory() {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto states = executableNetwork.QueryState();
|
||||
auto states = inferRequest.QueryState();
|
||||
for (auto& state : states) {
|
||||
auto name = state.GetName();
|
||||
if (name.find("cell_state_1") != std::string::npos) {
|
||||
@ -292,7 +291,6 @@ namespace SubgraphTestsDefinitions {
|
||||
GTEST_FAIL() << "unknown memory state";
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void MemoryLSTMCellTest::ApplyLowLatency() {
|
||||
@ -330,6 +328,7 @@ namespace SubgraphTestsDefinitions {
|
||||
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
} else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) {
|
||||
cnnNetwork = InferenceEngine::CNNNetwork{function};
|
||||
InferenceEngine::lowLatency2(cnnNetwork);
|
||||
@ -339,6 +338,7 @@ namespace SubgraphTestsDefinitions {
|
||||
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
}
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
@ -122,7 +122,7 @@ void MemoryEltwiseReshapeConcatTest::Run() {
|
||||
InferenceEngine::Layout::NC);
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto states = executableNetwork.QueryState();
|
||||
auto states = inferRequest.QueryState();
|
||||
auto state_values_blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description,
|
||||
memory_init.data(), memory_init.size());
|
||||
states[0].SetState(state_values_blob);
|
||||
|
@ -397,8 +397,7 @@ void MultipleLSTMCellTest::InitMemory() {
|
||||
InferenceEngine::TensorDesc state_description(InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::SizeVector({1, hiddenSize}),
|
||||
InferenceEngine::Layout::NC);
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto states = executableNetwork.QueryState();
|
||||
auto states = inferRequest.QueryState();
|
||||
for (auto& state : states) {
|
||||
auto name = state.GetName();
|
||||
if (name.find("cell_state_1") != std::string::npos) {
|
||||
@ -421,7 +420,6 @@ void MultipleLSTMCellTest::InitMemory() {
|
||||
GTEST_FAIL() << "unknown memory state";
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void MultipleLSTMCellTest::ApplyLowLatency() {
|
||||
@ -459,6 +457,7 @@ void MultipleLSTMCellTest::ApplyLowLatency() {
|
||||
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
} else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) {
|
||||
cnnNetwork = InferenceEngine::CNNNetwork{function};
|
||||
InferenceEngine::lowLatency2(cnnNetwork);
|
||||
@ -468,6 +467,7 @@ void MultipleLSTMCellTest::ApplyLowLatency() {
|
||||
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,8 +72,7 @@ namespace SubgraphTestsDefinitions {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
LoadNetwork();
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
auto states = executableNetwork.QueryState();
|
||||
auto states = inferRequest.QueryState();
|
||||
for (auto& state : states) {
|
||||
auto name = state.GetName();
|
||||
if (name == "memory") {
|
||||
@ -84,7 +83,6 @@ namespace SubgraphTestsDefinitions {
|
||||
GTEST_FAIL() << "unknown memory state";
|
||||
}
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
GenerateInputs();
|
||||
Infer();
|
||||
switchToNgraphFriendlyModel();
|
||||
|
@ -24,7 +24,6 @@ public:
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void));
|
||||
MOCK_METHOD1(Export, void(const std::string &));
|
||||
void Export(std::ostream &) override {};
|
||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>(void));
|
||||
MOCK_METHOD0(GetExecGraphInfo, std::shared_ptr<ngraph::Function>(void));
|
||||
|
||||
MOCK_METHOD1(SetConfig, void(const std::map<std::string, Parameter> &config));
|
||||
|
@ -43,162 +43,6 @@ class InferRequestVariableStateTests : public ::testing::Test {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn(1);
|
||||
toReturn[0] = mockVariableStateInternal;
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
|
||||
auto state = net->QueryState();
|
||||
ASSERT_EQ(state.size(), 1);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
|
||||
|
||||
auto state = net->QueryState();
|
||||
ASSERT_EQ(state.size(), 0);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
|
||||
auto state = net->QueryState();
|
||||
ASSERT_EQ(state.size(), 2);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesReset) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
|
||||
|
||||
auto state = net->QueryState();
|
||||
state.front()->Reset();
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesExceptionsFromReset) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
|
||||
|
||||
auto state = net->QueryState();
|
||||
EXPECT_ANY_THROW(state.front()->Reset());
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetName) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
auto state = net->QueryState();
|
||||
EXPECT_STREQ(state.front()->GetName().c_str(), "someName");
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
auto pState = net->QueryState().front();
|
||||
EXPECT_NO_THROW(pState->GetName());
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
auto pState = net->QueryState().front();
|
||||
std::string name;
|
||||
EXPECT_NO_THROW(name = pState->GetName());
|
||||
EXPECT_EQ(name, "someName");
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
||||
|
||||
auto pState = net->QueryState().front();
|
||||
std::string name;
|
||||
EXPECT_NO_THROW(name = pState->GetName());
|
||||
EXPECT_EQ(name, "someName");
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, VariableStateCanPropagateSetState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
Blob::Ptr saver;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
|
||||
|
||||
float data[] = {123, 124, 125};
|
||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
||||
|
||||
EXPECT_NO_THROW(net->QueryState().front()->SetState(stateBlob));
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
|
||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
TEST_F(InferRequestVariableStateTests, VariableStateCanPropagateGetLastState) {
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
||||
toReturn.push_back(mockVariableStateInternal);
|
||||
|
||||
float data[] = {123, 124, 125};
|
||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
||||
|
||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
|
||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
|
||||
|
||||
auto saver = net->QueryState().front()->GetState();
|
||||
ASSERT_NE(saver, nullptr);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
|
||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
class VariableStateInternalMockImpl : public IVariableStateInternal {
|
||||
public:
|
||||
VariableStateInternalMockImpl(const char* name) : IVariableStateInternal(name) {}
|
||||
|
@ -88,26 +88,6 @@ TEST_F(ExecutableNetworkTests, GetInputsInfo) {
|
||||
ASSERT_EQ(info, InferenceEngine::ConstInputsDataMap{});
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
TEST_F(ExecutableNetworkTests, QueryStateThrowsIfReturnErr) {
|
||||
EXPECT_CALL(*mockIExeNet.get(), QueryState())
|
||||
.Times(1)
|
||||
.WillOnce(Throw(InferenceEngine::GeneralError{""}));
|
||||
EXPECT_THROW(exeNetwork->QueryState(), InferenceEngine::Exception);
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkTests, QueryState) {
|
||||
auto mockIMemState_p = std::make_shared<MockIVariableStateInternal>();
|
||||
EXPECT_CALL(*mockIExeNet.get(), QueryState())
|
||||
.Times(1)
|
||||
.WillOnce(Return(std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(1, mockIMemState_p)));
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> MemState_v;
|
||||
EXPECT_NO_THROW(MemState_v = exeNetwork->QueryState());
|
||||
EXPECT_EQ(MemState_v.size(), 1);
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
|
||||
protected:
|
||||
@ -135,6 +115,29 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfReturnNotO
|
||||
ASSERT_THROW(exeNetwork->CreateInferRequest(), InferenceEngine::Exception);
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkWithIInferReqTests, QueryStateThrowsIfReturnErr) {
|
||||
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
|
||||
IInferRequestInternal::Ptr actualInferReq;
|
||||
ASSERT_NO_THROW(actualInferReq = exeNetwork->CreateInferRequest());
|
||||
EXPECT_CALL(*mockIInferReq_p.get(), QueryState())
|
||||
.Times(1)
|
||||
.WillOnce(Throw(InferenceEngine::GeneralError{""}));
|
||||
EXPECT_THROW(actualInferReq->QueryState(), InferenceEngine::Exception);
|
||||
}
|
||||
|
||||
TEST_F(ExecutableNetworkWithIInferReqTests, QueryState) {
|
||||
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
|
||||
IInferRequestInternal::Ptr actualInferReq;
|
||||
ASSERT_NO_THROW(actualInferReq = exeNetwork->CreateInferRequest());
|
||||
auto mockIMemState_p = std::make_shared<MockIVariableStateInternal>();
|
||||
EXPECT_CALL(*mockIInferReq_p.get(), QueryState())
|
||||
.Times(1)
|
||||
.WillOnce(Return(std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(1, mockIMemState_p)));
|
||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> MemState_v;
|
||||
EXPECT_NO_THROW(MemState_v = actualInferReq->QueryState());
|
||||
EXPECT_EQ(MemState_v.size(), 1);
|
||||
}
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
class ExecutableNetworkBaseTests : public ::testing::Test {
|
||||
|
@ -798,30 +798,6 @@ void GNAQueryStateMatcher :: match() {
|
||||
|
||||
EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
|
||||
#endif
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
try {
|
||||
loadNetwork();
|
||||
if (GnaPluginTestEnvironment::kAnyNotNull == _env.numberOfStates) {
|
||||
auto states = executer->QueryState();
|
||||
ASSERT_NE(states.size(), 0);
|
||||
// usually states are callable
|
||||
for (auto & state : states) {
|
||||
state->Reset();
|
||||
}
|
||||
} else if (_env.numberOfStates >= 0) {
|
||||
ASSERT_EQ(executer->QueryState().size(), _env.numberOfStates);
|
||||
} else {
|
||||
FAIL() << "number of memory states expectation not set";
|
||||
}
|
||||
|
||||
}
|
||||
catch(std::exception &ex) {
|
||||
FAIL() << ex.what();
|
||||
}
|
||||
catch(...) {
|
||||
FAIL() << "unknown exception thrown";
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,25 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
#include "gna_matcher.hpp"
|
||||
|
||||
class QueryStateTest : public GNATest<> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
}
|
||||
};
|
||||
using namespace GNATestIRs;
|
||||
|
||||
// Recursive Algorithm
|
||||
// Precision Threshold
|
||||
|
||||
TEST_F(QueryStateTest, returnEmptyCollectionOfStatesIfNoMemoryInIR) {
|
||||
assert_that().afterLoadingModel(TanhActivationModel()).withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f).queryState().isEmpty();
|
||||
}
|
||||
|
||||
TEST_F(QueryStateTest, returnNonEmptyCollectionOfStatesForMemoryIR) {
|
||||
assert_that().afterLoadingModel(affineToMemoryModel()).withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f).queryState().isNotEmpty();
|
||||
}
|
Loading…
Reference in New Issue
Block a user