[CPU] Fixed dynamic Concat with negative axis (#9861)

This commit is contained in:
Alexandra Sidorova 2022-01-26 17:10:34 +03:00 committed by GitHub
parent c026f8348b
commit 1d82294e00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 11 deletions

View File

@ -55,13 +55,16 @@ MKLDNNConcatNode::MKLDNNConcatNode(const std::shared_ptr<ngraph::Node>& op, cons
IE_THROW(NotImplemented) << errorMessage;
}
const auto inRank = getInputShapeAtPort(0).getRank();
auto concatOp = ngraph::as_type_ptr<ngraph::op::v0::Concat>(op);
auto axis = concatOp->get_axis();
if (axis < 0) {
this->axis = concatOp->get_input_shape(0).size() + axis;
} else {
this->axis = axis;
axis += inRank;
}
if (axis >= inRank || axis < 0) {
IE_THROW() << "Concat node with name '" << getName() << "' has invalid value of axis parameter: " << axis;
}
this->axis = axis;
}
void MKLDNNConcatNode::getSupportedDescriptors() {

View File

@ -126,7 +126,7 @@ const std::vector<ElementType> netPrecisions = {
INSTANTIATE_TEST_SUITE_P(smoke_Concat4D_CPU_Block8_static, ConcatLayerCPUTest,
::testing::Combine(
::testing::Values(1, 2, 3),
::testing::Values(1, -2, 3),
::testing::Values(static_shapes_to_test_representation({{2, 16, 3, 5}, {2, 16, 3, 5}})),
::testing::ValuesIn(netPrecisions),
::testing::Values(planar_4D_ref, planarChannels_4D, blocked8_4D_ref)),
@ -134,7 +134,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Concat4D_CPU_Block8_static, ConcatLayerCPUTest,
INSTANTIATE_TEST_SUITE_P(smoke_Concat4D_CPU_Block16_static, ConcatLayerCPUTest,
::testing::Combine(
::testing::Values(1, 2, 3),
::testing::Values(1, 2, -1),
::testing::Values(static_shapes_to_test_representation({{3, 32, 3, 5}, {3, 32, 3, 5}})),
::testing::ValuesIn(netPrecisions),
::testing::Values(blocked16_4D_ref)),
@ -156,7 +156,7 @@ const std::vector<std::vector<InputShape>> inputShapes4D_Block_axis1 = {
INSTANTIATE_TEST_SUITE_P(smoke_Concat4D_CPU_Block_dynamic_axis_1, ConcatLayerCPUTest,
::testing::Combine(
::testing::Values(1),
::testing::Values(1, -3),
::testing::ValuesIn(inputShapes4D_Block_axis1),
::testing::ValuesIn(netPrecisions),
::testing::Values(blocked8_4D_ref, blocked16_4D_ref)),
@ -229,7 +229,7 @@ const std::vector<std::vector<InputShape>> inputShapes4D_axis2 = {
INSTANTIATE_TEST_SUITE_P(smoke_Concat4D_CPU_dynamic_axis_2, ConcatLayerCPUTest,
::testing::Combine(
::testing::Values(2),
::testing::Values(2, -2),
::testing::ValuesIn(inputShapes4D_axis2),
::testing::ValuesIn(netPrecisions),
::testing::Values(planar_4D_ref, planarChannels_4D)),
@ -271,7 +271,7 @@ const std::vector<std::vector<InputShape>> inputShapes4D_axis3 = {
INSTANTIATE_TEST_SUITE_P(smoke_Concat4D_CPU_dynamic_axis_3, ConcatLayerCPUTest,
::testing::Combine(
::testing::Values(3),
::testing::Values(3, -1),
::testing::ValuesIn(inputShapes4D_axis3),
::testing::ValuesIn(netPrecisions),
::testing::Values(planar_4D_ref, planarChannels_4D)),
@ -279,7 +279,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Concat4D_CPU_dynamic_axis_3, ConcatLayerCPUTest,
INSTANTIATE_TEST_SUITE_P(smoke_Concat5D_CPU_Block8_static, ConcatLayerCPUTest,
::testing::Combine(
::testing::Values(2, 3, 4),
::testing::Values(2, 3, -2),
::testing::Values(static_shapes_to_test_representation({{2, 16, 3, 5, 7}, {2, 16, 3, 5, 7}})),
::testing::ValuesIn(netPrecisions),
::testing::Values(planar_5D_ref, planarChannels_5D, blocked8_5D_ref)),
@ -350,7 +350,7 @@ const std::vector<std::vector<InputShape>> inputShapes5D_Block_axis2 = {
INSTANTIATE_TEST_SUITE_P(smoke_Concat5D_CPU_Block_dynamic_axis_2, ConcatLayerCPUTest,
::testing::Combine(
::testing::Values(2),
::testing::Values(-3),
::testing::ValuesIn(inputShapes5D_Block_axis2),
::testing::ValuesIn(netPrecisions),
::testing::Values(blocked8_5D_ref, blocked16_5D_ref)),
@ -645,7 +645,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Concat5D_CPU_Block16inPlace, ConcatLayerCPUTest,
INSTANTIATE_TEST_SUITE_P(smoke_Concat_inPlace, ConcatLayerCPUTest,
::testing::Combine(
::testing::Values(0, 1, 2),
::testing::Values(0, 1, 2, -1),
::testing::ValuesIn(std::vector<std::vector<InputShape>>{
static_shapes_to_test_representation({{1, 1, 1, 10}, {1, 1, 1, 10}}),
static_shapes_to_test_representation({{1, 1, 5}, {1, 1, 5}})}),