From 7784b97bbc2eaf31a62e31379906be02f50f46b8 Mon Sep 17 00:00:00 2001 From: Mateusz Bencer Date: Thu, 1 Oct 2020 19:29:19 +0200 Subject: [PATCH] Relax shape inference for Split:v1 (#2444) * Relaxed shape inference for Split * added unit tests * review remarks --- ngraph/core/src/op/split.cpp | 58 +++++++------- ngraph/test/type_prop/split.cpp | 132 ++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 26 deletions(-) diff --git a/ngraph/core/src/op/split.cpp b/ngraph/core/src/op/split.cpp index 2045e2a5907..0e95fdd6dba 100644 --- a/ngraph/core/src/op/split.cpp +++ b/ngraph/core/src/op/split.cpp @@ -141,15 +141,19 @@ void op::v1::Split::validate_and_infer_types() const auto axis_ps = input_value(1).get_partial_shape(); const auto axis_et = input_value(1).get_element_type(); - NODE_VALIDATION_CHECK(this, - axis_ps.rank().is_static() && axis_ps.rank().get_length() == 0, - "The 'axis' input is expected to be a scalar. Got: ", - axis_ps); + if (axis_ps.rank().is_static()) + { + NODE_VALIDATION_CHECK(this, + axis_ps.rank().get_length() == 0, + "The 'axis' input is expected to be a scalar. Got: ", + axis_ps); + } NODE_VALIDATION_CHECK( this, axis_et.is_integral(), "The 'axis' input only accepts integral types"); - if (op::is_constant(input_value(1).get_node()) && data_ps.is_static()) + PartialShape each_output_shape{data_ps}; + if (op::is_constant(input_value(1).get_node()) && data_ps.rank().is_static()) { const auto axis_input = as_type_ptr(input_value(1).get_node_shared_ptr()); auto axis = axis_input->cast_vector()[0]; @@ -157,33 +161,35 @@ void op::v1::Split::validate_and_infer_types() const auto data_rank = get_input_partial_shape(0).rank(); axis = ngraph::normalize_axis(this, axis, data_rank); - const auto data_shape = data_ps.to_shape(); - const auto dimension_at_axis = data_shape.at(axis); - - NODE_VALIDATION_CHECK(this, - dimension_at_axis % m_num_splits == 0, - "The input tensor's dimension pointed by the 'axis' parameter: ", - dimension_at_axis, - " has to be a multiple of the 'num_splits' attribute value: ", - m_num_splits); - - Shape each_output_shape{data_shape}; - each_output_shape.at(axis) = dimension_at_axis / m_num_splits; - - for (size_t i = 0; i < m_num_splits; ++i) + if (data_ps[axis].is_static()) { - set_output_type(i, get_input_element_type(0), each_output_shape); + const auto dimension_at_axis = data_ps[axis].get_length(); + + NODE_VALIDATION_CHECK(this, + dimension_at_axis % m_num_splits == 0, + "The input tensor's dimension pointed by the 'axis' parameter: ", + dimension_at_axis, + " has to be a multiple of the 'num_splits' attribute value: ", + m_num_splits); + + each_output_shape[axis] = dimension_at_axis / m_num_splits; + } + else + { + each_output_shape[axis] = Dimension::dynamic(); } } else { - for (size_t i = 0; i < m_num_splits; ++i) - { - set_output_type(i, get_input_element_type(0), PartialShape::dynamic()); - } - - set_input_is_relevant_to_shape(0); + each_output_shape = PartialShape::dynamic(data_ps.rank()); } + + for (size_t i = 0; i < m_num_splits; ++i) + { + set_output_type(i, get_input_element_type(0), each_output_shape); + } + + set_input_is_relevant_to_shape(0); } shared_ptr op::v1::Split::clone_with_new_inputs(const OutputVector& new_args) const diff --git a/ngraph/test/type_prop/split.cpp b/ngraph/test/type_prop/split.cpp index 5c0e4a66a45..70d431b66ab 100644 --- a/ngraph/test/type_prop/split.cpp +++ b/ngraph/test/type_prop/split.cpp @@ -102,3 +102,135 @@ TEST(type_prop, split_axis_must_be_constant) FAIL() << "Deduced type check failed for unexpected reason."; } } + +TEST(type_prop, split_v1) +{ + const auto data = make_shared(element::f16, Shape{2, 3, 4}); + const auto axis = op::Constant::create(element::i64, {}, {1}); + const size_t num_splits = 3; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_element_type(i), element::f16); + EXPECT_EQ(split->get_output_shape(i), (Shape{2, 1, 4})); + } +} + +TEST(type_prop, split_v1_axis_const_data_axis_dim_known) +{ + const auto data = + make_shared(element::f32, PartialShape{2, 3, Dimension::dynamic()}); + const auto axis = op::Constant::create(element::i32, {}, {1}); + const size_t num_splits = 3; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_partial_shape(i), (PartialShape{2, 1, Dimension::dynamic()})); + } +} + +TEST(type_prop, split_v1_axis_const_only_data_axis_dim_known) +{ + const auto data = make_shared( + element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()}); + const auto axis = op::Constant::create(element::i16, {}, {0}); + const size_t num_splits = 2; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_partial_shape(i), + (PartialShape{1, Dimension::dynamic(), Dimension::dynamic()})); + } +} + +TEST(type_prop, split_v1_axis_const_data_axis_dim_unknown) +{ + const auto data = + make_shared(element::f32, PartialShape{4, Dimension::dynamic(), 3, 5}); + const auto axis = op::Constant::create(element::i8, {}, {1}); + const size_t num_splits = 3; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_partial_shape(i), + (PartialShape{4, Dimension::dynamic(), 3, 5})); + } +} + +TEST(type_prop, split_v1_axis_const_only_data_rank_known) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic(4)); + const auto axis = op::Constant::create(element::u64, {}, {1}); + const size_t num_splits = 3; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic(4)); + } +} + +TEST(type_prop, split_v1_axis_not_const_only_data_rank_known) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic(4)); + const auto axis = make_shared(element::u32, PartialShape{}); + const size_t num_splits = 3; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic(4)); + } +} + +TEST(type_prop, split_v1_axis_const_data_rank_unknown) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic()); + const auto axis = op::Constant::create(element::u16, {}, {2}); + const size_t num_splits = 3; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic()); + } +} + +TEST(type_prop, split_v1_axis_not_const_data_rank_unknown) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic()); + const auto axis = make_shared(element::u8, PartialShape{}); + const size_t num_splits = 3; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic()); + } +} + +TEST(type_prop, split_v1_axis_dynamic_rank) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic()); + const auto axis = make_shared(element::u8, PartialShape::dynamic()); + const size_t num_splits = 3; + const auto split = make_shared(data, axis, num_splits); + + EXPECT_EQ(split->outputs().size(), num_splits); + for (int i = 0; i < num_splits; ++i) + { + EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic()); + } +}