move creation of InferRequests back to Infer; fixed tests for QueryState (#8962)

* move creaton of InferRequests back to Infer; fixed tets for QueryState

* fix indentation

* fix indentation

* fixed MemoryLSTMCell test
This commit is contained in:
Svetlana Dolinina 2021-12-09 00:58:33 +03:00 committed by GitHub
parent d18e80b604
commit 12ccc66920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 85 additions and 5 deletions

View File

@ -8,6 +8,16 @@
#include "shared_test_classes/subgraph/basic_lstm.hpp" #include "shared_test_classes/subgraph/basic_lstm.hpp"
namespace SubgraphTestsDefinitions { namespace SubgraphTestsDefinitions {
void Basic_LSTM_S::LoadNetwork() {
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
inferRequest = executableNetwork.CreateInferRequest();
}
void Basic_LSTM_S::Infer() {
ConfigureInferRequest();
inferRequest.Infer();
}
TEST_P(Basic_LSTM_S, CompareWithRefImpl) { TEST_P(Basic_LSTM_S, CompareWithRefImpl) {
Run(); Run();
}; };

View File

@ -57,7 +57,6 @@ void DetectNetworkBatch::LoadNetwork() {
functionRefs = ngraph::clone_function(*cnnNetwork.getFunction()); functionRefs = ngraph::clone_function(*cnnNetwork.getFunction());
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
TEST_P(DetectNetworkBatch, InferWithOneInput) { TEST_P(DetectNetworkBatch, InferWithOneInput) {

View File

@ -54,10 +54,10 @@ namespace ConfigurationTestsDefinitions {
ConfigureNetwork(); ConfigureNetwork();
cnnNetwork.setBatchSize(max_batch_size); cnnNetwork.setBatchSize(max_batch_size);
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
void DynamicBatchTest::Infer() { void DynamicBatchTest::Infer() {
inferRequest = executableNetwork.CreateInferRequest();
inputs.clear(); inputs.clear();
for (int i = 0; i < batch_sizes.size(); i++) { for (int i = 0; i < batch_sizes.size(); i++) {

View File

@ -139,6 +139,8 @@ protected:
virtual void GenerateInputs(); virtual void GenerateInputs();
virtual void ConfigureInferRequest();
virtual void Infer(); virtual void Infer();
TargetDevice targetDevice; TargetDevice targetDevice;

View File

@ -36,6 +36,9 @@ public:
std::vector<float>* cell_memory_init_out = nullptr); std::vector<float>* cell_memory_init_out = nullptr);
void GenerateInputs() override; void GenerateInputs() override;
protected: protected:
void LoadNetwork() override;
void Infer() override;
size_t hidden_size; size_t hidden_size;
size_t third_dim; size_t third_dim;
std::vector<float> hidden_memory_init; std::vector<float> hidden_memory_init;

View File

@ -35,5 +35,7 @@ public:
protected: protected:
void SetUp() override; void SetUp() override;
void Run() override; void Run() override;
void LoadNetwork() override;
void Infer() override;
}; };
} // namespace SubgraphTestsDefinitions } // namespace SubgraphTestsDefinitions

View File

@ -30,6 +30,8 @@ private:
virtual void switchToNgraphFriendlyModel() = 0; virtual void switchToNgraphFriendlyModel() = 0;
protected: protected:
void Run() override; void Run() override;
void LoadNetwork() override;
void Infer() override;
std::vector<float> memory_init; std::vector<float> memory_init;
public: public:
static std::string getTestCaseName(const testing::TestParamInfo<DelayedCopyTuple> &obj); static std::string getTestCaseName(const testing::TestParamInfo<DelayedCopyTuple> &obj);

View File

@ -39,6 +39,8 @@ private:
protected: protected:
void SetUp() override; void SetUp() override;
void Run() override; void Run() override;
void LoadNetwork() override;
void Infer() override;
public: public:
static std::string getTestCaseName(const testing::TestParamInfo<memoryLSTMCellParams> &obj); static std::string getTestCaseName(const testing::TestParamInfo<memoryLSTMCellParams> &obj);
}; };

View File

@ -33,6 +33,8 @@ private:
protected: protected:
void SetUp() override; void SetUp() override;
void Run() override; void Run() override;
void LoadNetwork() override;
void Infer() override;
public: public:
static std::string getTestCaseName(const testing::TestParamInfo<memoryEltwiseReshapeConcatParams> &obj); static std::string getTestCaseName(const testing::TestParamInfo<memoryEltwiseReshapeConcatParams> &obj);
}; };

View File

@ -34,5 +34,7 @@ public:
protected: protected:
void SetUp() override; void SetUp() override;
void Run() override; void Run() override;
void LoadNetwork() override;
void Infer() override;
}; };
} // namespace SubgraphTestsDefinitions } // namespace SubgraphTestsDefinitions

View File

@ -345,7 +345,6 @@ void LayerTestsCommon::LoadNetwork() {
CoreConfiguration(this); CoreConfiguration(this);
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
void LayerTestsCommon::GenerateInputs() { void LayerTestsCommon::GenerateInputs() {
@ -362,7 +361,7 @@ void LayerTestsCommon::GenerateInputs() {
} }
} }
void LayerTestsCommon::Infer() { void LayerTestsCommon::ConfigureInferRequest() {
const auto& inputsInfo = executableNetwork.GetInputsInfo(); const auto& inputsInfo = executableNetwork.GetInputsInfo();
const auto& functionParams = function->get_parameters(); const auto& functionParams = function->get_parameters();
for (int i = 0; i < functionParams.size(); ++i) { for (int i = 0; i < functionParams.size(); ++i) {
@ -379,6 +378,13 @@ void LayerTestsCommon::Infer() {
auto batchSize = executableNetwork.GetInputsInfo().begin()->second->getTensorDesc().getDims()[0] / 2; auto batchSize = executableNetwork.GetInputsInfo().begin()->second->getTensorDesc().getDims()[0] / 2;
inferRequest.SetBatch(batchSize); inferRequest.SetBatch(batchSize);
} }
}
void LayerTestsCommon::Infer() {
inferRequest = executableNetwork.CreateInferRequest();
ConfigureInferRequest();
inferRequest.Infer(); inferRequest.Infer();
} }

View File

@ -81,7 +81,6 @@ namespace LayerTestsDefinitions {
CoreConfiguration(this); CoreConfiguration(this);
ConfigureNetwork(); ConfigureNetwork();
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
inferRequest = executableNetwork.CreateInferRequest();
} }
GenerateInputs(); GenerateInputs();
for (int64_t i = 0; i < iteration_count; ++i) { for (int64_t i = 0; i < iteration_count; ++i) {

View File

@ -99,6 +99,16 @@ namespace SubgraphTestsDefinitions {
function = std::make_shared<ngraph::Function>(sigm, input, "concat_quant_during_memory_requant_nomemory"); function = std::make_shared<ngraph::Function>(sigm, input, "concat_quant_during_memory_requant_nomemory");
} }
void ConcatQuantDuringMemoryRequantTest::LoadNetwork() {
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
inferRequest = executableNetwork.CreateInferRequest();
}
void ConcatQuantDuringMemoryRequantTest::Infer() {
ConfigureInferRequest();
inferRequest.Infer();
}
void ConcatQuantDuringMemoryRequantTest::Run() { void ConcatQuantDuringMemoryRequantTest::Run() {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()

View File

@ -19,6 +19,17 @@ namespace SubgraphTestsDefinitions {
} }
} }
void DelayedCopyTestBase::LoadNetwork() {
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
inferRequest = executableNetwork.CreateInferRequest();
}
void DelayedCopyTestBase::Infer() {
ConfigureInferRequest();
inferRequest.Infer();
}
void DelayedCopyTestBase::Run() { void DelayedCopyTestBase::Run() {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()

View File

@ -254,6 +254,16 @@ namespace SubgraphTestsDefinitions {
function = std::make_shared<Function>(final_reshape, input_parameter, "PureTI"); function = std::make_shared<Function>(final_reshape, input_parameter, "PureTI");
} }
void MemoryLSTMCellTest::LoadNetwork() {
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
inferRequest = executableNetwork.CreateInferRequest();
}
void MemoryLSTMCellTest::Infer() {
ConfigureInferRequest();
inferRequest.Infer();
}
void MemoryLSTMCellTest::Run() { void MemoryLSTMCellTest::Run() {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()
if (transformation != ngraph::helpers::MemoryTransformation::NONE) { if (transformation != ngraph::helpers::MemoryTransformation::NONE) {

View File

@ -112,6 +112,16 @@ void MemoryEltwiseReshapeConcatTest::initNgraphFriendlyModel() {
function = std::make_shared<ngraph::Function>(concat, input_parameter, "memory_multiply_reshape_concat"); function = std::make_shared<ngraph::Function>(concat, input_parameter, "memory_multiply_reshape_concat");
} }
void MemoryEltwiseReshapeConcatTest::LoadNetwork() {
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
inferRequest = executableNetwork.CreateInferRequest();
}
void MemoryEltwiseReshapeConcatTest::Infer() {
ConfigureInferRequest();
inferRequest.Infer();
}
void MemoryEltwiseReshapeConcatTest::Run() { void MemoryEltwiseReshapeConcatTest::Run() {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()
initTestModel(); initTestModel();

View File

@ -68,6 +68,16 @@ namespace SubgraphTestsDefinitions {
function = std::make_shared<ngraph::Function>(sigm, input, "negative_memory_layer_offset_nonmemory"); function = std::make_shared<ngraph::Function>(sigm, input, "negative_memory_layer_offset_nonmemory");
} }
void NegativeMemoryOffsetTest::LoadNetwork() {
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
inferRequest = executableNetwork.CreateInferRequest();
}
void NegativeMemoryOffsetTest::Infer() {
ConfigureInferRequest();
inferRequest.Infer();
}
void NegativeMemoryOffsetTest::Run() { void NegativeMemoryOffsetTest::Run() {
SKIP_IF_CURRENT_TEST_IS_DISABLED() SKIP_IF_CURRENT_TEST_IS_DISABLED()