From 3b8f58f1af358315bad049638d6235bcfdaedb85 Mon Sep 17 00:00:00 2001 From: Mikhail Ryzhov Date: Thu, 20 Jul 2023 12:08:02 +0200 Subject: [PATCH] [GNA] Added Reshape support to TS patterns (#18391) * Added Reshape support to TS patterns * fixed unit test * replaced duplicated code --- .../intel_gna/src/common/graph_utils.hpp | 10 ++++- .../src/gna_transformations_pipeline.cpp | 4 +- .../reshape_transpose_substitute.cpp | 5 ++- .../src/transformations/ts_concat_forward.cpp | 23 +++++++++-- .../src/transformations/ts_split_backward.cpp | 39 ++++++++++++++----- .../utils/gather_sinking_utils.cpp | 2 +- .../utils/transformation_helper.cpp | 7 ++-- .../gather_sinking_ts_split_test.cpp | 3 +- 8 files changed, 69 insertions(+), 24 deletions(-) diff --git a/src/plugins/intel_gna/src/common/graph_utils.hpp b/src/plugins/intel_gna/src/common/graph_utils.hpp index f7949c70ead..4122f369dfe 100644 --- a/src/plugins/intel_gna/src/common/graph_utils.hpp +++ b/src/plugins/intel_gna/src/common/graph_utils.hpp @@ -81,6 +81,13 @@ inline bool get_constant_value(const std::shared_ptr& return true; } +/** + * @brief Checks if 2 shapes are the same + */ +inline bool are_shapes_equal(const ov::Shape& shape_1, const ov::Shape& shape_2) { + return (shape_1.size() == shape_2.size()) && std::equal(shape_1.begin(), shape_1.end(), shape_2.begin()); +} + inline bool is_aligned_split(const std::shared_ptr input_op, size_t input_op_out_index) { size_t offset = 0; @@ -609,8 +616,7 @@ inline bool is_reshape_unsqueeze(const ov::Output& output) { auto reshape = output.get_node_shared_ptr(); const ov::Shape input_shape = trim_shape(reshape->get_input_shape(0)); const ov::Shape output_shape = trim_shape(reshape->get_output_shape(0)); - return (input_shape.size() == output_shape.size()) && - std::equal(input_shape.begin(), input_shape.end(), output_shape.begin()); + return are_shapes_equal(input_shape, output_shape); } /** diff --git a/src/plugins/intel_gna/src/gna_transformations_pipeline.cpp b/src/plugins/intel_gna/src/gna_transformations_pipeline.cpp index 8f382d4c333..255d9021d69 100644 --- a/src/plugins/intel_gna/src/gna_transformations_pipeline.cpp +++ b/src/plugins/intel_gna/src/gna_transformations_pipeline.cpp @@ -139,12 +139,12 @@ void TransformationsPipeline::apply(const std::shared_ptr& model, manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); - manager.register_pass(); manager.register_pass(); } manager.register_pass(input_output_subgraphs); diff --git a/src/plugins/intel_gna/src/transformations/reshape_transpose_substitute.cpp b/src/plugins/intel_gna/src/transformations/reshape_transpose_substitute.cpp index 5b9b5579a53..3f7e91d478a 100644 --- a/src/plugins/intel_gna/src/transformations/reshape_transpose_substitute.cpp +++ b/src/plugins/intel_gna/src/transformations/reshape_transpose_substitute.cpp @@ -8,6 +8,7 @@ #include #include +#include "common/graph_utils.hpp" #include "openvino/opsets/opset12.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" @@ -15,6 +16,7 @@ using namespace ov; using namespace ov::opset12; using namespace ov::pass::pattern; using namespace ov::op::util; +using namespace ov::intel_gna; using namespace ov::intel_gna::pass; namespace { @@ -91,8 +93,7 @@ AxisVector find_suitable_transpose_order(const Shape& input_shape, const std::vector& orders) { for (const auto& order : orders) { const Shape transposed_shape = apply_permutation(input_shape, order); - if ((transposed_shape.size() == output_shape.size()) && - std::equal(transposed_shape.begin(), transposed_shape.end(), output_shape.begin())) + if (graph_utils::are_shapes_equal(transposed_shape, output_shape)) return order; } diff --git a/src/plugins/intel_gna/src/transformations/ts_concat_forward.cpp b/src/plugins/intel_gna/src/transformations/ts_concat_forward.cpp index 73119b4df4b..cdf26b5c2ee 100644 --- a/src/plugins/intel_gna/src/transformations/ts_concat_forward.cpp +++ b/src/plugins/intel_gna/src/transformations/ts_concat_forward.cpp @@ -23,6 +23,7 @@ using namespace ov::intel_gna::pass::helper; using namespace ov::intel_gna::limitations; namespace { + bool is_concat_sinked(const Output& output) { auto concat_node = ov::as_type_ptr(output.get_node_shared_ptr()); @@ -33,13 +34,18 @@ bool is_concat_sinked(const Output& output) { for (size_t i = 0; i < concat_node->get_input_size(); ++i) { auto concat_input = concat_node->input_value(i); - auto transpose = ov::as_type_ptr(concat_input.get_node_shared_ptr()); + + auto target_node = graph_utils::get_prev_node_skipping_certain(concat_input.get_node_shared_ptr(), + graph_utils::is_gna_non_functional_node); + std::shared_ptr transpose = ov::as_type_ptr(target_node); + if (transpose && !Limitations::is_transpose_supported(transpose)) return true; } return false; } + } // namespace TSConcatForward::TSConcatForward() { @@ -57,7 +63,10 @@ TSConcatForward::TSConcatForward() { OutputVector concat_inputs = {}; for (size_t i = 0; i < concat_node->get_input_size(); ++i) { ov::Output concat_input = concat_node->input_value(i); - std::shared_ptr transpose = ov::as_type_ptr(concat_input.get_node_shared_ptr()); + + auto target_node = graph_utils::get_prev_node_skipping_certain(concat_input.get_node_shared_ptr(), + graph_utils::is_gna_non_functional_node); + std::shared_ptr transpose = ov::as_type_ptr(target_node); ov::Shape transpose_shape = concat_input.get_shape(); ov::AxisVector transpose_order(transpose_shape.size()); @@ -66,7 +75,7 @@ TSConcatForward::TSConcatForward() { ov::as_type_ptr(transpose->get_input_node_shared_ptr(1)); transpose_order = transpose_const->get_axis_vector_val(); transpose_shape = transpose->get_input_shape(0); - concat_input = concat_input.get_node_shared_ptr()->input_value(0); + concat_input = transpose->input_value(0); } else { std::iota(transpose_order.begin(), transpose_order.end(), 0); } @@ -102,7 +111,13 @@ TSConcatForward::TSConcatForward() { std::make_shared(ov::element::i64, ov::Shape{concat_shape_out.size()}, concat_shape_out); auto reshape_output = std::make_shared(gather_node, reshape_output_const, false); - ov::replace_node_update_name(concat_node, reshape_output); + // skip reshape if the input and output shapes are the same + if (graph_utils::are_shapes_equal(reshape_output->get_input_shape(0), reshape_output->get_output_shape(0))) { + ov::replace_node_update_name(concat_node, gather_node); + } else { + ov::replace_node_update_name(concat_node, reshape_output); + } + return true; }; diff --git a/src/plugins/intel_gna/src/transformations/ts_split_backward.cpp b/src/plugins/intel_gna/src/transformations/ts_split_backward.cpp index aa62b88d85b..8935b49349a 100644 --- a/src/plugins/intel_gna/src/transformations/ts_split_backward.cpp +++ b/src/plugins/intel_gna/src/transformations/ts_split_backward.cpp @@ -18,17 +18,21 @@ using namespace ov; using namespace ov::opset12; using namespace ov::pass::pattern; +using namespace ov::intel_gna::graph_utils; using namespace ov::intel_gna::pass; using namespace ov::intel_gna::pass::helper; using namespace ov::intel_gna::limitations; using namespace ov::intel_gna::graph_utils; namespace { + bool is_split_sinked(const Output& output) { auto split_node = output.get_node_shared_ptr(); for (size_t output_idx = 0; output_idx < split_node->get_output_size(); ++output_idx) { for (auto& input : split_node->get_output_target_inputs(output_idx)) { - auto transpose = ov::as_type_ptr(input.get_node()->shared_from_this()); + auto target_node = + get_next_node_skipping_certain(input.get_node()->shared_from_this(), is_gna_non_functional_node); + std::shared_ptr transpose = ov::as_type_ptr(target_node); if (transpose && !Limitations::is_transpose_supported(transpose)) return true; } @@ -40,19 +44,25 @@ bool is_split_sinked(const Output& output) { TSSplitBackward::TSSplitBackward() { MATCHER_SCOPE(TSSplitBackward); - auto split_node_label = wrap_type(is_split_sinked); + auto split_node_label = wrap_type(is_split_sinked); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); auto& split_node_label_output = pattern_to_output.at(split_node_label); - auto split_node = as_type_ptr(split_node_label_output.get_node_shared_ptr()); + std::shared_ptr split_node = as_type_ptr(split_node_label_output.get_node_shared_ptr()); + if (!split_node) { + split_node = as_type_ptr(split_node_label_output.get_node_shared_ptr()); + } ov::AxisVector gather_ids = {}; std::vector gather_indices_vecs; + std::vector split_slices; for (size_t output_idx = 0; output_idx < split_node->get_output_size(); ++output_idx) { for (auto& input : split_node->get_output_target_inputs(output_idx)) { - auto transpose = ov::as_type_ptr(input.get_node()->shared_from_this()); + auto target_node = + get_next_node_skipping_certain(input.get_node()->shared_from_this(), is_gna_non_functional_node); + std::shared_ptr transpose = ov::as_type_ptr(target_node); ov::Shape transpose_shape = split_node->get_output_shape(output_idx); ov::AxisVector transpose_order(transpose_shape.size()); @@ -74,6 +84,8 @@ TSSplitBackward::TSSplitBackward() { i += id; }); gather_ids.insert(gather_ids.end(), slice_ids.begin(), slice_ids.end()); + // collect slice sizes + split_slices.push_back(slice_ids.size()); } } @@ -89,18 +101,27 @@ TSSplitBackward::TSSplitBackward() { ov::copy_runtime_info(split_node, {reshape_input, reshape_input_const}); + std::shared_ptr gather; auto gather_axis = std::make_shared(ov::element::i64, ov::Shape{}, 1); auto gather_indices = std::make_shared(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids); - auto gather = std::make_shared(reshape_input, gather_indices, gather_axis); + if (graph_utils::are_shapes_equal(split_input_shape, reshape_input_shape)) { + gather = std::make_shared(split_node->input_value(0), gather_indices, gather_axis); + } else { + gather = std::make_shared(reshape_input, gather_indices, gather_axis); + } - auto split_axis_new = std::make_shared(ov::element::i64, ov::Shape{}, 1); - auto split_new = std::make_shared(gather, split_axis_new, split_node->get_num_splits()); + auto split_new_axis = std::make_shared(ov::element::i64, ov::Shape{}, 1); + auto split_new_lengths = + std::make_shared(ov::element::i64, ov::Shape{split_slices.size()}, split_slices); + auto split_new = std::make_shared(gather, split_new_axis, split_new_lengths); - ov::copy_runtime_info(split_node, {gather_axis, gather_indices, gather, split_axis_new, split_new}); + ov::copy_runtime_info(split_node, {gather_axis, gather_indices, gather, split_new_axis, split_new}); for (size_t output_idx = 0; output_idx < split_node->get_output_size(); ++output_idx) { for (auto& input : split_node->get_output_target_inputs(output_idx)) { - auto transpose = ov::as_type_ptr(input.get_node()->shared_from_this()); + auto target_node = get_next_node_skipping_certain(input.get_node()->shared_from_this(), + graph_utils::is_gna_non_functional_node); + std::shared_ptr transpose = ov::as_type_ptr(target_node); if (transpose) { auto reshape_output_const_new = std::make_shared(ov::element::i64, diff --git a/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp b/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp index 8dc687310f8..a1c21383394 100644 --- a/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp +++ b/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp @@ -257,7 +257,7 @@ struct GatherInfo { bool operator==(const GatherInfo& another) { if (indices.size() != another.indices.size()) return false; - if (!std::equal(indices.begin(), indices.end(), another.indices.begin())) + if (!are_shapes_equal(indices, another.indices)) return false; return axis == another.axis; } diff --git a/src/plugins/intel_gna/src/transformations/utils/transformation_helper.cpp b/src/plugins/intel_gna/src/transformations/utils/transformation_helper.cpp index 0d36cba1de8..125259cc6d5 100644 --- a/src/plugins/intel_gna/src/transformations/utils/transformation_helper.cpp +++ b/src/plugins/intel_gna/src/transformations/utils/transformation_helper.cpp @@ -4,6 +4,7 @@ #include "transformation_helper.hpp" +#include "common/graph_utils.hpp" #include "log/debug.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/opsets/opset12.hpp" @@ -13,6 +14,7 @@ #include "transformations/rt_info/transpose_sinking_attr.hpp" using namespace ov::opset12; +using namespace ov::intel_gna; namespace ov { namespace intel_gna { @@ -92,7 +94,7 @@ bool TransposeOrderMatches(std::shared_ptr transpose, std::vector node) { if (!node_parent) { THROW_GNA_EXCEPTION << "The removing node has no parrent node"; } - if (input_node_shape.size() != output_node_shape.size() || - !std::equal(input_node_shape.begin(), input_node_shape.end(), output_node_shape.begin())) { + if (!graph_utils::are_shapes_equal(input_node_shape, output_node_shape)) { auto reshape_const_node = std::make_shared(ov::element::i64, ov::Shape{output_node_shape.size()}, output_node_shape); node_parent = std::make_shared(node_parent, reshape_const_node, false); diff --git a/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_ts_split_test.cpp b/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_ts_split_test.cpp index 712a9413c31..69ca9807b49 100644 --- a/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_ts_split_test.cpp +++ b/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_ts_split_test.cpp @@ -43,7 +43,8 @@ TEST(TSSplit, Backward) { auto reshape1 = std::make_shared(input_params, reshape_const1, false); auto gather = make_gather(reshape1, TSSplit_Backward_indexes, /* axis */ 1); auto split_axis = Constant::create(element::i64, ov::Shape{}, ov::Shape{1}); - auto split = std::make_shared(gather, split_axis, 1); + auto split_lengths = Constant::create(element::i64, ov::Shape{1}, ov::Shape{8}); + auto split = std::make_shared(gather, split_axis, split_lengths); auto reshape_const2 = Constant::create(element::i64, ov::Shape{4}, ov::Shape{1, 1, 2, 4}); auto reshape2 = std::make_shared(split->output(0), reshape_const2, false); const auto result = std::make_shared(reshape2);