PullReshapeThroughReduce - skip transformation if Reshape doesn't unsqueeze input (#19477)
Ticket: CVS-118905
This commit is contained in:
parent
c46f6bf115
commit
bd0c156a70
@ -55,6 +55,9 @@ const std::vector<int64_t> adjust_axes(const std::vector<int64_t>& axes_to_align
|
||||
// - Reshape(input_shape={5,10,15}, target_shape={5,10,1,15}), 2 axis is returned
|
||||
std::vector<int64_t> try_get_unsqueeze_axes_from_reshape(const ov::Shape& target_shape, const ov::Shape& input_shape) {
|
||||
std::vector<int64_t> result;
|
||||
if (target_shape.size() <= input_shape.size()) {
|
||||
return result;
|
||||
}
|
||||
if (input_shape.size() == 0) { // scalar case - can be reshaped only to [1,..,1] shape
|
||||
result.resize(target_shape.size(), 0);
|
||||
std::iota(std::begin(result), std::end(result), 0);
|
||||
|
@ -360,6 +360,11 @@ TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfTheSameAxesScalarCase
|
||||
manager.register_pass<pass::PullReshapeThroughReduce>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfReshapeDoesntUnsqueeze) {
|
||||
model = generate_reshape_model<ReduceMean>(element::f32, {1, 100, 1}, {1, 1, 100}, {2});
|
||||
manager.register_pass<pass::PullReshapeThroughReduce>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfNonConstAxes) {
|
||||
const auto input = std::make_shared<Parameter>(element::f32, PartialShape{5, 10, 15});
|
||||
const auto target_shape = Constant::create(element::i64, Shape{4}, {1, 5, 10, 15});
|
||||
|
Loading…
Reference in New Issue
Block a user