diff --git a/src/common/transformations/src/transformations/common_optimizations/reduce_reshape_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/reduce_reshape_fusion.cpp index 726789c34f0..daaa22f3773 100644 --- a/src/common/transformations/src/transformations/common_optimizations/reduce_reshape_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/reduce_reshape_fusion.cpp @@ -60,6 +60,7 @@ ov::pass::ReduceReshapeFusion::ReduceReshapeFusion() { std::dynamic_pointer_cast(reduce_node)) { logical_reduce_node->set_keep_dims(true); } + reduce_node->validate_and_infer_types(); reduce_node->set_friendly_name(reshape_node->get_friendly_name()); copy_runtime_info(reshape_node, reduce_node); replace_node(m.get_match_root(), reduce_node); diff --git a/src/common/transformations/tests/common_optimizations/reduce_reshape_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/reduce_reshape_fusion_tests.cpp index 8234c2030ba..106d7838e88 100644 --- a/src/common/transformations/tests/common_optimizations/reduce_reshape_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/reduce_reshape_fusion_tests.cpp @@ -10,11 +10,13 @@ #include #include #include +#include #include #include #include "common_test_utils/ngraph_test_utils.hpp" +using namespace std; using namespace testing; using namespace ov; using namespace opset9; @@ -154,3 +156,21 @@ TEST_F(TransformationTestsF, ReduceMeanReshapeFusionSkipIfMoreThanOneReduceConsu model = std::make_shared(NodeVector{reshape, add}, ParameterVector{input}); manager.register_pass(); } + +TEST(TransformationTests, ReduceMeanReshapeFusionAssertValidOutputShape) { + const auto input = make_shared(element::f32, PartialShape{1, 16, 16, 24}); + const auto reduce_axes = Constant::create(element::i64, Shape{2}, {1, 2}); + const auto reduce_mean = make_shared(input, reduce_axes, false); + const auto target_shape = Constant::create(element::i64, Shape{4}, {1, 1, 1, 24}); + const auto reshape = make_shared(reduce_mean, target_shape, false); + const auto order = Constant::create(element::i64, Shape{4}, {0, 3, 1, 2}); + const auto transpose = make_shared(reshape, order); + + auto model = make_shared(NodeVector{transpose}, ParameterVector{input}); + + pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(); + manager.register_pass(); + ASSERT_NO_THROW(manager.run_passes(model)); +}