initial
This commit is contained in:
parent
cbd56c3ed9
commit
48f20927af
@ -0,0 +1,101 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/gather_sinking_binary.hpp"
|
||||
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/utils/gather_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/gather_sinking_attr.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::op::util;
|
||||
using namespace gather_sinking;
|
||||
using namespace ov::intel_gna::pass;
|
||||
using namespace ov::intel_gna::rt_info;
|
||||
|
||||
|
||||
GatherSinkingBinaryForward::GatherSinkingBinaryForward() {
|
||||
MATCHER_SCOPE(GatherSinkingBinaryForward);
|
||||
|
||||
auto if_gather_has_constants_rank_not_more_than_one = [](const GatherInputsInfo& inputs_info) -> bool {
|
||||
return constant_has_rank_not_more_than(inputs_info.axis_const, 1) &&
|
||||
constant_has_rank_not_more_than(inputs_info.indices_const, 1);
|
||||
};
|
||||
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(
|
||||
[if_gather_has_constants_rank_not_more_than_one](const Output<Node>& output) -> bool {
|
||||
return IfNodeHasGatherInputs(output, if_gather_has_constants_rank_not_more_than_one);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
|
||||
GatherInputsInfo gather_input_info = GetFirstGatherInput(main_node);
|
||||
|
||||
sink_forward::UpdateInputGather(main_node, gather_input_info);
|
||||
for (auto& new_node : sink_forward::InsertOutputGather(main_node, gather_input_info)) {
|
||||
register_new_node(new_node);
|
||||
gather_sinking::UpdateForwardGatherSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
GatherSinkingBinaryBackward::GatherSinkingBinaryBackward() {
|
||||
MATCHER_SCOPE(GatherSinkingBinaryBackward);
|
||||
auto main_node_label =
|
||||
wrap_type<op::util::BinaryElementwiseArithmetic>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputGatherNodes(output);
|
||||
});
|
||||
|
||||
auto indices_const_label = wrap_type<Constant>(rank_not_more_than(1));
|
||||
auto axes_const_label = wrap_type<Constant>(rank_not_more_than(1));
|
||||
|
||||
auto gather_label =
|
||||
wrap_type<Gather>({main_node_label, indices_const_label, axes_const_label}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_gather_sinking_node(output);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto indices_const = as_type_ptr<Constant>(pattern_to_output.at(indices_const_label).get_node_shared_ptr());
|
||||
auto axes_const = as_type_ptr<Constant>(pattern_to_output.at(axes_const_label).get_node_shared_ptr());
|
||||
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
for (auto& new_node : sink_backward::InsertGatherBeforeNode(main_node, indices_const, axes_const, gather)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
|
||||
SwapNames(gather, main_node);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pass {
|
||||
|
||||
class GatherSinkingBinaryForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingBinaryForward", "0");
|
||||
GatherSinkingBinaryForward();
|
||||
};
|
||||
|
||||
class GatherSinkingBinaryBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingBinaryBackward", "0");
|
||||
GatherSinkingBinaryBackward();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
@ -183,9 +183,10 @@ GatherSinkingUnaryBackwardMultiConsumers::GatherSinkingUnaryBackwardMultiConsume
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto indices_const = as_type_ptr<Constant>(pattern_to_output.at(indices_const_label).get_node_shared_ptr());
|
||||
auto axes_const = as_type_ptr<Constant>(pattern_to_output.at(axes_const_label).get_node_shared_ptr());
|
||||
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
|
||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||
|
||||
for (auto& new_node : sink_backward::InsertGatherBeforeNode(unary, indices_const, axes_const)) {
|
||||
for (auto& new_node : sink_backward::InsertGatherBeforeNode(unary, indices_const, axes_const, gather)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
|
@ -33,14 +33,14 @@ GatherInputsInfo GetFirstGatherInput(NodePtr node) {
|
||||
auto indices_const_node = as_type_ptr<Constant>(gather_node->input_value(1).get_node_shared_ptr());
|
||||
if (!indices_const_node)
|
||||
continue;
|
||||
auto axes_const_node = as_type_ptr<Constant>(gather_node->input_value(2).get_node_shared_ptr());
|
||||
if (!axes_const_node)
|
||||
auto axis_const_node = as_type_ptr<Constant>(gather_node->input_value(2).get_node_shared_ptr());
|
||||
if (!axis_const_node)
|
||||
continue;
|
||||
{
|
||||
GatherInputsInfo input_info;
|
||||
input_info.gather = gather_node;
|
||||
input_info.indices_const = indices_const_node;
|
||||
input_info.axes_const = axes_const_node;
|
||||
input_info.axis_const = axis_const_node;
|
||||
input_info.input_idx = input_idx;
|
||||
return input_info;
|
||||
}
|
||||
@ -94,32 +94,237 @@ Output<Node> FixInputNodeRank(Output<Node> input_node, Rank::value_type required
|
||||
return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0);
|
||||
}
|
||||
|
||||
/*
|
||||
Converts gather indices to positive form
|
||||
*/
|
||||
std::vector<int64_t> NormalizeGatherIndices(const std::vector<int64_t>& indices) {
|
||||
std::vector<int64_t> normalized(indices.size());
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
int64_t index = indices[i];
|
||||
if (index < 0)
|
||||
index += indices.size();
|
||||
normalized[i] = index;
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
/*
|
||||
Gets gather indices in positive form
|
||||
*/
|
||||
std::vector<int64_t> GetNormalizedGatherIndices(const std::shared_ptr<Constant>& indices) {
|
||||
return NormalizeGatherIndices(indices->cast_vector<int64_t>());
|
||||
}
|
||||
|
||||
/*
|
||||
Converts axis to negative form
|
||||
*/
|
||||
int64_t NormalizeNegativeGatherAxis(int64_t axis, ov::Rank::value_type gather_input_rank) {
|
||||
if (axis < 0)
|
||||
return axis;
|
||||
return axis - gather_input_rank;
|
||||
}
|
||||
|
||||
/*
|
||||
Gets gather axis in negative form
|
||||
*/
|
||||
int64_t GetNormalizedNegativeGatherAxis(const std::shared_ptr<Constant>& axis, ov::Rank::value_type gather_input_rank) {
|
||||
return NormalizeNegativeGatherAxis(axis->cast_vector<int64_t>()[0], gather_input_rank);
|
||||
}
|
||||
|
||||
int64_t ConvertAxisToPositive(int64_t axis, ov::Rank::value_type rank) {
|
||||
if (axis >= 0)
|
||||
return axis;
|
||||
return axis + rank;
|
||||
}
|
||||
|
||||
/*
|
||||
Reverts gather indices in a such way that reverted and initial gather will do nothing if
|
||||
stays after another.
|
||||
Works only with positive form (no negative indices).
|
||||
*/
|
||||
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
|
||||
std::vector<int64_t> out(indexes.size());
|
||||
for (size_t i = 0; i < indexes.size(); i++) {
|
||||
out.at(indexes[i]) = i;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
size_t GetDimByAxis(const Shape& shape, int64_t axis) {
|
||||
if (axis < 0)
|
||||
axis += shape.size();
|
||||
return shape[axis];
|
||||
}
|
||||
|
||||
Shape Broadcast(const Shape& shape, ov::Rank::value_type rank) {
|
||||
const int rank_delta = rank - shape.size();
|
||||
|
||||
if (rank_delta <= 0)
|
||||
return shape;
|
||||
|
||||
Shape broadcasted(rank);
|
||||
for (int i = 0; i < rank_delta; ++i) {
|
||||
broadcasted[i] = 1;
|
||||
}
|
||||
std::copy(shape.begin(), shape.end(), broadcasted.begin() + rank_delta);
|
||||
|
||||
return broadcasted;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
|
||||
const auto node2_output_names = output2.get_names();
|
||||
output2.set_names(output1.get_names());
|
||||
output1.set_names(node2_output_names);
|
||||
}
|
||||
|
||||
void SwapFriendlyNames(NodePtr node1, NodePtr node2) {
|
||||
const std::string node2_name = node2->get_friendly_name();
|
||||
node2->set_friendly_name(node1->get_friendly_name());
|
||||
node1->set_friendly_name(node2_name);
|
||||
}
|
||||
|
||||
void SwapNames(NodePtr node1, NodePtr node2) {
|
||||
SwapFriendlyNames(node1, node2);
|
||||
SwapOutputNames(node1->output(0), node2->output(0));
|
||||
}
|
||||
|
||||
namespace sink_forward {
|
||||
/** @brief
|
||||
* Inserts inverted Gather layer on all @main_node inputs except input from GatherInputsInfo argument
|
||||
* Works only with 1D indices.
|
||||
* It's simpler to work with negative gather axis since it doesn't depend on shape broadcasting.
|
||||
* Converts gather axis to a negative form
|
||||
* Doesn't add Gather layer if input_node_shape[axis] == 1 since it is useless and causes an invalid result.
|
||||
* Input nodes can have different shapes. That shapes can have smaller or larger ranks. To manage it we need
|
||||
* to find max input shape rank and broadcast all input shapes to it.
|
||||
*/
|
||||
void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info) {
|
||||
if (gather_input_info.isEmpty() || HasDynamicRankInput(main_node))
|
||||
return;
|
||||
|
||||
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
|
||||
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
|
||||
|
||||
const std::vector<int64_t> gather_indices = GetNormalizedGatherIndices(gather_input_info.indices_const);
|
||||
const std::vector<int64_t> reversed_gather_indices = ReverseGatherIndexes(gather_indices);
|
||||
|
||||
const auto indices_element_type = gather_input_info.indices_const->get_element_type();
|
||||
const auto axis_element_type = gather_input_info.axis_const->get_element_type();
|
||||
|
||||
const auto max_input_rank = GetMaxInputRank(main_node);
|
||||
if (max_input_rank < 0)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
auto input_node = main_node->input_value(i);
|
||||
if (i == gather_input_info.input_idx) {
|
||||
auto gather_parent = input_node.get_node()->input_value(0);
|
||||
main_node->input(i).replace_source_output(gather_parent);
|
||||
} else {
|
||||
const Shape broadcasted_input_shape = Broadcast(input_node.get_shape(), max_input_rank);
|
||||
if (GetDimByAxis(broadcasted_input_shape, gather_negative_axis) == 1)
|
||||
continue;
|
||||
|
||||
auto new_indices_const = std::make_shared<Constant>(indices_element_type,
|
||||
Shape{reversed_gather_indices.size()},
|
||||
reversed_gather_indices);
|
||||
|
||||
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
|
||||
input_node.get_partial_shape().rank().get_length());
|
||||
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
|
||||
Shape{1},
|
||||
gather_positive_axis);
|
||||
|
||||
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
|
||||
|
||||
main_node->input(i).replace_source_output(new_gather->output(0));
|
||||
|
||||
copy_runtime_info(input_node.get_node_shared_ptr(), {new_gather, new_indices_const, new_axis_const});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info) {
|
||||
if (gather_input_info.isEmpty())
|
||||
return {};
|
||||
|
||||
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
|
||||
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
|
||||
const auto axis_element_type = gather_input_info.axis_const->get_element_type();
|
||||
|
||||
NodeVector new_nodes;
|
||||
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
|
||||
auto main_node_consumers = main_node->output(i).get_target_inputs();
|
||||
|
||||
auto new_indices_const = gather_input_info.indices_const->clone_with_new_inputs({});
|
||||
|
||||
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
|
||||
main_node->output(i).get_partial_shape().rank().get_length());
|
||||
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
|
||||
Shape{1},
|
||||
gather_positive_axis);
|
||||
auto new_gather = std::make_shared<Gather>(main_node->output(i), new_indices_const, new_axis_const);
|
||||
|
||||
for (auto& consumer : main_node_consumers) {
|
||||
consumer.replace_source_output(new_gather);
|
||||
}
|
||||
|
||||
copy_runtime_info(main_node, {new_gather, new_indices_const, new_axis_const});
|
||||
SwapOutputNames(main_node->output(i), new_gather->output(0));
|
||||
|
||||
if (main_node->get_output_size() > 1)
|
||||
new_gather->set_friendly_name(main_node->get_friendly_name() + "." + std::to_string(i));
|
||||
else
|
||||
SwapFriendlyNames(new_gather, main_node);
|
||||
|
||||
new_nodes.push_back(new_gather);
|
||||
}
|
||||
|
||||
return new_nodes;
|
||||
}
|
||||
|
||||
} // namespace sink_forward
|
||||
|
||||
namespace sink_backward {
|
||||
|
||||
NodeVector InsertGatherBeforeNode(NodePtr main_node,
|
||||
const std::shared_ptr<Constant>& indices_const,
|
||||
const std::shared_ptr<Constant>& axes_const) {
|
||||
const std::shared_ptr<Constant>& axis_const,
|
||||
const std::shared_ptr<Gather>& gather_node) {
|
||||
if (HasDynamicRankInput(main_node))
|
||||
return {};
|
||||
|
||||
NodeVector new_nodes;
|
||||
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(axis_const,
|
||||
gather_node->get_input_partial_shape(0).rank().get_length());
|
||||
const auto axis_element_type = axis_const->get_element_type();
|
||||
|
||||
const auto max_input_rank = GetMaxInputRank(main_node);
|
||||
if (max_input_rank < 0)
|
||||
return {};
|
||||
|
||||
NodeVector new_nodes;
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
auto input_node = FixInputNodeRank(main_node->input_value(i), max_input_rank);
|
||||
auto input_node = main_node->input_value(i);
|
||||
|
||||
const Shape broadcasted_input_shape = Broadcast(input_node.get_shape(), max_input_rank);
|
||||
if (GetDimByAxis(broadcasted_input_shape, gather_negative_axis) == 1)
|
||||
continue;
|
||||
|
||||
auto new_indices_const = indices_const->clone_with_new_inputs({});
|
||||
auto new_axes_const = axes_const->clone_with_new_inputs({});
|
||||
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axes_const);
|
||||
|
||||
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
|
||||
input_node.get_partial_shape().rank().get_length());
|
||||
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
|
||||
Shape{1},
|
||||
gather_positive_axis);
|
||||
|
||||
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
|
||||
|
||||
main_node->input(i).replace_source_output(new_gather->output(0));
|
||||
|
||||
copy_runtime_info(input_node.get_node_shared_ptr(), {new_gather, new_indices_const, new_axes_const});
|
||||
copy_runtime_info(input_node.get_node_shared_ptr(), {new_gather, new_indices_const, new_axis_const});
|
||||
|
||||
new_nodes.push_back(new_gather);
|
||||
}
|
||||
@ -257,4 +462,16 @@ void RemoveSingleOutputConsumers(NodePtr node) {
|
||||
}
|
||||
}
|
||||
|
||||
std::function<bool(Output<Node>)> rank_not_more_than(const ov::Rank::value_type expected_rank) {
|
||||
return [=](Output<Node> output) -> bool {
|
||||
const Rank rank = output.get_partial_shape().rank();
|
||||
return (rank.is_static() && (rank.get_length() <= expected_rank));
|
||||
};
|
||||
}
|
||||
|
||||
bool constant_has_rank_not_more_than(const std::shared_ptr<Constant>& node, const ov::Rank::value_type expected_rank) {
|
||||
const Rank rank = node->get_output_partial_shape(0).rank();
|
||||
return (rank.is_static() && (rank.get_length() <= expected_rank));
|
||||
}
|
||||
|
||||
} // namespace gather_sinking
|
||||
|
@ -20,11 +20,11 @@ namespace gather_sinking {
|
||||
struct GatherInputsInfo {
|
||||
std::shared_ptr<ov::opset9::Gather> gather;
|
||||
std::shared_ptr<ov::opset9::Constant> indices_const;
|
||||
std::shared_ptr<ov::opset9::Constant> axes_const;
|
||||
std::shared_ptr<ov::opset9::Constant> axis_const;
|
||||
size_t input_idx;
|
||||
|
||||
bool isEmpty() const {
|
||||
return !gather || !indices_const || !axes_const;
|
||||
return !gather || !indices_const || !axis_const;
|
||||
}
|
||||
};
|
||||
|
||||
@ -37,7 +37,49 @@ GatherInputsInfo GetFirstGatherInput(std::shared_ptr<ov::Node>);
|
||||
/**
|
||||
* @brief Checks if @arg has any input node that is a Gather operation
|
||||
*/
|
||||
bool IfNodeHasGatherInputs(const ov::Output<ov::Node>&);
|
||||
template <typename GatherInfoPredicate>
|
||||
bool IfNodeHasGatherInputs(const ov::Output<ov::Node>& output, GatherInfoPredicate gather_info_predicate) {
|
||||
GatherInputsInfo inputs_info = GetFirstGatherInput(output.get_node_shared_ptr());
|
||||
if (inputs_info.isEmpty())
|
||||
return false;
|
||||
|
||||
return gather_info_predicate(inputs_info);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Swaps @args output tensor names
|
||||
*/
|
||||
void SwapOutputNames(ov::Output<ov::Node>, ov::Output<ov::Node>);
|
||||
|
||||
/**
|
||||
* @brief Swaps @args friendly names
|
||||
*/
|
||||
void SwapFriendlyNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
|
||||
|
||||
/**
|
||||
* @brief Swaps @args output tensor names and friendly names
|
||||
*/
|
||||
void SwapNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
|
||||
|
||||
namespace sink_forward {
|
||||
/**
|
||||
* @brief Inserts reversed Gather on @args main_node inputs. Removes input Gather specified in @arg
|
||||
* transpose_input_info
|
||||
*/
|
||||
void UpdateInputGather(std::shared_ptr<ov::Node> main_node, const GatherInputsInfo&);
|
||||
|
||||
/**
|
||||
* @brief Removes @arg input node
|
||||
*/
|
||||
void RemoveInputNode(std::shared_ptr<ov::Node>, size_t input_idx);
|
||||
|
||||
/**
|
||||
* @brief Inserts Gather on each main_node output with the order specified in @arg GatherInputsInfo
|
||||
*/
|
||||
ov::NodeVector InsertOutputGather(std::shared_ptr<ov::Node> main_node,
|
||||
const GatherInputsInfo&);
|
||||
} // namespace sink_forward
|
||||
|
||||
|
||||
namespace sink_backward {
|
||||
/**
|
||||
@ -45,7 +87,8 @@ namespace sink_backward {
|
||||
*/
|
||||
ov::NodeVector InsertGatherBeforeNode(std::shared_ptr<ov::Node> main_node,
|
||||
const std::shared_ptr<ov::opset9::Constant>& indices_const,
|
||||
const std::shared_ptr<ov::opset9::Constant>& axes_const);
|
||||
const std::shared_ptr<ov::opset9::Constant>& axes_const,
|
||||
const std::shared_ptr<ov::opset9::Gather>& gather_node);
|
||||
} // namespace sink_backward
|
||||
|
||||
void UpdateForwardGatherSinkingAbility(std::shared_ptr<ov::Node>);
|
||||
@ -61,4 +104,11 @@ bool HasSameOutputGatherNodes(const ov::Output<ov::Node>&);
|
||||
*/
|
||||
void RemoveSingleOutputConsumers(std::shared_ptr<ov::Node>);
|
||||
|
||||
bool constant_has_rank_not_more_than(const std::shared_ptr<ov::opset9::Constant>&, const ov::Rank::value_type expected_rank);
|
||||
|
||||
/**
|
||||
* Checks if output has rank not more than expected
|
||||
*/
|
||||
std::function<bool(ov::Output<ov::Node>)> rank_not_more_than(const ov::Rank::value_type expected_rank);
|
||||
|
||||
} // namespace gather_sinking
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user