From 9d8016d1e687c7e64842bab90bb3ff24ccf84c76 Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Thu, 23 Feb 2023 17:12:54 +0000 Subject: [PATCH] Fix validate for split, revert changes for concat, add BatchToSpace/SpaceToBatch --- .../transpose_sinking.cpp | 10 ++++++---- .../transpose_sinking_binary.cpp | 20 +++++++++++-------- .../transpose_sinking_concat.cpp | 17 ++++++++-------- .../transpose_sinking_general.cpp | 3 +++ .../transpose_sinking_pad.cpp | 1 + 5 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp index 61d8e797bbd..adcdbe27327 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp @@ -20,7 +20,8 @@ using namespace ov; namespace { -std::vector get_updated_order_forward(std::vector& axes_values, std::vector& order_values) { +std::vector get_updated_order_forward(const std::vector& axes_values, + const std::vector& order_values) { size_t buffer_size = order_values.size() - axes_values.size(); std::vector aligned_order(buffer_size, 0); std::vector values_to_reduce(axes_values); @@ -33,14 +34,15 @@ std::vector get_updated_order_forward(std::vector& axes_values, continue; } - auto ub = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]); - aligned_order[j] = order_values[i] - (ub - values_to_reduce.begin()); + auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]); + aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin()); ++j; } return aligned_order; } -std::vector get_updated_order_backward(std::vector& axes_values, std::vector& order_values) { +std::vector get_updated_order_backward(const std::vector& axes_values, + const std::vector& order_values) { size_t buffer_size = order_values.size() + axes_values.size(); std::vector aligned_order(buffer_size); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp index 1a190d8c1ed..490f438d0ad 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp @@ -23,10 +23,12 @@ using namespace transpose_sinking; ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() { MATCHER_SCOPE(TransposeSinkingBinaryForward); - auto main_node_label = - wrap_type([](const Output& output) -> bool { - return has_static_rank()(output) && IfNodeHasTransposeInputs(output); - }); + auto main_node_label = wrap_type([](const Output& output) -> bool { + return has_static_rank()(output) && IfNodeHasTransposeInputs(output); + }); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); @@ -54,10 +56,12 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() { ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() { MATCHER_SCOPE(TransposeSinkingBinaryBackward); - auto main_node_label = - wrap_type([](const Output& output) -> bool { - return has_static_rank()(output) && HasSameOutputTransposeNodes(output); - }); + auto main_node_label = wrap_type([](const Output& output) -> bool { + return has_static_rank()(output) && HasSameOutputTransposeNodes(output); + }); auto transpose_const_label = wrap_type(); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp index 6eb1e913993..315b5045318 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp @@ -68,19 +68,18 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() { return has_static_rank()(output) && HasSameOutputTransposeNodes(output); }); - /* auto transpose_const_label = wrap_type(); + auto transpose_const_label = wrap_type(); - auto transpose_label = - wrap_type({main_node_label, transpose_const_label}, [](const Output& output) -> bool { - return has_static_rank()(output); - });*/ + auto transpose_label = + wrap_type({main_node_label, transpose_const_label}, [](const Output& output) -> bool { + return has_static_rank()(output) && is_sinking_node(output); + }); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); - + auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); - auto transpose = main_node->output(0).get_target_inputs().begin()->get_node()->shared_from_this(); - auto transpose_const = as_type_ptr(transpose->input_value(1).get_node_shared_ptr()); auto concat_node = as_type_ptr(main_node); auto concat_axis = concat_node->get_concatenation_axis(); if (concat_axis < 0) { @@ -102,6 +101,6 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() { return true; }; - auto m = std::make_shared(main_node_label, matcher_name); + auto m = std::make_shared(transpose_label, matcher_name); register_matcher(m, matcher_pass_callback); } diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp index 5291e234046..d54477f2da3 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp @@ -12,6 +12,7 @@ #include "itt.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" +#include "transformations/common_optimizations/transpose_sinking_batch_to_space.hpp" #include "transformations/common_optimizations/transpose_sinking_binary.hpp" #include "transformations/common_optimizations/transpose_sinking_concat.hpp" #include "transformations/common_optimizations/transpose_sinking_pad.hpp" @@ -27,6 +28,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() { add_matcher(); add_matcher(); add_matcher(); + add_matcher(); add_matcher(); } @@ -38,6 +40,7 @@ ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() { add_matcher(); add_matcher(); add_matcher(); + add_matcher(); add_matcher(); } diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_pad.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_pad.cpp index 80deacde17c..014145c3456 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_pad.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_pad.cpp @@ -49,6 +49,7 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() { main_node->input(2).replace_source_output( ChangeValuesOrder(main_node->input_value(2), reversed_transpose_order, axis)); + main_node->validate_and_infer_types(); // insert Transpose for Pad output TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0}; for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {