[CPU] Fixes for static postops after dynamic convolutions (#13260)

This commit is contained in:
Vladislav Golubev 2022-09-29 15:42:05 +02:00 committed by GitHub
parent 2958756a39
commit 1a51d1cac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 6 deletions

View File

@ -1493,7 +1493,7 @@ void Convolution::redefineOutputMemory(const std::vector<VectorDims> &newOutputS
MemoryDescPtr Convolution::getSumMemDesc(primitive_desc_iterator &primitive_desc_it) {
if (getOutputShapeAtPort(0).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.dst_desc(0), getInputShapeAtPort(getParentEdges().size() - 1));
return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.dst_desc(0), getOutputShapeAtPort(0));
}
return DnnlExtensionUtils::makeDescriptor(primitive_desc_it.dst_desc(0));
}

View File

@ -128,7 +128,7 @@ void Reorder::initSupportedPrimitiveDescriptors() {
}
void Reorder::createPrimitive() {
if (inputShapesDefined()) {
if (shapesDefined()) {
if (needPrepareParams())
prepareParams();
updateLastInputDims();

View File

@ -321,7 +321,8 @@ InputShape convInpShape = {
}
};
InputShape secondInp = {
const std::vector<InputShape> secondInp = {
{
//dynamic shapes
{-1, -1, -1, -1},
{ //target static shapes
@ -331,12 +332,19 @@ InputShape secondInp = {
{1, 64, 8, 8},
{1, 64, 8, 1}
}
},
{
{1, 64, 8, 8},
{
{1, 64, 8, 8}
}
},
};
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConcatConvSumInPlaceTest,
::testing::Combine(
::testing::Values(convInpShape),
::testing::Values(secondInp),
::testing::ValuesIn(secondInp),
::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSet),
::testing::Values(cpuEmptyPluginConfig)),
@ -345,7 +353,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConcatConvSumInPlaceTest
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConcatConvSumInPlaceTest,
::testing::Combine(
::testing::Values(convInpShape),
::testing::Values(secondInp),
::testing::ValuesIn(secondInp),
::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSetBF16),
::testing::Values(cpuBF16PluginConfig)),
@ -354,7 +362,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConcatConvSumInPlaceTest
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConcatConvSumInPlaceTestInt8,
::testing::Combine(
::testing::Values(convInpShape),
::testing::Values(secondInp),
::testing::ValuesIn(secondInp),
::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSet),
::testing::Values(cpuEmptyPluginConfig)),