[nGraph] Gather v7 v1 up down transformations (#5118)

* v7->v1 works fine

* gather v1->v7 works fine

* changed op::v7 -> opset7, op:v1 -> opset1

* bumped to gather7 in all transformations

* applied review comments

* fixed pre-commit failures

* disabled incorrect unit-test, fixed f32->i32

* added comments why AddConvertToReorderTest was disabled

* Revert "bumped to gather7 in all transformations"

This reverts commit 965dc295

* fixed typos in v1->v7, v7->v1

* added GatherBase, redefined pattern in eliminate_unsqueeze gather, turned on Gather downgrading transformation

* resolved conflicts and build errors

* 🚀 finally EliminateUnsqueezeGather works: added inheritance of RTTI info from GatherBase

* fixed pre-commit failurer

* removed redundant debug code

* reverted f32 for Gather-1 indices in unit-tests and transformations

* relaxed restrictions for indices and axis type for v1::Gather

* corrected op scope

* moved type_check to validation_util.cpp

* removed type checks from Gather, removed upgrading transformation

* applied review coments

* fixed minor typos
This commit is contained in:
Pavel Esir 2021-04-20 10:27:48 +03:00 committed by GitHub
parent f2306adf98
commit 41bdec12df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 699 additions and 588 deletions

View File

@ -0,0 +1,27 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertGather7ToGather1;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ConvertGather7ToGather1 covert v7::Gather into v1::Gather.
*/
class ngraph::pass::ConvertGather7ToGather1 : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertGather7ToGather1();
};

View File

@ -37,12 +37,12 @@ static bool simplify_gather(std::shared_ptr<Node> node) {
return false;
}
auto axis = gather->get_axis();
if (axis == opset3::Gather::AXIS_NOT_SET_VALUE) {
if (!gather->is_axis_set()) {
NGRAPH_DEBUG << "axis value not set";
return false;
}
auto axis = gather->get_axis();
// case_1 : if the input tensor is of shape (4, 1, 4)
// and axis = 1, then the gather would be simply
@ -85,12 +85,12 @@ static bool simplify_gather_shapeof(shared_ptr<Node> node) {
}
auto gather_in_rank = gather->get_input_partial_shape(0).rank();
auto indices_rank = gather->get_input_partial_shape(1).rank();
auto axis = gather->get_axis();
if (gather_in_rank.is_dynamic() || indices_rank.is_dynamic() ||
axis == opset3::Gather::AXIS_NOT_SET_VALUE) {
!gather->is_axis_set()) {
NGRAPH_DEBUG << gather << " cannot simplify gather->shapeof";
return false;
}
auto axis = gather->get_axis();
auto zero_axis = opset3::Constant::create<int64_t>(element::i64, Shape{}, {0});
NodeVector new_ops;

View File

@ -42,6 +42,7 @@
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp"
#include "transformations/op_conversions/convert_mod.hpp"
#include "transformations/op_conversions/convert_minimum_to_power_and_max.hpp"
#include "transformations/op_conversions/convert_negative.hpp"
@ -156,6 +157,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
conv_fusions->set_name("ngraph::pass::ConvFusions");
manager.register_pass<ngraph::pass::ConstantFolding>();
// need to convert to Gather-1 until plugins do not support Gather-7
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();

View File

@ -19,7 +19,7 @@ ngraph::pass::EliminateUnsqueezeGather::EliminateUnsqueezeGather() {
const auto unsqueeze = ngraph::pattern::wrap_type<ngraph::opset6::Unsqueeze>({unsqueezeInput, unsqueezeAxis}, pattern::consumers_count(1));
const auto gatherIndices = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
const auto gatherAxis = ngraph::pattern::any_input();
const auto gather = ngraph::pattern::wrap_type<ngraph::opset6::Gather>({unsqueeze, gatherIndices, gatherAxis});
const auto gather = ngraph::pattern::wrap_type<ngraph::op::util::GatherBase>({unsqueeze, gatherIndices, gatherAxis});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& patternValue = m.get_pattern_value_map();

View File

@ -0,0 +1,43 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather7ToGather1, "ConvertGather7ToGather1", 0);
ngraph::pass::ConvertGather7ToGather1::ConvertGather7ToGather1() {
MATCHER_SCOPE(ConvertGather7ToGather1);
auto gather_v7 = pattern::wrap_type<ngraph::opset7::Gather>();
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto gather_v7_node = std::dynamic_pointer_cast<ngraph::opset7::Gather>(m.get_match_root());
if (!gather_v7_node)
return false;
int64_t batch_dims = 0;
if (gather_v7_node->is_axis_set())
batch_dims = gather_v7_node->get_batch_dims();
if (batch_dims != 0)
return false;
auto data_input = gather_v7_node->input_value(0);
auto indices_input = gather_v7_node->input_value(1);
auto axis_input = gather_v7_node->input_value(2);
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data_input, indices_input, axis_input);
gather_v1->set_friendly_name(gather_v7_node->get_friendly_name());
ngraph::copy_runtime_info(gather_v7_node, gather_v1);
ngraph::replace_node(gather_v7_node, gather_v1);
return true;
};
auto m = std::make_shared<pattern::Matcher>(gather_v7, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,51 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertGather7toGather1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data, indices, axis);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -6,7 +6,6 @@
#include <ngraph_functions/utils/ngraph_helpers.hpp>
#include <common_test_utils/test_common.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>

View File

@ -4,7 +4,7 @@
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/gather_base.hpp"
namespace ngraph
{
@ -13,12 +13,10 @@ namespace ngraph
namespace v1
{
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public Op
class NGRAPH_API Gather : public op::util::GatherBase
{
public:
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
static constexpr NodeTypeInfo type_info{"Gather", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
@ -28,35 +26,16 @@ namespace ngraph
const Output<Node>& axis);
bool visit_attributes(AttributeVisitor& visitor) override;
int64_t get_axis() const;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override;
bool constant_fold(OutputVector& output_values,
const OutputVector& inputs_values) override;
private:
static const int PARAMS;
static const int INDICES;
static const int AXIS;
bool evaluate_gather(const HostTensorVector& outputs,
const HostTensorVector& inputs) const;
};
} // namespace v1
namespace v7
{
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public Op
class NGRAPH_API Gather : public op::util::GatherBase
{
public:
NGRAPH_RTTI_DECLARATION;
@ -76,23 +55,6 @@ namespace ngraph
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t get_batch_dims() const;
int64_t get_axis() const;
bool is_axis_set() const;
bool evaluate_gather(const HostTensorVector& outputs,
const HostTensorVector& inputs) const;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override;
bool constant_fold(OutputVector& output_values,
const OutputVector& inputs_values) override;
private:
int64_t m_batch_dims = 0;
};
} // namespace v7
} // namespace op

View File

@ -0,0 +1,50 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief GatherBase basic class for Gather v1 and v7
class NGRAPH_API GatherBase : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
GatherBase() = default;
/// \param data The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
/// \param batch_dims The number of batch dimension in data and indices tensors
GatherBase(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims = 0);
void validate_and_infer_types() override;
int64_t get_batch_dims() const;
int64_t get_axis() const;
bool is_axis_set() const;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override;
bool constant_fold(OutputVector& output_values,
const OutputVector& inputs_values) override;
protected:
int64_t m_batch_dims = 0;
};
} // namespace utils
} // namespace op
} // namespace ngraph

View File

@ -3,30 +3,19 @@
//
#include "ngraph/op/gather.hpp"
#include "itt.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/squeeze.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/shape.hpp"
#include <ngraph/validation_util.hpp>
#include "itt.hpp"
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::Gather::type_info;
const int64_t op::v1::Gather::AXIS_NOT_SET_VALUE;
const int op::v1::Gather::PARAMS = 0;
const int op::v1::Gather::INDICES = 1;
const int op::v1::Gather::AXIS = 2;
NGRAPH_RTTI_DEFINITION(op::v1::Gather, "Gather", 1, op::util::GatherBase);
op::v1::Gather::Gather(const Output<Node>& params,
const Output<Node>& indices,
const Output<Node>& axes)
: Op({params, indices, axes})
: GatherBase(params, indices, axes)
{
constructor_validate_and_infer_types();
}
@ -37,107 +26,38 @@ bool ngraph::op::v1::Gather::visit_attributes(AttributeVisitor& visitor)
return true;
}
void op::v1::Gather::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v1_Gather_validate_and_infer_types);
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
const auto& axis_shape = get_input_partial_shape(AXIS);
const auto& axis_rank = axis_shape.rank();
if (axis_rank.is_static() && axis_shape.is_static())
{
const auto axis_is_scalar = axis_rank.get_length() == 0;
const auto axis_has_one_elem =
axis_rank.get_length() == 1 && axis_shape[0].get_length() == 1;
NODE_VALIDATION_CHECK(this,
axis_is_scalar || axis_has_one_elem,
"Axes input must be scalar or have 1 element (shape: ",
axis_shape,
").");
}
int64_t axis = get_axis();
if (input_rank.is_static() && axis != AXIS_NOT_SET_VALUE)
{
NODE_VALIDATION_CHECK(this,
axis < input_rank.get_length(),
"The axis must => 0 and <= input_rank (axis: ",
axis,
").");
}
element::Type result_et = get_input_element_type(PARAMS);
const PartialShape& params_shape = get_input_partial_shape(PARAMS);
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
PartialShape result_shape;
if (params_shape.rank().is_static() && indices_shape.rank().is_static() &&
axis != AXIS_NOT_SET_VALUE)
{
std::vector<Dimension> result_dims(params_shape.rank().get_length() +
indices_shape.rank().get_length() - 1);
int64_t i = 0;
for (; i < axis; i++)
{
result_dims[i] = params_shape[i];
}
for (int64_t j = 0; j < indices_shape.rank().get_length(); i++, j++)
{
result_dims[i] = indices_shape[j];
}
for (int64_t j = axis + 1; j < params_shape.rank().get_length(); i++, j++)
{
result_dims[i] = params_shape[j];
}
result_shape = PartialShape(result_dims);
}
else
{
result_shape = PartialShape::dynamic();
}
set_output_type(0, result_et, result_shape);
}
int64_t op::v1::Gather::get_axis() const
{
int64_t axis = AXIS_NOT_SET_VALUE;
if (const auto& const_op = get_constant_from_source(input_value(AXIS)))
{
axis = const_op->cast_vector<int64_t>()[0];
}
if (axis < 0)
{
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
if (input_rank.is_static())
{
axis += input_rank.get_length();
}
}
return axis;
}
shared_ptr<Node> op::v1::Gather::clone_with_new_inputs(const OutputVector& new_args) const
{
NGRAPH_OP_SCOPE(v1_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v1::Gather>(new_args.at(PARAMS), new_args.at(INDICES), new_args.at(AXIS));
return make_shared<v1::Gather>(new_args.at(0), new_args.at(1), new_args.at(2));
}
NGRAPH_RTTI_DEFINITION(op::v7::Gather, "Gather", 7);
NGRAPH_RTTI_DEFINITION(op::v7::Gather, "Gather", 7, op::util::GatherBase);
op::v7::Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
: Op({data, indices, axis})
, m_batch_dims(batch_dims)
: GatherBase(data, indices, axis, batch_dims)
{
constructor_validate_and_infer_types();
}
void op::v7::Gather::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v7_Gather_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_integral_number(),
"Indices element type must be of an integral number type.");
NODE_VALIDATION_CHECK(this,
get_input_element_type(2).is_integral_number(),
"Axis element type must be of an integral number type.");
op::util::GatherBase::validate_and_infer_types();
}
bool ngraph::op::v7::Gather::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v7_Gather_visit_attributes);
@ -145,440 +65,9 @@ bool ngraph::op::v7::Gather::visit_attributes(AttributeVisitor& visitor)
return true;
}
void op::v7::Gather::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v7_Gather_validate_and_infer_types);
const auto& data_type = get_input_element_type(0);
const auto& indices_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
indices_type == element::Type_t::i32 ||
indices_type == element::Type_t::i64,
"indices must be of int32 or int64 type. But instead got: ",
indices_type);
const auto& data_pshape = get_input_partial_shape(0);
const auto& indices_pshape = get_input_partial_shape(1);
const auto& axis_pshape = get_input_partial_shape(2);
auto data_rank = data_pshape.rank();
auto indices_rank = indices_pshape.rank();
auto axis_rank = axis_pshape.rank();
if (axis_rank.is_static() && axis_pshape.is_static())
{
const auto axis_is_scalar = axis_rank.get_length() == 0;
const auto axis_has_one_elem =
axis_rank.get_length() == 1 && axis_pshape[0].get_length() == 1;
NODE_VALIDATION_CHECK(
this,
axis_is_scalar || axis_has_one_elem,
"Axes input must be scalar or have 1 element. But instead got axis_shape = ",
axis_pshape);
}
int64_t batch_dims = get_batch_dims(); // will not be converted to positive if axis is not set
if (is_axis_set())
{
int64_t axis = get_axis();
NODE_VALIDATION_CHECK(this,
batch_dims <= axis,
"The batch_dims <= axis. But instead got: batch_dims = ",
batch_dims,
", axis = ",
axis);
if (data_rank.is_static())
{
NODE_VALIDATION_CHECK(this,
axis >= 0 && axis < data_rank.get_length(),
"The axis must be => 0 and < data_rank. But instead got axis = ",
axis,
" data_rank = ",
data_rank.get_length());
}
}
if (indices_rank.is_static() && batch_dims >= 0)
{
NODE_VALIDATION_CHECK(
this,
batch_dims <= indices_rank.get_length(),
"The batch_dims must be <= indices_rank. But instead got: batch_dims = ",
batch_dims,
", indices_rank = ",
indices_rank.get_length());
}
if (data_rank.is_static() && indices_rank.is_static())
{
if (batch_dims >= 0)
{
auto out_rank = data_rank.get_length() + indices_rank.get_length() - 1 - batch_dims;
PartialShape output_pshape = PartialShape::dynamic(out_rank);
// implementation of out_shape formula
// data.shape[:batch_dims] + data.shape[batch_dims:axis] + indices.shape[batch_dims:] +
// data.shape[axis + 1:]
int i = 0;
for (; i < batch_dims; i++)
{
NODE_VALIDATION_CHECK(this,
data_pshape[i].compatible(indices_pshape[i]),
"Shapes ",
data_pshape,
" and ",
indices_pshape,
" are not consistent. data and indices must have equal or "
"intersecting sizes until batch_dims");
output_pshape[i] = data_pshape[i] & indices_pshape[i];
}
if (is_axis_set())
{
int64_t axis = get_axis();
for (; i < axis; i++)
{
output_pshape[i] = data_pshape[i];
}
for (; i < axis + indices_rank.get_length() - batch_dims; i++)
{
output_pshape[i] = indices_pshape[batch_dims - axis + i];
}
for (; i < out_rank; i++)
{
output_pshape[i] = data_pshape[batch_dims + 1 - indices_rank.get_length() + i];
}
}
set_output_type(0, data_type, output_pshape);
}
else if (batch_dims < 0)
{
// batch_dims < 0 could be only if axis is not set
// as soon as axis value will arrive negative batch_dims should be resolved
// batch_dims value will be within [0, data_rank] && [0, indices_rank]
int64_t max_rank = data_rank.get_length() + indices_rank.get_length() - 1;
int64_t min_rank = max_rank - max(data_rank.get_length(), indices_rank.get_length());
set_output_type(0, data_type, PartialShape::dynamic(Dimension(min_rank, max_rank)));
}
}
else
{
set_output_type(0, data_type, PartialShape::dynamic());
}
}
int64_t op::v7::Gather::get_axis() const
{
const auto& const_op = get_constant_from_source(input_value(2));
int64_t axis = const_op->cast_vector<int64_t>()[0];
if (axis < 0)
{
const auto& data_rank = get_input_partial_shape(0).rank();
if (data_rank.is_static())
{
axis += data_rank.get_length();
}
}
return axis;
}
int64_t op::v7::Gather::get_batch_dims() const
{
if (m_batch_dims < 0 && is_axis_set())
return get_axis() + m_batch_dims;
else
return m_batch_dims;
}
bool op::v7::Gather::is_axis_set() const
{
const auto& axes_constant = get_constant_from_source(input_value(2));
if (axes_constant)
return true;
else
return false;
}
shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_args) const
{
NGRAPH_OP_SCOPE(v7_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
}
namespace gather
{
template <element::Type_t ET>
bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
size_t axis,
size_t batch_dims)
{
using T = typename element_type_traits<ET>::value_type;
Shape params_shape = arg0->get_shape();
Shape indices_shape = arg1->get_shape();
Shape out_shape(params_shape.size() + indices_shape.size() - 1 - batch_dims);
uint64_t i = 0;
for (; i < axis; i++)
{
out_shape[i] = params_shape[i];
}
for (uint64_t j = batch_dims; j < indices_shape.size(); i++, j++)
{
out_shape[i] = indices_shape[j];
}
for (uint64_t j = axis + 1; j < params_shape.size(); i++, j++)
{
out_shape[i] = params_shape[j];
}
out->set_shape(out_shape);
if (arg1->get_element_type() == element::i64)
{
runtime::reference::gather<T, int64_t>(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<int64_t>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
out->get_shape(),
axis,
batch_dims);
}
else if (arg1->get_element_type() == element::i32)
{
runtime::reference::gather<T, int32_t>(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<int32_t>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
out->get_shape(),
axis,
batch_dims);
}
else
{
throw ngraph_error("Unexpected type");
}
return true;
}
bool evaluate_gather(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
size_t axis,
size_t batch_dims = 0)
{
bool rc = true;
switch (out->get_element_type())
{
NGRAPH_TYPE_CASE(evaluate_gather, i32, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, i64, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, u32, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, u64, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, f16, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, f32, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, boolean, arg0, arg1, out, axis, batch_dims);
default: rc = false; break;
}
return rc;
}
bool cf_gather_with_subgraph(OutputVector& output_values,
const OutputVector& input_values,
const PartialShape& gather_ps)
{
if (gather_ps.is_dynamic() || input_values.size() != 3)
{
return false;
}
const auto concat =
std::dynamic_pointer_cast<op::Concat>(input_values[0].get_node_shared_ptr());
const auto indices =
std::dynamic_pointer_cast<op::Constant>(input_values[1].get_node_shared_ptr());
const auto axis =
std::dynamic_pointer_cast<op::Constant>(input_values[2].get_node_shared_ptr());
if (!concat || !indices || !axis)
{
return false;
}
// only along axis=0
if (axis->cast_vector<int64_t>()[0] != 0 || concat->get_axis() != 0)
{
return false;
}
// only single indices are accepted
const auto indices_shape = indices->get_shape();
if (indices_shape.size() > 1 || (indices_shape.size() == 1 && indices_shape[0] > 1))
{
return false;
}
// concat inputs are 1D and their count is equal to Concat output shape
if (concat->get_output_partial_shape(0).is_dynamic())
{
return false;
}
const auto concat_inputs = concat->inputs();
// concat inputs must be single elements
if (concat_inputs.size() != shape_size(concat->get_shape()))
{
return false;
}
const int64_t rank = concat->get_shape()[0];
const int64_t raw_index = indices->cast_vector<int64_t>()[0];
const int64_t positive_index = raw_index < 0 ? rank + raw_index : raw_index;
NGRAPH_CHECK(positive_index >= 0 && positive_index < rank);
// gather takes exactly one element out of the Concat output
const auto gathered_concat_input =
concat_inputs[positive_index].get_source_output().get_node_shared_ptr();
// Concat inputs are 1D, resulting tensor shape depends on Gather indices
auto gathered = gathered_concat_input;
if (indices_shape.empty())
{
// gathering a scalar
const auto axes = op::Constant::create(element::i64, Shape{1}, {0});
gathered = make_shared<op::v0::Squeeze>(gathered_concat_input, axes);
}
output_values[0] = gathered;
return true;
}
} // namespace gather
bool op::v1::Gather::evaluate_gather(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
int64_t axis = 0;
switch (inputs[2]->get_element_type())
{
case element::Type_t::i8: axis = inputs[2]->get_data_ptr<element::Type_t::i8>()[0]; break;
case element::Type_t::i16: axis = inputs[2]->get_data_ptr<element::Type_t::i16>()[0]; break;
case element::Type_t::i32: axis = inputs[2]->get_data_ptr<element::Type_t::i32>()[0]; break;
case element::Type_t::i64: axis = inputs[2]->get_data_ptr<element::Type_t::i64>()[0]; break;
case element::Type_t::u8: axis = inputs[2]->get_data_ptr<element::Type_t::u8>()[0]; break;
case element::Type_t::u16: axis = inputs[2]->get_data_ptr<element::Type_t::u16>()[0]; break;
case element::Type_t::u32: axis = inputs[2]->get_data_ptr<element::Type_t::u32>()[0]; break;
case element::Type_t::u64: axis = inputs[2]->get_data_ptr<element::Type_t::u64>()[0]; break;
default: throw ngraph_error("axis element type is not integral data type");
}
if (axis < 0)
{
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
if (input_rank.is_static())
{
axis += input_rank.get_length();
}
}
return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis);
}
bool op::v1::Gather::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
{
NGRAPH_OP_SCOPE(v1_Gather_evaluate);
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3));
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
return evaluate_gather(outputs, inputs);
}
bool op::v1::Gather::evaluate_lower(const HostTensorVector& output_values) const
{
if (!input_value(INDICES).get_tensor().has_and_set_bound() ||
!input_value(AXIS).get_tensor().has_and_set_bound())
return false;
return default_lower_bound_evaluator(this, output_values);
}
bool op::v1::Gather::evaluate_upper(const HostTensorVector& output_values) const
{
if (!input_value(INDICES).get_tensor().has_and_set_bound() ||
!input_value(AXIS).get_tensor().has_and_set_bound())
return false;
return default_upper_bound_evaluator(this, output_values);
}
bool op::v1::Gather::constant_fold(OutputVector& output_values, const OutputVector& input_values)
{
// try the regular constant folding just for the Gather node
if (Node::constant_fold(output_values, input_values))
{
return true;
}
else
{
return gather::cf_gather_with_subgraph(
output_values, input_values, get_output_partial_shape(0));
}
}
bool op::v7::Gather::evaluate_gather(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
int64_t axis = 0;
switch (inputs[2]->get_element_type())
{
case element::Type_t::i32: axis = inputs[2]->get_data_ptr<element::Type_t::i32>()[0]; break;
case element::Type_t::i64: axis = inputs[2]->get_data_ptr<element::Type_t::i64>()[0]; break;
default: throw ngraph_error("axis must be of int32 or int64 type.");
}
if (axis < 0)
{
const auto& input_rank = get_input_partial_shape(0).rank();
if (input_rank.is_static())
{
axis += input_rank.get_length();
}
}
return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis, get_batch_dims());
}
bool op::v7::Gather::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
{
NGRAPH_OP_SCOPE(v7_Gather_evaluate);
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3));
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
return evaluate_gather(outputs, inputs);
}
bool op::v7::Gather::evaluate_lower(const HostTensorVector& output_values) const
{
if (!input_value(1).get_tensor().has_and_set_bound() ||
!input_value(2).get_tensor().has_and_set_bound())
return false;
return default_lower_bound_evaluator(this, output_values);
}
bool op::v7::Gather::evaluate_upper(const HostTensorVector& output_values) const
{
if (!input_value(1).get_tensor().has_and_set_bound() ||
!input_value(2).get_tensor().has_and_set_bound())
return false;
return default_upper_bound_evaluator(this, output_values);
}
bool op::v7::Gather::constant_fold(OutputVector& output_values, const OutputVector& input_values)
{
// try the regular constant folding just for the Gather node
if (Node::constant_fold(output_values, input_values))
{
return true;
}
else
{
return gather::cf_gather_with_subgraph(
output_values, input_values, get_output_partial_shape(0));
}
}

View File

@ -0,0 +1,393 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/util/gather_base.hpp"
#include "itt.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/squeeze.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/shape.hpp"
#include <ngraph/validation_util.hpp>
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::util::GatherBase, "GatherBase", 7);
op::util::GatherBase::GatherBase(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
: Op({data, indices, axis})
, m_batch_dims(batch_dims)
{
constructor_validate_and_infer_types();
}
void op::util::GatherBase::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(util_GatherBase_validate_and_infer_types);
const auto& data_type = get_input_element_type(0);
const auto& data_pshape = get_input_partial_shape(0);
const auto& indices_pshape = get_input_partial_shape(1);
const auto& axis_pshape = get_input_partial_shape(2);
auto data_rank = data_pshape.rank();
auto indices_rank = indices_pshape.rank();
auto axis_rank = axis_pshape.rank();
if (axis_rank.is_static() && axis_pshape.is_static())
{
const auto axis_is_scalar = axis_rank.get_length() == 0;
const auto axis_has_one_elem =
axis_rank.get_length() == 1 && axis_pshape[0].get_length() == 1;
NODE_VALIDATION_CHECK(
this,
axis_is_scalar || axis_has_one_elem,
"Axis input must be scalar or have 1 element. But instead got axis_shape = ",
axis_pshape);
}
int64_t batch_dims = get_batch_dims(); // will not be converted to positive if axis is not set
if (is_axis_set())
{
int64_t axis = get_axis();
NODE_VALIDATION_CHECK(this,
batch_dims <= axis,
"The batch_dims <= axis. But instead got: batch_dims = ",
batch_dims,
", axis = ",
axis);
if (data_rank.is_static())
{
NODE_VALIDATION_CHECK(this,
axis >= 0 && axis < data_rank.get_length(),
"The axis must be >= 0 and < data_rank. But instead got axis = ",
axis,
" data_rank = ",
data_rank.get_length());
}
}
if (indices_rank.is_static() && batch_dims >= 0)
{
NODE_VALIDATION_CHECK(
this,
batch_dims <= indices_rank.get_length(),
"The batch_dims must be <= indices_rank. But instead got: batch_dims = ",
batch_dims,
", indices_rank = ",
indices_rank.get_length());
}
if (data_rank.is_static() && indices_rank.is_static())
{
if (batch_dims >= 0)
{
auto out_rank = data_rank.get_length() + indices_rank.get_length() - 1 - batch_dims;
PartialShape output_pshape = PartialShape::dynamic(out_rank);
// implementation of out_shape formula
// data.shape[:batch_dims] + data.shape[batch_dims:axis] + indices.shape[batch_dims:] +
// data.shape[axis + 1:]
int i = 0;
for (; i < batch_dims; i++)
{
NODE_VALIDATION_CHECK(this,
data_pshape[i].compatible(indices_pshape[i]),
"Shapes ",
data_pshape,
" and ",
indices_pshape,
" are not consistent. data and indices must have equal or "
"intersecting sizes until batch_dims");
output_pshape[i] = data_pshape[i] & indices_pshape[i];
}
if (is_axis_set())
{
int64_t axis = get_axis();
for (; i < axis; i++)
{
output_pshape[i] = data_pshape[i];
}
for (; i < axis + indices_rank.get_length() - batch_dims; i++)
{
output_pshape[i] = indices_pshape[batch_dims - axis + i];
}
for (; i < out_rank; i++)
{
output_pshape[i] = data_pshape[batch_dims + 1 - indices_rank.get_length() + i];
}
}
set_output_type(0, data_type, output_pshape);
}
else if (batch_dims < 0)
{
// batch_dims < 0 could be only if axis is not set
// as soon as axis value will arrive negative batch_dims should be resolved
// batch_dims value will be within [0, data_rank] && [0, indices_rank]
int64_t max_rank = data_rank.get_length() + indices_rank.get_length() - 1;
int64_t min_rank = max_rank - max(data_rank.get_length(), indices_rank.get_length());
set_output_type(0, data_type, PartialShape::dynamic(Dimension(min_rank, max_rank)));
}
}
else
{
set_output_type(0, data_type, PartialShape::dynamic());
}
}
int64_t op::util::GatherBase::get_axis() const
{
const auto& const_op = get_constant_from_source(input_value(2));
if (!const_op)
throw ngraph_error("axis value is not set");
int64_t axis = const_op->cast_vector<int64_t>()[0];
if (axis < 0)
{
const auto& data_rank = get_input_partial_shape(0).rank();
if (data_rank.is_static())
{
axis += data_rank.get_length();
}
}
return axis;
}
int64_t op::util::GatherBase::get_batch_dims() const
{
if (m_batch_dims < 0 && is_axis_set())
return get_axis() + m_batch_dims;
else
return m_batch_dims;
}
bool op::util::GatherBase::is_axis_set() const
{
const auto& axis_constant = get_constant_from_source(input_value(2));
if (axis_constant)
return true;
else
return false;
}
namespace gather
{
template <element::Type_t ET>
bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
size_t axis,
size_t batch_dims)
{
using T = typename element_type_traits<ET>::value_type;
Shape params_shape = arg0->get_shape();
Shape indices_shape = arg1->get_shape();
Shape out_shape(params_shape.size() + indices_shape.size() - 1 - batch_dims);
uint64_t i = 0;
for (; i < axis; i++)
{
out_shape[i] = params_shape[i];
}
for (uint64_t j = batch_dims; j < indices_shape.size(); i++, j++)
{
out_shape[i] = indices_shape[j];
}
for (uint64_t j = axis + 1; j < params_shape.size(); i++, j++)
{
out_shape[i] = params_shape[j];
}
out->set_shape(out_shape);
if (arg1->get_element_type() == element::i64)
{
runtime::reference::gather<T, int64_t>(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<int64_t>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
out->get_shape(),
axis,
batch_dims);
}
else if (arg1->get_element_type() == element::i32)
{
runtime::reference::gather<T, int32_t>(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<int32_t>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
out->get_shape(),
axis,
batch_dims);
}
else
{
throw ngraph_error("Unexpected type");
}
return true;
}
bool evaluate_gather(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
size_t axis,
size_t batch_dims = 0)
{
bool rc = true;
switch (out->get_element_type())
{
NGRAPH_TYPE_CASE(evaluate_gather, i32, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, i64, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, u32, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, u64, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, f16, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, f32, arg0, arg1, out, axis, batch_dims);
NGRAPH_TYPE_CASE(evaluate_gather, boolean, arg0, arg1, out, axis, batch_dims);
default: rc = false; break;
}
return rc;
}
bool cf_gather_with_subgraph(OutputVector& output_values,
const OutputVector& input_values,
const PartialShape& gather_ps)
{
if (gather_ps.is_dynamic() || input_values.size() != 3)
{
return false;
}
const auto concat =
std::dynamic_pointer_cast<op::Concat>(input_values[0].get_node_shared_ptr());
const auto indices =
std::dynamic_pointer_cast<op::Constant>(input_values[1].get_node_shared_ptr());
const auto axis =
std::dynamic_pointer_cast<op::Constant>(input_values[2].get_node_shared_ptr());
if (!concat || !indices || !axis)
{
return false;
}
// only along axis=0
if (axis->cast_vector<int64_t>()[0] != 0 || concat->get_axis() != 0)
{
return false;
}
// only single indices are accepted
const auto indices_shape = indices->get_shape();
if (indices_shape.size() > 1 || (indices_shape.size() == 1 && indices_shape[0] > 1))
{
return false;
}
// concat inputs are 1D and their count is equal to Concat output shape
if (concat->get_output_partial_shape(0).is_dynamic())
{
return false;
}
const auto concat_inputs = concat->inputs();
// concat inputs must be single elements
if (concat_inputs.size() != shape_size(concat->get_shape()))
{
return false;
}
const int64_t rank = concat->get_shape()[0];
const int64_t raw_index = indices->cast_vector<int64_t>()[0];
const int64_t positive_index = raw_index < 0 ? rank + raw_index : raw_index;
NGRAPH_CHECK(positive_index >= 0 && positive_index < rank);
// gather takes exactly one element out of the Concat output
const auto gathered_concat_input =
concat_inputs[positive_index].get_source_output().get_node_shared_ptr();
// Concat inputs are 1D, resulting tensor shape depends on Gather indices
auto gathered = gathered_concat_input;
if (indices_shape.empty())
{
// gathering a scalar
const auto axis_const = op::Constant::create(element::i64, Shape{1}, {0});
gathered = make_shared<op::v0::Squeeze>(gathered_concat_input, axis_const);
}
output_values[0] = gathered;
return true;
}
} // namespace gather
bool op::util::GatherBase::evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
NGRAPH_OP_SCOPE(util_GatherBase_evaluate);
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3));
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
int64_t axis = 0;
switch (inputs[2]->get_element_type())
{
case element::Type_t::i32: axis = inputs[2]->get_data_ptr<element::Type_t::i32>()[0]; break;
case element::Type_t::i64: axis = inputs[2]->get_data_ptr<element::Type_t::i64>()[0]; break;
case element::Type_t::i8: axis = inputs[2]->get_data_ptr<element::Type_t::i8>()[0]; break;
case element::Type_t::i16: axis = inputs[2]->get_data_ptr<element::Type_t::i16>()[0]; break;
case element::Type_t::u8: axis = inputs[2]->get_data_ptr<element::Type_t::u8>()[0]; break;
case element::Type_t::u16: axis = inputs[2]->get_data_ptr<element::Type_t::u16>()[0]; break;
case element::Type_t::u32: axis = inputs[2]->get_data_ptr<element::Type_t::u32>()[0]; break;
case element::Type_t::u64: axis = inputs[2]->get_data_ptr<element::Type_t::u64>()[0]; break;
default: throw ngraph_error("axis must be of integral data type.");
}
if (axis < 0)
{
const auto& input_rank = get_input_partial_shape(0).rank();
if (input_rank.is_static())
{
axis += input_rank.get_length();
}
}
return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis, get_batch_dims());
}
bool op::util::GatherBase::evaluate_lower(const HostTensorVector& output_values) const
{
if (!input_value(1).get_tensor().has_and_set_bound() ||
!input_value(2).get_tensor().has_and_set_bound())
return false;
return default_lower_bound_evaluator(this, output_values);
}
bool op::util::GatherBase::evaluate_upper(const HostTensorVector& output_values) const
{
if (!input_value(1).get_tensor().has_and_set_bound() ||
!input_value(2).get_tensor().has_and_set_bound())
return false;
return default_upper_bound_evaluator(this, output_values);
}
bool op::util::GatherBase::constant_fold(OutputVector& output_values,
const OutputVector& input_values)
{
// try the regular constant folding just for the Gather node
if (Node::constant_fold(output_values, input_values))
{
return true;
}
else
{
return gather::cf_gather_with_subgraph(
output_values, input_values, get_output_partial_shape(0));
}
}

View File

@ -27,6 +27,40 @@ TEST(type_prop, gather_axis_0)
ASSERT_EQ(G->get_axis(), 0);
}
TEST(type_prop, gather_7_uint8)
{
// Gather_1 must allow even if indices is not int32/int64
PartialShape data_shape{3, 2};
PartialShape indices_shape{2, 2};
PartialShape out_shape{2, 2, 2};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::u8, indices_shape);
auto A = op::Constant::create(element::i64, Shape{}, {0});
auto G = make_shared<op::v1::Gather>(D, I, A);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
ASSERT_EQ(G->get_axis(), 0);
}
TEST(type_prop, gather_7_float32)
{
// Gather_1 should allow non int32/int64 indices
PartialShape data_shape{3, 2};
PartialShape indices_shape{2, 2};
PartialShape out_shape{2, 2, 2};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::f32, indices_shape);
auto A = op::Constant::create(element::i64, Shape{}, {0});
auto G = make_shared<op::v1::Gather>(D, I, A);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
ASSERT_EQ(G->get_axis(), 0);
}
TEST(type_prop, gather_axis_1)
{
Shape params_shape{3, 3};
@ -55,7 +89,7 @@ TEST(type_prop, gather_v1_incorrect_axis_shape)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Axes input must be scalar or have 1 element (shape:"));
std::string("Axis input must be scalar or have 1 element"));
}
catch (...)
{
@ -77,7 +111,7 @@ TEST(type_prop, gather_v1_axis_out_of_input_rank)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The axis must => 0 and <= input_rank (axis:"));
std::string("The axis must be >= 0 and < data_rank. But instead got axis"));
}
catch (...)
{
@ -241,7 +275,7 @@ TEST(type_prop, gather_7_axis_not_set)
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
auto G = make_shared<op::v7::Gather>(D, I, A);
ASSERT_EQ(G->get_element_type(), element::f32);
@ -260,7 +294,7 @@ TEST(type_prop, gather_7_axis_not_set_positive_batch_dims)
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
@ -279,7 +313,7 @@ TEST(type_prop, gather_7_axis_not_set_negative_batch)
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
@ -303,7 +337,7 @@ TEST(type_prop, gather_7_incorrect_axis_shape)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Axes input must be scalar or have 1 element"));
std::string("Axis input must be scalar or have 1 element"));
}
catch (...)
{
@ -326,7 +360,7 @@ TEST(type_prop, gather_7_axis_out_of_input_rank)
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("The axis must be => 0 and < data_rank. But instead got"));
error.what(), std::string("The axis must be >= 0 and < data_rank. But instead got"));
}
catch (...)
{
@ -420,3 +454,63 @@ TEST(type_prop, gather_7_batch_dims_less_indices_rank_check)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
// disabled until decision of type constrains for gather
TEST(type_prop, DISABLED_gather_7_indices_type_check)
{
PartialShape data_shape{1, 20, 20, 22, 22};
PartialShape indices_shape{1, 3};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::f32, indices_shape);
int64_t axis = 4;
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 0;
try
{
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "indices element_type check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Indices element type must be of an integral number type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
// disabled until decision of type constrains for gather
TEST(type_prop, DISABLED_gather_7_axis_type_check)
{
PartialShape data_shape{1, 20, 20, 22, 22};
PartialShape indices_shape{1, 3};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
int64_t axis = 4;
auto A = make_shared<op::Constant>(element::f32, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 0;
try
{
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "axis element_type check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Axis element type must be of an integral number type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}