[CPU] Convolution shape inference fix (#14548)

This commit is contained in:
Vladislav Golubev 2022-12-12 10:59:37 +01:00 committed by GitHub
parent 95cf0ca7ba
commit 6091b425af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 1 deletions

View File

@ -1504,7 +1504,7 @@ void Convolution::executeDynamicImpl(dnnl::stream strm) {
auto out = subgraph->getOutput(0);
const auto& outMem = out->getParentEdgesAtPort(0).front()->getMemory();
auto convOutMem = getChildEdgesAtPort(0).front()->getMemoryPtr();
convOutMem->redefineDesc(getBaseMemDescAtOutputPort(0)->cloneWithNewDims(outMem.getStaticDims()));
Node::redefineOutputMemory({outMem.getStaticDims()});
convOutMem->SetData(outMem);
}
}

View File

@ -219,6 +219,25 @@ TEST_P(ConcatConvSumInPlaceTestInt8, CompareWithRefs) {
CheckPluginRelatedResults(compiledModel, "Convolution");
}
class ConcatConvSumInPlaceTestSeveralConsumers : public ConcatConvSumInPlaceTest {
public:
std::shared_ptr<ngraph::Node> addSum(std::shared_ptr<ngraph::Node> lastNode, const ngraph::ParameterVector& inputParams) override {
auto sum = std::make_shared<ngraph::opset3::Add>(lastNode, inputParams[1]);
fusedOps.insert(fusedOps.begin(), "Add");
auto shapeOf = std::make_shared<ngraph::opset3::ShapeOf>(sum);
return std::make_shared<ngraph::opset3::Reshape>(sum, shapeOf, true);
}
};
TEST_P(ConcatConvSumInPlaceTestSeveralConsumers, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
CheckPluginRelatedResults(compiledModel, "Convolution");
}
namespace {
const auto fusingMulAddFQMullAdd = fusingSpecificParams{ std::make_shared<postNodesMgr>(std::vector<postNodeBuilder>{
{[](postNodeConfig& cfg) {
@ -368,5 +387,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConcatConvSumInPlaceTest
::testing::Values(cpuEmptyPluginConfig)),
ConcatConvSumInPlaceTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_Several_Consumers, ConcatConvSumInPlaceTestSeveralConsumers,
::testing::Combine(
::testing::Values(convInpShape),
::testing::ValuesIn(secondInp),
::testing::Values(true),
::testing::Values(emptyFusingSpec),
::testing::Values(cpuEmptyPluginConfig)),
ConcatConvSumInPlaceTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions