[GNA] add support for NCHW & NHWC layouts for exporting output (#2031)

* [GNA] add support for NCHW & NHWC ExportScores

* fix cpplint
This commit is contained in:
Anna Alberska
2020-09-08 09:57:44 +02:00
committed by GitHub
parent 8e6d9470bb
commit 6357ce83c5
2 changed files with 105 additions and 43 deletions

View File

@@ -86,10 +86,53 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
}
};
class RemovePermutationsNHWCToNCHWPass4DOutputTest : public testing::WithParamInterface<removePermutationsPassParams>,
public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<removePermutationsPassParams> obj) {
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::map<std::string, std::string> configuration;
std::tie(netPrecision, targetDevice, configuration) = obj.param;
std::ostringstream result;
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
for (auto const& configItem : configuration) {
result << "_configItem=" << configItem.first << "_" << configItem.second;
}
return result.str();
}
protected:
void SetUp() override {
InferenceEngine::Precision netPrecision;
std::tie(netPrecision, targetDevice, configuration) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, { {1, 1, 168, 2} });
auto permute1 = std::make_shared<ngraph::opset1::Transpose>(params[0],
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 3, 1, 2 }));
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, { 1, 8 }, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 }, ngraph::op::PadType::VALID, 12);
auto permute2 = std::make_shared<ngraph::opset1::Transpose>(conv1,
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 2, 3, 1 }));
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(permute2) };
function = std::make_shared<ngraph::Function>(results, params, "RemovePermutationPass4DOutput");
}
};
TEST_P(RemovePermutationsNHWCToNCHWPassTest, CompareWithRefImpl) {
Run();
};
TEST_P(RemovePermutationsNHWCToNCHWPass4DOutputTest, CompareWithRefImpl) {
Run();
};
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
@@ -109,5 +152,12 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
::testing::ValuesIn(configs)),
RemovePermutationsNHWCToNCHWPassTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(PermutationPass, RemovePermutationsNHWCToNCHWPass4DOutputTest,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs)),
RemovePermutationsNHWCToNCHWPass4DOutputTest::getTestCaseName);
} // namespace LayerTestsDefinitions