From 6a5b6b8b30165c71cdb1eb160136ce429c0332f3 Mon Sep 17 00:00:00 2001 From: Anton Chetverikov Date: Wed, 9 Jun 2021 17:26:52 +0300 Subject: [PATCH] Add transformation to normalize negative indices in Gather (#6028) * Add transformation to normalize negative indices in Gather * Update transformation name and add constant indices input to the pattern * Apply comments * Fix copyright year * Add more tests * Update type logic * Add test with different types and update old ones * Update pattern and logic to check only axis dimension is static, add appropriate tests * Fix cdestyle * Add axis normalization * Fix wrong value in gather input * Add related tests * Add axis_constant to check * Remove ngraph_check --- .../gather_normalize_negative_indices.hpp | 29 ++ .../common_optimizations.cpp | 2 + .../gather_normalize_negative_indices.cpp | 77 +++++ ...gather_normalize_negative_indices_test.cpp | 306 ++++++++++++++++++ 4 files changed, 414 insertions(+) create mode 100644 inference-engine/src/transformations/include/transformations/op_conversions/gather_normalize_negative_indices.hpp create mode 100644 inference-engine/src/transformations/src/transformations/op_conversions/gather_normalize_negative_indices.cpp create mode 100644 inference-engine/tests/functional/inference_engine/transformations/gather_normalize_negative_indices_test.cpp diff --git a/inference-engine/src/transformations/include/transformations/op_conversions/gather_normalize_negative_indices.hpp b/inference-engine/src/transformations/include/transformations/op_conversions/gather_normalize_negative_indices.hpp new file mode 100644 index 00000000000..1ec1ffe628e --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/op_conversions/gather_normalize_negative_indices.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace ngraph { +namespace pass { + + class TRANSFORMATIONS_API GatherNegativeConstIndicesNormalize; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief GatherNegativeConstIndicesNormalize checks if indices value is negative scalar and + * normalizes it using ShapeOf->Add->Cast subgraph. + * We need to remove this transformation after adding support of negative indices in + * future version of Gather operation. + */ +class ngraph::pass::GatherNegativeConstIndicesNormalize : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + GatherNegativeConstIndicesNormalize(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 79f1dee8882..4ab5cf1e80d 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -70,6 +70,7 @@ #include "transformations/op_conversions/log_softmax_decomposition.hpp" #include "transformations/op_conversions/mvn6_decomposition.hpp" #include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp" +#include "transformations/op_conversions/gather_normalize_negative_indices.hpp" #include #include @@ -157,6 +158,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher(); decomp->add_matcher(); decomp->add_matcher(); + decomp->add_matcher(); decomp->set_name("ngraph::pass::CommonDecompositions"); // CF is required after all decompositions diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/gather_normalize_negative_indices.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/gather_normalize_negative_indices.cpp new file mode 100644 index 00000000000..86713451869 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/op_conversions/gather_normalize_negative_indices.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/gather_normalize_negative_indices.hpp" + +#include + +#include +#include +#include +#include "itt.hpp" + +NGRAPH_RTTI_DEFINITION(ngraph::pass::GatherNegativeConstIndicesNormalize, "GatherNegativeConstIndicesNormalize", 0); + +ngraph::pass::GatherNegativeConstIndicesNormalize::GatherNegativeConstIndicesNormalize() { + MATCHER_SCOPE(GatherNegativeConstIndicesNormalize); + auto data_input = ngraph::pattern::any_input(pattern::has_static_rank()); + auto axis_input = ngraph::pattern::wrap_type(); + auto indices_input = ngraph::pattern::wrap_type(); + auto gather_node = std::make_shared(data_input, indices_input, axis_input); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto gather = std::dynamic_pointer_cast(pattern_to_output.at(gather_node).get_node_shared_ptr()); + auto data = pattern_to_output.at(data_input); + auto axis_constant = std::dynamic_pointer_cast(pattern_to_output.at(axis_input).get_node_shared_ptr()); + auto indices_constant = std::dynamic_pointer_cast(pattern_to_output.at(indices_input).get_node_shared_ptr()); + + if (!gather || !axis_constant || !indices_constant) { + return false; + } + + auto indices = indices_constant->cast_vector(); + if (indices.size() != 1 || indices[0] >= 0) { + return false; + } + + auto axis = axis_constant->cast_vector(); + if (axis.size() != 1) { + return false; + } + + auto axis_value = axis[0]; + + // normalize `axis` value if it is negative + if (axis_value < 0) { + axis_value = axis_value + data.get_partial_shape().rank().get_length(); + } + + if (data.get_partial_shape().rank().get_length() < axis_value) { + return false; + } + + // check `axis` dimension of data tensor is static + if (!data.get_partial_shape()[axis_value].is_static()) { + return false; + } + + auto input_type = indices_constant->get_element_type(); + auto shape_of = std::make_shared(data, input_type); + auto input_gather = std::make_shared(shape_of, + ngraph::opset7::Constant::create(input_type, Shape{}, {axis_value}), ngraph::opset7::Constant::create(input_type, Shape{}, {0})); + + auto add = std::make_shared(input_gather, indices_constant); + auto gather_new = gather_node->copy_with_new_inputs({data, add, axis_constant}); + gather_new->set_friendly_name(gather->get_friendly_name()); + + ngraph::copy_runtime_info(gather, {shape_of, input_gather, add, gather_new}); + ngraph::replace_node(gather, gather_new); + + return true; + }; + + auto m = std::make_shared(gather_node, matcher_name); + register_matcher(m, callback); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/gather_normalize_negative_indices_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/gather_normalize_negative_indices_test.cpp new file mode 100644 index 00000000000..ec6c4204a9b --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/gather_normalize_negative_indices_test.cpp @@ -0,0 +1,306 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST(TransformationTests, GatherNegativeIndicesNormalize) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 15, 128}); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1}); + + auto gather = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto indices_type = ngraph::element::i32; + + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 15, 128}); + auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1}); + + auto shape_of = std::make_shared(data, indices_type); + auto input_gather = std::make_shared(shape_of, + ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0})); + auto add = std::make_shared(input_gather, indices); + auto gather = std::make_shared(data, add, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, GatherNegativeIndicesNormalize_neg_axis) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 15, 128}); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2}); + + auto gather = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto indices_type = ngraph::element::i32; + + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 15, 128}); + auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2}); + + auto shape_of = std::make_shared(data, indices_type); + auto input_gather = std::make_shared(shape_of, + ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0})); + auto add = std::make_shared(input_gather, indices); + auto gather = std::make_shared(data, add, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, GatherNegativeIndicesNormalize_dif_input_types) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 15, 128}); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}); + + auto gather = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto indices_type = ngraph::element::i32; + + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 15, 128}); + auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}); + + auto shape_of = std::make_shared(data, indices_type); + auto input_gather = std::make_shared(shape_of, + ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0})); + auto add = std::make_shared(input_gather, indices); + auto gather = std::make_shared(data, add, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, GatherNegativeIndicesNormalize_static_axis_dim) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN}); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1}); + + auto gather = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto indices_type = ngraph::element::i32; + + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN}); + auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1}); + + auto shape_of = std::make_shared(data, indices_type); + auto input_gather = std::make_shared(shape_of, + ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0})); + auto add = std::make_shared(input_gather, indices); + auto gather = std::make_shared(data, add, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, GatherNegativeIndicesNormalize_static_axis_dim_neg_axis) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN}); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2}); + + auto gather = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto indices_type = ngraph::element::i32; + + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN}); + auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2}); + + auto shape_of = std::make_shared(data, indices_type); + auto input_gather = std::make_shared(shape_of, + ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0})); + auto add = std::make_shared(input_gather, indices); + auto gather = std::make_shared(data, add, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, GatherNegativeIndicesNormalize_non_static_axis_dim) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape{DYN, DYN, DYN}); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1}); + + auto gather = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto indices_type = ngraph::element::i32; + + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape{DYN, DYN, DYN}); + auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1}); + + auto gather = std::make_shared(data, indices, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, GatherNegativeIndicesNormalize_positive_ind) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0}); + + auto gather = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0}); + + auto gather = std::make_shared(data, indices, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, GatherNegativeIndicesNormalize_non_static_rank) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic(ngraph::Rank::dynamic())); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0}); + + auto gather = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1}); + auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0}); + + auto gather = std::make_shared(data, indices, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +}