[nG] Gather-8 shell (#6171)

* gather-8 nGraph shell

* added opset8 tbl; added visit_attribute test

* corrected opset tbl
This commit is contained in:
Pavel Esir 2021-06-17 13:02:07 +03:00 committed by GitHub
parent 8390f40788
commit ccd568a2d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 691 additions and 8 deletions

View File

@ -61,5 +61,32 @@ namespace ngraph
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v7
namespace v8
{
/// \brief Gather slices from axis of params according to indices
class NGRAPH_API Gather : public op::util::GatherBase
{
public:
NGRAPH_RTTI_DECLARATION;
Gather() = default;
/// \param data The tensor from which slices are gathered
/// \param indices Tensor with indexes to gather
/// \param axis The tensor is a dimension index to gather data from
/// \param batch_dims The number of batch dimension in data and indices tensors.
Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims = 0);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
int64_t get_batch_dims() const;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace v8
} // namespace op
} // namespace ngraph

View File

@ -0,0 +1,17 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ngraph/ops.hpp"
namespace ngraph
{
namespace opset8
{
#define NGRAPH_OP(a, b) using b::a;
#include "ngraph/opsets/opset8_tbl.hpp"
#undef NGRAPH_OP
} // namespace opset8
} // namespace ngraph

View File

@ -0,0 +1,179 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(Abs, ngraph::op::v0)
NGRAPH_OP(Acos, ngraph::op::v0)
NGRAPH_OP(Add, ngraph::op::v1)
NGRAPH_OP(Asin, ngraph::op::v0)
NGRAPH_OP(Atan, ngraph::op::v0)
NGRAPH_OP(AvgPool, ngraph::op::v1)
NGRAPH_OP(BatchNormInference, ngraph::op::v5)
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
NGRAPH_OP(Broadcast, ngraph::op::v3)
NGRAPH_OP(Bucketize, ngraph::op::v3)
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
NGRAPH_OP(Ceiling, ngraph::op::v0)
NGRAPH_OP(Clamp, ngraph::op::v0)
NGRAPH_OP(Concat, ngraph::op::v0)
NGRAPH_OP(Constant, ngraph::op)
NGRAPH_OP(Convert, ngraph::op::v0)
NGRAPH_OP(ConvertLike, ngraph::op::v1)
NGRAPH_OP(Convolution, ngraph::op::v1)
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
NGRAPH_OP(Cos, ngraph::op::v0)
NGRAPH_OP(Cosh, ngraph::op::v0)
NGRAPH_OP(CumSum, ngraph::op::v0)
NGRAPH_OP(DeformableConvolution, ngraph::op::v1)
NGRAPH_OP(DeformablePSROIPooling, ngraph::op::v1)
NGRAPH_OP(DepthToSpace, ngraph::op::v0)
NGRAPH_OP(DetectionOutput, ngraph::op::v0)
NGRAPH_OP(Divide, ngraph::op::v1)
NGRAPH_OP(Elu, ngraph::op::v0)
NGRAPH_OP(Erf, ngraph::op::v0)
NGRAPH_OP(Equal, ngraph::op::v1)
NGRAPH_OP(Exp, ngraph::op::v0)
NGRAPH_OP(ExtractImagePatches, ngraph::op::v3)
NGRAPH_OP(FakeQuantize, ngraph::op::v0)
NGRAPH_OP(Floor, ngraph::op::v0)
NGRAPH_OP(FloorMod, ngraph::op::v1)
NGRAPH_OP(GatherTree, ngraph::op::v1)
NGRAPH_OP(Greater, ngraph::op::v1)
NGRAPH_OP(GreaterEqual, ngraph::op::v1)
NGRAPH_OP(GroupConvolution, ngraph::op::v1)
NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op::v1)
NGRAPH_OP(GRN, ngraph::op::v0)
NGRAPH_OP(HardSigmoid, ngraph::op::v0)
NGRAPH_OP(Less, ngraph::op::v1)
NGRAPH_OP(LessEqual, ngraph::op::v1)
NGRAPH_OP(Log, ngraph::op::v0)
NGRAPH_OP(LogicalAnd, ngraph::op::v1)
NGRAPH_OP(LogicalNot, ngraph::op::v1)
NGRAPH_OP(LogicalOr, ngraph::op::v1)
NGRAPH_OP(LogicalXor, ngraph::op::v1)
NGRAPH_OP(LRN, ngraph::op::v0)
NGRAPH_OP(LSTMCell, ngraph::op::v4)
NGRAPH_OP(MatMul, ngraph::op::v0)
NGRAPH_OP(MaxPool, ngraph::op::v1)
NGRAPH_OP(Maximum, ngraph::op::v1)
NGRAPH_OP(Minimum, ngraph::op::v1)
NGRAPH_OP(Mod, ngraph::op::v1)
NGRAPH_OP(Multiply, ngraph::op::v1)
NGRAPH_OP(Negative, ngraph::op::v0)
NGRAPH_OP(NormalizeL2, ngraph::op::v0)
NGRAPH_OP(NotEqual, ngraph::op::v1)
NGRAPH_OP(OneHot, ngraph::op::v1)
NGRAPH_OP(PRelu, ngraph::op::v0)
NGRAPH_OP(PSROIPooling, ngraph::op::v0)
NGRAPH_OP(Pad, ngraph::op::v1)
NGRAPH_OP(Parameter, ngraph::op::v0)
NGRAPH_OP(Power, ngraph::op::v1)
NGRAPH_OP(PriorBox, ngraph::op::v0)
NGRAPH_OP(PriorBoxClustered, ngraph::op::v0)
NGRAPH_OP(Proposal, ngraph::op::v4)
NGRAPH_OP(Range, ngraph::op::v4)
NGRAPH_OP(Relu, ngraph::op::v0)
NGRAPH_OP(ReduceMax, ngraph::op::v1)
NGRAPH_OP(ReduceLogicalAnd, ngraph::op::v1)
NGRAPH_OP(ReduceLogicalOr, ngraph::op::v1)
NGRAPH_OP(ReduceMean, ngraph::op::v1)
NGRAPH_OP(ReduceMin, ngraph::op::v1)
NGRAPH_OP(ReduceProd, ngraph::op::v1)
NGRAPH_OP(ReduceSum, ngraph::op::v1)
NGRAPH_OP(RegionYolo, ngraph::op::v0)
NGRAPH_OP(ReorgYolo, ngraph::op::v0)
NGRAPH_OP(Reshape, ngraph::op::v1)
NGRAPH_OP(Result, ngraph::op::v0)
NGRAPH_OP(ReverseSequence, ngraph::op::v0)
NGRAPH_OP(ROIPooling, ngraph::op::v0)
NGRAPH_OP(ScatterNDUpdate, ngraph::op::v3)
NGRAPH_OP(Select, ngraph::op::v1)
NGRAPH_OP(Selu, ngraph::op::v0)
NGRAPH_OP(Sign, ngraph::op::v0)
NGRAPH_OP(Sigmoid, ngraph::op::v0)
NGRAPH_OP(Sin, ngraph::op::v0)
NGRAPH_OP(Sinh, ngraph::op::v0)
NGRAPH_OP(Softmax, ngraph::op::v1)
NGRAPH_OP(Sqrt, ngraph::op::v0)
NGRAPH_OP(SpaceToDepth, ngraph::op::v0)
NGRAPH_OP(Split, ngraph::op::v1)
NGRAPH_OP(SquaredDifference, ngraph::op::v0)
NGRAPH_OP(Squeeze, ngraph::op::v0)
NGRAPH_OP(StridedSlice, ngraph::op::v1)
NGRAPH_OP(Subtract, ngraph::op::v1)
NGRAPH_OP(Tan, ngraph::op::v0)
NGRAPH_OP(Tanh, ngraph::op::v0)
NGRAPH_OP(TensorIterator, ngraph::op::v0)
NGRAPH_OP(Tile, ngraph::op::v0)
NGRAPH_OP(Transpose, ngraph::op::v1)
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
// New operations added in opset2
NGRAPH_OP(BatchToSpace, ngraph::op::v1)
NGRAPH_OP(SpaceToBatch, ngraph::op::v1)
// New operations added in opset3
NGRAPH_OP(EmbeddingBagPackedSum, ngraph::op::v3)
NGRAPH_OP(EmbeddingSegmentsSum, ngraph::op::v3)
NGRAPH_OP(EmbeddingBagOffsetsSum, ngraph::op::v3)
NGRAPH_OP(GRUCell, ngraph::op::v3)
NGRAPH_OP(NonZero, ngraph::op::v3)
NGRAPH_OP(RNNCell, ngraph::op::v0)
NGRAPH_OP(ROIAlign, ngraph::op::v3)
NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3)
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
NGRAPH_OP(ShapeOf, ngraph::op::v3)
NGRAPH_OP(TopK, ngraph::op::v3)
// New operations added in opset4
NGRAPH_OP(Acosh, ngraph::op::v3)
NGRAPH_OP(Asinh, ngraph::op::v3)
NGRAPH_OP(Atanh, ngraph::op::v3)
NGRAPH_OP(CTCLoss, ngraph::op::v4)
NGRAPH_OP(HSwish, ngraph::op::v4)
NGRAPH_OP(Interpolate, ngraph::op::v4)
NGRAPH_OP(Mish, ngraph::op::v4)
NGRAPH_OP(ReduceL1, ngraph::op::v4)
NGRAPH_OP(ReduceL2, ngraph::op::v4)
NGRAPH_OP(SoftPlus, ngraph::op::v4)
NGRAPH_OP(Swish, ngraph::op::v4)
// New operations added in opset5
NGRAPH_OP(GatherND, ngraph::op::v5)
NGRAPH_OP(GRUSequence, ngraph::op::v5)
NGRAPH_OP(HSigmoid, ngraph::op::v5)
NGRAPH_OP(LogSoftmax, ngraph::op::v5)
NGRAPH_OP(Loop, ngraph::op::v5)
NGRAPH_OP(LSTMSequence, ngraph::op::v5)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v5)
NGRAPH_OP(RNNSequence, ngraph::op::v5)
NGRAPH_OP(Round, ngraph::op::v5)
// New operations added in opset6
NGRAPH_OP(CTCGreedyDecoderSeqLen, ngraph::op::v6)
NGRAPH_OP(ExperimentalDetectronDetectionOutput, ngraph::op::v6)
NGRAPH_OP(ExperimentalDetectronGenerateProposalsSingleImage, ngraph::op::v6)
NGRAPH_OP(ExperimentalDetectronPriorGridGenerator, ngraph::op::v6)
NGRAPH_OP(ExperimentalDetectronROIFeatureExtractor, ngraph::op::v6)
NGRAPH_OP(ExperimentalDetectronTopKROIs, ngraph::op::v6)
NGRAPH_OP(GatherElements, ngraph::op::v6)
NGRAPH_OP(MVN, ngraph::op::v6)
NGRAPH_OP(Assign, ngraph::op::v6) // new version
NGRAPH_OP(ReadValue, ngraph::op::v6) // new version
// New operations added in opset7
NGRAPH_OP(DFT, ngraph::op::v7)
NGRAPH_OP(Einsum, ngraph::op::v7)
NGRAPH_OP(Gelu, ngraph::op::v7)
NGRAPH_OP(IDFT, ngraph::op::v7)
NGRAPH_OP(Roll, ngraph::op::v7)
// New operations added in opset8
NGRAPH_OP(Gather, ngraph::op::v8)

View File

@ -86,3 +86,50 @@ shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_a
check_new_args_count(this, new_args);
return make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
}
NGRAPH_RTTI_DEFINITION(op::v8::Gather, "Gather", 8, op::util::GatherBase);
op::v8::Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
: GatherBase(data, indices, axis, batch_dims)
{
constructor_validate_and_infer_types();
}
void op::v8::Gather::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v8_Gather_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_integral_number(),
"Indices element type must be of an integral number type.");
NODE_VALIDATION_CHECK(this,
get_input_element_type(2).is_integral_number(),
"Axis element type must be of an integral number type.");
op::util::GatherBase::validate_and_infer_types();
}
int64_t op::v8::Gather::get_batch_dims() const
{
if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static())
return m_batch_dims + get_input_partial_shape(1).rank().get_length();
else
return m_batch_dims;
}
bool ngraph::op::v8::Gather::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v8_Gather_visit_attributes);
visitor.on_attribute("batch_dims", m_batch_dims);
return true;
}
shared_ptr<Node> op::v8::Gather::clone_with_new_inputs(const OutputVector& new_args) const
{
NGRAPH_OP_SCOPE(v8_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v8::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
}

View File

@ -13,7 +13,7 @@ using namespace ngraph;
// ------------------------------ V1 ------------------------------
TEST(type_prop, gather_axis_0)
TEST(type_prop, gather_v1_axis_0)
{
Shape params_shape{3, 2};
Shape indices_shape{2, 2};
@ -27,7 +27,7 @@ TEST(type_prop, gather_axis_0)
ASSERT_EQ(G->get_axis(), 0);
}
TEST(type_prop, gather_7_uint8)
TEST(type_prop, gather_v1_uint8)
{
// Gather_1 must allow even if indices is not int32/int64
PartialShape data_shape{3, 2};
@ -44,7 +44,7 @@ TEST(type_prop, gather_7_uint8)
ASSERT_EQ(G->get_axis(), 0);
}
TEST(type_prop, gather_7_float32)
TEST(type_prop, gather_v1_float32)
{
// Gather_1 should allow non int32/int64 indices
PartialShape data_shape{3, 2};
@ -335,7 +335,7 @@ TEST(type_prop, gather_7_axis_not_set_positive_batch_dims)
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
// --------------------- Negative tests ------------------------------
// --------------------- V7 Negative tests ------------------------------
TEST(type_prop, gather_7_incorrect_axis_shape)
{
@ -470,8 +470,7 @@ TEST(type_prop, gather_7_batch_dims_less_indices_rank_check)
}
}
// disabled until decision of type constrains for gather
TEST(type_prop, DISABLED_gather_7_indices_type_check)
TEST(type_prop, gather_7_indices_type_check)
{
PartialShape data_shape{1, 20, 20, 22, 22};
PartialShape indices_shape{1, 3};
@ -500,8 +499,7 @@ TEST(type_prop, DISABLED_gather_7_indices_type_check)
}
}
// disabled until decision of type constrains for gather
TEST(type_prop, DISABLED_gather_7_axis_type_check)
TEST(type_prop, gather_7_axis_type_check)
{
PartialShape data_shape{1, 20, 20, 22, 22};
PartialShape indices_shape{1, 3};
@ -529,3 +527,402 @@ TEST(type_prop, DISABLED_gather_7_axis_type_check)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
// ------------------------------ V8 ------------------------------
TEST(type_prop, gather_v8_axis_0)
{
PartialShape data_shape{3, 2};
PartialShape indices_shape{2, 2};
PartialShape out_shape{2, 2, 2};
int64_t batch_dims = 0;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto A = op::Constant::create(element::i64, Shape{}, {0});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
ASSERT_EQ(G->get_axis(), 0);
}
TEST(type_prop, gather_v8_axis_1)
{
PartialShape data_shape{3, 3};
PartialShape indices_shape{1, 2};
PartialShape out_shape{3, 1, 2};
int64_t axis = 1;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto A = op::Constant::create(element::i64, Shape{}, {axis});
auto G = make_shared<op::v8::Gather>(D, I, A);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
ASSERT_EQ(G->get_axis(), 1);
}
TEST(type_prop, gather_v8_negative_axis)
{
PartialShape data_shape{5, 6, 7};
PartialShape indices_shape{4};
PartialShape out_shape{5, 4, 7};
int64_t axis = -2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v8::Gather>(D, I, A);
ASSERT_EQ(G->get_axis(), 1);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_dynamic_pshape_batch_dims_1_axis_1)
{
PartialShape data_shape{Dimension(1, 7), 20, 20};
PartialShape indices_shape{Dimension(7, 10), 3, 8};
PartialShape out_shape{7, 3, 8, 20};
int64_t axis = 1;
int64_t batch_dims = 1;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_dynamic_pshape_batch_dims_1_axis_3)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
PartialShape out_shape{7, Dimension(1, 3), 200, Dimension(2, 10), 3, 8};
int64_t axis = 3;
int64_t batch_dims = 1;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_dynamic_2d_pshape_batch_dim)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
PartialShape out_shape{7, Dimension(2, 3), 3, 8, 400};
int64_t axis = 2;
int64_t batch_dims = 2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_dynamic_2d_pshape_batch_dim_axis_3)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
PartialShape out_shape{7, Dimension(2, 3), 200, 3, 8};
int64_t axis = 3;
int64_t batch_dims = 2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_dynamic_rank)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape = PartialShape::dynamic(Rank(3, 5));
PartialShape out_shape = PartialShape::dynamic(Rank(4, 6));
int64_t axis = 3;
int64_t batch_dims = 2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_axis_boundcheck_for_dynamic_data_rank)
{
PartialShape data_shape = PartialShape::dynamic();
PartialShape indices_shape{7, 3, 8};
PartialShape out_shape = PartialShape::dynamic();
int64_t axis = 3;
int64_t batch_dims = 2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_dynamic_rank_negative_batch_dims)
{
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
PartialShape indices_shape = PartialShape::dynamic(Rank(3, 5));
PartialShape out_shape = PartialShape::dynamic(Rank(3, 5));
int64_t axis = 3;
int64_t batch_dims = -2;
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_axis_not_set)
{
PartialShape data_shape{1, 1, 200, 400};
PartialShape indices_shape{2, 2};
// default batch_dims = 0
PartialShape out_shape = PartialShape::dynamic(5); // out_rank = data_rank + indices_rank - 1 - batch_dims
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
auto G = make_shared<op::v8::Gather>(D, I, A);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
TEST(type_prop, gather_v8_axis_not_set_positive_batch_dims)
{
PartialShape data_shape{2, 1, 200, 400};
PartialShape indices_shape{2, 2};
int64_t batch_dims = 1;
PartialShape out_shape = PartialShape({2,
Dimension::dynamic(),
Dimension::dynamic(),
Dimension::dynamic()});
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
}
// --------------------- V8 Negative tests ------------------------------
TEST(type_prop, gather_v8_incorrect_axis_shape)
{
auto D = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto I = make_shared<op::Parameter>(element::i64, Shape{4});
auto A = make_shared<op::Parameter>(element::i64, Shape{2});
try
{
auto G = make_shared<op::v8::Gather>(D, I, A);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect A input shape";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Axis input must be scalar or have 1 element"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_v8_axis_out_of_input_rank)
{
auto D = make_shared<op::Parameter>(element::f32, Shape{5, 6});
auto I = make_shared<op::Parameter>(element::i64, Shape{4});
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{2});
int64_t batch_dims = 0;
try
{
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "axis check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("Normalized axis must be >= 0 and < data_rank. But instead got"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_v8_dynamic_batch_dims_inconsistent)
{
PartialShape data_shape{Dimension(1, 7), 20, 20};
PartialShape indices_shape{Dimension(8, 10), 3, 8};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
int64_t axis = 1;
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 1;
try
{
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("data and indices must have equal or intersecting sizes until batch_dims"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_v8_batch_dims_less_check)
{
PartialShape data_shape{1, 3, 20};
PartialShape indices_shape{1, 3, 8};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
int64_t axis = 1;
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 2;
try
{
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "batch_dims check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("After normalization batch_dims must be <= axis. But instead got: batch_dims ="));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_v8_batch_dims_less_indices_rank_check)
{
PartialShape data_shape{1, 20, 20, 22, 22};
PartialShape indices_shape{1, 3};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
int64_t axis = 4;
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 3;
try
{
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "batch_dims check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("batch_dims must be <= indices_rank"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_v8_indices_type_check)
{
PartialShape data_shape{1, 20, 20, 22, 22};
PartialShape indices_shape{1, 3};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::f32, indices_shape);
int64_t axis = 4;
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 0;
try
{
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "indices element_type check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Indices element type must be of an integral number type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_v8_axis_type_check)
{
PartialShape data_shape{1, 20, 20, 22, 22};
PartialShape indices_shape{1, 3};
auto D = make_shared<op::Parameter>(element::f32, data_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
int64_t axis = 4;
auto A = make_shared<op::Constant>(element::f32, Shape{1}, vector<int64_t>{axis});
int64_t batch_dims = 0;
try
{
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
// Should have thrown, so fail if it didn't
FAIL() << "axis element_type check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Axis element type must be of an integral number type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}

View File

@ -7,6 +7,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset7.hpp"
#include "ngraph/opsets/opset8.hpp"
#include "util/visitor.hpp"
@ -29,3 +30,18 @@ TEST(attributes, gather_v7_op)
EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims());
}
TEST(attributes, gather_v8_op)
{
NodeBuilder::get_ops().register_factory<opset8::Gather>();
auto data = make_shared<opset1::Parameter>(element::i32, Shape{2, 3, 4});
auto indices = make_shared<opset1::Parameter>(element::i32, Shape{2});
auto axis = make_shared<opset1::Constant>(element::i32, Shape{}, 2);
int64_t batch_dims = 1;
auto gather = make_shared<opset8::Gather>(data, indices, axis, batch_dims);
NodeBuilder builder(gather);
auto g_gather = as_type_ptr<opset8::Gather>(builder.create());
EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims());
}