[ShapeInference] Improvements of GatherElements shape inference (#15264)
* GatherElements shape infer improvements * Add new tests * Update test error message
This commit is contained in:
parent
9ee80d67b7
commit
c72c6ba331
@ -37,10 +37,6 @@ public:
|
||||
|
||||
private:
|
||||
int64_t m_axis{0};
|
||||
template <class T>
|
||||
void friend shape_infer(const GatherElements* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes);
|
||||
};
|
||||
} // namespace v6
|
||||
} // namespace op
|
||||
|
@ -20,14 +20,11 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
|
||||
auto indices_rank = indices_pshape.rank();
|
||||
auto& output_shape = output_shapes[0];
|
||||
|
||||
int64_t axis = op->m_axis;
|
||||
int64_t axis = op->get_axis();
|
||||
if (data_rank.is_static())
|
||||
axis = ov::normalize_axis(op, axis, data_rank);
|
||||
|
||||
output_shape = indices_pshape;
|
||||
|
||||
NODE_VALIDATION_CHECK(op, data_rank.is_dynamic() || data_rank.get_length() >= 1, "data rank must be >= 1.");
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
indices_rank.is_dynamic() || indices_rank.get_length() >= 1,
|
||||
"indices rank must be >= 1.");
|
||||
@ -40,9 +37,11 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
|
||||
}
|
||||
|
||||
if (data_rank.is_dynamic()) {
|
||||
// can't decide rank, set it to all dynamic
|
||||
if (indices_rank.is_dynamic())
|
||||
if (indices_rank.is_dynamic()) {
|
||||
output_shape = PartialShape::dynamic();
|
||||
return;
|
||||
}
|
||||
output_shape = indices_pshape;
|
||||
return;
|
||||
}
|
||||
|
||||
@ -54,25 +53,20 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
|
||||
" and ",
|
||||
indices_rank.get_length());
|
||||
|
||||
for (int i = 0; i < indices_rank.get_length(); i++) {
|
||||
if (i != axis) {
|
||||
// if size of the current dimension of indices is unknown it will be retrieved from data
|
||||
// e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0
|
||||
// (and if intervals intersect) then output_pshape will be {1, 4, 5}
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
data_pshape[i].compatible(indices_pshape[i]),
|
||||
"Shapes ",
|
||||
data_pshape,
|
||||
" and ",
|
||||
indices_pshape,
|
||||
" are not consistent. data and indices must have equal or "
|
||||
"intersecting sizes, except for axis ",
|
||||
axis);
|
||||
|
||||
output_shape[i] = data_pshape[i] & indices_pshape[i];
|
||||
}
|
||||
}
|
||||
// if size of the current dimension of indices is unknown it will be retrieved from data
|
||||
// e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0
|
||||
// (and if intervals intersect) then output_pshape will be {1, 4, 5}
|
||||
output_shape = data_pshape;
|
||||
output_shape[axis] = indices_pshape[axis];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
output_shape.merge_into(output_shape, indices_pshape),
|
||||
"Shapes ",
|
||||
data_pshape,
|
||||
" and ",
|
||||
indices_pshape,
|
||||
" are not consistent, `data` and `indices` must have equal or "
|
||||
"intersecting dimensions, except for the dimension at axis index.",
|
||||
axis);
|
||||
}
|
||||
} // namespace v6
|
||||
} // namespace op
|
||||
|
@ -30,9 +30,8 @@ void op::v6::GatherElements::validate_and_infer_types() {
|
||||
"indices must be of int32 or int64 type. But instead got: ",
|
||||
indices_type);
|
||||
|
||||
const auto& data_pshape = get_input_partial_shape(0);
|
||||
const auto& indices_pshape = get_input_partial_shape(1);
|
||||
std::vector<PartialShape> input_shapes = {data_pshape, indices_pshape}, output_shapes = {PartialShape{}};
|
||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||
auto output_shapes = std::vector<ov::PartialShape>(1);
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
set_output_type(0, data_type, output_shapes[0]);
|
||||
}
|
||||
|
@ -2,145 +2,219 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "openvino/op/ops.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace ov;
|
||||
using namespace op;
|
||||
using namespace testing;
|
||||
|
||||
// ------------------------------ V6 ------------------------------
|
||||
|
||||
TEST(type_prop, gather_elements_default_constructor) {
|
||||
PartialShape data_shape{1, 2, 3, 4};
|
||||
PartialShape indices_shape{1, 2, 10, 4};
|
||||
int64_t axis = -2;
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>();
|
||||
|
||||
op->set_axis(axis);
|
||||
EXPECT_EQ(op->get_axis(), axis);
|
||||
|
||||
op->set_argument(0, data);
|
||||
op->set_argument(1, indices);
|
||||
|
||||
op->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), indices_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_2D_axis_0) {
|
||||
Shape data_shape{3, 3};
|
||||
Shape indices_shape{2, 3};
|
||||
int axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_2D_axis_1) {
|
||||
Shape data_shape{3, 3};
|
||||
Shape indices_shape{3, 1};
|
||||
int axis = 1;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_3D_axis_0) {
|
||||
Shape data_shape{3, 3, 10000};
|
||||
Shape indices_shape{300, 3, 10000};
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_3D_axis_2) {
|
||||
Shape data_shape{300, 3, 10};
|
||||
Shape indices_shape{300, 3, 10000};
|
||||
int64_t axis = 2;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_4D_axis_minus_1) {
|
||||
Shape data_shape{300, 3, 10, 1};
|
||||
Shape indices_shape{300, 3, 10, 33333};
|
||||
int64_t axis = -1;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_nonfloat_data_type_int64_indices) {
|
||||
Shape data_shape{300, 3, 10, 1};
|
||||
Shape indices_shape{300, 3, 10, 33333};
|
||||
int64_t axis = -1;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_dynamic_consistent_shapes) {
|
||||
PartialShape data_shape{4, 4, Dimension::dynamic()};
|
||||
PartialShape indices_shape{1, Dimension::dynamic(), 5};
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
||||
ASSERT_EQ(GE->get_shape(), Shape({1, 4, 5}));
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||
EXPECT_EQ(op->get_shape(), Shape({1, 4, 5}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_dynamic_out_shape) {
|
||||
PartialShape data_shape{4, 4, Dimension::dynamic()};
|
||||
PartialShape indices_shape{1, Dimension::dynamic(), Dimension::dynamic()};
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()}));
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_interval_shapes) {
|
||||
PartialShape data_shape{4, Dimension(1, 7), 5};
|
||||
PartialShape indices_shape{1, Dimension(5, 10), 5};
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, Dimension(5, 7), 5}));
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, {5, 7}, 5}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_data_rank_dynamic_indices_rank_static) {
|
||||
PartialShape data_shape = PartialShape::dynamic();
|
||||
PartialShape indices_shape{4, 7, 5};
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({4, 7, 5}));
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({4, 7, 5}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_data_rank_static_indices_rank_dynamic) {
|
||||
PartialShape data_shape{4, Dimension(1, 7), 5};
|
||||
PartialShape indices_shape = PartialShape::dynamic();
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), Dimension(1, 7), 5}));
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), {1, 7}, 5}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_data_pshape_static_indices_rank_dynamic) {
|
||||
PartialShape data_shape{4, 7, 5};
|
||||
PartialShape indices_shape = PartialShape::dynamic();
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 7, 5}));
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({-1, 7, 5}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_interval_dims_with_labels_both_inputs) {
|
||||
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
|
||||
set_shape_labels(data_shape, 10);
|
||||
|
||||
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
|
||||
set_shape_labels(indices_shape, 20);
|
||||
|
||||
int64_t axis = 2;
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22, 23, 24, 25));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_interval_dims_with_labels_data) {
|
||||
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
|
||||
set_shape_labels(data_shape, 10);
|
||||
|
||||
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
|
||||
|
||||
int64_t axis = 2;
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, ov::no_label, 13, 14, 15));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_elements_interval_dims_with_labels_indices) {
|
||||
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
|
||||
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
|
||||
set_shape_labels(indices_shape, 20);
|
||||
|
||||
int64_t axis = 2;
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22, 23, 24, 25));
|
||||
}
|
||||
// --------------------- Negative tests ------------------------------
|
||||
|
||||
@ -148,11 +222,11 @@ TEST(type_prop, gather_elements_type_inconsistency) {
|
||||
Shape data_shape{3, 3};
|
||||
Shape indices_shape{2, 1};
|
||||
int64_t axis = 1;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::u32, indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::u32, indices_shape);
|
||||
|
||||
try {
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "the indices tensor type check failed";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
@ -166,11 +240,11 @@ TEST(type_prop, gather_elements_out_of_bounds_axis) {
|
||||
Shape data_shape{3, 3};
|
||||
Shape indices_shape{2, 1};
|
||||
int64_t axis = -33;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
|
||||
try {
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "axis out of bounds check failed";
|
||||
} catch (const ov::AssertFailure& error) {
|
||||
@ -184,11 +258,11 @@ TEST(type_prop, gather_elements_rank_consistency_check) {
|
||||
Shape data_shape{3, 3};
|
||||
Shape indices_shape{2, 3, 3333};
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
|
||||
try {
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "rank consistency check failed";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
@ -202,16 +276,15 @@ TEST(type_prop, gather_elements_shape_inconsistency) {
|
||||
Shape data_shape{3, 3};
|
||||
Shape indices_shape{2, 1};
|
||||
int64_t axis = 1;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||
|
||||
try {
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Shape inconsistency check failed";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("data and indices must have equal or intersecting sizes, except for axis"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
|
||||
} catch (...) {
|
||||
FAIL() << "Shape inconsistency check failed for unexpected reason";
|
||||
}
|
||||
@ -221,16 +294,15 @@ TEST(type_prop, gather_elements_dynamic_inconsistent_shapes) {
|
||||
PartialShape data_shape{4, 2, 4, Dimension::dynamic()};
|
||||
PartialShape indices_shape{1, 3, Dimension::dynamic(), 5};
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
|
||||
try {
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("data and indices must have equal or intersecting sizes, except for axis"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
|
||||
} catch (...) {
|
||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
|
||||
}
|
||||
@ -240,15 +312,15 @@ TEST(type_prop, gather_elements_incosistent_interval_shapes) {
|
||||
PartialShape data_shape{4, 4, 5};
|
||||
PartialShape indices_shape{1, Dimension(5, 10), 5};
|
||||
int64_t axis = 0;
|
||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
||||
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||
|
||||
try {
|
||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("data and indices must have equal or intersecting sizes, except for axis"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
|
||||
} catch (...) {
|
||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
|
||||
}
|
||||
|
@ -1,25 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <gather_elements_shape_inference.hpp>
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(StaticShapeInferenceTest, GatherElementsTest) {
|
||||
int64_t axis = -1;
|
||||
auto D = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
|
||||
auto GE = std::make_shared<op::v6::GatherElements>(D, I, axis);
|
||||
// Test StaticShape
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{300, 3, 10, 1}, StaticShape{300, 3, 10, 33333}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(GE.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes[0], (StaticShape{300, 3, 10, 33333}));
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "gather_elements_shape_inference.hpp"
|
||||
#include "openvino/op/ops.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace testing;
|
||||
|
||||
class GatherElementsStaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v6::GatherElements> {};
|
||||
|
||||
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_basic) {
|
||||
int64_t axis = -1;
|
||||
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
op = make_op(data, indices, axis);
|
||||
input_shapes = {StaticShape{300, 3, 10, 2}, StaticShape{300, 3, 10, 33333}};
|
||||
output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], (StaticShape{300, 3, 10, 33333}));
|
||||
}
|
||||
|
||||
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_incompatible_rank) {
|
||||
int64_t axis = -1;
|
||||
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
|
||||
op = make_op(data, indices, axis);
|
||||
input_shapes = {StaticShape{1, 2, 3, 4, 5}, StaticShape{1, 2, 3, 4}};
|
||||
output_shapes = {StaticShape{}};
|
||||
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("rank must be equal"));
|
||||
}
|
||||
|
||||
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_incompatible_dims) {
|
||||
int64_t axis = -1;
|
||||
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
op = make_op(data, indices, axis);
|
||||
input_shapes = {StaticShape{300, 4, 10, 2}, StaticShape{300, 5, 10, 33333}};
|
||||
output_shapes = {StaticShape{}};
|
||||
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("are not consistent"));
|
||||
}
|
||||
|
||||
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_default_constructor) {
|
||||
int64_t axis = -1;
|
||||
op = make_op();
|
||||
op->set_axis(axis);
|
||||
input_shapes = {StaticShape{300, 3, 10, 2}, StaticShape{300, 3, 10, 33333}};
|
||||
output_shapes = {StaticShape{}};
|
||||
|
||||
shape_infer(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], (StaticShape{300, 3, 10, 33333}));
|
||||
}
|
Loading…
Reference in New Issue
Block a user