nGraph shell implementation of Gather-7 (#4885)

* nGraph shell implementation of Gather-7

* review comments applied

* style_apply

* applied @ilyachur's comments

* style-apply

* applied @popovaan's comments

* changed ieFuncTest for Gather (now is created from op version instead of opset) added check for batch_dims

* clang_format_fix and some other corrections

* returned back opset3::Gather in ieFuncTests

* added `constexpr` to `AXIS_NOT_SET_VALUE` as @vgavrilo suggested

* removed AXIS_NOT_SET_VALUE and added proper support when axis is not specified

* clang_format_fix_all

* applied review comments: added support for dynamic axis

* applied review comments, minor corrections in gather_elements
This commit is contained in:
Pavel Esir 2021-04-01 19:38:27 +03:00 committed by GitHub
parent dca99aed64
commit 1a3e9abfbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 549 additions and 6 deletions

View File

@ -52,5 +52,38 @@ namespace ngraph
const HostTensorVector& inputs) const;
};
} // namespace v1
namespace v7
{
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
Gather() = default;
/// \param data The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
/// \param batch_dims The number of batch dimension in data and indices tensors
Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims = 0);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t get_batch_dims() const;
int64_t get_axis() const;
bool is_axis_set() const;
private:
int64_t m_batch_dims = 0;
};
} // namespace v7
} // namespace op
} // namespace ngraph

View File

@ -42,7 +42,7 @@ NGRAPH_OP(ExtractImagePatches, ngraph::op::v3)
NGRAPH_OP(FakeQuantize, ngraph::op::v0)
NGRAPH_OP(Floor, ngraph::op::v0)
NGRAPH_OP(FloorMod, ngraph::op::v1)
NGRAPH_OP(Gather, ngraph::op::v1)
NGRAPH_OP(Gather, ngraph::op::v7)
NGRAPH_OP(GatherTree, ngraph::op::v1)
NGRAPH_OP(Greater, ngraph::op::v1)
NGRAPH_OP(GreaterEqual, ngraph::op::v1)

View File

@ -126,6 +126,189 @@ shared_ptr<Node> op::v1::Gather::clone_with_new_inputs(const OutputVector& new_a
return make_shared<v1::Gather>(new_args.at(PARAMS), new_args.at(INDICES), new_args.at(AXIS));
}
NGRAPH_RTTI_DEFINITION(op::v7::Gather, "Gather", 7);
op::v7::Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
: Op({data, indices, axis})
, m_batch_dims(batch_dims)
{
constructor_validate_and_infer_types();
}
bool ngraph::op::v7::Gather::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v7_Gather_visit_attributes);
visitor.on_attribute("batch_dims", m_batch_dims);
return true;
}
void op::v7::Gather::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v7_Gather_validate_and_infer_types);
const auto& data_type = get_input_element_type(0);
const auto& indices_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
indices_type == element::Type_t::i32 ||
indices_type == element::Type_t::i64,
"indices must be of int32 or int64 type. But instead got: ",
indices_type);
const auto& data_pshape = get_input_partial_shape(0);
const auto& indices_pshape = get_input_partial_shape(1);
const auto& axis_pshape = get_input_partial_shape(2);
auto data_rank = data_pshape.rank();
auto indices_rank = indices_pshape.rank();
auto axis_rank = axis_pshape.rank();
if (axis_rank.is_static() && axis_pshape.is_static())
{
const auto axis_is_scalar = axis_rank.get_length() == 0;
const auto axis_has_one_elem =
axis_rank.get_length() == 1 && axis_pshape[0].get_length() == 1;
NODE_VALIDATION_CHECK(
this,
axis_is_scalar || axis_has_one_elem,
"Axes input must be scalar or have 1 element. But instead got axis_shape = ",
axis_pshape);
}
int64_t batch_dims = get_batch_dims(); // will not be converted to positive if axis is not set
if (is_axis_set())
{
int64_t axis = get_axis();
NODE_VALIDATION_CHECK(this,
batch_dims <= axis,
"The batch_dims <= axis. But instead got: batch_dims = ",
batch_dims,
", axis = ",
axis);
if (data_rank.is_static())
{
NODE_VALIDATION_CHECK(this,
axis >= 0 && axis < data_rank.get_length(),
"The axis must be => 0 and < data_rank. But instead got axis = ",
axis,
" data_rank = ",
data_rank.get_length());
}
}
if (indices_rank.is_static() && batch_dims >= 0)
{
NODE_VALIDATION_CHECK(
this,
batch_dims < indices_rank.get_length(),
"The batch_dims must be < indices_rank. But instead got: batch_dims = ",
batch_dims,
", indices_rank = ",
indices_rank.get_length());
}
if (data_rank.is_static() && indices_rank.is_static())
{
if (batch_dims >= 0)
{
auto out_rank = data_rank.get_length() + indices_rank.get_length() - 1 - batch_dims;
PartialShape output_pshape = PartialShape::dynamic(out_rank);
// implementation of out_shape formula
// data.shape[:batch_dims] + data.shape[batch_dims:axis] + indices.shape[batch_dims:] +
// data.shape[axis + 1:]
int i = 0;
for (; i < batch_dims; i++)
{
NODE_VALIDATION_CHECK(this,
data_pshape[i].compatible(indices_pshape[i]),
"Shapes ",
data_pshape,
" and ",
indices_pshape,
" are not consistent. data and indices must have equal or "
"intersecting sizes until batch_dims");
output_pshape[i] = data_pshape[i] & indices_pshape[i];
}
if (is_axis_set())
{
int64_t axis = get_axis();
for (; i < axis; i++)
{
output_pshape[i] = data_pshape[i];
}
for (; i < axis + indices_rank.get_length() - batch_dims; i++)
{
output_pshape[i] = indices_pshape[batch_dims - axis + i];
}
for (; i < out_rank; i++)
{
output_pshape[i] = data_pshape[batch_dims + 1 - indices_rank.get_length() + i];
}
}
set_output_type(0, data_type, output_pshape);
}
else if (batch_dims < 0)
{
// batch_dims < 0 could be only if axis is not set
// as soon as axis value will arrive negative batch_dims should be resolved
// batch_dims value will be within [0, data_rank] && [0, indices_rank]
int64_t max_rank = data_rank.get_length() + indices_rank.get_length() - 1;
int64_t min_rank = max_rank - max(data_rank.get_length(), indices_rank.get_length());
set_output_type(0, data_type, PartialShape::dynamic(Dimension(min_rank, max_rank)));
}
}
else
{
set_output_type(0, data_type, PartialShape::dynamic());
}
}
int64_t op::v7::Gather::get_axis() const
{
const auto& const_op = get_constant_from_source(input_value(2));
int64_t axis = const_op->cast_vector<int64_t>()[0];
if (axis < 0)
{
const auto& data_rank = get_input_partial_shape(0).rank();
if (data_rank.is_static())
{
axis += data_rank.get_length();
}
}
return axis;
}
int64_t op::v7::Gather::get_batch_dims() const
{
if (m_batch_dims < 0 && is_axis_set())
return get_axis() + m_batch_dims;
else
return m_batch_dims;
}
bool op::v7::Gather::is_axis_set() const
{
const auto& axes_constant = get_constant_from_source(input_value(2));
if (axes_constant)
return true;
else
return false;
}
shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_args) const
{
NGRAPH_OP_SCOPE(v7_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
}
namespace gather
{
template <element::Type_t ET>

View File

@ -87,13 +87,12 @@ void op::v6::GatherElements::validate_and_infer_types()
{
if (i != axis)
{
// if size of the current axis of indices is unknown it will retrieve it from data
// e.g., if data_shape = {4, 4, ?} indices_shape = {1, ?, 5} and axis = 0
// if size of the current dimension of indices is unknown it will be retrieved from data
// e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0
// (and if intervals intersect) then output_pshape will be {1, 4, 5}
Dimension curr_dim = data_pshape[i] & indices_pshape[i];
NODE_VALIDATION_CHECK(this,
!curr_dim.get_interval().empty(),
data_pshape[i].compatible(indices_pshape[i]),
"Shapes ",
data_pshape,
" and ",
@ -102,7 +101,7 @@ void op::v6::GatherElements::validate_and_infer_types()
"intersecting sizes, except for axis ",
m_axis);
output_pshape[i] = curr_dim;
output_pshape[i] = data_pshape[i] & indices_pshape[i];
}
}
set_output_type(0, data_type, output_pshape);

View File

@ -11,6 +11,8 @@ NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
// ------------------------------ V1 ------------------------------
TEST(type_prop, gather_axis_0)
{
Shape params_shape{3, 2};
@ -92,3 +94,329 @@ TEST(type_prop, gather_v1_negative_axis)
auto gather_v1 = make_shared<op::v1::Gather>(params, indices, axis_node);
ASSERT_EQ(gather_v1->get_axis(), 1);
}
// ------------------------------ V7 ------------------------------
TEST(type_prop, gather_7_axis_0)
{
PartialShape data_shape{3, 2};
PartialShape indices_shape{2, 2};
PartialShape out_shape{2, 2, 2};
int64_t batch_dims = 0;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto A = op::Constant::create(element::i64, Shape{}, {0});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
ASSERT_EQ(G->get_axis(), 0);
}
TEST(type_prop, gather_7_axis_1)
{
PartialShape data_shape{3, 3};
PartialShape indices_shape{1, 2};
PartialShape out_shape{3, 1, 2};
int64_t axis = 1;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto A = op::Constant::create(element::i64, Shape{}, {axis});
auto G = make_shared<op::v7::Gather>(D, I, A);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
ASSERT_EQ(G->get_axis(), 1);
}
TEST(type_prop, gather_7_negative_axis)
{
PartialShape data_shape{5, 6, 7};
PartialShape indices_shape{4};
PartialShape out_shape{5, 4, 7};
int64_t axis = -2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v7::Gather>(D, I, A);
ASSERT_EQ(G->get_axis(), 1);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_7_batch_dims_1_axis_3)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
PartialShape out_shape{7, Dimension(1, 3), 200, Dimension(2, 10), 3, 8};
int64_t axis = 3;
int64_t batch_dims = 1;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_7_dynamic_batch_dim)
{
PartialShape data_shape{Dimension(1, 7), 20, 20};
PartialShape indices_shape{Dimension(7, 10), 3, 8};
PartialShape out_shape{7, 3, 8, 20};
int64_t axis = 1;
int64_t batch_dims = 1;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_7_dynamic_2d_batch_dim)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
PartialShape out_shape{7, Dimension(2, 3), 3, 8, 400};
int64_t axis = 2;
int64_t batch_dims = 2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_7_dynamic_2d_batch_dim_axis_3)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
PartialShape out_shape{7, Dimension(2, 3), 200, 3, 8};
int64_t axis = 3;
int64_t batch_dims = 2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_7_dynamic_data_indices_rank)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape = PartialShape::dynamic();
PartialShape out_shape = PartialShape::dynamic();
int64_t axis = 3;
int64_t batch_dims = 2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_7_axis_not_set)
{
PartialShape data_shape{1, 1, 200, 400};
PartialShape indices_shape{2, 2};
// default batch_dims = 0
PartialShape out_shape = PartialShape::dynamic(5); // out_rank = data_rank + indices_rank - 1 - batch_dims
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
auto G = make_shared<op::v7::Gather>(D, I, A);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_7_axis_not_set_positive_batch_dims)
{
PartialShape data_shape{2, 1, 200, 400};
PartialShape indices_shape{2, 2};
int64_t batch_dims = 1;
PartialShape out_shape = PartialShape({2,
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()});
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_7_axis_not_set_negative_batch)
{
PartialShape data_shape{1, 1, 200, 400};
PartialShape indices_shape{2, 2};
int64_t batch_dims = -1;
// negative batch_dims together with unknown axis could mean any value
// within the intervals [0, data_rank] && [0, indices_rank] so out_rank will be dynamic with the range
// out_rank = data_rank + indices_rank - 1 - interval(0, max(data_rank, indices_rank))
PartialShape out_shape = PartialShape::dynamic(Dimension(2, 5));
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
// --------------------- Negative tests ------------------------------
TEST(type_prop, gather_7_incorrect_axis_shape)
{
auto D = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto I = make_shared<op::Parameter>(element::i64, Shape{4});
auto A = make_shared<op::Parameter>(element::i64, Shape{2});
try
{
auto G = make_shared<op::v7::Gather>(D, I, A);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect A input shape";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Axes input must be scalar or have 1 element"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_7_axis_out_of_input_rank)
{
auto D = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto I = make_shared<op::Parameter>(element::i64, Shape{4});
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{2});
int64_t batch_dims = 0;
try
{
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "axis check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("The axis must be => 0 and < data_rank. But instead got"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_7_dynamic_batch_dims_inconsistent)
{
PartialShape data_shape{Dimension(1, 7), 20, 20};
PartialShape indices_shape{Dimension(8, 10), 3, 8};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
int64_t axis = 1;
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 1;
try
{
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("data and indices must have equal or intersecting sizes until batch_dims"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_7_batch_dims_less_check)
{
PartialShape data_shape{1, 20, 20};
PartialShape indices_shape{1, 3, 8};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
int64_t axis = 1;
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 2;
try
{
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "batch_dims check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("batch_dims <= axis. But instead got: batch_dims ="));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_7_batch_dims_less_indices_rank_check)
{
PartialShape data_shape{1, 20, 20, 22, 22};
PartialShape indices_shape{1, 3, 8};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
int64_t axis = 4;
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 3;
try
{
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "batch_dims check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("batch_dims must be < indices_rank"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}