Relax shape inference for Split:v1 (#2444)

* Relaxed shape inference for Split

* added unit tests

* review remarks
This commit is contained in:
Mateusz Bencer 2020-10-01 19:29:19 +02:00 committed by GitHub
parent e3270b6b34
commit 7784b97bbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 164 additions and 26 deletions

View File

@ -141,15 +141,19 @@ void op::v1::Split::validate_and_infer_types()
const auto axis_ps = input_value(1).get_partial_shape();
const auto axis_et = input_value(1).get_element_type();
if (axis_ps.rank().is_static())
{
NODE_VALIDATION_CHECK(this,
axis_ps.rank().is_static() && axis_ps.rank().get_length() == 0,
axis_ps.rank().get_length() == 0,
"The 'axis' input is expected to be a scalar. Got: ",
axis_ps);
}
NODE_VALIDATION_CHECK(
this, axis_et.is_integral(), "The 'axis' input only accepts integral types");
if (op::is_constant(input_value(1).get_node()) && data_ps.is_static())
PartialShape each_output_shape{data_ps};
if (op::is_constant(input_value(1).get_node()) && data_ps.rank().is_static())
{
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>()[0];
@ -157,8 +161,9 @@ void op::v1::Split::validate_and_infer_types()
const auto data_rank = get_input_partial_shape(0).rank();
axis = ngraph::normalize_axis(this, axis, data_rank);
const auto data_shape = data_ps.to_shape();
const auto dimension_at_axis = data_shape.at(axis);
if (data_ps[axis].is_static())
{
const auto dimension_at_axis = data_ps[axis].get_length();
NODE_VALIDATION_CHECK(this,
dimension_at_axis % m_num_splits == 0,
@ -167,23 +172,24 @@ void op::v1::Split::validate_and_infer_types()
" has to be a multiple of the 'num_splits' attribute value: ",
m_num_splits);
Shape each_output_shape{data_shape};
each_output_shape.at(axis) = dimension_at_axis / m_num_splits;
each_output_shape[axis] = dimension_at_axis / m_num_splits;
}
else
{
each_output_shape[axis] = Dimension::dynamic();
}
}
else
{
each_output_shape = PartialShape::dynamic(data_ps.rank());
}
for (size_t i = 0; i < m_num_splits; ++i)
{
set_output_type(i, get_input_element_type(0), each_output_shape);
}
}
else
{
for (size_t i = 0; i < m_num_splits; ++i)
{
set_output_type(i, get_input_element_type(0), PartialShape::dynamic());
}
set_input_is_relevant_to_shape(0);
}
}
shared_ptr<Node> op::v1::Split::clone_with_new_inputs(const OutputVector& new_args) const

View File

@ -102,3 +102,135 @@ TEST(type_prop, split_axis_must_be_constant)
FAIL() << "Deduced type check failed for unexpected reason.";
}
}
TEST(type_prop, split_v1)
{
const auto data = make_shared<op::Parameter>(element::f16, Shape{2, 3, 4});
const auto axis = op::Constant::create(element::i64, {}, {1});
const size_t num_splits = 3;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_element_type(i), element::f16);
EXPECT_EQ(split->get_output_shape(i), (Shape{2, 1, 4}));
}
}
TEST(type_prop, split_v1_axis_const_data_axis_dim_known)
{
const auto data =
make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
const auto axis = op::Constant::create(element::i32, {}, {1});
const size_t num_splits = 3;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_partial_shape(i), (PartialShape{2, 1, Dimension::dynamic()}));
}
}
TEST(type_prop, split_v1_axis_const_only_data_axis_dim_known)
{
const auto data = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
const auto axis = op::Constant::create(element::i16, {}, {0});
const size_t num_splits = 2;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_partial_shape(i),
(PartialShape{1, Dimension::dynamic(), Dimension::dynamic()}));
}
}
TEST(type_prop, split_v1_axis_const_data_axis_dim_unknown)
{
const auto data =
make_shared<op::Parameter>(element::f32, PartialShape{4, Dimension::dynamic(), 3, 5});
const auto axis = op::Constant::create(element::i8, {}, {1});
const size_t num_splits = 3;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_partial_shape(i),
(PartialShape{4, Dimension::dynamic(), 3, 5}));
}
}
TEST(type_prop, split_v1_axis_const_only_data_rank_known)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
const auto axis = op::Constant::create(element::u64, {}, {1});
const size_t num_splits = 3;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic(4));
}
}
TEST(type_prop, split_v1_axis_not_const_only_data_rank_known)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
const auto axis = make_shared<op::Parameter>(element::u32, PartialShape{});
const size_t num_splits = 3;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic(4));
}
}
TEST(type_prop, split_v1_axis_const_data_rank_unknown)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto axis = op::Constant::create(element::u16, {}, {2});
const size_t num_splits = 3;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic());
}
}
TEST(type_prop, split_v1_axis_not_const_data_rank_unknown)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto axis = make_shared<op::Parameter>(element::u8, PartialShape{});
const size_t num_splits = 3;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic());
}
}
TEST(type_prop, split_v1_axis_dynamic_rank)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto axis = make_shared<op::Parameter>(element::u8, PartialShape::dynamic());
const size_t num_splits = 3;
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
EXPECT_EQ(split->outputs().size(), num_splits);
for (int i = 0; i < num_splits; ++i)
{
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic());
}
}