[Transformations] Assert valid output shape on ReduceReshape fusion (#16712)

* Assert valid output shape on ReduceReshape fusion

* Remove useless test step

* Reduce tests flow

* _dummy_

to retrigger failed nonrestartable check
This commit is contained in:
Tomasz Jankowski 2023-04-06 13:55:57 +02:00 committed by GitHub
parent cafc7359c5
commit cb436112b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 0 deletions

View File

@ -60,6 +60,7 @@ ov::pass::ReduceReshapeFusion::ReduceReshapeFusion() {
std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduce_node)) {
logical_reduce_node->set_keep_dims(true);
}
reduce_node->validate_and_infer_types();
reduce_node->set_friendly_name(reshape_node->get_friendly_name());
copy_runtime_info(reshape_node, reduce_node);
replace_node(m.get_match_root(), reduce_node);

View File

@ -10,11 +10,13 @@
#include <openvino/pass/manager.hpp>
#include <string>
#include <transformations/common_optimizations/reduce_reshape_fusion.hpp>
#include <transformations/common_optimizations/transpose_to_reshape.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace std;
using namespace testing;
using namespace ov;
using namespace opset9;
@ -154,3 +156,21 @@ TEST_F(TransformationTestsF, ReduceMeanReshapeFusionSkipIfMoreThanOneReduceConsu
model = std::make_shared<Model>(NodeVector{reshape, add}, ParameterVector{input});
manager.register_pass<pass::ReduceReshapeFusion>();
}
TEST(TransformationTests, ReduceMeanReshapeFusionAssertValidOutputShape) {
const auto input = make_shared<Parameter>(element::f32, PartialShape{1, 16, 16, 24});
const auto reduce_axes = Constant::create(element::i64, Shape{2}, {1, 2});
const auto reduce_mean = make_shared<ReduceMean>(input, reduce_axes, false);
const auto target_shape = Constant::create(element::i64, Shape{4}, {1, 1, 1, 24});
const auto reshape = make_shared<Reshape>(reduce_mean, target_shape, false);
const auto order = Constant::create(element::i64, Shape{4}, {0, 3, 1, 2});
const auto transpose = make_shared<Transpose>(reshape, order);
auto model = make_shared<Model>(NodeVector{transpose}, ParameterVector{input});
pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<pass::ReduceReshapeFusion>();
manager.register_pass<pass::TransposeToReshape>();
ASSERT_NO_THROW(manager.run_passes(model));
}