[CPU] Convolution shape inference fix (#14548)
This commit is contained in:
parent
95cf0ca7ba
commit
6091b425af
@ -1504,7 +1504,7 @@ void Convolution::executeDynamicImpl(dnnl::stream strm) {
|
||||
auto out = subgraph->getOutput(0);
|
||||
const auto& outMem = out->getParentEdgesAtPort(0).front()->getMemory();
|
||||
auto convOutMem = getChildEdgesAtPort(0).front()->getMemoryPtr();
|
||||
convOutMem->redefineDesc(getBaseMemDescAtOutputPort(0)->cloneWithNewDims(outMem.getStaticDims()));
|
||||
Node::redefineOutputMemory({outMem.getStaticDims()});
|
||||
convOutMem->SetData(outMem);
|
||||
}
|
||||
}
|
||||
|
@ -219,6 +219,25 @@ TEST_P(ConcatConvSumInPlaceTestInt8, CompareWithRefs) {
|
||||
CheckPluginRelatedResults(compiledModel, "Convolution");
|
||||
}
|
||||
|
||||
class ConcatConvSumInPlaceTestSeveralConsumers : public ConcatConvSumInPlaceTest {
|
||||
public:
|
||||
std::shared_ptr<ngraph::Node> addSum(std::shared_ptr<ngraph::Node> lastNode, const ngraph::ParameterVector& inputParams) override {
|
||||
auto sum = std::make_shared<ngraph::opset3::Add>(lastNode, inputParams[1]);
|
||||
fusedOps.insert(fusedOps.begin(), "Add");
|
||||
|
||||
auto shapeOf = std::make_shared<ngraph::opset3::ShapeOf>(sum);
|
||||
return std::make_shared<ngraph::opset3::Reshape>(sum, shapeOf, true);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConcatConvSumInPlaceTestSeveralConsumers, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
run();
|
||||
|
||||
CheckPluginRelatedResults(compiledModel, "Convolution");
|
||||
}
|
||||
|
||||
namespace {
|
||||
const auto fusingMulAddFQMullAdd = fusingSpecificParams{ std::make_shared<postNodesMgr>(std::vector<postNodeBuilder>{
|
||||
{[](postNodeConfig& cfg) {
|
||||
@ -368,5 +387,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConcatConvSumInPlaceTest
|
||||
::testing::Values(cpuEmptyPluginConfig)),
|
||||
ConcatConvSumInPlaceTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_Several_Consumers, ConcatConvSumInPlaceTestSeveralConsumers,
|
||||
::testing::Combine(
|
||||
::testing::Values(convInpShape),
|
||||
::testing::ValuesIn(secondInp),
|
||||
::testing::Values(true),
|
||||
::testing::Values(emptyFusingSpec),
|
||||
::testing::Values(cpuEmptyPluginConfig)),
|
||||
ConcatConvSumInPlaceTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user