[ShapeInference] Improve GatherND shape inference (#15378)

* Add shape_infer function for GatherND

* GatherND shape infer improvements

* Align test to trigger correct error message

* Add new and improve GatherND type_prop tests

* Update tests to use ov namespace

* Add GatherND common shape_infer tests

* Init shape infer tests for not common cases

* Tests refactor

* Add default ctor tests

* Add more test cases

* Register shape_infer for GatherND V5 and V8

* Enable more tests and print params

* Move GatherNDTestParams
This commit is contained in:
Katarzyna Mitrus 2023-01-31 11:12:12 +01:00 committed by GitHub
parent 4ce3e9a88d
commit f342e5d208
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 757 additions and 349 deletions

View File

@ -27,6 +27,10 @@ public:
return m_batch_dims;
}
void set_batch_dims(size_t batch_dims) {
m_batch_dims = batch_dims;
}
void validate_inputs_and_infer_shape();
bool visit_attributes(AttributeVisitor& visitor) override;

View File

@ -0,0 +1,93 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/gather_nd.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace gather_nd {
template <class TShape, class TOp>
std::vector<TShape> gather_nd_base_shape_infer(const TOp* op, const std::vector<TShape>& input_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2);
const auto& data_pshape = input_shapes[0];
const auto& indices_pshape = input_shapes[1];
if (data_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(op, data_pshape.size() > 0, "Data rank must be at least 1.");
NODE_VALIDATION_CHECK(op,
data_pshape.size() > op->get_batch_dims(),
"Number of batch dimensions must not exceed a rank of data.");
}
if (indices_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(op, indices_pshape.size() > 0, "Indices rank must be at least 1.");
NODE_VALIDATION_CHECK(op,
indices_pshape.size() > op->get_batch_dims(),
"Number of batch dimensions must not exceed a rank of indices.");
}
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() &&
indices_pshape[indices_pshape.size() - 1].is_static()) {
auto batch_dims = op->get_batch_dims();
auto indices_tuple_length = indices_pshape[indices_pshape.size() - 1].get_length(); // last dim of indices
NODE_VALIDATION_CHECK(
op,
static_cast<int64_t>(indices_tuple_length + op->get_batch_dims()) <= data_pshape.rank().get_length(),
"Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions.");
int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - batch_dims;
int64_t output_indices_length = indices_pshape.rank().get_length() - batch_dims - 1;
auto output_rank = output_indices_length + slice_length;
using DimType = typename TShape::value_type;
std::vector<DimType> output_shape(output_rank + batch_dims);
for (size_t dim = 0; dim < batch_dims; ++dim) {
NODE_VALIDATION_CHECK(op,
DimType::merge(output_shape[dim], data_pshape[dim], indices_pshape[dim]),
"Batch dimensions of data and indices must be the same.");
}
for (int64_t dim = 0; dim < output_indices_length; ++dim) {
output_shape[batch_dims + dim] = indices_pshape[batch_dims + dim];
}
for (int64_t dim = 0; dim < slice_length; ++dim) {
output_shape[batch_dims + output_indices_length + dim] =
data_pshape[batch_dims + indices_tuple_length + dim];
}
return std::vector<TShape>{TShape(output_shape)};
} else {
return std::vector<TShape>{ov::PartialShape::dynamic()};
}
}
} // namespace gather_nd
namespace v5 {
template <class TShape>
void shape_infer(const GatherND* op, const std::vector<TShape>& input_shapes, std::vector<TShape>& output_shapes) {
using DimType = typename TShape::value_type;
output_shapes = gather_nd::gather_nd_base_shape_infer(op, input_shapes);
// If batch_dims > 1, batch dimensions are need to be fused
auto batch_dims = op->get_batch_dims();
if (batch_dims > 1 && output_shapes[0].rank().is_static()) {
const auto& output_base_shape = output_shapes[0];
std::vector<DimType> output_shape{1};
for (size_t dim = 0; dim < batch_dims; ++dim) {
output_shape[0] *= output_base_shape[dim];
}
output_shape.insert(output_shape.begin() + 1, output_base_shape.begin() + batch_dims, output_base_shape.end());
output_shapes[0] = TShape(output_shape);
}
}
} // namespace v5
namespace v8 {
template <class TShape>
void shape_infer(const GatherND* op, const std::vector<TShape>& input_shapes, std::vector<TShape>& output_shapes) {
output_shapes = gather_nd::gather_nd_base_shape_infer(op, input_shapes);
}
} // namespace v8
} // namespace op
} // namespace ov

View File

@ -4,6 +4,7 @@
#include "ngraph/op/gather_nd.hpp"
#include "gather_nd_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/shape.hpp"
@ -18,38 +19,18 @@ op::v5::GatherND::GatherND(const Output<Node>& data, const Output<Node>& indices
void op::v5::GatherND::validate_and_infer_types() {
OV_OP_SCOPE(v5_GatherND_validate_and_infer_types);
validate_inputs_and_infer_shape();
// If we have m_batch_dims > 1 we need to fuse batch dimensions of output
if (m_batch_dims > 1) {
const auto& output_pshape = get_output_partial_shape(0);
const auto& data_type = get_input_element_type(0);
const auto& data_type = get_input_element_type(0);
const auto& indices_type = get_input_element_type(1);
if (output_pshape.rank().is_static()) {
const auto& out_size = output_pshape.size();
std::vector<Dimension> output_shape(out_size - m_batch_dims + 1);
output_shape[0] = 1;
for (size_t dim = 0; dim < m_batch_dims; dim++) {
if (output_pshape[dim].is_static()) {
output_shape[0] *= output_pshape[dim].get_length();
} else {
output_shape[0] = Dimension::dynamic();
break;
}
}
size_t ind = 1;
for (size_t dim = m_batch_dims; dim < out_size; dim++) {
if (output_pshape[dim].is_static()) {
output_shape[ind] = output_pshape[dim].get_length();
} else {
output_shape[ind] = Dimension::dynamic();
}
ind++;
}
NODE_VALIDATION_CHECK(this,
indices_type.is_integral_number(),
"The indices type is expected to be an integer type. Got: ",
indices_type);
set_output_type(0, data_type, ov::PartialShape(output_shape));
}
}
std::vector<PartialShape> out_shapes(1);
shape_infer(this, {get_input_partial_shape(0), get_input_partial_shape(1)}, out_shapes);
set_output_type(0, data_type, out_shapes[0]);
}
shared_ptr<Node> op::v5::GatherND::clone_with_new_inputs(const OutputVector& new_args) const {
@ -66,7 +47,17 @@ op::v8::GatherND::GatherND(const Output<Node>& data, const Output<Node>& indices
void op::v8::GatherND::validate_and_infer_types() {
OV_OP_SCOPE(v8_GatherND_validate_and_infer_types);
validate_inputs_and_infer_shape();
const auto& data_type = get_input_element_type(0);
const auto& indices_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
indices_type.is_integral_number(),
"The indices type is expected to be an integer type. Got: ",
indices_type);
std::vector<PartialShape> out_shapes(1);
shape_infer(this, {get_input_partial_shape(0), get_input_partial_shape(1)}, out_shapes);
set_output_type(0, data_type, ov::PartialShape(out_shapes[0]));
}
shared_ptr<Node> op::v8::GatherND::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -6,6 +6,7 @@
#include <ngraph/validation_util.hpp>
#include "gather_nd_shape_inference.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/squeeze.hpp"
@ -30,80 +31,9 @@ void ov::op::util::GatherNDBase::validate_inputs_and_infer_shape() {
"The indices type is expected to be an integer type. Got: ",
indices_type);
// check ranks of input tensors
const auto& data_pshape = get_input_partial_shape(0);
const auto& indices_pshape = get_input_partial_shape(1);
if (data_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1.");
NODE_VALIDATION_CHECK(this,
data_pshape.rank().get_length() > static_cast<int64_t>(m_batch_dims),
"Number of batch dimensions must not exceed a rank of data.");
}
if (indices_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1.");
NODE_VALIDATION_CHECK(this,
indices_pshape.rank().get_length() > static_cast<int64_t>(m_batch_dims),
"Number of batch dimensions must not exceed a rank of indices.");
}
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) {
// check that batch dimensions of data and indices are the same
for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) {
if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) {
NODE_VALIDATION_CHECK(this,
data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(),
"Batch dimensions of data and indices must be the same.");
}
}
if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) {
NODE_VALIDATION_CHECK(
this,
static_cast<int64_t>(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() +
m_batch_dims) <= data_pshape.rank().get_length(),
"Length of a tuple with indices must not exceed a rank of data tensor "
"excluding "
"batch dimensions.");
}
}
// set output shape
set_output_size(1);
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() &&
indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) {
auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length();
int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims;
int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1;
auto output_rank = output_indices_length + slice_length;
size_t delta_output_rank = 0;
delta_output_rank = m_batch_dims;
std::vector<Dimension> output_shape(output_rank + delta_output_rank);
for (size_t dim = 0; dim < m_batch_dims; dim++) {
output_shape[dim] = 1;
if (data_pshape[dim].is_static()) {
output_shape[dim] = data_pshape[dim].get_length();
} else if (indices_pshape[dim].is_static()) {
output_shape[dim] = indices_pshape[dim].get_length();
} else {
output_shape[dim] = Dimension::dynamic();
break;
}
}
for (int64_t dim = 0; dim < output_indices_length; dim++) {
output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims];
}
for (int64_t dim = 0; dim < slice_length; dim++) {
output_shape[output_indices_length + dim + delta_output_rank] =
data_pshape[m_batch_dims + indices_tuple_length + dim];
}
set_output_type(0, data_type, ov::PartialShape(output_shape));
} else {
set_output_type(0, data_type, ov::PartialShape::dynamic());
}
std::vector<PartialShape> in_shapes{get_input_partial_shape(0), get_input_partial_shape(1)};
std::vector<PartialShape> out_shapes = op::gather_nd::gather_nd_base_shape_infer(this, in_shapes);
set_output_type(0, data_type, out_shapes[0]);
}
bool ov::op::util::GatherNDBase::visit_attributes(AttributeVisitor& visitor) {

View File

@ -2,99 +2,248 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "dimension_tracker.hpp"
#include "gmock/gmock.h"
#include "openvino/op/ops.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace ov;
using namespace op;
using namespace testing;
template <typename T>
class gather_nd_type_prop : public TypePropOpTest<T> {};
TYPED_TEST_SUITE_P(gather_nd_type_prop);
// ------------------------------ V5 & V8 ----------------------------------------
// Output shape for V5 and V8 is the same, when batch_dims attribute is equal to 1
TYPED_TEST_P(gather_nd_type_prop, default_ctor) {
PartialShape data_shape{8, 3, 11, 12};
PartialShape indices_shape{8, 4, 2};
PartialShape expected_shape{8, 4, 12};
auto op = this->make_op();
constexpr auto batch_dims = 1;
op->set_batch_dims(batch_dims);
EXPECT_EQ(op->get_batch_dims(), batch_dims);
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
op->set_argument(0, data_param);
op->set_argument(1, indices_param);
op->validate_and_infer_types();
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), expected_shape);
}
TYPED_TEST_P(gather_nd_type_prop, static_shape_batch_dims_1_ind_tuple_2) {
PartialShape data_shape{8, 3, 11, 12};
PartialShape indices_shape{8, 4, 2};
PartialShape expected_shape{8, 4, 12};
constexpr auto batch_dims = 1;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = this->make_op(data_param, indices_param, batch_dims);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), expected_shape);
}
TYPED_TEST_P(gather_nd_type_prop, static_shape_batch_dims_1_ind_tuple_3) {
PartialShape data_shape{8, 3, 11, 12};
PartialShape indices_shape{8, 4, 3};
PartialShape expected_shape{8, 4};
constexpr auto batch_dims = 1;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = this->make_op(data_param, indices_param, batch_dims);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), expected_shape);
}
TYPED_TEST_P(gather_nd_type_prop, static_shape_batch_dims_1_ind_tuple_dynamic) {
PartialShape data_shape{8, 3, 11, 12};
PartialShape indices_shape{8, 4, -1};
PartialShape expected_shape = PartialShape::dynamic();
constexpr auto batch_dims = 1;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = this->make_op(data_param, indices_param, batch_dims);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), expected_shape);
}
TYPED_TEST_P(gather_nd_type_prop, interval_both_labeled_batch_dims_1_ind_tuple_2) {
PartialShape data_shape{{2, 6}, {3, 7}, {8, 10}, {12, 14}};
set_shape_labels(data_shape, 10);
PartialShape indices_shape{{4, 8}, {6, 10}, 2};
set_shape_labels(indices_shape, 20);
PartialShape expected_shape{{4, 6}, {6, 10}, {12, 14}};
constexpr auto batch_dims = 1;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = this->make_op(data_param, indices_param, batch_dims);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 13));
}
TYPED_TEST_P(gather_nd_type_prop, interval_data_labeled_batch_dims_1_ind_tuple_2) {
PartialShape data_shape{{2, 6}, {3, 7}, {8, 10}, {12, 14}};
set_shape_labels(data_shape, 10);
PartialShape indices_shape{{4, 8}, {6, 10}, 2};
PartialShape expected_shape{{4, 6}, {6, 10}, {12, 14}};
constexpr auto batch_dims = 1;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = this->make_op(data_param, indices_param, batch_dims);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, ov::no_label, 13));
}
TYPED_TEST_P(gather_nd_type_prop, interval_indices_labeled_batch_dims_1_ind_tuple_2) {
PartialShape data_shape{{2, 6}, {3, 7}, {8, 10}, {12, 14}};
PartialShape indices_shape{{4, 8}, {6, 10}, 2};
set_shape_labels(indices_shape, 20);
PartialShape expected_shape{{4, 6}, {6, 10}, {12, 14}};
constexpr auto batch_dims = 1;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = this->make_op(data_param, indices_param, batch_dims);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, ov::no_label));
}
REGISTER_TYPED_TEST_SUITE_P(gather_nd_type_prop,
default_ctor,
static_shape_batch_dims_1_ind_tuple_2,
static_shape_batch_dims_1_ind_tuple_3,
static_shape_batch_dims_1_ind_tuple_dynamic,
interval_both_labeled_batch_dims_1_ind_tuple_2,
interval_data_labeled_batch_dims_1_ind_tuple_2,
interval_indices_labeled_batch_dims_1_ind_tuple_2);
typedef Types<v5::GatherND, v8::GatherND> GatherNDTypes;
INSTANTIATE_TYPED_TEST_SUITE_P(type_prop, gather_nd_type_prop, GatherNDTypes);
// ------------------------------ V5 ------------------------------
TEST(type_prop, gather_nd_slices_from_4d_batch_dims0) {
Shape params_shape{2, 3, 11, 12};
TEST(type_prop, gather_nd_v5_slices_from_4d_batch_dims0) {
Shape data_shape{2, 3, 11, 12};
Shape indices_shape{2, 3, 2};
Shape out_shape{2, 3, 11, 12};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 0);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{2, 3, 11, 12};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 0);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_scalars_from_4d_batch_dims2) {
Shape params_shape{2, 3, 11, 12};
TEST(type_prop, gather_nd_v5_scalars_from_4d_batch_dims2) {
Shape data_shape{2, 3, 11, 12};
Shape indices_shape{2, 3, 2};
Shape out_shape{6};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{6};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_slices_from_5d_batch_dims2) {
Shape params_shape{7, 5, 11, 12, 32};
TEST(type_prop, gather_nd_v5_slices_from_5d_batch_dims2) {
Shape data_shape{7, 5, 11, 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{35, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{35, 3, 12, 32};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_batch_dim2_with_dyn_dim) {
PartialShape params_shape{7, Dimension::dynamic(), 11, 12, 32};
TEST(type_prop, gather_nd_v5_batch_dim2_with_dyn_dim) {
PartialShape data_shape{7, Dimension::dynamic(), 11, 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{35, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{35, 3, 12, 32};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_batch_dim2_with_dyn_dim2) {
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, 32};
TEST(type_prop, gather_nd_v5_batch_dim2_with_dyn_dim2) {
PartialShape data_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{35, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{35, 3, 12, 32};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_batch_dim2_with_dyn_dim3) {
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
TEST(type_prop, gather_nd_v5_batch_dim2_with_dyn_dim3) {
PartialShape data_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
Shape indices_shape{7, 5, 3, 1};
PartialShape out_shape{35, 3, 12, Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(out_shape));
PartialShape expected_shape{35, 3, 12, Dimension::dynamic()};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(expected_shape));
}
TEST(type_prop, gather_nd_batch_dim0_with_dyn_ind_dim) {
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
TEST(type_prop, gather_nd_v5_batch_dim0_with_dyn_ind_dim) {
PartialShape data_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
PartialShape indices_shape{7, 5, 3, Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 0);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 0);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, gather_nd_fail_batch_dims_greater_indices_rank) {
Shape params_shape{2, 3, 4, 5};
TEST(type_prop, gather_nd_v5_fail_batch_dims_greater_indices_rank) {
Shape data_shape{2, 3, 4, 5};
Shape indices_shape{2, 1};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
try {
auto G5 = make_shared<op::v5::GatherND>(P, I, 3);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 3);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -105,14 +254,14 @@ TEST(type_prop, gather_nd_fail_batch_dims_greater_indices_rank) {
}
}
TEST(type_prop, gather_nd_fail_unequal_batch_dims) {
Shape params_shape{2, 3, 4, 5};
Shape indices_shape{2, 1, 4};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
TEST(type_prop, gather_nd_v5_fail_unequal_batch_dims) {
Shape data_shape{2, 3, 4, 5};
Shape indices_shape{2, 1, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
try {
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -122,14 +271,14 @@ TEST(type_prop, gather_nd_fail_unequal_batch_dims) {
}
}
TEST(type_prop, gather_nd_fail_indices_tuple_greater_data_rank_batch_dims2) {
Shape params_shape{2, 1, 4, 5};
TEST(type_prop, gather_nd_v5_fail_indices_tuple_greater_data_rank_batch_dims2) {
Shape data_shape{2, 1, 4, 5};
Shape indices_shape{2, 1, 5, 3};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
try {
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
auto op = make_shared<v5::GatherND>(data_param, indices_param, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -141,137 +290,177 @@ TEST(type_prop, gather_nd_fail_indices_tuple_greater_data_rank_batch_dims2) {
}
}
// ------------------------------ V0 + V5 ------------------------------
TEST(type_prop, gather_nd_scalar_from_2d) {
Shape params_shape{2, 2};
TEST(type_prop, gather_nd_v5_scalar_from_2d) {
Shape data_shape{2, 2};
Shape indices_shape{2, 2};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_1d_from_2d) {
Shape params_shape{2, 2};
TEST(type_prop, gather_nd_v5_1d_from_2d) {
Shape data_shape{2, 2};
Shape indices_shape{2, 1};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_scalar_from_3d) {
Shape params_shape{2, 2, 2};
TEST(type_prop, gather_nd_v5_scalar_from_3d) {
Shape data_shape{2, 2, 2};
Shape indices_shape{2, 3};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_1d_from_3d) {
Shape params_shape{2, 2, 2};
TEST(type_prop, gather_nd_v5_1d_from_3d) {
Shape data_shape{2, 2, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_2d_from_3d) {
Shape params_shape{2, 2, 2};
TEST(type_prop, gather_nd_v5_2d_from_3d) {
Shape data_shape{2, 2, 2};
Shape indices_shape{1, 1};
Shape out_shape{1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{1, 2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_batch_scalar_from_2d) {
Shape params_shape{2, 2};
TEST(type_prop, gather_nd_v5_batch_scalar_from_2d) {
Shape data_shape{2, 2};
Shape indices_shape{2, 1, 2};
Shape out_shape{2, 1};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 1};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_batch_1d_from_2d) {
Shape params_shape{2, 2};
TEST(type_prop, gather_nd_v5_batch_1d_from_2d) {
Shape data_shape{2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 1, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_batch_scalar_from_3d) {
Shape params_shape{2, 2, 2};
TEST(type_prop, gather_nd_v5_batch_scalar_from_3d) {
Shape data_shape{2, 2, 2};
Shape indices_shape{2, 2, 3};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_batch_1d_from_3d) {
Shape params_shape{2, 2, 2};
TEST(type_prop, gather_nd_v5_batch_1d_from_3d) {
Shape data_shape{2, 2, 2};
Shape indices_shape{2, 2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_batch_2d_from_3d) {
Shape params_shape{2, 2, 2};
TEST(type_prop, gather_nd_v5_batch_2d_from_3d) {
Shape data_shape{2, 2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 1, 2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_fail_params_rank) {
Shape params_shape{};
TEST(type_prop, gather_nd_v5_interval_both_labeled_batch_dims_2_ind_tuple_2) {
PartialShape data_shape{{2, 6}, {3, 7}, {8, 10}, {12, 14}};
set_shape_labels(data_shape, 10);
PartialShape indices_shape{{4, 8}, {6, 10}, 2};
set_shape_labels(indices_shape, 20);
PartialShape expected_shape{{24, 42}};
constexpr auto batch_dims = 2;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, batch_dims);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(ov::no_label));
}
TEST(type_prop, gather_nd_v5_interval_both_labeled_batch_dims_2_ind_tuple_1) {
PartialShape data_shape{{2, 6}, {3, 7}, {8, 10}, {12, 14}};
set_shape_labels(data_shape, 10);
PartialShape indices_shape{{4, 8}, {6, 10}, 1};
set_shape_labels(indices_shape, 20);
PartialShape expected_shape{{24, 42}, {12, 14}};
constexpr auto batch_dims = 2;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v5::GatherND>(data_param, indices_param, batch_dims);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(ov::no_label, 13));
}
TEST(type_prop, gather_nd_v5_fail_params_rank) {
Shape data_shape{};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 1, 2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
try {
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect params rank";
} catch (const NodeValidationFailure& error) {
@ -281,15 +470,15 @@ TEST(type_prop, gather_nd_fail_params_rank) {
}
}
TEST(type_prop, gather_nd_fail_indices_rank) {
Shape params_shape{2, 2, 2};
TEST(type_prop, gather_nd_v5_fail_indices_rank) {
Shape data_shape{2, 2, 2};
Shape indices_shape{};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
Shape expected_shape{2, 1, 2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
try {
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -299,15 +488,15 @@ TEST(type_prop, gather_nd_fail_indices_rank) {
}
}
TEST(type_prop, gather_nd_fail_indices_element_type) {
Shape params_shape{2, 2, 2};
TEST(type_prop, gather_nd_v5_fail_indices_element_type) {
Shape data_shape{2, 2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::f32, indices_shape);
Shape expected_shape{2, 1, 2, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::f32, indices_shape);
try {
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto op = make_shared<v5::GatherND>(data_param, indices_param);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices element type";
} catch (const NodeValidationFailure& error) {
@ -319,90 +508,132 @@ TEST(type_prop, gather_nd_fail_indices_element_type) {
// ------------------------------ V8 ------------------------------
TEST(type_prop, gather_nd_8_slices_from_4d_batch_dims0) {
Shape params_shape{2, 3, 11, 12};
TEST(type_prop, gather_nd_v8_slices_from_4d_batch_dims0) {
Shape data_shape{2, 3, 11, 12};
Shape indices_shape{2, 3, 2};
Shape out_shape{2, 3, 11, 12};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v8::GatherND>(P, I, 0);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{2, 3, 11, 12};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 0);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_8_scalars_from_4d_batch_dims2) {
Shape params_shape{2, 3, 11, 12};
TEST(type_prop, gather_nd_v8_scalars_from_4d_batch_dims2) {
Shape data_shape{2, 3, 11, 12};
Shape indices_shape{2, 3, 2};
Shape out_shape{2, 3};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{2, 3};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_8_slices_from_5d_batch_dims2) {
Shape params_shape{7, 5, 11, 12, 32};
TEST(type_prop, gather_nd_v8_slices_from_5d_batch_dims2) {
Shape data_shape{7, 5, 11, 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{7, 5, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{7, 5, 3, 12, 32};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim) {
PartialShape params_shape{7, Dimension::dynamic(), 11, 12, 32};
TEST(type_prop, gather_nd_v8_batch_dim2_with_dyn_dim) {
PartialShape data_shape{7, Dimension::dynamic(), 11, 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{7, 5, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{7, 5, 3, 12, 32};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim2) {
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, 32};
TEST(type_prop, gather_nd_v8_batch_dim2_with_dyn_dim2) {
PartialShape data_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{7, 5, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
Shape expected_shape{7, 5, 3, 12, 32};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_EQ(op->get_shape(), expected_shape);
}
TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim3) {
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
TEST(type_prop, gather_nd_v8_batch_dim2_with_dyn_dim3) {
PartialShape data_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
Shape indices_shape{7, 5, 3, 1};
PartialShape out_shape{7, 5, 3, 12, Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(out_shape));
PartialShape expected_shape{7, 5, 3, 12, Dimension::dynamic()};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 2);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(expected_shape));
}
TEST(type_prop, gather_nd_8_batch_dim0_with_dyn_ind_dim) {
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
TEST(type_prop, gather_nd_v8_batch_dim0_with_dyn_ind_dim) {
PartialShape data_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
PartialShape indices_shape{7, 5, 3, Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v8::GatherND>(P, I, 0);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 0);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, gather_nd_8_fail_batch_dims_greater_indices_rank) {
Shape params_shape{2, 3, 4, 5};
TEST(type_prop, gather_nd_v8_interval_both_labeled_batch_dims_2_ind_tuple_2) {
PartialShape data_shape{{2, 6}, {3, 7}, {8, 10}, {12, 14}};
set_shape_labels(data_shape, 10);
PartialShape indices_shape{{4, 8}, {6, 10}, 2};
set_shape_labels(indices_shape, 20);
PartialShape expected_shape{{4, 6}, {6, 7}};
constexpr auto batch_dims = 2;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, batch_dims);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21));
}
TEST(type_prop, gather_nd_v8_interval_both_labeled_batch_dims_2_ind_tuple_1) {
PartialShape data_shape{{2, 6}, {3, 7}, {8, 10}, {12, 14}};
set_shape_labels(data_shape, 10);
PartialShape indices_shape{{4, 8}, {6, 10}, 1};
set_shape_labels(indices_shape, 20);
PartialShape expected_shape{{4, 6}, {6, 7}, {12, 14}};
constexpr auto batch_dims = 2;
auto data_param = std::make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = std::make_shared<v0::Parameter>(element::i32, indices_shape);
auto op = make_shared<v8::GatherND>(data_param, indices_param, batch_dims);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 13));
}
TEST(type_prop, gather_nd_v8_fail_batch_dims_greater_indices_rank) {
Shape data_shape{2, 3, 4, 5};
Shape indices_shape{2, 1};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
try {
auto G5 = make_shared<op::v8::GatherND>(P, I, 3);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 3);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -413,14 +644,14 @@ TEST(type_prop, gather_nd_8_fail_batch_dims_greater_indices_rank) {
}
}
TEST(type_prop, gather_nd_8_fail_unequal_batch_dims) {
Shape params_shape{2, 3, 4, 5};
Shape indices_shape{2, 1, 4};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
TEST(type_prop, gather_nd_v8_fail_unequal_batch_dims) {
Shape data_shape{2, 3, 4, 5};
Shape indices_shape{2, 1, 2};
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
try {
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -430,14 +661,14 @@ TEST(type_prop, gather_nd_8_fail_unequal_batch_dims) {
}
}
TEST(type_prop, gather_nd_8_fail_indices_tuple_greater_data_rank_batch_dims2) {
Shape params_shape{2, 1, 4, 5};
TEST(type_prop, gather_nd_v8_fail_indices_tuple_greater_data_rank_batch_dims2) {
Shape data_shape{2, 1, 4, 5};
Shape indices_shape{2, 1, 5, 3};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto data_param = make_shared<v0::Parameter>(element::f32, data_shape);
auto indices_param = make_shared<v0::Parameter>(element::i32, indices_shape);
try {
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
auto op = make_shared<v8::GatherND>(data_param, indices_param, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {

View File

@ -5,6 +5,7 @@
#include <openvino/core/node.hpp>
#include <openvino/opsets/opset1.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/opsets/opset5.hpp>
#include <openvino/opsets/opset7.hpp>
#include "assign_shape_inference.hpp"
@ -33,6 +34,7 @@
#include "fake_quantize.hpp"
#include "fft_base_shape_inference.hpp"
#include "gather_elements_shape_inference.hpp"
#include "gather_nd_shape_inference.hpp"
#include "gather_shape_inference.hpp"
#include "gather_tree_shape_inference.hpp"
#include "grid_sample_shape_inference.hpp"
@ -550,6 +552,7 @@ const IShapeInferCommonFactory::TRegistry IShapeInferCommonFactory::registry{
_OV_OP_SHAPE_INFER_REG(Eye, entryIOC),
_OV_OP_SHAPE_INFER_REG(FakeQuantize, entryIO),
_OV_OP_SHAPE_INFER_REG(GatherElements, entryIO),
_OV_OP_SHAPE_INFER_REG(GatherND, entryIO),
_OV_OP_SHAPE_INFER_REG(GatherTree, entryIO),
_OV_OP_SHAPE_INFER_REG(GridSample, entryIO),
_OV_OP_SHAPE_INFER_REG(GRUCell, entryIO),
@ -605,6 +608,8 @@ const IShapeInferCommonFactory::TRegistry IShapeInferCommonFactory::registry{
_OV_OP_SHAPE_INFER_VA_REG(ReduceSum, entryIOC, op::util::ArithmeticReductionKeepDims),
// opset7
_OV_OP_SHAPE_INFER_VA_REG(opset7::Gather, entryIOC, ov::op::util::GatherBase),
// opset5
_OV_OP_SHAPE_INFER_REG(opset5::GatherND, entryIO),
// opset3
_OV_OP_SHAPE_INFER_REG(opset3::Assign, entryIO),
_OV_OP_SHAPE_INFER_REG(opset3::ReadValue, entryIO),

View File

@ -0,0 +1,154 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gather_nd_shape_inference.hpp"
#include <gtest/gtest.h>
#include "openvino/op/ops.hpp"
#include "openvino/util/common_util.hpp"
#include "utils.hpp"
#include "utils/shape_inference/shape_inference.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
struct GatherNDTestParams {
ShapeVector input_shapes;
StaticShape exp_shape;
size_t batch_dims;
};
namespace {
template <class TGatherND>
std::shared_ptr<TGatherND> make_gather_nd(size_t batch_dims) {
auto data_param = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto indicies_param = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
return std::make_shared<TGatherND>(data_param, indicies_param, batch_dims);
}
template <typename TGatherND>
void run_gather_nd_test(const GatherNDTestParams& test_params) {
auto op = make_gather_nd<TGatherND>(test_params.batch_dims);
ShapeVector output_shapes(1);
shape_inference(op.get(), test_params.input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], test_params.exp_shape)
<< "Failed for input shapes: " << ov::util::vector_to_string(test_params.input_shapes)
<< " and batch_dims = " << test_params.batch_dims << std::endl;
}
std::string print_params(const testing::TestParamInfo<GatherNDTestParams>& test_params) {
std::ostringstream results;
results << "in_shapes: " << ov::util::vector_to_string(test_params.param.input_shapes)
<< ", batch_dims = " << test_params.param.batch_dims;
return results.str();
}
} // namespace
template <class TGatherND>
class StaticShapeInferenceGatherNDTest : public OpStaticShapeInferenceTest<TGatherND> {};
// Output shape for V5 and V8 is the same, when batch_dims attribute is less than 2
const auto GatherNDGatherNDTestParams = std::vector<GatherNDTestParams>{
// Test: batch_dims = 0
GatherNDTestParams{ShapeVector{{8}, {1}}, StaticShape{}, 0},
GatherNDTestParams{ShapeVector{{8}, {1, 1}}, StaticShape{1}, 0},
GatherNDTestParams{ShapeVector{{8}, {5, 1}}, StaticShape{5}, 0},
GatherNDTestParams{ShapeVector{{8, 11}, {2}}, StaticShape{}, 0},
GatherNDTestParams{ShapeVector{{8, 11}, {5, 2}}, StaticShape{5}, 0},
GatherNDTestParams{ShapeVector{{8, 11, 12}, {2}}, StaticShape{12}, 0},
GatherNDTestParams{ShapeVector{{8, 11, 12}, {5, 2}}, StaticShape{5, 12}, 0},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {2}}, StaticShape{11, 12}, 0},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {2, 1}}, StaticShape{2, 3, 11, 12}, 0},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {2, 2}}, StaticShape{2, 11, 12}, 0},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {2, 5, 4}}, StaticShape{2, 5}, 0},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {2, 5, 20, 3}}, StaticShape{2, 5, 20, 12}, 0},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {6, 4, 2}}, StaticShape{6, 4, 11, 12}, 0},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {8, 4, 2}}, StaticShape{8, 4, 11, 12}, 0},
GatherNDTestParams{ShapeVector{{7, 3, 11, 12}, {8, 6, 5, 4, 1}}, StaticShape{8, 6, 5, 4, 3, 11, 12}, 0},
GatherNDTestParams{ShapeVector{{7, 3, 11, 12}, {8, 6, 5, 4, 2}}, StaticShape{8, 6, 5, 4, 11, 12}, 0},
GatherNDTestParams{ShapeVector{{7, 3, 11, 12}, {8, 6, 5, 4, 3}}, StaticShape{8, 6, 5, 4, 12}, 0},
GatherNDTestParams{ShapeVector{{7, 3, 11, 12}, {8, 6, 5, 4, 4}}, StaticShape{8, 6, 5, 4}, 0},
GatherNDTestParams{ShapeVector{{7, 3, 11}, {8, 6, 5, 4, 1}}, StaticShape{8, 6, 5, 4, 3, 11}, 0},
// Test: batch_dims = 1
GatherNDTestParams{ShapeVector{{8, 11}, {8, 1}}, StaticShape{8}, 1},
GatherNDTestParams{ShapeVector{{8, 11, 12}, {8, 1}}, StaticShape{8, 12}, 1},
GatherNDTestParams{ShapeVector{{8, 11, 12}, {8, 2}}, StaticShape{8}, 1},
GatherNDTestParams{ShapeVector{{8, 11, 12}, {8, 5, 1}}, StaticShape{8, 5, 12}, 1},
GatherNDTestParams{ShapeVector{{8, 11, 12}, {8, 5, 2}}, StaticShape{8, 5}, 1},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {8, 2}}, StaticShape{8, 12}, 1},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {8, 2, 1}}, StaticShape{8, 2, 11, 12}, 1},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {8, 5, 2}}, StaticShape{8, 5, 12}, 1},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {8, 5, 3}}, StaticShape{8, 5}, 1},
GatherNDTestParams{ShapeVector{{8, 3, 11, 12}, {8, 7, 4, 2}}, StaticShape{8, 7, 4, 12}, 1},
GatherNDTestParams{ShapeVector{{7, 3, 11, 12}, {7, 6, 5, 4, 1}}, StaticShape{7, 6, 5, 4, 11, 12}, 1},
GatherNDTestParams{ShapeVector{{7, 3, 11, 12}, {7, 6, 5, 4, 2}}, StaticShape{7, 6, 5, 4, 12}, 1},
GatherNDTestParams{ShapeVector{{7, 3, 11, 12}, {7, 6, 5, 4, 3}}, StaticShape{7, 6, 5, 4}, 1}};
TYPED_TEST_SUITE_P(StaticShapeInferenceGatherNDTest);
TYPED_TEST_P(StaticShapeInferenceGatherNDTest, gather_nd_common_batch_dims) {
for (const auto& params : GatherNDGatherNDTestParams) {
run_gather_nd_test<TypeParam>(params);
}
}
TYPED_TEST_P(StaticShapeInferenceGatherNDTest, gather_nd_common_default_ctor) {
auto op = std::make_shared<TypeParam>();
op->set_batch_dims(1);
ShapeVector input_shapes{{8, 3, 11, 12}, {8, 5, 2}};
ShapeVector output_shapes(1);
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], (StaticShape{8, 5, 12}));
}
REGISTER_TYPED_TEST_SUITE_P(StaticShapeInferenceGatherNDTest,
gather_nd_common_batch_dims,
gather_nd_common_default_ctor);
using GatherNDTypes = Types<op::v5::GatherND, op::v8::GatherND>;
INSTANTIATE_TYPED_TEST_SUITE_P(shape_infer, StaticShapeInferenceGatherNDTest, GatherNDTypes);
// ------------------------------ V5 ------------------------------
class StaticShapeInferenceGatherNDV5Test : public TestWithParam<GatherNDTestParams> {};
TEST_P(StaticShapeInferenceGatherNDV5Test, gather_nd_v5_test) {
run_gather_nd_test<op::v5::GatherND>(GetParam());
}
INSTANTIATE_TEST_SUITE_P(
shape_infer,
StaticShapeInferenceGatherNDV5Test,
::testing::Values(GatherNDTestParams{ShapeVector{{6, 4, 11, 12, 13}, {6, 4, 2}}, StaticShape{24, 13}, 2},
GatherNDTestParams{ShapeVector{{6, 4, 11, 12, 13}, {6, 4, 5, 7, 2}}, StaticShape{24, 5, 7, 13}, 2},
GatherNDTestParams{ShapeVector{{6, 4, 11, 12, 13}, {6, 4, 3}}, StaticShape{24}, 2},
GatherNDTestParams{ShapeVector{{6, 4, 11, 12, 13}, {6, 4, 5, 3}}, StaticShape{24, 5}, 2},
GatherNDTestParams{ShapeVector{{6, 4, 1, 12, 13}, {6, 4, 1, 1}}, StaticShape{24, 13}, 3},
GatherNDTestParams{ShapeVector{{6, 4, 1, 12, 13}, {6, 4, 1, 2}}, StaticShape{24}, 3},
GatherNDTestParams{ShapeVector{{6, 4, 1, 12, 13}, {6, 4, 1, 5, 2}}, StaticShape{24, 5}, 3}),
print_params);
// ------------------------------ V8 ------------------------------
class StaticShapeInferenceGatherNDV8Test : public TestWithParam<GatherNDTestParams> {};
TEST_P(StaticShapeInferenceGatherNDV8Test, gather_nd_v8_test) {
run_gather_nd_test<op::v8::GatherND>(GetParam());
}
INSTANTIATE_TEST_SUITE_P(
shape_infer,
StaticShapeInferenceGatherNDV8Test,
::testing::Values(GatherNDTestParams{ShapeVector{{6, 4, 11, 12, 13}, {6, 4, 2}}, StaticShape{6, 4, 13}, 2},
GatherNDTestParams{ShapeVector{{6, 4, 11, 12, 13}, {6, 4, 5, 7, 2}}, StaticShape{6, 4, 5, 7, 13}, 2},
GatherNDTestParams{ShapeVector{{6, 4, 11, 12, 13}, {6, 4, 3}}, StaticShape{6, 4}, 2},
GatherNDTestParams{ShapeVector{{6, 4, 11, 12, 13}, {6, 4, 5, 3}}, StaticShape{6, 4, 5}, 2},
GatherNDTestParams{ShapeVector{{6, 4, 1, 12, 13}, {6, 4, 1, 1}}, StaticShape{6, 4, 1, 13}, 3},
GatherNDTestParams{ShapeVector{{6, 4, 1, 12, 13}, {6, 4, 1, 2}}, StaticShape{6, 4, 1}, 3},
GatherNDTestParams{ShapeVector{{6, 4, 1, 12, 13}, {6, 4, 1, 5, 2}}, StaticShape{6, 4, 1, 5}, 3}),
print_params);