[MO][nG] add Gather-8 to MO, layer tests, nG python api unit-tests and up/down-grading transformations (#6560)
* gather-8 upgrade/downgrade transforms * bump to opset8 * add Gather-8 to MO * fix permutation for Gather * added TF layer tests, ONNX layer tests for negative indices, and nG python api unit-tests for negative indices * typo fix, disable downgrading transformation * disable downgrade in clDNN, line width style fix * all Gathers are converted to 7th version, transformations will be enabled/disabled while op will be added into plugins * disabled Gather8LayerTest on GPU * added common function for Op replacement * concretized meaning of negative indices, fixed some typos * applied review comments: left only meaningful layer tests * removed op replacing functions from common utils * returned back transformations without subroutines * corrected style, added comments to common_optimizations.cpp
This commit is contained in:
parent
f324ca7fcd
commit
4738bbd757
@ -60,6 +60,7 @@
|
||||
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
|
||||
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
|
||||
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_downgrade.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_0d.hpp>
|
||||
#include <transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp>
|
||||
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
|
||||
@ -362,6 +363,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
pass_config->disable<ngraph::pass::ConvertBroadcast3>();
|
||||
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
|
||||
pass_config->disable<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
pass_config->enable<ngraph::pass::ConvertGather8ToGather7>();
|
||||
|
||||
if (!config.enable_loop_unrolling) {
|
||||
pass_config->disable<ngraph::pass::ConvertTensorIteratorToRNNSequence>();
|
||||
|
@ -31,8 +31,8 @@
|
||||
#include <transformations/op_conversions/convert_shuffle_channels3.hpp>
|
||||
#include <transformations/op_conversions/convert_space_to_depth.hpp>
|
||||
#include <transformations/op_conversions/convert_gelu.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_v1_to_gather_v7.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_downgrade.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_upgrade.hpp>
|
||||
#include <transformations/op_conversions/gelu7_downgrade.hpp>
|
||||
#include <transformations/op_conversions/hswish_decomposition.hpp>
|
||||
#include <transformations/op_conversions/hsigmoid_decomposition.hpp>
|
||||
@ -311,6 +311,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
||||
|
||||
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();
|
||||
pass_config->enable<ngraph::pass::ConvertGather1ToGather7>();
|
||||
pass_config->enable<ngraph::pass::ConvertGather8ToGather7>();
|
||||
|
||||
if (useLpt) {
|
||||
pass_config->set_callback<ngraph::pass::ConvertQuantizeDequantize>([](const_node_ptr &node) -> bool {
|
||||
|
@ -12,6 +12,7 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertGather7ToGather1;
|
||||
class TRANSFORMATIONS_API ConvertGather8ToGather7;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@ -25,3 +26,13 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertGather7ToGather1();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ConvertGather8ToGather7 converts v8::Gather into v7::Gather.
|
||||
*/
|
||||
class ngraph::pass::ConvertGather8ToGather7 : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertGather8ToGather7();
|
||||
};
|
@ -13,6 +13,8 @@ namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertGather1ToGather7;
|
||||
|
||||
class TRANSFORMATIONS_API ConvertGather7ToGather8;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
@ -25,3 +27,13 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertGather1ToGather7();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ConvertGather7ToGather8 converts v7::Gather into v8::Gather.
|
||||
*/
|
||||
class ngraph::pass::ConvertGather7ToGather8 : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertGather7ToGather8();
|
||||
};
|
@ -16,6 +16,10 @@
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
@ -130,7 +134,6 @@ Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & inpu
|
||||
}
|
||||
|
||||
TRANSFORMATIONS_API std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& node);
|
||||
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -47,8 +47,8 @@
|
||||
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
||||
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
|
||||
#include "transformations/op_conversions/convert_divide.hpp"
|
||||
#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp"
|
||||
#include "transformations/op_conversions/convert_gather_v1_to_gather_v7.hpp"
|
||||
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_gather_upgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_mod.hpp"
|
||||
#include "transformations/op_conversions/convert_minimum_to_power_and_max.hpp"
|
||||
#include "transformations/op_conversions/convert_negative.hpp"
|
||||
@ -179,8 +179,10 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
conv_fusions->set_name("ngraph::pass::ConvFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather8ToGather7>(); // not plugins implemented gather8
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>(); // not plugins implemented gather7
|
||||
manager.register_pass<ngraph::pass::ConvertGather1ToGather7, false>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather8, false>();
|
||||
manager.register_pass<ngraph::pass::ConvertDeformableConv8To1>();
|
||||
|
||||
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
|
@ -0,0 +1,69 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_gather_downgrade.hpp"
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(pass::ConvertGather7ToGather1, "ConvertGather7ToGather1", 0);
|
||||
NGRAPH_RTTI_DEFINITION(pass::ConvertGather8ToGather7, "ConvertGather8ToGather7", 0);
|
||||
|
||||
|
||||
pass::ConvertGather7ToGather1::ConvertGather7ToGather1() {
|
||||
MATCHER_SCOPE(ConvertGather7ToGather1);
|
||||
|
||||
auto gather_v7_pattern = pattern::wrap_type<opset7::Gather>();
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto gather_v7_node = std::dynamic_pointer_cast<opset7::Gather>(m.get_match_root());
|
||||
if (!gather_v7_node)
|
||||
return false;
|
||||
if (gather_v7_node->get_batch_dims() != 0)
|
||||
return false;
|
||||
|
||||
auto gather_v1_node = make_shared<opset1::Gather>(gather_v7_node->input_value(0),
|
||||
gather_v7_node->input_value(1),
|
||||
gather_v7_node->input_value(2));
|
||||
|
||||
gather_v1_node->set_friendly_name(gather_v7_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gather_v7_node, gather_v1_node);
|
||||
ngraph::replace_node(gather_v7_node, gather_v1_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(gather_v7_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
pass::ConvertGather8ToGather7::ConvertGather8ToGather7() {
|
||||
MATCHER_SCOPE(ConvertGather8ToGather7);
|
||||
|
||||
auto gather_v8_pattern = pattern::wrap_type<opset8::Gather>();
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto gather_v8_node = std::dynamic_pointer_cast<opset8::Gather>(m.get_match_root());
|
||||
if (!gather_v8_node)
|
||||
return false;
|
||||
|
||||
auto gather_v7_node = make_shared<opset7::Gather>(gather_v8_node->input_value(0),
|
||||
gather_v8_node->input_value(1),
|
||||
gather_v8_node->input_value(2),
|
||||
gather_v8_node->get_batch_dims());
|
||||
|
||||
gather_v7_node->set_friendly_name(gather_v8_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gather_v8_node, gather_v7_node);
|
||||
ngraph::replace_node(gather_v8_node, gather_v7_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(gather_v8_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_gather_upgrade.hpp"
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(pass::ConvertGather1ToGather7, "ConvertGather1ToGather7", 0);
|
||||
NGRAPH_RTTI_DEFINITION(pass::ConvertGather7ToGather8, "ConvertGather7ToGather8", 0);
|
||||
|
||||
|
||||
pass::ConvertGather1ToGather7::ConvertGather1ToGather7() {
|
||||
MATCHER_SCOPE(ConvertGather1ToGather7);
|
||||
|
||||
auto gather_v1_pattern = pattern::wrap_type<opset1::Gather>();
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto gather_v1_node = std::dynamic_pointer_cast<opset1::Gather>(m.get_match_root());
|
||||
if (!gather_v1_node)
|
||||
return false;
|
||||
|
||||
auto gather_v7_node = make_shared<opset7::Gather>(gather_v1_node->input_value(0),
|
||||
gather_v1_node->input_value(1),
|
||||
gather_v1_node->input_value(2),
|
||||
0);
|
||||
|
||||
gather_v7_node->set_friendly_name(gather_v1_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gather_v1_node, gather_v7_node);
|
||||
ngraph::replace_node(gather_v1_node, gather_v7_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(gather_v1_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
pass::ConvertGather7ToGather8::ConvertGather7ToGather8() {
|
||||
MATCHER_SCOPE(ConvertGather7ToGather8);
|
||||
|
||||
auto gather_v7_pattern = pattern::wrap_type<opset7::Gather>();
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto gather_v7_node = std::dynamic_pointer_cast<opset7::Gather>(m.get_match_root());
|
||||
if (!gather_v7_node)
|
||||
return false;
|
||||
|
||||
auto gather_v8_node = make_shared<opset8::Gather>(gather_v7_node->input_value(0),
|
||||
gather_v7_node->input_value(1),
|
||||
gather_v7_node->input_value(2),
|
||||
gather_v7_node->get_batch_dims());
|
||||
|
||||
gather_v8_node->set_friendly_name(gather_v7_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gather_v7_node, gather_v8_node);
|
||||
ngraph::replace_node(gather_v7_node, gather_v8_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(gather_v7_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_gather_v1_to_gather_v7.hpp"
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include "itt.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather1ToGather7, "ConvertGather1ToGather7", 0);
|
||||
|
||||
ngraph::pass::ConvertGather1ToGather7::ConvertGather1ToGather7() {
|
||||
MATCHER_SCOPE(ConvertGather1ToGather7);
|
||||
|
||||
auto gather_v1 = pattern::wrap_type<ngraph::opset1::Gather>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto gather_v1_node = m.get_match_root();
|
||||
if (!gather_v1_node)
|
||||
return false;
|
||||
|
||||
auto data_input = gather_v1_node->input_value(0);
|
||||
auto indices_input = gather_v1_node->input_value(1);
|
||||
auto axis_input = gather_v1_node->input_value(2);
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data_input, indices_input, axis_input, 0);
|
||||
gather_v7->set_friendly_name(gather_v1_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gather_v1_node, gather_v7);
|
||||
ngraph::replace_node(gather_v1_node, gather_v7);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(gather_v1, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp"
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather7ToGather1, "ConvertGather7ToGather1", 0);
|
||||
|
||||
ngraph::pass::ConvertGather7ToGather1::ConvertGather7ToGather1() {
|
||||
MATCHER_SCOPE(ConvertGather7ToGather1);
|
||||
|
||||
auto gather_v7 = pattern::wrap_type<ngraph::opset7::Gather>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto gather_v7_node = std::dynamic_pointer_cast<ngraph::opset7::Gather>(m.get_match_root());
|
||||
if (!gather_v7_node)
|
||||
return false;
|
||||
|
||||
if (gather_v7_node->get_batch_dims() != 0)
|
||||
return false;
|
||||
|
||||
auto data_input = gather_v7_node->input_value(0);
|
||||
auto indices_input = gather_v7_node->input_value(1);
|
||||
auto axis_input = gather_v7_node->input_value(2);
|
||||
|
||||
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data_input, indices_input, axis_input);
|
||||
gather_v1->set_friendly_name(gather_v7_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gather_v7_node, gather_v1);
|
||||
ngraph::replace_node(gather_v7_node, gather_v1);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(gather_v7, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,109 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_downgrade.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, ConvertGather7toGather1) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
|
||||
|
||||
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data, indices, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertGather7toGather1_nonzero_batch_dims) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, -1);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
// if batch_dims != 0 Gather-7 must remain
|
||||
ASSERT_EQ(count_ops_of_type<ngraph::opset1::Gather>(f), 0);
|
||||
ASSERT_EQ(count_ops_of_type<ngraph::opset7::Gather>(f), 1);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertGather8toGather7) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
|
||||
int64_t batch_dims = 1;
|
||||
|
||||
auto gather_v8 = std::make_shared<ngraph::opset8::Gather>(data, indices, axis, batch_dims);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v8}, ngraph::ParameterVector{data, indices});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather8ToGather7>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
|
||||
int64_t batch_dims = 1;
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, batch_dims);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -10,32 +10,18 @@
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/op/constant.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_upgrade.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, ConvertGather7toGather1) {
|
||||
TEST(TransformationTests, ConvertGather1toGather7) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
@ -43,32 +29,58 @@ TEST(TransformationTests, ConvertGather7toGather1) {
|
||||
|
||||
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data, indices, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather1ToGather7>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertGather7toGather1_nonzero_batch_dims) {
|
||||
TEST(TransformationTests, ConvertGather7toGather8) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, -1);
|
||||
int64_t batch_dims = 1;
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, batch_dims);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather8>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
// if batch_dims != 0 Gather-7 must remain
|
||||
ASSERT_EQ(count_ops_of_type<ngraph::opset1::Gather>(f), 0);
|
||||
ASSERT_EQ(count_ops_of_type<ngraph::opset7::Gather>(f), 1);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {1});
|
||||
int64_t batch_dims = 1;
|
||||
|
||||
auto gather_v8 = std::make_shared<ngraph::opset8::Gather>(data, indices, axis, batch_dims);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v8}, ngraph::ParameterVector{data, indices});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/op/constant.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_v1_to_gather_v7.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, ConvertGather1toGather7) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
|
||||
|
||||
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data, indices, axis);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather1ToGather7>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -61,6 +61,9 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*NormalizeL2LayerTest.*axes=\(\).*)",
|
||||
|
||||
// Not allowed dynamic loop tests on GPU
|
||||
R"(.*smoke_StaticShapeLoop_dynamic_exit.*)"
|
||||
R"(.*smoke_StaticShapeLoop_dynamic_exit.*)",
|
||||
|
||||
// TODO: until issue is xxx-59670 is resolved
|
||||
R"(.*Gather8LayerTest.*)"
|
||||
};
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ class Gather(Op):
|
||||
super().__init__(graph, {
|
||||
'op': self.op,
|
||||
'type': self.op,
|
||||
'version': 'opset7',
|
||||
'version': 'opset8',
|
||||
'batch_dims': 0,
|
||||
'infer': self.infer,
|
||||
'force_precision_in_ports': {1: 'int32', 2: 'int64'},
|
||||
@ -31,7 +31,7 @@ class Gather(Op):
|
||||
|
||||
def backend_attrs(self):
|
||||
version = self.get_opset()
|
||||
if version == 'opset7':
|
||||
if version in ['opset7', 'opset8']:
|
||||
return ['batch_dims']
|
||||
elif version == 'opset1':
|
||||
return []
|
||||
@ -75,7 +75,7 @@ class Gather(Op):
|
||||
|
||||
# we import PermuteInputs locally because it uses Gather inside and we have recursive imports
|
||||
from mo.graph.perm_inputs import PermuteInputs
|
||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
|
||||
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'axis')
|
||||
|
||||
batch_dims_range = indices_shape[:batch_dims]
|
||||
out_shape = np.concatenate((data_shape[:axis], indices_shape[batch_dims:], data_shape[axis + 1:]))
|
||||
|
@ -114,6 +114,12 @@ class TestGatherPartialInfer(unittest.TestCase):
|
||||
indices=[0, 0, 4],
|
||||
ref_value=[1, 1, 5])
|
||||
|
||||
def test_axis_0_batch_dims_0_negative_indices(self):
|
||||
self.build_and_test_value_inference(axis=0, batch_dims=0,
|
||||
data=[1, 2, 3, 4, 5],
|
||||
indices=[-1, -2, -3],
|
||||
ref_value=[5, 4, 3])
|
||||
|
||||
def test_axis_1_batch_dims_1(self):
|
||||
self.build_and_test_value_inference(axis=1, batch_dims=1,
|
||||
data=[[1, 2, 3, 4, 5],
|
||||
@ -155,6 +161,27 @@ class TestGatherPartialInfer(unittest.TestCase):
|
||||
[33, 34, 35, 36],
|
||||
[29, 30, 31, 32]]]])
|
||||
|
||||
def test_axis_2_batch_dims_1_with_negative_indices(self):
|
||||
self.build_and_test_value_inference(axis=2, batch_dims=1,
|
||||
data=[[[[ 1, 2, 3, 4], # <-- first batch
|
||||
[ 5, 6, 7, 8],
|
||||
[ 9, 10, 11, 12],
|
||||
[13, 14, 15, 16],
|
||||
[17, 18, 19, 20]]],
|
||||
[[[21, 22, 23, 24], # < -- second batch
|
||||
[25, 26, 27, 28],
|
||||
[29, 30, 31, 32],
|
||||
[33, 34, 35, 36],
|
||||
[37, 38, 39, 40]]]], # data_shape = (2, 1, 5, 4)
|
||||
indices=[[-4, -3, -1],
|
||||
[-1, 3, 2]],
|
||||
ref_value=[[[[ 5, 6, 7, 8],
|
||||
[ 9, 10, 11, 12],
|
||||
[17, 18, 19, 20]]],
|
||||
[[[37, 38, 39, 40],
|
||||
[33, 34, 35, 36],
|
||||
[29, 30, 31, 32]]]])
|
||||
|
||||
def test_axis_2_batch_dims_mimus_1(self):
|
||||
self.build_and_test_value_inference(axis=2, batch_dims=-1,
|
||||
data=[[[[ 1, 2, 3, 4], # <-- first batch
|
||||
|
@ -9,13 +9,13 @@
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace v1 {
|
||||
/// \brief Gather slices from axis of params according to indices
|
||||
/// \brief Gather slices from axis of data according to indices
|
||||
class NGRAPH_API Gather : public op::util::GatherBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
|
||||
Gather() = default;
|
||||
/// \param params The tensor from which slices are gathered
|
||||
/// \param data The tensor from which slices are gathered
|
||||
/// \param indices Tensor with indexes to gather
|
||||
/// \param axis The tensor is a dimension index to gather data from
|
||||
Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axis);
|
||||
@ -28,7 +28,7 @@ public:
|
||||
} // namespace v1
|
||||
|
||||
namespace v7 {
|
||||
/// \brief Gather slices from axis of params according to indices
|
||||
/// \brief Gather slices from axis of data according to indices
|
||||
class NGRAPH_API Gather : public op::util::GatherBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -53,7 +53,8 @@ public:
|
||||
} // namespace v7
|
||||
|
||||
namespace v8 {
|
||||
/// \brief Gather slices from axis of params according to indices
|
||||
/// \brief Gather slices from axis of data according to indices. Negative indices
|
||||
/// are supported and indicate reverse indexing from the end
|
||||
class NGRAPH_API Gather : public op::util::GatherBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
@ -53,7 +53,7 @@ from ngraph.opset1.ops import exp
|
||||
from ngraph.opset1.ops import fake_quantize
|
||||
from ngraph.opset1.ops import floor
|
||||
from ngraph.opset1.ops import floor_mod
|
||||
from ngraph.opset7.ops import gather
|
||||
from ngraph.opset8.ops import gather
|
||||
from ngraph.opset6.ops import gather_elements
|
||||
from ngraph.opset5.ops import gather_nd
|
||||
from ngraph.opset1.ops import gather_tree
|
||||
|
@ -253,3 +253,26 @@ def matrix_nms(
|
||||
}
|
||||
|
||||
return _get_node_factory_opset8().create("MatrixNms", inputs, attributes)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def gather(
|
||||
data: NodeInput,
|
||||
indices: NodeInput,
|
||||
axis: NodeInput,
|
||||
batch_dims: Optional[int] = 0,
|
||||
) -> Node:
|
||||
"""Return a node which performs Gather with support of negative indices.
|
||||
|
||||
@param data: N-D tensor with data for gathering
|
||||
@param indices: N-D tensor with indices by which data is gathered. Negative indices
|
||||
indicate reverse indexing from the end
|
||||
@param axis: axis along which elements are gathered
|
||||
@param batch_dims: number of batch dimensions
|
||||
@return: The new node which performs Gather
|
||||
"""
|
||||
inputs = as_nodes(data, indices, axis)
|
||||
attributes = {
|
||||
"batch_dims": batch_dims
|
||||
}
|
||||
return _get_node_factory_opset8().create("Gather", inputs, attributes)
|
||||
|
@ -145,3 +145,5 @@ xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadca
|
||||
xfail_issue_58033 = xfail_test(reason="Einsum operation misses support for complex ellipsis equations")
|
||||
xfail_issue_58676 = xfail_test(reason="AssertionError: Not equal to tolerance rtol=0.001, atol=1e-07")
|
||||
xfail_issue_onnx_models_140 = xfail_test(reason="https://github.com/onnx/models/issues/140")
|
||||
|
||||
xfail_issue_54630 = xfail_test(reason="Gather with negative indices is not yet implemented on CPU")
|
||||
|
@ -4,6 +4,7 @@
|
||||
import ngraph as ng
|
||||
import numpy as np
|
||||
|
||||
from tests import xfail_issue_54630
|
||||
from tests.test_ngraph.util import run_op_node
|
||||
|
||||
|
||||
@ -52,3 +53,37 @@ def test_gather_batch_dims_1():
|
||||
|
||||
result = run_op_node([input_data], ng.gather, input_indices, input_axis, batch_dims)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
@xfail_issue_54630
|
||||
def test_gather_negative_indices():
|
||||
input_data = np.array(
|
||||
[1.0, 1.1, 1.2, 2.0, 2.1, 2.2, 3.0, 3.1, 3.2], np.float32
|
||||
).reshape((3, 3))
|
||||
input_indices = np.array([0, -1], np.int32).reshape(1, 2)
|
||||
input_axis = np.array([1], np.int32)
|
||||
|
||||
expected = np.array([1.0, 1.2, 2.0, 2.2, 3.0, 3.2], dtype=np.float32).reshape(
|
||||
(3, 1, 2)
|
||||
)
|
||||
|
||||
result = run_op_node([input_data], ng.gather, input_indices, input_axis)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
@xfail_issue_54630
|
||||
def test_gather_batch_dims_1_negative_indices():
|
||||
|
||||
input_data = np.array([[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10]], np.float32)
|
||||
|
||||
input_indices = np.array([[0, 1, -2],
|
||||
[-2, 0, 0]], np.int32)
|
||||
input_axis = np.array([1], np.int32)
|
||||
batch_dims = 1
|
||||
|
||||
expected = np.array([[1, 2, 4],
|
||||
[9, 6, 6]], np.float32)
|
||||
|
||||
result = run_op_node([input_data], ng.gather, input_indices, input_axis, batch_dims)
|
||||
assert np.allclose(result, expected)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -16,7 +16,7 @@ def _get_numpy_dtype(scalar):
|
||||
|
||||
|
||||
def run_op_node(input_data, op_fun, *args):
|
||||
# type: (NumericData, Callable, *Any) -> List[NumericData]
|
||||
# type: (Union[NumericData, List[NumericData]], Callable, *Any) -> List[NumericData]
|
||||
"""Run computation on node performing `op_fun`.
|
||||
|
||||
`op_fun` has to accept a node as an argument.
|
||||
|
@ -31,7 +31,7 @@ def generate_ir(coverage=False, **kwargs):
|
||||
continue
|
||||
elif (isinstance(value, tuple) and value) or (isinstance(value, str)):
|
||||
params.extend(("--{}".format(key), str('"{}"'.format(value))))
|
||||
elif (key == "mean_values" and (' ' in value or '(' in value)):
|
||||
elif key == "mean_values" and (' ' in value or '(' in value):
|
||||
params.extend(("--{}".format(key), str('"{}"'.format(value))))
|
||||
else:
|
||||
params.extend(("--{}".format(key), str(value)))
|
||||
|
@ -215,31 +215,23 @@ class TestGather(OnnxRuntimeLayerTest):
|
||||
test_data_precommit = [
|
||||
dict(shape=[6, 8, 10, 12], axis=2, indices=[[0, 2, 4], [5, 7, 9]], output_shape=[6, 8, 2, 3, 12]),
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=1, indices=[2, 5], output_shape=[4, 2, 8, 10, 12]),
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], output_shape=[4, 6, 8, 10, 2])]
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], output_shape=[4, 6, 8, 10, 2]),
|
||||
dict(shape=[6, 8, 10, 12], axis=-1, indices=[[[2, -1], [3, 2]], [[5, -1], [3, -2]]], output_shape=[6, 8, 10, 2, 2, 2])
|
||||
]
|
||||
|
||||
test_data = [dict(shape=[10, 12], axis=0, indices=[3, 6], output_shape=[2, 12]),
|
||||
dict(shape=[10, 12], axis=1, indices=[4, 7], output_shape=[10, 2]),
|
||||
dict(shape=[10, 12], axis=-1, indices=[4, 7], output_shape=[10, 2]),
|
||||
dict(shape=[10, 12], axis=None, indices=[[0, 1, 3, 4], [5, 6, 8, 9]], output_shape=[2, 4, 12]),
|
||||
dict(shape=[10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]], output_shape=[10, 2, 5]),
|
||||
dict(shape=[8, 10, 12], axis=0, indices=[3, 6], output_shape=[2, 10, 12]),
|
||||
dict(shape=[8, 10, 12], axis=1, indices=[4, 7], output_shape=[8, 2, 12]),
|
||||
dict(shape=[8, 10, 12], axis=2, indices=[5, 8], output_shape=[8, 10, 2]),
|
||||
dict(shape=[8, 10, 12], axis=-1, indices=[5, 8], output_shape=[8, 10, 2]),
|
||||
dict(shape=[8, 10, 12], axis=None, indices=[[0, 1], [3, 4], [6, 7]], output_shape=[3, 2, 10, 12]),
|
||||
dict(shape=[8, 10, 12], axis=1, indices=[[0, 2, 4], [5, 7, 9]], output_shape=[8, 2, 3, 12]),
|
||||
dict(shape=[6, 8, 10, 12], axis=0, indices=[2, 5], output_shape=[2, 8, 10, 12]),
|
||||
dict(shape=[6, 8, 10, 12], axis=1, indices=[3, 6], output_shape=[6, 2, 10, 12]),
|
||||
dict(shape=[6, 8, 10, 12], axis=2, indices=[4, 7], output_shape=[6, 8, 2, 12]),
|
||||
dict(shape=[6, 8, 10, 12], axis=3, indices=[5, 8], output_shape=[6, 8, 10, 2]),
|
||||
dict(shape=[6, 8, 10, 12], axis=-1, indices=[5, 8], output_shape=[6, 8, 10, 2]),
|
||||
dict(shape=[6, 8, 10, 12], axis=None, indices=[[0, 1, 2], [3, 4, 5]], output_shape=[2, 3, 8, 10, 12]),
|
||||
dict(shape=[6, 8, 10, 12], axis=2, indices=[[0, 2, 4], [5, 7, 9]], output_shape=[6, 8, 2, 3, 12]),
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=0, indices=[1, 3], output_shape=[2, 6, 8, 10, 12]),
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=1, indices=[2, 5], output_shape=[4, 2, 8, 10, 12]),
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=2, indices=[3, 6], output_shape=[4, 6, 2, 10, 12]),
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=3, indices=[4, 7], output_shape=[4, 6, 8, 2, 12]),
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=4, indices=[5, 8], output_shape=[4, 6, 8, 10, 2]),
|
||||
dict(shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], output_shape=[4, 6, 8, 10, 2])]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_precommit)
|
||||
@ -259,3 +251,18 @@ class TestGather(OnnxRuntimeLayerTest):
|
||||
def test_gather_const(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_net_const(**params, ir_version=ir_version), ie_device, precision, ir_version,
|
||||
temp_dir=temp_dir)
|
||||
|
||||
test_data_negative_indices = [dict(shape=[10, 12], axis=0, indices=[3, -1, -4], output_shape=[3, 12]),
|
||||
dict(shape=[6, 10, 14, 12], axis=1, indices=[[0, -1, 3, -4], [-5, 6, -7, 8]],
|
||||
output_shape=[6, 2, 4, 14, 12]),
|
||||
dict(shape=[8, 10, 14, 12], axis=1, indices=[[-2, 2, -4], [5, -7, 9]],
|
||||
output_shape=[8, 2, 3, 14, 12]),
|
||||
dict(shape=[6, 8, 10, 12], axis=-1, indices=[[[2, -1], [3, 2]], [[5, -1], [3, -2]]],
|
||||
output_shape=[6, 8, 10, 2, 2, 2])]
|
||||
|
||||
@pytest.mark.xfail(reason='negative indices are not yet implemented on CPU: xxx-54630')
|
||||
@pytest.mark.parametrize("params", test_data_negative_indices)
|
||||
@pytest.mark.nightly
|
||||
def test_gather_nightly_negative_indices(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_net(**params, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir)
|
||||
|
61
tests/layer_tests/tensorflow_tests/test_tf_Gather.py
Normal file
61
tests/layer_tests/tensorflow_tests/test_tf_Gather.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestGather(CommonTFLayerTest):
|
||||
|
||||
def create_indices_constant(self):
|
||||
pass
|
||||
|
||||
def create_gather_net(self, data_shape, indices, axis, batch_dims, **kwargs):
|
||||
import tensorflow as tf
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
with tf.compat.v1.Session() as sess:
|
||||
data = tf.compat.v1.placeholder(tf.float32, data_shape, 'data')
|
||||
indices = tf.constant(indices, dtype=tf.int32)
|
||||
gather = tf.gather(data, indices, axis=axis, batch_dims=batch_dims, name='gather_output')
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
ref_net = None
|
||||
|
||||
return tf_net, ref_net
|
||||
|
||||
test_data_precommit = [
|
||||
dict(data_shape=[6, 8, 10, 12], indices=[[0, 2, 4], [5, 7, 9]], axis=2, batch_dims=0),
|
||||
dict(data_shape=[4, 6, 8, 10, 12], indices=[2, 5], axis=1, batch_dims=0),
|
||||
dict(data_shape=[4, 6, 8, 10, 12], indices=[2, 5], axis=-1, batch_dims=0)
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_precommit)
|
||||
@pytest.mark.precommit
|
||||
def test_gather(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_gather_net(**params, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir)
|
||||
|
||||
test_data_nightly = [
|
||||
dict(data_shape=[2, 3], axis=1, indices=[0, 2], batch_dims=0),
|
||||
dict(data_shape=[10, 12], axis=0, indices=[3, 6], batch_dims=0),
|
||||
dict(data_shape=[10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]], batch_dims=0),
|
||||
dict(data_shape=[8, 10, 12], axis=0, indices=[3, 6], batch_dims=0),
|
||||
dict(data_shape=[8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
|
||||
dict(data_shape=[6, 8, 10, 12], axis=0, indices=[2, 5], batch_dims=0),
|
||||
dict(data_shape=[6, 8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
|
||||
dict(data_shape=[6, 8, 10, 12], axis=2, indices=[[0, 2, 4], [5, 7, 9]], batch_dims=0),
|
||||
dict(data_shape=[2, 14, 10, 12], axis=1, indices=[[0, 1, 3, 4, 5], [6, 7, 9, 10, 11]], batch_dims=1),
|
||||
dict(data_shape=[4, 6, 8, 10, 12], axis=0, indices=[1, 3], batch_dims=0),
|
||||
dict(data_shape=[4, 6, 8, 10, 12], axis=-1, indices=[5, 8], batch_dims=0),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_nightly)
|
||||
@pytest.mark.nightly
|
||||
def test_gather_nightly(self, params, ie_device, precision, ir_version, temp_dir):
|
||||
self._test(*self.create_gather_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir)
|
Loading…
Reference in New Issue
Block a user