diff --git a/src/common/transformations/include/transformations/rt_info/transpose_sinking_attr.hpp b/src/common/transformations/include/transformations/rt_info/transpose_sinking_attr.hpp index 87053c72799..e2c2934b187 100644 --- a/src/common/transformations/include/transformations/rt_info/transpose_sinking_attr.hpp +++ b/src/common/transformations/include/transformations/rt_info/transpose_sinking_attr.hpp @@ -11,6 +11,7 @@ namespace ov { TRANSFORMATIONS_API void mark_as_no_sinking_node(const std::shared_ptr& node); +TRANSFORMATIONS_API void reset_no_sinking_attribute(const std::shared_ptr& node); TRANSFORMATIONS_API bool is_sinking_node(const std::shared_ptr& node); TRANSFORMATIONS_API bool is_sinking_node(const Node* node); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_base.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_base.hpp new file mode 100644 index 00000000000..93fb58dec91 --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_base.hpp @@ -0,0 +1,63 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSForwardBase; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSForwardBase is a base class for all forward transformations. + */ +class ov::pass::transpose_sinking::TSForwardBase : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSForwardBase", "0"); + TSForwardBase() = default; + + template + void create_pattern(bool const_transpose_input, std::vector transpose_indices = {}) { + m_const_transpose_input = const_transpose_input; + m_tranpose_indices = std::move(transpose_indices); + m_pattern = ov::pass::pattern::wrap_type([&](const Output& output) -> bool { + return if_node_has_transpose_inputs(output, m_const_transpose_input, m_tranpose_indices); + }); + } + + using sinking_function = + std::function& main_node, const utils::TransposeInputsInfo& transpose_info)>; + + void transpose_sinking(const std::string& pass_name, const sinking_function& sinking_transformation = nullptr); + +protected: + static bool default_inputs_update(const std::shared_ptr& main_node, + const utils::TransposeInputsInfo& transpose_info); + + void default_outputs_update(const std::shared_ptr& main_node, + const utils::TransposeInputsInfo& transpose_info); + +private: + static bool if_node_has_transpose_inputs(const Output& output, + bool const_transpose_input, + const std::vector& transpose_indices); + + std::shared_ptr m_pattern; + bool m_const_transpose_input = true; + std::vector m_tranpose_indices; +}; \ No newline at end of file diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_binary.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_binary.hpp index a1b559f3c4d..f48ed86db1a 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_binary.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_binary.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSBinaryBackward; * @brief TSBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic, * BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction. */ -class ov::pass::transpose_sinking::TSBinaryForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSBinaryForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSBinaryForward", "0"); TSBinaryForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_concat.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_concat.hpp index 904f68ec4fa..36a738440fe 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_concat.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_concat.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSConcatBackward; * @brief TSConcatForward transformation sinks Transpose through Concat operation * in the forward direction. */ -class ov::pass::transpose_sinking::TSConcatForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSConcatForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSConcatForward", "0"); TSConcatForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_data_movement.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_data_movement.hpp index 8a4612a513c..45cbee0ba1b 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_data_movement.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_data_movement.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -25,7 +26,7 @@ class TRANSFORMATIONS_API TSDataMovementBackward; * ReverseSequence and Pad operations in the forward direction. * These operations are categorized as "DataMovement" and are handled in a similar way in this transformation. */ -class ov::pass::transpose_sinking::TSDataMovementForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSDataMovementForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSDataMovementForward", "0"); TSDataMovementForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_gather.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_gather.hpp index 810c538a590..d7c4eeb37cb 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_gather.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_gather.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSGatherBackward; * @brief TSGatherForward transformation sinks Transpose through Gather operations * in the forward direction. */ -class ov::pass::transpose_sinking::TSGatherForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSGatherForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSGatherForward", "0"); TSGatherForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_interpolate.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_interpolate.hpp index 519154626a9..a0ee9bbbfcf 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_interpolate.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_interpolate.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSInterpolateBackward; * @brief TSInterpolateForward transformation sinks Transpose through Interpolate operation * in the forward direction. */ -class ov::pass::transpose_sinking::TSInterpolateForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSInterpolateForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSInterpolateForward", "0"); TSInterpolateForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_reduction.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_reduction.hpp index 4e2cbb6501c..5ad99be5d2c 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_reduction.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_reduction.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSReductionBackward; * @brief TSReductionForward transformation sinks Transpose through Reduce operations * in the forward direction. */ -class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSReductionForward", "0"); TSReductionForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp new file mode 100644 index 00000000000..eb2b4ddb868 --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSResetNoSinkingAttribute; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSResetNoSinkingAttribute transformation resets all NoSinkingAttribute runtime attributes + * in Transpose operations. + */ +class ov::pass::transpose_sinking::TSResetNoSinkingAttribute : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSResetNoSinkingAttribute", "0"); + TSResetNoSinkingAttribute(); +}; diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_slice.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_slice.hpp index a5a135d44b3..8b486609aa1 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_slice.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_slice.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -19,7 +20,7 @@ class TRANSFORMATIONS_API TSSliceBackward; } // namespace pass } // namespace ov -class ov::pass::transpose_sinking::TSSliceForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSSliceForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSSliceForward", "0"); TSSliceForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_split.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_split.hpp index ba75ac65662..c56a93415a6 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_split.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_split.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSSplitForward; * @brief TSSplitForward transformation sinks Transpose through Split, VariadicSplit operations * in the forward direction. */ -class ov::pass::transpose_sinking::TSSplitForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSSplitForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSSplitForward", "0"); TSSplitForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_squeeze.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_squeeze.hpp index c7aa6f2aa0f..f3b0da47a81 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_squeeze.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_squeeze.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSSqueezeBackward; * @brief TSSqueezeForward transformation sinks Transpose through Reshape, Squeeze operations * in the forward direction. */ -class ov::pass::transpose_sinking::TSSqueezeForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSSqueezeForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSSqueezeForward", "0"); TSSqueezeForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_unary.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_unary.hpp index 9c6e93356f7..c0d13b66772 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_unary.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_unary.hpp @@ -5,6 +5,7 @@ #pragma once #include "openvino/pass/graph_rewrite.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -23,7 +24,7 @@ class TRANSFORMATIONS_API TSUnaryBackward; * @brief TSUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu, * SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction. */ -class ov::pass::transpose_sinking::TSUnaryForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSUnaryForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("TSUnaryForward", "0"); TSUnaryForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_unsqueeze.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_unsqueeze.hpp index 05150bfe1fb..c6f3076cc92 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_unsqueeze.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_unsqueeze.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" #include "transformations_visibility.hpp" namespace ov { @@ -24,7 +25,7 @@ class TRANSFORMATIONS_API TSUnsqueezeBackward; * @brief TSUnsqueezeForward transformation sinks Transpose through Unsqueeze, Reshape operations * in the forward direction. */ -class ov::pass::transpose_sinking::TSUnsqueezeForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSUnsqueezeForward : public ov::pass::transpose_sinking::TSForwardBase { public: OPENVINO_RTTI("ov::pass::TSUnsqueezeForward", "0"); TSUnsqueezeForward(); diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_utils.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_utils.hpp index fdd24bd50b0..b2655a822b1 100644 --- a/src/common/transformations/include/transformations/transpose_sinking/ts_utils.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_utils.hpp @@ -30,12 +30,14 @@ struct TransposeInputsInfo { * @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo * for it */ -TransposeInputsInfo GetFirstTransposeInput(const std::shared_ptr&); +TransposeInputsInfo GetFirstTransposeInput(const std::shared_ptr&, + bool const_transpose_order, + const std::vector& indices = {}); /** * @brief Checks if @arg has any input node that is a transpose operation */ -bool IfNodeHasTransposeInputs(const ov::Output&); +bool IfNodeHasTransposeInputs(const ov::Output&, const std::vector& indices = {}); /** * @brief Reverses order of transpose operation. Do it in a such way that if we had couple following one after @@ -86,8 +88,6 @@ ov::NodeVector InsertTransposeBeforeNode(const std::shared_ptr& main_n std::vector input_indexes = {}); } // namespace sink_backward -void UpdateForwardSinkingAbility(const std::shared_ptr&); - /** * @brief Checks if @arg has consumers that are all the same Transpose operation * and that sinking is enabled for all these Transpose ops. Otherwise returns false. diff --git a/src/common/transformations/src/transformations/rt_info/transpose_sinking_attr.cpp b/src/common/transformations/src/transformations/rt_info/transpose_sinking_attr.cpp index 6d3b7bc520e..c26798aa9b0 100644 --- a/src/common/transformations/src/transformations/rt_info/transpose_sinking_attr.cpp +++ b/src/common/transformations/src/transformations/rt_info/transpose_sinking_attr.cpp @@ -11,6 +11,14 @@ void ov::mark_as_no_sinking_node(const std::shared_ptr& node) { rt_info[NoTransposeSinkingAttr::get_type_info_static()] = NoTransposeSinkingAttr(); } +void ov::reset_no_sinking_attribute(const std::shared_ptr& node) { + auto& rt_info = node->get_rt_info(); + auto it = rt_info.find(NoTransposeSinkingAttr::get_type_info_static()); + if (it != rt_info.end()) { + rt_info.erase(it); + } +} + namespace { template bool is_sinking_node_private(NodePtr node) { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_base.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_base.cpp new file mode 100644 index 00000000000..8d9c023e38e --- /dev/null +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_base.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/transpose_sinking/ts_base.hpp" + +#include "itt.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/fake_quantize.hpp" +#include "openvino/op/prelu.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" + +using namespace ov; +using namespace ov::pass::pattern; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; + +void TSForwardBase::transpose_sinking(const std::string& pass_name, + const TSForwardBase::sinking_function& sinking_transformation) { + ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto main_node = pattern_to_output.at(m_pattern).get_node_shared_ptr(); + utils::TransposeInputsInfo transpose_input_info = + utils::GetFirstTransposeInput(main_node, m_const_transpose_input, m_tranpose_indices); + + if (transformation_callback(main_node)) { + mark_as_no_sinking_node(transpose_input_info.transpose); + return false; + } + + bool res; + if (sinking_transformation) { + // use custom function to sink transpose + res = sinking_transformation(main_node, transpose_input_info); + } else { + // default transpose sinking function: + res = default_inputs_update(main_node, transpose_input_info); + if (res) { + default_outputs_update(main_node, transpose_input_info); + } + } + if (!res) { + mark_as_no_sinking_node(transpose_input_info.transpose); + } + return res; + }; + + auto m = std::make_shared(m_pattern, pass_name); + register_matcher(m, matcher_pass_callback); +} + +bool TSForwardBase::default_inputs_update(const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) { + return utils::sink_forward::UpdateInputTransposes(main_node, transpose_info); +} + +void TSForwardBase::default_outputs_update(const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) { + main_node->validate_and_infer_types(); + for (auto& new_node : utils::sink_forward::InsertOutputTransposes(main_node, transpose_info)) { + register_new_node(new_node); + mark_as_no_sinking_node(new_node); + } +} + +bool TSForwardBase::if_node_has_transpose_inputs(const Output& output, + bool const_transpose_input, + const std::vector& transpose_indices) { + utils::TransposeInputsInfo inputs_info = + utils::GetFirstTransposeInput(output.get_node_shared_ptr(), const_transpose_input, transpose_indices); + return !inputs_info.isEmpty(); +} diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp index 2b5603b3ef3..9e3b10b4222 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp @@ -20,42 +20,14 @@ using namespace ov::pass::pattern; using namespace ov::pass::transpose_sinking; using namespace ov::pass::transpose_sinking::utils; -TSBinaryForward::TSBinaryForward() { +TSBinaryForward::TSBinaryForward() : TSForwardBase() { MATCHER_SCOPE(TSBinaryForward); - - auto main_node_label = wrap_type([](const Output& output) -> bool { - return has_static_rank()(output) && IfNodeHasTransposeInputs(output); - }); - - matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { - const auto& pattern_to_output = m.get_pattern_value_map(); - auto& main_node_output = pattern_to_output.at(main_node_label); - auto main_node = main_node_output.get_node_shared_ptr(); - if (transformation_callback(main_node)) { - return false; - } - - TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); - - // todo: support dynamic rank case - bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_input_info); - if (!updated) { - return false; - } - main_node->validate_and_infer_types(); - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } - return true; - }; - - auto m = std::make_shared(main_node_label, matcher_name); - register_matcher(m, matcher_pass_callback); + create_pattern(true); + transpose_sinking(matcher_name); } TSBinaryBackward::TSBinaryBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_concat.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_concat.cpp index bb703faadf1..580cc07a73d 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_concat.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_concat.cpp @@ -21,45 +21,35 @@ using namespace ov::pass::transpose_sinking::utils; TSConcatForward::TSConcatForward() { MATCHER_SCOPE(TSConcatForward); - auto main_node_label = wrap_type(IfNodeHasTransposeInputs); + create_pattern(true); - matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { - const auto& pattern_to_output = m.get_pattern_value_map(); - - auto& main_node_output = pattern_to_output.at(main_node_label); - auto main_node = main_node_output.get_node_shared_ptr(); - if (transformation_callback(main_node)) { + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { + // todo: support dynamic rank case + auto concat_node = as_type_ptr(main_node); + if (!concat_node) { return false; } - TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); - auto concat_node = as_type_ptr(main_node); auto concat_axis = concat_node->get_concatenation_axis(); if (concat_axis < 0) { return false; } // todo: support dyn rank case - bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_input_info); + bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_info); if (!updated) { return false; } - const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); + const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); const int64_t transposed_concat_axis = transpose_axis_order[concat_axis]; concat_node->set_axis(transposed_concat_axis); concat_node->set_concatenation_axis(-1); - main_node->validate_and_infer_types(); - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } - + default_outputs_update(main_node, transpose_info); return true; }; - - auto m = std::make_shared(main_node_label, matcher_name); - register_matcher(m, matcher_pass_callback); + transpose_sinking(matcher_name, sinking_transformation); } TSConcatBackward::TSConcatBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp index 26e75201cff..1f4aed8e724 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp @@ -38,35 +38,17 @@ std::vector get_indices_by_op_type(const std::shared_ptr& main_nod TSDataMovementForward::TSDataMovementForward() { MATCHER_SCOPE(TSDataMovementForward); - auto const_label = wrap_type(); - auto transpose_label = wrap_type({any_input(), const_label}); - auto main_node_label = - wrap_type( - {transpose_label, any_input(), any_input(), any_input()}); - - matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { - const auto& pattern_to_node = m.get_pattern_map(); - - auto& main_node = pattern_to_node.at(main_node_label); - if (transformation_callback(main_node)) { - return false; - } - - auto transpose = std::dynamic_pointer_cast(pattern_to_node.at(transpose_label)); - if (!transpose) { - return false; - } - - auto transpose_const = as_type_ptr(pattern_to_node.at(const_label)); - if (!transpose_const) { - return false; - } + create_pattern( + true, + {0}); + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { // remove Transpose on 1st input: auto transpose_parent = main_node->input_value(0).get_node()->input_value(0); main_node->input(0).replace_source_output(transpose_parent); - const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order); auto axis = std::make_shared(element::i32, Shape{}, 0); @@ -80,17 +62,12 @@ TSDataMovementForward::TSDataMovementForward() { reverse_seq->set_batch_axis(transpose_axis_order[reverse_seq->get_batch_axis()]); reverse_seq->set_sequence_axis(transpose_axis_order[reverse_seq->get_sequence_axis()]); } - main_node->validate_and_infer_types(); - TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0}; - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } + + default_outputs_update(main_node, transpose_info); return true; }; - auto m = std::make_shared(main_node_label, matcher_name); - register_matcher(m, matcher_pass_callback); + transpose_sinking(matcher_name, sinking_transformation); } TSDataMovementBackward::TSDataMovementBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_fuse.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_fuse.cpp index 49638c83b0c..4ebc7dca880 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_fuse.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_fuse.cpp @@ -12,6 +12,7 @@ #include "openvino/op/constant.hpp" #include "openvino/op/transpose.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" #include "transformations/transpose_sinking/ts_utils.hpp" #include "transformations/utils/utils.hpp" @@ -72,7 +73,7 @@ TSFuse::TSFuse() { copy_runtime_info(transpose1, new_transpose); ov::replace_node(transpose1, new_transpose); - UpdateForwardSinkingAbility(new_transpose); + mark_as_no_sinking_node(new_transpose); } return true; }; diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_gather.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_gather.cpp index 967792da4c0..8752be04b1d 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_gather.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_gather.cpp @@ -23,22 +23,18 @@ using namespace ov::pass::transpose_sinking::utils; TSGatherForward::TSGatherForward() { MATCHER_SCOPE(TSGatherForward); - auto transpose_label = wrap_type({any_input(), wrap_type()}); - auto gather_label = - wrap_type({transpose_label, any_input(), wrap_type()}); + create_pattern(true, {0}); - ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { - const auto& pattern_to_output = m.get_pattern_map(); - - auto transpose = as_type_ptr(pattern_to_output.at(transpose_label)); - auto main_node = as_type_ptr(pattern_to_output.at(gather_label)); - if (transformation_callback(main_node) || !main_node) { + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { + auto gather = as_type_ptr(main_node); + if (!gather) { return false; } - auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); + auto transpose_order = transpose_info.transpose_const; auto gather_axis = as_type_ptr(main_node->get_input_node_shared_ptr(2)); - if (!transpose || !transpose_order || !gather_axis) { + if (!gather_axis) { return false; } @@ -53,7 +49,7 @@ TSGatherForward::TSGatherForward() { } const auto& order_val = transpose_order->cast_vector(); - auto batch_dims = static_cast(main_node->get_batch_dims()); + auto batch_dims = static_cast(gather->get_batch_dims()); for (size_t i = 0; i < batch_dims; ++i) { // transpose changes the order of batch dims if (order_val[i] != i) { @@ -88,7 +84,7 @@ TSGatherForward::TSGatherForward() { auto new_order_const = ov::op::v0::Constant::create(transpose_order->get_element_type(), {new_transpose_order.size()}, new_transpose_order); - TransposeInputsInfo transpose_input_info = {transpose, new_order_const, 0}; + TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_order_const, 0}; // deletes Transpose from 0 input auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0}); if (!success) { @@ -98,17 +94,12 @@ TSGatherForward::TSGatherForward() { ov::op::v0::Constant::create(gather_axis->get_element_type(), gather_axis->get_shape(), {order_val[axis]}); main_node->input(2).replace_source_output(new_axis); copy_runtime_info(gather_axis, new_axis); - main_node->validate_and_infer_types(); - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } + default_outputs_update(main_node, transpose_input_info); return true; }; - auto m = std::make_shared(gather_label, matcher_name); - register_matcher(m, matcher_pass_callback); + transpose_sinking(matcher_name, sinking_transformation); } TSGatherBackward::TSGatherBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp index 3f247fadca0..efbf30bb21f 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp @@ -17,6 +17,7 @@ #include "transformations/transpose_sinking/ts_gather.hpp" #include "transformations/transpose_sinking/ts_interpolate.hpp" #include "transformations/transpose_sinking/ts_reduction.hpp" +#include "transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp" #include "transformations/transpose_sinking/ts_slice.hpp" #include "transformations/transpose_sinking/ts_split.hpp" #include "transformations/transpose_sinking/ts_squeeze.hpp" @@ -73,6 +74,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr& f) { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.run_passes(f); } diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp index 98167310659..b338b1652e8 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp @@ -22,39 +22,21 @@ using namespace ov::pass::transpose_sinking::utils; TSInterpolateForward::TSInterpolateForward() { MATCHER_SCOPE(TSInterpolateForward); - auto const_label = wrap_type(); - auto transpose_label = wrap_type({any_input(), const_label}); - auto main_node_label = wrap_type({transpose_label, any_input(), any_input(), any_input()}); - - matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { - const auto& pattern_to_node = m.get_pattern_map(); - - auto& main_node = pattern_to_node.at(main_node_label); - if (transformation_callback(main_node)) { - return false; - } - - auto transpose = std::dynamic_pointer_cast(pattern_to_node.at(transpose_label)); - if (!transpose) { - return false; - } - - auto transpose_const = as_type_ptr(pattern_to_node.at(const_label)); - if (!transpose_const) { - return false; - } + create_pattern(true, {0}); + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { // remove Transpose on 1st input: - auto transpose_parent = transpose->input_value(0); + auto transpose_parent = transpose_info.transpose->input_value(0); main_node->input(0).replace_source_output(transpose_parent); - const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); auto axis = std::make_shared(element::i32, Shape{}, 0); - const auto& interpolate = std::dynamic_pointer_cast(main_node); const auto& new_axes = ChangeAxes(main_node->input_value(3), transpose_axis_order, axis); main_node->input(3).replace_source_output(new_axes); + const auto& interpolate = std::dynamic_pointer_cast(main_node); if (interpolate) { op::v4::Interpolate::InterpolateAttrs attrs = interpolate->get_attrs(); if (!attrs.pads_begin.empty() || !attrs.pads_end.empty()) { @@ -72,17 +54,11 @@ TSInterpolateForward::TSInterpolateForward() { } } - main_node->validate_and_infer_types(); - TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0}; - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } + default_outputs_update(main_node, transpose_info); return true; }; - auto m = std::make_shared(main_node_label, matcher_name); - register_matcher(m, matcher_pass_callback); + transpose_sinking(matcher_name, sinking_transformation); } TSInterpolateBackward::TSInterpolateBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp index c7fd8150ab0..be01f1c10ab 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp @@ -42,20 +42,11 @@ bool get_keep_dims(const std::shared_ptr& main_node) { TSReductionForward::TSReductionForward() { MATCHER_SCOPE(TSReductionForward); - auto transpose_label = wrap_type({any_input(), wrap_type()}); - auto reduce_label = wrap_type( - {transpose_label, wrap_type()}); - - ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { - const auto& pattern_to_output = m.get_pattern_map(); - auto transpose = as_type_ptr(pattern_to_output.at(transpose_label)); - auto main_node = pattern_to_output.at(reduce_label); - if (!transpose || transformation_callback(main_node)) { - return false; - } - + create_pattern(true, {0}); + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { auto keep_dims = get_keep_dims(main_node); - auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); + auto transpose_order = transpose_info.transpose_const; auto reduction_axes = as_type_ptr(main_node->get_input_node_shared_ptr(1)); if (!transpose_order || !reduction_axes) return false; @@ -84,7 +75,7 @@ TSReductionForward::TSReductionForward() { auto new_const = ov::op::v0::Constant::create(reduction_axes->get_element_type(), {new_values.size()}, new_values); main_node->input(1).replace_source_output(new_const); - TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0}; + TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_transpose_order, 0}; // deletes Transpose from 0 input auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0}); if (!success) { @@ -92,16 +83,12 @@ TSReductionForward::TSReductionForward() { } copy_runtime_info(reduction_axes, new_const); - main_node->validate_and_infer_types(); - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } + + default_outputs_update(main_node, transpose_input_info); return true; }; - auto m = std::make_shared(reduce_label, matcher_name); - register_matcher(m, matcher_pass_callback); + transpose_sinking(matcher_name, sinking_transformation); } TSReductionBackward::TSReductionBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_reset_no_sinking_attribute.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_reset_no_sinking_attribute.cpp new file mode 100644 index 00000000000..dea1a224345 --- /dev/null +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_reset_no_sinking_attribute.cpp @@ -0,0 +1,33 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp" + +#include "itt.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" + +using namespace ov; +using namespace ov::pass::pattern; +using namespace ov::pass::transpose_sinking; + +TSResetNoSinkingAttribute::TSResetNoSinkingAttribute() { + MATCHER_SCOPE(TSResetNoSinkingAttribute); + + auto transpose_label = wrap_type([](const Output& output) -> bool { + const auto& rt_info = output.get_node()->get_rt_info(); + return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) != rt_info.end(); + }); + ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { + const auto& pattern_to_output = m.get_pattern_map(); + + const auto& transpose = pattern_to_output.at(transpose_label); + ov::reset_no_sinking_attribute(transpose); + return false; + }; + auto m = std::make_shared(transpose_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp index e3d2790c9cf..861010edc65 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp @@ -23,34 +23,19 @@ using namespace ov::pass::transpose_sinking::utils; TSSliceForward::TSSliceForward() { MATCHER_SCOPE(TSSliceForward); - auto const_label = wrap_type(); - auto transpose_label = wrap_type({any_input(), const_label}); - auto main_node_label = - wrap_type({transpose_label, any_input(), any_input(), any_input(), any_input()}); + create_pattern(true, {0}); - matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { - const auto& pattern_to_node = m.get_pattern_map(); - - auto& main_node = pattern_to_node.at(main_node_label); - if (transformation_callback(main_node)) { - return false; - } - - auto transpose = std::dynamic_pointer_cast(pattern_to_node.at(transpose_label)); - if (!transpose || main_node->get_input_size() < 5) { - return false; - } - - auto transpose_const = as_type_ptr(pattern_to_node.at(const_label)); - if (!transpose_const) { + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { + if (main_node->get_input_size() < 5) { return false; } // remove Transpose on 1st input: - auto transpose_parent = transpose->input_value(0); + auto transpose_parent = transpose_info.transpose->input_value(0); main_node->input(0).replace_source_output(transpose_parent); - const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); auto axis = std::make_shared(element::i32, Shape{}, std::vector{0}); auto data = std::make_shared(element::i32, @@ -61,17 +46,11 @@ TSSliceForward::TSSliceForward() { main_node->input(4).replace_source_output(new_axis); - main_node->validate_and_infer_types(); - TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0}; - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } + default_outputs_update(main_node, transpose_info); return true; }; - auto m = std::make_shared(main_node_label, matcher_name); - register_matcher(m, matcher_pass_callback); + transpose_sinking(matcher_name, sinking_transformation); } TSSliceBackward::TSSliceBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp index 3bb37bc8dea..5062b2ae814 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp @@ -95,6 +95,40 @@ bool GetSplitAxis(const std::shared_ptr& split_axis, const } } // namespace +TSSplitForward::TSSplitForward() { + MATCHER_SCOPE(TSSplitForward); + + create_pattern(true, {0}); + + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { + auto split_axis_constant = as_type_ptr(main_node->input_value(1).get_node_shared_ptr()); + if (!split_axis_constant) { + return false; + } + + int64_t split_axis; + if (!GetSplitAxis(split_axis_constant, main_node->input_value(0).get_partial_shape().rank(), split_axis)) { + return false; + } + + sink_forward::RemoveInputNode(main_node, /* input_idx */ 0); + const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); + const size_t transposed_split_axis = transpose_axis_order[split_axis]; + auto new_split_axis_const = std::make_shared(split_axis_constant->get_element_type(), + Shape{}, + transposed_split_axis); + main_node->input(1).replace_source_output(new_split_axis_const); + copy_runtime_info({split_axis_constant, transpose_info.transpose, transpose_info.transpose_const}, + new_split_axis_const); + + default_outputs_update(main_node, transpose_info); + return true; + }; + + transpose_sinking(matcher_name, sinking_transformation); +} + /* * We follow Transpose operations rather than Split. We cannot create matcher pattern * for Split with Transpose outputs since Split can have different number of outputs. @@ -191,51 +225,3 @@ TSSplitBackward::TSSplitBackward() { auto m = std::make_shared(transpose_label, matcher_name); register_matcher(m, matcher_pass_callback); } - -TSSplitForward::TSSplitForward() { - MATCHER_SCOPE(TSSplitForward); - - auto main_node_label = wrap_type(IfNodeHasTransposeInputs); - - matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { - const auto& pattern_to_output = m.get_pattern_value_map(); - - auto& main_node_output = pattern_to_output.at(main_node_label); - auto main_node = main_node_output.get_node_shared_ptr(); - if (transformation_callback(main_node)) { - return false; - } - - auto split_axis_constant = as_type_ptr(main_node->input_value(1).get_node_shared_ptr()); - if (!split_axis_constant) { - return false; - } - - int64_t split_axis; - if (!GetSplitAxis(split_axis_constant, main_node->input_value(0).get_partial_shape().rank(), split_axis)) { - return false; - } - TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); - - sink_forward::RemoveInputNode(main_node, /* input_idx */ 0); - const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); - const size_t transposed_split_axis = transpose_axis_order[split_axis]; - auto new_split_axis_const = std::make_shared(split_axis_constant->get_element_type(), - Shape{}, - transposed_split_axis); - main_node->input(1).replace_source_output(new_split_axis_const); - copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const}, - new_split_axis_const); - main_node->validate_and_infer_types(); - - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } - - return true; - }; - - auto m = std::make_shared(main_node_label, matcher_name); - register_matcher(m, matcher_pass_callback); -} diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_squeeze.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_squeeze.cpp index bc7ca6036c1..7e9a85c68be 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_squeeze.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_squeeze.cpp @@ -104,32 +104,10 @@ bool squeeze_axes_to_shape(const Output& input_node, TSSqueezeForward::TSSqueezeForward() { MATCHER_SCOPE(TSSqueezeForward); - auto transpose_label = wrap_type({any_input(), wrap_type()}); - auto squeeze_with_1_input = wrap_type({transpose_label}); - auto squeeze_label = - wrap_type({transpose_label, wrap_type()}); - auto pattern = std::make_shared(OutputVector{squeeze_with_1_input, squeeze_label}); - - ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { - const auto& pattern_to_output = m.get_pattern_map(); - - auto transpose = as_type_ptr(pattern_to_output.at(transpose_label)); - std::shared_ptr main_node; - if (pattern_to_output.count(squeeze_label)) { - main_node = pattern_to_output.at(squeeze_label); - } else { - main_node = pattern_to_output.at(squeeze_with_1_input); - } - if (!transpose || transformation_callback(main_node)) { - return false; - } - - auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); - - if (!transpose_order) { - return false; - } + create_pattern(true, {0}); + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { std::vector non_negative_axes; std::shared_ptr squeeze_axes; if (main_node->get_input_size() > 1) { @@ -153,7 +131,7 @@ TSSqueezeForward::TSSqueezeForward() { // if 2nd input to main_node is empty then all '1' dims will be deleted. if (non_negative_axes.empty()) { - auto input_pshape = transpose->output(0).get_partial_shape(); + auto input_pshape = transpose_info.transpose->output(0).get_partial_shape(); if (input_pshape.is_dynamic()) { return false; } @@ -164,7 +142,7 @@ TSSqueezeForward::TSSqueezeForward() { } } - auto transpose_order_values = transpose_order->cast_vector(); + auto transpose_order_values = transpose_info.transpose_const->cast_vector(); std::vector new_values; new_values.reserve(non_negative_axes.size()); for (const auto& axis : non_negative_axes) { @@ -172,13 +150,13 @@ TSSqueezeForward::TSSqueezeForward() { } transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values); - auto new_transpose_order = ov::op::v0::Constant::create(transpose_order->get_element_type(), + auto new_transpose_order = ov::op::v0::Constant::create(transpose_info.transpose_const->get_element_type(), {transpose_order_values.size()}, transpose_order_values); if (as_type_ptr(main_node)) { std::vector to_shape; - auto success = squeeze_axes_to_shape(transpose->input_value(0), new_values, to_shape); + auto success = squeeze_axes_to_shape(transpose_info.transpose->input_value(0), new_values, to_shape); if (!success) { return false; } @@ -192,24 +170,18 @@ TSSqueezeForward::TSSqueezeForward() { copy_runtime_info(squeeze_axes, new_const); } - TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0}; + TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_transpose_order, 0}; // deletes Transpose from 0 input auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0}); if (!success) { return false; } - main_node->validate_and_infer_types(); - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } - + default_outputs_update(main_node, transpose_input_info); return true; }; - auto m = std::make_shared(pattern, matcher_name); - register_matcher(m, matcher_pass_callback); + transpose_sinking(matcher_name, sinking_transformation); } TSSqueezeBackward::TSSqueezeBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp index 416a75638e2..d4b4869c3eb 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp @@ -32,66 +32,21 @@ namespace { using NodePtr = std::shared_ptr; using NodePair = std::pair; -/** - * @brief SwapNodes allows to perform swapping nodes even if there are more than one consumers but has less performance - * - * @param first_node first node pointer - * @param second_node first node pointer - * @return NodePair pair of nodes in new order that allows to register them in MatcherPass - */ -NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) { - auto second_node_inputs = second_node->input_values(); - second_node_inputs[0] = first_node->input_value(0); - - auto new_first_node = second_node->clone_with_new_inputs(second_node_inputs); - - auto first_node_inputs = first_node->input_values(); - first_node_inputs[0] = new_first_node; - auto new_second_node = first_node->clone_with_new_inputs(first_node_inputs); - - new_second_node->set_friendly_name(second_node->get_friendly_name()); - ov::copy_runtime_info({first_node, second_node}, {new_first_node, new_second_node}); - - ov::replace_node(second_node, new_second_node); - - return std::make_pair(new_first_node, new_second_node); -} - } // namespace TSUnaryForward::TSUnaryForward() { MATCHER_SCOPE(TSUnaryForward); - auto transpose_label = wrap_type({any_input(), any_input()}); - auto unary_label = wrap_type({transpose_label}); - - ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { - const auto& pattern_to_output = m.get_pattern_value_map(); - auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); - auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); - if (transformation_callback(unary)) { - return false; - } - - const NodePair new_nodes = SwapNodes(transpose, unary); - - register_new_node(new_nodes.first); - register_new_node(new_nodes.second); - - UpdateForwardSinkingAbility(new_nodes.second); - return true; - }; - - auto m = std::make_shared(unary_label, "ov::pass::TSUnaryForward"); - register_matcher(m, matcher_pass_callback); + create_pattern(true); + transpose_sinking(matcher_name); } TSUnaryBackward::TSUnaryBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp index d5980625839..7abbf6f1699 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp @@ -104,22 +104,12 @@ bool unsqueeze_axes_to_shape(const Output& input_node, TSUnsqueezeForward::TSUnsqueezeForward() { MATCHER_SCOPE(TSUnsqueezeForward); - auto transpose_label = wrap_type({any_input(), wrap_type()}); - auto unsqueeze_label = - wrap_type({transpose_label, wrap_type()}); + create_pattern(true, {0}); - ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { - const auto& pattern_to_output = m.get_pattern_map(); - - auto transpose = as_type_ptr(pattern_to_output.at(transpose_label)); - auto main_node = pattern_to_output.at(unsqueeze_label); - if (!transpose || transformation_callback(main_node)) { - return false; - } - - auto transpose_order = as_type_ptr(transpose->get_input_node_shared_ptr(1)); + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { auto unsqueeze_axes = as_type_ptr(main_node->get_input_node_shared_ptr(1)); - if (!transpose_order || !unsqueeze_axes) { + if (!unsqueeze_axes) { return false; } @@ -136,17 +126,17 @@ TSUnsqueezeForward::TSUnsqueezeForward() { normalize_axes(main_node->get_friendly_name(), unsqueeze_axes->cast_vector(), rank); OPENVINO_SUPPRESS_DEPRECATED_END } - auto ts_order_values = transpose_order->cast_vector(); + auto ts_order_values = transpose_info.transpose_const->cast_vector(); ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values); - auto new_transpose_order = ov::op::v0::Constant::create(transpose_order->get_element_type(), + auto new_transpose_order = ov::op::v0::Constant::create(transpose_info.transpose_const->get_element_type(), {ts_order_values.size()}, ts_order_values); - std::shared_ptr new_unsqueeze; if (as_type_ptr(main_node)) { std::vector new_values; - auto success = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes, new_values); + auto success = + unsqueeze_axes_to_shape(transpose_info.transpose->input_value(0), non_negative_axes, new_values); if (!success) { return false; } @@ -156,24 +146,18 @@ TSUnsqueezeForward::TSUnsqueezeForward() { copy_runtime_info(unsqueeze_axes, new_const); } - TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0}; + TransposeInputsInfo transpose_input_info = {transpose_info.transpose, new_transpose_order, 0}; // deletes Transpose from 0 input auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0}); if (!success) { return false; } - main_node->validate_and_infer_types(); - for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { - register_new_node(new_node); - UpdateForwardSinkingAbility(new_node); - } - + default_outputs_update(main_node, transpose_input_info); return true; }; - auto m = std::make_shared(unsqueeze_label, matcher_name); - register_matcher(m, matcher_pass_callback); + transpose_sinking(matcher_name, sinking_transformation); } TSUnsqueezeBackward::TSUnsqueezeBackward() { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp index 114208158fe..b5a381bde11 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp @@ -66,14 +66,22 @@ Output ChangeAxes(const Output& indices, return ChangeAxes(indices, data, axis); } -TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node) { - for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) { +TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node, + bool const_transpose_order, + const std::vector& indices) { + auto indices_to_check = indices; + if (indices.empty()) { + indices_to_check.resize(node->get_input_size()); + std::iota(indices_to_check.begin(), indices_to_check.end(), 0); + } + + for (const auto& input_idx : indices_to_check) { NodePtr input_node = node->get_input_node_shared_ptr(input_idx); auto transpose_node = as_type_ptr(input_node); if (!transpose_node) continue; auto constant_node = as_type_ptr(transpose_node->input_value(1).get_node_shared_ptr()); - if (!constant_node) + if (const_transpose_order && !constant_node) continue; { TransposeInputsInfo input_info; @@ -87,11 +95,6 @@ TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node) { return {}; } -bool IfNodeHasTransposeInputs(const Output& output) { - TransposeInputsInfo inputs_info = GetFirstTransposeInput(output.get_node_shared_ptr()); - return !inputs_info.isEmpty(); -} - AxisVector ReverseTransposeOrder(const AxisVector& axis_order) { AxisVector out(axis_order.size()); for (size_t i = 0; i < axis_order.size(); i++) { @@ -104,6 +107,12 @@ void SwapOutputNames(Output output1, Output output2) { const auto node2_output_names = output2.get_names(); output2.set_names(output1.get_names()); output1.set_names(node2_output_names); + + OPENVINO_SUPPRESS_DEPRECATED_START + const auto node2_legacy_output_names = get_ov_tensor_legacy_name(output2.get_tensor()); + set_ov_tensor_legacy_name(output2.get_tensor(), get_ov_tensor_legacy_name(output1.get_tensor())); + set_ov_tensor_legacy_name(output1.get_tensor(), node2_legacy_output_names); + OPENVINO_SUPPRESS_DEPRECATED_END } void SwapFriendlyNames(const NodePtr& node1, const NodePtr& node2) { @@ -304,67 +313,6 @@ NodeVector InsertTransposeBeforeNode(const NodePtr& main_node, } } // namespace sink_backward -#define CHECK_TRANSPOSE_SINKING_SUPPORTED(TYPE, node) \ - if (dynamic_cast(node)) { \ - return true; \ - } - -namespace { - -bool CanPropagateForwardThrough(Node* node) { - // todo: collect this info automatically - CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::UnaryElementwiseArithmetic, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Clamp, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Elu, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v4::SoftPlus, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::LogicalNot, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Convert, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v10::IsInf, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v10::IsNaN, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v10::IsFinite, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::BinaryElementwiseArithmetic, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::BinaryElementwiseComparison, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::BinaryElementwiseLogical, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::PRelu, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Pad, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::BatchToSpace, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::SpaceToBatch, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::ReverseSequence, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v8::Gather, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v4::Interpolate, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::ArithmeticReductionKeepDims, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(op::util::LogicalReductionKeepDims, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v8::Slice, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Split, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::VariadicSplit, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Squeeze, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Reshape, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Unsqueeze, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Transpose, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::FakeQuantize, node) - CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Concat, node) - - return false; -} - -bool CanPropagateForward(const NodePtr& node) { - for (size_t i = 0; i < node->get_output_size(); ++i) { - for (auto& consumer_input : node->output(i).get_target_inputs()) { - if (!CanPropagateForwardThrough(consumer_input.get_node())) - return false; - } - } - - return true; -} - -} // namespace - -void UpdateForwardSinkingAbility(const NodePtr& node) { - if (!CanPropagateForward(node)) - mark_as_no_sinking_node(node); -} - namespace { std::shared_ptr GetTransposeConstant(Node* node) { diff --git a/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp index eda664acfce..ed5e27a6e62 100644 --- a/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp @@ -394,13 +394,13 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) { { auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto node0 = MakeAllNodesSubgraph(X, 1, 1); - auto node0 = MakeAllNodesSubgraph(transpose0, 3, 3); + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(node0, ng_order0); auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96}); - auto reshape = std::make_shared(node0, reshape_const, false); + auto reshape = std::make_shared(transpose0, reshape_const, false); auto node1 = MakeAllNodesSubgraph(reshape, 3, 3); diff --git a/src/common/transformations/tests/transpose_sinking/ts_reset_no_sinking_attribute.cpp b/src/common/transformations/tests/transpose_sinking/ts_reset_no_sinking_attribute.cpp new file mode 100644 index 00000000000..207bf8a025f --- /dev/null +++ b/src/common/transformations/tests/transpose_sinking/ts_reset_no_sinking_attribute.cpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" + +using namespace testing; +using namespace std; +using namespace ov; + +using namespace op::v0; +using namespace op::v1; + +TEST(TransformationTests, ResetNoSinkingAttribute) { + auto a = std::make_shared(element::f32, Shape{12, 3, 4, 8}); + auto b = std::make_shared(element::f32, Shape{12, 3, 4, 8}); + + auto transpose_a = make_shared(a, Constant::create(element::i64, Shape{4}, {1, 0, 2, 3})); + auto transpose_b = make_shared(b, Constant::create(element::i64, Shape{4}, {1, 0, 2, 3})); + + auto add = std::make_shared(transpose_a, transpose_b); + auto trans_after = make_shared(add, Constant::create(element::i64, Shape{4}, {1, 0, 2, 3})); + auto model = std::make_shared(NodeVector{trans_after}, ParameterVector{a, b}); + + mark_as_no_sinking_node(transpose_a); + mark_as_no_sinking_node(transpose_b); + mark_as_no_sinking_node(trans_after); + + const auto& ops = model->get_ordered_ops(); + const auto cnt_before = count_if(ops.begin(), ops.end(), [](const std::shared_ptr& node) { + const auto& rt_info = node->get_rt_info(); + return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) != rt_info.end(); + }); + + EXPECT_EQ(cnt_before, 3); + ov::pass::Manager manager; + manager.register_pass(); + manager.run_passes(model); + + const auto cnt_after = count_if(ops.begin(), ops.end(), [](const std::shared_ptr& node) { + const auto& rt_info = node->get_rt_info(); + return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) != rt_info.end(); + }); + + EXPECT_EQ(cnt_after, 0); +} \ No newline at end of file