From 4d9f2f3cd7e012b5ef879a2f9af03703247c60cd Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Tue, 10 Oct 2023 13:06:10 +0200 Subject: [PATCH] Support new operations in TS: Selu, Swish, HSwish, Tile, CumSum, HardSigmoid (#19990) * add new operations as unary * get unary as input(0) instead of iterating pattern map * add CumSum + unit tests * add Tile + unit tests * add tile * fix ts_tile * code review fix: use ADD_MATCHER * fix bug CI tests --- .../transpose_sinking/ts_cumsum.hpp | 41 ++++ .../transpose_sinking/ts_tile.hpp | 41 ++++ .../transpose_sinking/ts_cumsum.cpp | 92 +++++++ .../transpose_sinking/ts_general.cpp | 57 +++-- .../transpose_sinking/ts_tile.cpp | 93 ++++++++ .../transpose_sinking/ts_unary.cpp | 50 +++- .../transpose_sinking/ts_utils.cpp | 1 + .../transpose_sinking/ts_common_test.cpp | 224 ++++++++++++++++++ .../tests/transpose_sinking/ts_unary_test.cpp | 55 ++++- 9 files changed, 605 insertions(+), 49 deletions(-) create mode 100644 src/common/transformations/include/transformations/transpose_sinking/ts_cumsum.hpp create mode 100644 src/common/transformations/include/transformations/transpose_sinking/ts_tile.hpp create mode 100644 src/common/transformations/src/transformations/transpose_sinking/ts_cumsum.cpp create mode 100644 src/common/transformations/src/transformations/transpose_sinking/ts_tile.cpp diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_cumsum.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_cumsum.hpp new file mode 100644 index 00000000000..d8c70e65ad2 --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_cumsum.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSCumSumForward; +class TRANSFORMATIONS_API TSCumSumBackward; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSCumSumForward transformation sinks Transpose through CumSum in the forward direction. + */ +class ov::pass::transpose_sinking::TSCumSumForward : public ov::pass::transpose_sinking::TSForwardBase { +public: + OPENVINO_RTTI("ov::pass::TSBinaryForward", "0"); + TSCumSumForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSCumSumBackward transformation sinks Transpose through CumSum in the backward direction. + */ +class ov::pass::transpose_sinking::TSCumSumBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0"); + TSCumSumBackward(); +}; diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_tile.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_tile.hpp new file mode 100644 index 00000000000..cd125ca0563 --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_tile.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations/transpose_sinking/ts_base.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSTileForward; +class TRANSFORMATIONS_API TSTileBackward; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSTileForward transformation sinks Transpose through Tile in the forward direction. + */ +class ov::pass::transpose_sinking::TSTileForward : public ov::pass::transpose_sinking::TSForwardBase { +public: + OPENVINO_RTTI("ov::pass::TSBinaryForward", "0"); + TSTileForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSTileBackward transformation sinks Transpose through Tile in the backward direction. + */ +class ov::pass::transpose_sinking::TSTileBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0"); + TSTileBackward(); +}; diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_cumsum.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_cumsum.cpp new file mode 100644 index 00000000000..623724f8ccf --- /dev/null +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_cumsum.cpp @@ -0,0 +1,92 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/transpose_sinking/ts_cumsum.hpp" + +#include "itt.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/cum_sum.hpp" +#include "openvino/op/fake_quantize.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" + +using namespace ov; +using namespace ov::pass::pattern; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; + +#undef CUMSUM_AXIS_INPUT_IDX +#define CUMSUM_AXIS_INPUT_IDX 1 + +TSCumSumForward::TSCumSumForward() { + MATCHER_SCOPE(TSCumSumForward); + + create_pattern(true, {0}); + + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { + if (transformation_callback(main_node)) { + return false; + } + + bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0}); + if (!res) + return res; + + const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); + auto axis = std::make_shared(element::i32, Shape{}, 0); + const auto& new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), transpose_axis_order, axis); + main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes); + + default_outputs_update(main_node, transpose_info); + return true; + }; + transpose_sinking(matcher_name, sinking_transformation); +} + +TSCumSumBackward::TSCumSumBackward() { + MATCHER_SCOPE(TSCumSumBackward); + auto main_node_label = wrap_type([](const Output& output) -> bool { + return has_static_rank()(output) && CheckTransposeConsumers(output); + }); + + auto transpose_const_label = wrap_type(); + + auto transpose_label = wrap_type({main_node_label, transpose_const_label}, + [](const Output& output) -> bool { + return has_static_rank()(output); + }); + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose_const = + as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + + if (transformation_callback(main_node)) { + return false; + } + + for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, + transpose_const, + /* input_indexes= */ {0})) { + register_new_node(new_node); + } + + RemoveTransposeConsumers(main_node); + const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order); + auto axis = std::make_shared(element::i32, Shape{}, 0); + auto new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), reversed_transpose_order, axis); + main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes); + + main_node->validate_and_infer_types(); + return true; + }; + auto m = std::make_shared(transpose_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp index 4b4a0835a9d..ceae4cd45e6 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp @@ -13,6 +13,7 @@ #include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp" #include "transformations/transpose_sinking/ts_binary.hpp" #include "transformations/transpose_sinking/ts_concat.hpp" +#include "transformations/transpose_sinking/ts_cumsum.hpp" #include "transformations/transpose_sinking/ts_data_movement.hpp" #include "transformations/transpose_sinking/ts_fuse.hpp" #include "transformations/transpose_sinking/ts_gather.hpp" @@ -23,6 +24,7 @@ #include "transformations/transpose_sinking/ts_slice.hpp" #include "transformations/transpose_sinking/ts_split.hpp" #include "transformations/transpose_sinking/ts_squeeze.hpp" +#include "transformations/transpose_sinking/ts_tile.hpp" #include "transformations/transpose_sinking/ts_unary.hpp" #include "transformations/transpose_sinking/ts_unsqueeze.hpp" #include "transformations/utils/utils.hpp" @@ -31,35 +33,40 @@ using namespace ov::pass::transpose_sinking; TSGeneralForward::TSGeneralForward() { MATCHER_SCOPE(TSGeneralForward); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); + ADD_MATCHER(this, TSUnaryForward); + ADD_MATCHER(this, TSBinaryForward); + ADD_MATCHER(this, TSConcatForward); + ADD_MATCHER(this, TSSplitForward); + ADD_MATCHER(this, TSDataMovementForward); + ADD_MATCHER(this, TSReductionForward); + ADD_MATCHER(this, TSSqueezeForward); + ADD_MATCHER(this, TSUnsqueezeForward); + ADD_MATCHER(this, TSInterpolateForward); + ADD_MATCHER(this, TSSliceForward); + ADD_MATCHER(this, TSGatherForward); + ADD_MATCHER(this, TSShapeOfForward); + ADD_MATCHER(this, TSCumSumForward); + ADD_MATCHER(this, TSTileForward); + ADD_MATCHER(this, TSFuse); } TSGeneralBackward::TSGeneralBackward() { MATCHER_SCOPE(TSGeneralBackward); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); + ADD_MATCHER(this, TSUnaryBackward); + ADD_MATCHER(this, TSUnaryBackward); + ADD_MATCHER(this, TSBinaryBackward); + ADD_MATCHER(this, TSConcatBackward); + ADD_MATCHER(this, TSSplitBackward); + ADD_MATCHER(this, TSDataMovementBackward); + ADD_MATCHER(this, TSReductionBackward); + ADD_MATCHER(this, TSSqueezeBackward); + ADD_MATCHER(this, TSUnsqueezeBackward); + ADD_MATCHER(this, TSInterpolateBackward); + ADD_MATCHER(this, TSSliceBackward); + ADD_MATCHER(this, TSGatherBackward); + ADD_MATCHER(this, TSCumSumBackward); + ADD_MATCHER(this, TSTileBackward); + ADD_MATCHER(this, TSFuse); } bool TSGeneral::run_on_model(const std::shared_ptr& f) { diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_tile.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_tile.cpp new file mode 100644 index 00000000000..dda10b0a6cb --- /dev/null +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_tile.cpp @@ -0,0 +1,93 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/transpose_sinking/ts_tile.hpp" + +#include "itt.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/fake_quantize.hpp" +#include "openvino/op/tile.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" + +using namespace ov; +using namespace ov::pass::pattern; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; + +#undef TILE_REPEATS_INPUT_IDX +#define TILE_REPEATS_INPUT_IDX 1 + +TSTileForward::TSTileForward() { + MATCHER_SCOPE(TSTileForward); + + create_pattern(true, {0}); + + auto sinking_transformation = [=](const std::shared_ptr& main_node, + const TransposeInputsInfo& transpose_info) -> bool { + if (transformation_callback(main_node)) { + return false; + } + + bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0}); + if (!res) + return res; + + const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val(); + auto repeats = std::make_shared(element::i32, Shape{}, 0); + const auto& new_repeats = + ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats); + main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats); + + default_outputs_update(main_node, transpose_info); + return true; + }; + transpose_sinking(matcher_name, sinking_transformation); +} + +TSTileBackward::TSTileBackward() { + MATCHER_SCOPE(TSTileBackward); + auto main_node_label = wrap_type([](const Output& output) -> bool { + return has_static_rank()(output) && CheckTransposeConsumers(output); + }); + + auto transpose_const_label = wrap_type(); + + auto transpose_label = wrap_type({main_node_label, transpose_const_label}, + [](const Output& output) -> bool { + return has_static_rank()(output); + }); + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose_const = + as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + + if (transformation_callback(main_node)) { + return false; + } + + for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, + transpose_const, + /* input_indexes= */ {0})) { + register_new_node(new_node); + } + + RemoveTransposeConsumers(main_node); + const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + auto repeats = std::make_shared(element::i32, Shape{}, 0); + auto new_repeats = + ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats); + main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats); + + main_node->validate_and_infer_types(); + return true; + }; + auto m = std::make_shared(transpose_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp index d4b4869c3eb..5814634e740 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp @@ -16,6 +16,7 @@ #include "openvino/op/logical_not.hpp" #include "openvino/op/softplus.hpp" #include "openvino/op/transpose.hpp" +#include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp" #include "transformations/transpose_sinking/ts_utils.hpp" @@ -45,8 +46,21 @@ TSUnaryForward::TSUnaryForward() { ov::op::v0::Convert, ov::op::v10::IsInf, ov::op::v10::IsNaN, - ov::op::v10::IsFinite>(true); - transpose_sinking(matcher_name); + ov::op::v10::IsFinite, + ov::op::v0::Selu, + ov::op::v4::Swish, + ov::op::v0::HardSigmoid, + ov::op::v5::LogSoftmax, + ov::op::v1::ConvertLike>(true); + auto ts_unary_sinking_function = [this](const std::shared_ptr& main_node, + const utils::TransposeInputsInfo& transpose_info) -> bool { + bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, {0}); + if (!res) + return res; + default_outputs_update(main_node, transpose_info); + return true; + }; + transpose_sinking(matcher_name, ts_unary_sinking_function); } TSUnaryBackward::TSUnaryBackward() { @@ -56,15 +70,25 @@ TSUnaryBackward::TSUnaryBackward() { return CheckTransposeConsumers(output); }; - auto unary_label = wrap_type({any_input()}, unary_restrictions); + auto unary_with_1_input_label = wrap_type({any_input()}, unary_restrictions); + + auto unary_with_2_inputs_label = + wrap_type({any_input(), any_input()}, unary_restrictions); + auto unary_with_3_inputs_label = + wrap_type({any_input(), any_input(), any_input()}, + unary_restrictions); + + auto unary_label = std::make_shared( + ov::OutputVector{unary_with_1_input_label, unary_with_2_inputs_label, unary_with_3_inputs_label}); auto transpose_const_label = wrap_type(); @@ -75,12 +99,12 @@ TSUnaryBackward::TSUnaryBackward() { auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); - auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); + auto unary = transpose->get_input_node_shared_ptr(0); if (transformation_callback(unary)) { return false; } - for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) { + for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const, {0})) { register_new_node(new_node); } unary->validate_and_infer_types(); diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp index 177267581eb..38073bc8848 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp @@ -59,6 +59,7 @@ Output ChangeAxes(const Output& indices, copy_runtime_info(indices.get_node_shared_ptr(), gather); return gather; } + Output ChangeAxes(const Output& indices, const AxisVector& transpose_axis_order, const std::shared_ptr& axis) { diff --git a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp index 9a00aa9773e..1da471f1666 100644 --- a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp @@ -8,12 +8,14 @@ #include "openvino/pass/manager.hpp" #include "transformations/transpose_sinking/ts_binary.hpp" #include "transformations/transpose_sinking/ts_concat.hpp" +#include "transformations/transpose_sinking/ts_cumsum.hpp" #include "transformations/transpose_sinking/ts_data_movement.hpp" #include "transformations/transpose_sinking/ts_interpolate.hpp" #include "transformations/transpose_sinking/ts_reduction.hpp" #include "transformations/transpose_sinking/ts_slice.hpp" #include "transformations/transpose_sinking/ts_split.hpp" #include "transformations/transpose_sinking/ts_squeeze.hpp" +#include "transformations/transpose_sinking/ts_tile.hpp" #include "transformations/transpose_sinking/ts_unary.hpp" #include "transformations/transpose_sinking/ts_unsqueeze.hpp" #include "ts_test_case.hpp" @@ -206,6 +208,30 @@ FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_refere return std::make_shared(type_name, is_reference); } +class CumSumFactory : public IFactory { +public: + explicit CumSumFactory(const std::string& type_name) : IFactory(type_name) {} + NodePtr create(const OutputVector& parent_nodes) const override { + return std::make_shared(parent_nodes[0], parent_nodes[1]); + } +}; + +FactoryPtr CreateCumSumFactory(const std::string& type_name) { + return std::make_shared(type_name); +} + +class TileFactory : public IFactory { +public: + explicit TileFactory(const std::string& type_name) : IFactory(type_name) {} + NodePtr create(const OutputVector& parent_nodes) const override { + return std::make_shared(parent_nodes[0], parent_nodes[1]); + } +}; + +FactoryPtr CreateTileFactory(const std::string& type_name) { + return std::make_shared(type_name); +} + class SliceFactory : public IFactory { public: explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {} @@ -285,6 +311,12 @@ FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) { #undef CREATE_INTERPOLATE_FACTORY #define CREATE_INTERPOLATE_FACTORY(type_name, reference_flag) CreateInterpolateFactory(#type_name, reference_flag) +#undef CREATE_CUMSUM_FACTORY +#define CREATE_CUMSUM_FACTORY(type_name) CreateCumSumFactory(#type_name) + +#undef CREATE_TILE_FACTORY +#define CREATE_TILE_FACTORY(type_name) CreateTileFactory(#type_name) + #undef CREATE_SLICE_FACTORY #define CREATE_SLICE_FACTORY(type_name) CreateSliceFactory(#type_name) @@ -761,6 +793,84 @@ auto test_forward_interpolate = []() { INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward, TSTestFixture, test_forward_interpolate()); +auto test_forward_cumsum = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSCumSumForward); + test_case.num_main_ops = {1}; + test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}), + constant(element::i64, {}, std::vector{0})}; + + // Test model description: + test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; + test_case.model.main_op = {CREATE_CUMSUM_FACTORY(CumSum)}; + test_case.model.model_template = create_model; + + // Reference model description: + auto set_specific_gather_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { + OutputVector result = out_vec; + for (const auto& idx : idxs) { + const auto& out = out_vec[idx]; + vector transpose_order(out_vec[0].get_shape().size()); + iota(transpose_order.begin(), transpose_order.end(), 0); + reverse(transpose_order.begin(), transpose_order.end()); + auto data = make_shared(element::i32, Shape{transpose_order.size()}, transpose_order); + auto axis = make_shared(element::i32, Shape{}, 0); + auto transpose = make_shared(data, out, axis); + result[idx] = transpose; + } + return result; + }; + test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{1}}}; + test_case.model_ref.main_op = {CREATE_CUMSUM_FACTORY(CumSum)}; + test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; + test_case.model_ref.model_template = create_model; + + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonCumSumForward, TSTestFixture, test_forward_cumsum()); + +auto test_forward_tile = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSTileForward); + test_case.num_main_ops = {1}; + test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}), + constant(element::i64, {4}, std::vector{1, 2, 3, 4})}; + + // Test model description: + test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; + test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)}; + test_case.model.model_template = create_model; + + // Reference model description: + auto set_specific_gather_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { + OutputVector result = out_vec; + for (const auto& idx : idxs) { + const auto& out = out_vec[idx]; + vector transpose_order(out_vec[0].get_shape().size()); + iota(transpose_order.begin(), transpose_order.end(), 0); + reverse(transpose_order.begin(), transpose_order.end()); + auto data = make_shared(element::i32, Shape{transpose_order.size()}, transpose_order); + auto axis = make_shared(element::i32, Shape{}, 0); + auto transpose = make_shared(out, data, axis); + result[idx] = transpose; + } + return result; + }; + test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{1}}}; + test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)}; + test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; + test_case.model_ref.model_template = create_model; + + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileForward, TSTestFixture, test_forward_tile()); + auto test_forward_squeeze = []() { TestCase test_case; @@ -1262,6 +1372,120 @@ auto test_backward_interpolate = []() { INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward, TSTestFixture, test_backward_interpolate()); +auto test_backward_cumsum = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSCumSumBackward); + test_case.num_main_ops = {1}; + test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}), + constant(element::i64, {}, std::vector{0})}; + + // Test model description: + test_case.model.main_op = {CREATE_CUMSUM_FACTORY(CumSum)}; + test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; + test_case.model.model_template = create_model; + + // Reference model description: + auto set_specific_gather_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { + OutputVector result = out_vec; + for (const auto& idx : idxs) { + const auto& out = out_vec[idx]; + vector transpose_order(out_vec[0].get_shape().size()); + iota(transpose_order.begin(), transpose_order.end(), 0); + reverse(transpose_order.begin(), transpose_order.end()); + auto data = make_shared(element::i32, Shape{transpose_order.size()}, transpose_order); + auto axis = make_shared(element::i32, Shape{}, 0); + auto transpose = make_shared(data, out, axis); + result[idx] = transpose; + } + return result; + }; + test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}}; + test_case.model_ref.main_op = {CREATE_CUMSUM_FACTORY(CumSum)}; + test_case.model_ref.model_template = create_model; + + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonCumSumBackward, TSTestFixture, test_backward_cumsum()); + +auto test_backward_tile = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSTileBackward); + test_case.num_main_ops = {1}; + test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}), + constant(element::i64, {4}, std::vector{1, 2, 3, 4})}; + + // Test model description: + test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)}; + test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; + test_case.model.model_template = create_model; + + // Reference model description: + auto set_specific_gather_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { + OutputVector result = out_vec; + for (const auto& idx : idxs) { + const auto& out = out_vec[idx]; + vector transpose_order(out_vec[0].get_shape().size()); + iota(transpose_order.begin(), transpose_order.end(), 0); + reverse(transpose_order.begin(), transpose_order.end()); + auto data = make_shared(element::i32, Shape{transpose_order.size()}, transpose_order); + auto axis = make_shared(element::i32, Shape{}, 0); + auto transpose = make_shared(out, data, axis); + result[idx] = transpose; + } + return result; + }; + test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}}; + test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)}; + test_case.model_ref.model_template = create_model; + + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileBackward, TSTestFixture, test_backward_tile()); + +auto test_backward_tile_tf_case = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSTileBackward); + test_case.num_main_ops = {1}; + test_case.inputs_to_main = {parameter(element::f32, {2, 1, 1, 128}), + constant(element::i64, {4}, std::vector{1, 1, 88, 1})}; + + // Test model description: + test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)}; + test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; + test_case.model.model_template = create_model; + + // Reference model description: + auto set_specific_gather_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { + OutputVector result = out_vec; + for (const auto& idx : idxs) { + const auto& out = out_vec[idx]; + vector transpose_order(out_vec[0].get_shape().size()); + iota(transpose_order.begin(), transpose_order.end(), 0); + reverse(transpose_order.begin(), transpose_order.end()); + auto data = make_shared(element::i32, Shape{transpose_order.size()}, transpose_order); + auto axis = make_shared(element::i32, Shape{}, 0); + auto transpose = make_shared(out, data, axis); + result[idx] = transpose; + } + return result; + }; + test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}}; + test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)}; + test_case.model_ref.model_template = create_model; + + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileBackwardTfCase, TSTestFixture, test_backward_tile_tf_case()); + auto test_backward_unsqueeze = []() { TestCase test_case; diff --git a/src/common/transformations/tests/transpose_sinking/ts_unary_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_unary_test.cpp index e47f378cb2b..8076edf43b2 100644 --- a/src/common/transformations/tests/transpose_sinking/ts_unary_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_unary_test.cpp @@ -7,7 +7,7 @@ #include "common_test_utils/ov_test_utils.hpp" #include "gtest/gtest.h" #include "openvino/frontend/manager.hpp" -#include "openvino/opsets/opset10.hpp" +#include "openvino/opsets/opset12.hpp" #include "openvino/pass/manager.hpp" #include "ts_test_utils.hpp" @@ -85,6 +85,37 @@ NodePtr UnaryFactory::create(const OutputVector& inputs) const { return std::make_shared(inputs[0], element::f64); } +template <> +NodePtr UnaryFactory::create(const OutputVector& inputs) const { + auto alpha = std::make_shared(element::f32, Shape{}, 2.0); + auto lambda = std::make_shared(element::f32, Shape{}, 3.0); + return std::make_shared(inputs[0], alpha, lambda); +} + +template <> +NodePtr UnaryFactory::create(const OutputVector& inputs) const { + auto beta = std::make_shared(element::f32, Shape{}, 0.9); + return std::make_shared(inputs[0], beta); +} + +template <> +NodePtr UnaryFactory::create(const OutputVector& inputs) const { + auto alpha = std::make_shared(element::f32, Shape{}, 2.0); + auto beta = std::make_shared(element::f32, Shape{}, 3.0); + return std::make_shared(inputs[0], alpha, beta); +} + +template <> +NodePtr UnaryFactory::create(const OutputVector& inputs) const { + return std::make_shared(inputs[0], 2); +} + +template <> +NodePtr UnaryFactory::create(const OutputVector& inputs) const { + auto like = std::make_shared(element::f64, Shape{}, 1); + return std::make_shared(inputs[0], like); +} + template FactoryPtr CreateUnaryFactory(const std::string& type_name) { return std::make_shared>(type_name); @@ -352,16 +383,18 @@ std::shared_ptr CreateReferenceFunction(const FactoryPtr& unary_facto } // namespace mult_consumers_first_node std::vector unary_factories = { - CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus), - CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs), - CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), - CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos), - CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp), - CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish), - CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu), - CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin), - CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt), - CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)}; + CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus), + CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs), + CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), + CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos), + CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp), + CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish), + CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu), + CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin), + CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt), + CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh), CREATE_UNARY_FACTORY(Selu), + CREATE_UNARY_FACTORY(Swish), CREATE_UNARY_FACTORY(HardSigmoid), CREATE_UNARY_FACTORY(LogSoftmax), + CREATE_UNARY_FACTORY(ConvertLike)}; TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) { FactoryPtr unary_factory;