Remove ExecutableNetwork::QueryState (#8034)
* removed ExecutableNetwork::QueryState from code * removed ExecutableNetwork::QueryStates from tests (not checked) * buildable version * remove unneeded change and fix cpplint error * remove extra space * remove QueryState from GNAExecutableNetwork * clean up GNA tests for QueryState in tests_deprecated (without replacement because deprecated) * fix tests after merge * remove tests again after merge * fixed tests with _REGULAR_API suffix
This commit is contained in:
parent
3354275da1
commit
e6884c3fd7
@ -67,13 +67,6 @@ class GNAExecutableNetwork : public InferenceEngine::IExecutableNetworkInternal
|
|||||||
return std::make_shared<GNAInferRequest>(plg, inputs, outputs);
|
return std::make_shared<GNAInferRequest>(plg, inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
|
||||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
return plg->QueryState();
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
void Export(const std::string &modelFileName) override {
|
void Export(const std::string &modelFileName) override {
|
||||||
plg->Export(modelFileName);
|
plg->Export(modelFileName);
|
||||||
}
|
}
|
||||||
|
@ -190,18 +190,6 @@ public:
|
|||||||
*/
|
*/
|
||||||
INFERENCE_ENGINE_DEPRECATED("Use ExecutableNetwork::CreateInferRequest instead")
|
INFERENCE_ENGINE_DEPRECATED("Use ExecutableNetwork::CreateInferRequest instead")
|
||||||
InferRequest::Ptr CreateInferRequestPtr();
|
InferRequest::Ptr CreateInferRequestPtr();
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use InferRequest::QueryState instead
|
|
||||||
* @brief Gets state control interface for given executable network.
|
|
||||||
*
|
|
||||||
* State control essential for recurrent networks
|
|
||||||
*
|
|
||||||
* @return A vector of Memory State objects
|
|
||||||
*/
|
|
||||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
|
||||||
std::vector<VariableState> QueryState();
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace InferenceEngine
|
} // namespace InferenceEngine
|
||||||
|
@ -65,16 +65,6 @@ ExecutableNetwork::operator IExecutableNetwork::Ptr() {
|
|||||||
return std::make_shared<ExecutableNetworkBase>(_impl);
|
return std::make_shared<ExecutableNetworkBase>(_impl);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<VariableState> ExecutableNetwork::QueryState() {
|
|
||||||
std::vector<VariableState> controller;
|
|
||||||
EXEC_NET_CALL_STATEMENT({
|
|
||||||
for (auto&& state : _impl->QueryState()) {
|
|
||||||
controller.emplace_back(VariableState{_so, state});
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return controller;
|
|
||||||
}
|
|
||||||
|
|
||||||
InferRequest ExecutableNetwork::CreateInferRequest() {
|
InferRequest ExecutableNetwork::CreateInferRequest() {
|
||||||
EXEC_NET_CALL_STATEMENT(return {_so, _impl->CreateInferRequest()});
|
EXEC_NET_CALL_STATEMENT(return {_so, _impl->CreateInferRequest()});
|
||||||
}
|
}
|
||||||
|
@ -87,10 +87,6 @@ std::shared_ptr<ngraph::Function> IExecutableNetworkInternal::GetExecGraphInfo()
|
|||||||
IE_THROW(NotImplemented);
|
IE_THROW(NotImplemented);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<IVariableStateInternal>> IExecutableNetworkInternal::QueryState() {
|
|
||||||
IE_THROW(NotImplemented);
|
|
||||||
}
|
|
||||||
|
|
||||||
void IExecutableNetworkInternal::SetPointerToPlugin(const std::shared_ptr<IInferencePlugin>& plugin) {
|
void IExecutableNetworkInternal::SetPointerToPlugin(const std::shared_ptr<IInferencePlugin>& plugin) {
|
||||||
_plugin = plugin;
|
_plugin = plugin;
|
||||||
}
|
}
|
||||||
|
@ -283,12 +283,6 @@ bool MKLDNNExecNetwork::CanProcessDynBatch(const InferenceEngine::CNNNetwork &ne
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> MKLDNNExecNetwork::QueryState() {
|
|
||||||
return memoryStates;
|
|
||||||
}
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
|
|
||||||
void MKLDNNExecNetwork::Export(std::ostream& modelStream) {
|
void MKLDNNExecNetwork::Export(std::ostream& modelStream) {
|
||||||
CNNNetworkSerializer serializer(modelStream, extensionManager);
|
CNNNetworkSerializer serializer(modelStream, extensionManager);
|
||||||
serializer <<_network;
|
serializer <<_network;
|
||||||
|
@ -44,9 +44,6 @@ public:
|
|||||||
|
|
||||||
std::shared_ptr<ngraph::Function> GetExecGraphInfo() override;
|
std::shared_ptr<ngraph::Function> GetExecGraphInfo() override;
|
||||||
|
|
||||||
INFERENCE_ENGINE_DEPRECATED("Use InferRequest::QueryState instead")
|
|
||||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override;
|
|
||||||
|
|
||||||
void Export(std::ostream& modelStream) override;
|
void Export(std::ostream& modelStream) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -60,26 +60,20 @@ void MKLDNNPlugin::MKLDNNInferRequest::CreateInferRequest() {
|
|||||||
// Save all MemoryLayer data tensors. Will use insight about mechanics
|
// Save all MemoryLayer data tensors. Will use insight about mechanics
|
||||||
// of MemoryLayer implementation. It uses output edge of MemoryLayer
|
// of MemoryLayer implementation. It uses output edge of MemoryLayer
|
||||||
// producer as storage for tensor to keep it between infer calls.
|
// producer as storage for tensor to keep it between infer calls.
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
for (auto& node : graph->GetNodes()) {
|
||||||
if (execNetwork->_numRequests > 1 || execNetwork->QueryState().size() == 0) {
|
if (node->getType() == MemoryInput) {
|
||||||
for (auto &node : graph->GetNodes()) {
|
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
|
||||||
if (node->getType() == MemoryInput) {
|
auto state_store = memoryNode->getStore();
|
||||||
auto memoryNode = dynamic_cast<MKLDNNMemoryInputNode*>(node.get());
|
auto state_name = memoryNode->getId();
|
||||||
auto state_store = memoryNode->getStore();
|
|
||||||
auto state_name = memoryNode->getId();
|
|
||||||
|
|
||||||
// Remove suffix with pair ID. Internal information.
|
// Remove suffix with pair ID. Internal information.
|
||||||
auto suffix_idx = state_name.find("/id=");
|
auto suffix_idx = state_name.find("/id=");
|
||||||
if (suffix_idx != std::string::npos)
|
if (suffix_idx != std::string::npos)
|
||||||
state_name = state_name.substr(0, suffix_idx);
|
state_name = state_name.substr(0, suffix_idx);
|
||||||
|
|
||||||
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
|
memoryStates.emplace_back(new MKLDNNVariableState(state_name, state_store));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
memoryStates = execNetwork->QueryState();
|
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() {
|
MKLDNNPlugin::MKLDNNInferRequest::~MKLDNNInferRequest() {
|
||||||
|
@ -115,13 +115,6 @@ public:
|
|||||||
*/
|
*/
|
||||||
virtual std::shared_ptr<ngraph::Function> GetExecGraphInfo();
|
virtual std::shared_ptr<ngraph::Function> GetExecGraphInfo();
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Need to implement GetVariablesInfo for ExecutableNetwork
|
|
||||||
* @brief Queries memory states.
|
|
||||||
* @return Returns memory states
|
|
||||||
*/
|
|
||||||
virtual std::vector<std::shared_ptr<IVariableStateInternal>> QueryState();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sets the pointer to plugin internal.
|
* @brief Sets the pointer to plugin internal.
|
||||||
* @param[in] plugin The plugin
|
* @param[in] plugin The plugin
|
||||||
|
@ -35,13 +35,6 @@ TEST(ExecutableNetworkTests, throwsOnUninitializedGetExecGraphInfo) {
|
|||||||
ASSERT_THROW(exec.GetExecGraphInfo(), InferenceEngine::NotAllocated);
|
ASSERT_THROW(exec.GetExecGraphInfo(), InferenceEngine::NotAllocated);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ExecutableNetworkTests, throwsOnUninitializedQueryState) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
ExecutableNetwork exec;
|
|
||||||
ASSERT_THROW(exec.QueryState(), InferenceEngine::NotAllocated);
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) {
|
TEST(ExecutableNetworkTests, throwsOnUninitializedSetConfig) {
|
||||||
ExecutableNetwork exec;
|
ExecutableNetwork exec;
|
||||||
ASSERT_THROW(exec.SetConfig({{}}), InferenceEngine::NotAllocated);
|
ASSERT_THROW(exec.SetConfig({{}}), InferenceEngine::NotAllocated);
|
||||||
|
@ -65,16 +65,14 @@ public:
|
|||||||
}
|
}
|
||||||
auto importedNetwork = core->ImportNetwork(inputStream, targetDevice, configuration);
|
auto importedNetwork = core->ImportNetwork(inputStream, targetDevice, configuration);
|
||||||
std::vector<std::string> queryToState;
|
std::vector<std::string> queryToState;
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
InferenceEngine::InferRequest importInfer = importedNetwork.CreateInferRequest();
|
||||||
for (const auto &query_state : executableNetwork.QueryState()) {
|
for (const auto &query_state : importInfer.QueryState()) {
|
||||||
queryToState.push_back(query_state.GetName());
|
queryToState.push_back(query_state.GetName());
|
||||||
}
|
}
|
||||||
for (const auto &next_memory : importedNetwork.QueryState()) {
|
for (const auto &next_memory : importInfer.QueryState()) {
|
||||||
ASSERT_TRUE(std::find(queryToState.begin(), queryToState.end(), next_memory.GetName()) != queryToState.end())
|
ASSERT_TRUE(std::find(queryToState.begin(), queryToState.end(), next_memory.GetName()) != queryToState.end())
|
||||||
<< "State " << next_memory.GetName() << " expected to be in memory states but it is not!";
|
<< "State " << next_memory.GetName() << " expected to be in memory states but it is not!";
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
InferenceEngine::InferRequest importInfer = importedNetwork.CreateInferRequest();
|
|
||||||
importInfer.Infer();
|
importInfer.Infer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,8 +42,7 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
|
|||||||
manager.register_pass<ngraph::pass::LowLatency2>(); // LowLatency enables UnrollTI
|
manager.register_pass<ngraph::pass::LowLatency2>(); // LowLatency enables UnrollTI
|
||||||
manager.run_passes(function);
|
manager.run_passes(function);
|
||||||
LoadNetwork();
|
LoadNetwork();
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
auto states = inferRequest.QueryState();
|
||||||
auto states = executableNetwork.QueryState();
|
|
||||||
for (auto& state : states) {
|
for (auto& state : states) {
|
||||||
auto name = state.GetName();
|
auto name = state.GetName();
|
||||||
if (name.find("cell_state_1") != std::string::npos) {
|
if (name.find("cell_state_1") != std::string::npos) {
|
||||||
@ -58,7 +57,6 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) {
|
|||||||
GTEST_FAIL() << "unknown memory state";
|
GTEST_FAIL() << "unknown memory state";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
// Run and compare
|
// Run and compare
|
||||||
Infer();
|
Infer();
|
||||||
const auto& actualOutputs = GetOutputs();
|
const auto& actualOutputs = GetOutputs();
|
||||||
|
@ -31,98 +31,6 @@ InferenceEngine::ExecutableNetwork InferRequestVariableStateTest::PrepareNetwork
|
|||||||
return ie->LoadNetwork(net, deviceName);
|
return ie->LoadNetwork(net, deviceName);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(InferRequestVariableStateTest, smoke_VariableState_QueryState) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
auto executableNet = PrepareNetwork();
|
|
||||||
|
|
||||||
auto states = executableNet.QueryState();
|
|
||||||
ASSERT_TRUE(states.size() == 2) << "Incorrect number of VariableStates";
|
|
||||||
|
|
||||||
for (auto &&state : states) {
|
|
||||||
auto name = state.GetName();
|
|
||||||
ASSERT_TRUE(std::find(statesToQuery.begin(), statesToQuery.end(), name) != statesToQuery.end())
|
|
||||||
<< "State " << name << "expected to be in memory states but it is not!";
|
|
||||||
}
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(InferRequestVariableStateTest, smoke_VariableState_SetState) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
auto executableNet = PrepareNetwork();
|
|
||||||
const float new_state_val = 13.0f;
|
|
||||||
for (auto &&state : executableNet.QueryState()) {
|
|
||||||
state.Reset();
|
|
||||||
auto state_val = state.GetState();
|
|
||||||
auto element_count = state_val->size();
|
|
||||||
|
|
||||||
float *new_state_data = new float[element_count];
|
|
||||||
for (int i = 0; i < element_count; i++) {
|
|
||||||
new_state_data[i] = new_state_val;
|
|
||||||
}
|
|
||||||
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
|
|
||||||
stateBlob->allocate();
|
|
||||||
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
|
|
||||||
delete[]new_state_data;
|
|
||||||
state.SetState(stateBlob);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto &&state : executableNet.QueryState()) {
|
|
||||||
auto lastState = state.GetState();
|
|
||||||
auto last_state_size = lastState->size();
|
|
||||||
auto last_state_data = lastState->cbuffer().as<float *>();
|
|
||||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
|
||||||
|
|
||||||
for (int i = 0; i < last_state_size; i++) {
|
|
||||||
EXPECT_NEAR(new_state_val, last_state_data[i], 1e-5);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(InferRequestVariableStateTest, smoke_VariableState_Reset) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
auto executableNet = PrepareNetwork();
|
|
||||||
const float new_state_val = 13.0f;
|
|
||||||
for (auto &&state : executableNet.QueryState()) {
|
|
||||||
state.Reset();
|
|
||||||
auto state_val = state.GetState();
|
|
||||||
auto element_count = state_val->size();
|
|
||||||
|
|
||||||
float *new_state_data = new float[element_count];
|
|
||||||
for (int i = 0; i < element_count; i++) {
|
|
||||||
new_state_data[i] = new_state_val;
|
|
||||||
}
|
|
||||||
auto stateBlob = make_blob_with_precision(state_val->getTensorDesc());
|
|
||||||
stateBlob->allocate();
|
|
||||||
std::memcpy(stateBlob->buffer(), new_state_data, element_count * sizeof(float));
|
|
||||||
delete[]new_state_data;
|
|
||||||
|
|
||||||
state.SetState(stateBlob);
|
|
||||||
}
|
|
||||||
|
|
||||||
executableNet.QueryState().front().Reset();
|
|
||||||
|
|
||||||
auto states = executableNet.QueryState();
|
|
||||||
for (int i = 0; i < states.size(); ++i) {
|
|
||||||
auto lastState = states[i].GetState();
|
|
||||||
auto last_state_size = lastState->size();
|
|
||||||
auto last_state_data = lastState->cbuffer().as<float *>();
|
|
||||||
|
|
||||||
ASSERT_TRUE(last_state_size != 0) << "State size should not be 0";
|
|
||||||
|
|
||||||
if (i == 0) {
|
|
||||||
for (int j = 0; j < last_state_size; ++j) {
|
|
||||||
EXPECT_NEAR(0, last_state_data[j], 1e-5);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < last_state_size; ++j) {
|
|
||||||
EXPECT_NEAR(new_state_val, last_state_data[j], 1e-5);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_QueryState) {
|
TEST_P(InferRequestVariableStateTest, inferreq_smoke_VariableState_QueryState) {
|
||||||
auto executableNet = PrepareNetwork();
|
auto executableNet = PrepareNetwork();
|
||||||
auto inferReq = executableNet.CreateInferRequest();
|
auto inferReq = executableNet.CreateInferRequest();
|
||||||
|
@ -57,6 +57,7 @@ void DetectNetworkBatch::LoadNetwork() {
|
|||||||
functionRefs = ngraph::clone_function(*cnnNetwork.getFunction());
|
functionRefs = ngraph::clone_function(*cnnNetwork.getFunction());
|
||||||
ConfigureNetwork();
|
ConfigureNetwork();
|
||||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||||
|
inferRequest = executableNetwork.CreateInferRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(DetectNetworkBatch, InferWithOneInput) {
|
TEST_P(DetectNetworkBatch, InferWithOneInput) {
|
||||||
|
@ -54,10 +54,10 @@ namespace ConfigurationTestsDefinitions {
|
|||||||
ConfigureNetwork();
|
ConfigureNetwork();
|
||||||
cnnNetwork.setBatchSize(max_batch_size);
|
cnnNetwork.setBatchSize(max_batch_size);
|
||||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||||
|
inferRequest = executableNetwork.CreateInferRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicBatchTest::Infer() {
|
void DynamicBatchTest::Infer() {
|
||||||
inferRequest = executableNetwork.CreateInferRequest();
|
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
|
|
||||||
for (int i = 0; i < batch_sizes.size(); i++) {
|
for (int i = 0; i < batch_sizes.size(); i++) {
|
||||||
|
@ -344,6 +344,7 @@ void LayerTestsCommon::LoadNetwork() {
|
|||||||
CoreConfiguration(this);
|
CoreConfiguration(this);
|
||||||
ConfigureNetwork();
|
ConfigureNetwork();
|
||||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||||
|
inferRequest = executableNetwork.CreateInferRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
void LayerTestsCommon::GenerateInputs() {
|
void LayerTestsCommon::GenerateInputs() {
|
||||||
@ -361,8 +362,6 @@ void LayerTestsCommon::GenerateInputs() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LayerTestsCommon::Infer() {
|
void LayerTestsCommon::Infer() {
|
||||||
inferRequest = executableNetwork.CreateInferRequest();
|
|
||||||
|
|
||||||
const auto& inputsInfo = executableNetwork.GetInputsInfo();
|
const auto& inputsInfo = executableNetwork.GetInputsInfo();
|
||||||
const auto& functionParams = function->get_parameters();
|
const auto& functionParams = function->get_parameters();
|
||||||
for (int i = 0; i < functionParams.size(); ++i) {
|
for (int i = 0; i < functionParams.size(); ++i) {
|
||||||
|
@ -81,6 +81,7 @@ namespace LayerTestsDefinitions {
|
|||||||
CoreConfiguration(this);
|
CoreConfiguration(this);
|
||||||
ConfigureNetwork();
|
ConfigureNetwork();
|
||||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||||
|
inferRequest = executableNetwork.CreateInferRequest();
|
||||||
}
|
}
|
||||||
GenerateInputs();
|
GenerateInputs();
|
||||||
for (int64_t i = 0; i < iteration_count; ++i) {
|
for (int64_t i = 0; i < iteration_count; ++i) {
|
||||||
|
@ -103,8 +103,8 @@ namespace SubgraphTestsDefinitions {
|
|||||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
LoadNetwork();
|
LoadNetwork();
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
auto states = executableNetwork.QueryState();
|
auto states = inferRequest.QueryState();
|
||||||
for (auto& state : states) {
|
for (auto& state : states) {
|
||||||
auto name = state.GetName();
|
auto name = state.GetName();
|
||||||
if (name == "memory_1") {
|
if (name == "memory_1") {
|
||||||
@ -119,7 +119,6 @@ namespace SubgraphTestsDefinitions {
|
|||||||
GTEST_FAIL() << "unknown memory state";
|
GTEST_FAIL() << "unknown memory state";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
GenerateInputs();
|
GenerateInputs();
|
||||||
Infer();
|
Infer();
|
||||||
switchToNgraphFriendlyModel();
|
switchToNgraphFriendlyModel();
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
|
|
||||||
namespace SubgraphTestsDefinitions {
|
namespace SubgraphTestsDefinitions {
|
||||||
void DelayedCopyTestBase::InitMemory() {
|
void DelayedCopyTestBase::InitMemory() {
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
auto states = inferRequest.QueryState();
|
||||||
auto states = executableNetwork.QueryState();
|
|
||||||
for (auto& state : states) {
|
for (auto& state : states) {
|
||||||
auto name = state.GetName();
|
auto name = state.GetName();
|
||||||
if (name.find("id") != std::string::npos) {
|
if (name.find("id") != std::string::npos) {
|
||||||
@ -18,7 +17,6 @@ namespace SubgraphTestsDefinitions {
|
|||||||
GTEST_FAIL() << "unknown memory state";
|
GTEST_FAIL() << "unknown memory state";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DelayedCopyTestBase::Run() {
|
void DelayedCopyTestBase::Run() {
|
||||||
|
@ -276,8 +276,7 @@ namespace SubgraphTestsDefinitions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MemoryLSTMCellTest::InitMemory() {
|
void MemoryLSTMCellTest::InitMemory() {
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
auto states = inferRequest.QueryState();
|
||||||
auto states = executableNetwork.QueryState();
|
|
||||||
for (auto& state : states) {
|
for (auto& state : states) {
|
||||||
auto name = state.GetName();
|
auto name = state.GetName();
|
||||||
if (name.find("cell_state_1") != std::string::npos) {
|
if (name.find("cell_state_1") != std::string::npos) {
|
||||||
@ -292,7 +291,6 @@ namespace SubgraphTestsDefinitions {
|
|||||||
GTEST_FAIL() << "unknown memory state";
|
GTEST_FAIL() << "unknown memory state";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemoryLSTMCellTest::ApplyLowLatency() {
|
void MemoryLSTMCellTest::ApplyLowLatency() {
|
||||||
@ -330,6 +328,7 @@ namespace SubgraphTestsDefinitions {
|
|||||||
|
|
||||||
ConfigureNetwork();
|
ConfigureNetwork();
|
||||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||||
|
inferRequest = executableNetwork.CreateInferRequest();
|
||||||
} else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) {
|
} else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) {
|
||||||
cnnNetwork = InferenceEngine::CNNNetwork{function};
|
cnnNetwork = InferenceEngine::CNNNetwork{function};
|
||||||
InferenceEngine::lowLatency2(cnnNetwork);
|
InferenceEngine::lowLatency2(cnnNetwork);
|
||||||
@ -339,6 +338,7 @@ namespace SubgraphTestsDefinitions {
|
|||||||
|
|
||||||
ConfigureNetwork();
|
ConfigureNetwork();
|
||||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||||
|
inferRequest = executableNetwork.CreateInferRequest();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace SubgraphTestsDefinitions
|
} // namespace SubgraphTestsDefinitions
|
||||||
|
@ -122,7 +122,7 @@ void MemoryEltwiseReshapeConcatTest::Run() {
|
|||||||
InferenceEngine::Layout::NC);
|
InferenceEngine::Layout::NC);
|
||||||
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
IE_SUPPRESS_DEPRECATED_START
|
||||||
auto states = executableNetwork.QueryState();
|
auto states = inferRequest.QueryState();
|
||||||
auto state_values_blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description,
|
auto state_values_blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description,
|
||||||
memory_init.data(), memory_init.size());
|
memory_init.data(), memory_init.size());
|
||||||
states[0].SetState(state_values_blob);
|
states[0].SetState(state_values_blob);
|
||||||
|
@ -397,8 +397,7 @@ void MultipleLSTMCellTest::InitMemory() {
|
|||||||
InferenceEngine::TensorDesc state_description(InferenceEngine::Precision::FP32,
|
InferenceEngine::TensorDesc state_description(InferenceEngine::Precision::FP32,
|
||||||
InferenceEngine::SizeVector({1, hiddenSize}),
|
InferenceEngine::SizeVector({1, hiddenSize}),
|
||||||
InferenceEngine::Layout::NC);
|
InferenceEngine::Layout::NC);
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
auto states = inferRequest.QueryState();
|
||||||
auto states = executableNetwork.QueryState();
|
|
||||||
for (auto& state : states) {
|
for (auto& state : states) {
|
||||||
auto name = state.GetName();
|
auto name = state.GetName();
|
||||||
if (name.find("cell_state_1") != std::string::npos) {
|
if (name.find("cell_state_1") != std::string::npos) {
|
||||||
@ -421,7 +420,6 @@ void MultipleLSTMCellTest::InitMemory() {
|
|||||||
GTEST_FAIL() << "unknown memory state";
|
GTEST_FAIL() << "unknown memory state";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultipleLSTMCellTest::ApplyLowLatency() {
|
void MultipleLSTMCellTest::ApplyLowLatency() {
|
||||||
@ -459,6 +457,7 @@ void MultipleLSTMCellTest::ApplyLowLatency() {
|
|||||||
|
|
||||||
ConfigureNetwork();
|
ConfigureNetwork();
|
||||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||||
|
inferRequest = executableNetwork.CreateInferRequest();
|
||||||
} else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) {
|
} else if (transformation == ngraph::helpers::MemoryTransformation::LOW_LATENCY_V2_REGULAR_API) {
|
||||||
cnnNetwork = InferenceEngine::CNNNetwork{function};
|
cnnNetwork = InferenceEngine::CNNNetwork{function};
|
||||||
InferenceEngine::lowLatency2(cnnNetwork);
|
InferenceEngine::lowLatency2(cnnNetwork);
|
||||||
@ -468,6 +467,7 @@ void MultipleLSTMCellTest::ApplyLowLatency() {
|
|||||||
|
|
||||||
ConfigureNetwork();
|
ConfigureNetwork();
|
||||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||||
|
inferRequest = executableNetwork.CreateInferRequest();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,8 +72,7 @@ namespace SubgraphTestsDefinitions {
|
|||||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
LoadNetwork();
|
LoadNetwork();
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
auto states = inferRequest.QueryState();
|
||||||
auto states = executableNetwork.QueryState();
|
|
||||||
for (auto& state : states) {
|
for (auto& state : states) {
|
||||||
auto name = state.GetName();
|
auto name = state.GetName();
|
||||||
if (name == "memory") {
|
if (name == "memory") {
|
||||||
@ -84,7 +83,6 @@ namespace SubgraphTestsDefinitions {
|
|||||||
GTEST_FAIL() << "unknown memory state";
|
GTEST_FAIL() << "unknown memory state";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
GenerateInputs();
|
GenerateInputs();
|
||||||
Infer();
|
Infer();
|
||||||
switchToNgraphFriendlyModel();
|
switchToNgraphFriendlyModel();
|
||||||
|
@ -24,7 +24,6 @@ public:
|
|||||||
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void));
|
MOCK_METHOD0(CreateInferRequest, IInferRequestInternal::Ptr(void));
|
||||||
MOCK_METHOD1(Export, void(const std::string &));
|
MOCK_METHOD1(Export, void(const std::string &));
|
||||||
void Export(std::ostream &) override {};
|
void Export(std::ostream &) override {};
|
||||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>(void));
|
|
||||||
MOCK_METHOD0(GetExecGraphInfo, std::shared_ptr<ngraph::Function>(void));
|
MOCK_METHOD0(GetExecGraphInfo, std::shared_ptr<ngraph::Function>(void));
|
||||||
|
|
||||||
MOCK_METHOD1(SetConfig, void(const std::map<std::string, Parameter> &config));
|
MOCK_METHOD1(SetConfig, void(const std::map<std::string, Parameter> &config));
|
||||||
|
@ -43,162 +43,6 @@ class InferRequestVariableStateTests : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvertOneVariableStateFromCppToAPI) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn(1);
|
|
||||||
toReturn[0] = mockVariableStateInternal;
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
|
||||||
|
|
||||||
auto state = net->QueryState();
|
|
||||||
ASSERT_EQ(state.size(), 1);
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvertZeroVariableStateFromCppToAPI) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillOnce(Return(toReturn));
|
|
||||||
|
|
||||||
auto state = net->QueryState();
|
|
||||||
ASSERT_EQ(state.size(), 0);
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, ExecutableNetworkCanConvert2VariableStatesFromCPPtoAPI) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
|
||||||
|
|
||||||
auto state = net->QueryState();
|
|
||||||
ASSERT_EQ(state.size(), 2);
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesReset) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
|
||||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).Times(1);
|
|
||||||
|
|
||||||
auto state = net->QueryState();
|
|
||||||
state.front()->Reset();
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesExceptionsFromReset) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
|
||||||
EXPECT_CALL(*mockVariableStateInternal.get(), Reset()).WillOnce(Throw(std::logic_error("some error")));
|
|
||||||
|
|
||||||
auto state = net->QueryState();
|
|
||||||
EXPECT_ANY_THROW(state.front()->Reset());
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetName) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
|
||||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
|
||||||
|
|
||||||
auto state = net->QueryState();
|
|
||||||
EXPECT_STREQ(state.front()->GetName().c_str(), "someName");
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithZeroLen) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
|
||||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
|
||||||
|
|
||||||
auto pState = net->QueryState().front();
|
|
||||||
EXPECT_NO_THROW(pState->GetName());
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithLenOfOne) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
|
||||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
|
||||||
|
|
||||||
auto pState = net->QueryState().front();
|
|
||||||
std::string name;
|
|
||||||
EXPECT_NO_THROW(name = pState->GetName());
|
|
||||||
EXPECT_EQ(name, "someName");
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, VariableStatePropagatesGetNameWithLenOfTwo) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).Times(1).WillRepeatedly(Return(toReturn));
|
|
||||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetName()).WillOnce(Return("someName"));
|
|
||||||
|
|
||||||
auto pState = net->QueryState().front();
|
|
||||||
std::string name;
|
|
||||||
EXPECT_NO_THROW(name = pState->GetName());
|
|
||||||
EXPECT_EQ(name, "someName");
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, VariableStateCanPropagateSetState) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
Blob::Ptr saver;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
|
|
||||||
EXPECT_CALL(*mockVariableStateInternal.get(), SetState(_)).WillOnce(SaveArg<0>(&saver));
|
|
||||||
|
|
||||||
float data[] = {123, 124, 125};
|
|
||||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
|
||||||
|
|
||||||
EXPECT_NO_THROW(net->QueryState().front()->SetState(stateBlob));
|
|
||||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[0], 123);
|
|
||||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[1], 124);
|
|
||||||
ASSERT_FLOAT_EQ(saver->buffer().as<float*>()[2], 125);
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InferRequestVariableStateTests, VariableStateCanPropagateGetLastState) {
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
std::vector<IVariableStateInternal::Ptr> toReturn;
|
|
||||||
toReturn.push_back(mockVariableStateInternal);
|
|
||||||
|
|
||||||
float data[] = {123, 124, 125};
|
|
||||||
auto stateBlob = make_shared_blob<float>({ Precision::FP32, {3}, C }, data, sizeof(data) / sizeof(*data));
|
|
||||||
|
|
||||||
EXPECT_CALL(*mockExeNetworkInternal.get(), QueryState()).WillRepeatedly(Return(toReturn));
|
|
||||||
EXPECT_CALL(*mockVariableStateInternal.get(), GetState()).WillOnce(Return(stateBlob));
|
|
||||||
|
|
||||||
auto saver = net->QueryState().front()->GetState();
|
|
||||||
ASSERT_NE(saver, nullptr);
|
|
||||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[0], 123);
|
|
||||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[1], 124);
|
|
||||||
ASSERT_FLOAT_EQ(saver->cbuffer().as<const float*>()[2], 125);
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
class VariableStateInternalMockImpl : public IVariableStateInternal {
|
class VariableStateInternalMockImpl : public IVariableStateInternal {
|
||||||
public:
|
public:
|
||||||
VariableStateInternalMockImpl(const char* name) : IVariableStateInternal(name) {}
|
VariableStateInternalMockImpl(const char* name) : IVariableStateInternal(name) {}
|
||||||
|
@ -88,26 +88,6 @@ TEST_F(ExecutableNetworkTests, GetInputsInfo) {
|
|||||||
ASSERT_EQ(info, InferenceEngine::ConstInputsDataMap{});
|
ASSERT_EQ(info, InferenceEngine::ConstInputsDataMap{});
|
||||||
}
|
}
|
||||||
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
|
|
||||||
TEST_F(ExecutableNetworkTests, QueryStateThrowsIfReturnErr) {
|
|
||||||
EXPECT_CALL(*mockIExeNet.get(), QueryState())
|
|
||||||
.Times(1)
|
|
||||||
.WillOnce(Throw(InferenceEngine::GeneralError{""}));
|
|
||||||
EXPECT_THROW(exeNetwork->QueryState(), InferenceEngine::Exception);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ExecutableNetworkTests, QueryState) {
|
|
||||||
auto mockIMemState_p = std::make_shared<MockIVariableStateInternal>();
|
|
||||||
EXPECT_CALL(*mockIExeNet.get(), QueryState())
|
|
||||||
.Times(1)
|
|
||||||
.WillOnce(Return(std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(1, mockIMemState_p)));
|
|
||||||
std::vector<InferenceEngine::IVariableStateInternal::Ptr> MemState_v;
|
|
||||||
EXPECT_NO_THROW(MemState_v = exeNetwork->QueryState());
|
|
||||||
EXPECT_EQ(MemState_v.size(), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
|
|
||||||
class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
|
class ExecutableNetworkWithIInferReqTests : public ExecutableNetworkTests {
|
||||||
protected:
|
protected:
|
||||||
@ -135,6 +115,29 @@ TEST_F(ExecutableNetworkWithIInferReqTests, CreateInferRequestThrowsIfReturnNotO
|
|||||||
ASSERT_THROW(exeNetwork->CreateInferRequest(), InferenceEngine::Exception);
|
ASSERT_THROW(exeNetwork->CreateInferRequest(), InferenceEngine::Exception);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ExecutableNetworkWithIInferReqTests, QueryStateThrowsIfReturnErr) {
|
||||||
|
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
|
||||||
|
IInferRequestInternal::Ptr actualInferReq;
|
||||||
|
ASSERT_NO_THROW(actualInferReq = exeNetwork->CreateInferRequest());
|
||||||
|
EXPECT_CALL(*mockIInferReq_p.get(), QueryState())
|
||||||
|
.Times(1)
|
||||||
|
.WillOnce(Throw(InferenceEngine::GeneralError{""}));
|
||||||
|
EXPECT_THROW(actualInferReq->QueryState(), InferenceEngine::Exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ExecutableNetworkWithIInferReqTests, QueryState) {
|
||||||
|
EXPECT_CALL(*mockIExeNet.get(), CreateInferRequest()).WillOnce(Return(mockIInferReq_p));
|
||||||
|
IInferRequestInternal::Ptr actualInferReq;
|
||||||
|
ASSERT_NO_THROW(actualInferReq = exeNetwork->CreateInferRequest());
|
||||||
|
auto mockIMemState_p = std::make_shared<MockIVariableStateInternal>();
|
||||||
|
EXPECT_CALL(*mockIInferReq_p.get(), QueryState())
|
||||||
|
.Times(1)
|
||||||
|
.WillOnce(Return(std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(1, mockIMemState_p)));
|
||||||
|
std::vector<InferenceEngine::IVariableStateInternal::Ptr> MemState_v;
|
||||||
|
EXPECT_NO_THROW(MemState_v = actualInferReq->QueryState());
|
||||||
|
EXPECT_EQ(MemState_v.size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
IE_SUPPRESS_DEPRECATED_START
|
||||||
|
|
||||||
class ExecutableNetworkBaseTests : public ::testing::Test {
|
class ExecutableNetworkBaseTests : public ::testing::Test {
|
||||||
|
@ -798,30 +798,6 @@ void GNAQueryStateMatcher :: match() {
|
|||||||
|
|
||||||
EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
|
EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
|
||||||
#endif
|
#endif
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
|
||||||
try {
|
|
||||||
loadNetwork();
|
|
||||||
if (GnaPluginTestEnvironment::kAnyNotNull == _env.numberOfStates) {
|
|
||||||
auto states = executer->QueryState();
|
|
||||||
ASSERT_NE(states.size(), 0);
|
|
||||||
// usually states are callable
|
|
||||||
for (auto & state : states) {
|
|
||||||
state->Reset();
|
|
||||||
}
|
|
||||||
} else if (_env.numberOfStates >= 0) {
|
|
||||||
ASSERT_EQ(executer->QueryState().size(), _env.numberOfStates);
|
|
||||||
} else {
|
|
||||||
FAIL() << "number of memory states expectation not set";
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
catch(std::exception &ex) {
|
|
||||||
FAIL() << ex.what();
|
|
||||||
}
|
|
||||||
catch(...) {
|
|
||||||
FAIL() << "unknown exception thrown";
|
|
||||||
}
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
// Copyright (C) 2018-2021 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include "gna_matcher.hpp"
|
|
||||||
|
|
||||||
class QueryStateTest : public GNATest<> {
|
|
||||||
protected:
|
|
||||||
void SetUp() override {
|
|
||||||
}
|
|
||||||
};
|
|
||||||
using namespace GNATestIRs;
|
|
||||||
|
|
||||||
// Recursive Algorithm
|
|
||||||
// Precision Threshold
|
|
||||||
|
|
||||||
TEST_F(QueryStateTest, returnEmptyCollectionOfStatesIfNoMemoryInIR) {
|
|
||||||
assert_that().afterLoadingModel(TanhActivationModel()).withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f).queryState().isEmpty();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(QueryStateTest, returnNonEmptyCollectionOfStatesForMemoryIR) {
|
|
||||||
assert_that().afterLoadingModel(affineToMemoryModel()).withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f).queryState().isNotEmpty();
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user