[nG] Gather-8 shell (#6171)
* gather-8 nGraph shell * added opset8 tbl; added visit_attribute test * corrected opset tbl
This commit is contained in:
parent
8390f40788
commit
ccd568a2d2
@ -61,5 +61,32 @@ namespace ngraph
|
|||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
};
|
};
|
||||||
} // namespace v7
|
} // namespace v7
|
||||||
|
|
||||||
|
namespace v8
|
||||||
|
{
|
||||||
|
/// \brief Gather slices from axis of params according to indices
|
||||||
|
class NGRAPH_API Gather : public op::util::GatherBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
Gather() = default;
|
||||||
|
|
||||||
|
/// \param data The tensor from which slices are gathered
|
||||||
|
/// \param indices Tensor with indexes to gather
|
||||||
|
/// \param axis The tensor is a dimension index to gather data from
|
||||||
|
/// \param batch_dims The number of batch dimension in data and indices tensors.
|
||||||
|
Gather(const Output<Node>& data,
|
||||||
|
const Output<Node>& indices,
|
||||||
|
const Output<Node>& axis,
|
||||||
|
const int64_t batch_dims = 0);
|
||||||
|
|
||||||
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
void validate_and_infer_types() override;
|
||||||
|
int64_t get_batch_dims() const;
|
||||||
|
|
||||||
|
std::shared_ptr<Node>
|
||||||
|
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
};
|
||||||
|
} // namespace v8
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
17
ngraph/core/include/ngraph/opsets/opset8.hpp
Normal file
17
ngraph/core/include/ngraph/opsets/opset8.hpp
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ngraph/ops.hpp"
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace opset8
|
||||||
|
{
|
||||||
|
#define NGRAPH_OP(a, b) using b::a;
|
||||||
|
#include "ngraph/opsets/opset8_tbl.hpp"
|
||||||
|
#undef NGRAPH_OP
|
||||||
|
} // namespace opset8
|
||||||
|
} // namespace ngraph
|
179
ngraph/core/include/ngraph/opsets/opset8_tbl.hpp
Normal file
179
ngraph/core/include/ngraph/opsets/opset8_tbl.hpp
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef NGRAPH_OP
|
||||||
|
#warning "NGRAPH_OP not defined"
|
||||||
|
#define NGRAPH_OP(x, y)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NGRAPH_OP(Abs, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Acos, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Add, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Asin, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Atan, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(AvgPool, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(BatchNormInference, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(BinaryConvolution, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Broadcast, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Bucketize, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Ceiling, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Clamp, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Concat, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Constant, ngraph::op)
|
||||||
|
NGRAPH_OP(Convert, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ConvertLike, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Convolution, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Cos, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Cosh, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(CumSum, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(DeformableConvolution, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(DeformablePSROIPooling, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(DepthToSpace, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(DetectionOutput, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Divide, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Elu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Erf, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Equal, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Exp, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ExtractImagePatches, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(FakeQuantize, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Floor, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(FloorMod, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GatherTree, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Greater, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GreaterEqual, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GroupConvolution, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GroupConvolutionBackpropData, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(GRN, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(HardSigmoid, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Less, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LessEqual, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Log, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(LogicalAnd, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LogicalNot, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LogicalOr, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LogicalXor, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(LRN, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(LSTMCell, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(MatMul, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(MaxPool, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Maximum, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Minimum, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Mod, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Multiply, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Negative, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(NormalizeL2, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(NotEqual, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(OneHot, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(PRelu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(PSROIPooling, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Pad, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Parameter, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Power, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(PriorBox, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(PriorBoxClustered, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Proposal, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Range, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Relu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ReduceMax, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceLogicalAnd, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceLogicalOr, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceMean, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceMin, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceProd, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(ReduceSum, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(RegionYolo, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ReorgYolo, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Reshape, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Result, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ReverseSequence, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ROIPooling, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ScatterNDUpdate, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Select, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Selu, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Sign, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Sigmoid, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Sin, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Sinh, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Softmax, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Sqrt, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(SpaceToDepth, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Split, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(SquaredDifference, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Squeeze, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(StridedSlice, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Subtract, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Tan, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Tanh, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(TensorIterator, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Tile, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(Transpose, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
|
||||||
|
|
||||||
|
// New operations added in opset2
|
||||||
|
NGRAPH_OP(BatchToSpace, ngraph::op::v1)
|
||||||
|
NGRAPH_OP(SpaceToBatch, ngraph::op::v1)
|
||||||
|
|
||||||
|
// New operations added in opset3
|
||||||
|
NGRAPH_OP(EmbeddingBagPackedSum, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(EmbeddingSegmentsSum, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(EmbeddingBagOffsetsSum, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(GRUCell, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(NonZero, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(RNNCell, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ROIAlign, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
|
||||||
|
NGRAPH_OP(ShapeOf, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(TopK, ngraph::op::v3)
|
||||||
|
|
||||||
|
// New operations added in opset4
|
||||||
|
NGRAPH_OP(Acosh, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Asinh, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Atanh, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(CTCLoss, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(HSwish, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Interpolate, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Mish, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(ReduceL1, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(ReduceL2, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(SoftPlus, ngraph::op::v4)
|
||||||
|
NGRAPH_OP(Swish, ngraph::op::v4)
|
||||||
|
|
||||||
|
// New operations added in opset5
|
||||||
|
NGRAPH_OP(GatherND, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(GRUSequence, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(HSigmoid, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(LogSoftmax, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(Loop, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(LSTMSequence, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(NonMaxSuppression, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(RNNSequence, ngraph::op::v5)
|
||||||
|
NGRAPH_OP(Round, ngraph::op::v5)
|
||||||
|
|
||||||
|
// New operations added in opset6
|
||||||
|
NGRAPH_OP(CTCGreedyDecoderSeqLen, ngraph::op::v6)
|
||||||
|
NGRAPH_OP(ExperimentalDetectronDetectionOutput, ngraph::op::v6)
|
||||||
|
NGRAPH_OP(ExperimentalDetectronGenerateProposalsSingleImage, ngraph::op::v6)
|
||||||
|
NGRAPH_OP(ExperimentalDetectronPriorGridGenerator, ngraph::op::v6)
|
||||||
|
NGRAPH_OP(ExperimentalDetectronROIFeatureExtractor, ngraph::op::v6)
|
||||||
|
NGRAPH_OP(ExperimentalDetectronTopKROIs, ngraph::op::v6)
|
||||||
|
NGRAPH_OP(GatherElements, ngraph::op::v6)
|
||||||
|
NGRAPH_OP(MVN, ngraph::op::v6)
|
||||||
|
NGRAPH_OP(Assign, ngraph::op::v6) // new version
|
||||||
|
NGRAPH_OP(ReadValue, ngraph::op::v6) // new version
|
||||||
|
|
||||||
|
// New operations added in opset7
|
||||||
|
NGRAPH_OP(DFT, ngraph::op::v7)
|
||||||
|
NGRAPH_OP(Einsum, ngraph::op::v7)
|
||||||
|
NGRAPH_OP(Gelu, ngraph::op::v7)
|
||||||
|
NGRAPH_OP(IDFT, ngraph::op::v7)
|
||||||
|
NGRAPH_OP(Roll, ngraph::op::v7)
|
||||||
|
|
||||||
|
// New operations added in opset8
|
||||||
|
NGRAPH_OP(Gather, ngraph::op::v8)
|
@ -86,3 +86,50 @@ shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_a
|
|||||||
check_new_args_count(this, new_args);
|
check_new_args_count(this, new_args);
|
||||||
return make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
return make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(op::v8::Gather, "Gather", 8, op::util::GatherBase);
|
||||||
|
|
||||||
|
op::v8::Gather::Gather(const Output<Node>& data,
|
||||||
|
const Output<Node>& indices,
|
||||||
|
const Output<Node>& axis,
|
||||||
|
const int64_t batch_dims)
|
||||||
|
: GatherBase(data, indices, axis, batch_dims)
|
||||||
|
{
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v8::Gather::validate_and_infer_types()
|
||||||
|
{
|
||||||
|
NGRAPH_OP_SCOPE(v8_Gather_validate_and_infer_types);
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
get_input_element_type(1).is_integral_number(),
|
||||||
|
"Indices element type must be of an integral number type.");
|
||||||
|
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
get_input_element_type(2).is_integral_number(),
|
||||||
|
"Axis element type must be of an integral number type.");
|
||||||
|
|
||||||
|
op::util::GatherBase::validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t op::v8::Gather::get_batch_dims() const
|
||||||
|
{
|
||||||
|
if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static())
|
||||||
|
return m_batch_dims + get_input_partial_shape(1).rank().get_length();
|
||||||
|
else
|
||||||
|
return m_batch_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ngraph::op::v8::Gather::visit_attributes(AttributeVisitor& visitor)
|
||||||
|
{
|
||||||
|
NGRAPH_OP_SCOPE(v8_Gather_visit_attributes);
|
||||||
|
visitor.on_attribute("batch_dims", m_batch_dims);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Node> op::v8::Gather::clone_with_new_inputs(const OutputVector& new_args) const
|
||||||
|
{
|
||||||
|
NGRAPH_OP_SCOPE(v8_Gather_clone_with_new_inputs);
|
||||||
|
check_new_args_count(this, new_args);
|
||||||
|
return make_shared<v8::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
||||||
|
}
|
||||||
|
@ -13,7 +13,7 @@ using namespace ngraph;
|
|||||||
|
|
||||||
// ------------------------------ V1 ------------------------------
|
// ------------------------------ V1 ------------------------------
|
||||||
|
|
||||||
TEST(type_prop, gather_axis_0)
|
TEST(type_prop, gather_v1_axis_0)
|
||||||
{
|
{
|
||||||
Shape params_shape{3, 2};
|
Shape params_shape{3, 2};
|
||||||
Shape indices_shape{2, 2};
|
Shape indices_shape{2, 2};
|
||||||
@ -27,7 +27,7 @@ TEST(type_prop, gather_axis_0)
|
|||||||
ASSERT_EQ(G->get_axis(), 0);
|
ASSERT_EQ(G->get_axis(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_7_uint8)
|
TEST(type_prop, gather_v1_uint8)
|
||||||
{
|
{
|
||||||
// Gather_1 must allow even if indices is not int32/int64
|
// Gather_1 must allow even if indices is not int32/int64
|
||||||
PartialShape data_shape{3, 2};
|
PartialShape data_shape{3, 2};
|
||||||
@ -44,7 +44,7 @@ TEST(type_prop, gather_7_uint8)
|
|||||||
ASSERT_EQ(G->get_axis(), 0);
|
ASSERT_EQ(G->get_axis(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_7_float32)
|
TEST(type_prop, gather_v1_float32)
|
||||||
{
|
{
|
||||||
// Gather_1 should allow non int32/int64 indices
|
// Gather_1 should allow non int32/int64 indices
|
||||||
PartialShape data_shape{3, 2};
|
PartialShape data_shape{3, 2};
|
||||||
@ -335,7 +335,7 @@ TEST(type_prop, gather_7_axis_not_set_positive_batch_dims)
|
|||||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------------- Negative tests ------------------------------
|
// --------------------- V7 Negative tests ------------------------------
|
||||||
|
|
||||||
TEST(type_prop, gather_7_incorrect_axis_shape)
|
TEST(type_prop, gather_7_incorrect_axis_shape)
|
||||||
{
|
{
|
||||||
@ -470,8 +470,7 @@ TEST(type_prop, gather_7_batch_dims_less_indices_rank_check)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// disabled until decision of type constrains for gather
|
TEST(type_prop, gather_7_indices_type_check)
|
||||||
TEST(type_prop, DISABLED_gather_7_indices_type_check)
|
|
||||||
{
|
{
|
||||||
PartialShape data_shape{1, 20, 20, 22, 22};
|
PartialShape data_shape{1, 20, 20, 22, 22};
|
||||||
PartialShape indices_shape{1, 3};
|
PartialShape indices_shape{1, 3};
|
||||||
@ -500,8 +499,7 @@ TEST(type_prop, DISABLED_gather_7_indices_type_check)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// disabled until decision of type constrains for gather
|
TEST(type_prop, gather_7_axis_type_check)
|
||||||
TEST(type_prop, DISABLED_gather_7_axis_type_check)
|
|
||||||
{
|
{
|
||||||
PartialShape data_shape{1, 20, 20, 22, 22};
|
PartialShape data_shape{1, 20, 20, 22, 22};
|
||||||
PartialShape indices_shape{1, 3};
|
PartialShape indices_shape{1, 3};
|
||||||
@ -529,3 +527,402 @@ TEST(type_prop, DISABLED_gather_7_axis_type_check)
|
|||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ------------------------------ V8 ------------------------------
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_axis_0)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{3, 2};
|
||||||
|
PartialShape indices_shape{2, 2};
|
||||||
|
PartialShape out_shape{2, 2, 2};
|
||||||
|
int64_t batch_dims = 0;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||||
|
auto A = op::Constant::create(element::i64, Shape{}, {0});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
ASSERT_EQ(G->get_axis(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_axis_1)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{3, 3};
|
||||||
|
PartialShape indices_shape{1, 2};
|
||||||
|
PartialShape out_shape{3, 1, 2};
|
||||||
|
int64_t axis = 1;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||||
|
auto A = op::Constant::create(element::i64, Shape{}, {axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
ASSERT_EQ(G->get_axis(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_negative_axis)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{5, 6, 7};
|
||||||
|
PartialShape indices_shape{4};
|
||||||
|
PartialShape out_shape{5, 4, 7};
|
||||||
|
int64_t axis = -2;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_axis(), 1);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_dynamic_pshape_batch_dims_1_axis_1)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{Dimension(1, 7), 20, 20};
|
||||||
|
PartialShape indices_shape{Dimension(7, 10), 3, 8};
|
||||||
|
PartialShape out_shape{7, 3, 8, 20};
|
||||||
|
int64_t axis = 1;
|
||||||
|
int64_t batch_dims = 1;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_dynamic_pshape_batch_dims_1_axis_3)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||||
|
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
|
||||||
|
PartialShape out_shape{7, Dimension(1, 3), 200, Dimension(2, 10), 3, 8};
|
||||||
|
int64_t axis = 3;
|
||||||
|
int64_t batch_dims = 1;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_dynamic_2d_pshape_batch_dim)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||||
|
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
|
||||||
|
PartialShape out_shape{7, Dimension(2, 3), 3, 8, 400};
|
||||||
|
int64_t axis = 2;
|
||||||
|
int64_t batch_dims = 2;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_dynamic_2d_pshape_batch_dim_axis_3)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||||
|
PartialShape indices_shape{Dimension(7, 10), Dimension(2, 10), 3, 8};
|
||||||
|
PartialShape out_shape{7, Dimension(2, 3), 200, 3, 8};
|
||||||
|
int64_t axis = 3;
|
||||||
|
int64_t batch_dims = 2;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_dynamic_rank)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||||
|
PartialShape indices_shape = PartialShape::dynamic(Rank(3, 5));
|
||||||
|
PartialShape out_shape = PartialShape::dynamic(Rank(4, 6));
|
||||||
|
int64_t axis = 3;
|
||||||
|
int64_t batch_dims = 2;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_axis_boundcheck_for_dynamic_data_rank)
|
||||||
|
{
|
||||||
|
PartialShape data_shape = PartialShape::dynamic();
|
||||||
|
PartialShape indices_shape{7, 3, 8};
|
||||||
|
PartialShape out_shape = PartialShape::dynamic();
|
||||||
|
int64_t axis = 3;
|
||||||
|
int64_t batch_dims = 2;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_dynamic_rank_negative_batch_dims)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{Dimension(1, 7), Dimension(1, 3), 200, 400};
|
||||||
|
PartialShape indices_shape = PartialShape::dynamic(Rank(3, 5));
|
||||||
|
PartialShape out_shape = PartialShape::dynamic(Rank(3, 5));
|
||||||
|
int64_t axis = 3;
|
||||||
|
int64_t batch_dims = -2;
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_axis_not_set)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{1, 1, 200, 400};
|
||||||
|
PartialShape indices_shape{2, 2};
|
||||||
|
// default batch_dims = 0
|
||||||
|
PartialShape out_shape = PartialShape::dynamic(5); // out_rank = data_rank + indices_rank - 1 - batch_dims
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_axis_not_set_positive_batch_dims)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{2, 1, 200, 400};
|
||||||
|
PartialShape indices_shape{2, 2};
|
||||||
|
int64_t batch_dims = 1;
|
||||||
|
PartialShape out_shape = PartialShape({2,
|
||||||
|
Dimension::dynamic(),
|
||||||
|
Dimension::dynamic(),
|
||||||
|
Dimension::dynamic()});
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
|
||||||
|
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------- V8 Negative tests ------------------------------
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_incorrect_axis_shape)
|
||||||
|
{
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, Shape{5, 6});
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, Shape{4});
|
||||||
|
auto A = make_shared<op::Parameter>(element::i64, Shape{2});
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "Incorrect A input shape";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(error.what(),
|
||||||
|
std::string("Axis input must be scalar or have 1 element"));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_axis_out_of_input_rank)
|
||||||
|
{
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, Shape{5, 6});
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, Shape{4});
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{2});
|
||||||
|
int64_t batch_dims = 0;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "axis check failed";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(
|
||||||
|
error.what(), std::string("Normalized axis must be >= 0 and < data_rank. But instead got"));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_dynamic_batch_dims_inconsistent)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{Dimension(1, 7), 20, 20};
|
||||||
|
PartialShape indices_shape{Dimension(8, 10), 3, 8};
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
int64_t axis = 1;
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
int64_t batch_dims = 1;
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(
|
||||||
|
error.what(),
|
||||||
|
std::string("data and indices must have equal or intersecting sizes until batch_dims"));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_batch_dims_less_check)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{1, 3, 20};
|
||||||
|
PartialShape indices_shape{1, 3, 8};
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
int64_t axis = 1;
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
int64_t batch_dims = 2;
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "batch_dims check failed";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(
|
||||||
|
error.what(),
|
||||||
|
std::string("After normalization batch_dims must be <= axis. But instead got: batch_dims ="));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_batch_dims_less_indices_rank_check)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{1, 20, 20, 22, 22};
|
||||||
|
PartialShape indices_shape{1, 3};
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||||
|
int64_t axis = 4;
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
int64_t batch_dims = 3;
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "batch_dims check failed";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(
|
||||||
|
error.what(),
|
||||||
|
std::string("batch_dims must be <= indices_rank"));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_indices_type_check)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{1, 20, 20, 22, 22};
|
||||||
|
PartialShape indices_shape{1, 3};
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::f32, indices_shape);
|
||||||
|
int64_t axis = 4;
|
||||||
|
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||||
|
int64_t batch_dims = 0;
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "indices element_type check failed";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(
|
||||||
|
error.what(),
|
||||||
|
std::string("Indices element type must be of an integral number type"));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_v8_axis_type_check)
|
||||||
|
{
|
||||||
|
PartialShape data_shape{1, 20, 20, 22, 22};
|
||||||
|
PartialShape indices_shape{1, 3};
|
||||||
|
|
||||||
|
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||||
|
int64_t axis = 4;
|
||||||
|
auto A = make_shared<op::Constant>(element::f32, Shape{1}, vector<int64_t>{axis});
|
||||||
|
int64_t batch_dims = 0;
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v8::Gather>(D, I, A, batch_dims);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "axis element_type check failed";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(
|
||||||
|
error.what(),
|
||||||
|
std::string("Axis element type must be of an integral number type"));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include "ngraph/ngraph.hpp"
|
#include "ngraph/ngraph.hpp"
|
||||||
#include "ngraph/opsets/opset1.hpp"
|
#include "ngraph/opsets/opset1.hpp"
|
||||||
#include "ngraph/opsets/opset7.hpp"
|
#include "ngraph/opsets/opset7.hpp"
|
||||||
|
#include "ngraph/opsets/opset8.hpp"
|
||||||
|
|
||||||
#include "util/visitor.hpp"
|
#include "util/visitor.hpp"
|
||||||
|
|
||||||
@ -29,3 +30,18 @@ TEST(attributes, gather_v7_op)
|
|||||||
|
|
||||||
EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims());
|
EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(attributes, gather_v8_op)
|
||||||
|
{
|
||||||
|
NodeBuilder::get_ops().register_factory<opset8::Gather>();
|
||||||
|
auto data = make_shared<opset1::Parameter>(element::i32, Shape{2, 3, 4});
|
||||||
|
auto indices = make_shared<opset1::Parameter>(element::i32, Shape{2});
|
||||||
|
auto axis = make_shared<opset1::Constant>(element::i32, Shape{}, 2);
|
||||||
|
int64_t batch_dims = 1;
|
||||||
|
|
||||||
|
auto gather = make_shared<opset8::Gather>(data, indices, axis, batch_dims);
|
||||||
|
NodeBuilder builder(gather);
|
||||||
|
auto g_gather = as_type_ptr<opset8::Gather>(builder.create());
|
||||||
|
|
||||||
|
EXPECT_EQ(g_gather->get_batch_dims(), gather->get_batch_dims());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user