Eliminates Nop Broadcast/Tile and Slice Before GatherElements (#18614)

This commit is contained in:
Evgenya Stepyreva 2023-07-19 14:12:13 +04:00 committed by GitHub
parent 53fe969773
commit b51069dd79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 165 additions and 0 deletions

View File

@ -20,6 +20,8 @@ class TRANSFORMATIONS_API EliminateSplit;
class TRANSFORMATIONS_API EliminateSplitConcat;
class TRANSFORMATIONS_API EliminateSqueeze;
class TRANSFORMATIONS_API EliminateTranspose;
class TRANSFORMATIONS_API EliminateNopBroadcast;
class TRANSFORMATIONS_API NopSliceBeforeGatherElements;
class TRANSFORMATIONS_API NopElimination;
} // namespace pass
@ -130,3 +132,25 @@ public:
OPENVINO_RTTI("EliminateSplitConcat", "0");
EliminateSplitConcat();
};
/**
* @ingroup ie_transformation_comm on_api
* @brief EliminateNopBroadcast eliminates broadcast or tile with all ones on the second input
*/
class ov::pass::EliminateNopBroadcast : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("EliminateNopBroadcast", "0");
EliminateNopBroadcast();
};
/**
* @ingroup ie_transformation_comm on_api
* @brief NopSliceBeforeGatherElements eliminates slice before GElements if slicing from 0
* It is valid since GatherElements doesn't support negative indices and Slice won't affect
* indexing of elements in the original tensor that GatherElements would like to take
*/
class ov::pass::NopSliceBeforeGatherElements : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("NopSliceBeforeGatherElements", "0");
NopSliceBeforeGatherElements();
};

View File

@ -221,6 +221,7 @@ TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output<Node>& node);
TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr<Node>& eltwise,
const Output<Node>& constant,
const Output<Node>& non_constant_input);
TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v);
} // namespace util
} // namespace op
} // namespace ov

View File

@ -788,6 +788,47 @@ pass::EliminateScatterUpdate::EliminateScatterUpdate() {
this->register_matcher(m, callback);
}
ov::pass::EliminateNopBroadcast::EliminateNopBroadcast() {
MATCHER_SCOPE(EliminateNopBroadcast);
auto root = pattern::wrap_type<op::v1::Broadcast, op::v3::Broadcast, op::v0::Tile>(
pattern::op::as_value_predicate([](std::shared_ptr<Node> node) {
auto input_rank = node->get_input_partial_shape(0).rank();
auto output_rank = node->get_output_partial_shape(0).rank();
return input_rank.is_static() && output_rank.is_static() && input_rank == output_rank;
}));
ov::matcher_pass_callback matcher_pass_callback = [](pattern::Matcher& m) {
const auto& op = m.get_match_root();
if (op::util::is_constant_and_all_values_equal_int(op->input_value(1), 1))
return replace_output_update_name(op->output(0), op->input_value(0));
return false;
};
auto m = std::make_shared<pattern::Matcher>(root, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::NopSliceBeforeGatherElements::NopSliceBeforeGatherElements() {
MATCHER_SCOPE(NopSliceBeforeGatherElements);
auto slice = pattern::wrap_type<op::v8::Slice>();
auto gather = pattern::wrap_type<op::v6::GatherElements>({slice, pattern::any_input()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();
const auto& slice_node = pattern_to_node.at(slice);
bool start_from_zero = op::util::is_constant_and_all_values_equal_int(slice_node->input_value(1), 0);
bool step_is_one = op::util::is_constant_and_all_values_equal_int(slice_node->input_value(3), 1);
if (!start_from_zero || !step_is_one)
return false;
const auto& gather_node = pattern_to_node.at(gather);
gather_node->input(0).replace_source_output(slice_node->input_value(0));
return true;
};
auto m = std::make_shared<pattern::Matcher>(gather, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
// shape-agnostic transformations
ADD_MATCHER_FOR_THIS(EliminatePad)
@ -807,6 +848,8 @@ ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
ADD_MATCHER_FOR_THIS(EliminateSqueeze)
ADD_MATCHER_FOR_THIS(EliminateUnsqueeze)
ADD_MATCHER_FOR_THIS(EliminateBroadcast)
ADD_MATCHER_FOR_THIS(EliminateNopBroadcast)
ADD_MATCHER_FOR_THIS(NopSliceBeforeGatherElements)
ADD_MATCHER_FOR_THIS(EliminateGather)
}
}

View File

@ -9,6 +9,7 @@
#include <functional>
#include <memory>
#include <ngraph/op/util/op_annotations.hpp>
#include <openvino/core/validation_util.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/gather.hpp>
@ -353,6 +354,18 @@ float cast_eps_to_float(double eps_d) {
return eps_f;
}
bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v) {
OPENVINO_SUPPRESS_DEPRECATED_START
if (const auto& constant = ov::get_constant_from_source(output)) {
OPENVINO_SUPPRESS_DEPRECATED_END
const auto& values = constant->cast_vector<int64_t>();
return std::all_of(values.begin(), values.end(), [&](const int64_t& i) {
return i == v;
});
}
return false;
}
} // namespace util
} // namespace op
} // namespace ov

View File

@ -1338,3 +1338,87 @@ TEST(nop_elimination, gather_to_squeeze) {
run_and_check(func_axis_2);
run_and_check(func_axis_3);
}
TEST_F(TransformationTestsF, Nopv1Broadcast) {
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto broadcast_shape = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
auto broadcast = std::make_shared<op::v1::Broadcast>(data, broadcast_shape);
auto relu = std::make_shared<op::v0::Relu>(broadcast);
auto result = std::make_shared<opset10::Result>(relu);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
manager.register_pass<ov::pass::EliminateNopBroadcast>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto relu = std::make_shared<op::v0::Relu>(data);
auto result = std::make_shared<opset10::Result>(relu);
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, Nopv3Broadcast) {
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto broadcast_shape = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
auto broadcast = std::make_shared<op::v3::Broadcast>(data, broadcast_shape);
auto relu = std::make_shared<op::v0::Relu>(broadcast);
auto result = std::make_shared<opset10::Result>(relu);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
manager.register_pass<ov::pass::EliminateNopBroadcast>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto relu = std::make_shared<op::v0::Relu>(data);
auto result = std::make_shared<opset10::Result>(relu);
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, NopTile) {
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto repeats = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
auto tile = std::make_shared<op::v0::Tile>(data, repeats);
auto relu = std::make_shared<op::v0::Relu>(tile);
auto result = std::make_shared<opset10::Result>(relu);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
manager.register_pass<ov::pass::EliminateNopBroadcast>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto relu = std::make_shared<op::v0::Relu>(data);
auto result = std::make_shared<opset10::Result>(relu);
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, NopSliceBeforeGatherElements) {
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto start = opset10::Constant::create(element::i32, Shape{1}, {0});
auto stop = opset10::Constant::create(element::i32, Shape{1}, {2});
auto step = opset10::Constant::create(element::i32, Shape{1}, {1});
auto axis = opset10::Constant::create(element::i32, Shape{1}, {-1});
auto slice = std::make_shared<op::v8::Slice>(data, start, stop, step, axis);
auto indices = std::make_shared<opset10::Parameter>(element::i64, PartialShape{-1, -1, -1, -1});
auto gather_elements = std::make_shared<op::v6::GatherElements>(slice, indices, 2);
auto relu = std::make_shared<op::v0::Relu>(gather_elements);
auto result = std::make_shared<opset10::Result>(relu);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data, indices});
manager.register_pass<ov::pass::NopSliceBeforeGatherElements>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto indices = std::make_shared<opset10::Parameter>(element::i64, PartialShape{-1, -1, -1, -1});
auto gather_elements = std::make_shared<op::v6::GatherElements>(data, indices, 2);
auto relu = std::make_shared<op::v0::Relu>(gather_elements);
auto result = std::make_shared<opset10::Result>(relu);
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data, indices});
}
}