[CPU] Fixes for static postops after dynamic convolutions (#13260)

This commit is contained in:
Vladislav Golubev 2022-09-29 15:42:05 +02:00 committed by GitHub
parent 2958756a39
commit 1a51d1cac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 6 deletions

View File

@ -1493,7 +1493,7 @@ void Convolution::redefineOutputMemory(const std::vector<VectorDims> &newOutputS
MemoryDescPtr Convolution::getSumMemDesc(primitive_desc_iterator &primitive_desc_it) { MemoryDescPtr Convolution::getSumMemDesc(primitive_desc_iterator &primitive_desc_it) {
if (getOutputShapeAtPort(0).isDynamic()) { if (getOutputShapeAtPort(0).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.dst_desc(0), getInputShapeAtPort(getParentEdges().size() - 1)); return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.dst_desc(0), getOutputShapeAtPort(0));
} }
return DnnlExtensionUtils::makeDescriptor(primitive_desc_it.dst_desc(0)); return DnnlExtensionUtils::makeDescriptor(primitive_desc_it.dst_desc(0));
} }

View File

@ -128,7 +128,7 @@ void Reorder::initSupportedPrimitiveDescriptors() {
} }
void Reorder::createPrimitive() { void Reorder::createPrimitive() {
if (inputShapesDefined()) { if (shapesDefined()) {
if (needPrepareParams()) if (needPrepareParams())
prepareParams(); prepareParams();
updateLastInputDims(); updateLastInputDims();

View File

@ -321,7 +321,8 @@ InputShape convInpShape = {
} }
}; };
InputShape secondInp = { const std::vector<InputShape> secondInp = {
{
//dynamic shapes //dynamic shapes
{-1, -1, -1, -1}, {-1, -1, -1, -1},
{ //target static shapes { //target static shapes
@ -331,12 +332,19 @@ InputShape secondInp = {
{1, 64, 8, 8}, {1, 64, 8, 8},
{1, 64, 8, 1} {1, 64, 8, 1}
} }
},
{
{1, 64, 8, 8},
{
{1, 64, 8, 8}
}
},
}; };
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConcatConvSumInPlaceTest, INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConcatConvSumInPlaceTest,
::testing::Combine( ::testing::Combine(
::testing::Values(convInpShape), ::testing::Values(convInpShape),
::testing::Values(secondInp), ::testing::ValuesIn(secondInp),
::testing::Values(true, false), ::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSet), ::testing::ValuesIn(fusingParamsSet),
::testing::Values(cpuEmptyPluginConfig)), ::testing::Values(cpuEmptyPluginConfig)),
@ -345,7 +353,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_FP32, ConcatConvSumInPlaceTest
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConcatConvSumInPlaceTest, INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConcatConvSumInPlaceTest,
::testing::Combine( ::testing::Combine(
::testing::Values(convInpShape), ::testing::Values(convInpShape),
::testing::Values(secondInp), ::testing::ValuesIn(secondInp),
::testing::Values(true, false), ::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSetBF16), ::testing::ValuesIn(fusingParamsSetBF16),
::testing::Values(cpuBF16PluginConfig)), ::testing::Values(cpuBF16PluginConfig)),
@ -354,7 +362,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_BF16, ConcatConvSumInPlaceTest
INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConcatConvSumInPlaceTestInt8, INSTANTIATE_TEST_SUITE_P(smoke_Conv_Sum_Broadcast_INT8, ConcatConvSumInPlaceTestInt8,
::testing::Combine( ::testing::Combine(
::testing::Values(convInpShape), ::testing::Values(convInpShape),
::testing::Values(secondInp), ::testing::ValuesIn(secondInp),
::testing::Values(true, false), ::testing::Values(true, false),
::testing::ValuesIn(fusingParamsSet), ::testing::ValuesIn(fusingParamsSet),
::testing::Values(cpuEmptyPluginConfig)), ::testing::Values(cpuEmptyPluginConfig)),