Reshape should support reshape to zero shapes (#5828)

* Reshape should support reshape to zero shapes

* Fixed comments

* Fixed backward compatible check

* Fixed myriad tests

* Removed header

* Fixed myriad tests

* Disabled Myriad tests

* Fix tests

* Fixed evaluate

* Fixed comments

* FIxed tests

* Fixed tests

* Fixed code style

* Fixed Myriad tests

* Added more tests
This commit is contained in:
Ilya Churaev 2021-06-03 06:26:30 +03:00 committed by GitHub
parent ac1803c3ad
commit d56cf51c81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 262 additions and 157 deletions

View File

@ -82,7 +82,7 @@ INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticTopKPropagationConcatBased,
class DynamicToStaticTopKPropagationConcatReshape : public DynamicToStaticTopKPropagationConcatBased { class DynamicToStaticTopKPropagationConcatReshape : public DynamicToStaticTopKPropagationConcatBased {
protected: protected:
std::shared_ptr<ngraph::Node> buildSubgraph(std::shared_ptr<ngraph::Node> node) const override { std::shared_ptr<ngraph::Node> buildSubgraph(std::shared_ptr<ngraph::Node> node) const override {
return std::make_shared<ngraph::opset5::Reshape>(node, ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}), false); return std::make_shared<ngraph::opset5::Reshape>(node, ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}), false);
} }
}; };

View File

@ -64,6 +64,12 @@ namespace ngraph
bool m_special_zero; bool m_special_zero;
bool evaluate_reshape(const HostTensorVector& outputs, bool evaluate_reshape(const HostTensorVector& outputs,
const HostTensorVector& inputs) const; const HostTensorVector& inputs) const;
private:
void calculate_output_shape(std::vector<Dimension>& reshape_pattern,
const int64_t& minus_one_idx,
const PartialShape& input_pshape,
std::vector<Dimension>& output_shape) const;
}; };
} // namespace v1 } // namespace v1
} // namespace op } // namespace op

View File

@ -35,158 +35,12 @@ namespace reshapeop
{ {
using T = typename element_type_traits<ET>::value_type; using T = typename element_type_traits<ET>::value_type;
T* shape_pattern_ptr = shape_pattern->get_data_ptr<ET>(); T* shape_pattern_ptr = shape_pattern->get_data_ptr<ET>();
size_t output_rank = shape_pattern->get_shape()[0]; size_t output_rank = shape_pattern->get_shape().empty() ? 0 : shape_pattern->get_shape()[0];
for (size_t i = 0; i < output_rank; i++) for (size_t i = 0; i < output_rank; i++)
{ {
output_shape.push_back(shape_pattern_ptr[i]); output_shape.push_back(shape_pattern_ptr[i]);
} }
} }
void calculate_output_shape(const ngraph::op::v1::Reshape* reshape_node,
vector<Dimension>& reshape_pattern,
const int64_t& minus_one_idx,
const PartialShape& input_pshape,
vector<Dimension>& output_shape)
{
if (reshape_pattern == std::vector<Dimension>{0} && !reshape_node->get_special_zero())
{ // legacy check introduced by PR #1206
reshape_pattern = std::vector<Dimension>{};
output_shape = {};
return;
}
Dimension output_product(1);
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i)
{
if (i == minus_one_idx) // resolving everything except -1
continue;
auto pattern_dim = reshape_pattern[i];
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
reshape_node->get_special_zero())
{
if (input_pshape.rank().is_dynamic())
{
output_shape[i] = Dimension::dynamic();
output_product *= Dimension::dynamic();
}
else
{
NODE_VALIDATION_CHECK(reshape_node,
i < input_pshape.rank().get_length(),
"'0' dimension is out of range");
output_shape[i] = input_pshape[i];
// we do not include dimension to output product here and won't include in input
// product later because we will divide output_product by input_product. This
// dimension contributes to both products equally, but in case this dimension
// is dynamic and others are not we could fully define output dimension that
// is masked by -1
}
}
else
{
output_shape[i] = pattern_dim;
output_product *= pattern_dim;
}
}
Dimension input_product(1);
if (input_pshape.rank().is_static())
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i)
{
if (i < static_cast<int64_t>(reshape_pattern.size()) &&
reshape_pattern[i].get_min_length() == 0 &&
reshape_pattern[i].get_max_length() == 0)
continue;
input_product *= input_pshape[i];
}
else
input_product = Dimension::dynamic();
if (minus_one_idx != -1) // resolving -1 masked dimension
{
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
{
// TODO: Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(reshape_node,
input_product.get_min_length() == 0 &&
input_product.get_max_length() == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
output_shape[minus_one_idx] = Dimension(0);
}
else
{
if (input_product.is_static() && output_product.is_static())
{
NODE_VALIDATION_CHECK(
reshape_node,
input_product.get_length() % output_product.get_length() == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
}
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
input_product == Dimension())
{
output_shape[minus_one_idx] = Dimension::dynamic();
}
else
{
Dimension::value_type lower;
if (input_product.get_min_length() == 0)
lower = 0;
else if (input_product.get_min_length() == -1 ||
output_product.get_max_length() == 0 ||
output_product.get_max_length() == -1)
lower = -1; // dynamic
else
lower = static_cast<Dimension::value_type>(
ceil(static_cast<double>(input_product.get_min_length()) /
output_product.get_max_length()));
Dimension::value_type upper;
if (input_product.get_max_length() == 0)
upper = 0;
else if (input_product.get_max_length() == -1 ||
output_product.get_min_length() == 0 ||
output_product.get_min_length() == -1)
upper = -1; // dynamic
else
upper = static_cast<Dimension::value_type>(
floor(static_cast<double>(input_product.get_max_length()) /
output_product.get_min_length()));
if (lower == -1)
output_shape[minus_one_idx] = Dimension::dynamic();
else if (upper == -1)
output_shape[minus_one_idx] = Dimension(lower, upper);
else if (lower > upper) // empty intersection
output_shape[minus_one_idx] = Dimension::dynamic();
else
output_shape[minus_one_idx] = Dimension(lower, upper);
}
}
}
PartialShape output_pshape(output_shape);
if (input_pshape.is_static() && output_pshape.is_static())
{
size_t zero_dims =
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
});
bool backward_compatible_check =
(zero_dims && reshape_node->get_special_zero()) || minus_one_idx != -1;
bool in_out_elements_equal = shape_size(reshape_node->get_input_shape(0)) ==
shape_size(output_pshape.to_shape());
NODE_VALIDATION_CHECK(reshape_node,
backward_compatible_check || in_out_elements_equal,
"Requested output shape ",
output_shape,
" is incompatible with input shape ",
reshape_node->get_input_shape(0));
}
}
} // namespace reshapeop } // namespace reshapeop
NGRAPH_RTTI_DEFINITION(op::v1::Reshape, "Reshape", 1); NGRAPH_RTTI_DEFINITION(op::v1::Reshape, "Reshape", 1);
@ -204,7 +58,6 @@ bool op::v1::Reshape::visit_attributes(AttributeVisitor& visitor)
visitor.on_attribute("special_zero", m_special_zero); visitor.on_attribute("special_zero", m_special_zero);
return true; return true;
} }
void op::v1::Reshape::validate_and_infer_types() void op::v1::Reshape::validate_and_infer_types()
{ {
NGRAPH_OP_SCOPE(v1_Reshape_validate_and_infer_types); NGRAPH_OP_SCOPE(v1_Reshape_validate_and_infer_types);
@ -217,16 +70,21 @@ void op::v1::Reshape::validate_and_infer_types()
const PartialShape& input_pshape = get_input_partial_shape(0); const PartialShape& input_pshape = get_input_partial_shape(0);
const PartialShape& shape_pattern_shape = get_input_partial_shape(1); const PartialShape& shape_pattern_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
shape_pattern_shape.rank().compatible(1), shape_pattern_shape.rank().compatible(1) ||
"Pattern shape must have rank 1, got ", (shape_pattern_shape.rank().is_static() &&
shape_pattern_shape.rank().get_length() == 0),
"Pattern shape must have rank 1 or be empty, got ",
shape_pattern_shape.rank(), shape_pattern_shape.rank(),
"."); ".");
Rank output_rank = Rank output_rank =
shape_pattern_shape.rank().is_dynamic() ? Rank::dynamic() : shape_pattern_shape[0]; shape_pattern_shape.rank().is_dynamic()
? Rank::dynamic()
: shape_pattern_shape.rank().get_length() == 0 ? 0 : shape_pattern_shape[0];
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank)); set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
std::vector<Dimension> reshape_pattern; std::vector<Dimension> reshape_pattern;
bool shape_can_be_calculated = false;
int64_t minus_one_idx = -1; int64_t minus_one_idx = -1;
HostTensorPtr lb, ub; HostTensorPtr lb, ub;
@ -235,6 +93,7 @@ void op::v1::Reshape::validate_and_infer_types()
{ {
const auto lower_bound = std::make_shared<op::Constant>(lb)->cast_vector<int64_t>(); const auto lower_bound = std::make_shared<op::Constant>(lb)->cast_vector<int64_t>();
const auto upper_bound = std::make_shared<op::Constant>(ub)->cast_vector<int64_t>(); const auto upper_bound = std::make_shared<op::Constant>(ub)->cast_vector<int64_t>();
shape_can_be_calculated = true;
NGRAPH_CHECK(lower_bound.size() == upper_bound.size()); NGRAPH_CHECK(lower_bound.size() == upper_bound.size());
for (size_t i = 0; i < lower_bound.size(); ++i) for (size_t i = 0; i < lower_bound.size(); ++i)
{ {
@ -250,13 +109,22 @@ void op::v1::Reshape::validate_and_infer_types()
} }
reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]); reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]);
} }
// For scalar case reshape_patter should be empty but scalar reshape pattern should be empty
// or equal to 1
if (output_rank.is_static() && output_rank.get_length() == 0 && !lower_bound.empty())
{
reshape_pattern.clear();
NGRAPH_CHECK(lower_bound.size() == 1);
NODE_VALIDATION_CHECK(this,
lower_bound[0] == 1 && upper_bound[0] == 1,
"The value of scalar shape pattern should be equal to 1!");
}
} }
if (!reshape_pattern.empty()) if (shape_can_be_calculated)
{ {
std::vector<Dimension> output_shape(output_rank.get_length()); std::vector<Dimension> output_shape(output_rank.get_length());
reshapeop::calculate_output_shape( calculate_output_shape(reshape_pattern, minus_one_idx, input_pshape, output_shape);
this, reshape_pattern, minus_one_idx, input_pshape, output_shape);
set_output_type(0, get_input_element_type(0), output_shape); set_output_type(0, get_input_element_type(0), output_shape);
} }
} }
@ -311,8 +179,8 @@ bool op::v1::Reshape::evaluate_reshape(const HostTensorVector& outputs,
} }
std::vector<Dimension> output_shape(out_shape_val.size()); std::vector<Dimension> output_shape(out_shape_val.size());
reshapeop::calculate_output_shape( calculate_output_shape(
this, reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape); reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape);
NGRAPH_CHECK(PartialShape(output_shape).is_static()); NGRAPH_CHECK(PartialShape(output_shape).is_static());
outputs[0]->set_shape(PartialShape(output_shape).to_shape()); outputs[0]->set_shape(PartialShape(output_shape).to_shape());
@ -390,3 +258,140 @@ bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVec
} }
return false; return false;
} }
void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
const int64_t& minus_one_idx,
const PartialShape& input_pshape,
vector<Dimension>& output_shape) const
{
Dimension output_product(1);
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i)
{
if (i == minus_one_idx) // resolving everything except -1
continue;
auto pattern_dim = reshape_pattern[i];
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
get_special_zero())
{
if (input_pshape.rank().is_dynamic())
{
output_shape[i] = Dimension::dynamic();
output_product *= Dimension::dynamic();
}
else
{
NODE_VALIDATION_CHECK(
this, i < input_pshape.rank().get_length(), "'0' dimension is out of range");
output_shape[i] = input_pshape[i];
// we do not include dimension to output product here and won't include in input
// product later because we will divide output_product by input_product. This
// dimension contributes to both products equally, but in case this dimension
// is dynamic and others are not we could fully define output dimension that
// is masked by -1
}
}
else
{
output_shape[i] = pattern_dim;
output_product *= pattern_dim;
}
}
Dimension input_product(1);
if (input_pshape.rank().is_static())
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i)
{
if (i < static_cast<int64_t>(reshape_pattern.size()) &&
reshape_pattern[i].get_min_length() == 0 &&
reshape_pattern[i].get_max_length() == 0)
continue;
input_product *= input_pshape[i];
}
else
input_product = Dimension::dynamic();
if (minus_one_idx != -1) // resolving -1 masked dimension
{
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
{
// TODO: Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_product.get_min_length() == 0 &&
input_product.get_max_length() == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
output_shape[minus_one_idx] = Dimension(0);
}
else
{
if (input_product.is_static() && output_product.is_static())
{
NODE_VALIDATION_CHECK(
this,
input_product.get_length() % output_product.get_length() == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
}
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
input_product == Dimension())
{
output_shape[minus_one_idx] = Dimension::dynamic();
}
else
{
Dimension::value_type lower;
if (input_product.get_min_length() == 0)
lower = 0;
else if (input_product.get_min_length() == -1 ||
output_product.get_max_length() == 0 ||
output_product.get_max_length() == -1)
lower = -1; // dynamic
else
lower = static_cast<Dimension::value_type>(
ceil(static_cast<double>(input_product.get_min_length()) /
output_product.get_max_length()));
Dimension::value_type upper;
if (input_product.get_max_length() == 0)
upper = 0;
else if (input_product.get_max_length() == -1 ||
output_product.get_min_length() == 0 ||
output_product.get_min_length() == -1)
upper = -1; // dynamic
else
upper = static_cast<Dimension::value_type>(
floor(static_cast<double>(input_product.get_max_length()) /
output_product.get_min_length()));
if (lower == -1)
output_shape[minus_one_idx] = Dimension::dynamic();
else if (upper == -1)
output_shape[minus_one_idx] = Dimension(lower, upper);
else if (lower > upper) // empty intersection
output_shape[minus_one_idx] = Dimension::dynamic();
else
output_shape[minus_one_idx] = Dimension(lower, upper);
}
}
}
PartialShape output_pshape(output_shape);
if (input_pshape.is_static() && output_pshape.is_static())
{
size_t zero_dims =
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
});
bool backward_compatible_check = (zero_dims && get_special_zero()) || minus_one_idx != -1;
bool in_out_elements_equal =
shape_size(get_input_shape(0)) == shape_size(output_pshape.to_shape());
NODE_VALIDATION_CHECK(this,
backward_compatible_check || in_out_elements_equal,
"Requested output shape ",
output_shape,
" is incompatible with input shape ",
get_input_shape(0));
}
}

View File

@ -331,6 +331,37 @@ NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_3D_to_scalar)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_1d_to_same_shape)
{
const Shape input_shape{1};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
auto function = make_shared<Function>(r, ParameterVector{param});
auto test_case = test::TestCase<TestEngine>(function);
vector<float> input_values(shape_size(input_shape), 1.f);
test_case.add_input<float>(input_shape, input_values);
test_case.add_expected_output<float>(Shape{}, vector<float>{1.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_to_same_shape)
{
const Shape input_shape{};
auto param = make_shared<op::Parameter>(element::f32, input_shape);
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
auto function = make_shared<Function>(r, ParameterVector{param});
auto test_case = test::TestCase<TestEngine>(function);
vector<float> input_values(shape_size(input_shape), 1.f);
test_case.add_input<float>(input_shape, input_values);
test_case.add_expected_output<float>(Shape{}, vector<float>{1.f});
test_case.run();
}
#if NGRAPH_INTERPRETER_ENABLE #if NGRAPH_INTERPRETER_ENABLE
NGRAPH_TEST(${BACKEND_NAME}, reshape_shufflenet_5d) NGRAPH_TEST(${BACKEND_NAME}, reshape_shufflenet_5d)

View File

@ -579,3 +579,66 @@ TEST(type_prop, reshape_multiply_intervals_by_interval_zero_included)
ASSERT_EQ(r->get_element_type(), element::f32); ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(0, 24)})); ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(0, 24)}));
} }
TEST(type_prop, reshape_to_zero_shape)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{0, 1});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_shape(0), (Shape{0}));
}
TEST(type_prop, reshape_to_zero_shape_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_shape(0), (Shape{0}));
}
TEST(type_prop, reshape_to_zero_shape_incorrect)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 1});
ASSERT_THROW(
make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false),
std::exception);
}
TEST(type_prop, reshape_to_zero)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 1});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_shape(0), (Shape{2}));
}
TEST(type_prop, reshape_to_scalar)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_shape(0), (Shape{}));
}
TEST(type_prop, reshape_to_scalar_2)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_shape(0), (Shape{}));
}
TEST(type_prop, reshape_to_scalar_3)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
ASSERT_THROW(
make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{100}), false),
std::exception);
}