clang fixes + remove unneeded functions

This commit is contained in:
Evgeny Kotov 2023-02-20 18:52:49 +01:00
parent 4417a13bad
commit 3fed498f03
4 changed files with 101 additions and 109 deletions

View File

@ -4,11 +4,11 @@
#include "transformations/gather_sinking_binary.hpp"
#include <openvino/cc/ngraph/itt.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include <openvino/cc/ngraph/itt.hpp>
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
@ -16,8 +16,8 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/utils/gather_sinking_utils.hpp"
#include "transformations/rt_info/gather_sinking_attr.hpp"
#include "transformations/utils/gather_sinking_utils.hpp"
using namespace ov;
using namespace ov::opset9;
@ -27,7 +27,6 @@ using namespace gather_sinking;
using namespace ov::intel_gna::pass;
using namespace ov::intel_gna::rt_info;
GatherSinkingBinaryForward::GatherSinkingBinaryForward() {
MATCHER_SCOPE(GatherSinkingBinaryForward);
@ -64,18 +63,17 @@ GatherSinkingBinaryForward::GatherSinkingBinaryForward() {
GatherSinkingBinaryBackward::GatherSinkingBinaryBackward() {
MATCHER_SCOPE(GatherSinkingBinaryBackward);
auto main_node_label =
wrap_type<op::util::BinaryElementwiseArithmetic>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputGatherNodes(output);
});
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputGatherNodes(output);
});
auto indices_const_label = wrap_type<Constant>(rank_not_more_than(1));
auto axes_const_label = wrap_type<Constant>(rank_not_more_than(1));
auto gather_label =
wrap_type<Gather>({main_node_label, indices_const_label, axes_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_gather_sinking_node(output);
});
auto gather_label = wrap_type<Gather>({main_node_label, indices_const_label, axes_const_label},
[](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_gather_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();

View File

@ -49,11 +49,6 @@ GatherInputsInfo GetFirstGatherInput(NodePtr node) {
return GatherInputsInfo();
}
bool IfNodeHasGatherInputs(const Output<Node>& output) {
GatherInputsInfo inputs_info = GetFirstGatherInput(output.get_node_shared_ptr());
return !inputs_info.isEmpty();
}
namespace {
bool HasDynamicRankInput(NodePtr node) {
@ -78,15 +73,6 @@ Rank::value_type GetMaxInputRank(const NodePtr& node) {
return max_input_rank;
}
NodePtr InsertUnsqueeze(Output<Node> node, size_t n_dims) {
std::vector<size_t> dims(n_dims);
std::iota(dims.begin(), dims.end(), 0);
auto unsqueeze_const = std::make_shared<Constant>(element::i64, Shape{dims.size()}, dims);
auto unsqueeze = std::make_shared<Unsqueeze>(node, unsqueeze_const);
copy_runtime_info(node.get_node_shared_ptr(), {unsqueeze, unsqueeze_const});
return unsqueeze;
}
/*
Converts gather indices to positive form
*/
@ -185,20 +171,21 @@ void SwapNames(NodePtr node1, NodePtr node2) {
namespace sink_forward {
/** @brief
* Inserts inverted Gather layer on all @main_node inputs except input from GatherInputsInfo argument
* Works only with 1D indices.
* It's simpler to work with negative gather axis since it doesn't depend on shape broadcasting.
* Converts gather axis to a negative form
* Doesn't add Gather layer if input_node_shape[axis] == 1 since it is useless and causes an invalid result.
* Input nodes can have different shapes. That shapes can have smaller or larger ranks. To manage it we need
* to find max input shape rank and broadcast all input shapes to it.
*/
* Inserts inverted Gather layer on all @main_node inputs except input from GatherInputsInfo argument
* Works only with 1D indices.
* It's simpler to work with negative gather axis since it doesn't depend on shape broadcasting.
* Converts gather axis to a negative form
* Doesn't add Gather layer if input_node_shape[axis] == 1 since it is useless and causes an invalid result.
* Input nodes can have different shapes. That shapes can have smaller or larger ranks. To manage it we need
* to find max input shape rank and broadcast all input shapes to it.
*/
void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info) {
if (gather_input_info.isEmpty() || HasDynamicRankInput(main_node))
return;
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
const int64_t gather_negative_axis =
GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
const std::vector<int64_t> gather_indices = GetNormalizedGatherIndices(gather_input_info.indices_const);
const std::vector<int64_t> reversed_gather_indices = ReverseGatherIndexes(gather_indices);
@ -221,14 +208,12 @@ void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_i
continue;
auto new_indices_const = std::make_shared<Constant>(indices_element_type,
Shape{reversed_gather_indices.size()},
reversed_gather_indices);
Shape{reversed_gather_indices.size()},
reversed_gather_indices);
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{},
gather_positive_axis);
const int64_t gather_positive_axis =
ConvertAxisToPositive(gather_negative_axis, input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type, Shape{}, gather_positive_axis);
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
@ -243,8 +228,9 @@ NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_
if (gather_input_info.isEmpty())
return {};
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
const int64_t gather_negative_axis =
GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
const auto axis_element_type = gather_input_info.axis_const->get_element_type();
NodeVector new_nodes;
@ -253,11 +239,9 @@ NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_
auto new_indices_const = gather_input_info.indices_const->clone_with_new_inputs({});
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
main_node->output(i).get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{},
gather_positive_axis);
const int64_t gather_positive_axis =
ConvertAxisToPositive(gather_negative_axis, main_node->output(i).get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type, Shape{}, gather_positive_axis);
auto new_gather = std::make_shared<Gather>(main_node->output(i), new_indices_const, new_axis_const);
for (auto& consumer : main_node_consumers) {
@ -278,7 +262,7 @@ NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_
return new_nodes;
}
} // namespace sink_forward
} // namespace sink_forward
namespace sink_backward {
@ -289,8 +273,8 @@ NodeVector InsertGatherBeforeNode(NodePtr main_node,
if (HasDynamicRankInput(main_node))
return {};
const int64_t gather_negative_axis = GetNormalizedNegativeGatherAxis(axis_const,
gather_node->get_input_partial_shape(0).rank().get_length());
const int64_t gather_negative_axis =
GetNormalizedNegativeGatherAxis(axis_const, gather_node->get_input_partial_shape(0).rank().get_length());
const auto axis_element_type = axis_const->get_element_type();
const auto max_input_rank = GetMaxInputRank(main_node);
@ -307,11 +291,9 @@ NodeVector InsertGatherBeforeNode(NodePtr main_node,
auto new_indices_const = indices_const->clone_with_new_inputs({});
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{},
gather_positive_axis);
const int64_t gather_positive_axis =
ConvertAxisToPositive(gather_negative_axis, input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type, Shape{}, gather_positive_axis);
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
@ -464,7 +446,7 @@ std::function<bool(Output<Node>)> rank_not_more_than(const ov::Rank::value_type
bool constant_has_rank_not_more_than(const std::shared_ptr<Constant>& node, const ov::Rank::value_type expected_rank) {
const Rank rank = node->get_output_partial_shape(0).rank();
return (rank.is_static() && (rank.get_length() <= expected_rank));
return (rank.is_static() && (rank.get_length() <= expected_rank));
}
} // namespace gather_sinking

View File

@ -42,7 +42,7 @@ bool IfNodeHasGatherInputs(const ov::Output<ov::Node>& output, GatherInfoPredica
GatherInputsInfo inputs_info = GetFirstGatherInput(output.get_node_shared_ptr());
if (inputs_info.isEmpty())
return false;
return gather_info_predicate(inputs_info);
}
@ -76,11 +76,9 @@ void RemoveInputNode(std::shared_ptr<ov::Node>, size_t input_idx);
/**
* @brief Inserts Gather on each main_node output with the order specified in @arg GatherInputsInfo
*/
ov::NodeVector InsertOutputGather(std::shared_ptr<ov::Node> main_node,
const GatherInputsInfo&);
ov::NodeVector InsertOutputGather(std::shared_ptr<ov::Node> main_node, const GatherInputsInfo&);
} // namespace sink_forward
namespace sink_backward {
/**
* @brief Inserts Gather layers on each input of @arg main_node with cloned indices and axes constants
@ -104,11 +102,12 @@ bool HasSameOutputGatherNodes(const ov::Output<ov::Node>&);
*/
void RemoveSingleOutputConsumers(std::shared_ptr<ov::Node>);
bool constant_has_rank_not_more_than(const std::shared_ptr<ov::opset9::Constant>&, const ov::Rank::value_type expected_rank);
bool constant_has_rank_not_more_than(const std::shared_ptr<ov::opset9::Constant>&,
const ov::Rank::value_type expected_rank);
/**
* Checks if output has rank not more than expected
*/
*/
std::function<bool(ov::Output<ov::Node>)> rank_not_more_than(const ov::Rank::value_type expected_rank);
} // namespace gather_sinking

View File

@ -87,7 +87,10 @@ std::shared_ptr<Gather> MakeGather(NodePtr input_node, CreateIndicesF create_ind
}
template <typename CreateIndicesF>
std::shared_ptr<Gather> MakeGather(NodePtr input_node, CreateIndicesF create_indices_func, size_t axis, size_t indices_size) {
std::shared_ptr<Gather> MakeGather(NodePtr input_node,
CreateIndicesF create_indices_func,
size_t axis,
size_t indices_size) {
const std::vector<size_t> indexes = create_indices_func(indices_size, 0);
auto gather_indexes_node = Constant::create(ngraph::element::i64, ov::Shape{indexes.size()}, indexes);
@ -393,7 +396,7 @@ using TestBinaryParams = std::tuple<BinaryFactoryPtr,
size_t>; /* binary_gather_input_idx */
class GatherSinkingBinaryTestFixture : public ::testing::WithParamInterface<TestBinaryParams>,
public TransformationTestsF {
public TransformationTestsF {
public:
static std::string get_test_name(const testing::TestParamInfo<TestBinaryParams>& obj) {
BinaryFactoryPtr binary_factory;
@ -470,9 +473,8 @@ INSTANTIATE_TEST_SUITE_P(
// --------------------------------------------------------------------------------------
using CreateGraphBinaryIncompatShapesF = std::function<std::shared_ptr<Model>(BinaryFactoryPtr unary_factory,
element::Type input_type,
size_t binary_gather_input_idx)>;
using CreateGraphBinaryIncompatShapesF = std::function<
std::shared_ptr<Model>(BinaryFactoryPtr unary_factory, element::Type input_type, size_t binary_gather_input_idx)>;
using TestBinaryIncompatShapesParams = std::tuple<BinaryFactoryPtr,
PassFactoryPtr,
@ -548,7 +550,7 @@ std::shared_ptr<Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
return std::make_shared<Model>(ov::OutputVector{binary_op}, ov::ParameterVector{X});
}
} // namespace insert_gather
} // namespace insert_gather
namespace no_insert_gather {
std::shared_ptr<Model> CreateFunction(BinaryFactoryPtr binary_factory,
@ -584,7 +586,7 @@ std::shared_ptr<Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
return std::make_shared<Model>(ov::OutputVector{binary_op}, ov::ParameterVector{X});
}
} // namespace no_insert_gather
} // namespace no_insert_gather
} // namespace incompat_shapes
} // namespace backward
@ -628,7 +630,7 @@ std::shared_ptr<Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
return std::make_shared<Model>(ov::OutputVector{gather1}, ov::ParameterVector{X});
}
} // namespace insert_gather
} // namespace insert_gather
namespace no_insert_gather {
std::shared_ptr<Model> CreateFunction(BinaryFactoryPtr binary_factory,
@ -665,9 +667,9 @@ std::shared_ptr<Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
return std::make_shared<Model>(ov::OutputVector{gather1}, ov::ParameterVector{X});
}
} // namespace no_insert_gather
} // namespace no_insert_gather
} // namespace gather_small_input
} // namespace gather_small_input
namespace gather_large_input {
namespace insert_gather {
@ -706,9 +708,9 @@ std::shared_ptr<Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
return std::make_shared<Model>(ov::OutputVector{gather1}, ov::ParameterVector{X});
}
} // namespace insert_gather
} // namespace insert_gather
} // namespace gather_large_input
} // namespace gather_large_input
} // namespace incompat_shapes
} // namespace forward
@ -719,56 +721,68 @@ std::shared_ptr<Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
INSTANTIATE_TEST_SUITE_P(
GatherSinkingBinaryIncompatShapesBackwardInsertGatherTestSuite,
GatherSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryBackward)),
::testing::Values(binary::single_consumer::backward::incompat_shapes::insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::backward::incompat_shapes::insert_gather::CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
::testing::Combine(
::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryBackward)),
::testing::Values(binary::single_consumer::backward::incompat_shapes::insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::backward::incompat_shapes::insert_gather::CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
GatherSinkingBinaryIncompatShapesTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
GatherSinkingBinaryIncompatShapesBackwardNoGatherInsertTestSuite,
GatherSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryBackward)),
::testing::Values(binary::single_consumer::backward::incompat_shapes::no_insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::backward::incompat_shapes::no_insert_gather::CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
::testing::Combine(
::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryBackward)),
::testing::Values(binary::single_consumer::backward::incompat_shapes::no_insert_gather::CreateFunction),
::testing::Values(
binary::single_consumer::backward::incompat_shapes::no_insert_gather::CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
GatherSinkingBinaryIncompatShapesTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
GatherSinkingBinaryIncompatShapesGatherSmallInputForwardInsertGatherTestSuite,
GatherSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryForward)),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_small_input::insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_small_input::insert_gather::CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
::testing::Combine(
::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryForward)),
::testing::Values(
binary::single_consumer::forward::incompat_shapes::gather_small_input::insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_small_input::insert_gather::
CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
GatherSinkingBinaryIncompatShapesTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
GatherSinkingBinaryIncompatShapesGatherSmallInputForwardNoGatherInsertTestSuite,
GatherSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryForward)),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_small_input::no_insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_small_input::no_insert_gather::CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
::testing::Combine(
::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryForward)),
::testing::Values(
binary::single_consumer::forward::incompat_shapes::gather_small_input::no_insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_small_input::no_insert_gather::
CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
GatherSinkingBinaryIncompatShapesTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
GatherSinkingBinaryIncompatShapesGatherLargeInputInsertGatherForwardTestSuite,
GatherSinkingBinaryIncompatShapesTestFixture,
::testing::Combine(::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryForward)),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_large_input::insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_large_input::insert_gather::CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
::testing::Combine(
::testing::ValuesIn(binary_elementwise_factories),
::testing::Values(CREATE_PASS_FACTORY(GatherSinkingBinaryForward)),
::testing::Values(
binary::single_consumer::forward::incompat_shapes::gather_large_input::insert_gather::CreateFunction),
::testing::Values(binary::single_consumer::forward::incompat_shapes::gather_large_input::insert_gather::
CreateReferenceFunction),
::testing::Values(element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)),
GatherSinkingBinaryIncompatShapesTestFixture::get_test_name);
} // namespace one_input_transpose
@ -1154,9 +1168,8 @@ std::shared_ptr<Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
} // namespace backward
using CreateGraphF = std::function<std::shared_ptr<Model>(BinaryFactoryPtr binary_factory,
element::Type input_type,
size_t binary_gather_input_idx)>;
using CreateGraphF = std::function<
std::shared_ptr<Model>(BinaryFactoryPtr binary_factory, element::Type input_type, size_t binary_gather_input_idx)>;
struct CreateGraphFunctionDesc {
CreateGraphFunctionDesc() = default;