[GNA] add support for NCHW & NHWC layouts for exporting output (#2031)
* [GNA] add support for NCHW & NHWC ExportScores * fix cpplint
This commit is contained in:
@@ -86,10 +86,53 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
|
||||
}
|
||||
};
|
||||
|
||||
class RemovePermutationsNHWCToNCHWPass4DOutputTest : public testing::WithParamInterface<removePermutationsPassParams>,
|
||||
public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<removePermutationsPassParams> obj) {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::map<std::string, std::string> configuration;
|
||||
std::tie(netPrecision, targetDevice, configuration) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
for (auto const& configItem : configuration) {
|
||||
result << "_configItem=" << configItem.first << "_" << configItem.second;
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(netPrecision, targetDevice, configuration) = this->GetParam();
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngPrc, { {1, 1, 168, 2} });
|
||||
auto permute1 = std::make_shared<ngraph::opset1::Transpose>(params[0],
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 3, 1, 2 }));
|
||||
|
||||
auto conv1 = ngraph::builder::makeConvolution(permute1, ngPrc, { 1, 8 }, { 1, 1 }, { 0, 0 }, { 0, 0 }, { 1, 1 }, ngraph::op::PadType::VALID, 12);
|
||||
|
||||
auto permute2 = std::make_shared<ngraph::opset1::Transpose>(conv1,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 4 }, { 0, 2, 3, 1 }));
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(permute2) };
|
||||
|
||||
function = std::make_shared<ngraph::Function>(results, params, "RemovePermutationPass4DOutput");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(RemovePermutationsNHWCToNCHWPassTest, CompareWithRefImpl) {
|
||||
Run();
|
||||
};
|
||||
|
||||
TEST_P(RemovePermutationsNHWCToNCHWPass4DOutputTest, CompareWithRefImpl) {
|
||||
Run();
|
||||
};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
@@ -109,5 +152,12 @@ class RemovePermutationsNHWCToNCHWPassTest : public testing::WithParamInterface<
|
||||
::testing::ValuesIn(configs)),
|
||||
RemovePermutationsNHWCToNCHWPassTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(PermutationPass, RemovePermutationsNHWCToNCHWPass4DOutputTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs)),
|
||||
RemovePermutationsNHWCToNCHWPass4DOutputTest::getTestCaseName);
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
||||
|
||||
Reference in New Issue
Block a user