From ccd568a2d2b0bc5b10e0b4469c1e9acb0681a26a Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Thu, 17 Jun 2021 13:02:07 +0300 Subject: [PATCH] [nG] Gather-8 shell (#6171) * gather-8 nGraph shell * added opset8 tbl; added visit_attribute test * corrected opset tbl --- ngraph/core/include/ngraph/op/gather.hpp | 27 ++ ngraph/core/include/ngraph/opsets/opset8.hpp | 17 + .../core/include/ngraph/opsets/opset8_tbl.hpp | 179 ++++++++ ngraph/core/src/op/gather.cpp | 47 ++ ngraph/test/type_prop/gather.cpp | 413 +++++++++++++++++- ngraph/test/visitors/op/gather.cpp | 16 + 6 files changed, 691 insertions(+), 8 deletions(-) create mode 100644 ngraph/core/include/ngraph/opsets/opset8.hpp create mode 100644 ngraph/core/include/ngraph/opsets/opset8_tbl.hpp diff --git a/ngraph/core/include/ngraph/op/gather.hpp b/ngraph/core/include/ngraph/op/gather.hpp index 9293521a7d0..61bbf33bc3a 100644 --- a/ngraph/core/include/ngraph/op/gather.hpp +++ b/ngraph/core/include/ngraph/op/gather.hpp @@ -61,5 +61,32 @@ namespace ngraph clone_with_new_inputs(const OutputVector& new_args) const override; }; } // namespace v7 + + namespace v8 + { + /// \brief Gather slices from axis of params according to indices + class NGRAPH_API Gather : public op::util::GatherBase + { + public: + NGRAPH_RTTI_DECLARATION; + Gather() = default; + + /// \param data The tensor from which slices are gathered + /// \param indices Tensor with indexes to gather + /// \param axis The tensor is a dimension index to gather data from + /// \param batch_dims The number of batch dimension in data and indices tensors. + Gather(const Output& data, + const Output& indices, + const Output& axis, + const int64_t batch_dims = 0); + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + int64_t get_batch_dims() const; + + std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + }; + } // namespace v8 } // namespace op } // namespace ngraph diff --git a/ngraph/core/include/ngraph/opsets/opset8.hpp b/ngraph/core/include/ngraph/opsets/opset8.hpp new file mode 100644 index 00000000000..f31a3142cf6 --- /dev/null +++ b/ngraph/core/include/ngraph/opsets/opset8.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/ops.hpp" + +namespace ngraph +{ + namespace opset8 + { +#define NGRAPH_OP(a, b) using b::a; +#include "ngraph/opsets/opset8_tbl.hpp" +#undef NGRAPH_OP + } // namespace opset8 +} // namespace ngraph diff --git a/ngraph/core/include/ngraph/opsets/opset8_tbl.hpp b/ngraph/core/include/ngraph/opsets/opset8_tbl.hpp new file mode 100644 index 00000000000..0dbe077ddae --- /dev/null +++ b/ngraph/core/include/ngraph/opsets/opset8_tbl.hpp @@ -0,0 +1,179 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef NGRAPH_OP +#warning "NGRAPH_OP not defined" +#define NGRAPH_OP(x, y) +#endif + +NGRAPH_OP(Abs, ngraph::op::v0) +NGRAPH_OP(Acos, ngraph::op::v0) +NGRAPH_OP(Add, ngraph::op::v1) +NGRAPH_OP(Asin, ngraph::op::v0) +NGRAPH_OP(Atan, ngraph::op::v0) +NGRAPH_OP(AvgPool, ngraph::op::v1) +NGRAPH_OP(BatchNormInference, ngraph::op::v5) +NGRAPH_OP(BinaryConvolution, ngraph::op::v1) +NGRAPH_OP(Broadcast, ngraph::op::v3) +NGRAPH_OP(Bucketize, ngraph::op::v3) +NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0) +NGRAPH_OP(Ceiling, ngraph::op::v0) +NGRAPH_OP(Clamp, ngraph::op::v0) +NGRAPH_OP(Concat, ngraph::op::v0) +NGRAPH_OP(Constant, ngraph::op) +NGRAPH_OP(Convert, ngraph::op::v0) +NGRAPH_OP(ConvertLike, ngraph::op::v1) +NGRAPH_OP(Convolution, ngraph::op::v1) +NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1) +NGRAPH_OP(Cos, ngraph::op::v0) +NGRAPH_OP(Cosh, ngraph::op::v0) +NGRAPH_OP(CumSum, ngraph::op::v0) +NGRAPH_OP(DeformableConvolution, ngraph::op::v1) +NGRAPH_OP(DeformablePSROIPooling, ngraph::op::v1) +NGRAPH_OP(DepthToSpace, ngraph::op::v0) +NGRAPH_OP(DetectionOutput, ngraph::op::v0) +NGRAPH_OP(Divide, ngraph::op::v1) +NGRAPH_OP(Elu, ngraph::op::v0) +NGRAPH_OP(Erf, ngraph::op::v0) +NGRAPH_OP(Equal, ngraph::op::v1) +NGRAPH_OP(Exp, ngraph::op::v0) +NGRAPH_OP(ExtractImagePatches, ngraph::op::v3) +NGRAPH_OP(FakeQuantize, ngraph::op::v0) +NGRAPH_OP(Floor, ngraph::op::v0) +NGRAPH_OP(FloorMod, ngraph::op::v1) +NGRAPH_OP(GatherTree, ngraph::op::v1) +NGRAPH_OP(Greater, ngraph::op::v1) +NGRAPH_OP(GreaterEqual, ngraph::op::v1) +NGRAPH_OP(GroupConvolution, ngraph::op::v1) +NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op::v1) +NGRAPH_OP(GRN, ngraph::op::v0) +NGRAPH_OP(HardSigmoid, ngraph::op::v0) +NGRAPH_OP(Less, ngraph::op::v1) +NGRAPH_OP(LessEqual, ngraph::op::v1) +NGRAPH_OP(Log, ngraph::op::v0) +NGRAPH_OP(LogicalAnd, ngraph::op::v1) +NGRAPH_OP(LogicalNot, ngraph::op::v1) +NGRAPH_OP(LogicalOr, ngraph::op::v1) +NGRAPH_OP(LogicalXor, ngraph::op::v1) +NGRAPH_OP(LRN, ngraph::op::v0) +NGRAPH_OP(LSTMCell, ngraph::op::v4) +NGRAPH_OP(MatMul, ngraph::op::v0) +NGRAPH_OP(MaxPool, ngraph::op::v1) +NGRAPH_OP(Maximum, ngraph::op::v1) +NGRAPH_OP(Minimum, ngraph::op::v1) +NGRAPH_OP(Mod, ngraph::op::v1) +NGRAPH_OP(Multiply, ngraph::op::v1) +NGRAPH_OP(Negative, ngraph::op::v0) +NGRAPH_OP(NormalizeL2, ngraph::op::v0) +NGRAPH_OP(NotEqual, ngraph::op::v1) +NGRAPH_OP(OneHot, ngraph::op::v1) +NGRAPH_OP(PRelu, ngraph::op::v0) +NGRAPH_OP(PSROIPooling, ngraph::op::v0) +NGRAPH_OP(Pad, ngraph::op::v1) +NGRAPH_OP(Parameter, ngraph::op::v0) +NGRAPH_OP(Power, ngraph::op::v1) +NGRAPH_OP(PriorBox, ngraph::op::v0) +NGRAPH_OP(PriorBoxClustered, ngraph::op::v0) +NGRAPH_OP(Proposal, ngraph::op::v4) +NGRAPH_OP(Range, ngraph::op::v4) +NGRAPH_OP(Relu, ngraph::op::v0) +NGRAPH_OP(ReduceMax, ngraph::op::v1) +NGRAPH_OP(ReduceLogicalAnd, ngraph::op::v1) +NGRAPH_OP(ReduceLogicalOr, ngraph::op::v1) +NGRAPH_OP(ReduceMean, ngraph::op::v1) +NGRAPH_OP(ReduceMin, ngraph::op::v1) +NGRAPH_OP(ReduceProd, ngraph::op::v1) +NGRAPH_OP(ReduceSum, ngraph::op::v1) +NGRAPH_OP(RegionYolo, ngraph::op::v0) +NGRAPH_OP(ReorgYolo, ngraph::op::v0) +NGRAPH_OP(Reshape, ngraph::op::v1) +NGRAPH_OP(Result, ngraph::op::v0) +NGRAPH_OP(ReverseSequence, ngraph::op::v0) +NGRAPH_OP(ROIPooling, ngraph::op::v0) +NGRAPH_OP(ScatterNDUpdate, ngraph::op::v3) +NGRAPH_OP(Select, ngraph::op::v1) +NGRAPH_OP(Selu, ngraph::op::v0) +NGRAPH_OP(Sign, ngraph::op::v0) +NGRAPH_OP(Sigmoid, ngraph::op::v0) +NGRAPH_OP(Sin, ngraph::op::v0) +NGRAPH_OP(Sinh, ngraph::op::v0) +NGRAPH_OP(Softmax, ngraph::op::v1) +NGRAPH_OP(Sqrt, ngraph::op::v0) +NGRAPH_OP(SpaceToDepth, ngraph::op::v0) +NGRAPH_OP(Split, ngraph::op::v1) +NGRAPH_OP(SquaredDifference, ngraph::op::v0) +NGRAPH_OP(Squeeze, ngraph::op::v0) +NGRAPH_OP(StridedSlice, ngraph::op::v1) +NGRAPH_OP(Subtract, ngraph::op::v1) +NGRAPH_OP(Tan, ngraph::op::v0) +NGRAPH_OP(Tanh, ngraph::op::v0) +NGRAPH_OP(TensorIterator, ngraph::op::v0) +NGRAPH_OP(Tile, ngraph::op::v0) +NGRAPH_OP(Transpose, ngraph::op::v1) +NGRAPH_OP(Unsqueeze, ngraph::op::v0) +NGRAPH_OP(VariadicSplit, ngraph::op::v1) + +// New operations added in opset2 +NGRAPH_OP(BatchToSpace, ngraph::op::v1) +NGRAPH_OP(SpaceToBatch, ngraph::op::v1) + +// New operations added in opset3 +NGRAPH_OP(EmbeddingBagPackedSum, ngraph::op::v3) +NGRAPH_OP(EmbeddingSegmentsSum, ngraph::op::v3) +NGRAPH_OP(EmbeddingBagOffsetsSum, ngraph::op::v3) +NGRAPH_OP(GRUCell, ngraph::op::v3) +NGRAPH_OP(NonZero, ngraph::op::v3) +NGRAPH_OP(RNNCell, ngraph::op::v0) +NGRAPH_OP(ROIAlign, ngraph::op::v3) +NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3) +NGRAPH_OP(ScatterUpdate, ngraph::op::v3) +NGRAPH_OP(ShuffleChannels, ngraph::op::v0) +NGRAPH_OP(ShapeOf, ngraph::op::v3) +NGRAPH_OP(TopK, ngraph::op::v3) + +// New operations added in opset4 +NGRAPH_OP(Acosh, ngraph::op::v3) +NGRAPH_OP(Asinh, ngraph::op::v3) +NGRAPH_OP(Atanh, ngraph::op::v3) +NGRAPH_OP(CTCLoss, ngraph::op::v4) +NGRAPH_OP(HSwish, ngraph::op::v4) +NGRAPH_OP(Interpolate, ngraph::op::v4) +NGRAPH_OP(Mish, ngraph::op::v4) +NGRAPH_OP(ReduceL1, ngraph::op::v4) +NGRAPH_OP(ReduceL2, ngraph::op::v4) +NGRAPH_OP(SoftPlus, ngraph::op::v4) +NGRAPH_OP(Swish, ngraph::op::v4) + +// New operations added in opset5 +NGRAPH_OP(GatherND, ngraph::op::v5) +NGRAPH_OP(GRUSequence, ngraph::op::v5) +NGRAPH_OP(HSigmoid, ngraph::op::v5) +NGRAPH_OP(LogSoftmax, ngraph::op::v5) +NGRAPH_OP(Loop, ngraph::op::v5) +NGRAPH_OP(LSTMSequence, ngraph::op::v5) +NGRAPH_OP(NonMaxSuppression, ngraph::op::v5) +NGRAPH_OP(RNNSequence, ngraph::op::v5) +NGRAPH_OP(Round, ngraph::op::v5) + +// New operations added in opset6 +NGRAPH_OP(CTCGreedyDecoderSeqLen, ngraph::op::v6) +NGRAPH_OP(ExperimentalDetectronDetectionOutput, ngraph::op::v6) +NGRAPH_OP(ExperimentalDetectronGenerateProposalsSingleImage, ngraph::op::v6) +NGRAPH_OP(ExperimentalDetectronPriorGridGenerator, ngraph::op::v6) +NGRAPH_OP(ExperimentalDetectronROIFeatureExtractor, ngraph::op::v6) +NGRAPH_OP(ExperimentalDetectronTopKROIs, ngraph::op::v6) +NGRAPH_OP(GatherElements, ngraph::op::v6) +NGRAPH_OP(MVN, ngraph::op::v6) +NGRAPH_OP(Assign, ngraph::op::v6) // new version +NGRAPH_OP(ReadValue, ngraph::op::v6) // new version + +// New operations added in opset7 +NGRAPH_OP(DFT, ngraph::op::v7) +NGRAPH_OP(Einsum, ngraph::op::v7) +NGRAPH_OP(Gelu, ngraph::op::v7) +NGRAPH_OP(IDFT, ngraph::op::v7) +NGRAPH_OP(Roll, ngraph::op::v7) + +// New operations added in opset8 +NGRAPH_OP(Gather, ngraph::op::v8) diff --git a/ngraph/core/src/op/gather.cpp b/ngraph/core/src/op/gather.cpp index f2ddc9948b3..aeacffc2719 100644 --- a/ngraph/core/src/op/gather.cpp +++ b/ngraph/core/src/op/gather.cpp @@ -86,3 +86,50 @@ shared_ptr op::v7::Gather::clone_with_new_inputs(const OutputVector& new_a check_new_args_count(this, new_args); return make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims); } + +NGRAPH_RTTI_DEFINITION(op::v8::Gather, "Gather", 8, op::util::GatherBase); + +op::v8::Gather::Gather(const Output& data, + const Output& indices, + const Output& axis, + const int64_t batch_dims) + : GatherBase(data, indices, axis, batch_dims) +{ + constructor_validate_and_infer_types(); +} + +void op::v8::Gather::validate_and_infer_types() +{ + NGRAPH_OP_SCOPE(v8_Gather_validate_and_infer_types); + NODE_VALIDATION_CHECK(this, + get_input_element_type(1).is_integral_number(), + "Indices element type must be of an integral number type."); + + NODE_VALIDATION_CHECK(this, + get_input_element_type(2).is_integral_number(), + "Axis element type must be of an integral number type."); + + op::util::GatherBase::validate_and_infer_types(); +} + +int64_t op::v8::Gather::get_batch_dims() const +{ + if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static()) + return m_batch_dims + get_input_partial_shape(1).rank().get_length(); + else + return m_batch_dims; +} + +bool ngraph::op::v8::Gather::visit_attributes(AttributeVisitor& visitor) +{ + NGRAPH_OP_SCOPE(v8_Gather_visit_attributes); + visitor.on_attribute("batch_dims", m_batch_dims); + return true; +} + +shared_ptr op::v8::Gather::clone_with_new_inputs(const OutputVector& new_args) const +{ + NGRAPH_OP_SCOPE(v8_Gather_clone_with_new_inputs); + check_new_args_count(this, new_args); + return make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims); +} diff --git a/ngraph/test/type_prop/gather.cpp b/ngraph/test/type_prop/gather.cpp index 70bfc8530ee..96b2bbff832 100644 --- a/ngraph/test/type_prop/gather.cpp +++ b/ngraph/test/type_prop/gather.cpp @@ -13,7 +13,7 @@ using namespace ngraph; // ------------------------------ V1 ------------------------------ -TEST(type_prop, gather_axis_0) +TEST(type_prop, gather_v1_axis_0) { Shape params_shape{3, 2}; Shape indices_shape{2, 2}; @@ -27,7 +27,7 @@ TEST(type_prop, gather_axis_0) ASSERT_EQ(G->get_axis(), 0); } -TEST(type_prop, gather_7_uint8) +TEST(type_prop, gather_v1_uint8) { // Gather_1 must allow even if indices is not int32/int64 PartialShape data_shape{3, 2}; @@ -44,7 +44,7 @@ TEST(type_prop, gather_7_uint8) ASSERT_EQ(G->get_axis(), 0); } -TEST(type_prop, gather_7_float32) +TEST(type_prop, gather_v1_float32) { // Gather_1 should allow non int32/int64 indices PartialShape data_shape{3, 2}; @@ -335,7 +335,7 @@ TEST(type_prop, gather_7_axis_not_set_positive_batch_dims) ASSERT_EQ(G->get_output_partial_shape(0), out_shape); } -// --------------------- Negative tests ------------------------------ +// --------------------- V7 Negative tests ------------------------------ TEST(type_prop, gather_7_incorrect_axis_shape) { @@ -470,8 +470,7 @@ TEST(type_prop, gather_7_batch_dims_less_indices_rank_check) } } -// disabled until decision of type constrains for gather -TEST(type_prop, DISABLED_gather_7_indices_type_check) +TEST(type_prop, gather_7_indices_type_check) { PartialShape data_shape{1, 20, 20, 22, 22}; PartialShape indices_shape{1, 3}; @@ -500,8 +499,7 @@ TEST(type_prop, DISABLED_gather_7_indices_type_check) } } -// disabled until decision of type constrains for gather -TEST(type_prop, DISABLED_gather_7_axis_type_check) +TEST(type_prop, gather_7_axis_type_check) { PartialShape data_shape{1, 20, 20, 22, 22}; PartialShape indices_shape{1, 3}; @@ -529,3 +527,402 @@ TEST(type_prop, DISABLED_gather_7_axis_type_check) FAIL() << "Deduced type check failed for unexpected reason"; } } + +// ------------------------------ V8 ------------------------------ + +TEST(type_prop, gather_v8_axis_0) +{ + PartialShape data_shape{3, 2}; + PartialShape indices_shape{2, 2}; + PartialShape out_shape{2, 2, 2}; + int64_t batch_dims = 0; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i32, indices_shape); + auto A = op::Constant::create(element::i64, Shape{}, {0}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); + ASSERT_EQ(G->get_axis(), 0); +} + +TEST(type_prop, gather_v8_axis_1) +{ + PartialShape data_shape{3, 3}; + PartialShape indices_shape{1, 2}; + PartialShape out_shape{3, 1, 2}; + int64_t axis = 1; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i32, indices_shape); + auto A = op::Constant::create(element::i64, Shape{}, {axis}); + auto G = make_shared(D, I, A); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); + ASSERT_EQ(G->get_axis(), 1); +} + +TEST(type_prop, gather_v8_negative_axis) +{ + PartialShape data_shape{5, 6, 7}; + PartialShape indices_shape{4}; + PartialShape out_shape{5, 4, 7}; + int64_t axis = -2; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + auto G = make_shared(D, I, A); + + ASSERT_EQ(G->get_axis(), 1); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_dynamic_pshape_batch_dims_1_axis_1) +{ + PartialShape data_shape{Dimension(1, 7), 20, 20}; + PartialShape indices_shape{Dimension(7, 10), 3, 8}; + PartialShape out_shape{7, 3, 8, 20}; + int64_t axis = 1; + int64_t batch_dims = 1; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_dynamic_pshape_batch_dims_1_axis_3) +{ + PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400}; + PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8}; + PartialShape out_shape{7, Dimension(1, 3), 200, Dimension(2, 10), 3, 8}; + int64_t axis = 3; + int64_t batch_dims = 1; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_dynamic_2d_pshape_batch_dim) +{ + PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400}; + PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8}; + PartialShape out_shape{7, Dimension(2, 3), 3, 8, 400}; + int64_t axis = 2; + int64_t batch_dims = 2; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_dynamic_2d_pshape_batch_dim_axis_3) +{ + PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400}; + PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8}; + PartialShape out_shape{7, Dimension(2, 3), 200, 3, 8}; + int64_t axis = 3; + int64_t batch_dims = 2; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_dynamic_rank) +{ + PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400}; + PartialShape indices_shape = PartialShape::dynamic(Rank(3, 5)); + PartialShape out_shape = PartialShape::dynamic(Rank(4, 6)); + int64_t axis = 3; + int64_t batch_dims = 2; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_axis_boundcheck_for_dynamic_data_rank) +{ + PartialShape data_shape = PartialShape::dynamic(); + PartialShape indices_shape{7, 3, 8}; + PartialShape out_shape = PartialShape::dynamic(); + int64_t axis = 3; + int64_t batch_dims = 2; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_dynamic_rank_negative_batch_dims) +{ + PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400}; + PartialShape indices_shape = PartialShape::dynamic(Rank(3, 5)); + PartialShape out_shape = PartialShape::dynamic(Rank(3, 5)); + int64_t axis = 3; + int64_t batch_dims = -2; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_axis_not_set) +{ + PartialShape data_shape{1, 1, 200, 400}; + PartialShape indices_shape{2, 2}; + // default batch_dims = 0 + PartialShape out_shape = PartialShape::dynamic(5); // out_rank = data_rank + indices_rank - 1 - batch_dims + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i32, Shape{1}); + auto G = make_shared(D, I, A); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +TEST(type_prop, gather_v8_axis_not_set_positive_batch_dims) +{ + PartialShape data_shape{2, 1, 200, 400}; + PartialShape indices_shape{2, 2}; + int64_t batch_dims = 1; + PartialShape out_shape = PartialShape({2, + Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic()}); + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + auto A = make_shared(element::i32, Shape{1}); + auto G = make_shared(D, I, A, batch_dims); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); +} + +// --------------------- V8 Negative tests ------------------------------ + +TEST(type_prop, gather_v8_incorrect_axis_shape) +{ + auto D = make_shared(element::f32, Shape{5, 6}); + auto I = make_shared(element::i64, Shape{4}); + auto A = make_shared(element::i64, Shape{2}); + + try + { + auto G = make_shared(D, I, A); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect A input shape"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Axis input must be scalar or have 1 element")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, gather_v8_axis_out_of_input_rank) +{ + auto D = make_shared(element::f32, Shape{5, 6}); + auto I = make_shared(element::i64, Shape{4}); + auto A = make_shared(element::i64, Shape{1}, vector{2}); + int64_t batch_dims = 0; + try + { + auto G = make_shared(D, I, A, batch_dims); + // Should have thrown, so fail if it didn't + FAIL() << "axis check failed"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), std::string("Normalized axis must be >= 0 and < data_rank. But instead got")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, gather_v8_dynamic_batch_dims_inconsistent) +{ + PartialShape data_shape{Dimension(1, 7), 20, 20}; + PartialShape indices_shape{Dimension(8, 10), 3, 8}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + int64_t axis = 1; + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + int64_t batch_dims = 1; + + try + { + auto G = make_shared(D, I, A, batch_dims); + // Should have thrown, so fail if it didn't + FAIL() << "Shape inconsistency check for dynamic PartialShape failed"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("data and indices must have equal or intersecting sizes until batch_dims")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, gather_v8_batch_dims_less_check) +{ + PartialShape data_shape{1, 3, 20}; + PartialShape indices_shape{1, 3, 8}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + int64_t axis = 1; + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + int64_t batch_dims = 2; + + try + { + auto G = make_shared(D, I, A, batch_dims); + // Should have thrown, so fail if it didn't + FAIL() << "batch_dims check failed"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("After normalization batch_dims must be <= axis. But instead got: batch_dims =")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, gather_v8_batch_dims_less_indices_rank_check) +{ + PartialShape data_shape{1, 20, 20, 22, 22}; + PartialShape indices_shape{1, 3}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i64, indices_shape); + int64_t axis = 4; + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + int64_t batch_dims = 3; + + try + { + auto G = make_shared(D, I, A, batch_dims); + // Should have thrown, so fail if it didn't + FAIL() << "batch_dims check failed"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("batch_dims must be <= indices_rank")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, gather_v8_indices_type_check) +{ + PartialShape data_shape{1, 20, 20, 22, 22}; + PartialShape indices_shape{1, 3}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::f32, indices_shape); + int64_t axis = 4; + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + int64_t batch_dims = 0; + + try + { + auto G = make_shared(D, I, A, batch_dims); + // Should have thrown, so fail if it didn't + FAIL() << "indices element_type check failed"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("Indices element type must be of an integral number type")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +TEST(type_prop, gather_v8_axis_type_check) +{ + PartialShape data_shape{1, 20, 20, 22, 22}; + PartialShape indices_shape{1, 3}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i32, indices_shape); + int64_t axis = 4; + auto A = make_shared(element::f32, Shape{1}, vector{axis}); + int64_t batch_dims = 0; + + try + { + auto G = make_shared(D, I, A, batch_dims); + // Should have thrown, so fail if it didn't + FAIL() << "axis element_type check failed"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("Axis element type must be of an integral number type")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} diff --git a/ngraph/test/visitors/op/gather.cpp b/ngraph/test/visitors/op/gather.cpp index 3e6446a07b8..c162d8949ff 100644 --- a/ngraph/test/visitors/op/gather.cpp +++ b/ngraph/test/visitors/op/gather.cpp @@ -7,6 +7,7 @@ #include "ngraph/ngraph.hpp" #include "ngraph/opsets/opset1.hpp" #include "ngraph/opsets/opset7.hpp" +#include "ngraph/opsets/opset8.hpp" #include "util/visitor.hpp" @@ -29,3 +30,18 @@ TEST(attributes, gather_v7_op) EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims()); } + +TEST(attributes, gather_v8_op) +{ + NodeBuilder::get_ops().register_factory(); + auto data = make_shared(element::i32, Shape{2, 3, 4}); + auto indices = make_shared(element::i32, Shape{2}); + auto axis = make_shared(element::i32, Shape{}, 2); + int64_t batch_dims = 1; + + auto gather = make_shared(data, indices, axis, batch_dims); + NodeBuilder builder(gather); + auto g_gather = as_type_ptr(builder.create()); + + EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims()); +}