SplitSqueezeConcatFusion - handle Squeeze nodes without second input (#19512)

Ticket: CVS-119330
This commit is contained in:
Mateusz Tabaka 2023-09-01 21:35:17 +02:00 committed by GitHub
parent a4b0fe51af
commit 441adcc122
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 6 deletions

View File

@ -45,14 +45,28 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
nodes_to_delete.push_back(squeeze);
auto split_to_check =
std::dynamic_pointer_cast<ov::op::v1::Split>(squeeze->input_value(0).get_node_shared_ptr());
auto squeeze_axes =
std::dynamic_pointer_cast<ov::op::v0::Constant>(squeeze->input_value(1).get_node_shared_ptr());
if (!squeeze_axes || !split_to_check)
auto split_to_check = std::dynamic_pointer_cast<ov::op::v1::Split>(squeeze->get_input_node_shared_ptr(0));
if (!split_to_check)
return false;
std::vector<int64_t> squeeze_axes_vec;
if (squeeze->get_input_size() < 2) {
const auto& shape = squeeze->get_input_partial_shape(0);
if (shape.is_dynamic()) {
return false;
}
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i].get_length() == 1)
squeeze_axes_vec.push_back(static_cast<int64_t>(i));
}
} else {
auto squeeze_axes =
std::dynamic_pointer_cast<ov::op::v0::Constant>(squeeze->get_input_node_shared_ptr(1));
if (!squeeze_axes)
return false;
squeeze_axes_vec = squeeze_axes->cast_vector<int64_t>();
}
auto squeeze_axes_vec = squeeze_axes->cast_vector<int64_t>();
if (squeeze_axes_vec.size() != 1)
return false;

View File

@ -52,6 +52,38 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusion) {
}
}
TEST_F(TransformationTestsF, SplitSqueezeConcatFusionSqueezeWithoutAxesInput) {
size_t num_splits = 4;
{
auto input = std::make_shared<opset7::Parameter>(element::f32, Shape{3, 2, num_splits, 640, 20, 2});
auto split_axis = opset7::Constant::create(element::i64, Shape{}, {2});
auto split = std::make_shared<opset7::Split>(input, split_axis, num_splits);
OutputVector squeeze_vec(num_splits);
for (size_t i = 0; i < squeeze_vec.size(); i++) {
squeeze_vec[i] = std::make_shared<opset7::Squeeze>(split->output(i));
}
auto concat = std::make_shared<opset7::Concat>(squeeze_vec, 4);
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
}
{
auto input = std::make_shared<opset7::Parameter>(element::f32, Shape{3, 2, num_splits, 640, 20, 2});
auto transpose_order = opset7::Constant::create(element::i64, Shape{6}, {0, 1, 3, 4, 2, 5});
auto transpose = std::make_shared<opset7::Transpose>(input, transpose_order);
auto reshape_shape =
opset7::Constant::create<int64_t>(element::i64, Shape{5}, {3, 2, 640, 20, 2 * (int64_t)num_splits});
auto reshape = std::make_shared<opset7::Reshape>(transpose, reshape_shape, false);
model_ref = std::make_shared<ov::Model>(NodeVector{reshape}, ParameterVector{input});
}
}
TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseNotAllSplitOutputsGoToSqueeze) {
size_t num_splits = 4;
@ -172,3 +204,24 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitAxisDiffer
model_ref = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
}
}
TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeSqueezeWithoutAxesInputMultipleUnitDimensions) {
size_t num_splits = 4;
{
auto input = std::make_shared<opset7::Parameter>(element::f32, Shape{1, 2, num_splits, 640, 20, 2});
auto split_axis = opset7::Constant::create(element::i64, Shape{}, {2});
auto split = std::make_shared<opset7::Split>(input, split_axis, num_splits);
OutputVector squeeze_vec(num_splits);
for (size_t i = 0; i < squeeze_vec.size(); i++) {
squeeze_vec[i] = std::make_shared<opset7::Squeeze>(split->output(i));
}
auto concat = std::make_shared<opset7::Concat>(squeeze_vec, 3);
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
}
}