From 04f300e187c8cb0cca807d9fb03088fbf3e9a922 Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Fri, 17 Feb 2023 18:36:22 +0100 Subject: [PATCH] Gather Sinking for Unary operations (#15289) * initial * fix year * CanGatherPropagateForward review fix * HasSameOutputGatherNodes fix code review * fix namespaces code review * fix ReverseGatherIndexes code review * clang fixes * clang fixes * remove unneeded function * move utils to utils dir + change namespace * clang fixes * windows build fixes * resotore attr file * resotore attr file * code review fix --- .../transformations/gather_sinking_unary.cpp | 206 +++++++ .../transformations/gather_sinking_unary.hpp | 111 ++++ .../rt_info/gather_sinking_attr.cpp | 28 + .../rt_info/gather_sinking_attr.hpp | 37 ++ .../utils/gather_sinking_utils.cpp | 260 ++++++++ .../utils/gather_sinking_utils.hpp | 64 ++ .../gather_sinking_unary_test.cpp | 563 ++++++++++++++++++ 7 files changed, 1269 insertions(+) create mode 100644 src/plugins/intel_gna/src/transformations/gather_sinking_unary.cpp create mode 100644 src/plugins/intel_gna/src/transformations/gather_sinking_unary.hpp create mode 100644 src/plugins/intel_gna/src/transformations/rt_info/gather_sinking_attr.cpp create mode 100644 src/plugins/intel_gna/src/transformations/rt_info/gather_sinking_attr.hpp create mode 100644 src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp create mode 100644 src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.hpp create mode 100644 src/plugins/intel_gna/tests/unit/transformations/gather_sinking_unary_test.cpp diff --git a/src/plugins/intel_gna/src/transformations/gather_sinking_unary.cpp b/src/plugins/intel_gna/src/transformations/gather_sinking_unary.cpp new file mode 100644 index 00000000000..1fb88d78052 --- /dev/null +++ b/src/plugins/intel_gna/src/transformations/gather_sinking_unary.cpp @@ -0,0 +1,206 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/gather_sinking_unary.hpp" + +#include +#include +#include + +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/rt_info/gather_sinking_attr.hpp" +#include "transformations/utils/gather_sinking_utils.hpp" + +using namespace ov; +using namespace ov::opset9; +using namespace ov::pass::pattern; +using namespace ov::op::util; +using namespace gather_sinking; +using namespace ov::intel_gna::pass; +using namespace ov::intel_gna::rt_info; + +namespace { + +using NodePtr = std::shared_ptr; +using NodePair = std::pair; + +/** + * @brief SwapNodes allows to perform swapping nodes even if there are more than one consumers but has less performance + * + * @param first_node first node pointer + * @param second_node first node pointer + * @return NodePair pair of nodes in new order that allows to register them in MatcherPass + */ +NodePair SwapNodes(NodePtr first_node, NodePtr second_node) { + auto second_node_inputs = second_node->input_values(); + second_node_inputs[0] = first_node->input_value(0); + + auto new_first_node = second_node->clone_with_new_inputs(second_node_inputs); + + auto first_node_inputs = first_node->input_values(); + first_node_inputs[0] = new_first_node; + auto new_second_node = first_node->clone_with_new_inputs(first_node_inputs); + + new_second_node->set_friendly_name(second_node->get_friendly_name()); + ov::copy_runtime_info({first_node, second_node}, {new_first_node, new_second_node}); + + ov::replace_node(second_node, new_second_node); + + return std::make_pair(new_first_node, new_second_node); +} + +/** + * @brief SwapOutputs has much better performance than SwapNodes and covers the most of the real situations + * but cannot work when the consumers count greater than one + * @param first_node first node pointer + * @param second_node second node pointer + * @return NodePair pair of nodes in new order that allows to register them in MatcherPass + */ +NodePair SwapOutputs(NodePtr first_node, NodePtr second_node) { + const auto first_node_output_names = first_node->output(0).get_names(); + const auto second_node_output_names = second_node->output(0).get_names(); + + auto swap_names = [&]() { + const std::string first_name = first_node->get_friendly_name(); + first_node->set_friendly_name(second_node->get_friendly_name()); + second_node->set_friendly_name(first_name); + + first_node->output(0).set_names(second_node_output_names); + second_node->output(0).set_names(first_node_output_names); + }; + + auto out_1 = first_node->input_value(0); + second_node->input(0).replace_source_output(out_1); + + auto out_2 = second_node->output(0); + second_node->output(0).replace(first_node->output(0)); + + first_node->input(0).replace_source_output(out_2); + + swap_names(); + + return std::make_pair(second_node, first_node); +} + +/** + * Swapping inputs/outputs has better perfomance that Swapping nodes with clone but it cannot be used + * in multiple consumers case + */ +NodePair Swap(NodePtr first_node, NodePtr second_node) { + NodePair new_nodes; + + if (first_node->output(0).get_target_inputs().size() > 1 || second_node->output(0).get_target_inputs().size() > 1) + new_nodes = SwapNodes(first_node, second_node); + else + new_nodes = SwapOutputs(first_node, second_node); + + return new_nodes; +} + +} // namespace + +GatherSinkingUnaryForward::GatherSinkingUnaryForward() { + MATCHER_SCOPE(GatherSinkingUnaryForward); + auto gather_label = wrap_type({any_input(), any_input(), any_input()}); + auto unary_label = wrap_type({gather_label}); + + ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto gather = pattern_to_output.at(gather_label).get_node_shared_ptr(); + auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); + + const NodePair new_nodes = Swap(gather, unary); + + register_new_node(new_nodes.first); + register_new_node(new_nodes.second); + + UpdateForwardGatherSinkingAbility(new_nodes.second); + + return true; + }; + + auto m = std::make_shared(unary_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +namespace { +bool IfGatherSinkingEnabled(const Output& output) { + return is_gather_sinking_node(output.get_node_shared_ptr()); +} +} // namespace + +GatherSinkingUnaryBackwardSingleConsumer::GatherSinkingUnaryBackwardSingleConsumer() { + MATCHER_SCOPE(GatherSinkingUnaryBackwardSingleConsumer); + auto unary_label = + wrap_type({any_input()}, + consumers_count(1)); + + auto gather_label = wrap_type({unary_label, any_input(), any_input()}, IfGatherSinkingEnabled); + + ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto gather = pattern_to_output.at(gather_label).get_node_shared_ptr(); + auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); + + const NodePair new_nodes = Swap(unary, gather); + + register_new_node(new_nodes.first); + register_new_node(new_nodes.second); + + return true; + }; + + auto m = std::make_shared(gather_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +namespace { +std::function)> consumers_more_than(size_t n) { + return [=](Output output) -> bool { + return output.get_target_inputs().size() > n; + }; +} +} // namespace + +GatherSinkingUnaryBackwardMultiConsumers::GatherSinkingUnaryBackwardMultiConsumers() { + MATCHER_SCOPE(GatherSinkingUnaryBackwardMultiConsumers); + auto unary_restrictions = [](const Output& output) -> bool { + return consumers_more_than(1)(output) && HasSameOutputGatherNodes(output); + }; + + auto unary_label = + wrap_type({any_input()}, + unary_restrictions); + + auto indices_const_label = wrap_type(); + auto axes_const_label = wrap_type(); + + auto gather_label = wrap_type({unary_label, indices_const_label, axes_const_label}, IfGatherSinkingEnabled); + + ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto indices_const = as_type_ptr(pattern_to_output.at(indices_const_label).get_node_shared_ptr()); + auto axes_const = as_type_ptr(pattern_to_output.at(axes_const_label).get_node_shared_ptr()); + auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); + + for (auto& new_node : sink_backward::InsertGatherBeforeNode(unary, indices_const, axes_const)) { + register_new_node(new_node); + } + + // remove output transposes + RemoveSingleOutputConsumers(unary); + + return true; + }; + + auto m = std::make_shared(gather_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +GatherSinkingUnaryBackward::GatherSinkingUnaryBackward() { + MATCHER_SCOPE(GatherSinkingUnaryBackward); + add_matcher(); + add_matcher(); +} diff --git a/src/plugins/intel_gna/src/transformations/gather_sinking_unary.hpp b/src/plugins/intel_gna/src/transformations/gather_sinking_unary.hpp new file mode 100644 index 00000000000..c25a59a3986 --- /dev/null +++ b/src/plugins/intel_gna/src/transformations/gather_sinking_unary.hpp @@ -0,0 +1,111 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace intel_gna { +namespace pass { + +/** + * @brief Moves Gather layer forward from the start to the end of the graph + * through the unary operations UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert + * + * Gather Unary + * | => | + * Unary Gather + * | | + * Another Another + * + * Gather Unary + * | => | + * Unary Gather + * | | | | + * Any1 Any2 Any1 Any2 + * + * Gather Unary1 + * | => | | + * Unary1 Unary2 Unary3 + * | | | | + * Unary2 Unary3 Gather Gather + * + * Another1 Another1 + * | | | + * Gather Unary Gather + * | | => | | + * Unary Another2 Gather Another2 + * | | + * Another3 Another3 + * + * All GatherSinking tranformations are designed to work in 2 steps: + * - forward push + * - backward push + * Add flag into Gather layer rt_info that prevents backward sinking if the next layer + * after Gather does not support by GatherSinking transformations. That is done to + * prevent backward pushing the layer that already pushed forward through the graph. + */ +class GatherSinkingUnaryForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("GatherSinkingUnaryForward", "0"); + GatherSinkingUnaryForward(); +}; + +/** + * @brief Moves Gather layer backward from the end to the start of the graph + * Works only with single consumer case. If Gather is marked as not-sinkable + * (since it was moved previously by forward sinking) it is not proceeded. + * + * Any Any + * | | + * Unary => Gather + * | | + * Gather Unary + */ +class GatherSinkingUnaryBackwardSingleConsumer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("GatherSinkingUnaryBackwardSingleConsumer", "0"); + GatherSinkingUnaryBackwardSingleConsumer(); +}; + +/** + * @brief Moves Gather layer backward from the end to the start of the graph + * Works only with multiple consumer case. If Gather is marked as non-sinkable + * (since it was moved previously by forward sinking) it is not proceeded. + * + * Any1 Any1 + * | | + * Unary => Gather + * | | | + * Gather Gather Unary + * | | | | + * Any2 Any3 Any2 Any3 + * + * Moves Gather layer backward only if: + * - Gather is not marked as non-sinkable + * - Unary layer has > 1 gather consumers + * - All Unary consumers are Gather layers + * - All that Gather layers equal each other + */ +class GatherSinkingUnaryBackwardMultiConsumers : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("GatherSinkingUnaryBackwardMultiConsumers", "0"); + GatherSinkingUnaryBackwardMultiConsumers(); +}; + +/** + * @brief GatherSinkingUnaryBackward transformations calls GatherSinkingUnaryBackward and + * GatherSinkingUnaryBackwardMultiConsumers so there is no need to use them if GatherSinkingUnaryBackward is used + */ +class GatherSinkingUnaryBackward : public ov::pass::GraphRewrite { +public: + OPENVINO_RTTI("GatherSinkingUnaryBackward", "0"); + GatherSinkingUnaryBackward(); +}; + +} // namespace pass +} // namespace intel_gna +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_gna/src/transformations/rt_info/gather_sinking_attr.cpp b/src/plugins/intel_gna/src/transformations/rt_info/gather_sinking_attr.cpp new file mode 100644 index 00000000000..cd18a94db93 --- /dev/null +++ b/src/plugins/intel_gna/src/transformations/rt_info/gather_sinking_attr.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gather_sinking_attr.hpp" + +void ov::intel_gna::rt_info::mark_as_no_gather_sinking_node(const std::shared_ptr& node) { + auto& rt_info = node->get_rt_info(); + rt_info[NoGatherSinkingAttr::get_type_info_static()] = NoGatherSinkingAttr(); +} + +template +bool is_gather_sinking_node_private(NodePtr node) { + const auto& rt_info = node->get_rt_info(); + return rt_info.find(ov::intel_gna::rt_info::NoGatherSinkingAttr::get_type_info_static()) == rt_info.end(); +} + +bool ov::intel_gna::rt_info::is_gather_sinking_node(const std::shared_ptr& node) { + return is_gather_sinking_node_private(node); +} + +bool ov::intel_gna::rt_info::is_gather_sinking_node(const Node* node) { + return is_gather_sinking_node_private(node); +} + +bool ov::intel_gna::rt_info::is_gather_sinking_node(Output output) { + return is_gather_sinking_node(output.get_node()); +} diff --git a/src/plugins/intel_gna/src/transformations/rt_info/gather_sinking_attr.hpp b/src/plugins/intel_gna/src/transformations/rt_info/gather_sinking_attr.hpp new file mode 100644 index 00000000000..2848fae9513 --- /dev/null +++ b/src/plugins/intel_gna/src/transformations/rt_info/gather_sinking_attr.hpp @@ -0,0 +1,37 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/core/runtime_attribute.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace intel_gna { +namespace rt_info { + +void mark_as_no_gather_sinking_node(const std::shared_ptr& node); + +bool is_gather_sinking_node(const std::shared_ptr& node); +bool is_gather_sinking_node(const Node* node); +bool is_gather_sinking_node(ov::Output output); + +/** + * @ingroup ie_runtime_attr_api + * @brief NoGatherSinkingAttr class represents runtime info attribute that marks gather + * operation should not be moved be backward sinking propagation. + */ +class NoGatherSinkingAttr : public RuntimeAttribute { +public: + OPENVINO_RTTI("no_gather_sinking", "0"); + virtual ~NoGatherSinkingAttr() = default; + bool is_copyable() const override { + return false; + } +}; + +} // namespace rt_info +} // namespace intel_gna +} // namespace ov diff --git a/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp b/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp new file mode 100644 index 00000000000..a8a1fbad30b --- /dev/null +++ b/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.cpp @@ -0,0 +1,260 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/utils/gather_sinking_utils.hpp" + +#include +#include +#include + +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/util/common_util.hpp" +#include "openvino/util/log.hpp" +#include "transformations/rt_info/gather_sinking_attr.hpp" + +namespace gather_sinking { + +using namespace ov; +using namespace ov::intel_gna::rt_info; +using namespace ov::opset9; + +using NodePtr = std::shared_ptr; + +GatherInputsInfo GetFirstGatherInput(NodePtr node) { + for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) { + NodePtr input_node = node->get_input_node_shared_ptr(input_idx); + auto gather_node = as_type_ptr(input_node); + if (!gather_node) + continue; + auto indices_const_node = as_type_ptr(gather_node->input_value(1).get_node_shared_ptr()); + if (!indices_const_node) + continue; + auto axes_const_node = as_type_ptr(gather_node->input_value(2).get_node_shared_ptr()); + if (!axes_const_node) + continue; + { + GatherInputsInfo input_info; + input_info.gather = gather_node; + input_info.indices_const = indices_const_node; + input_info.axes_const = axes_const_node; + input_info.input_idx = input_idx; + return input_info; + } + } + + return GatherInputsInfo(); +} + +bool IfNodeHasGatherInputs(const Output& output) { + GatherInputsInfo inputs_info = GetFirstGatherInput(output.get_node_shared_ptr()); + return !inputs_info.isEmpty(); +} + +namespace { + +bool HasDynamicRankInput(NodePtr node) { + for (auto& input_node : node->input_values()) { + const Rank output_rank = input_node.get_partial_shape().rank(); + if (output_rank.is_dynamic()) + return true; + } + return false; +} + +Rank::value_type GetMaxInputRank(const NodePtr& node) { + Rank::value_type max_input_rank = 0; + for (auto& input_node : node->input_values()) { + const Rank output_rank = input_node.get_partial_shape().rank(); + if (output_rank.is_dynamic()) + return -1; + const Rank::value_type output_rank_len = output_rank.get_length(); + if (output_rank_len > max_input_rank) + max_input_rank = output_rank_len; + } + return max_input_rank; +} + +NodePtr InsertUnsqueeze(Output node, size_t n_dims) { + std::vector dims(n_dims); + std::iota(dims.begin(), dims.end(), 0); + auto unsqueeze_const = std::make_shared(element::i64, Shape{dims.size()}, dims); + auto unsqueeze = std::make_shared(node, unsqueeze_const); + copy_runtime_info(node.get_node_shared_ptr(), {unsqueeze, unsqueeze_const}); + return unsqueeze; +} + +Output FixInputNodeRank(Output input_node, Rank::value_type required_rank) { + const Rank::value_type output_rank = input_node.get_partial_shape().rank().get_length(); + if (output_rank >= required_rank) + return input_node; + return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0); +} + +} // namespace + +namespace sink_backward { + +NodeVector InsertGatherBeforeNode(NodePtr main_node, + const std::shared_ptr& indices_const, + const std::shared_ptr& axes_const) { + if (HasDynamicRankInput(main_node)) + return {}; + + NodeVector new_nodes; + + const auto max_input_rank = GetMaxInputRank(main_node); + if (max_input_rank < 0) + return {}; + + for (size_t i = 0; i < main_node->get_input_size(); ++i) { + auto input_node = FixInputNodeRank(main_node->input_value(i), max_input_rank); + + auto new_indices_const = indices_const->clone_with_new_inputs({}); + auto new_axes_const = axes_const->clone_with_new_inputs({}); + auto new_gather = std::make_shared(input_node, new_indices_const, new_axes_const); + + main_node->input(i).replace_source_output(new_gather->output(0)); + + copy_runtime_info(input_node.get_node_shared_ptr(), {new_gather, new_indices_const, new_axes_const}); + + new_nodes.push_back(new_gather); + } + + return new_nodes; +} + +} // namespace sink_backward + +namespace { +#define CHECK_GATHER_SINKING_SUPPORTED(TYPE, node) \ + if (dynamic_cast(node)) { \ + return true; \ + } + +bool CanPropagateGatherForwardThrough(Node* node) { + CHECK_GATHER_SINKING_SUPPORTED(ov::op::util::UnaryElementwiseArithmetic, node); + CHECK_GATHER_SINKING_SUPPORTED(Clamp, node); + CHECK_GATHER_SINKING_SUPPORTED(Elu, node); + CHECK_GATHER_SINKING_SUPPORTED(SoftPlus, node); + CHECK_GATHER_SINKING_SUPPORTED(LogicalNot, node); + CHECK_GATHER_SINKING_SUPPORTED(Convert, node); + return false; +} + +#undef CHECK_GATHER_SINKING_SUPPORTED + +bool CanGatherPropagateForward(NodePtr node) { + for (auto output : node->outputs()) { + for (auto& consumer_input : output.get_target_inputs()) { + if (!CanPropagateGatherForwardThrough(consumer_input.get_node())) + return false; + } + } + + return true; +} + +} // namespace + +void UpdateForwardGatherSinkingAbility(NodePtr node) { + if (!CanGatherPropagateForward(node)) + mark_as_no_gather_sinking_node(node); +} + +namespace { + +struct GatherInfo { + bool isEmpty() const { + return indices.empty(); + } + bool operator==(const GatherInfo& another) { + if (indices.size() != another.indices.size()) + return false; + if (!std::equal(indices.begin(), indices.end(), another.indices.begin())) + return false; + return axis == another.axis; + } + bool operator!=(const GatherInfo& another) { + return !(*this == another); + } + + ov::AxisVector indices; + int64_t axis = {}; +}; + +GatherInfo GetGatherInfo(Node* node) { + GatherInfo gather_info; + + auto gather_node = dynamic_cast(node); + if (!gather_node) + return {}; + + auto constant_node = as_type_ptr(gather_node->input_value(1).get_node_shared_ptr()); + if (!constant_node) + return {}; + + gather_info.indices = constant_node->get_axis_vector_val(); + + constant_node = as_type_ptr(gather_node->input_value(2).get_node_shared_ptr()); + if (!constant_node) + return {}; + + gather_info.axis = constant_node->get_axis_vector_val()[0]; + + return gather_info; +} + +Node* FindFirstConsumer(NodePtr node) { + for (auto output : node->outputs()) { + auto inputs = output.get_target_inputs(); + if (inputs.empty()) + continue; + return inputs.begin()->get_node(); + } + return nullptr; +} + +bool HasSameOutputGatherNodes(NodePtr main_node) { + GatherInfo first_gather_info; + { + Node* first_consumer = FindFirstConsumer(main_node); + if (!first_consumer) + return false; + first_gather_info = GetGatherInfo(first_consumer); + if (first_gather_info.isEmpty()) + return false; + } + + for (size_t output_idx = 0; output_idx < main_node->get_output_size(); ++output_idx) { + for (auto& input : main_node->get_output_target_inputs(output_idx)) { + GatherInfo gather_info = GetGatherInfo(input.get_node()); + if (gather_info.isEmpty() || gather_info != first_gather_info) + return false; + } + } + + return true; +} + +} // namespace + +bool HasSameOutputGatherNodes(const Output& output) { + return HasSameOutputGatherNodes(output.get_node_shared_ptr()); +} + +void RemoveSingleOutputConsumers(NodePtr node) { + for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) { + for (auto& input : node->get_output_target_inputs(output_idx)) { + Node* consumer = input.get_node(); + if (consumer->get_output_size() != 1) + continue; + consumer->output(0).replace(node->output(output_idx)); + } + } +} + +} // namespace gather_sinking diff --git a/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.hpp b/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.hpp new file mode 100644 index 00000000000..c3fd1530a36 --- /dev/null +++ b/src/plugins/intel_gna/src/transformations/utils/gather_sinking_utils.hpp @@ -0,0 +1,64 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/util/common_util.hpp" +#include "openvino/util/log.hpp" + +namespace gather_sinking { + +struct GatherInputsInfo { + std::shared_ptr gather; + std::shared_ptr indices_const; + std::shared_ptr axes_const; + size_t input_idx; + + bool isEmpty() const { + return !gather || !indices_const || !axes_const; + } +}; + +/** + * @brief Finds node first input that is a Gather operation and returns filled GatherInputsInfo + * for it + */ +GatherInputsInfo GetFirstGatherInput(std::shared_ptr); + +/** + * @brief Checks if @arg has any input node that is a Gather operation + */ +bool IfNodeHasGatherInputs(const ov::Output&); + +namespace sink_backward { +/** + * @brief Inserts Gather layers on each input of @arg main_node with cloned indices and axes constants + */ +ov::NodeVector InsertGatherBeforeNode(std::shared_ptr main_node, + const std::shared_ptr& indices_const, + const std::shared_ptr& axes_const); +} // namespace sink_backward + +void UpdateForwardGatherSinkingAbility(std::shared_ptr); + +/** + * @brief Checks if @arg has consumers that all are the same Gather operation. If no consumers at all + * returns false. + */ +bool HasSameOutputGatherNodes(const ov::Output&); + +/** + * Removes all direct node consumers that have one output + */ +void RemoveSingleOutputConsumers(std::shared_ptr); + +} // namespace gather_sinking diff --git a/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_unary_test.cpp b/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_unary_test.cpp new file mode 100644 index 00000000000..95ebb5dbde2 --- /dev/null +++ b/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_unary_test.cpp @@ -0,0 +1,563 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/gather_sinking_unary.hpp" + +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "gtest/gtest.h" + +using namespace ov; +using namespace ov::opset9; + +using NodePtr = std::shared_ptr; + +namespace { +std::string to_string(const Shape& shape) { + std::ostringstream result; + result << "{"; + for (size_t idx = 0; idx < shape.size(); ++idx) { + if (idx) + result << ","; + result << shape[idx]; + } + result << "}"; + return result.str(); +} + +std::vector GenerateVector(size_t size, size_t initial_value) { + std::vector vec(size); + std::iota(vec.begin(), vec.end(), initial_value); + return vec; +} + +std::vector MakeGatherIndexes(size_t size) { + std::vector indexes = GenerateVector(size, 0); + std::next_permutation(indexes.begin(), indexes.end()); + return indexes; +} + +std::shared_ptr MakeGather(NodePtr input_node) { + const ov::Shape& input_shape = input_node->get_output_shape(0); + const size_t input_shape_product = + std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + const std::vector indexes = MakeGatherIndexes(input_shape_product); + auto gather_indexes_node = Constant::create(element::i64, ov::Shape{indexes.size()}, indexes); + + const size_t axis = 1; + auto gather_axis_node = Constant::create(element::i64, Shape{}, {axis}); + + return std::make_shared(input_node, gather_indexes_node, gather_axis_node); +} + +} // namespace + +class IUnaryFactory { +public: + IUnaryFactory(const std::string& type_name) : type_name_(type_name) {} + virtual ~IUnaryFactory() = default; + virtual NodePtr create(NodePtr parent_node) const = 0; + + const std::string& getTypeName() const { + return type_name_; + } + +private: + const std::string type_name_; +}; + +using UnaryFactoryPtr = std::shared_ptr; + +template +class UnaryFactory : public IUnaryFactory { +public: + UnaryFactory(const std::string& type_name) : IUnaryFactory(type_name) {} + NodePtr create(NodePtr parent_node) const override { + return std::make_shared(parent_node); + } +}; + +template <> +NodePtr UnaryFactory::create(NodePtr parent_node) const { + return std::make_shared(parent_node, 0.1); +} + +template <> +NodePtr UnaryFactory::create(NodePtr parent_node) const { + return std::make_shared(parent_node, 0.1, 0.2); +} + +template <> +NodePtr UnaryFactory::create(NodePtr parent_node) const { + return std::make_shared(parent_node, element::f64); +} + +template +UnaryFactoryPtr CreateUnaryFactory(const std::string& type_name) { + return std::make_shared>(type_name); +} + +// ---------------------------------------------------------------------------- + +class IPassFactory { +public: + IPassFactory(const std::string& type_name) : type_name_(type_name) {} + virtual ~IPassFactory() = default; + virtual void registerPass(ov::pass::Manager& pass_manager) const = 0; + const std::string& getTypeName() const { + return type_name_; + } + +private: + const std::string type_name_; +}; + +using PassFactoryPtr = std::shared_ptr; + +template +class PassFactory : public IPassFactory { +public: + PassFactory(const std::string& type_name) : IPassFactory(type_name) {} + void registerPass(ov::pass::Manager& pass_manager) const override { + pass_manager.register_pass(); + } +}; + +#define CREATE_PASS_FACTORY(pass_name) std::make_shared>(#pass_name) + +#undef CREATE_UNARY_FACTORY +#define CREATE_UNARY_FACTORY(type_name) CreateUnaryFactory(#type_name) + +// ---------------------------------------------------------------------------- + +using FloatPtr = std::unique_ptr; + +using CreateGraphF = std::function(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type)>; + +using TestParams = std::tuple; /* input type */ + +class GatherSinkingUnaryTestFixture : public ::testing::WithParamInterface, public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + UnaryFactoryPtr unary_factory; + PassFactoryPtr pass_factory; + size_t num_unary_ops; + CreateGraphF model_factory; + CreateGraphF reference_model_factory; + Shape input_shape; + element::Type input_type; + std::tie(unary_factory, + pass_factory, + num_unary_ops, + model_factory, + reference_model_factory, + input_shape, + input_type) = obj.param; + + std::ostringstream test_name; + test_name << "unaryFactory=" << unary_factory->getTypeName() << "/"; + test_name << "numUnaryOps=" << num_unary_ops << "/"; + test_name << "inputShape=" << to_string(input_shape) << "/"; + test_name << "unaryFactory=" << unary_factory->getTypeName() << "/"; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "inputType=" << input_type; + + return test_name.str(); + } +}; + +namespace { + +std::string GetFinalNodeName(std::shared_ptr model, int index = 0) { + NodePtr result_node = model->get_results()[index]; + return result_node->get_input_node_ptr(0)->get_friendly_name(); +} + +std::shared_ptr CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + auto gather = MakeGather(X); + + NodePtr in_op = gather; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + return std::make_shared(in_op, ov::ParameterVector{X}); +} + +std::shared_ptr CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto gather = MakeGather(in_op); + + return std::make_shared(gather, ov::ParameterVector{X}); +} + +static NodePtr CreateReshape(NodePtr parent_node) { + const Shape& input_shape = parent_node->get_output_shape(0); + const size_t mul = std::accumulate(input_shape.begin(), input_shape.end(), (size_t)1, std::multiplies()); + auto reshape_const = std::make_shared(element::u64, Shape{1}, Shape{mul}); + return std::make_shared(parent_node, reshape_const, false); +} + +namespace mult_consumers_last_node { +namespace with_reshape { + +std::shared_ptr CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto gather = MakeGather(in_op); + + auto reshape1 = CreateReshape(gather); + auto reshape2 = CreateReshape(gather); + + return std::make_shared(ov::OutputVector{reshape1, reshape2}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + auto gather = MakeGather(X); + + NodePtr in_op = gather; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto reshape1 = CreateReshape(in_op); + auto reshape2 = CreateReshape(in_op); + + return std::make_shared(ov::OutputVector{reshape1, reshape2}, ov::ParameterVector{X}); +} +} // namespace with_reshape + +namespace with_eltwise { + +std::shared_ptr CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto sinh = std::make_shared(in_op); + + auto gather0 = MakeGather(sinh); + + auto cosh = std::make_shared(in_op); + + auto gather1 = MakeGather(cosh); + + return std::make_shared(ov::OutputVector{gather0, gather1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + auto gather0 = MakeGather(X); + + NodePtr in_op = gather0; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto sinh = std::make_shared(in_op); + auto cosh = std::make_shared(in_op); + + return std::make_shared(ov::OutputVector{sinh, cosh}, ov::ParameterVector{X}); +} + +} // namespace with_eltwise +} // namespace mult_consumers_last_node + +namespace mult_consumers_first_node { +namespace backward { + +std::shared_ptr CreateFunction(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + ov::OutputVector outputs; + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + auto cosh = std::make_shared(in_op); + outputs.push_back(cosh); + } + + auto gather0 = MakeGather(in_op); + + outputs.push_back(gather0); + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +} // namespace backward + +namespace backward_mult_transposes { + +std::shared_ptr CreateFunction(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto gather0 = MakeGather(in_op); + + auto tanh0 = std::make_shared(gather0); + + auto gather1 = MakeGather(in_op); + + auto tanh1 = std::make_shared(gather1); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + auto gather0 = MakeGather(X); + + NodePtr in_op = gather0; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto tanh0 = std::make_shared(in_op); + auto tanh1 = std::make_shared(in_op); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +} // namespace backward_mult_transposes + +namespace forward { + +std::shared_ptr CreateFunction(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + auto sinh = std::make_shared(X); + + auto gather0 = MakeGather(sinh); + + auto reshape = CreateReshape(gather0); + + NodePtr in_op = gather0; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + return std::make_shared(ov::OutputVector{in_op, reshape}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + auto sinh = std::make_shared(X); + + auto gather0 = MakeGather(sinh); + + auto reshape = CreateReshape(gather0); + + NodePtr in_op = sinh; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto gather1 = MakeGather(in_op); + + return std::make_shared(ov::OutputVector{gather1, reshape}, ov::ParameterVector{X}); +} + +} // namespace forward +} // namespace mult_consumers_first_node + +std::vector unary_factories = { + CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus), + CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs), + CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), + CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos), + CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp), + CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish), + CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu), + CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin), + CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt), + CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)}; + +std::vector unary_operations_numbers = {1, 10}; + +} // namespace + +TEST_P(GatherSinkingUnaryTestFixture, CompareFunctions) { + UnaryFactoryPtr unary_factory; + PassFactoryPtr pass_factory; + size_t num_unary_ops; + CreateGraphF model_factory; + CreateGraphF reference_model_factory; + Shape input_shape; + element::Type input_type; + std::tie(unary_factory, + pass_factory, + num_unary_ops, + model_factory, + reference_model_factory, + input_shape, + input_type) = this->GetParam(); + + model = model_factory(unary_factory, num_unary_ops, input_shape, input_type); + model_ref = reference_model_factory(unary_factory, num_unary_ops, input_shape, input_type); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P(GatherSinkingUnaryForwardTestSuite, + GatherSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryForward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(CreateFunctionTransposeBefore), + ::testing::Values(CreateFunctionTransposeAfter), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + GatherSinkingUnaryTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(GatherSinkingUnaryBackwardTestSuite, + GatherSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryBackward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(CreateFunctionTransposeAfter), + ::testing::Values(CreateFunctionTransposeBefore), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + GatherSinkingUnaryTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + GatherSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape, + GatherSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryForward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore), + ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + GatherSinkingUnaryTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + GatherSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape, + GatherSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryBackward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter), + ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + GatherSinkingUnaryTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + GatherSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise, + GatherSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryForward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore), + ::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + GatherSinkingUnaryTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + GatherSinkingUnaryForwardMultConsumersTestSuiteFirstNode, + GatherSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryForward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(mult_consumers_first_node::forward::CreateFunction), + ::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + GatherSinkingUnaryTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(GatherSinkingUnaryBackwardMultConsumersTestSuiteFirstNode, + GatherSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryBackward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(mult_consumers_first_node::backward::CreateFunction), + ::testing::Values(mult_consumers_first_node::backward::CreateFunction), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + GatherSinkingUnaryTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + GatherSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode, + GatherSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryBackward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateFunction), + ::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + GatherSinkingUnaryTestFixture::get_test_name);