While working on applying multiple target shapes to templateFuncTests, still failed

This commit is contained in:
Steve Yoo 2021-09-27 20:23:27 +09:00
parent e73a9741de
commit 0870cc6cc1
5 changed files with 21 additions and 1 deletions

View File

@ -84,7 +84,9 @@ INSTANTIATE_TEST_SUITE_P(Convolution2D_ExplicitPaddingDynamicShape, ConvolutionL
::testing::Values(std::vector<std::vector<std::vector<size_t>>>({{{1, 3, 30, 30}}}), ::testing::Values(std::vector<std::vector<std::vector<size_t>>>({{{1, 3, 30, 30}}}),
std::vector<std::vector<std::vector<size_t>>>({{{2, 4, 31, 31}}}), std::vector<std::vector<std::vector<size_t>>>({{{2, 4, 31, 31}}}),
std::vector<std::vector<std::vector<size_t>>>({{{1, 3, 30, 30}}, std::vector<std::vector<std::vector<size_t>>>({{{1, 3, 30, 30}},
{{2, 4, 31, 31}}})), {{2, 4, 31, 31}}}),
std::vector<std::vector<std::vector<size_t>>>({{{2, 4, 31, 31}},
{{1, 3, 30, 30}}})),
::testing::Values(CommonTestUtils::DEVICE_TEMPLATE)), ::testing::Values(CommonTestUtils::DEVICE_TEMPLATE)),
ConvolutionLayerTest::getTestCaseName); ConvolutionLayerTest::getTestCaseName);
// ! [test_convolution:instantiate] // ! [test_convolution:instantiate]

View File

@ -153,6 +153,8 @@ protected:
virtual void setTargetStaticShape(std::vector<ngraph::Shape>& desiredTargetStaticShape) {} virtual void setTargetStaticShape(std::vector<ngraph::Shape>& desiredTargetStaticShape) {}
virtual bool updateFunctionRefs() {return false;}
virtual void Validate(); virtual void Validate();
virtual std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> CalculateRefs(); virtual std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> CalculateRefs();

View File

@ -46,6 +46,7 @@ protected:
void SetUp() override; void SetUp() override;
std::shared_ptr<ngraph::Function> makeConvolution(const std::string& name = ""); std::shared_ptr<ngraph::Function> makeConvolution(const std::string& name = "");
void setTargetStaticShape(std::vector<ngraph::Shape>& desiredTargetStaticShape) override; void setTargetStaticShape(std::vector<ngraph::Shape>& desiredTargetStaticShape) override;
bool updateFunctionRefs() override;
private: private:
InferenceEngine::Precision::ePrecision netPrecision = InferenceEngine::Precision::UNSPECIFIED; InferenceEngine::Precision::ePrecision netPrecision = InferenceEngine::Precision::UNSPECIFIED;

View File

@ -49,6 +49,7 @@ void LayerTestsCommon::Run() {
LoadNetwork(); LoadNetwork();
for (auto&& tss : targetStaticShapes) { for (auto&& tss : targetStaticShapes) {
setTargetStaticShape(tss); setTargetStaticShape(tss);
updateFunctionRefs();
GenerateInputs(); GenerateInputs();
Infer(); Infer();
Validate(); Validate();

View File

@ -88,4 +88,18 @@ void ConvolutionLayerTest::setTargetStaticShape(std::vector<ngraph::Shape>& desi
targetStaticShape = desiredTargetStaticShape; targetStaticShape = desiredTargetStaticShape;
} }
bool ConvolutionLayerTest::updateFunctionRefs() {
auto params = functionRefs->get_parameters()[0];
if (!params) {
return false;
}
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params_new = std::make_shared<ngraph::opset1::Parameter>(ngPrc, targetStaticShape.front());
params_new->set_friendly_name(params->get_friendly_name());
ngraph::copy_runtime_info(params, params_new);
ngraph::replace_node(params, params_new);
functionRefs->validate_nodes_and_infer_types();
return true;
}
} // namespace LayerTestsDefinitions } // namespace LayerTestsDefinitions