[ShapeInference] Improvements of GatherElements shape inference (#15264)

* GatherElements shape infer improvements

* Add new tests

* Update test error message
This commit is contained in:
Katarzyna Mitrus 2023-01-24 10:20:28 +01:00 committed by GitHub
parent 9ee80d67b7
commit c72c6ba331
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 244 additions and 143 deletions

View File

@ -37,10 +37,6 @@ public:
private: private:
int64_t m_axis{0}; int64_t m_axis{0};
template <class T>
void friend shape_infer(const GatherElements* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes);
}; };
} // namespace v6 } // namespace v6
} // namespace op } // namespace op

View File

@ -20,14 +20,11 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
auto indices_rank = indices_pshape.rank(); auto indices_rank = indices_pshape.rank();
auto& output_shape = output_shapes[0]; auto& output_shape = output_shapes[0];
int64_t axis = op->m_axis; int64_t axis = op->get_axis();
if (data_rank.is_static()) if (data_rank.is_static())
axis = ov::normalize_axis(op, axis, data_rank); axis = ov::normalize_axis(op, axis, data_rank);
output_shape = indices_pshape;
NODE_VALIDATION_CHECK(op, data_rank.is_dynamic() || data_rank.get_length() >= 1, "data rank must be >= 1."); NODE_VALIDATION_CHECK(op, data_rank.is_dynamic() || data_rank.get_length() >= 1, "data rank must be >= 1.");
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op,
indices_rank.is_dynamic() || indices_rank.get_length() >= 1, indices_rank.is_dynamic() || indices_rank.get_length() >= 1,
"indices rank must be >= 1."); "indices rank must be >= 1.");
@ -40,9 +37,11 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
} }
if (data_rank.is_dynamic()) { if (data_rank.is_dynamic()) {
// can't decide rank, set it to all dynamic if (indices_rank.is_dynamic()) {
if (indices_rank.is_dynamic())
output_shape = PartialShape::dynamic(); output_shape = PartialShape::dynamic();
return;
}
output_shape = indices_pshape;
return; return;
} }
@ -54,25 +53,20 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
" and ", " and ",
indices_rank.get_length()); indices_rank.get_length());
for (int i = 0; i < indices_rank.get_length(); i++) { // if size of the current dimension of indices is unknown it will be retrieved from data
if (i != axis) { // e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0
// if size of the current dimension of indices is unknown it will be retrieved from data // (and if intervals intersect) then output_pshape will be {1, 4, 5}
// e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0 output_shape = data_pshape;
// (and if intervals intersect) then output_pshape will be {1, 4, 5} output_shape[axis] = indices_pshape[axis];
NODE_VALIDATION_CHECK(op,
NODE_VALIDATION_CHECK(op, output_shape.merge_into(output_shape, indices_pshape),
data_pshape[i].compatible(indices_pshape[i]), "Shapes ",
"Shapes ", data_pshape,
data_pshape, " and ",
" and ", indices_pshape,
indices_pshape, " are not consistent, `data` and `indices` must have equal or "
" are not consistent. data and indices must have equal or " "intersecting dimensions, except for the dimension at axis index.",
"intersecting sizes, except for axis ", axis);
axis);
output_shape[i] = data_pshape[i] & indices_pshape[i];
}
}
} }
} // namespace v6 } // namespace v6
} // namespace op } // namespace op

View File

@ -30,9 +30,8 @@ void op::v6::GatherElements::validate_and_infer_types() {
"indices must be of int32 or int64 type. But instead got: ", "indices must be of int32 or int64 type. But instead got: ",
indices_type); indices_type);
const auto& data_pshape = get_input_partial_shape(0); const auto input_shapes = get_node_input_partial_shapes(*this);
const auto& indices_pshape = get_input_partial_shape(1); auto output_shapes = std::vector<ov::PartialShape>(1);
std::vector<PartialShape> input_shapes = {data_pshape, indices_pshape}, output_shapes = {PartialShape{}};
shape_infer(this, input_shapes, output_shapes); shape_infer(this, input_shapes, output_shapes);
set_output_type(0, data_type, output_shapes[0]); set_output_type(0, data_type, output_shapes[0]);
} }

View File

@ -2,145 +2,219 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "dimension_tracker.hpp"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "openvino/op/ops.hpp"
#include "util/type_prop.hpp" #include "util/type_prop.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ov;
using namespace op;
using namespace testing;
// ------------------------------ V6 ------------------------------ // ------------------------------ V6 ------------------------------
TEST(type_prop, gather_elements_default_constructor) {
PartialShape data_shape{1, 2, 3, 4};
PartialShape indices_shape{1, 2, 10, 4};
int64_t axis = -2;
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto op = make_shared<v6::GatherElements>();
op->set_axis(axis);
EXPECT_EQ(op->get_axis(), axis);
op->set_argument(0, data);
op->set_argument(1, indices);
op->validate_and_infer_types();
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
EXPECT_EQ(op->get_output_partial_shape(0), indices_shape);
}
TEST(type_prop, gather_elements_2D_axis_0) { TEST(type_prop, gather_elements_2D_axis_0) {
Shape data_shape{3, 3}; Shape data_shape{3, 3};
Shape indices_shape{2, 3}; Shape indices_shape{2, 3};
int axis = 0; int axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
ASSERT_EQ(GE->get_shape(), indices_shape); EXPECT_EQ(op->get_shape(), indices_shape);
} }
TEST(type_prop, gather_elements_2D_axis_1) { TEST(type_prop, gather_elements_2D_axis_1) {
Shape data_shape{3, 3}; Shape data_shape{3, 3};
Shape indices_shape{3, 1}; Shape indices_shape{3, 1};
int axis = 1; int axis = 1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
ASSERT_EQ(GE->get_shape(), indices_shape); EXPECT_EQ(op->get_shape(), indices_shape);
} }
TEST(type_prop, gather_elements_3D_axis_0) { TEST(type_prop, gather_elements_3D_axis_0) {
Shape data_shape{3, 3, 10000}; Shape data_shape{3, 3, 10000};
Shape indices_shape{300, 3, 10000}; Shape indices_shape{300, 3, 10000};
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
ASSERT_EQ(GE->get_shape(), indices_shape); EXPECT_EQ(op->get_shape(), indices_shape);
} }
TEST(type_prop, gather_elements_3D_axis_2) { TEST(type_prop, gather_elements_3D_axis_2) {
Shape data_shape{300, 3, 10}; Shape data_shape{300, 3, 10};
Shape indices_shape{300, 3, 10000}; Shape indices_shape{300, 3, 10000};
int64_t axis = 2; int64_t axis = 2;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
ASSERT_EQ(GE->get_shape(), indices_shape); EXPECT_EQ(op->get_shape(), indices_shape);
} }
TEST(type_prop, gather_elements_4D_axis_minus_1) { TEST(type_prop, gather_elements_4D_axis_minus_1) {
Shape data_shape{300, 3, 10, 1}; Shape data_shape{300, 3, 10, 1};
Shape indices_shape{300, 3, 10, 33333}; Shape indices_shape{300, 3, 10, 33333};
int64_t axis = -1; int64_t axis = -1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
ASSERT_EQ(GE->get_shape(), indices_shape); EXPECT_EQ(op->get_shape(), indices_shape);
} }
TEST(type_prop, gather_elements_nonfloat_data_type_int64_indices) { TEST(type_prop, gather_elements_nonfloat_data_type_int64_indices) {
Shape data_shape{300, 3, 10, 1}; Shape data_shape{300, 3, 10, 1};
Shape indices_shape{300, 3, 10, 33333}; Shape indices_shape{300, 3, 10, 33333};
int64_t axis = -1; int64_t axis = -1;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_shape(), indices_shape); EXPECT_EQ(op->get_shape(), indices_shape);
} }
TEST(type_prop, gather_elements_dynamic_consistent_shapes) { TEST(type_prop, gather_elements_dynamic_consistent_shapes) {
PartialShape data_shape{4, 4, Dimension::dynamic()}; PartialShape data_shape{4, 4, Dimension::dynamic()};
PartialShape indices_shape{1, Dimension::dynamic(), 5}; PartialShape indices_shape{1, Dimension::dynamic(), 5};
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_shape(), Shape({1, 4, 5})); EXPECT_EQ(op->get_shape(), Shape({1, 4, 5}));
} }
TEST(type_prop, gather_elements_dynamic_out_shape) { TEST(type_prop, gather_elements_dynamic_out_shape) {
PartialShape data_shape{4, 4, Dimension::dynamic()}; PartialShape data_shape{4, 4, Dimension::dynamic()};
PartialShape indices_shape{1, Dimension::dynamic(), Dimension::dynamic()}; PartialShape indices_shape{1, Dimension::dynamic(), Dimension::dynamic()};
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()})); EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()}));
} }
TEST(type_prop, gather_elements_interval_shapes) { TEST(type_prop, gather_elements_interval_shapes) {
PartialShape data_shape{4, Dimension(1, 7), 5}; PartialShape data_shape{4, Dimension(1, 7), 5};
PartialShape indices_shape{1, Dimension(5, 10), 5}; PartialShape indices_shape{1, Dimension(5, 10), 5};
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, Dimension(5, 7), 5})); EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, {5, 7}, 5}));
} }
TEST(type_prop, gather_elements_data_rank_dynamic_indices_rank_static) { TEST(type_prop, gather_elements_data_rank_dynamic_indices_rank_static) {
PartialShape data_shape = PartialShape::dynamic(); PartialShape data_shape = PartialShape::dynamic();
PartialShape indices_shape{4, 7, 5}; PartialShape indices_shape{4, 7, 5};
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({4, 7, 5})); EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({4, 7, 5}));
} }
TEST(type_prop, gather_elements_data_rank_static_indices_rank_dynamic) { TEST(type_prop, gather_elements_data_rank_static_indices_rank_dynamic) {
PartialShape data_shape{4, Dimension(1, 7), 5}; PartialShape data_shape{4, Dimension(1, 7), 5};
PartialShape indices_shape = PartialShape::dynamic(); PartialShape indices_shape = PartialShape::dynamic();
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), Dimension(1, 7), 5})); EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), {1, 7}, 5}));
} }
TEST(type_prop, gather_elements_data_pshape_static_indices_rank_dynamic) { TEST(type_prop, gather_elements_data_pshape_static_indices_rank_dynamic) {
PartialShape data_shape{4, 7, 5}; PartialShape data_shape{4, 7, 5};
PartialShape indices_shape = PartialShape::dynamic(); PartialShape indices_shape = PartialShape::dynamic();
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 7, 5})); EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({-1, 7, 5}));
}
TEST(type_prop, gather_elements_interval_dims_with_labels_both_inputs) {
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
set_shape_labels(data_shape, 10);
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
set_shape_labels(indices_shape, 20);
int64_t axis = 2;
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto op = make_shared<v6::GatherElements>(data, indices, axis);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22, 23, 24, 25));
}
TEST(type_prop, gather_elements_interval_dims_with_labels_data) {
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
set_shape_labels(data_shape, 10);
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
int64_t axis = 2;
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto op = make_shared<v6::GatherElements>(data, indices, axis);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, ov::no_label, 13, 14, 15));
}
TEST(type_prop, gather_elements_interval_dims_with_labels_indices) {
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
set_shape_labels(indices_shape, 20);
int64_t axis = 2;
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
auto op = make_shared<v6::GatherElements>(data, indices, axis);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22, 23, 24, 25));
} }
// --------------------- Negative tests ------------------------------ // --------------------- Negative tests ------------------------------
@ -148,11 +222,11 @@ TEST(type_prop, gather_elements_type_inconsistency) {
Shape data_shape{3, 3}; Shape data_shape{3, 3};
Shape indices_shape{2, 1}; Shape indices_shape{2, 1};
int64_t axis = 1; int64_t axis = 1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::u32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::u32, indices_shape);
try { try {
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "the indices tensor type check failed"; FAIL() << "the indices tensor type check failed";
} catch (const NodeValidationFailure& error) { } catch (const NodeValidationFailure& error) {
@ -166,11 +240,11 @@ TEST(type_prop, gather_elements_out_of_bounds_axis) {
Shape data_shape{3, 3}; Shape data_shape{3, 3};
Shape indices_shape{2, 1}; Shape indices_shape{2, 1};
int64_t axis = -33; int64_t axis = -33;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
try { try {
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "axis out of bounds check failed"; FAIL() << "axis out of bounds check failed";
} catch (const ov::AssertFailure& error) { } catch (const ov::AssertFailure& error) {
@ -184,11 +258,11 @@ TEST(type_prop, gather_elements_rank_consistency_check) {
Shape data_shape{3, 3}; Shape data_shape{3, 3};
Shape indices_shape{2, 3, 3333}; Shape indices_shape{2, 3, 3333};
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
try { try {
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "rank consistency check failed"; FAIL() << "rank consistency check failed";
} catch (const NodeValidationFailure& error) { } catch (const NodeValidationFailure& error) {
@ -202,16 +276,15 @@ TEST(type_prop, gather_elements_shape_inconsistency) {
Shape data_shape{3, 3}; Shape data_shape{3, 3};
Shape indices_shape{2, 1}; Shape indices_shape{2, 1};
int64_t axis = 1; int64_t axis = 1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
try { try {
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Shape inconsistency check failed"; FAIL() << "Shape inconsistency check failed";
} catch (const NodeValidationFailure& error) { } catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
std::string("data and indices must have equal or intersecting sizes, except for axis"));
} catch (...) { } catch (...) {
FAIL() << "Shape inconsistency check failed for unexpected reason"; FAIL() << "Shape inconsistency check failed for unexpected reason";
} }
@ -221,16 +294,15 @@ TEST(type_prop, gather_elements_dynamic_inconsistent_shapes) {
PartialShape data_shape{4, 2, 4, Dimension::dynamic()}; PartialShape data_shape{4, 2, 4, Dimension::dynamic()};
PartialShape indices_shape{1, 3, Dimension::dynamic(), 5}; PartialShape indices_shape{1, 3, Dimension::dynamic(), 5};
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
try { try {
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Shape inconsistency check for dynamic PartialShape failed"; FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
} catch (const NodeValidationFailure& error) { } catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
std::string("data and indices must have equal or intersecting sizes, except for axis"));
} catch (...) { } catch (...) {
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason"; FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
} }
@ -240,15 +312,15 @@ TEST(type_prop, gather_elements_incosistent_interval_shapes) {
PartialShape data_shape{4, 4, 5}; PartialShape data_shape{4, 4, 5};
PartialShape indices_shape{1, Dimension(5, 10), 5}; PartialShape indices_shape{1, Dimension(5, 10), 5};
int64_t axis = 0; int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape); const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape); const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
try { try {
auto GE = make_shared<op::v6::GatherElements>(D, I, axis); auto op = make_shared<v6::GatherElements>(data, indices, axis);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Shape inconsistency check for dynamic PartialShape failed"; FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
} catch (const NodeValidationFailure& error) { } catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
std::string("data and indices must have equal or intersecting sizes, except for axis"));
} catch (...) { } catch (...) {
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason"; FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
} }

View File

@ -1,25 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <gather_elements_shape_inference.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, GatherElementsTest) {
int64_t axis = -1;
auto D = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
auto GE = std::make_shared<op::v6::GatherElements>(D, I, axis);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{300, 3, 10, 1}, StaticShape{300, 3, 10, 33333}},
static_output_shapes = {StaticShape{}};
shape_inference(GE.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{300, 3, 10, 33333}));
}

View File

@ -0,0 +1,65 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/test_assertions.hpp"
#include "gather_elements_shape_inference.hpp"
#include "openvino/op/ops.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
class GatherElementsStaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v6::GatherElements> {};
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_basic) {
int64_t axis = -1;
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
op = make_op(data, indices, axis);
input_shapes = {StaticShape{300, 3, 10, 2}, StaticShape{300, 3, 10, 33333}};
output_shapes = {StaticShape{}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], (StaticShape{300, 3, 10, 33333}));
}
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_incompatible_rank) {
int64_t axis = -1;
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
op = make_op(data, indices, axis);
input_shapes = {StaticShape{1, 2, 3, 4, 5}, StaticShape{1, 2, 3, 4}};
output_shapes = {StaticShape{}};
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
ov::NodeValidationFailure,
HasSubstr("rank must be equal"));
}
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_incompatible_dims) {
int64_t axis = -1;
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
op = make_op(data, indices, axis);
input_shapes = {StaticShape{300, 4, 10, 2}, StaticShape{300, 5, 10, 33333}};
output_shapes = {StaticShape{}};
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
ov::NodeValidationFailure,
HasSubstr("are not consistent"));
}
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_default_constructor) {
int64_t axis = -1;
op = make_op();
op->set_axis(axis);
input_shapes = {StaticShape{300, 3, 10, 2}, StaticShape{300, 3, 10, 33333}};
output_shapes = {StaticShape{}};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], (StaticShape{300, 3, 10, 33333}));
}