PullReshapeThroughReduce - skip transformation if Reshape doesn't unsqueeze input (#19477)

Ticket: CVS-118905
This commit is contained in:
Mateusz Tabaka 2023-09-04 13:58:53 +02:00 committed by GitHub
parent c46f6bf115
commit bd0c156a70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 0 deletions

View File

@ -55,6 +55,9 @@ const std::vector<int64_t> adjust_axes(const std::vector<int64_t>& axes_to_align
// - Reshape(input_shape={5,10,15}, target_shape={5,10,1,15}), 2 axis is returned // - Reshape(input_shape={5,10,15}, target_shape={5,10,1,15}), 2 axis is returned
std::vector<int64_t> try_get_unsqueeze_axes_from_reshape(const ov::Shape& target_shape, const ov::Shape& input_shape) { std::vector<int64_t> try_get_unsqueeze_axes_from_reshape(const ov::Shape& target_shape, const ov::Shape& input_shape) {
std::vector<int64_t> result; std::vector<int64_t> result;
if (target_shape.size() <= input_shape.size()) {
return result;
}
if (input_shape.size() == 0) { // scalar case - can be reshaped only to [1,..,1] shape if (input_shape.size() == 0) { // scalar case - can be reshaped only to [1,..,1] shape
result.resize(target_shape.size(), 0); result.resize(target_shape.size(), 0);
std::iota(std::begin(result), std::end(result), 0); std::iota(std::begin(result), std::end(result), 0);

View File

@ -360,6 +360,11 @@ TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfTheSameAxesScalarCase
manager.register_pass<pass::PullReshapeThroughReduce>(); manager.register_pass<pass::PullReshapeThroughReduce>();
} }
TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfReshapeDoesntUnsqueeze) {
model = generate_reshape_model<ReduceMean>(element::f32, {1, 100, 1}, {1, 1, 100}, {2});
manager.register_pass<pass::PullReshapeThroughReduce>();
}
TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfNonConstAxes) { TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfNonConstAxes) {
const auto input = std::make_shared<Parameter>(element::f32, PartialShape{5, 10, 15}); const auto input = std::make_shared<Parameter>(element::f32, PartialShape{5, 10, 15});
const auto target_shape = Constant::create(element::i64, Shape{4}, {1, 5, 10, 15}); const auto target_shape = Constant::create(element::i64, Shape{4}, {1, 5, 10, 15});