move creation of InferRequests back to Infer; fixed tests for QueryState (#8962)
* move creaton of InferRequests back to Infer; fixed tets for QueryState * fix indentation * fix indentation * fixed MemoryLSTMCell test
This commit is contained in:
parent
d18e80b604
commit
12ccc66920
@ -8,6 +8,16 @@
|
||||
#include "shared_test_classes/subgraph/basic_lstm.hpp"
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
void Basic_LSTM_S::LoadNetwork() {
|
||||
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void Basic_LSTM_S::Infer() {
|
||||
ConfigureInferRequest();
|
||||
inferRequest.Infer();
|
||||
}
|
||||
|
||||
TEST_P(Basic_LSTM_S, CompareWithRefImpl) {
|
||||
Run();
|
||||
};
|
||||
|
@ -57,7 +57,6 @@ void DetectNetworkBatch::LoadNetwork() {
|
||||
functionRefs = ngraph::clone_function(*cnnNetwork.getFunction());
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
TEST_P(DetectNetworkBatch, InferWithOneInput) {
|
||||
|
@ -54,10 +54,10 @@ namespace ConfigurationTestsDefinitions {
|
||||
ConfigureNetwork();
|
||||
cnnNetwork.setBatchSize(max_batch_size);
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void DynamicBatchTest::Infer() {
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
inputs.clear();
|
||||
|
||||
for (int i = 0; i < batch_sizes.size(); i++) {
|
||||
|
@ -139,6 +139,8 @@ protected:
|
||||
|
||||
virtual void GenerateInputs();
|
||||
|
||||
virtual void ConfigureInferRequest();
|
||||
|
||||
virtual void Infer();
|
||||
|
||||
TargetDevice targetDevice;
|
||||
|
@ -36,6 +36,9 @@ public:
|
||||
std::vector<float>* cell_memory_init_out = nullptr);
|
||||
void GenerateInputs() override;
|
||||
protected:
|
||||
void LoadNetwork() override;
|
||||
void Infer() override;
|
||||
|
||||
size_t hidden_size;
|
||||
size_t third_dim;
|
||||
std::vector<float> hidden_memory_init;
|
||||
|
@ -35,5 +35,7 @@ public:
|
||||
protected:
|
||||
void SetUp() override;
|
||||
void Run() override;
|
||||
void LoadNetwork() override;
|
||||
void Infer() override;
|
||||
};
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
@ -30,6 +30,8 @@ private:
|
||||
virtual void switchToNgraphFriendlyModel() = 0;
|
||||
protected:
|
||||
void Run() override;
|
||||
void LoadNetwork() override;
|
||||
void Infer() override;
|
||||
std::vector<float> memory_init;
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<DelayedCopyTuple> &obj);
|
||||
|
@ -39,6 +39,8 @@ private:
|
||||
protected:
|
||||
void SetUp() override;
|
||||
void Run() override;
|
||||
void LoadNetwork() override;
|
||||
void Infer() override;
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<memoryLSTMCellParams> &obj);
|
||||
};
|
||||
|
@ -33,6 +33,8 @@ private:
|
||||
protected:
|
||||
void SetUp() override;
|
||||
void Run() override;
|
||||
void LoadNetwork() override;
|
||||
void Infer() override;
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<memoryEltwiseReshapeConcatParams> &obj);
|
||||
};
|
||||
|
@ -34,5 +34,7 @@ public:
|
||||
protected:
|
||||
void SetUp() override;
|
||||
void Run() override;
|
||||
void LoadNetwork() override;
|
||||
void Infer() override;
|
||||
};
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
@ -345,7 +345,6 @@ void LayerTestsCommon::LoadNetwork() {
|
||||
CoreConfiguration(this);
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void LayerTestsCommon::GenerateInputs() {
|
||||
@ -362,7 +361,7 @@ void LayerTestsCommon::GenerateInputs() {
|
||||
}
|
||||
}
|
||||
|
||||
void LayerTestsCommon::Infer() {
|
||||
void LayerTestsCommon::ConfigureInferRequest() {
|
||||
const auto& inputsInfo = executableNetwork.GetInputsInfo();
|
||||
const auto& functionParams = function->get_parameters();
|
||||
for (int i = 0; i < functionParams.size(); ++i) {
|
||||
@ -379,6 +378,13 @@ void LayerTestsCommon::Infer() {
|
||||
auto batchSize = executableNetwork.GetInputsInfo().begin()->second->getTensorDesc().getDims()[0] / 2;
|
||||
inferRequest.SetBatch(batchSize);
|
||||
}
|
||||
}
|
||||
|
||||
void LayerTestsCommon::Infer() {
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
|
||||
ConfigureInferRequest();
|
||||
|
||||
inferRequest.Infer();
|
||||
}
|
||||
|
||||
|
@ -81,7 +81,6 @@ namespace LayerTestsDefinitions {
|
||||
CoreConfiguration(this);
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
GenerateInputs();
|
||||
for (int64_t i = 0; i < iteration_count; ++i) {
|
||||
|
@ -99,6 +99,16 @@ namespace SubgraphTestsDefinitions {
|
||||
function = std::make_shared<ngraph::Function>(sigm, input, "concat_quant_during_memory_requant_nomemory");
|
||||
}
|
||||
|
||||
void ConcatQuantDuringMemoryRequantTest::LoadNetwork() {
|
||||
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void ConcatQuantDuringMemoryRequantTest::Infer() {
|
||||
ConfigureInferRequest();
|
||||
inferRequest.Infer();
|
||||
}
|
||||
|
||||
void ConcatQuantDuringMemoryRequantTest::Run() {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
|
@ -19,6 +19,17 @@ namespace SubgraphTestsDefinitions {
|
||||
}
|
||||
}
|
||||
|
||||
void DelayedCopyTestBase::LoadNetwork() {
|
||||
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void DelayedCopyTestBase::Infer() {
|
||||
ConfigureInferRequest();
|
||||
inferRequest.Infer();
|
||||
}
|
||||
|
||||
|
||||
void DelayedCopyTestBase::Run() {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
|
@ -254,6 +254,16 @@ namespace SubgraphTestsDefinitions {
|
||||
function = std::make_shared<Function>(final_reshape, input_parameter, "PureTI");
|
||||
}
|
||||
|
||||
void MemoryLSTMCellTest::LoadNetwork() {
|
||||
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void MemoryLSTMCellTest::Infer() {
|
||||
ConfigureInferRequest();
|
||||
inferRequest.Infer();
|
||||
}
|
||||
|
||||
void MemoryLSTMCellTest::Run() {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
if (transformation != ngraph::helpers::MemoryTransformation::NONE) {
|
||||
|
@ -112,6 +112,16 @@ void MemoryEltwiseReshapeConcatTest::initNgraphFriendlyModel() {
|
||||
function = std::make_shared<ngraph::Function>(concat, input_parameter, "memory_multiply_reshape_concat");
|
||||
}
|
||||
|
||||
void MemoryEltwiseReshapeConcatTest::LoadNetwork() {
|
||||
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void MemoryEltwiseReshapeConcatTest::Infer() {
|
||||
ConfigureInferRequest();
|
||||
inferRequest.Infer();
|
||||
}
|
||||
|
||||
void MemoryEltwiseReshapeConcatTest::Run() {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
initTestModel();
|
||||
|
@ -68,6 +68,16 @@ namespace SubgraphTestsDefinitions {
|
||||
function = std::make_shared<ngraph::Function>(sigm, input, "negative_memory_layer_offset_nonmemory");
|
||||
}
|
||||
|
||||
void NegativeMemoryOffsetTest::LoadNetwork() {
|
||||
LayerTestsUtils::LayerTestsCommon::LoadNetwork();
|
||||
inferRequest = executableNetwork.CreateInferRequest();
|
||||
}
|
||||
|
||||
void NegativeMemoryOffsetTest::Infer() {
|
||||
ConfigureInferRequest();
|
||||
inferRequest.Infer();
|
||||
}
|
||||
|
||||
void NegativeMemoryOffsetTest::Run() {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user