Gather Sinking for Unary operations (#15289)
* initial * fix year * CanGatherPropagateForward review fix * HasSameOutputGatherNodes fix code review * fix namespaces code review * fix ReverseGatherIndexes code review * clang fixes * clang fixes * remove unneeded function * move utils to utils dir + change namespace * clang fixes * windows build fixes * resotore attr file * resotore attr file * code review fix
This commit is contained in:
parent
efb51b058c
commit
04f300e187
@ -0,0 +1,206 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/gather_sinking_unary.hpp"
|
||||||
|
|
||||||
|
#include <openvino/cc/ngraph/itt.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "transformations/rt_info/gather_sinking_attr.hpp"
|
||||||
|
#include "transformations/utils/gather_sinking_utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
using namespace ov::pass::pattern;
|
||||||
|
using namespace ov::op::util;
|
||||||
|
using namespace gather_sinking;
|
||||||
|
using namespace ov::intel_gna::pass;
|
||||||
|
using namespace ov::intel_gna::rt_info;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using NodePtr = std::shared_ptr<ov::Node>;
|
||||||
|
using NodePair = std::pair<NodePtr, NodePtr>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief SwapNodes allows to perform swapping nodes even if there are more than one consumers but has less performance
|
||||||
|
*
|
||||||
|
* @param first_node first node pointer
|
||||||
|
* @param second_node first node pointer
|
||||||
|
* @return NodePair pair of nodes in new order that allows to register them in MatcherPass
|
||||||
|
*/
|
||||||
|
NodePair SwapNodes(NodePtr first_node, NodePtr second_node) {
|
||||||
|
auto second_node_inputs = second_node->input_values();
|
||||||
|
second_node_inputs[0] = first_node->input_value(0);
|
||||||
|
|
||||||
|
auto new_first_node = second_node->clone_with_new_inputs(second_node_inputs);
|
||||||
|
|
||||||
|
auto first_node_inputs = first_node->input_values();
|
||||||
|
first_node_inputs[0] = new_first_node;
|
||||||
|
auto new_second_node = first_node->clone_with_new_inputs(first_node_inputs);
|
||||||
|
|
||||||
|
new_second_node->set_friendly_name(second_node->get_friendly_name());
|
||||||
|
ov::copy_runtime_info({first_node, second_node}, {new_first_node, new_second_node});
|
||||||
|
|
||||||
|
ov::replace_node(second_node, new_second_node);
|
||||||
|
|
||||||
|
return std::make_pair(new_first_node, new_second_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief SwapOutputs has much better performance than SwapNodes and covers the most of the real situations
|
||||||
|
* but cannot work when the consumers count greater than one
|
||||||
|
* @param first_node first node pointer
|
||||||
|
* @param second_node second node pointer
|
||||||
|
* @return NodePair pair of nodes in new order that allows to register them in MatcherPass
|
||||||
|
*/
|
||||||
|
NodePair SwapOutputs(NodePtr first_node, NodePtr second_node) {
|
||||||
|
const auto first_node_output_names = first_node->output(0).get_names();
|
||||||
|
const auto second_node_output_names = second_node->output(0).get_names();
|
||||||
|
|
||||||
|
auto swap_names = [&]() {
|
||||||
|
const std::string first_name = first_node->get_friendly_name();
|
||||||
|
first_node->set_friendly_name(second_node->get_friendly_name());
|
||||||
|
second_node->set_friendly_name(first_name);
|
||||||
|
|
||||||
|
first_node->output(0).set_names(second_node_output_names);
|
||||||
|
second_node->output(0).set_names(first_node_output_names);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto out_1 = first_node->input_value(0);
|
||||||
|
second_node->input(0).replace_source_output(out_1);
|
||||||
|
|
||||||
|
auto out_2 = second_node->output(0);
|
||||||
|
second_node->output(0).replace(first_node->output(0));
|
||||||
|
|
||||||
|
first_node->input(0).replace_source_output(out_2);
|
||||||
|
|
||||||
|
swap_names();
|
||||||
|
|
||||||
|
return std::make_pair(second_node, first_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Swapping inputs/outputs has better perfomance that Swapping nodes with clone but it cannot be used
|
||||||
|
* in multiple consumers case
|
||||||
|
*/
|
||||||
|
NodePair Swap(NodePtr first_node, NodePtr second_node) {
|
||||||
|
NodePair new_nodes;
|
||||||
|
|
||||||
|
if (first_node->output(0).get_target_inputs().size() > 1 || second_node->output(0).get_target_inputs().size() > 1)
|
||||||
|
new_nodes = SwapNodes(first_node, second_node);
|
||||||
|
else
|
||||||
|
new_nodes = SwapOutputs(first_node, second_node);
|
||||||
|
|
||||||
|
return new_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
GatherSinkingUnaryForward::GatherSinkingUnaryForward() {
|
||||||
|
MATCHER_SCOPE(GatherSinkingUnaryForward);
|
||||||
|
auto gather_label = wrap_type<Gather>({any_input(), any_input(), any_input()});
|
||||||
|
auto unary_label = wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({gather_label});
|
||||||
|
|
||||||
|
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto gather = pattern_to_output.at(gather_label).get_node_shared_ptr();
|
||||||
|
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||||
|
|
||||||
|
const NodePair new_nodes = Swap(gather, unary);
|
||||||
|
|
||||||
|
register_new_node(new_nodes.first);
|
||||||
|
register_new_node(new_nodes.second);
|
||||||
|
|
||||||
|
UpdateForwardGatherSinkingAbility(new_nodes.second);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(unary_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
bool IfGatherSinkingEnabled(const Output<Node>& output) {
|
||||||
|
return is_gather_sinking_node(output.get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
GatherSinkingUnaryBackwardSingleConsumer::GatherSinkingUnaryBackwardSingleConsumer() {
|
||||||
|
MATCHER_SCOPE(GatherSinkingUnaryBackwardSingleConsumer);
|
||||||
|
auto unary_label =
|
||||||
|
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
|
||||||
|
consumers_count(1));
|
||||||
|
|
||||||
|
auto gather_label = wrap_type<Gather>({unary_label, any_input(), any_input()}, IfGatherSinkingEnabled);
|
||||||
|
|
||||||
|
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto gather = pattern_to_output.at(gather_label).get_node_shared_ptr();
|
||||||
|
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||||
|
|
||||||
|
const NodePair new_nodes = Swap(unary, gather);
|
||||||
|
|
||||||
|
register_new_node(new_nodes.first);
|
||||||
|
register_new_node(new_nodes.second);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::function<bool(Output<Node>)> consumers_more_than(size_t n) {
|
||||||
|
return [=](Output<Node> output) -> bool {
|
||||||
|
return output.get_target_inputs().size() > n;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
GatherSinkingUnaryBackwardMultiConsumers::GatherSinkingUnaryBackwardMultiConsumers() {
|
||||||
|
MATCHER_SCOPE(GatherSinkingUnaryBackwardMultiConsumers);
|
||||||
|
auto unary_restrictions = [](const Output<Node>& output) -> bool {
|
||||||
|
return consumers_more_than(1)(output) && HasSameOutputGatherNodes(output);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto unary_label =
|
||||||
|
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
|
||||||
|
unary_restrictions);
|
||||||
|
|
||||||
|
auto indices_const_label = wrap_type<Constant>();
|
||||||
|
auto axes_const_label = wrap_type<Constant>();
|
||||||
|
|
||||||
|
auto gather_label = wrap_type<Gather>({unary_label, indices_const_label, axes_const_label}, IfGatherSinkingEnabled);
|
||||||
|
|
||||||
|
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto indices_const = as_type_ptr<Constant>(pattern_to_output.at(indices_const_label).get_node_shared_ptr());
|
||||||
|
auto axes_const = as_type_ptr<Constant>(pattern_to_output.at(axes_const_label).get_node_shared_ptr());
|
||||||
|
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||||
|
|
||||||
|
for (auto& new_node : sink_backward::InsertGatherBeforeNode(unary, indices_const, axes_const)) {
|
||||||
|
register_new_node(new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove output transposes
|
||||||
|
RemoveSingleOutputConsumers(unary);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
GatherSinkingUnaryBackward::GatherSinkingUnaryBackward() {
|
||||||
|
MATCHER_SCOPE(GatherSinkingUnaryBackward);
|
||||||
|
add_matcher<GatherSinkingUnaryBackwardSingleConsumer>();
|
||||||
|
add_matcher<GatherSinkingUnaryBackwardMultiConsumers>();
|
||||||
|
}
|
@ -0,0 +1,111 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace intel_gna {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Moves Gather layer forward from the start to the end of the graph
|
||||||
|
* through the unary operations UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert
|
||||||
|
*
|
||||||
|
* Gather Unary
|
||||||
|
* | => |
|
||||||
|
* Unary Gather
|
||||||
|
* | |
|
||||||
|
* Another Another
|
||||||
|
*
|
||||||
|
* Gather Unary
|
||||||
|
* | => |
|
||||||
|
* Unary Gather
|
||||||
|
* | | | |
|
||||||
|
* Any1 Any2 Any1 Any2
|
||||||
|
*
|
||||||
|
* Gather Unary1
|
||||||
|
* | => | |
|
||||||
|
* Unary1 Unary2 Unary3
|
||||||
|
* | | | |
|
||||||
|
* Unary2 Unary3 Gather Gather
|
||||||
|
*
|
||||||
|
* Another1 Another1
|
||||||
|
* | | |
|
||||||
|
* Gather Unary Gather
|
||||||
|
* | | => | |
|
||||||
|
* Unary Another2 Gather Another2
|
||||||
|
* | |
|
||||||
|
* Another3 Another3
|
||||||
|
*
|
||||||
|
* All GatherSinking tranformations are designed to work in 2 steps:
|
||||||
|
* - forward push
|
||||||
|
* - backward push
|
||||||
|
* Add flag into Gather layer rt_info that prevents backward sinking if the next layer
|
||||||
|
* after Gather does not support by GatherSinking transformations. That is done to
|
||||||
|
* prevent backward pushing the layer that already pushed forward through the graph.
|
||||||
|
*/
|
||||||
|
class GatherSinkingUnaryForward : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("GatherSinkingUnaryForward", "0");
|
||||||
|
GatherSinkingUnaryForward();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Moves Gather layer backward from the end to the start of the graph
|
||||||
|
* Works only with single consumer case. If Gather is marked as not-sinkable
|
||||||
|
* (since it was moved previously by forward sinking) it is not proceeded.
|
||||||
|
*
|
||||||
|
* Any Any
|
||||||
|
* | |
|
||||||
|
* Unary => Gather
|
||||||
|
* | |
|
||||||
|
* Gather Unary
|
||||||
|
*/
|
||||||
|
class GatherSinkingUnaryBackwardSingleConsumer : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("GatherSinkingUnaryBackwardSingleConsumer", "0");
|
||||||
|
GatherSinkingUnaryBackwardSingleConsumer();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Moves Gather layer backward from the end to the start of the graph
|
||||||
|
* Works only with multiple consumer case. If Gather is marked as non-sinkable
|
||||||
|
* (since it was moved previously by forward sinking) it is not proceeded.
|
||||||
|
*
|
||||||
|
* Any1 Any1
|
||||||
|
* | |
|
||||||
|
* Unary => Gather
|
||||||
|
* | | |
|
||||||
|
* Gather Gather Unary
|
||||||
|
* | | | |
|
||||||
|
* Any2 Any3 Any2 Any3
|
||||||
|
*
|
||||||
|
* Moves Gather layer backward only if:
|
||||||
|
* - Gather is not marked as non-sinkable
|
||||||
|
* - Unary layer has > 1 gather consumers
|
||||||
|
* - All Unary consumers are Gather layers
|
||||||
|
* - All that Gather layers equal each other
|
||||||
|
*/
|
||||||
|
class GatherSinkingUnaryBackwardMultiConsumers : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("GatherSinkingUnaryBackwardMultiConsumers", "0");
|
||||||
|
GatherSinkingUnaryBackwardMultiConsumers();
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief GatherSinkingUnaryBackward transformations calls GatherSinkingUnaryBackward and
|
||||||
|
* GatherSinkingUnaryBackwardMultiConsumers so there is no need to use them if GatherSinkingUnaryBackward is used
|
||||||
|
*/
|
||||||
|
class GatherSinkingUnaryBackward : public ov::pass::GraphRewrite {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("GatherSinkingUnaryBackward", "0");
|
||||||
|
GatherSinkingUnaryBackward();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace intel_gna
|
||||||
|
} // namespace ov
|
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "gather_sinking_attr.hpp"
|
||||||
|
|
||||||
|
void ov::intel_gna::rt_info::mark_as_no_gather_sinking_node(const std::shared_ptr<Node>& node) {
|
||||||
|
auto& rt_info = node->get_rt_info();
|
||||||
|
rt_info[NoGatherSinkingAttr::get_type_info_static()] = NoGatherSinkingAttr();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename NodePtr>
|
||||||
|
bool is_gather_sinking_node_private(NodePtr node) {
|
||||||
|
const auto& rt_info = node->get_rt_info();
|
||||||
|
return rt_info.find(ov::intel_gna::rt_info::NoGatherSinkingAttr::get_type_info_static()) == rt_info.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ov::intel_gna::rt_info::is_gather_sinking_node(const std::shared_ptr<Node>& node) {
|
||||||
|
return is_gather_sinking_node_private(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ov::intel_gna::rt_info::is_gather_sinking_node(const Node* node) {
|
||||||
|
return is_gather_sinking_node_private(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ov::intel_gna::rt_info::is_gather_sinking_node(Output<Node> output) {
|
||||||
|
return is_gather_sinking_node(output.get_node());
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/core/runtime_attribute.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace intel_gna {
|
||||||
|
namespace rt_info {
|
||||||
|
|
||||||
|
void mark_as_no_gather_sinking_node(const std::shared_ptr<Node>& node);
|
||||||
|
|
||||||
|
bool is_gather_sinking_node(const std::shared_ptr<Node>& node);
|
||||||
|
bool is_gather_sinking_node(const Node* node);
|
||||||
|
bool is_gather_sinking_node(ov::Output<ov::Node> output);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_runtime_attr_api
|
||||||
|
* @brief NoGatherSinkingAttr class represents runtime info attribute that marks gather
|
||||||
|
* operation should not be moved be backward sinking propagation.
|
||||||
|
*/
|
||||||
|
class NoGatherSinkingAttr : public RuntimeAttribute {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("no_gather_sinking", "0");
|
||||||
|
virtual ~NoGatherSinkingAttr() = default;
|
||||||
|
bool is_copyable() const override {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace rt_info
|
||||||
|
} // namespace intel_gna
|
||||||
|
} // namespace ov
|
@ -0,0 +1,260 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/utils/gather_sinking_utils.hpp"
|
||||||
|
|
||||||
|
#include <openvino/pass/pattern/op/or.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "openvino/op/util/op_types.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "openvino/util/common_util.hpp"
|
||||||
|
#include "openvino/util/log.hpp"
|
||||||
|
#include "transformations/rt_info/gather_sinking_attr.hpp"
|
||||||
|
|
||||||
|
namespace gather_sinking {
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::intel_gna::rt_info;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
|
||||||
|
using NodePtr = std::shared_ptr<Node>;
|
||||||
|
|
||||||
|
GatherInputsInfo GetFirstGatherInput(NodePtr node) {
|
||||||
|
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
||||||
|
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
|
||||||
|
auto gather_node = as_type_ptr<Gather>(input_node);
|
||||||
|
if (!gather_node)
|
||||||
|
continue;
|
||||||
|
auto indices_const_node = as_type_ptr<Constant>(gather_node->input_value(1).get_node_shared_ptr());
|
||||||
|
if (!indices_const_node)
|
||||||
|
continue;
|
||||||
|
auto axes_const_node = as_type_ptr<Constant>(gather_node->input_value(2).get_node_shared_ptr());
|
||||||
|
if (!axes_const_node)
|
||||||
|
continue;
|
||||||
|
{
|
||||||
|
GatherInputsInfo input_info;
|
||||||
|
input_info.gather = gather_node;
|
||||||
|
input_info.indices_const = indices_const_node;
|
||||||
|
input_info.axes_const = axes_const_node;
|
||||||
|
input_info.input_idx = input_idx;
|
||||||
|
return input_info;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return GatherInputsInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IfNodeHasGatherInputs(const Output<Node>& output) {
|
||||||
|
GatherInputsInfo inputs_info = GetFirstGatherInput(output.get_node_shared_ptr());
|
||||||
|
return !inputs_info.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool HasDynamicRankInput(NodePtr node) {
|
||||||
|
for (auto& input_node : node->input_values()) {
|
||||||
|
const Rank output_rank = input_node.get_partial_shape().rank();
|
||||||
|
if (output_rank.is_dynamic())
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Rank::value_type GetMaxInputRank(const NodePtr& node) {
|
||||||
|
Rank::value_type max_input_rank = 0;
|
||||||
|
for (auto& input_node : node->input_values()) {
|
||||||
|
const Rank output_rank = input_node.get_partial_shape().rank();
|
||||||
|
if (output_rank.is_dynamic())
|
||||||
|
return -1;
|
||||||
|
const Rank::value_type output_rank_len = output_rank.get_length();
|
||||||
|
if (output_rank_len > max_input_rank)
|
||||||
|
max_input_rank = output_rank_len;
|
||||||
|
}
|
||||||
|
return max_input_rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
NodePtr InsertUnsqueeze(Output<Node> node, size_t n_dims) {
|
||||||
|
std::vector<size_t> dims(n_dims);
|
||||||
|
std::iota(dims.begin(), dims.end(), 0);
|
||||||
|
auto unsqueeze_const = std::make_shared<Constant>(element::i64, Shape{dims.size()}, dims);
|
||||||
|
auto unsqueeze = std::make_shared<Unsqueeze>(node, unsqueeze_const);
|
||||||
|
copy_runtime_info(node.get_node_shared_ptr(), {unsqueeze, unsqueeze_const});
|
||||||
|
return unsqueeze;
|
||||||
|
}
|
||||||
|
|
||||||
|
Output<Node> FixInputNodeRank(Output<Node> input_node, Rank::value_type required_rank) {
|
||||||
|
const Rank::value_type output_rank = input_node.get_partial_shape().rank().get_length();
|
||||||
|
if (output_rank >= required_rank)
|
||||||
|
return input_node;
|
||||||
|
return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace sink_backward {
|
||||||
|
|
||||||
|
NodeVector InsertGatherBeforeNode(NodePtr main_node,
|
||||||
|
const std::shared_ptr<Constant>& indices_const,
|
||||||
|
const std::shared_ptr<Constant>& axes_const) {
|
||||||
|
if (HasDynamicRankInput(main_node))
|
||||||
|
return {};
|
||||||
|
|
||||||
|
NodeVector new_nodes;
|
||||||
|
|
||||||
|
const auto max_input_rank = GetMaxInputRank(main_node);
|
||||||
|
if (max_input_rank < 0)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||||
|
auto input_node = FixInputNodeRank(main_node->input_value(i), max_input_rank);
|
||||||
|
|
||||||
|
auto new_indices_const = indices_const->clone_with_new_inputs({});
|
||||||
|
auto new_axes_const = axes_const->clone_with_new_inputs({});
|
||||||
|
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axes_const);
|
||||||
|
|
||||||
|
main_node->input(i).replace_source_output(new_gather->output(0));
|
||||||
|
|
||||||
|
copy_runtime_info(input_node.get_node_shared_ptr(), {new_gather, new_indices_const, new_axes_const});
|
||||||
|
|
||||||
|
new_nodes.push_back(new_gather);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sink_backward
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define CHECK_GATHER_SINKING_SUPPORTED(TYPE, node) \
|
||||||
|
if (dynamic_cast<TYPE*>(node)) { \
|
||||||
|
return true; \
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CanPropagateGatherForwardThrough(Node* node) {
|
||||||
|
CHECK_GATHER_SINKING_SUPPORTED(ov::op::util::UnaryElementwiseArithmetic, node);
|
||||||
|
CHECK_GATHER_SINKING_SUPPORTED(Clamp, node);
|
||||||
|
CHECK_GATHER_SINKING_SUPPORTED(Elu, node);
|
||||||
|
CHECK_GATHER_SINKING_SUPPORTED(SoftPlus, node);
|
||||||
|
CHECK_GATHER_SINKING_SUPPORTED(LogicalNot, node);
|
||||||
|
CHECK_GATHER_SINKING_SUPPORTED(Convert, node);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef CHECK_GATHER_SINKING_SUPPORTED
|
||||||
|
|
||||||
|
bool CanGatherPropagateForward(NodePtr node) {
|
||||||
|
for (auto output : node->outputs()) {
|
||||||
|
for (auto& consumer_input : output.get_target_inputs()) {
|
||||||
|
if (!CanPropagateGatherForwardThrough(consumer_input.get_node()))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void UpdateForwardGatherSinkingAbility(NodePtr node) {
|
||||||
|
if (!CanGatherPropagateForward(node))
|
||||||
|
mark_as_no_gather_sinking_node(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct GatherInfo {
|
||||||
|
bool isEmpty() const {
|
||||||
|
return indices.empty();
|
||||||
|
}
|
||||||
|
bool operator==(const GatherInfo& another) {
|
||||||
|
if (indices.size() != another.indices.size())
|
||||||
|
return false;
|
||||||
|
if (!std::equal(indices.begin(), indices.end(), another.indices.begin()))
|
||||||
|
return false;
|
||||||
|
return axis == another.axis;
|
||||||
|
}
|
||||||
|
bool operator!=(const GatherInfo& another) {
|
||||||
|
return !(*this == another);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::AxisVector indices;
|
||||||
|
int64_t axis = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
GatherInfo GetGatherInfo(Node* node) {
|
||||||
|
GatherInfo gather_info;
|
||||||
|
|
||||||
|
auto gather_node = dynamic_cast<Gather*>(node);
|
||||||
|
if (!gather_node)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto constant_node = as_type_ptr<Constant>(gather_node->input_value(1).get_node_shared_ptr());
|
||||||
|
if (!constant_node)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
gather_info.indices = constant_node->get_axis_vector_val();
|
||||||
|
|
||||||
|
constant_node = as_type_ptr<Constant>(gather_node->input_value(2).get_node_shared_ptr());
|
||||||
|
if (!constant_node)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
gather_info.axis = constant_node->get_axis_vector_val()[0];
|
||||||
|
|
||||||
|
return gather_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* FindFirstConsumer(NodePtr node) {
|
||||||
|
for (auto output : node->outputs()) {
|
||||||
|
auto inputs = output.get_target_inputs();
|
||||||
|
if (inputs.empty())
|
||||||
|
continue;
|
||||||
|
return inputs.begin()->get_node();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasSameOutputGatherNodes(NodePtr main_node) {
|
||||||
|
GatherInfo first_gather_info;
|
||||||
|
{
|
||||||
|
Node* first_consumer = FindFirstConsumer(main_node);
|
||||||
|
if (!first_consumer)
|
||||||
|
return false;
|
||||||
|
first_gather_info = GetGatherInfo(first_consumer);
|
||||||
|
if (first_gather_info.isEmpty())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t output_idx = 0; output_idx < main_node->get_output_size(); ++output_idx) {
|
||||||
|
for (auto& input : main_node->get_output_target_inputs(output_idx)) {
|
||||||
|
GatherInfo gather_info = GetGatherInfo(input.get_node());
|
||||||
|
if (gather_info.isEmpty() || gather_info != first_gather_info)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool HasSameOutputGatherNodes(const Output<Node>& output) {
|
||||||
|
return HasSameOutputGatherNodes(output.get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
void RemoveSingleOutputConsumers(NodePtr node) {
|
||||||
|
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
|
||||||
|
for (auto& input : node->get_output_target_inputs(output_idx)) {
|
||||||
|
Node* consumer = input.get_node();
|
||||||
|
if (consumer->get_output_size() != 1)
|
||||||
|
continue;
|
||||||
|
consumer->output(0).replace(node->output(output_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gather_sinking
|
@ -0,0 +1,64 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <openvino/pass/pattern/op/or.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "openvino/op/util/op_types.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "openvino/util/common_util.hpp"
|
||||||
|
#include "openvino/util/log.hpp"
|
||||||
|
|
||||||
|
namespace gather_sinking {
|
||||||
|
|
||||||
|
struct GatherInputsInfo {
|
||||||
|
std::shared_ptr<ov::opset9::Gather> gather;
|
||||||
|
std::shared_ptr<ov::opset9::Constant> indices_const;
|
||||||
|
std::shared_ptr<ov::opset9::Constant> axes_const;
|
||||||
|
size_t input_idx;
|
||||||
|
|
||||||
|
bool isEmpty() const {
|
||||||
|
return !gather || !indices_const || !axes_const;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Finds node first input that is a Gather operation and returns filled GatherInputsInfo
|
||||||
|
* for it
|
||||||
|
*/
|
||||||
|
GatherInputsInfo GetFirstGatherInput(std::shared_ptr<ov::Node>);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Checks if @arg has any input node that is a Gather operation
|
||||||
|
*/
|
||||||
|
bool IfNodeHasGatherInputs(const ov::Output<ov::Node>&);
|
||||||
|
|
||||||
|
namespace sink_backward {
|
||||||
|
/**
|
||||||
|
* @brief Inserts Gather layers on each input of @arg main_node with cloned indices and axes constants
|
||||||
|
*/
|
||||||
|
ov::NodeVector InsertGatherBeforeNode(std::shared_ptr<ov::Node> main_node,
|
||||||
|
const std::shared_ptr<ov::opset9::Constant>& indices_const,
|
||||||
|
const std::shared_ptr<ov::opset9::Constant>& axes_const);
|
||||||
|
} // namespace sink_backward
|
||||||
|
|
||||||
|
void UpdateForwardGatherSinkingAbility(std::shared_ptr<ov::Node>);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Checks if @arg has consumers that all are the same Gather operation. If no consumers at all
|
||||||
|
* returns false.
|
||||||
|
*/
|
||||||
|
bool HasSameOutputGatherNodes(const ov::Output<ov::Node>&);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes all direct node consumers that have one output
|
||||||
|
*/
|
||||||
|
void RemoveSingleOutputConsumers(std::shared_ptr<ov::Node>);
|
||||||
|
|
||||||
|
} // namespace gather_sinking
|
@ -0,0 +1,563 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/gather_sinking_unary.hpp"
|
||||||
|
|
||||||
|
#include <openvino/frontend/manager.hpp>
|
||||||
|
#include <openvino/opsets/opset9.hpp>
|
||||||
|
#include <openvino/pass/manager.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
|
||||||
|
using NodePtr = std::shared_ptr<ov::Node>;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::string to_string(const Shape& shape) {
|
||||||
|
std::ostringstream result;
|
||||||
|
result << "{";
|
||||||
|
for (size_t idx = 0; idx < shape.size(); ++idx) {
|
||||||
|
if (idx)
|
||||||
|
result << ",";
|
||||||
|
result << shape[idx];
|
||||||
|
}
|
||||||
|
result << "}";
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> GenerateVector(size_t size, size_t initial_value) {
|
||||||
|
std::vector<size_t> vec(size);
|
||||||
|
std::iota(vec.begin(), vec.end(), initial_value);
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> MakeGatherIndexes(size_t size) {
|
||||||
|
std::vector<size_t> indexes = GenerateVector(size, 0);
|
||||||
|
std::next_permutation(indexes.begin(), indexes.end());
|
||||||
|
return indexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Gather> MakeGather(NodePtr input_node) {
|
||||||
|
const ov::Shape& input_shape = input_node->get_output_shape(0);
|
||||||
|
const size_t input_shape_product =
|
||||||
|
std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<size_t>());
|
||||||
|
const std::vector<size_t> indexes = MakeGatherIndexes(input_shape_product);
|
||||||
|
auto gather_indexes_node = Constant::create(element::i64, ov::Shape{indexes.size()}, indexes);
|
||||||
|
|
||||||
|
const size_t axis = 1;
|
||||||
|
auto gather_axis_node = Constant::create(element::i64, Shape{}, {axis});
|
||||||
|
|
||||||
|
return std::make_shared<Gather>(input_node, gather_indexes_node, gather_axis_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class IUnaryFactory {
|
||||||
|
public:
|
||||||
|
IUnaryFactory(const std::string& type_name) : type_name_(type_name) {}
|
||||||
|
virtual ~IUnaryFactory() = default;
|
||||||
|
virtual NodePtr create(NodePtr parent_node) const = 0;
|
||||||
|
|
||||||
|
const std::string& getTypeName() const {
|
||||||
|
return type_name_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string type_name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>;
|
||||||
|
|
||||||
|
template <typename UnaryT>
|
||||||
|
class UnaryFactory : public IUnaryFactory {
|
||||||
|
public:
|
||||||
|
UnaryFactory(const std::string& type_name) : IUnaryFactory(type_name) {}
|
||||||
|
NodePtr create(NodePtr parent_node) const override {
|
||||||
|
return std::make_shared<UnaryT>(parent_node);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
NodePtr UnaryFactory<Elu>::create(NodePtr parent_node) const {
|
||||||
|
return std::make_shared<Elu>(parent_node, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
NodePtr UnaryFactory<Clamp>::create(NodePtr parent_node) const {
|
||||||
|
return std::make_shared<Clamp>(parent_node, 0.1, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
NodePtr UnaryFactory<Convert>::create(NodePtr parent_node) const {
|
||||||
|
return std::make_shared<Convert>(parent_node, element::f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename UnaryT>
|
||||||
|
UnaryFactoryPtr CreateUnaryFactory(const std::string& type_name) {
|
||||||
|
return std::make_shared<UnaryFactory<UnaryT>>(type_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class IPassFactory {
|
||||||
|
public:
|
||||||
|
IPassFactory(const std::string& type_name) : type_name_(type_name) {}
|
||||||
|
virtual ~IPassFactory() = default;
|
||||||
|
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||||
|
const std::string& getTypeName() const {
|
||||||
|
return type_name_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string type_name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||||
|
|
||||||
|
template <typename PassT>
|
||||||
|
class PassFactory : public IPassFactory {
|
||||||
|
public:
|
||||||
|
PassFactory(const std::string& type_name) : IPassFactory(type_name) {}
|
||||||
|
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||||
|
pass_manager.register_pass<PassT>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::intel_gna::pass::pass_name>>(#pass_name)
|
||||||
|
|
||||||
|
#undef CREATE_UNARY_FACTORY
|
||||||
|
#define CREATE_UNARY_FACTORY(type_name) CreateUnaryFactory<type_name>(#type_name)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
using FloatPtr = std::unique_ptr<float[]>;
|
||||||
|
|
||||||
|
using CreateGraphF = std::function<std::shared_ptr<ov::Model>(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type)>;
|
||||||
|
|
||||||
|
using TestParams = std::tuple<UnaryFactoryPtr,
|
||||||
|
PassFactoryPtr,
|
||||||
|
size_t, /* num_unary_ops */
|
||||||
|
CreateGraphF, /* model_factory */
|
||||||
|
CreateGraphF, /* reference_model_factory */
|
||||||
|
Shape, /* input shape */
|
||||||
|
element::Type>; /* input type */
|
||||||
|
|
||||||
|
class GatherSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
|
||||||
|
public:
|
||||||
|
static std::string get_test_name(const testing::TestParamInfo<TestParams>& obj) {
|
||||||
|
UnaryFactoryPtr unary_factory;
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t num_unary_ops;
|
||||||
|
CreateGraphF model_factory;
|
||||||
|
CreateGraphF reference_model_factory;
|
||||||
|
Shape input_shape;
|
||||||
|
element::Type input_type;
|
||||||
|
std::tie(unary_factory,
|
||||||
|
pass_factory,
|
||||||
|
num_unary_ops,
|
||||||
|
model_factory,
|
||||||
|
reference_model_factory,
|
||||||
|
input_shape,
|
||||||
|
input_type) = obj.param;
|
||||||
|
|
||||||
|
std::ostringstream test_name;
|
||||||
|
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
|
||||||
|
test_name << "numUnaryOps=" << num_unary_ops << "/";
|
||||||
|
test_name << "inputShape=" << to_string(input_shape) << "/";
|
||||||
|
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
|
||||||
|
test_name << "passFactory=" << pass_factory->getTypeName() << "/";
|
||||||
|
test_name << "inputType=" << input_type;
|
||||||
|
|
||||||
|
return test_name.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::string GetFinalNodeName(std::shared_ptr<ov::Model> model, int index = 0) {
|
||||||
|
NodePtr result_node = model->get_results()[index];
|
||||||
|
return result_node->get_input_node_ptr(0)->get_friendly_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto gather = MakeGather(X);
|
||||||
|
|
||||||
|
NodePtr in_op = gather;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gather = MakeGather(in_op);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(gather, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
static NodePtr CreateReshape(NodePtr parent_node) {
|
||||||
|
const Shape& input_shape = parent_node->get_output_shape(0);
|
||||||
|
const size_t mul = std::accumulate(input_shape.begin(), input_shape.end(), (size_t)1, std::multiplies<size_t>());
|
||||||
|
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{1}, Shape{mul});
|
||||||
|
return std::make_shared<Reshape>(parent_node, reshape_const, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mult_consumers_last_node {
|
||||||
|
namespace with_reshape {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gather = MakeGather(in_op);
|
||||||
|
|
||||||
|
auto reshape1 = CreateReshape(gather);
|
||||||
|
auto reshape2 = CreateReshape(gather);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{reshape1, reshape2}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto gather = MakeGather(X);
|
||||||
|
|
||||||
|
NodePtr in_op = gather;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto reshape1 = CreateReshape(in_op);
|
||||||
|
auto reshape2 = CreateReshape(in_op);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{reshape1, reshape2}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
} // namespace with_reshape
|
||||||
|
|
||||||
|
namespace with_eltwise {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sinh = std::make_shared<Sinh>(in_op);
|
||||||
|
|
||||||
|
auto gather0 = MakeGather(sinh);
|
||||||
|
|
||||||
|
auto cosh = std::make_shared<Cosh>(in_op);
|
||||||
|
|
||||||
|
auto gather1 = MakeGather(cosh);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{gather0, gather1}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto gather0 = MakeGather(X);
|
||||||
|
|
||||||
|
NodePtr in_op = gather0;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sinh = std::make_shared<Sinh>(in_op);
|
||||||
|
auto cosh = std::make_shared<Cosh>(in_op);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{sinh, cosh}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace with_eltwise
|
||||||
|
} // namespace mult_consumers_last_node
|
||||||
|
|
||||||
|
namespace mult_consumers_first_node {
|
||||||
|
namespace backward {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
auto cosh = std::make_shared<Cosh>(in_op);
|
||||||
|
outputs.push_back(cosh);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gather0 = MakeGather(in_op);
|
||||||
|
|
||||||
|
outputs.push_back(gather0);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace backward
|
||||||
|
|
||||||
|
namespace backward_mult_transposes {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gather0 = MakeGather(in_op);
|
||||||
|
|
||||||
|
auto tanh0 = std::make_shared<Tanh>(gather0);
|
||||||
|
|
||||||
|
auto gather1 = MakeGather(in_op);
|
||||||
|
|
||||||
|
auto tanh1 = std::make_shared<Tanh>(gather1);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto gather0 = MakeGather(X);
|
||||||
|
|
||||||
|
NodePtr in_op = gather0;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tanh0 = std::make_shared<Tanh>(in_op);
|
||||||
|
auto tanh1 = std::make_shared<Tanh>(in_op);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace backward_mult_transposes
|
||||||
|
|
||||||
|
namespace forward {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto sinh = std::make_shared<Sinh>(X);
|
||||||
|
|
||||||
|
auto gather0 = MakeGather(sinh);
|
||||||
|
|
||||||
|
auto reshape = CreateReshape(gather0);
|
||||||
|
|
||||||
|
NodePtr in_op = gather0;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{in_op, reshape}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto sinh = std::make_shared<Sinh>(X);
|
||||||
|
|
||||||
|
auto gather0 = MakeGather(sinh);
|
||||||
|
|
||||||
|
auto reshape = CreateReshape(gather0);
|
||||||
|
|
||||||
|
NodePtr in_op = sinh;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gather1 = MakeGather(in_op);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{gather1, reshape}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace forward
|
||||||
|
} // namespace mult_consumers_first_node
|
||||||
|
|
||||||
|
std::vector<UnaryFactoryPtr> unary_factories = {
|
||||||
|
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
|
||||||
|
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
|
||||||
|
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
|
||||||
|
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
|
||||||
|
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
|
||||||
|
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
|
||||||
|
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||||
|
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||||
|
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
|
||||||
|
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
|
||||||
|
|
||||||
|
std::vector<size_t> unary_operations_numbers = {1, 10};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST_P(GatherSinkingUnaryTestFixture, CompareFunctions) {
|
||||||
|
UnaryFactoryPtr unary_factory;
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t num_unary_ops;
|
||||||
|
CreateGraphF model_factory;
|
||||||
|
CreateGraphF reference_model_factory;
|
||||||
|
Shape input_shape;
|
||||||
|
element::Type input_type;
|
||||||
|
std::tie(unary_factory,
|
||||||
|
pass_factory,
|
||||||
|
num_unary_ops,
|
||||||
|
model_factory,
|
||||||
|
reference_model_factory,
|
||||||
|
input_shape,
|
||||||
|
input_type) = this->GetParam();
|
||||||
|
|
||||||
|
model = model_factory(unary_factory, num_unary_ops, input_shape, input_type);
|
||||||
|
model_ref = reference_model_factory(unary_factory, num_unary_ops, input_shape, input_type);
|
||||||
|
pass_factory->registerPass(manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(GatherSinkingUnaryForwardTestSuite,
|
||||||
|
GatherSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryForward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(CreateFunctionTransposeBefore),
|
||||||
|
::testing::Values(CreateFunctionTransposeAfter),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
GatherSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(GatherSinkingUnaryBackwardTestSuite,
|
||||||
|
GatherSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryBackward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(CreateFunctionTransposeAfter),
|
||||||
|
::testing::Values(CreateFunctionTransposeBefore),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
GatherSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
GatherSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
|
||||||
|
GatherSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryForward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
|
||||||
|
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
GatherSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
GatherSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
|
||||||
|
GatherSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryBackward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
|
||||||
|
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
GatherSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
GatherSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
|
||||||
|
GatherSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryForward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
|
||||||
|
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
GatherSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
GatherSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
|
||||||
|
GatherSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryForward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(mult_consumers_first_node::forward::CreateFunction),
|
||||||
|
::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
GatherSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(GatherSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
|
||||||
|
GatherSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryBackward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
|
||||||
|
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
GatherSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
GatherSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
|
||||||
|
GatherSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingUnaryBackward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateFunction),
|
||||||
|
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
GatherSinkingUnaryTestFixture::get_test_name);
|
Loading…
Reference in New Issue
Block a user