Eliminates Nop Broadcast/Tile and Slice Before GatherElements (#18614)
This commit is contained in:
parent
53fe969773
commit
b51069dd79
@ -20,6 +20,8 @@ class TRANSFORMATIONS_API EliminateSplit;
|
||||
class TRANSFORMATIONS_API EliminateSplitConcat;
|
||||
class TRANSFORMATIONS_API EliminateSqueeze;
|
||||
class TRANSFORMATIONS_API EliminateTranspose;
|
||||
class TRANSFORMATIONS_API EliminateNopBroadcast;
|
||||
class TRANSFORMATIONS_API NopSliceBeforeGatherElements;
|
||||
class TRANSFORMATIONS_API NopElimination;
|
||||
|
||||
} // namespace pass
|
||||
@ -130,3 +132,25 @@ public:
|
||||
OPENVINO_RTTI("EliminateSplitConcat", "0");
|
||||
EliminateSplitConcat();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_comm on_api
|
||||
* @brief EliminateNopBroadcast eliminates broadcast or tile with all ones on the second input
|
||||
*/
|
||||
class ov::pass::EliminateNopBroadcast : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("EliminateNopBroadcast", "0");
|
||||
EliminateNopBroadcast();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_comm on_api
|
||||
* @brief NopSliceBeforeGatherElements eliminates slice before GElements if slicing from 0
|
||||
* It is valid since GatherElements doesn't support negative indices and Slice won't affect
|
||||
* indexing of elements in the original tensor that GatherElements would like to take
|
||||
*/
|
||||
class ov::pass::NopSliceBeforeGatherElements : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("NopSliceBeforeGatherElements", "0");
|
||||
NopSliceBeforeGatherElements();
|
||||
};
|
||||
|
@ -221,6 +221,7 @@ TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output<Node>& node);
|
||||
TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr<Node>& eltwise,
|
||||
const Output<Node>& constant,
|
||||
const Output<Node>& non_constant_input);
|
||||
TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v);
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
@ -788,6 +788,47 @@ pass::EliminateScatterUpdate::EliminateScatterUpdate() {
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::pass::EliminateNopBroadcast::EliminateNopBroadcast() {
|
||||
MATCHER_SCOPE(EliminateNopBroadcast);
|
||||
auto root = pattern::wrap_type<op::v1::Broadcast, op::v3::Broadcast, op::v0::Tile>(
|
||||
pattern::op::as_value_predicate([](std::shared_ptr<Node> node) {
|
||||
auto input_rank = node->get_input_partial_shape(0).rank();
|
||||
auto output_rank = node->get_output_partial_shape(0).rank();
|
||||
return input_rank.is_static() && output_rank.is_static() && input_rank == output_rank;
|
||||
}));
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [](pattern::Matcher& m) {
|
||||
const auto& op = m.get_match_root();
|
||||
if (op::util::is_constant_and_all_values_equal_int(op->input_value(1), 1))
|
||||
return replace_output_update_name(op->output(0), op->input_value(0));
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(root, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::NopSliceBeforeGatherElements::NopSliceBeforeGatherElements() {
|
||||
MATCHER_SCOPE(NopSliceBeforeGatherElements);
|
||||
auto slice = pattern::wrap_type<op::v8::Slice>();
|
||||
auto gather = pattern::wrap_type<op::v6::GatherElements>({slice, pattern::any_input()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_node = m.get_pattern_map();
|
||||
const auto& slice_node = pattern_to_node.at(slice);
|
||||
bool start_from_zero = op::util::is_constant_and_all_values_equal_int(slice_node->input_value(1), 0);
|
||||
bool step_is_one = op::util::is_constant_and_all_values_equal_int(slice_node->input_value(3), 1);
|
||||
if (!start_from_zero || !step_is_one)
|
||||
return false;
|
||||
const auto& gather_node = pattern_to_node.at(gather);
|
||||
gather_node->input(0).replace_source_output(slice_node->input_value(0));
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(gather, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
|
||||
// shape-agnostic transformations
|
||||
ADD_MATCHER_FOR_THIS(EliminatePad)
|
||||
@ -807,6 +848,8 @@ ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
|
||||
ADD_MATCHER_FOR_THIS(EliminateSqueeze)
|
||||
ADD_MATCHER_FOR_THIS(EliminateUnsqueeze)
|
||||
ADD_MATCHER_FOR_THIS(EliminateBroadcast)
|
||||
ADD_MATCHER_FOR_THIS(EliminateNopBroadcast)
|
||||
ADD_MATCHER_FOR_THIS(NopSliceBeforeGatherElements)
|
||||
ADD_MATCHER_FOR_THIS(EliminateGather)
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <ngraph/op/util/op_annotations.hpp>
|
||||
#include <openvino/core/validation_util.hpp>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
@ -353,6 +354,18 @@ float cast_eps_to_float(double eps_d) {
|
||||
return eps_f;
|
||||
}
|
||||
|
||||
bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
if (const auto& constant = ov::get_constant_from_source(output)) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
const auto& values = constant->cast_vector<int64_t>();
|
||||
return std::all_of(values.begin(), values.end(), [&](const int64_t& i) {
|
||||
return i == v;
|
||||
});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
@ -1338,3 +1338,87 @@ TEST(nop_elimination, gather_to_squeeze) {
|
||||
run_and_check(func_axis_2);
|
||||
run_and_check(func_axis_3);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, Nopv1Broadcast) {
|
||||
{
|
||||
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto broadcast_shape = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
|
||||
auto broadcast = std::make_shared<op::v1::Broadcast>(data, broadcast_shape);
|
||||
auto relu = std::make_shared<op::v0::Relu>(broadcast);
|
||||
auto result = std::make_shared<opset10::Result>(relu);
|
||||
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
|
||||
manager.register_pass<ov::pass::EliminateNopBroadcast>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto relu = std::make_shared<op::v0::Relu>(data);
|
||||
auto result = std::make_shared<opset10::Result>(relu);
|
||||
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, Nopv3Broadcast) {
|
||||
{
|
||||
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto broadcast_shape = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
|
||||
auto broadcast = std::make_shared<op::v3::Broadcast>(data, broadcast_shape);
|
||||
auto relu = std::make_shared<op::v0::Relu>(broadcast);
|
||||
auto result = std::make_shared<opset10::Result>(relu);
|
||||
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
|
||||
manager.register_pass<ov::pass::EliminateNopBroadcast>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto relu = std::make_shared<op::v0::Relu>(data);
|
||||
auto result = std::make_shared<opset10::Result>(relu);
|
||||
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NopTile) {
|
||||
{
|
||||
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto repeats = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
|
||||
auto tile = std::make_shared<op::v0::Tile>(data, repeats);
|
||||
auto relu = std::make_shared<op::v0::Relu>(tile);
|
||||
auto result = std::make_shared<opset10::Result>(relu);
|
||||
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
|
||||
manager.register_pass<ov::pass::EliminateNopBroadcast>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto relu = std::make_shared<op::v0::Relu>(data);
|
||||
auto result = std::make_shared<opset10::Result>(relu);
|
||||
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NopSliceBeforeGatherElements) {
|
||||
{
|
||||
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
auto start = opset10::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto stop = opset10::Constant::create(element::i32, Shape{1}, {2});
|
||||
auto step = opset10::Constant::create(element::i32, Shape{1}, {1});
|
||||
auto axis = opset10::Constant::create(element::i32, Shape{1}, {-1});
|
||||
auto slice = std::make_shared<op::v8::Slice>(data, start, stop, step, axis);
|
||||
|
||||
auto indices = std::make_shared<opset10::Parameter>(element::i64, PartialShape{-1, -1, -1, -1});
|
||||
auto gather_elements = std::make_shared<op::v6::GatherElements>(slice, indices, 2);
|
||||
|
||||
auto relu = std::make_shared<op::v0::Relu>(gather_elements);
|
||||
auto result = std::make_shared<opset10::Result>(relu);
|
||||
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data, indices});
|
||||
manager.register_pass<ov::pass::NopSliceBeforeGatherElements>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto indices = std::make_shared<opset10::Parameter>(element::i64, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
auto gather_elements = std::make_shared<op::v6::GatherElements>(data, indices, 2);
|
||||
|
||||
auto relu = std::make_shared<op::v0::Relu>(gather_elements);
|
||||
auto result = std::make_shared<opset10::Result>(relu);
|
||||
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data, indices});
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user