This commit is contained in:
Evgeny Kotov 2023-02-20 18:33:03 +01:00
parent cbd56c3ed9
commit 48f20927af
6 changed files with 1737 additions and 14 deletions

View File

@ -0,0 +1,101 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/gather_sinking_binary.hpp"
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include <openvino/cc/ngraph/itt.hpp>
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/utils/gather_sinking_utils.hpp"
#include "transformations/rt_info/gather_sinking_attr.hpp"
using namespace ov;
using namespace ov::opset9;
using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace gather_sinking;
using namespace ov::intel_gna::pass;
using namespace ov::intel_gna::rt_info;
GatherSinkingBinaryForward::GatherSinkingBinaryForward() {
MATCHER_SCOPE(GatherSinkingBinaryForward);
auto if_gather_has_constants_rank_not_more_than_one = [](const GatherInputsInfo& inputs_info) -> bool {
return constant_has_rank_not_more_than(inputs_info.axis_const, 1) &&
constant_has_rank_not_more_than(inputs_info.indices_const, 1);
};
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(
[if_gather_has_constants_rank_not_more_than_one](const Output<Node>& output) -> bool {
return IfNodeHasGatherInputs(output, if_gather_has_constants_rank_not_more_than_one);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
GatherInputsInfo gather_input_info = GetFirstGatherInput(main_node);
sink_forward::UpdateInputGather(main_node, gather_input_info);
for (auto& new_node : sink_forward::InsertOutputGather(main_node, gather_input_info)) {
register_new_node(new_node);
gather_sinking::UpdateForwardGatherSinkingAbility(new_node);
}
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
GatherSinkingBinaryBackward::GatherSinkingBinaryBackward() {
MATCHER_SCOPE(GatherSinkingBinaryBackward);
auto main_node_label =
wrap_type<op::util::BinaryElementwiseArithmetic>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputGatherNodes(output);
});
auto indices_const_label = wrap_type<Constant>(rank_not_more_than(1));
auto axes_const_label = wrap_type<Constant>(rank_not_more_than(1));
auto gather_label =
wrap_type<Gather>({main_node_label, indices_const_label, axes_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_gather_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto indices_const = as_type_ptr<Constant>(pattern_to_output.at(indices_const_label).get_node_shared_ptr());
auto axes_const = as_type_ptr<Constant>(pattern_to_output.at(axes_const_label).get_node_shared_ptr());
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertGatherBeforeNode(main_node, indices_const, axes_const, gather)) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(gather, main_node);
return true;
};
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,29 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace intel_gna {
namespace pass {
class GatherSinkingBinaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GatherSinkingBinaryForward", "0");
GatherSinkingBinaryForward();
};
class GatherSinkingBinaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GatherSinkingBinaryBackward", "0");
GatherSinkingBinaryBackward();
};
} // namespace pass
} // namespace intel_gna
} // namespace ov

View File

@ -183,9 +183,10 @@ GatherSinkingUnaryBackwardMultiConsumers::GatherSinkingUnaryBackwardMultiConsume
const auto& pattern_to_output = m.get_pattern_value_map();
auto indices_const = as_type_ptr<Constant>(pattern_to_output.at(indices_const_label).get_node_shared_ptr());
auto axes_const = as_type_ptr<Constant>(pattern_to_output.at(axes_const_label).get_node_shared_ptr());
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertGatherBeforeNode(unary, indices_const, axes_const)) {
for (auto& new_node : sink_backward::InsertGatherBeforeNode(unary, indices_const, axes_const, gather)) {
register_new_node(new_node);
}

View File

@ -33,14 +33,14 @@ GatherInputsInfo GetFirstGatherInput(NodePtr node) {
auto indices_const_node = as_type_ptr<Constant>(gather_node->input_value(1).get_node_shared_ptr());
if (!indices_const_node)
continue;
auto axes_const_node = as_type_ptr<Constant>(gather_node->input_value(2).get_node_shared_ptr());
if (!axes_const_node)
auto axis_const_node = as_type_ptr<Constant>(gather_node->input_value(2).get_node_shared_ptr());
if (!axis_const_node)
continue;
{
GatherInputsInfo input_info;
input_info.gather = gather_node;
input_info.indices_const = indices_const_node;
input_info.axes_const = axes_const_node;
input_info.axis_const = axis_const_node;
input_info.input_idx = input_idx;
return input_info;
}
@ -94,32 +94,237 @@ Output<Node> FixInputNodeRank(Output<Node> input_node, Rank::value_type required
return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0);
}
/*
Converts gather indices to positive form
*/
std::vector<int64_t> NormalizeGatherIndices(const std::vector<int64_t>& indices) {
std::vector<int64_t> normalized(indices.size());
for (int i = 0; i < indices.size(); ++i) {
int64_t index = indices[i];
if (index < 0)
index += indices.size();
normalized[i] = index;
}
return normalized;
}
/*
Gets gather indices in positive form
*/
std::vector<int64_t> GetNormalizedGatherIndices(const std::shared_ptr<Constant>& indices) {
return NormalizeGatherIndices(indices->cast_vector<int64_t>());
}
/*
Converts axis to negative form
*/
int64_t NormalizeNegativeGatherAxis(int64_t axis, ov::Rank::value_type gather_input_rank) {
if (axis < 0)
return axis;
return axis - gather_input_rank;
}
/*
Gets gather axis in negative form
*/
int64_t GetNormalizedNegativeGatherAxis(const std::shared_ptr<Constant>& axis, ov::Rank::value_type gather_input_rank) {
return NormalizeNegativeGatherAxis(axis->cast_vector<int64_t>()[0], gather_input_rank);
}
int64_t ConvertAxisToPositive(int64_t axis, ov::Rank::value_type rank) {
if (axis >= 0)
return axis;
return axis + rank;
}
/*
Reverts gather indices in a such way that reverted and initial gather will do nothing if
stays after another.
Works only with positive form (no negative indices).
*/
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
std::vector<int64_t> out(indexes.size());
for (size_t i = 0; i < indexes.size(); i++) {
out.at(indexes[i]) = i;
}
return out;
}
size_t GetDimByAxis(const Shape& shape, int64_t axis) {
if (axis < 0)
axis += shape.size();
return shape[axis];
}
Shape Broadcast(const Shape& shape, ov::Rank::value_type rank) {
const int rank_delta = rank - shape.size();
if (rank_delta <= 0)
return shape;
Shape broadcasted(rank);
for (int i = 0; i < rank_delta; ++i) {
broadcasted[i] = 1;
}
std::copy(shape.begin(), shape.end(), broadcasted.begin() + rank_delta);
return broadcasted;
}
} // namespace
void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
const auto node2_output_names = output2.get_names();
output2.set_names(output1.get_names());
output1.set_names(node2_output_names);
}
void SwapFriendlyNames(NodePtr node1, NodePtr node2) {
const std::string node2_name = node2->get_friendly_name();
node2->set_friendly_name(node1->get_friendly_name());
node1->set_friendly_name(node2_name);
}
void SwapNames(NodePtr node1, NodePtr node2) {
SwapFriendlyNames(node1, node2);
SwapOutputNames(node1->output(0), node2->output(0));
}
namespace sink_forward {
/** @brief
* Inserts inverted Gather layer on all @main_node inputs except input from GatherInputsInfo argument
* Works only with 1D indices.
* It's simpler to work with negative gather axis since it doesn't depend on shape broadcasting.
* Converts gather axis to a negative form
* Doesn't add Gather layer if input_node_shape[axis] == 1 since it is useless and causes an invalid result.
* Input nodes can have different shapes. That shapes can have smaller or larger ranks. To manage it we need
* to find max input shape rank and broadcast all input shapes to it.
*/
void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info) {
if (gather_input_info.isEmpty() || HasDynamicRankInput(main_node))
return;
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
const std::vector<int64_t> gather_indices = GetNormalizedGatherIndices(gather_input_info.indices_const);
const std::vector<int64_t> reversed_gather_indices = ReverseGatherIndexes(gather_indices);
const auto indices_element_type = gather_input_info.indices_const->get_element_type();
const auto axis_element_type = gather_input_info.axis_const->get_element_type();
const auto max_input_rank = GetMaxInputRank(main_node);
if (max_input_rank < 0)
return;
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
auto input_node = main_node->input_value(i);
if (i == gather_input_info.input_idx) {
auto gather_parent = input_node.get_node()->input_value(0);
main_node->input(i).replace_source_output(gather_parent);
} else {
const Shape broadcasted_input_shape = Broadcast(input_node.get_shape(), max_input_rank);
if (GetDimByAxis(broadcasted_input_shape, gather_negative_axis) == 1)
continue;
auto new_indices_const = std::make_shared<Constant>(indices_element_type,
Shape{reversed_gather_indices.size()},
reversed_gather_indices);
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
gather_positive_axis);
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
main_node->input(i).replace_source_output(new_gather->output(0));
copy_runtime_info(input_node.get_node_shared_ptr(), {new_gather, new_indices_const, new_axis_const});
}
}
}
NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info) {
if (gather_input_info.isEmpty())
return {};
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
const auto axis_element_type = gather_input_info.axis_const->get_element_type();
NodeVector new_nodes;
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
auto main_node_consumers = main_node->output(i).get_target_inputs();
auto new_indices_const = gather_input_info.indices_const->clone_with_new_inputs({});
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
main_node->output(i).get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
gather_positive_axis);
auto new_gather = std::make_shared<Gather>(main_node->output(i), new_indices_const, new_axis_const);
for (auto& consumer : main_node_consumers) {
consumer.replace_source_output(new_gather);
}
copy_runtime_info(main_node, {new_gather, new_indices_const, new_axis_const});
SwapOutputNames(main_node->output(i), new_gather->output(0));
if (main_node->get_output_size() > 1)
new_gather->set_friendly_name(main_node->get_friendly_name() + "." + std::to_string(i));
else
SwapFriendlyNames(new_gather, main_node);
new_nodes.push_back(new_gather);
}
return new_nodes;
}
} // namespace sink_forward
namespace sink_backward {
NodeVector InsertGatherBeforeNode(NodePtr main_node,
const std::shared_ptr<Constant>& indices_const,
const std::shared_ptr<Constant>& axes_const) {
const std::shared_ptr<Constant>& axis_const,
const std::shared_ptr<Gather>& gather_node) {
if (HasDynamicRankInput(main_node))
return {};
NodeVector new_nodes;
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(axis_const,
gather_node->get_input_partial_shape(0).rank().get_length());
const auto axis_element_type = axis_const->get_element_type();
const auto max_input_rank = GetMaxInputRank(main_node);
if (max_input_rank < 0)
return {};
NodeVector new_nodes;
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
auto input_node = FixInputNodeRank(main_node->input_value(i), max_input_rank);
auto input_node = main_node->input_value(i);
const Shape broadcasted_input_shape = Broadcast(input_node.get_shape(), max_input_rank);
if (GetDimByAxis(broadcasted_input_shape, gather_negative_axis) == 1)
continue;
auto new_indices_const = indices_const->clone_with_new_inputs({});
auto new_axes_const = axes_const->clone_with_new_inputs({});
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axes_const);
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
gather_positive_axis);
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
main_node->input(i).replace_source_output(new_gather->output(0));
copy_runtime_info(input_node.get_node_shared_ptr(), {new_gather, new_indices_const, new_axes_const});
copy_runtime_info(input_node.get_node_shared_ptr(), {new_gather, new_indices_const, new_axis_const});
new_nodes.push_back(new_gather);
}
@ -257,4 +462,16 @@ void RemoveSingleOutputConsumers(NodePtr node) {
}
}
std::function<bool(Output<Node>)> rank_not_more_than(const ov::Rank::value_type expected_rank) {
return [=](Output<Node> output) -> bool {
const Rank rank = output.get_partial_shape().rank();
return (rank.is_static() && (rank.get_length() <= expected_rank));
};
}
bool constant_has_rank_not_more_than(const std::shared_ptr<Constant>& node, const ov::Rank::value_type expected_rank) {
const Rank rank = node->get_output_partial_shape(0).rank();
return (rank.is_static() && (rank.get_length() <= expected_rank));
}
} // namespace gather_sinking

View File

@ -20,11 +20,11 @@ namespace gather_sinking {
struct GatherInputsInfo {
std::shared_ptr<ov::opset9::Gather> gather;
std::shared_ptr<ov::opset9::Constant> indices_const;
std::shared_ptr<ov::opset9::Constant> axes_const;
std::shared_ptr<ov::opset9::Constant> axis_const;
size_t input_idx;
bool isEmpty() const {
return !gather || !indices_const || !axes_const;
return !gather || !indices_const || !axis_const;
}
};
@ -37,7 +37,49 @@ GatherInputsInfo GetFirstGatherInput(std::shared_ptr<ov::Node>);
/**
* @brief Checks if @arg has any input node that is a Gather operation
*/
bool IfNodeHasGatherInputs(const ov::Output<ov::Node>&);
template <typename GatherInfoPredicate>
bool IfNodeHasGatherInputs(const ov::Output<ov::Node>& output, GatherInfoPredicate gather_info_predicate) {
GatherInputsInfo inputs_info = GetFirstGatherInput(output.get_node_shared_ptr());
if (inputs_info.isEmpty())
return false;
return gather_info_predicate(inputs_info);
}
/**
* @brief Swaps @args output tensor names
*/
void SwapOutputNames(ov::Output<ov::Node>, ov::Output<ov::Node>);
/**
* @brief Swaps @args friendly names
*/
void SwapFriendlyNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
/**
* @brief Swaps @args output tensor names and friendly names
*/
void SwapNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
namespace sink_forward {
/**
* @brief Inserts reversed Gather on @args main_node inputs. Removes input Gather specified in @arg
* transpose_input_info
*/
void UpdateInputGather(std::shared_ptr<ov::Node> main_node, const GatherInputsInfo&);
/**
* @brief Removes @arg input node
*/
void RemoveInputNode(std::shared_ptr<ov::Node>, size_t input_idx);
/**
* @brief Inserts Gather on each main_node output with the order specified in @arg GatherInputsInfo
*/
ov::NodeVector InsertOutputGather(std::shared_ptr<ov::Node> main_node,
const GatherInputsInfo&);
} // namespace sink_forward
namespace sink_backward {
/**
@ -45,7 +87,8 @@ namespace sink_backward {
*/
ov::NodeVector InsertGatherBeforeNode(std::shared_ptr<ov::Node> main_node,
const std::shared_ptr<ov::opset9::Constant>& indices_const,
const std::shared_ptr<ov::opset9::Constant>& axes_const);
const std::shared_ptr<ov::opset9::Constant>& axes_const,
const std::shared_ptr<ov::opset9::Gather>& gather_node);
} // namespace sink_backward
void UpdateForwardGatherSinkingAbility(std::shared_ptr<ov::Node>);
@ -61,4 +104,11 @@ bool HasSameOutputGatherNodes(const ov::Output<ov::Node>&);
*/
void RemoveSingleOutputConsumers(std::shared_ptr<ov::Node>);
bool constant_has_rank_not_more_than(const std::shared_ptr<ov::opset9::Constant>&, const ov::Rank::value_type expected_rank);
/**
* Checks if output has rank not more than expected
*/
std::function<bool(ov::Output<ov::Node>)> rank_not_more_than(const ov::Rank::value_type expected_rank);
} // namespace gather_sinking

File diff suppressed because it is too large Load Diff