Revise of reshape operator after documentation update (new tests, remove unused code) (#3410)

* Remove unnecessary code from the reshape operator shape deduction

* Add new tests to cover cornercases after documentation update

* Fix typo in test name

* Fix codestyle issues

* Fix tests naming
This commit is contained in:
Bartosz Sledz
2020-12-01 04:03:20 +01:00
committed by GitHub
parent 0a52702e6a
commit cb03a5e052
3 changed files with 107 additions and 7 deletions

View File

@@ -205,13 +205,6 @@ void op::v1::Reshape::validate_and_infer_types()
}
}
if (out_shape_val == std::vector<std::int64_t>{0, -1} &&
input_pshape.rank().is_static() && input_pshape.rank().get_length() == 2)
{
partial_shape[0] = input_pshape[0];
partial_shape[1] = input_pshape[1];
}
set_output_type(0, get_input_element_type(0), PartialShape(partial_shape));
}
}

View File

@@ -727,6 +727,77 @@ TEST(eval, evaluate_reshape_v1_pattern_int16)
ASSERT_EQ(computed_val, expected_val);
}
TEST(eval, evaluate_reshape_v1_special_zero_shape_neg_zero)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 1, 2});
auto pattern = make_shared<op::Parameter>(element::i64, Shape{2});
auto dyn_reshape = make_shared<op::v1::Reshape>(data, pattern, true);
auto func = make_shared<Function>(OutputVector{dyn_reshape}, ParameterVector{data, pattern});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(
func->evaluate({result_tensor},
{make_host_tensor<element::Type_t::f32>({3, 1, 2}, {0, 1, 2, 3, 4, 5}),
make_host_tensor<element::Type_t::i64>({2}, {-1, 0})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{6, 1}));
auto computed_val = read_vector<float>(result_tensor);
vector<float> expected_val{0, 1, 2, 3, 4, 5};
ASSERT_EQ(computed_val, expected_val);
}
TEST(eval, evaluate_reshape_v1_special_zero_shape_zero_neg)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 1, 2});
auto pattern = make_shared<op::Parameter>(element::i64, Shape{2});
auto dyn_reshape = make_shared<op::v1::Reshape>(data, pattern, true);
auto func = make_shared<Function>(OutputVector{dyn_reshape}, ParameterVector{data, pattern});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(
func->evaluate({result_tensor},
{make_host_tensor<element::Type_t::f32>({3, 1, 2}, {0, 1, 2, 3, 4, 5}),
make_host_tensor<element::Type_t::i64>({2}, {0, -1})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 2}));
auto computed_val = read_vector<float>(result_tensor);
vector<float> expected_val{0, 1, 2, 3, 4, 5};
ASSERT_EQ(computed_val, expected_val);
}
TEST(eval, evaluate_reshape_v1_special_zero_shape_zero_neg_copy_input)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto pattern = make_shared<op::Parameter>(element::i64, Shape{2});
auto dyn_reshape = make_shared<op::v1::Reshape>(data, pattern, true);
auto func = make_shared<Function>(OutputVector{dyn_reshape}, ParameterVector{data, pattern});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(func->evaluate({result_tensor},
{make_host_tensor<element::Type_t::f32>({3, 1}, {0, 1, 2}),
make_host_tensor<element::Type_t::i64>({2}, {0, -1})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 1}));
auto computed_val = read_vector<float>(result_tensor);
vector<float> expected_val{0, 1, 2};
ASSERT_EQ(computed_val, expected_val);
}
TEST(eval, evaluate_reshape_v1_special_zero_shape_zero_zero_one_neg)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3});
auto pattern = make_shared<op::Parameter>(element::i64, Shape{4});
auto dyn_reshape = make_shared<op::v1::Reshape>(data, pattern, true);
auto func = make_shared<Function>(OutputVector{dyn_reshape}, ParameterVector{data, pattern});
auto result_tensor = make_shared<HostTensor>();
ASSERT_TRUE(func->evaluate(
{result_tensor},
{make_host_tensor<element::Type_t::f32>({2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}),
make_host_tensor<element::Type_t::i64>({4}, {0, 0, 1, -1})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{2, 2, 1, 3}));
auto computed_val = read_vector<float>(result_tensor);
vector<float> expected_val{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
ASSERT_EQ(computed_val, expected_val);
}
TEST(eval, evaluate_convert)
{
auto p = make_shared<op::Parameter>(element::f32, PartialShape{-1, -1});

View File

@@ -158,3 +158,39 @@ TEST(type_prop, reshape_partial_rank_static_dynamic_but_zero_ok)
ASSERT_TRUE(r->get_output_partial_shape(0).is_static());
ASSERT_EQ(r->get_shape(), (Shape{3, 1, 0, 2}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1, 2});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::u64, {2}, std::vector<int64_t>{-1, 0}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{6, 1}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1, 2});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::u64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{3, 2}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_copy_input)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::u64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{3, 1}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::u64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{2, 2, 1, 3}));
}