[MO][nG] add Gather-8 to MO, layer tests, nG python api unit-tests and up/down-grading transformations (#6560)

* gather-8 upgrade/downgrade transforms

* bump to opset8

* add Gather-8 to MO

* fix permutation for Gather

* added TF layer tests, ONNX layer tests for negative indices, and nG python api unit-tests for negative indices

* typo fix, disable downgrading transformation

* disable downgrade in clDNN, line width style fix

* all Gathers are converted to 7th version, transformations will be enabled/disabled while op will be added into plugins

* disabled Gather8LayerTest on GPU

* added common function for Op replacement

* concretized meaning of negative indices, fixed some typos

* applied review comments: left only meaningful layer tests

* removed op replacing functions from common utils

* returned back transformations without subroutines

* corrected style, added comments to common_optimizations.cpp
This commit is contained in:
Pavel Esir 2021-08-13 11:56:49 +03:00 committed by GitHub
parent f324ca7fcd
commit 4738bbd757
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 503 additions and 185 deletions

View File

@ -60,6 +60,7 @@
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
#include <transformations/op_conversions/convert_gather_downgrade.hpp>
#include <transformations/op_conversions/convert_gather_0d.hpp>
#include <transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp>
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
@ -362,6 +363,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
pass_config->disable<ngraph::pass::ConvertBroadcast3>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
pass_config->disable<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
pass_config->enable<ngraph::pass::ConvertGather8ToGather7>();
if (!config.enable_loop_unrolling) {
pass_config->disable<ngraph::pass::ConvertTensorIteratorToRNNSequence>();

View File

@ -31,8 +31,8 @@
#include <transformations/op_conversions/convert_shuffle_channels3.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
#include <transformations/op_conversions/convert_gelu.hpp>
#include <transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp>
#include <transformations/op_conversions/convert_gather_v1_to_gather_v7.hpp>
#include <transformations/op_conversions/convert_gather_downgrade.hpp>
#include <transformations/op_conversions/convert_gather_upgrade.hpp>
#include <transformations/op_conversions/gelu7_downgrade.hpp>
#include <transformations/op_conversions/hswish_decomposition.hpp>
#include <transformations/op_conversions/hsigmoid_decomposition.hpp>
@ -311,6 +311,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();
pass_config->enable<ngraph::pass::ConvertGather1ToGather7>();
pass_config->enable<ngraph::pass::ConvertGather8ToGather7>();
if (useLpt) {
pass_config->set_callback<ngraph::pass::ConvertQuantizeDequantize>([](const_node_ptr &node) -> bool {

View File

@ -12,6 +12,7 @@ namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertGather7ToGather1;
class TRANSFORMATIONS_API ConvertGather8ToGather7;
} // namespace pass
} // namespace ngraph
@ -25,3 +26,13 @@ public:
NGRAPH_RTTI_DECLARATION;
ConvertGather7ToGather1();
};
/**
* @ingroup ie_transformation_common_api
* @brief ConvertGather8ToGather7 converts v8::Gather into v7::Gather.
*/
class ngraph::pass::ConvertGather8ToGather7 : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertGather8ToGather7();
};

View File

@ -13,6 +13,8 @@ namespace pass {
class TRANSFORMATIONS_API ConvertGather1ToGather7;
class TRANSFORMATIONS_API ConvertGather7ToGather8;
} // namespace pass
} // namespace ngraph
@ -25,3 +27,13 @@ public:
NGRAPH_RTTI_DECLARATION;
ConvertGather1ToGather7();
};
/**
* @ingroup ie_transformation_common_api
* @brief ConvertGather7ToGather8 converts v7::Gather into v8::Gather.
*/
class ngraph::pass::ConvertGather7ToGather8 : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertGather7ToGather8();
};

View File

@ -16,6 +16,10 @@
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace op {
namespace util {
@ -130,7 +134,6 @@ Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & inpu
}
TRANSFORMATIONS_API std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& node);
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -47,8 +47,8 @@
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp"
#include "transformations/op_conversions/convert_gather_v1_to_gather_v7.hpp"
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
#include "transformations/op_conversions/convert_gather_upgrade.hpp"
#include "transformations/op_conversions/convert_mod.hpp"
#include "transformations/op_conversions/convert_minimum_to_power_and_max.hpp"
#include "transformations/op_conversions/convert_negative.hpp"
@ -179,8 +179,10 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
conv_fusions->set_name("ngraph::pass::ConvFusions");
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
manager.register_pass<ngraph::pass::ConvertGather8ToGather7>(); // not plugins implemented gather8
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>(); // not plugins implemented gather7
manager.register_pass<ngraph::pass::ConvertGather1ToGather7, false>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather8, false>();
manager.register_pass<ngraph::pass::ConvertDeformableConv8To1>();
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();

View File

@ -0,0 +1,69 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(pass::ConvertGather7ToGather1, "ConvertGather7ToGather1", 0);
NGRAPH_RTTI_DEFINITION(pass::ConvertGather8ToGather7, "ConvertGather8ToGather7", 0);
pass::ConvertGather7ToGather1::ConvertGather7ToGather1() {
MATCHER_SCOPE(ConvertGather7ToGather1);
auto gather_v7_pattern = pattern::wrap_type<opset7::Gather>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto gather_v7_node = std::dynamic_pointer_cast<opset7::Gather>(m.get_match_root());
if (!gather_v7_node)
return false;
if (gather_v7_node->get_batch_dims() != 0)
return false;
auto gather_v1_node = make_shared<opset1::Gather>(gather_v7_node->input_value(0),
gather_v7_node->input_value(1),
gather_v7_node->input_value(2));
gather_v1_node->set_friendly_name(gather_v7_node->get_friendly_name());
ngraph::copy_runtime_info(gather_v7_node, gather_v1_node);
ngraph::replace_node(gather_v7_node, gather_v1_node);
return true;
};
auto m = make_shared<pattern::Matcher>(gather_v7_pattern, matcher_name);
register_matcher(m, callback);
}
pass::ConvertGather8ToGather7::ConvertGather8ToGather7() {
MATCHER_SCOPE(ConvertGather8ToGather7);
auto gather_v8_pattern = pattern::wrap_type<opset8::Gather>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto gather_v8_node = std::dynamic_pointer_cast<opset8::Gather>(m.get_match_root());
if (!gather_v8_node)
return false;
auto gather_v7_node = make_shared<opset7::Gather>(gather_v8_node->input_value(0),
gather_v8_node->input_value(1),
gather_v8_node->input_value(2),
gather_v8_node->get_batch_dims());
gather_v7_node->set_friendly_name(gather_v8_node->get_friendly_name());
ngraph::copy_runtime_info(gather_v8_node, gather_v7_node);
ngraph::replace_node(gather_v8_node, gather_v7_node);
return true;
};
auto m = make_shared<pattern::Matcher>(gather_v8_pattern, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,68 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_gather_upgrade.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(pass::ConvertGather1ToGather7, "ConvertGather1ToGather7", 0);
NGRAPH_RTTI_DEFINITION(pass::ConvertGather7ToGather8, "ConvertGather7ToGather8", 0);
pass::ConvertGather1ToGather7::ConvertGather1ToGather7() {
MATCHER_SCOPE(ConvertGather1ToGather7);
auto gather_v1_pattern = pattern::wrap_type<opset1::Gather>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto gather_v1_node = std::dynamic_pointer_cast<opset1::Gather>(m.get_match_root());
if (!gather_v1_node)
return false;
auto gather_v7_node = make_shared<opset7::Gather>(gather_v1_node->input_value(0),
gather_v1_node->input_value(1),
gather_v1_node->input_value(2),
0);
gather_v7_node->set_friendly_name(gather_v1_node->get_friendly_name());
ngraph::copy_runtime_info(gather_v1_node, gather_v7_node);
ngraph::replace_node(gather_v1_node, gather_v7_node);
return true;
};
auto m = make_shared<pattern::Matcher>(gather_v1_pattern, matcher_name);
register_matcher(m, callback);
}
pass::ConvertGather7ToGather8::ConvertGather7ToGather8() {
MATCHER_SCOPE(ConvertGather7ToGather8);
auto gather_v7_pattern = pattern::wrap_type<opset7::Gather>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto gather_v7_node = std::dynamic_pointer_cast<opset7::Gather>(m.get_match_root());
if (!gather_v7_node)
return false;
auto gather_v8_node = make_shared<opset8::Gather>(gather_v7_node->input_value(0),
gather_v7_node->input_value(1),
gather_v7_node->input_value(2),
gather_v7_node->get_batch_dims());
gather_v8_node->set_friendly_name(gather_v7_node->get_friendly_name());
ngraph::copy_runtime_info(gather_v7_node, gather_v8_node);
ngraph::replace_node(gather_v7_node, gather_v8_node);
return true;
};
auto m = make_shared<pattern::Matcher>(gather_v7_pattern, matcher_name);
register_matcher(m, callback);
}

View File

@ -1,37 +0,0 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_gather_v1_to_gather_v7.hpp"
#include <ngraph/rt_info.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather1ToGather7, "ConvertGather1ToGather7", 0);
ngraph::pass::ConvertGather1ToGather7::ConvertGather1ToGather7() {
MATCHER_SCOPE(ConvertGather1ToGather7);
auto gather_v1 = pattern::wrap_type<ngraph::opset1::Gather>();
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto gather_v1_node = m.get_match_root();
if (!gather_v1_node)
return false;
auto data_input = gather_v1_node->input_value(0);
auto indices_input = gather_v1_node->input_value(1);
auto axis_input = gather_v1_node->input_value(2);
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data_input, indices_input, axis_input, 0);
gather_v7->set_friendly_name(gather_v1_node->get_friendly_name());
ngraph::copy_runtime_info(gather_v1_node, gather_v7);
ngraph::replace_node(gather_v1_node, gather_v7);
return true;
};
auto m = std::make_shared<pattern::Matcher>(gather_v1, matcher_name);
register_matcher(m, callback);
}

View File

@ -1,41 +0,0 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather7ToGather1, "ConvertGather7ToGather1", 0);
ngraph::pass::ConvertGather7ToGather1::ConvertGather7ToGather1() {
MATCHER_SCOPE(ConvertGather7ToGather1);
auto gather_v7 = pattern::wrap_type<ngraph::opset7::Gather>();
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto gather_v7_node = std::dynamic_pointer_cast<ngraph::opset7::Gather>(m.get_match_root());
if (!gather_v7_node)
return false;
if (gather_v7_node->get_batch_dims() != 0)
return false;
auto data_input = gather_v7_node->input_value(0);
auto indices_input = gather_v7_node->input_value(1);
auto axis_input = gather_v7_node->input_value(2);
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data_input, indices_input, axis_input);
gather_v1->set_friendly_name(gather_v7_node->get_friendly_name());
ngraph::copy_runtime_info(gather_v7_node, gather_v1);
ngraph::replace_node(gather_v7_node, gather_v1);
return true;
};
auto m = std::make_shared<pattern::Matcher>(gather_v7, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,109 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/convert_gather_downgrade.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertGather7toGather1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data, indices, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertGather7toGather1_nonzero_batch_dims) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, -1);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
// if batch_dims != 0 Gather-7 must remain
ASSERT_EQ(count_ops_of_type<ngraph::opset1::Gather>(f), 0);
ASSERT_EQ(count_ops_of_type<ngraph::opset7::Gather>(f), 1);
}
TEST(TransformationTests, ConvertGather8toGather7) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
int64_t batch_dims = 1;
auto gather_v8 = std::make_shared<ngraph::opset8::Gather>(data, indices, axis, batch_dims);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v8}, ngraph::ParameterVector{data, indices});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertGather8ToGather7>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
int64_t batch_dims = 1;
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, batch_dims);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -10,32 +10,18 @@
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/op/constant.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp>
#include <transformations/op_conversions/convert_gather_upgrade.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertGather7toGather1) {
TEST(TransformationTests, ConvertGather1toGather7) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
@ -43,32 +29,58 @@ TEST(TransformationTests, ConvertGather7toGather1) {
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data, indices, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertGather1ToGather7>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertGather7toGather1_nonzero_batch_dims) {
TEST(TransformationTests, ConvertGather7toGather8) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, -1);
int64_t batch_dims = 1;
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, batch_dims);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather8>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
// if batch_dims != 0 Gather-7 must remain
ASSERT_EQ(count_ops_of_type<ngraph::opset1::Gather>(f), 0);
ASSERT_EQ(count_ops_of_type<ngraph::opset7::Gather>(f), 1);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
int64_t batch_dims = 1;
auto gather_v8 = std::make_shared<ngraph::opset8::Gather>(data, indices, axis, batch_dims);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v8}, ngraph::ParameterVector{data, indices});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -1,52 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/op/constant.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/convert_gather_v1_to_gather_v7.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertGather1toGather7) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data, indices, axis);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertGather1ToGather7>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -61,6 +61,9 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*NormalizeL2LayerTest.*axes=\(\).*)",
// Not allowed dynamic loop tests on GPU
R"(.*smoke_StaticShapeLoop_dynamic_exit.*)"
R"(.*smoke_StaticShapeLoop_dynamic_exit.*)",
// TODO: until issue is xxx-59670 is resolved
R"(.*Gather8LayerTest.*)"
};
}

View File

@ -18,7 +18,7 @@ class Gather(Op):
super().__init__(graph, {
'op': self.op,
'type': self.op,
'version': 'opset7',
'version': 'opset8',
'batch_dims': 0,
'infer': self.infer,
'force_precision_in_ports': {1: 'int32', 2: 'int64'},
@ -31,7 +31,7 @@ class Gather(Op):
def backend_attrs(self):
version = self.get_opset()
if version == 'opset7':
if version in ['opset7', 'opset8']:
return ['batch_dims']
elif version == 'opset1':
return []
@ -75,7 +75,7 @@ class Gather(Op):
# we import PermuteInputs locally because it uses Gather inside and we have recursive imports
from mo.graph.perm_inputs import PermuteInputs
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'axis')
batch_dims_range = indices_shape[:batch_dims]
out_shape = np.concatenate((data_shape[:axis], indices_shape[batch_dims:], data_shape[axis + 1:]))

View File

@ -114,6 +114,12 @@ class TestGatherPartialInfer(unittest.TestCase):
indices=[0, 0, 4],
ref_value=[1, 1, 5])
def test_axis_0_batch_dims_0_negative_indices(self):
self.build_and_test_value_inference(axis=0, batch_dims=0,
data=[1, 2, 3, 4, 5],
indices=[-1, -2, -3],
ref_value=[5, 4, 3])
def test_axis_1_batch_dims_1(self):
self.build_and_test_value_inference(axis=1, batch_dims=1,
data=[[1, 2, 3, 4, 5],
@ -155,6 +161,27 @@ class TestGatherPartialInfer(unittest.TestCase):
[33, 34, 35, 36],
[29, 30, 31, 32]]]])
def test_axis_2_batch_dims_1_with_negative_indices(self):
self.build_and_test_value_inference(axis=2, batch_dims=1,
data=[[[[ 1, 2, 3, 4], # <-- first batch
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20]]],
[[[21, 22, 23, 24], # < -- second batch
[25, 26, 27, 28],
[29, 30, 31, 32],
[33, 34, 35, 36],
[37, 38, 39, 40]]]], # data_shape = (2, 1, 5, 4)
indices=[[-4, -3, -1],
[-1, 3, 2]],
ref_value=[[[[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[17, 18, 19, 20]]],
[[[37, 38, 39, 40],
[33, 34, 35, 36],
[29, 30, 31, 32]]]])
def test_axis_2_batch_dims_mimus_1(self):
self.build_and_test_value_inference(axis=2, batch_dims=-1,
data=[[[[ 1, 2, 3, 4], # <-- first batch

View File

@ -9,13 +9,13 @@
namespace ngraph {
namespace op {
namespace v1 {
/// \brief Gather slices from axis of params according to indices
/// \brief Gather slices from axis of data according to indices
class NGRAPH_API Gather : public op::util::GatherBase {
public:
NGRAPH_RTTI_DECLARATION;
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param data The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axis);
@ -28,7 +28,7 @@ public:
} // namespace v1
namespace v7 {
/// \brief Gather slices from axis of params according to indices
/// \brief Gather slices from axis of data according to indices
class NGRAPH_API Gather : public op::util::GatherBase {
public:
NGRAPH_RTTI_DECLARATION;
@ -53,7 +53,8 @@ public:
} // namespace v7
namespace v8 {
/// \brief Gather slices from axis of params according to indices
/// \brief Gather slices from axis of data according to indices. Negative indices
/// are supported and indicate reverse indexing from the end
class NGRAPH_API Gather : public op::util::GatherBase {
public:
NGRAPH_RTTI_DECLARATION;

View File

@ -53,7 +53,7 @@ from ngraph.opset1.ops import exp
from ngraph.opset1.ops import fake_quantize
from ngraph.opset1.ops import floor
from ngraph.opset1.ops import floor_mod
from ngraph.opset7.ops import gather
from ngraph.opset8.ops import gather
from ngraph.opset6.ops import gather_elements
from ngraph.opset5.ops import gather_nd
from ngraph.opset1.ops import gather_tree

View File

@ -253,3 +253,26 @@ def matrix_nms(
}
return _get_node_factory_opset8().create("MatrixNms", inputs, attributes)
@nameable_op
def gather(
data: NodeInput,
indices: NodeInput,
axis: NodeInput,
batch_dims: Optional[int] = 0,
) -> Node:
"""Return a node which performs Gather with support of negative indices.
@param data: N-D tensor with data for gathering
@param indices: N-D tensor with indices by which data is gathered. Negative indices
indicate reverse indexing from the end
@param axis: axis along which elements are gathered
@param batch_dims: number of batch dimensions
@return: The new node which performs Gather
"""
inputs = as_nodes(data, indices, axis)
attributes = {
"batch_dims": batch_dims
}
return _get_node_factory_opset8().create("Gather", inputs, attributes)

View File

@ -145,3 +145,5 @@ xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadca
xfail_issue_58033 = xfail_test(reason="Einsum operation misses support for complex ellipsis equations")
xfail_issue_58676 = xfail_test(reason="AssertionError: Not equal to tolerance rtol=0.001, atol=1e-07")
xfail_issue_onnx_models_140 = xfail_test(reason="https://github.com/onnx/models/issues/140")
xfail_issue_54630 = xfail_test(reason="Gather with negative indices is not yet implemented on CPU")

View File

@ -4,6 +4,7 @@
import ngraph as ng
import numpy as np
from tests import xfail_issue_54630
from tests.test_ngraph.util import run_op_node
@ -52,3 +53,37 @@ def test_gather_batch_dims_1():
result = run_op_node([input_data], ng.gather, input_indices, input_axis, batch_dims)
assert np.allclose(result, expected)
@xfail_issue_54630
def test_gather_negative_indices():
input_data = np.array(
[1.0, 1.1, 1.2, 2.0, 2.1, 2.2, 3.0, 3.1, 3.2], np.float32
).reshape((3, 3))
input_indices = np.array([0, -1], np.int32).reshape(1, 2)
input_axis = np.array([1], np.int32)
expected = np.array([1.0, 1.2, 2.0, 2.2, 3.0, 3.2], dtype=np.float32).reshape(
(3, 1, 2)
)
result = run_op_node([input_data], ng.gather, input_indices, input_axis)
assert np.allclose(result, expected)
@xfail_issue_54630
def test_gather_batch_dims_1_negative_indices():
input_data = np.array([[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]], np.float32)
input_indices = np.array([[0, 1, -2],
[-2, 0, 0]], np.int32)
input_axis = np.array([1], np.int32)
batch_dims = 1
expected = np.array([[1, 2, 4],
[9, 6, 6]], np.float32)
result = run_op_node([input_data], ng.gather, input_indices, input_axis, batch_dims)
assert np.allclose(result, expected)

View File

@ -1,7 +1,7 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, List
from typing import Any, Callable, List, Union
import numpy as np
@ -16,7 +16,7 @@ def _get_numpy_dtype(scalar):
def run_op_node(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
# type: (Union[NumericData, List[NumericData]], Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a node as an argument.

View File

@ -31,7 +31,7 @@ def generate_ir(coverage=False, **kwargs):
continue
elif (isinstance(value, tuple) and value) or (isinstance(value, str)):
params.extend(("--{}".format(key), str('"{}"'.format(value))))
elif (key == "mean_values" and (' ' in value or '(' in value)):
elif key == "mean_values" and (' ' in value or '(' in value):
params.extend(("--{}".format(key), str('"{}"'.format(value))))
else:
params.extend(("--{}".format(key), str(value)))

View File

@ -215,31 +215,23 @@ class TestGather(OnnxRuntimeLayerTest):
test_data_precommit = [
dict(shape=[6, 8, 10, 12], axis=2, indices=[[0, 2, 4], [5, 7, 9]], output_shape=[6, 8, 2, 3, 12]),
dict(shape=[4, 6, 8, 10, 12], axis=1, indices=[2, 5], output_shape=[4, 2, 8, 10, 12]),
dict(shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], output_shape=[4, 6, 8, 10, 2])]
dict(shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], output_shape=[4, 6, 8, 10, 2]),
dict(shape=[6, 8, 10, 12], axis=-1, indices=[[[2, -1], [3, 2]], [[5, -1], [3, -2]]], output_shape=[6, 8, 10, 2, 2, 2])
]
test_data = [dict(shape=[10, 12], axis=0, indices=[3, 6], output_shape=[2, 12]),
dict(shape=[10, 12], axis=1, indices=[4, 7], output_shape=[10, 2]),
dict(shape=[10, 12], axis=-1, indices=[4, 7], output_shape=[10, 2]),
dict(shape=[10, 12], axis=None, indices=[[0, 1, 3, 4], [5, 6, 8, 9]], output_shape=[2, 4, 12]),
dict(shape=[10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]], output_shape=[10, 2, 5]),
dict(shape=[8, 10, 12], axis=0, indices=[3, 6], output_shape=[2, 10, 12]),
dict(shape=[8, 10, 12], axis=1, indices=[4, 7], output_shape=[8, 2, 12]),
dict(shape=[8, 10, 12], axis=2, indices=[5, 8], output_shape=[8, 10, 2]),
dict(shape=[8, 10, 12], axis=-1, indices=[5, 8], output_shape=[8, 10, 2]),
dict(shape=[8, 10, 12], axis=None, indices=[[0, 1], [3, 4], [6, 7]], output_shape=[3, 2, 10, 12]),
dict(shape=[8, 10, 12], axis=1, indices=[[0, 2, 4], [5, 7, 9]], output_shape=[8, 2, 3, 12]),
dict(shape=[6, 8, 10, 12], axis=0, indices=[2, 5], output_shape=[2, 8, 10, 12]),
dict(shape=[6, 8, 10, 12], axis=1, indices=[3, 6], output_shape=[6, 2, 10, 12]),
dict(shape=[6, 8, 10, 12], axis=2, indices=[4, 7], output_shape=[6, 8, 2, 12]),
dict(shape=[6, 8, 10, 12], axis=3, indices=[5, 8], output_shape=[6, 8, 10, 2]),
dict(shape=[6, 8, 10, 12], axis=-1, indices=[5, 8], output_shape=[6, 8, 10, 2]),
dict(shape=[6, 8, 10, 12], axis=None, indices=[[0, 1, 2], [3, 4, 5]], output_shape=[2, 3, 8, 10, 12]),
dict(shape=[6, 8, 10, 12], axis=2, indices=[[0, 2, 4], [5, 7, 9]], output_shape=[6, 8, 2, 3, 12]),
dict(shape=[4, 6, 8, 10, 12], axis=0, indices=[1, 3], output_shape=[2, 6, 8, 10, 12]),
dict(shape=[4, 6, 8, 10, 12], axis=1, indices=[2, 5], output_shape=[4, 2, 8, 10, 12]),
dict(shape=[4, 6, 8, 10, 12], axis=2, indices=[3, 6], output_shape=[4, 6, 2, 10, 12]),
dict(shape=[4, 6, 8, 10, 12], axis=3, indices=[4, 7], output_shape=[4, 6, 8, 2, 12]),
dict(shape=[4, 6, 8, 10, 12], axis=4, indices=[5, 8], output_shape=[4, 6, 8, 10, 2]),
dict(shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], output_shape=[4, 6, 8, 10, 2])]
@pytest.mark.parametrize("params", test_data_precommit)
@ -259,3 +251,18 @@ class TestGather(OnnxRuntimeLayerTest):
def test_gather_const(self, params, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_net_const(**params, ir_version=ir_version), ie_device, precision, ir_version,
temp_dir=temp_dir)
test_data_negative_indices = [dict(shape=[10, 12], axis=0, indices=[3, -1, -4], output_shape=[3, 12]),
dict(shape=[6, 10, 14, 12], axis=1, indices=[[0, -1, 3, -4], [-5, 6, -7, 8]],
output_shape=[6, 2, 4, 14, 12]),
dict(shape=[8, 10, 14, 12], axis=1, indices=[[-2, 2, -4], [5, -7, 9]],
output_shape=[8, 2, 3, 14, 12]),
dict(shape=[6, 8, 10, 12], axis=-1, indices=[[[2, -1], [3, 2]], [[5, -1], [3, -2]]],
output_shape=[6, 8, 10, 2, 2, 2])]
@pytest.mark.xfail(reason='negative indices are not yet implemented on CPU: xxx-54630')
@pytest.mark.parametrize("params", test_data_negative_indices)
@pytest.mark.nightly
def test_gather_nightly_negative_indices(self, params, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_net(**params, ir_version=ir_version),
ie_device, precision, ir_version, temp_dir=temp_dir)

View File

@ -0,0 +1,61 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from common.tf_layer_test_class import CommonTFLayerTest
class TestGather(CommonTFLayerTest):
def create_indices_constant(self):
pass
def create_gather_net(self, data_shape, indices, axis, batch_dims, **kwargs):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
data = tf.compat.v1.placeholder(tf.float32, data_shape, 'data')
indices = tf.constant(indices, dtype=tf.int32)
gather = tf.gather(data, indices, axis=axis, batch_dims=batch_dims, name='gather_output')
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
ref_net = None
return tf_net, ref_net
test_data_precommit = [
dict(data_shape=[6, 8, 10, 12], indices=[[0, 2, 4], [5, 7, 9]], axis=2, batch_dims=0),
dict(data_shape=[4, 6, 8, 10, 12], indices=[2, 5], axis=1, batch_dims=0),
dict(data_shape=[4, 6, 8, 10, 12], indices=[2, 5], axis=-1, batch_dims=0)
]
@pytest.mark.parametrize("params", test_data_precommit)
@pytest.mark.precommit
def test_gather(self, params, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_gather_net(**params, ir_version=ir_version),
ie_device, precision, ir_version, temp_dir=temp_dir)
test_data_nightly = [
dict(data_shape=[2, 3], axis=1, indices=[0, 2], batch_dims=0),
dict(data_shape=[10, 12], axis=0, indices=[3, 6], batch_dims=0),
dict(data_shape=[10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]], batch_dims=0),
dict(data_shape=[8, 10, 12], axis=0, indices=[3, 6], batch_dims=0),
dict(data_shape=[8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
dict(data_shape=[6, 8, 10, 12], axis=0, indices=[2, 5], batch_dims=0),
dict(data_shape=[6, 8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
dict(data_shape=[6, 8, 10, 12], axis=2, indices=[[0, 2, 4], [5, 7, 9]], batch_dims=0),
dict(data_shape=[2, 14, 10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]], batch_dims=1),
dict(data_shape=[4, 6, 8, 10, 12], axis=0, indices=[1, 3], batch_dims=0),
dict(data_shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
]
@pytest.mark.parametrize("params", test_data_nightly)
@pytest.mark.nightly
def test_gather_nightly(self, params, ie_device, precision, ir_version, temp_dir):
self._test(*self.create_gather_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir)