Introduce PullThroughReduce and ReduceReshapeFusion transformations (#13047)

* Initial version of PullThroughReduce and ReduceReshapeFusion transformation

* remove std namespace

* headers + namespaces clean-up

* resolved problem with ambiguous namespace

* clang format applied

* tests refactor. part 1

* tests refactor. part.2

* changed includes

* review remarks

* tests refactor. part.3

* review remarks

* fix build problem

* remarks

* fix reserving memory

* review remarks

* styles applied

* handle axes in the middle in try_get_unsqueeze_axes_from_reshape

* handle case of scalar input shape

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Mateusz Bencer
2022-11-08 14:33:04 +00:00
committed by GitHub
parent 6f2eab4413
commit 116f294a9a
15 changed files with 877 additions and 65 deletions

View File

@@ -0,0 +1,57 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API PullThroughReduce;
class TRANSFORMATIONS_API PullUnsqueezeThroughReduce;
class TRANSFORMATIONS_API PullReshapeThroughReduce;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief PullUnsqueezeThroughReduce transformation
* The transformation pulls Unsqueeze operator though Reduce ops if possible.
* In the further processing such Unsqueeze can be often skipped as nop.
*/
class ov::pass::PullUnsqueezeThroughReduce : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PullUnsqueezeThroughReduce", "0");
PullUnsqueezeThroughReduce();
};
/**
* @ingroup ie_transformation_common_api
* @brief PullReshapeThroughReduce transformation
* The transformation pulls Reshape operator though Reduce ops if possible.
* In the further processing such Reshape can be often skipped as nop.
*/
class ov::pass::PullReshapeThroughReduce : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PullReshapeThroughReduce", "0");
PullReshapeThroughReduce();
};
/**
* @ingroup ie_transformation_common_api
* @brief PullThroughReduce transformation
* The transformation pulls Reshape or Unsqueeze operators though Reduce ops if possible.
* In the further processing such Reshape/Unsqueeze can be often skipped as nop.
*/
class ov::pass::PullThroughReduce : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("PullThroughReduce", "0");
PullThroughReduce() {
add_matcher<ov::pass::PullUnsqueezeThroughReduce>();
add_matcher<ov::pass::PullReshapeThroughReduce>();
}
};

View File

@@ -0,0 +1,27 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API ReduceReshapeFusion;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief ReduceReshapeFusion transformation
* Fuse ReduceOp(keep_dims=false)+Reshape to ReduceOp(keep_dims=true)
*/
class ov::pass::ReduceReshapeFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ReduceReshapeFusion", "0");
ReduceReshapeFusion();
};

View File

@@ -42,8 +42,10 @@
#include <transformations/common_optimizations/optimize_strided_slice.hpp>
#include <transformations/common_optimizations/pad_fusion.hpp>
#include <transformations/common_optimizations/prelu_fusion.hpp>
#include <transformations/common_optimizations/pull_through_reduce.hpp>
#include <transformations/common_optimizations/pull_transpose_through_fq.hpp>
#include <transformations/common_optimizations/random_uniform_fusion.hpp>
#include <transformations/common_optimizations/reduce_reshape_fusion.hpp>
#include <transformations/common_optimizations/relu_fake_quantize_fusion.hpp>
#include <transformations/common_optimizations/remove_concat_zero_dim_input.hpp>
#include <transformations/common_optimizations/remove_filtering_boxes_by_size.hpp>
@@ -132,10 +134,9 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
}
// workaround until dynamism in NMS is not supported
manager.register_pass<ngraph::pass::ConvertNmsGatherPathToUnsigned>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(m_use_shapes);
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();
manager.register_pass<ov::pass::PullThroughReduce>();
auto transpose_sinking = manager.register_pass<ngraph::pass::GraphRewrite>();
transpose_sinking->add_matcher<ngraph::pass::TransposeSinking>();
@@ -162,6 +163,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
common_fusions->add_matcher<ngraph::pass::ClampFusion>();
common_fusions->add_matcher<ngraph::pass::PadFusion>();
common_fusions->add_matcher<ngraph::pass::SoftmaxFusion>();
common_fusions->add_matcher<ov::pass::ReduceReshapeFusion>();
common_fusions->add_matcher<ngraph::pass::MVNFusion>();
common_fusions->add_matcher<ngraph::pass::DilatedConvolutionConverter>();
common_fusions->add_matcher<ngraph::pass::GeluFusion>();

View File

@@ -0,0 +1,231 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/pull_through_reduce.hpp"
#include <memory>
#include <vector>
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/util/reduction_base.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "sequnce_generator.hpp"
#include "transformations/utils/utils.hpp"
namespace {
// Adjust axes of Unsqueeze/Reduce ops after Unsqueeze pulling
// For example if we have:
// input(shape={5,10,15})
// |
// Unsqueeze(axes=[0,3]) -> output_shape = {1,5,10,1,15}
// |
// ReduceOp(axes=[2,4], keep_dims=false) -> output_shape = {1,5,1}
// after pulling it will be:
// input(shape={5,10,15})
// |
// ReduceOp(axes=[1,2], keep_dims=false) -> output_shape = {5}
// |
// Unsqueeze(axes=[0,2]) -> output_shape = {1,5,1}
const std::vector<int64_t> adjust_axes(const std::vector<int64_t>& axes_to_align,
const std::vector<int64_t>& offset_axes) {
auto number_of_axes_less_than = [&offset_axes](const int64_t current_axis) {
return std::count_if(std::begin(offset_axes),
std::end(offset_axes),
[&current_axis](const int64_t excluded_axis) {
return excluded_axis < current_axis;
});
};
std::vector<int64_t> result;
for (const auto& axis : axes_to_align) {
result.push_back(axis - number_of_axes_less_than(axis));
}
return result;
}
// Try to represent given Reshape node via Unsqueeze and calculate axes of such Unsqueeze
// - Reshape(input_shape={5,10,15}, target_shape={5,10,15,1}), 3 axis returned
// - Reshape(input_shape={5,10,15}, target_shape={1,5,10,15}), 0 axis returned
// - Reshape(input_shape={5,10,15}, target_shape={1,5,10,15,1}), 0 and 3 axes returned
// - Reshape(input_shape={5,10,15}, target_shape={5,10,1,15}), 2 axis is returned
std::vector<int64_t> try_get_unsqueeze_axes_from_reshape(const ov::Shape& target_shape, const ov::Shape& input_shape) {
std::vector<int64_t> result;
if (input_shape.size() == 0) { // scalar case - can be reshaped only to [1,..,1] shape
result.resize(target_shape.size());
std::iota(std::begin(result), std::end(result), 0);
return result;
}
auto cur_input_shape_elem_idx = 0;
auto cur_input_shape_elem = input_shape[cur_input_shape_elem_idx];
auto target_shape_idx = 0;
for (; target_shape_idx < target_shape.size(); ++target_shape_idx) {
if (cur_input_shape_elem == target_shape[target_shape_idx] &&
cur_input_shape_elem_idx + 1 < input_shape.size()) {
++cur_input_shape_elem_idx;
cur_input_shape_elem = input_shape[cur_input_shape_elem_idx];
} else if (target_shape[target_shape_idx] == 1 &&
(target_shape_idx >= input_shape.size() + result.size() || cur_input_shape_elem != 1)) {
result.push_back(target_shape_idx);
}
}
if (cur_input_shape_elem_idx == input_shape.size() - 1 && target_shape_idx == target_shape.size()) {
return result;
} else {
return {};
}
return result;
}
// Update given reshape_input_shape by inserting "1" dimension on the postion represented by axes_to_insert
std::shared_ptr<ov::opset9::Constant> update_reshape_target_shape(const ov::Shape& reshape_input_shape,
const std::vector<int64_t>& axes_to_insert) {
auto result = std::vector<int64_t>(std::begin(reshape_input_shape), std::end(reshape_input_shape));
for (const auto& axis : axes_to_insert) {
result.insert(std::next(std::begin(result), axis), 1);
}
return ov::opset9::Constant::create(ov::element::i64, ov::Shape{result.size()}, result);
}
// Return true if given inputs have some common elements, otherwise return false.
bool have_same_axes(const std::vector<int64_t>& unsqueeze_axes, const std::vector<int64_t>& reduce_op_axes) {
return std::find_first_of(std::begin(unsqueeze_axes),
std::end(unsqueeze_axes),
std::begin(reduce_op_axes),
std::end(reduce_op_axes)) != std::end(unsqueeze_axes);
}
} // namespace
ov::pass::PullUnsqueezeThroughReduce::PullUnsqueezeThroughReduce() {
MATCHER_SCOPE(PullUnsqueezeThroughReduce);
const auto input = pattern::any_input(pattern::has_static_rank());
const auto unsqueeze_axes = pattern::wrap_type<opset9::Constant>();
const auto unsqueeze = pattern::wrap_type<opset9::Unsqueeze>({input, unsqueeze_axes}, pattern::has_static_rank());
const auto reduce_axes = pattern::wrap_type<opset9::Constant>();
const auto reduce = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{unsqueeze, reduce_axes});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
const auto input_node = pattern_map.at(input);
const auto reduce_node =
std::dynamic_pointer_cast<op::util::ReductionBase>(pattern_map.at(reduce).get_node_shared_ptr());
const auto unsqueeze_node = pattern_map.at(unsqueeze).get_node_shared_ptr();
auto unsqueeze_axes_input =
std::dynamic_pointer_cast<opset9::Constant>(pattern_map.at(unsqueeze_axes).get_node_shared_ptr());
auto reduce_axes_input =
std::dynamic_pointer_cast<opset9::Constant>(pattern_map.at(reduce_axes).get_node_shared_ptr());
if (!unsqueeze_axes_input || !reduce_axes_input || !reduce_node) {
return false;
}
auto unsqueeze_axes_val = unsqueeze_axes_input->cast_vector<int64_t>();
normalize_axes(unsqueeze_node.get(),
unsqueeze_node->get_output_partial_shape(0).rank().get_length(),
unsqueeze_axes_val);
const auto reduce_axes_val = reduce_node->get_reduction_axes().to_vector();
if (have_same_axes(unsqueeze_axes_val, reduce_axes_val)) {
return false;
}
const bool keep_dims = reduce_node->get_keep_dims();
if (!keep_dims) {
const auto unsqueeze_adjusted_axes = adjust_axes(unsqueeze_axes_val, reduce_axes_val);
if (unsqueeze_adjusted_axes != unsqueeze_axes_val) {
unsqueeze_axes_input = opset9::Constant::create(unsqueeze_axes_input->get_element_type(),
unsqueeze_axes_input->get_shape(),
unsqueeze_adjusted_axes);
}
}
const auto reduce_adjusted_axes = adjust_axes(reduce_axes_val, unsqueeze_axes_val);
if (reduce_adjusted_axes != reduce_axes_val) {
reduce_axes_input = opset9::Constant::create(reduce_axes_input->get_element_type(),
reduce_axes_input->get_shape(),
reduce_adjusted_axes);
}
const auto new_reduce_node = reduce_node->clone_with_new_inputs({input_node, reduce_axes_input});
new_reduce_node->set_friendly_name(unsqueeze_node->get_friendly_name());
const auto new_unsqueeze_node = unsqueeze_node->clone_with_new_inputs({new_reduce_node, unsqueeze_axes_input});
new_unsqueeze_node->set_friendly_name(reduce_node->get_friendly_name());
copy_runtime_info({reduce_node, unsqueeze_node}, {new_reduce_node, new_unsqueeze_node});
replace_node(m.get_match_root(), new_unsqueeze_node);
return true;
};
auto m = std::make_shared<pattern::Matcher>(reduce, matcher_name);
register_matcher(m, callback);
}
ov::pass::PullReshapeThroughReduce::PullReshapeThroughReduce() {
MATCHER_SCOPE(PullReshapeThroughReduce);
const auto input = pattern::any_input(pattern::has_static_shape());
const auto reshape_target_shape = pattern::wrap_type<opset9::Constant>();
const auto reshape =
pattern::wrap_type<opset9::Reshape>({input, reshape_target_shape}, pattern::has_static_shape());
const auto reduce_axes = pattern::wrap_type<opset9::Constant>();
const auto reduce = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{reshape, reduce_axes});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
const auto input_node = pattern_map.at(input).get_node_shared_ptr();
const auto reduce_node =
std::dynamic_pointer_cast<op::util::ReductionBase>(pattern_map.at(reduce).get_node_shared_ptr());
if (!reduce_node) {
return false;
}
const auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
const auto unsqueeze_axes =
try_get_unsqueeze_axes_from_reshape(reshape_node->get_shape(), input_node->get_shape());
if (unsqueeze_axes.empty()) {
return false;
}
const auto reduce_axes_val = reduce_node->get_reduction_axes().to_vector();
if (have_same_axes(unsqueeze_axes, reduce_axes_val)) {
return false;
}
const auto unsqueeze_adjusted_axes = adjust_axes(unsqueeze_axes, reduce_axes_val);
const auto reduce_adjusted_axes = adjust_axes(reduce_axes_val, unsqueeze_axes);
auto reduce_axes_input =
std::dynamic_pointer_cast<opset9::Constant>(pattern_map.at(reduce_axes).get_node_shared_ptr());
if (!reduce_axes_input) {
return false;
}
if (reduce_adjusted_axes != reduce_axes_val) {
reduce_axes_input = opset9::Constant::create(reduce_axes_input->get_element_type(),
reduce_axes_input->get_shape(),
reduce_adjusted_axes);
}
const auto new_reduce_node = reduce_node->clone_with_new_inputs({input_node, reduce_axes_input});
new_reduce_node->set_friendly_name(reshape_node->get_friendly_name());
const auto new_reshape_node = reshape_node->clone_with_new_inputs(
{new_reduce_node, update_reshape_target_shape(new_reduce_node->get_shape(), unsqueeze_adjusted_axes)});
new_reshape_node->set_friendly_name(reduce_node->get_friendly_name());
copy_runtime_info({reduce_node, reshape_node}, {new_reduce_node, new_reshape_node});
replace_node(m.get_match_root(), new_reshape_node);
return true;
};
auto m = std::make_shared<pattern::Matcher>(reduce, matcher_name);
register_matcher(m, callback);
}

View File

@@ -0,0 +1,69 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/reduce_reshape_fusion.hpp"
#include <memory>
#include <vector>
#include "itt.hpp"
#include "openvino/op/util/reduction_base.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
ov::pass::ReduceReshapeFusion::ReduceReshapeFusion() {
MATCHER_SCOPE(ReduceReshapeFusion);
const auto reduce_axes = pattern::wrap_type<opset9::Constant>();
const auto reduce = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{pattern::any_input(), reduce_axes},
pattern::has_static_shape());
const auto reshape =
pattern::wrap_type<opset9::Reshape>({reduce, pattern::any_input()}, pattern::has_static_shape());
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
const auto reduce_node =
std::dynamic_pointer_cast<op::util::ReductionBase>(pattern_map.at(reduce).get_node_shared_ptr());
if (!reduce_node) {
return false;
}
const bool keep_dims = reduce_node->get_keep_dims();
if (keep_dims) {
return false;
}
const auto reduce_axes_val = reduce_node->get_reduction_axes().to_vector();
const auto& reshape_shape = reshape_node->get_shape();
auto reduce_shape_if_keep_dims = reduce_node->get_shape();
for (const auto& axis : reduce_axes_val) {
reduce_shape_if_keep_dims.insert(std::next(std::begin(reduce_shape_if_keep_dims), axis), 1);
}
if (reduce_shape_if_keep_dims != reshape_shape) {
return false;
}
if (auto arithmetic_reduce_node =
std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduce_node)) {
arithmetic_reduce_node->set_keep_dims(true);
} else if (auto logical_reduce_node =
std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduce_node)) {
logical_reduce_node->set_keep_dims(true);
}
reduce_node->set_friendly_name(reshape_node->get_friendly_name());
copy_runtime_info(reshape_node, reduce_node);
replace_node(m.get_match_root(), reduce_node);
return true;
};
auto m = std::make_shared<pattern::Matcher>(reshape, matcher_name);
register_matcher(m, callback);
}

View File

@@ -28,17 +28,6 @@ public:
OPENVINO_OP("ArithmeticReduction", "util");
BWDCMP_RTTI_DECLARATION;
void validate_and_infer_types() override;
/// \return true if reduction axes are constant else false.
bool reduction_axes_constant() const;
/// \return The axis positions (0-based) to be eliminated through reduction.
/// \throws CheckFailure if the reduction axes are not constant. (Use
/// reduction_axes_constant to check.)
const AxisSet get_reduction_axes() const;
/// \brief Change the reduction axes
void set_reduction_axes(const AxisSet& reduction_axes);
};
} // namespace util
} // namespace op

View File

@@ -27,7 +27,7 @@ public:
/// \return If set to 1 it holds axes that are used for reduction.
/// For each such axis, output dimension is equal to 1.
bool get_keep_dims() const {
bool get_keep_dims() const override {
return m_keep_dims;
}
void set_keep_dims(bool keep_dims) {

View File

@@ -32,15 +32,6 @@ public:
OPENVINO_OP("LogicalReduction", "util");
BWDCMP_RTTI_DECLARATION;
void validate_and_infer_types() override;
/// \return true if reduction axes are constant else false.
bool reduction_axes_constant() const;
/// \return The axis positions (0-based) to be eliminated through reduction.
/// \throws CheckFailure if the reduction axes are not constant. (Use
/// reduction_axes_constant to check.)
const AxisSet get_reduction_axes() const;
void set_reduction_axes(const AxisSet& reduction_axes);
};
} // namespace util
} // namespace op

View File

@@ -28,7 +28,7 @@ public:
/// \return If set to 1 it holds axes that are used for reduction.
/// For each such axis, output dimension is equal to 1.
bool get_keep_dims() const {
bool get_keep_dims() const override {
return m_keep_dims;
}
void set_keep_dims(bool keep_dims) {

View File

@@ -31,6 +31,23 @@ protected:
public:
OPENVINO_OP("ReductionBase", "util");
BWDCMP_RTTI_DECLARATION;
/// \return true if reduction axes are constant else false.
bool reduction_axes_constant() const;
/// \return The axis positions (0-based) to be eliminated through reduction.
/// \throws CheckFailure if the reduction axes are not constant. (Use
/// reduction_axes_constant to check.)
const AxisSet get_reduction_axes() const;
/// \brief Change the reduction axes
void set_reduction_axes(const AxisSet& reduction_axes);
// \brief Returns true if keep_dims is set to true explicitly.
// Otherwise, (also keep_dims not handled) returns false.
virtual bool get_keep_dims() const {
return false;
}
};
} // namespace util
} // namespace op

View File

@@ -5,7 +5,6 @@
#include "ngraph/op/util/arithmetic_reduction.hpp"
#include "itt.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
@@ -17,27 +16,6 @@ ov::op::util::ArithmeticReduction::ArithmeticReduction() = default;
ov::op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ReductionBase(arg, reduction_axes) {}
bool ov::op::util::ArithmeticReduction::reduction_axes_constant() const {
return ov::is_type<ngraph::op::Constant>(input_value(1).get_node());
}
const ov::AxisSet ov::op::util::ArithmeticReduction::get_reduction_axes() const {
AxisSet axes;
if (const auto& const_op = get_constant_from_source(input_value(1))) {
const auto const_data = const_op->cast_vector<int64_t>();
const auto input_data_rank = get_input_partial_shape(0).rank();
const auto normalized_axes = ngraph::normalize_axes(get_friendly_name(), const_data, input_data_rank);
axes = AxisSet{normalized_axes};
}
return axes;
}
void ov::op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_axes) {
this->input(1).replace_source_output(
ngraph::op::Constant::create(element::i64, ov::Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0));
}
void ov::op::util::ArithmeticReduction::validate_and_infer_types() {
OV_OP_SCOPE(util_ArithmeticReduction_validate_and_infer_types);

View File

@@ -25,24 +25,6 @@ op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const Axis
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const Output<Node>& reduction_axes)
: ReductionBase(arg, reduction_axes) {}
bool op::util::LogicalReduction::reduction_axes_constant() const {
return ngraph::has_and_set_equal_bounds(input_value(1));
}
const AxisSet op::util::LogicalReduction::get_reduction_axes() const {
AxisSet axes;
if (auto const_op = get_constant_from_source(input_value(1))) {
axes = const_op->get_axis_set_val();
}
return axes;
}
void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axes) {
this->input(1).replace_source_output(
ngraph::op::Constant::create(element::i64, ov::Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0));
}
void op::util::LogicalReduction::validate_and_infer_types() {
OV_OP_SCOPE(util_LogicalReduction_validate_and_infer_types);

View File

@@ -2,8 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/util/reduction_base.hpp"
#include "openvino/op/util/reduction_base.hpp"
#include "openvino/op/constant.hpp"
#include "reduce_shape_inference.hpp"
using namespace std;
@@ -20,3 +21,24 @@ ov::PartialShape ov::op::util::ReductionBase::infer_reduction_output_shape(const
reduce_shape_infer(this, keep_dims, get_input_partial_shape(0), output_shape);
return output_shape;
}
bool ov::op::util::ReductionBase::reduction_axes_constant() const {
return ov::is_type<op::v0::Constant>(input_value(1).get_node());
}
const ov::AxisSet ov::op::util::ReductionBase::get_reduction_axes() const {
AxisSet axes;
if (const auto& const_op = get_constant_from_source(input_value(1))) {
const auto const_data = const_op->cast_vector<int64_t>();
const auto input_data_rank = get_input_partial_shape(0).rank();
const auto normalized_axes = ov::normalize_axes(get_friendly_name(), const_data, input_data_rank);
axes = AxisSet{normalized_axes};
}
return axes;
}
void ov::op::util::ReductionBase::set_reduction_axes(const AxisSet& reduction_axes) {
this->input(1).replace_source_output(
op::v0::Constant::create(element::i64, ov::Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0));
}

View File

@@ -0,0 +1,300 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <openvino/core/model.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/common_optimizations/pull_through_reduce.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ov;
using namespace opset9;
namespace {
template<typename ReduceType>
std::shared_ptr<Model> generate_unsqueeze_model(element::Type in_type,
PartialShape in_shape,
std::vector<int64_t> unsqueeze_axes,
std::vector<int64_t> reduce_axes,
bool keep_dims = false) {
const auto input = std::make_shared<Parameter>(in_type, in_shape);
const auto unsqueeze_axes_const = Constant::create(element::i64, Shape{unsqueeze_axes.size()}, unsqueeze_axes);
const auto unsqueeze = std::make_shared<Unsqueeze>(input, unsqueeze_axes_const);
const auto reduce_axes_const = Constant::create(element::i64, Shape{reduce_axes.size()}, reduce_axes);
const auto reduce_mean = std::make_shared<ReduceType>(unsqueeze, reduce_axes_const, keep_dims);
return std::make_shared<Model>(NodeVector{reduce_mean}, ParameterVector{input});
}
template<typename ReduceType>
std::shared_ptr<Model> generate_unsqueeze_ref_model(element::Type in_type,
PartialShape in_shape,
std::vector<int64_t> unsqueeze_axes,
std::vector<int64_t> reduce_axes,
bool keep_dims = false) {
const auto input = std::make_shared<Parameter>(in_type, in_shape);
const auto unsqueeze_axes_const = Constant::create(element::i64, Shape{unsqueeze_axes.size()}, unsqueeze_axes);
const auto reduce_axes_const = Constant::create(element::i64, Shape{reduce_axes.size()}, reduce_axes);
const auto reduce_mean = std::make_shared<ReduceType>(input, reduce_axes_const, keep_dims);
const auto unsqueeze = std::make_shared<Unsqueeze>(reduce_mean, unsqueeze_axes_const);
return std::make_shared<Model>(NodeVector{unsqueeze}, ParameterVector{input});
}
template<typename ReduceType>
std::shared_ptr<Model> generate_reshape_model(element::Type in_type,
PartialShape in_shape,
std::vector<int64_t> reshape_target_shape,
std::vector<int64_t> reduce_axes,
bool keep_dims = false,
bool reshape_special_zero = false) {
const auto input = std::make_shared<Parameter>(in_type, in_shape);
const auto reshape_target_shape_const = Constant::create(element::i64, Shape{reshape_target_shape.size()}, reshape_target_shape);
const auto reshape = std::make_shared<Reshape>(input, reshape_target_shape_const, reshape_special_zero);
const auto reduce_axes_const = Constant::create(element::i64, Shape{reduce_axes.size()}, reduce_axes);
const auto reduce_mean = std::make_shared<ReduceType>(reshape, reduce_axes_const, keep_dims);
return std::make_shared<Model>(NodeVector{reduce_mean}, ParameterVector{input});
}
template<typename ReduceType>
std::shared_ptr<Model> generate_reshape_ref_model(element::Type in_type,
PartialShape in_shape,
std::vector<int64_t> reshape_target_shape,
std::vector<int64_t> reduce_axes,
bool keep_dims = false,
bool reshape_special_zero = false) {
const auto input = std::make_shared<Parameter>(in_type, in_shape);
const auto reduce_axes_const = Constant::create(element::i64, Shape{reduce_axes.size()}, reduce_axes);
const auto reduce_mean = std::make_shared<ReduceType>(input, reduce_axes_const, keep_dims);
const auto reshape_target_shape_const = Constant::create(element::i64, Shape{reshape_target_shape.size()}, reshape_target_shape);
const auto reshape = std::make_shared<Reshape>(reduce_mean, reshape_target_shape_const, reshape_special_zero);
return std::make_shared<Model>(NodeVector{reshape}, ParameterVector{input});
}
} // namespace
struct PullUnsqueezeParams {
element::Type in_type;
PartialShape in_shape;
std::vector<int64_t> unsqueeze_axes;
std::vector<int64_t> ref_unsqueeze_axes;
std::vector<int64_t> reduce_axes;
std::vector<int64_t> ref_reduce_axes;
bool keep_dims;
};
class PullUnsqueezeThroughReduceMean
: public WithParamInterface<PullUnsqueezeParams>,
public TransformationTestsF {
};
TEST_P(PullUnsqueezeThroughReduceMean, PullUnsqueezeThroughReduceMeanPattern) {
const auto& p = GetParam();
{
model = generate_unsqueeze_model<ReduceMean>(p.in_type, p.in_shape, p.unsqueeze_axes, p.reduce_axes, p.keep_dims);
manager.register_pass<pass::PullUnsqueezeThroughReduce>();
}
{
model_ref = generate_unsqueeze_ref_model<ReduceMean>(p.in_type, p.in_shape, p.ref_unsqueeze_axes, p.ref_reduce_axes, p.keep_dims);
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
static const std::vector<PullUnsqueezeParams> reduce_mean_params = {
PullUnsqueezeParams{element::f32, {5, 10, 15}, {0}, {0}, {2}, {1}},
// unsqueeze axes greater than reduce axes
PullUnsqueezeParams{element::f32, {5, 10, 15, 20}, {0, 2, 3}, {0, 2, 3}, {5, 6}, {2, 3}},
// unsqueeze axes lower than reduce axes
PullUnsqueezeParams{element::f32, {5, 10, 15}, {3, 4, 5}, {1, 2, 3}, {0, 2}, {0, 2}},
// unsqueeze axes between reduce axes
PullUnsqueezeParams{element::f32, {5, 10, 15}, {1, 3, 5}, {0, 1, 2}, {0, 2, 4}, {0, 1, 2}},
// unsqueeze axes between reduce axes 2
PullUnsqueezeParams{element::f32, {1, 10, 1, 20}, {0, 1, 3, 4}, {0, 1, 2, 3}, {2, 5}, {0, 1}},
// unsqueeze axes between reduce axes, keep_dims=true
PullUnsqueezeParams{element::f32, {5, 10, 15}, {1, 3, 5}, {1, 3, 5}, {0, 2, 4}, {0, 1, 2}, true},
// negative unsqueeze axes between negative reduce axes
PullUnsqueezeParams{element::f32, {5, 10, 15}, {1, -3, -1}, {0, 1, 2}, {-2, 2, 0}, {0, 1, 2}},
// dynamic input
PullUnsqueezeParams{element::f32, {5, Dimension::dynamic(), 15}, {0}, {0}, {2}, {1}},
};
INSTANTIATE_TEST_SUITE_P(PullUnsqueezeThroughReduceMean, PullUnsqueezeThroughReduceMean, ValuesIn(reduce_mean_params));
class PullUnsqueezeThroughReduceLogicalOr
: public WithParamInterface<PullUnsqueezeParams>,
public TransformationTestsF {
};
TEST_P(PullUnsqueezeThroughReduceLogicalOr, PullUnsqueezeThroughReduceLogicalOrPattern) {
const auto& p = GetParam();
{
model = generate_unsqueeze_model<ReduceLogicalOr>(p.in_type, p.in_shape, p.unsqueeze_axes, p.reduce_axes, p.keep_dims);
manager.register_pass<pass::PullUnsqueezeThroughReduce>();
}
{
model_ref = generate_unsqueeze_ref_model<ReduceLogicalOr>(p.in_type, p.in_shape, p.ref_unsqueeze_axes, p.ref_reduce_axes, p.keep_dims);
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
static const std::vector<PullUnsqueezeParams> reduce_logical_or_params = {
PullUnsqueezeParams{element::boolean, {5, 10, 15}, {0}, {0}, {2}, {1}},
// unsqueeze axes between reduce axes, keep_dims=true
PullUnsqueezeParams{element::boolean, {1, 10, 1, 20}, {0, 1, 3, 4}, {0, 1, 3, 4}, {2, 5}, {0, 1}, true},
};
INSTANTIATE_TEST_SUITE_P(PullUnsqueezeThroughReduceLogicalOr, PullUnsqueezeThroughReduceLogicalOr, ValuesIn(reduce_logical_or_params));
TEST_F(TransformationTestsF, PullUnsqueezeThroughReduceSkipIfTheSameAxes) {
model = generate_unsqueeze_model<ReduceMean>(element::f32, {5, 10, 15}, {0, 1}, {1, 2});
manager.register_pass<pass::PullUnsqueezeThroughReduce>();
}
TEST_F(TransformationTestsF, PullUnsqueezeThroughReduceSkipIfNotConstAxes) {
const auto input = std::make_shared<Parameter>(element::f32, PartialShape{5, Dimension::dynamic(), 15});
const auto unsqueeze_axes = std::make_shared<Parameter>(element::i64, Shape{});
const auto unsqueeze = std::make_shared<Unsqueeze>(input, unsqueeze_axes);
const auto reduce_axes = Constant::create(element::i64, Shape{}, {2});
const auto reduce_mean = std::make_shared<ReduceMean>(unsqueeze, reduce_axes);
model = std::make_shared<Model>(NodeVector{reduce_mean}, ParameterVector{input, unsqueeze_axes});
manager.register_pass<pass::PullUnsqueezeThroughReduce>();
}
struct PullReshapeParams {
element::Type in_type;
PartialShape in_shape;
std::vector<int64_t> target_shape;
std::vector<int64_t> ref_target_shape;
std::vector<int64_t> reduce_axes;
std::vector<int64_t> ref_reduce_axes;
bool keep_dims;
bool reshape_special_zero;
};
class PullReshapeThroughReduceMean
: public WithParamInterface<PullReshapeParams>,
public TransformationTestsF {
};
TEST_P(PullReshapeThroughReduceMean, PullReshapeThroughReduceMeanPattern) {
const auto& p = GetParam();
{
model = generate_reshape_model<ReduceMean>(p.in_type, p.in_shape, p.target_shape, p.reduce_axes, p.keep_dims, p.reshape_special_zero);
manager.register_pass<pass::PullReshapeThroughReduce>();
}
{
model_ref = generate_reshape_ref_model<ReduceMean>(p.in_type, p.in_shape, p.ref_target_shape, p.ref_reduce_axes, p.keep_dims, p.reshape_special_zero);
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
static const std::vector<PullReshapeParams> reduce_mean_reshape_params = {
PullReshapeParams{element::f32, {5, 10, 15}, {1, 5, 10, 15}, {1, 5, 15}, {2}, {1}},
// insert axes at the end
PullReshapeParams{element::f32, {5, 10, 15}, {5, 10, 15, 1, 1}, {5, 1, 1}, {1, 2}, {1, 2}},
// insert axes at the begin
PullReshapeParams{element::f32, {5, 10, 15}, {1, 1, 1, 5, 10, 15}, {1, 1, 1, 10}, {3, 5}, {0, 2}},
// insert axes on both sides
PullReshapeParams{element::f32, {5, 10, 15}, {1, 1, 5, 10, 15, 1}, {1, 1, 15, 1}, {2, 3}, {0, 1}},
// insert axes on both sides 2
PullReshapeParams{element::f32, {1, 5, 10, 1}, {1, 1, 5, 10, 1, 1}, {1, 1, 1, 1}, {2, 3}, {1, 2}},
// insert axes in the middle
PullReshapeParams{element::f32, {4, 5, 6}, {1, 1, 4, 5, 6, 1, 1}, {1, 1, 1, 1}, {2, 3, 4}, {0, 1, 2}},
PullReshapeParams{element::f32, {4, 5, 6}, {1, 1, 4, 1, 5, 6, 1, 1}, {1, 1, 1, 1, 1}, {2, 4, 5}, {0, 1, 2}},
PullReshapeParams{element::f32, {4, 5, 6}, {1, 1, 4, 1, 1, 5, 6, 1, 1}, {1, 1, 1, 1, 1, 1}, {2, 5, 6}, {0, 1, 2}},
PullReshapeParams{element::f32, {4, 1, 1}, {4, 1, 1, 1}, {1}, {0, 1, 2}, {0, 1, 2}},
PullReshapeParams{element::f32, {1, 1}, {1, 1, 1}, {1, 1}, {1}, {1}},
PullReshapeParams{element::f32, {2, 1, 3, 1}, {1, 2, 1, 1, 3, 1, 1}, {1, 1, 1}, {1, 2, 4, 5}, {0, 1, 2, 3}},
PullReshapeParams{element::f32, {1, 3, 1}, {1, 1, 1, 3, 1, 1}, {1, 1, 1}, {0, 3, 4}, {0, 1, 2}},
// insert axes on both sides, keep_dims=true
PullReshapeParams{element::f32, {5, 10, 15}, {1, 5, 10, 15}, {1, 1, 1, 1}, {1, 2, 3}, {0, 1, 2}, true},
// negative axes
PullReshapeParams{element::f32, {5, 10, 15}, {1, 5, 10, 15, 1}, {1, 5, 1}, {-2, -3}, {1, 2}},
// special zero true
PullReshapeParams{element::f32, {5, 10, 15}, {5, 0, -1, 1}, {5, 10, 1}, {2}, {2}, false, true},
};
INSTANTIATE_TEST_SUITE_P(PullReshapeThroughReduceMean, PullReshapeThroughReduceMean, ValuesIn(reduce_mean_reshape_params));
class PullReshapeThroughReduceLogicalOr
: public WithParamInterface<PullReshapeParams>,
public TransformationTestsF {
};
TEST_P(PullReshapeThroughReduceLogicalOr, PullReshapeThroughReduceLogicalOrPattern) {
const auto& p = GetParam();
{
model = generate_reshape_model<ReduceLogicalOr>(p.in_type, p.in_shape, p.target_shape, p.reduce_axes, p.keep_dims, p.reshape_special_zero);
manager.register_pass<pass::PullReshapeThroughReduce>();
}
{
model_ref = generate_reshape_ref_model<ReduceLogicalOr>(p.in_type, p.in_shape, p.ref_target_shape,
p.ref_reduce_axes, p.keep_dims, p.reshape_special_zero);
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
static const std::vector<PullReshapeParams> reduce_logical_or_reshape_params = {
PullReshapeParams{element::boolean, {5, 10, 15}, {1, 5, 10, 15}, {1, 5, 15}, {2}, {1}},
// keep_dims=true
PullReshapeParams{element::boolean, {1, 10, 1, 20}, {1, 1, 10, 1, 20}, {1, 1, 1, 1, 1}, {2, 4}, {1, 3}, true},
};
INSTANTIATE_TEST_SUITE_P(PullReshapeThroughReduceLogicalOr, PullReshapeThroughReduceLogicalOr, ValuesIn(reduce_logical_or_reshape_params));
TEST_F(TransformationTestsF, PullReshapeThroughReduceMeanSkipIfDynamicInput) {
model = generate_reshape_model<ReduceMean>(element::f32, {5, Dimension::dynamic(), 15}, {1, 5, 10, 15}, {2});
manager.register_pass<pass::PullReshapeThroughReduce>();
}
TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfTheSameAxes) {
model = generate_reshape_model<ReduceMean>(element::f32, {5, 10, 15}, {1, 5, 10, 15}, {0});
manager.register_pass<pass::PullReshapeThroughReduce>();
}
TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfTheSameAxesScalarCase) {
model = generate_reshape_model<ReduceMean>(element::f32, {}, {1}, {0});
manager.register_pass<pass::PullReshapeThroughReduce>();
}
TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfTheSameAxesScalarCase2) {
model = generate_reshape_model<ReduceMean>(element::f32, {}, {1, 1, 1}, {1});
manager.register_pass<pass::PullReshapeThroughReduce>();
}
TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfNonConstAxes) {
const auto input = std::make_shared<Parameter>(element::f32, PartialShape{5, 10, 15});
const auto target_shape = Constant::create(element::i64, Shape{4}, {1, 5, 10, 15});
const auto reshape = std::make_shared<Reshape>(input, target_shape, false);
const auto reduce_axes = std::make_shared<Parameter>(element::i64, PartialShape{});
const auto reduce_mean = std::make_shared<ReduceMean>(reshape, reduce_axes);
model = std::make_shared<Model>(NodeVector{reduce_mean}, ParameterVector{input, reduce_axes});
manager.register_pass<pass::PullReshapeThroughReduce>();
}
TEST_F(TransformationTestsF, PullReshapeThroughReduceMeanSkipIfDynamicReshapeOutputShape) {
const auto input = std::make_shared<Parameter>(element::f32, PartialShape{5, 10, 15});
const auto target_shape = std::make_shared<Parameter>(element::i32, PartialShape{4});
const auto reshape = std::make_shared<Reshape>(input, target_shape, false);
const auto reduce_axes = Constant::create(element::i64, Shape{}, {2});
const auto reduce_mean = std::make_shared<ReduceMean>(reshape, reduce_axes);
model = std::make_shared<Model>(NodeVector{reduce_mean}, ParameterVector{input, target_shape});
manager.register_pass<pass::PullReshapeThroughReduce>();
}

View File

@@ -0,0 +1,147 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <openvino/core/model.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/common_optimizations/reduce_reshape_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ov;
using namespace opset9;
namespace {
template<typename ReduceType>
std::shared_ptr<Model> generate_model(element::Type in_type,
PartialShape in_shape,
std::vector<int64_t> reshape_target_shape,
std::vector<int64_t> reduce_axes,
bool reduce_keep_dims,
bool reshape_special_zero) {
const auto input = std::make_shared<Parameter>(in_type, in_shape);
const auto reduce_axes_const = Constant::create(element::i64, Shape{reduce_axes.size()}, reduce_axes);
const auto reduce_mean = std::make_shared<ReduceType>(input, reduce_axes_const, reduce_keep_dims);
const auto target_shape = Constant::create(element::i64, Shape{reshape_target_shape.size()}, reshape_target_shape);
const auto reshape = std::make_shared<Reshape>(reduce_mean, target_shape, reshape_special_zero);
return std::make_shared<Model>(NodeVector{reshape}, ParameterVector{input});
}
template<typename ReduceType>
std::shared_ptr<Model> generate_ref_model(element::Type in_type,
PartialShape in_shape,
std::vector<int64_t> reduce_axes) {
const auto input = std::make_shared<Parameter>(in_type, in_shape);
const auto reduce_axes_const = Constant::create(element::i64, Shape{reduce_axes.size()}, reduce_axes);
const auto reduce_mean = std::make_shared<ReduceType>(input, reduce_axes_const, true);
return std::make_shared<Model>(NodeVector{reduce_mean}, ParameterVector{input});
}
} // namespace
struct ReduceReshapeFusionParams {
element::Type in_type;
PartialShape in_shape;
std::vector<int64_t> reshape_target_shape;
std::vector<int64_t> reduce_axes;
bool keep_dims;
bool reshape_special_zero;
};
class ReduceReshapeFusion
: public WithParamInterface<ReduceReshapeFusionParams>,
public TransformationTestsF {
};
TEST_P(ReduceReshapeFusion, ReduceReshapeFusionPattern) {
const auto& p = GetParam();
{
model = generate_model<ReduceMean>(p.in_type, p.in_shape, p.reshape_target_shape, p.reduce_axes, p.keep_dims, p.reshape_special_zero);
manager.register_pass<pass::ReduceReshapeFusion>();
}
{
model_ref = generate_ref_model<ReduceMean>(p.in_type, p.in_shape, p.reduce_axes);
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
static const std::vector<ReduceReshapeFusionParams> params = {
ReduceReshapeFusionParams{element::f32, {5, 10, 15}, {5, 1, 15}, {1}, false, false},
// many axes
ReduceReshapeFusionParams{element::f32, {5, 10, 15, 20}, {5, 1, 1, 20}, {1, 2}, false, false},
// many axes 2
ReduceReshapeFusionParams{element::f32, {5, 10, 15, 1}, {5, 1, 1, 1}, {1, 2}, false, false},
// special zero
ReduceReshapeFusionParams{element::f32, {5, 10, 15, 20}, {-1, 0, 1, 1}, {2, 3}, false, true},
// negative axes
ReduceReshapeFusionParams{element::f32, {5, 10, 15, 20}, {5, 1, 15, 1}, {-1, -3}, false, false},
// negative axes 2
ReduceReshapeFusionParams{element::f32, {5, 10, 15, 20}, {5, 1, 1, 20}, {-2, -3}, false, false},
};
INSTANTIATE_TEST_SUITE_P(ReduceReshapeFusion, ReduceReshapeFusion, ValuesIn(params));
TEST_F(TransformationTestsF, ReduceOrReshapeFusion) {
{
model = generate_model<ReduceLogicalOr>(element::boolean, {5, 10, 15, 20}, {5, 1, 1, 20}, {1, 2}, false, false);
manager.register_pass<pass::ReduceReshapeFusion>();
}
{
model_ref = generate_ref_model<ReduceLogicalOr>(element::boolean, {5, 10, 15, 20}, {1, 2});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, ReduceMeanReshapeFusionSkipIfOneInNotAxisPosition) {
model = generate_model<ReduceMean>(element::f32, {5, 10, 15, 1}, {5, 1, 1, 1, 1}, {1, 2}, false, false);
manager.register_pass<pass::ReduceReshapeFusion>();
}
TEST_F(TransformationTestsF, ReduceMeanReshapeFusionSkipIfReshapeNotCompatible) {
model = generate_model<ReduceMean>(element::f32, {5, 10, 15, 20}, {20, 1, 1, 5}, {1, 2}, false, false);
manager.register_pass<pass::ReduceReshapeFusion>();
}
TEST_F(TransformationTestsF, ReduceMeanReshapeFusion_SkipIfReshapeRankLessThanReduceRank) {
model = generate_model<ReduceMean>(element::f32, {5, 10, 15}, {50}, {2}, false, false);
manager.register_pass<pass::ReduceReshapeFusion>();
}
TEST_F(TransformationTestsF, ReduceMeanReshapeFusion_SkipIfKeepDims) {
model = generate_model<ReduceMean>(element::f32, {5, 10, 15}, {5, 1, 15}, {1}, true, false);
manager.register_pass<pass::ReduceReshapeFusion>();
}
TEST_F(TransformationTestsF, ReduceMeanReshapeFusionSkipIfNonConstReduceAxes) {
const auto input = std::make_shared<Parameter>(element::f32, PartialShape{5, 10, 15});
const auto reduce_axes = std::make_shared<Parameter>(element::i64, PartialShape{1});
const auto reduce_mean = std::make_shared<ReduceMean>(input, reduce_axes);
const auto target_shape = Constant::create(element::i64, Shape{3}, {5, 1, 15});
const auto reshape = std::make_shared<Reshape>(reduce_mean, target_shape, false);
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{input, reduce_axes});
manager.register_pass<pass::ReduceReshapeFusion>();
}
TEST_F(TransformationTestsF, ReduceMeanReshapeFusionSkipIfNonConstReshapeTargetShape) {
const auto input = std::make_shared<Parameter>(element::f32, PartialShape{5, 10, 15});
const auto reduce_axes = Constant::create(element::i64, Shape{}, {1});
const auto reduce_mean = std::make_shared<ReduceMean>(input, reduce_axes);
const auto target_shape = std::make_shared<Parameter>(element::i64, PartialShape{3});
const auto reshape = std::make_shared<Reshape>(reduce_mean, target_shape, false);
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{input, target_shape});
manager.register_pass<pass::ReduceReshapeFusion>();
}