[ShapeInference] Improvements of GatherElements shape inference (#15264)
* GatherElements shape infer improvements * Add new tests * Update test error message
This commit is contained in:
parent
9ee80d67b7
commit
c72c6ba331
@ -37,10 +37,6 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
int64_t m_axis{0};
|
int64_t m_axis{0};
|
||||||
template <class T>
|
|
||||||
void friend shape_infer(const GatherElements* op,
|
|
||||||
const std::vector<T>& input_shapes,
|
|
||||||
std::vector<T>& output_shapes);
|
|
||||||
};
|
};
|
||||||
} // namespace v6
|
} // namespace v6
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -20,14 +20,11 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
|
|||||||
auto indices_rank = indices_pshape.rank();
|
auto indices_rank = indices_pshape.rank();
|
||||||
auto& output_shape = output_shapes[0];
|
auto& output_shape = output_shapes[0];
|
||||||
|
|
||||||
int64_t axis = op->m_axis;
|
int64_t axis = op->get_axis();
|
||||||
if (data_rank.is_static())
|
if (data_rank.is_static())
|
||||||
axis = ov::normalize_axis(op, axis, data_rank);
|
axis = ov::normalize_axis(op, axis, data_rank);
|
||||||
|
|
||||||
output_shape = indices_pshape;
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(op, data_rank.is_dynamic() || data_rank.get_length() >= 1, "data rank must be >= 1.");
|
NODE_VALIDATION_CHECK(op, data_rank.is_dynamic() || data_rank.get_length() >= 1, "data rank must be >= 1.");
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(op,
|
NODE_VALIDATION_CHECK(op,
|
||||||
indices_rank.is_dynamic() || indices_rank.get_length() >= 1,
|
indices_rank.is_dynamic() || indices_rank.get_length() >= 1,
|
||||||
"indices rank must be >= 1.");
|
"indices rank must be >= 1.");
|
||||||
@ -40,9 +37,11 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (data_rank.is_dynamic()) {
|
if (data_rank.is_dynamic()) {
|
||||||
// can't decide rank, set it to all dynamic
|
if (indices_rank.is_dynamic()) {
|
||||||
if (indices_rank.is_dynamic())
|
|
||||||
output_shape = PartialShape::dynamic();
|
output_shape = PartialShape::dynamic();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
output_shape = indices_pshape;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,25 +53,20 @@ void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, s
|
|||||||
" and ",
|
" and ",
|
||||||
indices_rank.get_length());
|
indices_rank.get_length());
|
||||||
|
|
||||||
for (int i = 0; i < indices_rank.get_length(); i++) {
|
// if size of the current dimension of indices is unknown it will be retrieved from data
|
||||||
if (i != axis) {
|
// e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0
|
||||||
// if size of the current dimension of indices is unknown it will be retrieved from data
|
// (and if intervals intersect) then output_pshape will be {1, 4, 5}
|
||||||
// e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0
|
output_shape = data_pshape;
|
||||||
// (and if intervals intersect) then output_pshape will be {1, 4, 5}
|
output_shape[axis] = indices_pshape[axis];
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
NODE_VALIDATION_CHECK(op,
|
output_shape.merge_into(output_shape, indices_pshape),
|
||||||
data_pshape[i].compatible(indices_pshape[i]),
|
"Shapes ",
|
||||||
"Shapes ",
|
data_pshape,
|
||||||
data_pshape,
|
" and ",
|
||||||
" and ",
|
indices_pshape,
|
||||||
indices_pshape,
|
" are not consistent, `data` and `indices` must have equal or "
|
||||||
" are not consistent. data and indices must have equal or "
|
"intersecting dimensions, except for the dimension at axis index.",
|
||||||
"intersecting sizes, except for axis ",
|
axis);
|
||||||
axis);
|
|
||||||
|
|
||||||
output_shape[i] = data_pshape[i] & indices_pshape[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} // namespace v6
|
} // namespace v6
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -30,9 +30,8 @@ void op::v6::GatherElements::validate_and_infer_types() {
|
|||||||
"indices must be of int32 or int64 type. But instead got: ",
|
"indices must be of int32 or int64 type. But instead got: ",
|
||||||
indices_type);
|
indices_type);
|
||||||
|
|
||||||
const auto& data_pshape = get_input_partial_shape(0);
|
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||||
const auto& indices_pshape = get_input_partial_shape(1);
|
auto output_shapes = std::vector<ov::PartialShape>(1);
|
||||||
std::vector<PartialShape> input_shapes = {data_pshape, indices_pshape}, output_shapes = {PartialShape{}};
|
|
||||||
shape_infer(this, input_shapes, output_shapes);
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
set_output_type(0, data_type, output_shapes[0]);
|
set_output_type(0, data_type, output_shapes[0]);
|
||||||
}
|
}
|
||||||
|
@ -2,145 +2,219 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "dimension_tracker.hpp"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "ngraph/ngraph.hpp"
|
#include "openvino/op/ops.hpp"
|
||||||
#include "util/type_prop.hpp"
|
#include "util/type_prop.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ov;
|
||||||
|
using namespace op;
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
// ------------------------------ V6 ------------------------------
|
// ------------------------------ V6 ------------------------------
|
||||||
|
|
||||||
|
TEST(type_prop, gather_elements_default_constructor) {
|
||||||
|
PartialShape data_shape{1, 2, 3, 4};
|
||||||
|
PartialShape indices_shape{1, 2, 10, 4};
|
||||||
|
int64_t axis = -2;
|
||||||
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
|
auto op = make_shared<v6::GatherElements>();
|
||||||
|
|
||||||
|
op->set_axis(axis);
|
||||||
|
EXPECT_EQ(op->get_axis(), axis);
|
||||||
|
|
||||||
|
op->set_argument(0, data);
|
||||||
|
op->set_argument(1, indices);
|
||||||
|
|
||||||
|
op->validate_and_infer_types();
|
||||||
|
|
||||||
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(0), indices_shape);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_2D_axis_0) {
|
TEST(type_prop, gather_elements_2D_axis_0) {
|
||||||
Shape data_shape{3, 3};
|
Shape data_shape{3, 3};
|
||||||
Shape indices_shape{2, 3};
|
Shape indices_shape{2, 3};
|
||||||
int axis = 0;
|
int axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_2D_axis_1) {
|
TEST(type_prop, gather_elements_2D_axis_1) {
|
||||||
Shape data_shape{3, 3};
|
Shape data_shape{3, 3};
|
||||||
Shape indices_shape{3, 1};
|
Shape indices_shape{3, 1};
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_3D_axis_0) {
|
TEST(type_prop, gather_elements_3D_axis_0) {
|
||||||
Shape data_shape{3, 3, 10000};
|
Shape data_shape{3, 3, 10000};
|
||||||
Shape indices_shape{300, 3, 10000};
|
Shape indices_shape{300, 3, 10000};
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_3D_axis_2) {
|
TEST(type_prop, gather_elements_3D_axis_2) {
|
||||||
Shape data_shape{300, 3, 10};
|
Shape data_shape{300, 3, 10};
|
||||||
Shape indices_shape{300, 3, 10000};
|
Shape indices_shape{300, 3, 10000};
|
||||||
int64_t axis = 2;
|
int64_t axis = 2;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_4D_axis_minus_1) {
|
TEST(type_prop, gather_elements_4D_axis_minus_1) {
|
||||||
Shape data_shape{300, 3, 10, 1};
|
Shape data_shape{300, 3, 10, 1};
|
||||||
Shape indices_shape{300, 3, 10, 33333};
|
Shape indices_shape{300, 3, 10, 33333};
|
||||||
int64_t axis = -1;
|
int64_t axis = -1;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_nonfloat_data_type_int64_indices) {
|
TEST(type_prop, gather_elements_nonfloat_data_type_int64_indices) {
|
||||||
Shape data_shape{300, 3, 10, 1};
|
Shape data_shape{300, 3, 10, 1};
|
||||||
Shape indices_shape{300, 3, 10, 33333};
|
Shape indices_shape{300, 3, 10, 33333};
|
||||||
int64_t axis = -1;
|
int64_t axis = -1;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||||
ASSERT_EQ(GE->get_shape(), indices_shape);
|
EXPECT_EQ(op->get_shape(), indices_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_dynamic_consistent_shapes) {
|
TEST(type_prop, gather_elements_dynamic_consistent_shapes) {
|
||||||
PartialShape data_shape{4, 4, Dimension::dynamic()};
|
PartialShape data_shape{4, 4, Dimension::dynamic()};
|
||||||
PartialShape indices_shape{1, Dimension::dynamic(), 5};
|
PartialShape indices_shape{1, Dimension::dynamic(), 5};
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||||
ASSERT_EQ(GE->get_shape(), Shape({1, 4, 5}));
|
EXPECT_EQ(op->get_shape(), Shape({1, 4, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_dynamic_out_shape) {
|
TEST(type_prop, gather_elements_dynamic_out_shape) {
|
||||||
PartialShape data_shape{4, 4, Dimension::dynamic()};
|
PartialShape data_shape{4, 4, Dimension::dynamic()};
|
||||||
PartialShape indices_shape{1, Dimension::dynamic(), Dimension::dynamic()};
|
PartialShape indices_shape{1, Dimension::dynamic(), Dimension::dynamic()};
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()}));
|
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_interval_shapes) {
|
TEST(type_prop, gather_elements_interval_shapes) {
|
||||||
PartialShape data_shape{4, Dimension(1, 7), 5};
|
PartialShape data_shape{4, Dimension(1, 7), 5};
|
||||||
PartialShape indices_shape{1, Dimension(5, 10), 5};
|
PartialShape indices_shape{1, Dimension(5, 10), 5};
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, Dimension(5, 7), 5}));
|
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({1, {5, 7}, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_data_rank_dynamic_indices_rank_static) {
|
TEST(type_prop, gather_elements_data_rank_dynamic_indices_rank_static) {
|
||||||
PartialShape data_shape = PartialShape::dynamic();
|
PartialShape data_shape = PartialShape::dynamic();
|
||||||
PartialShape indices_shape{4, 7, 5};
|
PartialShape indices_shape{4, 7, 5};
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({4, 7, 5}));
|
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({4, 7, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_data_rank_static_indices_rank_dynamic) {
|
TEST(type_prop, gather_elements_data_rank_static_indices_rank_dynamic) {
|
||||||
PartialShape data_shape{4, Dimension(1, 7), 5};
|
PartialShape data_shape{4, Dimension(1, 7), 5};
|
||||||
PartialShape indices_shape = PartialShape::dynamic();
|
PartialShape indices_shape = PartialShape::dynamic();
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), Dimension(1, 7), 5}));
|
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), {1, 7}, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, gather_elements_data_pshape_static_indices_rank_dynamic) {
|
TEST(type_prop, gather_elements_data_pshape_static_indices_rank_dynamic) {
|
||||||
PartialShape data_shape{4, 7, 5};
|
PartialShape data_shape{4, 7, 5};
|
||||||
PartialShape indices_shape = PartialShape::dynamic();
|
PartialShape indices_shape = PartialShape::dynamic();
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
|
EXPECT_EQ(op->get_element_type(), element::Type_t::i8);
|
||||||
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({Dimension::dynamic(), 7, 5}));
|
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({-1, 7, 5}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_elements_interval_dims_with_labels_both_inputs) {
|
||||||
|
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
|
||||||
|
set_shape_labels(data_shape, 10);
|
||||||
|
|
||||||
|
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
|
||||||
|
set_shape_labels(indices_shape, 20);
|
||||||
|
|
||||||
|
int64_t axis = 2;
|
||||||
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
|
|
||||||
|
const auto& out_shape = op->get_output_partial_shape(0);
|
||||||
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
|
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
|
||||||
|
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22, 23, 24, 25));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_elements_interval_dims_with_labels_data) {
|
||||||
|
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
|
||||||
|
set_shape_labels(data_shape, 10);
|
||||||
|
|
||||||
|
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
|
||||||
|
|
||||||
|
int64_t axis = 2;
|
||||||
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
|
|
||||||
|
const auto& out_shape = op->get_output_partial_shape(0);
|
||||||
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
|
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
|
||||||
|
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, ov::no_label, 13, 14, 15));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, gather_elements_interval_dims_with_labels_indices) {
|
||||||
|
PartialShape data_shape{-1, {2, 4}, {1, 5}, -1, {4, 8}, {2, 4}};
|
||||||
|
PartialShape indices_shape{-1, {3, 6}, {6, 10}, {4, 8}, -1, {4, 6}};
|
||||||
|
set_shape_labels(indices_shape, 20);
|
||||||
|
|
||||||
|
int64_t axis = 2;
|
||||||
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
|
|
||||||
|
const auto& out_shape = op->get_output_partial_shape(0);
|
||||||
|
EXPECT_EQ(op->get_element_type(), element::Type_t::f32);
|
||||||
|
EXPECT_EQ(out_shape, PartialShape({-1, {3, 4}, {6, 10}, {4, 8}, {4, 8}, 4}));
|
||||||
|
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22, 23, 24, 25));
|
||||||
}
|
}
|
||||||
// --------------------- Negative tests ------------------------------
|
// --------------------- Negative tests ------------------------------
|
||||||
|
|
||||||
@ -148,11 +222,11 @@ TEST(type_prop, gather_elements_type_inconsistency) {
|
|||||||
Shape data_shape{3, 3};
|
Shape data_shape{3, 3};
|
||||||
Shape indices_shape{2, 1};
|
Shape indices_shape{2, 1};
|
||||||
int64_t axis = 1;
|
int64_t axis = 1;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::u32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::u32, indices_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "the indices tensor type check failed";
|
FAIL() << "the indices tensor type check failed";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
@ -166,11 +240,11 @@ TEST(type_prop, gather_elements_out_of_bounds_axis) {
|
|||||||
Shape data_shape{3, 3};
|
Shape data_shape{3, 3};
|
||||||
Shape indices_shape{2, 1};
|
Shape indices_shape{2, 1};
|
||||||
int64_t axis = -33;
|
int64_t axis = -33;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "axis out of bounds check failed";
|
FAIL() << "axis out of bounds check failed";
|
||||||
} catch (const ov::AssertFailure& error) {
|
} catch (const ov::AssertFailure& error) {
|
||||||
@ -184,11 +258,11 @@ TEST(type_prop, gather_elements_rank_consistency_check) {
|
|||||||
Shape data_shape{3, 3};
|
Shape data_shape{3, 3};
|
||||||
Shape indices_shape{2, 3, 3333};
|
Shape indices_shape{2, 3, 3333};
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "rank consistency check failed";
|
FAIL() << "rank consistency check failed";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
@ -202,16 +276,15 @@ TEST(type_prop, gather_elements_shape_inconsistency) {
|
|||||||
Shape data_shape{3, 3};
|
Shape data_shape{3, 3};
|
||||||
Shape indices_shape{2, 1};
|
Shape indices_shape{2, 1};
|
||||||
int64_t axis = 1;
|
int64_t axis = 1;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::f32, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i32, indices_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "Shape inconsistency check failed";
|
FAIL() << "Shape inconsistency check failed";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
EXPECT_HAS_SUBSTRING(error.what(),
|
EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
|
||||||
std::string("data and indices must have equal or intersecting sizes, except for axis"));
|
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
FAIL() << "Shape inconsistency check failed for unexpected reason";
|
FAIL() << "Shape inconsistency check failed for unexpected reason";
|
||||||
}
|
}
|
||||||
@ -221,16 +294,15 @@ TEST(type_prop, gather_elements_dynamic_inconsistent_shapes) {
|
|||||||
PartialShape data_shape{4, 2, 4, Dimension::dynamic()};
|
PartialShape data_shape{4, 2, 4, Dimension::dynamic()};
|
||||||
PartialShape indices_shape{1, 3, Dimension::dynamic(), 5};
|
PartialShape indices_shape{1, 3, Dimension::dynamic(), 5};
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
|
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
EXPECT_HAS_SUBSTRING(error.what(),
|
EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
|
||||||
std::string("data and indices must have equal or intersecting sizes, except for axis"));
|
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
|
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
|
||||||
}
|
}
|
||||||
@ -240,15 +312,15 @@ TEST(type_prop, gather_elements_incosistent_interval_shapes) {
|
|||||||
PartialShape data_shape{4, 4, 5};
|
PartialShape data_shape{4, 4, 5};
|
||||||
PartialShape indices_shape{1, Dimension(5, 10), 5};
|
PartialShape indices_shape{1, Dimension(5, 10), 5};
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
|
const auto data = make_shared<v0::Parameter>(element::Type_t::i8, data_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
|
const auto indices = make_shared<v0::Parameter>(element::Type_t::i64, indices_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
|
auto op = make_shared<v6::GatherElements>(data, indices, axis);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
|
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
EXPECT_HAS_SUBSTRING(error.what(),
|
EXPECT_HAS_SUBSTRING(error.what(), std::string("are not consistent"));
|
||||||
std::string("data and indices must have equal or intersecting sizes, except for axis"));
|
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
|
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
|
||||||
}
|
}
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
// Copyright (C) 2018-2023 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
#include <gather_elements_shape_inference.hpp>
|
|
||||||
#include <openvino/op/ops.hpp>
|
|
||||||
#include <openvino/op/parameter.hpp>
|
|
||||||
#include <utils/shape_inference/shape_inference.hpp>
|
|
||||||
#include <utils/shape_inference/static_shape.hpp>
|
|
||||||
|
|
||||||
using namespace ov;
|
|
||||||
using namespace ov::intel_cpu;
|
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, GatherElementsTest) {
|
|
||||||
int64_t axis = -1;
|
|
||||||
auto D = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
|
||||||
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
|
|
||||||
auto GE = std::make_shared<op::v6::GatherElements>(D, I, axis);
|
|
||||||
// Test StaticShape
|
|
||||||
std::vector<StaticShape> static_input_shapes = {StaticShape{300, 3, 10, 1}, StaticShape{300, 3, 10, 33333}},
|
|
||||||
static_output_shapes = {StaticShape{}};
|
|
||||||
shape_inference(GE.get(), static_input_shapes, static_output_shapes);
|
|
||||||
ASSERT_EQ(static_output_shapes[0], (StaticShape{300, 3, 10, 33333}));
|
|
||||||
}
|
|
@ -0,0 +1,65 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "common_test_utils/test_assertions.hpp"
|
||||||
|
#include "gather_elements_shape_inference.hpp"
|
||||||
|
#include "openvino/op/ops.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::intel_cpu;
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
class GatherElementsStaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v6::GatherElements> {};
|
||||||
|
|
||||||
|
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_basic) {
|
||||||
|
int64_t axis = -1;
|
||||||
|
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
|
||||||
|
|
||||||
|
op = make_op(data, indices, axis);
|
||||||
|
input_shapes = {StaticShape{300, 3, 10, 2}, StaticShape{300, 3, 10, 33333}};
|
||||||
|
output_shapes = {StaticShape{}};
|
||||||
|
|
||||||
|
shape_inference(op.get(), input_shapes, output_shapes);
|
||||||
|
EXPECT_EQ(output_shapes[0], (StaticShape{300, 3, 10, 33333}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_incompatible_rank) {
|
||||||
|
int64_t axis = -1;
|
||||||
|
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||||
|
|
||||||
|
op = make_op(data, indices, axis);
|
||||||
|
input_shapes = {StaticShape{1, 2, 3, 4, 5}, StaticShape{1, 2, 3, 4}};
|
||||||
|
output_shapes = {StaticShape{}};
|
||||||
|
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||||
|
ov::NodeValidationFailure,
|
||||||
|
HasSubstr("rank must be equal"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_incompatible_dims) {
|
||||||
|
int64_t axis = -1;
|
||||||
|
const auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
const auto indices = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
|
||||||
|
|
||||||
|
op = make_op(data, indices, axis);
|
||||||
|
input_shapes = {StaticShape{300, 4, 10, 2}, StaticShape{300, 5, 10, 33333}};
|
||||||
|
output_shapes = {StaticShape{}};
|
||||||
|
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||||
|
ov::NodeValidationFailure,
|
||||||
|
HasSubstr("are not consistent"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GatherElementsStaticShapeInferenceTest, GatherElements_default_constructor) {
|
||||||
|
int64_t axis = -1;
|
||||||
|
op = make_op();
|
||||||
|
op->set_axis(axis);
|
||||||
|
input_shapes = {StaticShape{300, 3, 10, 2}, StaticShape{300, 3, 10, 33333}};
|
||||||
|
output_shapes = {StaticShape{}};
|
||||||
|
|
||||||
|
shape_infer(op.get(), input_shapes, output_shapes);
|
||||||
|
EXPECT_EQ(output_shapes[0], (StaticShape{300, 3, 10, 33333}));
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user