[Transformations] Added interchangeable reshape elimination (#9691)
* [Transformations] Added interchangeable reshape elimination * Applied comments #2 * returned Reshape in condition * applied comments #3 * applied comments #4 * added comment in plugin with reason about transformation
This commit is contained in:
committed by
GitHub
parent
a002b26294
commit
fce49e6d80
@@ -18,11 +18,11 @@ class TRANSFORMATIONS_API ReshapeSequenceFusion;
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ReshpaeSequenceFusion fuses sequence of Reshape operation into single Reshape
|
||||
* @brief ReshapeSequenceFusion fuses sequence of Reshape operation into single Reshape or eliminates full redundant sequence
|
||||
*/
|
||||
|
||||
class ngraph::pass::ReshapeSequenceFusion: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ReshapeSequenceFusion();
|
||||
ReshapeSequenceFusion(bool use_shape_for_elimination = true);
|
||||
};
|
||||
|
||||
@@ -153,7 +153,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
|
||||
common_fusions->add_matcher<ngraph::pass::DivideFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SubtractFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
|
||||
common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>(m_use_shapes);
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::BinarizeWeights>();
|
||||
|
||||
@@ -102,16 +102,25 @@ static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
|
||||
if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
|
||||
ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
|
||||
ov::as_type_ptr<opset3::Reshape>(input_node)) {
|
||||
if (input_node->get_output_target_inputs(0).size() != 1)
|
||||
return false;
|
||||
|
||||
auto shape = node->get_output_shape(0);
|
||||
std::vector<int64_t> vi;
|
||||
vi.assign(shape.begin(), shape.end());
|
||||
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
||||
auto new_reshape =
|
||||
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
|
||||
new_reshape->set_friendly_name(node->get_friendly_name());
|
||||
copy_runtime_info({input_node, node}, new_reshape);
|
||||
replace_node(node, new_reshape);
|
||||
return true;
|
||||
|
||||
// remove interchangeable nodes
|
||||
if (input_node->get_input_partial_shape(0).is_static() && input_node->get_input_shape(0) == shape) {
|
||||
return replace_output_update_name(node->output(0), input_node->input_value(0));
|
||||
} else {
|
||||
std::vector<int64_t> vi;
|
||||
vi.assign(shape.begin(), shape.end());
|
||||
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
||||
auto new_reshape =
|
||||
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
|
||||
new_reshape->set_friendly_name(node->get_friendly_name());
|
||||
copy_runtime_info({input_node, node}, new_reshape);
|
||||
replace_node(node, new_reshape);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
@@ -55,7 +55,7 @@ bool has_valid_pattern(const ov::Output<ov::Node>& node_out) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
|
||||
ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion(bool use_shape_for_elimination) {
|
||||
MATCHER_SCOPE(ReshapeSequenceFusion);
|
||||
auto reshape_input = pattern::any_input();
|
||||
auto reshape_a_pattern = pattern::wrap_type<opset8::Constant>();
|
||||
@@ -87,9 +87,21 @@ ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
|
||||
input = node->input_value(0);
|
||||
}
|
||||
|
||||
reshape->input(0).replace_source_output(input);
|
||||
copy_runtime_info(nodes, reshape);
|
||||
return false;
|
||||
// remove redundant reshapes
|
||||
bool replaced = false;
|
||||
if (use_shape_for_elimination && input.get_partial_shape().is_static() && reshape->get_output_partial_shape(0).is_static() &&
|
||||
input.get_shape() == reshape->get_output_shape(0)) {
|
||||
// in case if elimination is not allowed we still can eliminate all transposes except last one
|
||||
replaced = replace_output_update_name(reshape->output(0), input);
|
||||
}
|
||||
|
||||
if (!replaced) {
|
||||
reshape->input(0).replace_source_output(input);
|
||||
copy_runtime_info(nodes, reshape);
|
||||
return false; // because root node wasn't replaced
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_b, matcher_name);
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "transformations/convert_precision.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
#include "rnn_sequences_optimization.hpp"
|
||||
#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
@@ -34,6 +35,8 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
|
||||
if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(nGraphFunc)) {
|
||||
manager.register_pass<ReshapeFullyConnectedFusion>();
|
||||
}
|
||||
// after transformation "MoveEltwiseUpThroughDataMov" there can be Reshape sequences that should be eliminated or fused
|
||||
manager.register_pass<ngraph::pass::ReshapeSequenceFusion>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});
|
||||
|
||||
|
||||
@@ -140,17 +140,48 @@ TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
|
||||
pass_manager.register_pass<pass::NopElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
bool reshape_is_missing = true;
|
||||
bool movement_are_missing = true;
|
||||
for (auto node : f->get_ops()) {
|
||||
if (node->get_friendly_name() == "reshape") {
|
||||
reshape_is_missing = false;
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
|
||||
auto original_names = ngraph::getFusedNamesVector(node);
|
||||
sort(original_names.begin(), original_names.end());
|
||||
ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
|
||||
if (node->get_friendly_name() == "reshape" || node->get_friendly_name() == "squeeze") {
|
||||
movement_are_missing = false;
|
||||
}
|
||||
}
|
||||
ASSERT_FALSE(reshape_is_missing);
|
||||
ASSERT_TRUE(movement_are_missing);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, squeeze_unsqueeze_elimination) {
|
||||
std::shared_ptr<Function> f;
|
||||
{
|
||||
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
|
||||
|
||||
auto relu = std::make_shared<opset4::Relu>(arg);
|
||||
relu->set_friendly_name("relu");
|
||||
|
||||
auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
|
||||
squeeze->set_friendly_name("squeeze");
|
||||
|
||||
auto unsqueeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(squeeze, unsqueeze_axes);
|
||||
unsqueeze->set_friendly_name("unsqueeze");
|
||||
|
||||
auto abs = std::make_shared<opset4::Abs>(unsqueeze);
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
|
||||
}
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::InitNodeInfo>();
|
||||
pass_manager.register_pass<pass::NopElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
bool movement_are_missing = true;
|
||||
for (auto node : f->get_ops()) {
|
||||
if (node->get_friendly_name() == "squeeze" || node->get_friendly_name() == "unsqueeze") {
|
||||
movement_are_missing = false;
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(movement_are_missing);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, reshape_elimination_v1_dynamic) {
|
||||
@@ -165,6 +196,33 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic) {
|
||||
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 1);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, reshape_elimination_v1_check_consumer_count) {
|
||||
std::shared_ptr<Function> f;
|
||||
{
|
||||
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
|
||||
|
||||
auto reshape_1_shape = opset4::Constant::create(element::i64, Shape{2}, {128, 3});
|
||||
auto reshape_1 = std::make_shared<opset4::Reshape>(arg, reshape_1_shape, false);
|
||||
reshape_1->set_friendly_name("reshape_1");
|
||||
|
||||
auto reshape_2_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
|
||||
auto reshape_2 = std::make_shared<opset4::Reshape>(reshape_1, reshape_2_shape, false);
|
||||
reshape_2->set_friendly_name("reshape_2");
|
||||
|
||||
auto relu = std::make_shared<opset4::Relu>(reshape_1);
|
||||
relu->set_friendly_name("relu");
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{reshape_2, relu}, ParameterVector{arg});
|
||||
}
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::InitNodeInfo>();
|
||||
pass_manager.register_pass<pass::NopElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 2);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, concat_elimination_single_node) {
|
||||
int64_t a = 0;
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
|
||||
|
||||
@@ -305,3 +305,21 @@ TEST_F(TransformationTestsF, ReshapeSequenceFusionNeg5_special_zero_false) {
|
||||
manager.register_pass<pass::ReshapeSequenceFusion>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ReshapeSequenceFusionEliminate) {
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
auto relu = std::make_shared<opset6::Relu>(data);
|
||||
auto a = reshape(relu, {2, 3});
|
||||
auto b = reshape(a, {1, 2, 3});
|
||||
function = std::make_shared<Function>(OutputVector{b}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::ReshapeSequenceFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
auto relu = std::make_shared<opset6::Relu>(data);
|
||||
function_ref = std::make_shared<Function>(OutputVector{relu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user