SplitSqueezeConcatFusion - handle Squeeze nodes without second input (#19512)
Ticket: CVS-119330
This commit is contained in:
parent
a4b0fe51af
commit
441adcc122
@ -45,14 +45,28 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
|
||||
|
||||
nodes_to_delete.push_back(squeeze);
|
||||
|
||||
auto split_to_check =
|
||||
std::dynamic_pointer_cast<ov::op::v1::Split>(squeeze->input_value(0).get_node_shared_ptr());
|
||||
auto squeeze_axes =
|
||||
std::dynamic_pointer_cast<ov::op::v0::Constant>(squeeze->input_value(1).get_node_shared_ptr());
|
||||
if (!squeeze_axes || !split_to_check)
|
||||
auto split_to_check = std::dynamic_pointer_cast<ov::op::v1::Split>(squeeze->get_input_node_shared_ptr(0));
|
||||
if (!split_to_check)
|
||||
return false;
|
||||
std::vector<int64_t> squeeze_axes_vec;
|
||||
if (squeeze->get_input_size() < 2) {
|
||||
const auto& shape = squeeze->get_input_partial_shape(0);
|
||||
if (shape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i].get_length() == 1)
|
||||
squeeze_axes_vec.push_back(static_cast<int64_t>(i));
|
||||
}
|
||||
|
||||
} else {
|
||||
auto squeeze_axes =
|
||||
std::dynamic_pointer_cast<ov::op::v0::Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
if (!squeeze_axes)
|
||||
return false;
|
||||
squeeze_axes_vec = squeeze_axes->cast_vector<int64_t>();
|
||||
}
|
||||
|
||||
auto squeeze_axes_vec = squeeze_axes->cast_vector<int64_t>();
|
||||
if (squeeze_axes_vec.size() != 1)
|
||||
return false;
|
||||
|
||||
|
@ -52,6 +52,38 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusion) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SplitSqueezeConcatFusionSqueezeWithoutAxesInput) {
|
||||
size_t num_splits = 4;
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset7::Parameter>(element::f32, Shape{3, 2, num_splits, 640, 20, 2});
|
||||
auto split_axis = opset7::Constant::create(element::i64, Shape{}, {2});
|
||||
auto split = std::make_shared<opset7::Split>(input, split_axis, num_splits);
|
||||
OutputVector squeeze_vec(num_splits);
|
||||
|
||||
for (size_t i = 0; i < squeeze_vec.size(); i++) {
|
||||
squeeze_vec[i] = std::make_shared<opset7::Squeeze>(split->output(i));
|
||||
}
|
||||
|
||||
auto concat = std::make_shared<opset7::Concat>(squeeze_vec, 4);
|
||||
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset7::Parameter>(element::f32, Shape{3, 2, num_splits, 640, 20, 2});
|
||||
auto transpose_order = opset7::Constant::create(element::i64, Shape{6}, {0, 1, 3, 4, 2, 5});
|
||||
auto transpose = std::make_shared<opset7::Transpose>(input, transpose_order);
|
||||
auto reshape_shape =
|
||||
opset7::Constant::create<int64_t>(element::i64, Shape{5}, {3, 2, 640, 20, 2 * (int64_t)num_splits});
|
||||
auto reshape = std::make_shared<opset7::Reshape>(transpose, reshape_shape, false);
|
||||
|
||||
model_ref = std::make_shared<ov::Model>(NodeVector{reshape}, ParameterVector{input});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseNotAllSplitOutputsGoToSqueeze) {
|
||||
size_t num_splits = 4;
|
||||
|
||||
@ -172,3 +204,24 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitAxisDiffer
|
||||
model_ref = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeSqueezeWithoutAxesInputMultipleUnitDimensions) {
|
||||
size_t num_splits = 4;
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset7::Parameter>(element::f32, Shape{1, 2, num_splits, 640, 20, 2});
|
||||
auto split_axis = opset7::Constant::create(element::i64, Shape{}, {2});
|
||||
auto split = std::make_shared<opset7::Split>(input, split_axis, num_splits);
|
||||
OutputVector squeeze_vec(num_splits);
|
||||
|
||||
for (size_t i = 0; i < squeeze_vec.size(); i++) {
|
||||
squeeze_vec[i] = std::make_shared<opset7::Squeeze>(split->output(i));
|
||||
}
|
||||
|
||||
auto concat = std::make_shared<opset7::Concat>(squeeze_vec, 3);
|
||||
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user