ConvolutionLayerTest for dynamic shape case (Test only)

This commit is contained in:
Steve Yoo 2021-09-13 15:30:27 +09:00
parent d8eaf21acd
commit d11e3e917e
3 changed files with 47 additions and 4 deletions

View File

@ -80,7 +80,8 @@ INSTANTIATE_TEST_SUITE_P(Convolution2D_ExplicitPadding, ConvolutionLayerTest,
::testing::Values(InferenceEngine::Layout::ANY), ::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY), ::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(std::vector<std::pair<size_t, size_t>>({{1, 10}, {3, 30}, {30, 300}, {30, 300}})), ::testing::Values(std::vector<std::pair<size_t, size_t>>({{1, 10}, {3, 30}, {30, 300}, {30, 300}})),
::testing::Values(std::vector<std::vector<size_t>>({{1, 3, 30, 30}, {2, 4, 31, 31}})), // ::testing::Values(std::vector<std::vector<size_t>>({{1, 3, 30, 30}, {2, 4, 31, 31}})),
::testing::Values(std::vector<std::vector<size_t>>({{2, 4, 31, 31}})),
::testing::Values(CommonTestUtils::DEVICE_TEMPLATE)), ::testing::Values(CommonTestUtils::DEVICE_TEMPLATE)),
ConvolutionLayerTest::getTestCaseName); ConvolutionLayerTest::getTestCaseName);
// ! [test_convolution:instantiate] // ! [test_convolution:instantiate]

View File

@ -45,6 +45,7 @@ public:
protected: protected:
void SetUp() override; void SetUp() override;
std::shared_ptr<ngraph::Function> makeConvolution(const std::string& name = ""); std::shared_ptr<ngraph::Function> makeConvolution(const std::string& name = "");
void Run() override;
private: private:
InferenceEngine::Precision::ePrecision netPrecision = InferenceEngine::Precision::UNSPECIFIED; InferenceEngine::Precision::ePrecision netPrecision = InferenceEngine::Precision::UNSPECIFIED;

View File

@ -57,8 +57,8 @@ void ConvolutionLayerTest::SetUp() {
std::tie(kernel, stride, padBegin, padEnd, dilation, convOutChannels, padType) = convParams; std::tie(kernel, stride, padBegin, padEnd, dilation, convOutChannels, padType) = convParams;
setTargetStaticShape(targetStaticShapes[0]); setTargetStaticShape(targetStaticShapes[0]);
function = makeConvolution(); function = makeConvolution("convolution");
functionRefs = makeConvolution(); functionRefs = makeConvolution("convolutionRefs");
} }
std::shared_ptr<ngraph::Function> ConvolutionLayerTest::makeConvolution(const std::string& name) { std::shared_ptr<ngraph::Function> ConvolutionLayerTest::makeConvolution(const std::string& name) {
@ -76,7 +76,48 @@ std::shared_ptr<ngraph::Function> ConvolutionLayerTest::makeConvolution(const st
ngraph::builder::makeConvolution(paramOuts[0], ngPrc, kernel, stride, padBegin, ngraph::builder::makeConvolution(paramOuts[0], ngPrc, kernel, stride, padBegin,
padEnd, dilation, padType, convOutChannels, false, filter_weights)); padEnd, dilation, padType, convOutChannels, false, filter_weights));
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(conv)}; ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(conv)};
return std::make_shared<ngraph::Function>(results, params, "convolution"); return std::make_shared<ngraph::Function>(results, params, name);
}
void ConvolutionLayerTest::Run() {
auto crashHandler = [](int errCode) {
auto &s = LayerTestsUtils::Summary::getInstance();
s.saveReport();
std::cout << "Unexpected application crash!" << std::endl;
std::abort();
};
signal(SIGSEGV, crashHandler);
auto &s = LayerTestsUtils::Summary::getInstance();
s.setDeviceName(targetDevice);
if (FuncTestUtils::SkipTestsConfig::currentTestIsDisabled()) {
s.updateOPsStats(function, LayerTestsUtils::PassRate::Statuses::SKIPPED);
GTEST_SKIP() << "Disabled test due to configuration" << std::endl;
} else {
s.updateOPsStats(function, LayerTestsUtils::PassRate::Statuses::CRASHED);
}
try {
LoadNetwork();
for (auto&& tss : targetStaticShapes) {
setTargetStaticShape(tss);
GenerateInputs();
Infer();
Validate();
s.updateOPsStats(function, LayerTestsUtils::PassRate::Statuses::PASSED);
}
}
catch (const std::runtime_error &re) {
s.updateOPsStats(function, LayerTestsUtils::PassRate::Statuses::FAILED);
GTEST_FATAL_FAILURE_(re.what());
} catch (const std::exception &ex) {
s.updateOPsStats(function, LayerTestsUtils::PassRate::Statuses::FAILED);
GTEST_FATAL_FAILURE_(ex.what());
} catch (...) {
s.updateOPsStats(function, LayerTestsUtils::PassRate::Statuses::FAILED);
GTEST_FATAL_FAILURE_("Unknown failure occurred.");
}
} }
} // namespace LayerTestsDefinitions } // namespace LayerTestsDefinitions