Add transformation to normalize negative indices in Gather (#6028)
* Add transformation to normalize negative indices in Gather * Update transformation name and add constant indices input to the pattern * Apply comments * Fix copyright year * Add more tests * Update type logic * Add test with different types and update old ones * Update pattern and logic to check only axis dimension is static, add appropriate tests * Fix cdestyle * Add axis normalization * Fix wrong value in gather input * Add related tests * Add axis_constant to check * Remove ngraph_check
This commit is contained in:
parent
3bc2c46693
commit
6a5b6b8b30
@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API GatherNegativeConstIndicesNormalize;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GatherNegativeConstIndicesNormalize checks if indices value is negative scalar and
|
||||
* normalizes it using ShapeOf->Add->Cast subgraph.
|
||||
* We need to remove this transformation after adding support of negative indices in
|
||||
* future version of Gather operation.
|
||||
*/
|
||||
class ngraph::pass::GatherNegativeConstIndicesNormalize : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
GatherNegativeConstIndicesNormalize();
|
||||
};
|
@ -70,6 +70,7 @@
|
||||
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
|
||||
#include "transformations/op_conversions/mvn6_decomposition.hpp"
|
||||
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
|
||||
#include "transformations/op_conversions/gather_normalize_negative_indices.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
@ -157,6 +158,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
decomp->add_matcher<ngraph::pass::MVN6Decomposition>();
|
||||
decomp->add_matcher<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
decomp->add_matcher<ngraph::pass::EinsumDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
decomp->set_name("ngraph::pass::CommonDecompositions");
|
||||
|
||||
// CF is required after all decompositions
|
||||
|
@ -0,0 +1,77 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/gather_normalize_negative_indices.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include "itt.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::GatherNegativeConstIndicesNormalize, "GatherNegativeConstIndicesNormalize", 0);
|
||||
|
||||
ngraph::pass::GatherNegativeConstIndicesNormalize::GatherNegativeConstIndicesNormalize() {
|
||||
MATCHER_SCOPE(GatherNegativeConstIndicesNormalize);
|
||||
auto data_input = ngraph::pattern::any_input(pattern::has_static_rank());
|
||||
auto axis_input = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
||||
auto indices_input = ngraph::pattern::wrap_type<ngraph::opset7::Constant>();
|
||||
auto gather_node = std::make_shared<ngraph::opset7::Gather>(data_input, indices_input, axis_input);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto gather = std::dynamic_pointer_cast<ngraph::opset7::Gather>(pattern_to_output.at(gather_node).get_node_shared_ptr());
|
||||
auto data = pattern_to_output.at(data_input);
|
||||
auto axis_constant = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(axis_input).get_node_shared_ptr());
|
||||
auto indices_constant = std::dynamic_pointer_cast<ngraph::opset7::Constant>(pattern_to_output.at(indices_input).get_node_shared_ptr());
|
||||
|
||||
if (!gather || !axis_constant || !indices_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto indices = indices_constant->cast_vector<int64_t>();
|
||||
if (indices.size() != 1 || indices[0] >= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis = axis_constant->cast_vector<int64_t>();
|
||||
if (axis.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis_value = axis[0];
|
||||
|
||||
// normalize `axis` value if it is negative
|
||||
if (axis_value < 0) {
|
||||
axis_value = axis_value + data.get_partial_shape().rank().get_length();
|
||||
}
|
||||
|
||||
if (data.get_partial_shape().rank().get_length() < axis_value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check `axis` dimension of data tensor is static
|
||||
if (!data.get_partial_shape()[axis_value].is_static()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_type = indices_constant->get_element_type();
|
||||
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, input_type);
|
||||
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
|
||||
ngraph::opset7::Constant::create(input_type, Shape{}, {axis_value}), ngraph::opset7::Constant::create(input_type, Shape{}, {0}));
|
||||
|
||||
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices_constant);
|
||||
auto gather_new = gather_node->copy_with_new_inputs({data, add, axis_constant});
|
||||
gather_new->set_friendly_name(gather->get_friendly_name());
|
||||
|
||||
ngraph::copy_runtime_info(gather, {shape_of, input_gather, add, gather_new});
|
||||
ngraph::replace_node(gather, gather_new);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(gather_node, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,306 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/gather_normalize_negative_indices.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, GatherNegativeIndicesNormalize) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto indices_type = ngraph::element::i32;
|
||||
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
|
||||
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
|
||||
|
||||
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
|
||||
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
|
||||
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
|
||||
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, GatherNegativeIndicesNormalize_neg_axis) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto indices_type = ngraph::element::i32;
|
||||
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
|
||||
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
|
||||
|
||||
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
|
||||
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
|
||||
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
|
||||
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, GatherNegativeIndicesNormalize_dif_input_types) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto indices_type = ngraph::element::i32;
|
||||
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{1, 15, 128});
|
||||
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1});
|
||||
|
||||
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
|
||||
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
|
||||
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
|
||||
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, GatherNegativeIndicesNormalize_static_axis_dim) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto indices_type = ngraph::element::i32;
|
||||
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
|
||||
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
|
||||
|
||||
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
|
||||
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
|
||||
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
|
||||
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, GatherNegativeIndicesNormalize_static_axis_dim_neg_axis) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto indices_type = ngraph::element::i32;
|
||||
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
|
||||
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
|
||||
|
||||
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
|
||||
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
|
||||
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
|
||||
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, GatherNegativeIndicesNormalize_non_static_axis_dim) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, DYN, DYN});
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto indices_type = ngraph::element::i32;
|
||||
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, DYN, DYN});
|
||||
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, GatherNegativeIndicesNormalize_positive_ind) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, GatherNegativeIndicesNormalize_non_static_rank) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(ngraph::Rank::dynamic()));
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::GatherNegativeConstIndicesNormalize>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto indices = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-1});
|
||||
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {0});
|
||||
|
||||
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
Loading…
Reference in New Issue
Block a user