nGraph shell implementation of Gather-7 (#4885)
* nGraph shell implementation of Gather-7 * review comments applied * style_apply * applied @ilyachur's comments * style-apply * applied @popovaan's comments * changed ieFuncTest for Gather (now is created from op version instead of opset) added check for batch_dims * clang_format_fix and some other corrections * returned back opset3::Gather in ieFuncTests * added `constexpr` to `AXIS_NOT_SET_VALUE` as @vgavrilo suggested * removed AXIS_NOT_SET_VALUE and added proper support when axis is not specified * clang_format_fix_all * applied review comments: added support for dynamic axis * applied review comments, minor corrections in gather_elements
This commit is contained in:
parent
dca99aed64
commit
1a3e9abfbe
@ -52,5 +52,38 @@ namespace ngraph
|
||||
const HostTensorVector& inputs) const;
|
||||
};
|
||||
} // namespace v1
|
||||
|
||||
namespace v7
|
||||
{
|
||||
/// \brief Gather slices from axis of params according to indices
|
||||
class NGRAPH_API Gather : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
Gather() = default;
|
||||
|
||||
/// \param data The tensor from which slices are gathered
|
||||
/// \param indices Tensor with indexes to gather
|
||||
/// \param axis The tensor is a dimension index to gather data from
|
||||
/// \param batch_dims The number of batch dimension in data and indices tensors
|
||||
Gather(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims = 0);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
int64_t get_batch_dims() const;
|
||||
int64_t get_axis() const;
|
||||
bool is_axis_set() const;
|
||||
|
||||
private:
|
||||
int64_t m_batch_dims = 0;
|
||||
};
|
||||
} // namespace v7
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -42,7 +42,7 @@ NGRAPH_OP(ExtractImagePatches, ngraph::op::v3)
|
||||
NGRAPH_OP(FakeQuantize, ngraph::op::v0)
|
||||
NGRAPH_OP(Floor, ngraph::op::v0)
|
||||
NGRAPH_OP(FloorMod, ngraph::op::v1)
|
||||
NGRAPH_OP(Gather, ngraph::op::v1)
|
||||
NGRAPH_OP(Gather, ngraph::op::v7)
|
||||
NGRAPH_OP(GatherTree, ngraph::op::v1)
|
||||
NGRAPH_OP(Greater, ngraph::op::v1)
|
||||
NGRAPH_OP(GreaterEqual, ngraph::op::v1)
|
||||
|
@ -126,6 +126,189 @@ shared_ptr<Node> op::v1::Gather::clone_with_new_inputs(const OutputVector& new_a
|
||||
return make_shared<v1::Gather>(new_args.at(PARAMS), new_args.at(INDICES), new_args.at(AXIS));
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::Gather, "Gather", 7);
|
||||
|
||||
op::v7::Gather::Gather(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
: Op({data, indices, axis})
|
||||
, m_batch_dims(batch_dims)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ngraph::op::v7::Gather::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Gather_visit_attributes);
|
||||
visitor.on_attribute("batch_dims", m_batch_dims);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v7::Gather::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Gather_validate_and_infer_types);
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
const auto& indices_type = get_input_element_type(1);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_type == element::Type_t::i32 ||
|
||||
indices_type == element::Type_t::i64,
|
||||
"indices must be of int32 or int64 type. But instead got: ",
|
||||
indices_type);
|
||||
|
||||
const auto& data_pshape = get_input_partial_shape(0);
|
||||
const auto& indices_pshape = get_input_partial_shape(1);
|
||||
const auto& axis_pshape = get_input_partial_shape(2);
|
||||
auto data_rank = data_pshape.rank();
|
||||
auto indices_rank = indices_pshape.rank();
|
||||
auto axis_rank = axis_pshape.rank();
|
||||
|
||||
if (axis_rank.is_static() && axis_pshape.is_static())
|
||||
{
|
||||
const auto axis_is_scalar = axis_rank.get_length() == 0;
|
||||
const auto axis_has_one_elem =
|
||||
axis_rank.get_length() == 1 && axis_pshape[0].get_length() == 1;
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
axis_is_scalar || axis_has_one_elem,
|
||||
"Axes input must be scalar or have 1 element. But instead got axis_shape = ",
|
||||
axis_pshape);
|
||||
}
|
||||
|
||||
int64_t batch_dims = get_batch_dims(); // will not be converted to positive if axis is not set
|
||||
if (is_axis_set())
|
||||
{
|
||||
int64_t axis = get_axis();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
batch_dims <= axis,
|
||||
"The batch_dims <= axis. But instead got: batch_dims = ",
|
||||
batch_dims,
|
||||
", axis = ",
|
||||
axis);
|
||||
|
||||
if (data_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis >= 0 && axis < data_rank.get_length(),
|
||||
"The axis must be => 0 and < data_rank. But instead got axis = ",
|
||||
axis,
|
||||
" data_rank = ",
|
||||
data_rank.get_length());
|
||||
}
|
||||
}
|
||||
|
||||
if (indices_rank.is_static() && batch_dims >= 0)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
batch_dims < indices_rank.get_length(),
|
||||
"The batch_dims must be < indices_rank. But instead got: batch_dims = ",
|
||||
batch_dims,
|
||||
", indices_rank = ",
|
||||
indices_rank.get_length());
|
||||
}
|
||||
|
||||
if (data_rank.is_static() && indices_rank.is_static())
|
||||
{
|
||||
if (batch_dims >= 0)
|
||||
{
|
||||
auto out_rank = data_rank.get_length() + indices_rank.get_length() - 1 - batch_dims;
|
||||
PartialShape output_pshape = PartialShape::dynamic(out_rank);
|
||||
|
||||
// implementation of out_shape formula
|
||||
// data.shape[:batch_dims] + data.shape[batch_dims:axis] + indices.shape[batch_dims:] +
|
||||
// data.shape[axis + 1:]
|
||||
int i = 0;
|
||||
for (; i < batch_dims; i++)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_pshape[i].compatible(indices_pshape[i]),
|
||||
"Shapes ",
|
||||
data_pshape,
|
||||
" and ",
|
||||
indices_pshape,
|
||||
" are not consistent. data and indices must have equal or "
|
||||
"intersecting sizes until batch_dims");
|
||||
|
||||
output_pshape[i] = data_pshape[i] & indices_pshape[i];
|
||||
}
|
||||
|
||||
if (is_axis_set())
|
||||
{
|
||||
int64_t axis = get_axis();
|
||||
for (; i < axis; i++)
|
||||
{
|
||||
output_pshape[i] = data_pshape[i];
|
||||
}
|
||||
for (; i < axis + indices_rank.get_length() - batch_dims; i++)
|
||||
{
|
||||
output_pshape[i] = indices_pshape[batch_dims - axis + i];
|
||||
}
|
||||
for (; i < out_rank; i++)
|
||||
{
|
||||
output_pshape[i] = data_pshape[batch_dims + 1 - indices_rank.get_length() + i];
|
||||
}
|
||||
}
|
||||
|
||||
set_output_type(0, data_type, output_pshape);
|
||||
}
|
||||
else if (batch_dims < 0)
|
||||
{
|
||||
// batch_dims < 0 could be only if axis is not set
|
||||
// as soon as axis value will arrive negative batch_dims should be resolved
|
||||
// batch_dims value will be within [0, data_rank] && [0, indices_rank]
|
||||
int64_t max_rank = data_rank.get_length() + indices_rank.get_length() - 1;
|
||||
int64_t min_rank = max_rank - max(data_rank.get_length(), indices_rank.get_length());
|
||||
|
||||
set_output_type(0, data_type, PartialShape::dynamic(Dimension(min_rank, max_rank)));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, data_type, PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t op::v7::Gather::get_axis() const
|
||||
{
|
||||
const auto& const_op = get_constant_from_source(input_value(2));
|
||||
int64_t axis = const_op->cast_vector<int64_t>()[0];
|
||||
if (axis < 0)
|
||||
{
|
||||
const auto& data_rank = get_input_partial_shape(0).rank();
|
||||
if (data_rank.is_static())
|
||||
{
|
||||
axis += data_rank.get_length();
|
||||
}
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
|
||||
int64_t op::v7::Gather::get_batch_dims() const
|
||||
{
|
||||
if (m_batch_dims < 0 && is_axis_set())
|
||||
return get_axis() + m_batch_dims;
|
||||
else
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
bool op::v7::Gather::is_axis_set() const
|
||||
{
|
||||
const auto& axes_constant = get_constant_from_source(input_value(2));
|
||||
if (axes_constant)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Gather_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
||||
}
|
||||
|
||||
namespace gather
|
||||
{
|
||||
template <element::Type_t ET>
|
||||
|
@ -87,13 +87,12 @@ void op::v6::GatherElements::validate_and_infer_types()
|
||||
{
|
||||
if (i != axis)
|
||||
{
|
||||
// if size of the current axis of indices is unknown it will retrieve it from data
|
||||
// e.g., if data_shape = {4, 4, ?} indices_shape = {1, ?, 5} and axis = 0
|
||||
// if size of the current dimension of indices is unknown it will be retrieved from data
|
||||
// e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0
|
||||
// (and if intervals intersect) then output_pshape will be {1, 4, 5}
|
||||
Dimension curr_dim = data_pshape[i] & indices_pshape[i];
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
!curr_dim.get_interval().empty(),
|
||||
data_pshape[i].compatible(indices_pshape[i]),
|
||||
"Shapes ",
|
||||
data_pshape,
|
||||
" and ",
|
||||
@ -102,7 +101,7 @@ void op::v6::GatherElements::validate_and_infer_types()
|
||||
"intersecting sizes, except for axis ",
|
||||
m_axis);
|
||||
|
||||
output_pshape[i] = curr_dim;
|
||||
output_pshape[i] = data_pshape[i] & indices_pshape[i];
|
||||
}
|
||||
}
|
||||
set_output_type(0, data_type, output_pshape);
|
||||
|
@ -11,6 +11,8 @@ NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
// ------------------------------ V1 ------------------------------
|
||||
|
||||
TEST(type_prop, gather_axis_0)
|
||||
{
|
||||
Shape params_shape{3, 2};
|
||||
@ -92,3 +94,329 @@ TEST(type_prop, gather_v1_negative_axis)
|
||||
auto gather_v1 = make_shared<op::v1::Gather>(params, indices, axis_node);
|
||||
ASSERT_EQ(gather_v1->get_axis(), 1);
|
||||
}
|
||||
|
||||
// ------------------------------ V7 ------------------------------
|
||||
|
||||
TEST(type_prop, gather_7_axis_0)
|
||||
{
|
||||
PartialShape data_shape{3, 2};
|
||||
PartialShape indices_shape{2, 2};
|
||||
PartialShape out_shape{2, 2, 2};
|
||||
int64_t batch_dims = 0;
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto A = op::Constant::create(element::i64, Shape{}, {0});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
ASSERT_EQ(G->get_axis(), 0);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_axis_1)
|
||||
{
|
||||
PartialShape data_shape{3, 3};
|
||||
PartialShape indices_shape{1, 2};
|
||||
PartialShape out_shape{3, 1, 2};
|
||||
int64_t axis = 1;
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto A = op::Constant::create(element::i64, Shape{}, {axis});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
ASSERT_EQ(G->get_axis(), 1);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_negative_axis)
|
||||
{
|
||||
PartialShape data_shape{5, 6, 7};
|
||||
PartialShape indices_shape{4};
|
||||
PartialShape out_shape{5, 4, 7};
|
||||
int64_t axis = -2;
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A);
|
||||
|
||||
ASSERT_EQ(G->get_axis(), 1);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_batch_dims_1_axis_3)
|
||||
{
|
||||
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
|
||||
PartialShape out_shape{7, Dimension(1, 3), 200, Dimension(2, 10), 3, 8};
|
||||
int64_t axis = 3;
|
||||
int64_t batch_dims = 1;
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_dynamic_batch_dim)
|
||||
{
|
||||
PartialShape data_shape{Dimension(1, 7), 20, 20};
|
||||
PartialShape indices_shape{Dimension(7, 10), 3, 8};
|
||||
PartialShape out_shape{7, 3, 8, 20};
|
||||
int64_t axis = 1;
|
||||
int64_t batch_dims = 1;
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_dynamic_2d_batch_dim)
|
||||
{
|
||||
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
|
||||
PartialShape out_shape{7, Dimension(2, 3), 3, 8, 400};
|
||||
int64_t axis = 2;
|
||||
int64_t batch_dims = 2;
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_dynamic_2d_batch_dim_axis_3)
|
||||
{
|
||||
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
|
||||
PartialShape out_shape{7, Dimension(2, 3), 200, 3, 8};
|
||||
int64_t axis = 3;
|
||||
int64_t batch_dims = 2;
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_dynamic_data_indices_rank)
|
||||
{
|
||||
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||
PartialShape indices_shape = PartialShape::dynamic();
|
||||
PartialShape out_shape = PartialShape::dynamic();
|
||||
int64_t axis = 3;
|
||||
int64_t batch_dims = 2;
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_axis_not_set)
|
||||
{
|
||||
PartialShape data_shape{1, 1, 200, 400};
|
||||
PartialShape indices_shape{2, 2};
|
||||
// default batch_dims = 0
|
||||
PartialShape out_shape = PartialShape::dynamic(5); // out_rank = data_rank + indices_rank - 1 - batch_dims
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_axis_not_set_positive_batch_dims)
|
||||
{
|
||||
PartialShape data_shape{2, 1, 200, 400};
|
||||
PartialShape indices_shape{2, 2};
|
||||
int64_t batch_dims = 1;
|
||||
PartialShape out_shape = PartialShape({2,
|
||||
Dimension::dynamic(),
|
||||
Dimension::dynamic(),
|
||||
Dimension::dynamic()});
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_axis_not_set_negative_batch)
|
||||
{
|
||||
PartialShape data_shape{1, 1, 200, 400};
|
||||
PartialShape indices_shape{2, 2};
|
||||
int64_t batch_dims = -1;
|
||||
// negative batch_dims together with unknown axis could mean any value
|
||||
// within the intervals [0, data_rank] && [0, indices_rank] so out_rank will be dynamic with the range
|
||||
// out_rank = data_rank + indices_rank - 1 - interval(0, max(data_rank, indices_rank))
|
||||
PartialShape out_shape = PartialShape::dynamic(Dimension(2, 5));
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
}
|
||||
|
||||
// --------------------- Negative tests ------------------------------
|
||||
|
||||
TEST(type_prop, gather_7_incorrect_axis_shape)
|
||||
{
|
||||
auto D = make_shared<op::Parameter>(element::f32, Shape{5, 6});
|
||||
auto I = make_shared<op::Parameter>(element::i64, Shape{4});
|
||||
auto A = make_shared<op::Parameter>(element::i64, Shape{2});
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect A input shape";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Axes input must be scalar or have 1 element"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_axis_out_of_input_rank)
|
||||
{
|
||||
auto D = make_shared<op::Parameter>(element::f32, Shape{5, 6});
|
||||
auto I = make_shared<op::Parameter>(element::i64, Shape{4});
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{2});
|
||||
int64_t batch_dims = 0;
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "axis check failed";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(), std::string("The axis must be => 0 and < data_rank. But instead got"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_dynamic_batch_dims_inconsistent)
|
||||
{
|
||||
PartialShape data_shape{Dimension(1, 7), 20, 20};
|
||||
PartialShape indices_shape{Dimension(8, 10), 3, 8};
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
int64_t axis = 1;
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
int64_t batch_dims = 1;
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("data and indices must have equal or intersecting sizes until batch_dims"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_batch_dims_less_check)
|
||||
{
|
||||
PartialShape data_shape{1, 20, 20};
|
||||
PartialShape indices_shape{1, 3, 8};
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
int64_t axis = 1;
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
int64_t batch_dims = 2;
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "batch_dims check failed";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("batch_dims <= axis. But instead got: batch_dims ="));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_batch_dims_less_indices_rank_check)
|
||||
{
|
||||
PartialShape data_shape{1, 20, 20, 22, 22};
|
||||
PartialShape indices_shape{1, 3, 8};
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
int64_t axis = 4;
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
int64_t batch_dims = 3;
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "batch_dims check failed";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("batch_dims must be < indices_rank"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user