Relax shape inference for Split:v1 (#2444)
* Relaxed shape inference for Split * added unit tests * review remarks
This commit is contained in:
parent
e3270b6b34
commit
7784b97bbc
@ -141,15 +141,19 @@ void op::v1::Split::validate_and_infer_types()
|
||||
const auto axis_ps = input_value(1).get_partial_shape();
|
||||
const auto axis_et = input_value(1).get_element_type();
|
||||
|
||||
if (axis_ps.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis_ps.rank().is_static() && axis_ps.rank().get_length() == 0,
|
||||
axis_ps.rank().get_length() == 0,
|
||||
"The 'axis' input is expected to be a scalar. Got: ",
|
||||
axis_ps);
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, axis_et.is_integral(), "The 'axis' input only accepts integral types");
|
||||
|
||||
if (op::is_constant(input_value(1).get_node()) && data_ps.is_static())
|
||||
PartialShape each_output_shape{data_ps};
|
||||
if (op::is_constant(input_value(1).get_node()) && data_ps.rank().is_static())
|
||||
{
|
||||
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
|
||||
auto axis = axis_input->cast_vector<int64_t>()[0];
|
||||
@ -157,8 +161,9 @@ void op::v1::Split::validate_and_infer_types()
|
||||
const auto data_rank = get_input_partial_shape(0).rank();
|
||||
axis = ngraph::normalize_axis(this, axis, data_rank);
|
||||
|
||||
const auto data_shape = data_ps.to_shape();
|
||||
const auto dimension_at_axis = data_shape.at(axis);
|
||||
if (data_ps[axis].is_static())
|
||||
{
|
||||
const auto dimension_at_axis = data_ps[axis].get_length();
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
dimension_at_axis % m_num_splits == 0,
|
||||
@ -167,24 +172,25 @@ void op::v1::Split::validate_and_infer_types()
|
||||
" has to be a multiple of the 'num_splits' attribute value: ",
|
||||
m_num_splits);
|
||||
|
||||
Shape each_output_shape{data_shape};
|
||||
each_output_shape.at(axis) = dimension_at_axis / m_num_splits;
|
||||
each_output_shape[axis] = dimension_at_axis / m_num_splits;
|
||||
}
|
||||
else
|
||||
{
|
||||
each_output_shape[axis] = Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
each_output_shape = PartialShape::dynamic(data_ps.rank());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < m_num_splits; ++i)
|
||||
{
|
||||
set_output_type(i, get_input_element_type(0), each_output_shape);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < m_num_splits; ++i)
|
||||
{
|
||||
set_output_type(i, get_input_element_type(0), PartialShape::dynamic());
|
||||
}
|
||||
|
||||
set_input_is_relevant_to_shape(0);
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::Split::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
|
@ -102,3 +102,135 @@ TEST(type_prop, split_axis_must_be_constant)
|
||||
FAIL() << "Deduced type check failed for unexpected reason.";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f16, Shape{2, 3, 4});
|
||||
const auto axis = op::Constant::create(element::i64, {}, {1});
|
||||
const size_t num_splits = 3;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_element_type(i), element::f16);
|
||||
EXPECT_EQ(split->get_output_shape(i), (Shape{2, 1, 4}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1_axis_const_data_axis_dim_known)
|
||||
{
|
||||
const auto data =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
|
||||
const auto axis = op::Constant::create(element::i32, {}, {1});
|
||||
const size_t num_splits = 3;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_partial_shape(i), (PartialShape{2, 1, Dimension::dynamic()}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1_axis_const_only_data_axis_dim_known)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(
|
||||
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
|
||||
const auto axis = op::Constant::create(element::i16, {}, {0});
|
||||
const size_t num_splits = 2;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_partial_shape(i),
|
||||
(PartialShape{1, Dimension::dynamic(), Dimension::dynamic()}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1_axis_const_data_axis_dim_unknown)
|
||||
{
|
||||
const auto data =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{4, Dimension::dynamic(), 3, 5});
|
||||
const auto axis = op::Constant::create(element::i8, {}, {1});
|
||||
const size_t num_splits = 3;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_partial_shape(i),
|
||||
(PartialShape{4, Dimension::dynamic(), 3, 5}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1_axis_const_only_data_rank_known)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
|
||||
const auto axis = op::Constant::create(element::u64, {}, {1});
|
||||
const size_t num_splits = 3;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic(4));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1_axis_not_const_only_data_rank_known)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
|
||||
const auto axis = make_shared<op::Parameter>(element::u32, PartialShape{});
|
||||
const size_t num_splits = 3;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic(4));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1_axis_const_data_rank_unknown)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto axis = op::Constant::create(element::u16, {}, {2});
|
||||
const size_t num_splits = 3;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1_axis_not_const_data_rank_unknown)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto axis = make_shared<op::Parameter>(element::u8, PartialShape{});
|
||||
const size_t num_splits = 3;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, split_v1_axis_dynamic_rank)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto axis = make_shared<op::Parameter>(element::u8, PartialShape::dynamic());
|
||||
const size_t num_splits = 3;
|
||||
const auto split = make_shared<op::v1::Split>(data, axis, num_splits);
|
||||
|
||||
EXPECT_EQ(split->outputs().size(), num_splits);
|
||||
for (int i = 0; i < num_splits; ++i)
|
||||
{
|
||||
EXPECT_EQ(split->get_output_partial_shape(i), PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user