While working on applying multiple target shapes to templateFuncTests, still failed
This commit is contained in:
parent
e73a9741de
commit
0870cc6cc1
@ -84,7 +84,9 @@ INSTANTIATE_TEST_SUITE_P(Convolution2D_ExplicitPaddingDynamicShape, ConvolutionL
|
||||
::testing::Values(std::vector<std::vector<std::vector<size_t>>>({{{1, 3, 30, 30}}}),
|
||||
std::vector<std::vector<std::vector<size_t>>>({{{2, 4, 31, 31}}}),
|
||||
std::vector<std::vector<std::vector<size_t>>>({{{1, 3, 30, 30}},
|
||||
{{2, 4, 31, 31}}})),
|
||||
{{2, 4, 31, 31}}}),
|
||||
std::vector<std::vector<std::vector<size_t>>>({{{2, 4, 31, 31}},
|
||||
{{1, 3, 30, 30}}})),
|
||||
::testing::Values(CommonTestUtils::DEVICE_TEMPLATE)),
|
||||
ConvolutionLayerTest::getTestCaseName);
|
||||
// ! [test_convolution:instantiate]
|
||||
|
@ -153,6 +153,8 @@ protected:
|
||||
|
||||
virtual void setTargetStaticShape(std::vector<ngraph::Shape>& desiredTargetStaticShape) {}
|
||||
|
||||
virtual bool updateFunctionRefs() {return false;}
|
||||
|
||||
virtual void Validate();
|
||||
|
||||
virtual std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> CalculateRefs();
|
||||
|
@ -46,6 +46,7 @@ protected:
|
||||
void SetUp() override;
|
||||
std::shared_ptr<ngraph::Function> makeConvolution(const std::string& name = "");
|
||||
void setTargetStaticShape(std::vector<ngraph::Shape>& desiredTargetStaticShape) override;
|
||||
bool updateFunctionRefs() override;
|
||||
|
||||
private:
|
||||
InferenceEngine::Precision::ePrecision netPrecision = InferenceEngine::Precision::UNSPECIFIED;
|
||||
|
@ -49,6 +49,7 @@ void LayerTestsCommon::Run() {
|
||||
LoadNetwork();
|
||||
for (auto&& tss : targetStaticShapes) {
|
||||
setTargetStaticShape(tss);
|
||||
updateFunctionRefs();
|
||||
GenerateInputs();
|
||||
Infer();
|
||||
Validate();
|
||||
|
@ -88,4 +88,18 @@ void ConvolutionLayerTest::setTargetStaticShape(std::vector<ngraph::Shape>& desi
|
||||
targetStaticShape = desiredTargetStaticShape;
|
||||
}
|
||||
|
||||
bool ConvolutionLayerTest::updateFunctionRefs() {
|
||||
auto params = functionRefs->get_parameters()[0];
|
||||
if (!params) {
|
||||
return false;
|
||||
}
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params_new = std::make_shared<ngraph::opset1::Parameter>(ngPrc, targetStaticShape.front());
|
||||
params_new->set_friendly_name(params->get_friendly_name());
|
||||
ngraph::copy_runtime_info(params, params_new);
|
||||
ngraph::replace_node(params, params_new);
|
||||
functionRefs->validate_nodes_and_infer_types();
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user