[Transformations] Assert valid output shape on ReduceReshape fusion (#16712)
* Assert valid output shape on ReduceReshape fusion * Remove useless test step * Reduce tests flow * _dummy_ to retrigger failed nonrestartable check
This commit is contained in:
parent
cafc7359c5
commit
cb436112b2
@ -60,6 +60,7 @@ ov::pass::ReduceReshapeFusion::ReduceReshapeFusion() {
|
||||
std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduce_node)) {
|
||||
logical_reduce_node->set_keep_dims(true);
|
||||
}
|
||||
reduce_node->validate_and_infer_types();
|
||||
reduce_node->set_friendly_name(reshape_node->get_friendly_name());
|
||||
copy_runtime_info(reshape_node, reduce_node);
|
||||
replace_node(m.get_match_root(), reduce_node);
|
||||
|
@ -10,11 +10,13 @@
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <string>
|
||||
#include <transformations/common_optimizations/reduce_reshape_fusion.hpp>
|
||||
#include <transformations/common_optimizations/transpose_to_reshape.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace testing;
|
||||
using namespace ov;
|
||||
using namespace opset9;
|
||||
@ -154,3 +156,21 @@ TEST_F(TransformationTestsF, ReduceMeanReshapeFusionSkipIfMoreThanOneReduceConsu
|
||||
model = std::make_shared<Model>(NodeVector{reshape, add}, ParameterVector{input});
|
||||
manager.register_pass<pass::ReduceReshapeFusion>();
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ReduceMeanReshapeFusionAssertValidOutputShape) {
|
||||
const auto input = make_shared<Parameter>(element::f32, PartialShape{1, 16, 16, 24});
|
||||
const auto reduce_axes = Constant::create(element::i64, Shape{2}, {1, 2});
|
||||
const auto reduce_mean = make_shared<ReduceMean>(input, reduce_axes, false);
|
||||
const auto target_shape = Constant::create(element::i64, Shape{4}, {1, 1, 1, 24});
|
||||
const auto reshape = make_shared<Reshape>(reduce_mean, target_shape, false);
|
||||
const auto order = Constant::create(element::i64, Shape{4}, {0, 3, 1, 2});
|
||||
const auto transpose = make_shared<Transpose>(reshape, order);
|
||||
|
||||
auto model = make_shared<Model>(NodeVector{transpose}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.set_per_pass_validation(false);
|
||||
manager.register_pass<pass::ReduceReshapeFusion>();
|
||||
manager.register_pass<pass::TransposeToReshape>();
|
||||
ASSERT_NO_THROW(manager.run_passes(model));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user