From c72c6ba3315b8f72524dd15637f5bdf362d106ac Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Tue, 24 Jan 2023 10:20:28 +0100 Subject: [PATCH] [ShapeInference] Improvements of GatherElements shape inference (#15264) * GatherElements shape infer improvements * Add new tests * Update test error message --- .../include/openvino/op/gather_elements.hpp | 4 - .../gather_elements_shape_inference.hpp | 44 ++-- src/core/src/op/gather_elements.cpp | 5 +- src/core/tests/type_prop/gather_elements.cpp | 244 ++++++++++++------ .../gather_elements_shape_inference.cpp | 25 -- .../gather_elements_shape_inference_test.cpp | 65 +++++ 6 files changed, 244 insertions(+), 143 deletions(-) delete mode 100644 src/plugins/intel_cpu/tests/unit/shape_inference_test/gather_elements_shape_inference.cpp create mode 100644 src/plugins/intel_cpu/tests/unit/shape_inference_test/gather_elements_shape_inference_test.cpp diff --git a/src/core/include/openvino/op/gather_elements.hpp b/src/core/include/openvino/op/gather_elements.hpp index e85384aee54..c184d86e362 100644 --- a/src/core/include/openvino/op/gather_elements.hpp +++ b/src/core/include/openvino/op/gather_elements.hpp @@ -37,10 +37,6 @@ public: private: int64_t m_axis{0}; - template - void friend shape_infer(const GatherElements* op, - const std::vector& input_shapes, - std::vector& output_shapes); }; } // namespace v6 } // namespace op diff --git a/src/core/shape_inference/include/gather_elements_shape_inference.hpp b/src/core/shape_inference/include/gather_elements_shape_inference.hpp index 426e816b5b1..8bc763371c3 100644 --- a/src/core/shape_inference/include/gather_elements_shape_inference.hpp +++ b/src/core/shape_inference/include/gather_elements_shape_inference.hpp @@ -20,14 +20,11 @@ void shape_infer(const GatherElements* op, const std::vector& input_shapes, s auto indices_rank = indices_pshape.rank(); auto& output_shape = output_shapes[0]; - int64_t axis = op->m_axis; + int64_t axis = op->get_axis(); if (data_rank.is_static()) axis = ov::normalize_axis(op, axis, data_rank); - output_shape = indices_pshape; - NODE_VALIDATION_CHECK(op, data_rank.is_dynamic() || data_rank.get_length() >= 1, "data rank must be >= 1."); - NODE_VALIDATION_CHECK(op, indices_rank.is_dynamic() || indices_rank.get_length() >= 1, "indices rank must be >= 1."); @@ -40,9 +37,11 @@ void shape_infer(const GatherElements* op, const std::vector& input_shapes, s } if (data_rank.is_dynamic()) { - // can't decide rank, set it to all dynamic - if (indices_rank.is_dynamic()) + if (indices_rank.is_dynamic()) { output_shape = PartialShape::dynamic(); + return; + } + output_shape = indices_pshape; return; } @@ -54,25 +53,20 @@ void shape_infer(const GatherElements* op, const std::vector& input_shapes, s " and ", indices_rank.get_length()); - for (int i = 0; i < indices_rank.get_length(); i++) { - if (i != axis) { - // if size of the current dimension of indices is unknown it will be retrieved from data - // e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0 - // (and if intervals intersect) then output_pshape will be {1, 4, 5} - - NODE_VALIDATION_CHECK(op, - data_pshape[i].compatible(indices_pshape[i]), - "Shapes ", - data_pshape, - " and ", - indices_pshape, - " are not consistent. data and indices must have equal or " - "intersecting sizes, except for axis ", - axis); - - output_shape[i] = data_pshape[i] & indices_pshape[i]; - } - } + // if size of the current dimension of indices is unknown it will be retrieved from data + // e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0 + // (and if intervals intersect) then output_pshape will be {1, 4, 5} + output_shape = data_pshape; + output_shape[axis] = indices_pshape[axis]; + NODE_VALIDATION_CHECK(op, + output_shape.merge_into(output_shape, indices_pshape), + "Shapes ", + data_pshape, + " and ", + indices_pshape, + " are not consistent, `data` and `indices` must have equal or " + "intersecting dimensions, except for the dimension at axis index.", + axis); } } // namespace v6 } // namespace op diff --git a/src/core/src/op/gather_elements.cpp b/src/core/src/op/gather_elements.cpp index 3bf1cac3520..391e4bb6e6c 100644 --- a/src/core/src/op/gather_elements.cpp +++ b/src/core/src/op/gather_elements.cpp @@ -30,9 +30,8 @@ void op::v6::GatherElements::validate_and_infer_types() { "indices must be of int32 or int64 type. But instead got: ", indices_type); - const auto& data_pshape = get_input_partial_shape(0); - const auto& indices_pshape = get_input_partial_shape(1); - std::vector input_shapes = {data_pshape, indices_pshape}, output_shapes = {PartialShape{}}; + const auto input_shapes = get_node_input_partial_shapes(*this); + auto output_shapes = std::vector(1); shape_infer(this, input_shapes, output_shapes); set_output_type(0, data_type, output_shapes[0]); } diff --git a/src/core/tests/type_prop/gather_elements.cpp b/src/core/tests/type_prop/gather_elements.cpp index 9d4906277e2..87fb269e1f5 100644 --- a/src/core/tests/type_prop/gather_elements.cpp +++ b/src/core/tests/type_prop/gather_elements.cpp @@ -2,145 +2,219 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "dimension_tracker.hpp" #include "gtest/gtest.h" -#include "ngraph/ngraph.hpp" +#include "openvino/op/ops.hpp" #include "util/type_prop.hpp" using namespace std; -using namespace ngraph; +using namespace ov; +using namespace op; +using namespace testing; // ------------------------------ V6 ------------------------------ +TEST(type_prop, gather_elements_default_constructor) { + PartialShape data_shape{1, 2, 3, 4}; + PartialShape indices_shape{1, 2, 10, 4}; + int64_t axis = -2; + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(); + + op->set_axis(axis); + EXPECT_EQ(op->get_axis(), axis); + + op->set_argument(0, data); + op->set_argument(1, indices); + + op->validate_and_infer_types(); + + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(op->get_output_partial_shape(0), indices_shape); +} + TEST(type_prop, gather_elements_2D_axis_0) { Shape data_shape{3, 3}; Shape indices_shape{2, 3}; int axis = 0; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::i32, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); - ASSERT_EQ(GE->get_shape(), indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(op->get_shape(), indices_shape); } TEST(type_prop, gather_elements_2D_axis_1) { Shape data_shape{3, 3}; Shape indices_shape{3, 1}; int axis = 1; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::i32, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); - ASSERT_EQ(GE->get_shape(), indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(op->get_shape(), indices_shape); } TEST(type_prop, gather_elements_3D_axis_0) { Shape data_shape{3, 3, 10000}; Shape indices_shape{300, 3, 10000}; int64_t axis = 0; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::i32, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); - ASSERT_EQ(GE->get_shape(), indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(op->get_shape(), indices_shape); } TEST(type_prop, gather_elements_3D_axis_2) { Shape data_shape{300, 3, 10}; Shape indices_shape{300, 3, 10000}; int64_t axis = 2; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::i32, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); - ASSERT_EQ(GE->get_shape(), indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(op->get_shape(), indices_shape); } TEST(type_prop, gather_elements_4D_axis_minus_1) { Shape data_shape{300, 3, 10, 1}; Shape indices_shape{300, 3, 10, 33333}; int64_t axis = -1; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::i32, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::f32); - ASSERT_EQ(GE->get_shape(), indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(op->get_shape(), indices_shape); } TEST(type_prop, gather_elements_nonfloat_data_type_int64_indices) { Shape data_shape{300, 3, 10, 1}; Shape indices_shape{300, 3, 10, 33333}; int64_t axis = -1; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); - ASSERT_EQ(GE->get_shape(), indices_shape); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::i8); + EXPECT_EQ(op->get_shape(), indices_shape); } TEST(type_prop, gather_elements_dynamic_consistent_shapes) { PartialShape data_shape{4, 4, Dimension::dynamic()}; PartialShape indices_shape{1, Dimension::dynamic(), 5}; int64_t axis = 0; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); - ASSERT_EQ(GE->get_shape(), Shape({1, 4, 5})); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::i8); + EXPECT_EQ(op->get_shape(), Shape({1, 4, 5})); } TEST(type_prop, gather_elements_dynamic_out_shape) { PartialShape data_shape{4, 4, Dimension::dynamic()}; PartialShape indices_shape{1, Dimension::dynamic(), Dimension::dynamic()}; int64_t axis = 0; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); - ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()})); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::i8); + EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()})); } TEST(type_prop, gather_elements_interval_shapes) { PartialShape data_shape{4, Dimension(1, 7), 5}; PartialShape indices_shape{1, Dimension(5, 10), 5}; int64_t axis = 0; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); - ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, Dimension(5, 7), 5})); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::i8); + EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, {5, 7}, 5})); } TEST(type_prop, gather_elements_data_rank_dynamic_indices_rank_static) { PartialShape data_shape = PartialShape::dynamic(); PartialShape indices_shape{4, 7, 5}; int64_t axis = 0; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); - ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({4, 7, 5})); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::i8); + EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({4, 7, 5})); } TEST(type_prop, gather_elements_data_rank_static_indices_rank_dynamic) { PartialShape data_shape{4, Dimension(1, 7), 5}; PartialShape indices_shape = PartialShape::dynamic(); int64_t axis = 0; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); - ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), Dimension(1, 7), 5})); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::i8); + EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), {1, 7}, 5})); } TEST(type_prop, gather_elements_data_pshape_static_indices_rank_dynamic) { PartialShape data_shape{4, 7, 5}; PartialShape indices_shape = PartialShape::dynamic(); int64_t axis = 0; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); - auto GE = make_shared(D, I, axis); - ASSERT_EQ(GE->get_element_type(), element::Type_t::i8); - ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 7, 5})); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); + auto op = make_shared(data, indices, axis); + EXPECT_EQ(op->get_element_type(), element::Type_t::i8); + EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({-1, 7, 5})); +} + +TEST(type_prop, gather_elements_interval_dims_with_labels_both_inputs) { + PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}}; + set_shape_labels(data_shape, 10); + + PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}}; + set_shape_labels(indices_shape, 20); + + int64_t axis = 2; + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(data, indices, axis); + + const auto& out_shape = op->get_output_partial_shape(0); + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4})); + EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22, 23, 24, 25)); +} + +TEST(type_prop, gather_elements_interval_dims_with_labels_data) { + PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}}; + set_shape_labels(data_shape, 10); + + PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}}; + + int64_t axis = 2; + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(data, indices, axis); + + const auto& out_shape = op->get_output_partial_shape(0); + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4})); + EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, ov::no_label, 13, 14, 15)); +} + +TEST(type_prop, gather_elements_interval_dims_with_labels_indices) { + PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}}; + PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}}; + set_shape_labels(indices_shape, 20); + + int64_t axis = 2; + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); + auto op = make_shared(data, indices, axis); + + const auto& out_shape = op->get_output_partial_shape(0); + EXPECT_EQ(op->get_element_type(), element::Type_t::f32); + EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4})); + EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22, 23, 24, 25)); } // --------------------- Negative tests ------------------------------ @@ -148,11 +222,11 @@ TEST(type_prop, gather_elements_type_inconsistency) { Shape data_shape{3, 3}; Shape indices_shape{2, 1}; int64_t axis = 1; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::u32, indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::u32, indices_shape); try { - auto GE = make_shared(D, I, axis); + auto op = make_shared(data, indices, axis); // Should have thrown, so fail if it didn't FAIL() << "the indices tensor type check failed"; } catch (const NodeValidationFailure& error) { @@ -166,11 +240,11 @@ TEST(type_prop, gather_elements_out_of_bounds_axis) { Shape data_shape{3, 3}; Shape indices_shape{2, 1}; int64_t axis = -33; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::i32, indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); try { - auto GE = make_shared(D, I, axis); + auto op = make_shared(data, indices, axis); // Should have thrown, so fail if it didn't FAIL() << "axis out of bounds check failed"; } catch (const ov::AssertFailure& error) { @@ -184,11 +258,11 @@ TEST(type_prop, gather_elements_rank_consistency_check) { Shape data_shape{3, 3}; Shape indices_shape{2, 3, 3333}; int64_t axis = 0; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::i32, indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); try { - auto GE = make_shared(D, I, axis); + auto op = make_shared(data, indices, axis); // Should have thrown, so fail if it didn't FAIL() << "rank consistency check failed"; } catch (const NodeValidationFailure& error) { @@ -202,16 +276,15 @@ TEST(type_prop, gather_elements_shape_inconsistency) { Shape data_shape{3, 3}; Shape indices_shape{2, 1}; int64_t axis = 1; - auto D = make_shared(element::Type_t::f32, data_shape); - auto I = make_shared(element::Type_t::i32, indices_shape); + const auto data = make_shared(element::Type_t::f32, data_shape); + const auto indices = make_shared(element::Type_t::i32, indices_shape); try { - auto GE = make_shared(D, I, axis); + auto op = make_shared(data, indices, axis); // Should have thrown, so fail if it didn't FAIL() << "Shape inconsistency check failed"; } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), - std::string("data and indices must have equal or intersecting sizes, except for axis")); + EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent")); } catch (...) { FAIL() << "Shape inconsistency check failed for unexpected reason"; } @@ -221,16 +294,15 @@ TEST(type_prop, gather_elements_dynamic_inconsistent_shapes) { PartialShape data_shape{4, 2, 4, Dimension::dynamic()}; PartialShape indices_shape{1, 3, Dimension::dynamic(), 5}; int64_t axis = 0; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); try { - auto GE = make_shared(D, I, axis); + auto op = make_shared(data, indices, axis); // Should have thrown, so fail if it didn't FAIL() << "Shape inconsistency check for dynamic PartialShape failed"; } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), - std::string("data and indices must have equal or intersecting sizes, except for axis")); + EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent")); } catch (...) { FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason"; } @@ -240,15 +312,15 @@ TEST(type_prop, gather_elements_incosistent_interval_shapes) { PartialShape data_shape{4, 4, 5}; PartialShape indices_shape{1, Dimension(5, 10), 5}; int64_t axis = 0; - auto D = make_shared(element::Type_t::i8, data_shape); - auto I = make_shared(element::Type_t::i64, indices_shape); + const auto data = make_shared(element::Type_t::i8, data_shape); + const auto indices = make_shared(element::Type_t::i64, indices_shape); + try { - auto GE = make_shared(D, I, axis); + auto op = make_shared(data, indices, axis); // Should have thrown, so fail if it didn't FAIL() << "Shape inconsistency check for dynamic PartialShape failed"; } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), - std::string("data and indices must have equal or intersecting sizes, except for axis")); + EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent")); } catch (...) { FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason"; } diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/gather_elements_shape_inference.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/gather_elements_shape_inference.cpp deleted file mode 100644 index 61e13f7e8be..00000000000 --- a/src/plugins/intel_cpu/tests/unit/shape_inference_test/gather_elements_shape_inference.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// -#include - -#include -#include -#include -#include -#include - -using namespace ov; -using namespace ov::intel_cpu; - -TEST(StaticShapeInferenceTest, GatherElementsTest) { - int64_t axis = -1; - auto D = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); - auto I = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1}); - auto GE = std::make_shared(D, I, axis); - // Test StaticShape - std::vector static_input_shapes = {StaticShape{300, 3, 10, 1}, StaticShape{300, 3, 10, 33333}}, - static_output_shapes = {StaticShape{}}; - shape_inference(GE.get(), static_input_shapes, static_output_shapes); - ASSERT_EQ(static_output_shapes[0], (StaticShape{300, 3, 10, 33333})); -} diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/gather_elements_shape_inference_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/gather_elements_shape_inference_test.cpp new file mode 100644 index 00000000000..7813596547d --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/gather_elements_shape_inference_test.cpp @@ -0,0 +1,65 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include + +#include "common_test_utils/test_assertions.hpp" +#include "gather_elements_shape_inference.hpp" +#include "openvino/op/ops.hpp" +#include "utils.hpp" + +using namespace ov; +using namespace ov::intel_cpu; +using namespace testing; + +class GatherElementsStaticShapeInferenceTest : public OpStaticShapeInferenceTest {}; + +TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_basic) { + int64_t axis = -1; + const auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + const auto indices = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1}); + + op = make_op(data, indices, axis); + input_shapes = {StaticShape{300, 3, 10, 2}, StaticShape{300, 3, 10, 33333}}; + output_shapes = {StaticShape{}}; + + shape_inference(op.get(), input_shapes, output_shapes); + EXPECT_EQ(output_shapes[0], (StaticShape{300, 3, 10, 33333})); +} + +TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_incompatible_rank) { + int64_t axis = -1; + const auto data = std::make_shared(element::f32, PartialShape::dynamic()); + const auto indices = std::make_shared(element::i32, PartialShape::dynamic()); + + op = make_op(data, indices, axis); + input_shapes = {StaticShape{1, 2, 3, 4, 5}, StaticShape{1, 2, 3, 4}}; + output_shapes = {StaticShape{}}; + OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes), + ov::NodeValidationFailure, + HasSubstr("rank must be equal")); +} + +TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_incompatible_dims) { + int64_t axis = -1; + const auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + const auto indices = std::make_shared(element::i32, PartialShape{-1, -1, -1, -1}); + + op = make_op(data, indices, axis); + input_shapes = {StaticShape{300, 4, 10, 2}, StaticShape{300, 5, 10, 33333}}; + output_shapes = {StaticShape{}}; + OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes), + ov::NodeValidationFailure, + HasSubstr("are not consistent")); +} + +TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_default_constructor) { + int64_t axis = -1; + op = make_op(); + op->set_axis(axis); + input_shapes = {StaticShape{300, 3, 10, 2}, StaticShape{300, 3, 10, 33333}}; + output_shapes = {StaticShape{}}; + + shape_infer(op.get(), input_shapes, output_shapes); + EXPECT_EQ(output_shapes[0], (StaticShape{300, 3, 10, 33333})); +}