Add transformation to normalize negative indices in Gather (#6028)

* Add transformation to normalize negative indices in Gather

* Update transformation name and add constant indices input to the pattern

* Apply comments

* Fix copyright year

* Add more tests

* Update type logic

* Add test with different types and update old ones

* Update pattern and logic to check only axis dimension is static, add appropriate tests

* Fix cdestyle

* Add axis normalization

* Fix wrong value in gather input

* Add related tests

* Add axis_constant to check

* Remove ngraph_check
This commit is contained in:
Anton Chetverikov 2021-06-09 17:26:52 +03:00 committed by GitHub
parent 3bc2c46693
commit 6a5b6b8b30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 414 additions and 0 deletions

View File

@ -0,0 +1,29 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API GatherNegativeConstIndicesNormalize;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief GatherNegativeConstIndicesNormalize checks if indices value is negative scalar and
* normalizes it using ShapeOf->Add->Cast subgraph.
* We need to remove this transformation after adding support of negative indices in
* future version of Gather operation.
*/
class ngraph::pass::GatherNegativeConstIndicesNormalize : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
GatherNegativeConstIndicesNormalize();
};

View File

@ -70,6 +70,7 @@
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
#include "transformations/op_conversions/mvn6_decomposition.hpp"
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
#include "transformations/op_conversions/gather_normalize_negative_indices.hpp"
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
@ -157,6 +158,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::MVN6Decomposition>();
decomp->add_matcher<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
decomp->add_matcher<ngraph::pass::EinsumDecomposition>();
decomp->add_matcher<ngraph::pass::GatherNegativeConstIndicesNormalize>();
decomp->set_name("ngraph::pass::CommonDecompositions");
// CF is required after all decompositions

View File

@ -0,0 +1,77 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/gather_normalize_negative_indices.hpp"
#include <memory>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::GatherNegativeConstIndicesNormalize, "GatherNegativeConstIndicesNormalize", 0);
ngraph::pass::GatherNegativeConstIndicesNormalize::GatherNegativeConstIndicesNormalize() {
MATCHER_SCOPE(GatherNegativeConstIndicesNormalize);
auto data_input = ngraph::pattern::any_input(pattern::has_static_rank());
auto axis_input = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto indices_input = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
auto gather_node = std::make_shared<ngraph::opset7::Gather>(data_input, indices_input, axis_input);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto gather = std::dynamic_pointer_cast<ngraph::opset7::Gather>(pattern_to_output.at(gather_node).get_node_shared_ptr());
auto data = pattern_to_output.at(data_input);
auto axis_constant = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(axis_input).get_node_shared_ptr());
auto indices_constant = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(indices_input).get_node_shared_ptr());
if (!gather || !axis_constant || !indices_constant) {
return false;
}
auto indices = indices_constant->cast_vector<int64_t>();
if (indices.size() != 1 || indices[0] >= 0) {
return false;
}
auto axis = axis_constant->cast_vector<int64_t>();
if (axis.size() != 1) {
return false;
}
auto axis_value = axis[0];
// normalize `axis` value if it is negative
if (axis_value < 0) {
axis_value = axis_value + data.get_partial_shape().rank().get_length();
}
if (data.get_partial_shape().rank().get_length() < axis_value) {
return false;
}
// check `axis` dimension of data tensor is static
if (!data.get_partial_shape()[axis_value].is_static()) {
return false;
}
auto input_type = indices_constant->get_element_type();
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, input_type);
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
ngraph::opset7::Constant::create(input_type, Shape{}, {axis_value}), ngraph::opset7::Constant::create(input_type, Shape{}, {0}));
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices_constant);
auto gather_new = gather_node->copy_with_new_inputs({data, add, axis_constant});
gather_new->set_friendly_name(gather->get_friendly_name());
ngraph::copy_runtime_info(gather, {shape_of, input_gather, add, gather_new});
ngraph::replace_node(gather, gather_new);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gather_node, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,306 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/gather_normalize_negative_indices.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, GatherNegativeIndicesNormalize) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto indices_type = ngraph::element::i32;
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, GatherNegativeIndicesNormalize_neg_axis) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto indices_type = ngraph::element::i32;
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, GatherNegativeIndicesNormalize_dif_input_types) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto indices_type = ngraph::element::i32;
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, GatherNegativeIndicesNormalize_static_axis_dim) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto indices_type = ngraph::element::i32;
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, GatherNegativeIndicesNormalize_static_axis_dim_neg_axis) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto indices_type = ngraph::element::i32;
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, GatherNegativeIndicesNormalize_non_static_axis_dim) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, DYN, DYN});
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto indices_type = ngraph::element::i32;
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, DYN, DYN});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, GatherNegativeIndicesNormalize_positive_ind) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, GatherNegativeIndicesNormalize_non_static_rank) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(ngraph::Rank::dynamic()));
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0});
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}