[TESTS] Behavior test check input layout propagation (#4239)

This commit is contained in:
Maxim Kurin 2021-02-10 01:12:18 +03:00 committed by GitHub
parent 78c045b7ae
commit 929fa26e2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 14 deletions

View File

@ -12,7 +12,9 @@ namespace {
const std::vector<InferenceEngine::Layout> Layout = {
InferenceEngine::Layout::NCHW,
InferenceEngine::Layout::NHWC,
InferenceEngine::Layout::CHW,
InferenceEngine::Layout::HWC,
InferenceEngine::Layout::NC,
InferenceEngine::Layout::C
};
@ -27,10 +29,9 @@ namespace {
INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, LayoutTest,
::testing::Combine(
::testing::Values(InferenceEngine::Precision::FP32),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CommonTestUtils::DEVICE_CPU, "HETERO:CPU"),
::testing::ValuesIn(configs),
::testing::ValuesIn(Layout),
::testing::ValuesIn(inputShapes)),
LayoutTest::getTestCaseName);
} // namespace
} // namespace

View File

@ -16,20 +16,26 @@ std::string LayoutTest::getTestCaseName(testing::TestParamInfo<LayoutParams> obj
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
result << "layout=" << layout << "_";
result << "inputShapes=";
if (!inputShapes.empty()) {
for (auto &Item : inputShapes) {
result << "inputShapes=" << Item << "_";
result << Item << "x";
}
}
return result.str();
auto str = result.str();
str.pop_back();
return str;
}
void LayoutTest::SetUp() {
std::tie(netPrecision, targetDevice, configuration, layout, inputShapes) = this->GetParam();
function = ngraph::builder::subgraph::make2InputSubtract(inputShapes, FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision));
void LayoutTest::SetUp() {
std::tie(netPrecision, targetDevice, configuration,
layout, inputShapes) = this->GetParam();
function = ngraph::builder::subgraph::make2InputSubtract(
inputShapes, FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision));
}
void LayoutTest::TearDown() {
void LayoutTest::TearDown() {
if ((targetDevice == CommonTestUtils::DEVICE_GPU) || (!configuration.empty())) {
PluginCache::get().reset();
}
@ -61,6 +67,7 @@ inline bool checkLayout(InferenceEngine::Layout layout, std::vector<size_t> &inp
check = 4 == inputShapes.size();
break;
case InferenceEngine::Layout::CHW:
case InferenceEngine::Layout::HWC:
check = 3 == inputShapes.size();
break;
case InferenceEngine::Layout::CN:
@ -81,13 +88,24 @@ TEST_P(LayoutTest, NetWithLayout) {
InferenceEngine::CNNNetwork cnnNet(function);
if (checkLayout(layout, inputShapes)) {
ASSERT_NO_THROW(cnnNet.getInputsInfo().begin()->second->setLayout(layout));
if (targetDevice == CommonTestUtils::DEVICE_GNA) {
return;
}
InferenceEngine::ExecutableNetwork exeNetwork;
ASSERT_NO_THROW(exeNetwork = ie->LoadNetwork(cnnNet, targetDevice, configuration));
InferenceEngine::InferRequest request;
ASSERT_NO_THROW(request = exeNetwork.CreateInferRequest());
InferenceEngine::Blob::Ptr inputBlob;
ASSERT_NO_THROW(inputBlob = request.GetBlob(cnnNet.getInputsInfo().begin()->second->name()));
ASSERT_EQ(inputBlob->getTensorDesc().getLayout(), layout);
} else {
ASSERT_THROW(cnnNet.getInputsInfo().begin()->second->setLayout(layout),
InferenceEngine::details::InferenceEngineException);
}
if (targetDevice != CommonTestUtils::DEVICE_GNA) {
ASSERT_NO_THROW(InferenceEngine::ExecutableNetwork exeNetwork =
ie->LoadNetwork(cnnNet, targetDevice, configuration));
}
}
} // namespace BehaviorTestsDefinitions
} // namespace BehaviorTestsDefinitions