[ nG ] opset1::Reshape output shape inference refactoring (#3542)

* [ nG ] opset1::Reshape output shape inference refactoring

* SET_INPUT_IS_RELEVANT_TO_SHAPE

* small refactoring

* brought legacy check back

* style

* Integizer

* stylee

* Apply suggestions from code review

* merge master
This commit is contained in:
Evgenya Stepyreva 2021-01-10 21:37:24 +03:00 committed by GitHub
parent 7f04723d25
commit 511db4724f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 399 additions and 221 deletions

View File

@ -54,6 +54,150 @@ namespace reshapeop
output_shape.push_back(shape_pattern_ptr[i]);
}
}
void calculate_output_shape(const ngraph::op::v1::Reshape* reshape_node,
vector<Dimension>& reshape_pattern,
const int64_t& minus_one_idx,
const PartialShape& input_pshape,
vector<Dimension>& output_shape)
{
if (reshape_pattern == std::vector<Dimension>{0} && !reshape_node->get_special_zero())
{ // legacy check introduced by PR #1206
reshape_pattern = std::vector<Dimension>{};
output_shape = {};
return;
}
Dimension output_product(1);
for (size_t i = 0; i < reshape_pattern.size(); ++i)
{
if (i == minus_one_idx) // resolving everything except -1
continue;
auto pattern_dim = reshape_pattern[i];
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
reshape_node->get_special_zero())
{
if (input_pshape.rank().is_dynamic())
{
output_shape[i] = Dimension::dynamic();
output_product *= Dimension::dynamic();
}
else
{
NODE_VALIDATION_CHECK(reshape_node,
i < input_pshape.rank().get_length(),
"'0' dimension is out of range");
output_shape[i] = input_pshape[i];
// we do not include dimension to output product here and won't include in input
// product later because we will divide output_product by input_product. This
// dimension contributes to both products equally, but in case this dimension
// is dynamic and others are not we could fully define output dimension that
// is masked by -1
}
}
else
{
output_shape[i] = pattern_dim;
output_product *= pattern_dim;
}
}
Dimension input_product(1);
if (input_pshape.rank().is_static())
for (size_t i = 0; i < input_pshape.rank().get_length(); ++i)
{
if (i < reshape_pattern.size() && reshape_pattern[i] == 0)
continue;
input_product *= input_pshape[i];
}
else
input_product = Dimension::dynamic();
if (minus_one_idx != -1) // resolving -1 masked dimension
{
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
{
// TODO: Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(reshape_node,
input_product.get_min_length() == 0 &&
input_product.get_max_length() == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
output_shape[minus_one_idx] = Dimension(0);
}
else
{
if (input_product.is_static() && output_product.is_static())
{
NODE_VALIDATION_CHECK(
reshape_node,
input_product.get_length() % output_product.get_length() == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
}
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
input_product == Dimension())
{
output_shape[minus_one_idx] = Dimension::dynamic();
}
else
{
Dimension::value_type lower;
if (input_product.get_min_length() == 0)
lower = 0;
else if (input_product.get_min_length() == -1 ||
output_product.get_max_length() == 0 ||
output_product.get_max_length() == -1)
lower = -1; // dynamic
else
lower = static_cast<Dimension::value_type>(
ceil(static_cast<double>(input_product.get_min_length()) /
output_product.get_max_length()));
Dimension::value_type upper;
if (input_product.get_max_length() == 0)
upper = 0;
else if (input_product.get_max_length() == -1 ||
output_product.get_min_length() == 0 ||
output_product.get_min_length() == -1)
upper = -1; // dynamic
else
upper = static_cast<Dimension::value_type>(
floor(static_cast<double>(input_product.get_max_length()) /
output_product.get_min_length()));
if (lower == -1)
output_shape[minus_one_idx] = Dimension::dynamic();
else if (upper == -1)
output_shape[minus_one_idx] = Dimension(lower, upper);
else if (lower > upper) // empty intersection
output_shape[minus_one_idx] = Dimension::dynamic();
else
output_shape[minus_one_idx] = Dimension(lower, upper);
}
}
}
PartialShape output_pshape(output_shape);
if (input_pshape.is_static() && output_pshape.is_static())
{
size_t zero_dims =
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
});
bool backward_compatible_check =
(zero_dims && reshape_node->get_special_zero()) || minus_one_idx != -1;
bool in_out_elements_equal = shape_size(reshape_node->get_input_shape(0)) ==
shape_size(output_pshape.to_shape());
NODE_VALIDATION_CHECK(reshape_node,
backward_compatible_check || in_out_elements_equal,
"Requested output shape ",
output_shape,
" is incompatible with input shape ",
reshape_node->get_input_shape(0));
}
}
}
NGRAPH_RTTI_DEFINITION(op::v1::Reshape, "Reshape", 1);
@ -90,136 +234,35 @@ void op::v1::Reshape::validate_and_infer_types()
".");
Rank output_rank =
shape_pattern_shape.rank().is_dynamic() ? Rank::dynamic() : shape_pattern_shape[0];
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
set_input_is_relevant_to_shape(1);
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
std::vector<Dimension> reshape_pattern;
int64_t minus_one_idx = -1;
if (const auto constant = as_type_ptr<op::Constant>(get_input_node_shared_ptr(1)))
{
std::vector<int64_t> out_shape_val = const_shape->cast_vector<int64_t>();
NODE_VALIDATION_CHECK(this,
std::none_of(out_shape_val.begin(),
out_shape_val.end(),
[](int64_t v) { return v < -1; }),
"Dim size cannot be less than -1 ");
int zero_dims = std::count_if(
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == 0; });
int negative_dims = std::count_if(
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == -1; });
NODE_VALIDATION_CHECK(this,
negative_dims <= 1,
"More than one dimension has size of -1 (",
negative_dims,
")");
if (!(zero_dims && m_special_zero) && !negative_dims)
const auto pattern_vector = constant->cast_vector<int64_t>();
for (size_t i = 0; i < pattern_vector.size(); ++i)
{
auto output_shape = const_shape->get_shape_val();
if (output_shape == Shape{0})
{
output_shape = Shape{};
}
if (get_input_partial_shape(0).is_static())
{
NODE_VALIDATION_CHECK(this,
shape_size(get_input_shape(0)) == shape_size(output_shape),
"Requested output shape ",
output_shape,
" is incompatible with input shape ",
get_input_shape(0));
}
set_output_type(0, get_input_element_type(0), output_shape);
}
else
{
std::vector<Dimension> partial_shape(output_rank.get_length());
// Replace zeros with Dynamic dimensions as needed
for (size_t i = 0; i < out_shape_val.size(); ++i)
{
const auto& v = out_shape_val[i];
if (v < 0)
{
partial_shape[i] = Dimension();
}
else if (v == 0 && m_special_zero)
{
partial_shape[i] = ((input_pshape.rank().is_static() &&
input_pshape.rank().get_length() == out_shape_val.size())
? input_pshape[i]
: Dimension());
}
else
{
partial_shape[i] = Dimension(v);
}
}
NODE_VALIDATION_CHECK(this, pattern_vector[i] >= -1, "Dim size cannot be less than -1");
if (input_pshape.is_static())
{
size_t output_elements = 1;
int negative_dim = -1;
auto input_shape = input_pshape.to_shape();
size_t input_elements = shape_size(input_shape);
for (size_t i = 0; i < output_rank.get_length(); i++)
{
if (out_shape_val[i] == 0 && m_special_zero)
{
// Copy input_shape[i] for zero values
NODE_VALIDATION_CHECK(
this, i < input_shape.size(), "'0' dimension is out of range");
partial_shape[i] = Dimension(input_shape[i]);
output_elements *= input_shape[i];
}
else if (out_shape_val[i] == -1)
{
negative_dim = i;
}
else
{
output_elements *= out_shape_val[i];
}
}
if (negative_dim != -1)
{
// Infer size such that number of output elements matches
// input elements
if (output_elements == 0)
{
// TODO(amprocte): Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_elements == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
partial_shape[negative_dim] = Dimension(0);
}
else
{
NODE_VALIDATION_CHECK(this,
input_elements % output_elements == 0,
"Non-'-1' output dimensions do not evenly divide "
"the input dimensions");
partial_shape[negative_dim] = Dimension(input_elements / output_elements);
}
}
if (pattern_vector[i] == -1)
{ // ctor of Dimension(-1) would turn input Dimension(0, max_int)
NODE_VALIDATION_CHECK(
this, minus_one_idx == -1, "More than one dimension has size of -1");
minus_one_idx = static_cast<int64_t>(i);
}
if (out_shape_val == std::vector<std::int64_t>{0, -1} &&
input_pshape.rank().is_static() && input_pshape.rank().get_length() == 2)
{
partial_shape[0] = input_pshape[0];
partial_shape[1] = input_pshape[1];
}
set_output_type(0, get_input_element_type(0), PartialShape(partial_shape));
reshape_pattern.emplace_back(pattern_vector[i]);
}
}
else
if (!reshape_pattern.empty())
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
std::vector<Dimension> output_shape(output_rank.get_length());
reshapeop::calculate_output_shape(
this, reshape_pattern, minus_one_idx, input_pshape, output_shape);
set_output_type(0, get_input_element_type(0), output_shape);
}
}
@ -259,86 +302,26 @@ bool op::v1::Reshape::evaluate_reshape(const HostTensorVector& outputs,
default: throw ngraph_error("shape_pattern element type is not integral data type");
}
NODE_VALIDATION_CHECK(
this,
std::none_of(out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v < -1; }),
"Dim size cannot be less than -1 ");
int zero_dims =
std::count_if(out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == 0; });
int negative_dims = std::count_if(
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == -1; });
NODE_VALIDATION_CHECK(
this, negative_dims <= 1, "More than one dimension has size of -1 (", negative_dims, ")");
Shape output_shape;
std::copy(out_shape_val.begin(), out_shape_val.end(), std::back_inserter(output_shape));
if (!(zero_dims && m_special_zero) && !negative_dims)
std::vector<Dimension> reshape_pattern;
int64_t minus_one_idx = -1;
for (size_t i = 0; i < out_shape_val.size(); ++i)
{
if (get_input_partial_shape(0).is_static())
{
NODE_VALIDATION_CHECK(this,
shape_size(inputs[0]->get_shape()) == shape_size(output_shape),
"Requested output shape ",
output_shape,
" is incompatible with input shape ",
get_input_shape(0));
NODE_VALIDATION_CHECK(this, out_shape_val[i] >= -1, "Dim size cannot be less than -1");
if (out_shape_val[i] == -1)
{ // ctor of Dimension(-1) would turn input Dimension(0, max_int)
NODE_VALIDATION_CHECK(
this, minus_one_idx == -1, "More than one dimension has size of -1");
minus_one_idx = static_cast<int64_t>(i);
}
outputs[0]->set_shape(output_shape);
reshape_pattern.emplace_back(out_shape_val[i]);
}
else
{
size_t output_elements = 1;
int negative_dim = -1;
auto input_shape = inputs[0]->get_shape();
size_t input_elements = shape_size(input_shape);
std::vector<Dimension> output_shape(out_shape_val.size());
reshapeop::calculate_output_shape(
this, reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape);
NGRAPH_CHECK(PartialShape(output_shape).is_static());
outputs[0]->set_shape(PartialShape(output_shape).to_shape());
// compute the output shape
for (size_t i = 0; i < output_rank; i++)
{
if (out_shape_val[i] == 0 && m_special_zero)
{
// Copy input_shape[i] for zero values
NODE_VALIDATION_CHECK(
this, i < input_shape.size(), "'0' dimension is out of range");
output_shape[i] = input_shape[i];
output_elements *= input_shape[i];
}
else if (out_shape_val[i] == -1)
{
negative_dim = i;
}
else
{
output_elements *= out_shape_val[i];
}
}
if (negative_dim != -1)
{
// Infer size such that number of output elements matches
// input elements
if (output_elements == 0)
{
NODE_VALIDATION_CHECK(this,
input_elements == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
output_shape[negative_dim] = 0;
}
else
{
NODE_VALIDATION_CHECK(
this,
input_elements % output_elements == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
output_shape[negative_dim] = input_elements / output_elements;
}
}
outputs[0]->set_shape(output_shape);
}
const AxisVector order = get_default_order(inputs[0]->get_shape());
return reshapeop::evaluate_reshape(inputs[0], outputs[0], order);
}

View File

@ -21,7 +21,7 @@
using namespace std;
using namespace ngraph;
TEST(type_prop, reshape_deduce_s2v)
TEST(type_prop, reshape_deduce_s2t)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{});
auto r = make_shared<op::v1::Reshape>(
@ -39,7 +39,7 @@ TEST(type_prop, reshape_deduce_s2m)
ASSERT_EQ(r->get_shape(), (Shape{1, 1}));
}
TEST(type_prop, reshape_deduce_s2t)
TEST(type_prop, reshape_deduce_s2m3)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{});
auto r = make_shared<op::v1::Reshape>(
@ -48,7 +48,7 @@ TEST(type_prop, reshape_deduce_s2t)
ASSERT_EQ(r->get_shape(), (Shape{1, 1, 1}));
}
TEST(type_prop, reshape_deduce_m2v_01)
TEST(type_prop, reshape_deduce_2d_to_1d)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4});
auto r = make_shared<op::v1::Reshape>(
@ -57,25 +57,7 @@ TEST(type_prop, reshape_deduce_m2v_01)
ASSERT_EQ(r->get_shape(), (Shape{12}));
}
TEST(type_prop, reshape_deduce_m2v_10)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::u64, {1}, Shape{12}), false);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{12}));
}
TEST(type_prop, reshape_deduce_t2v_012)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::u64, {1}, Shape{60}), false);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{60}));
}
TEST(type_prop, reshape_deduce_t2v_120)
TEST(type_prop, reshape_deduce_3d_to_1d)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto r = make_shared<op::v1::Reshape>(
@ -105,9 +87,7 @@ TEST(type_prop, reshape_deduce_wrong_output_shape)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Check 'shape_size(get_input_shape(0)) == shape_size(output_shape)'"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("is incompatible with input shape"));
}
catch (...)
{
@ -158,3 +138,218 @@ TEST(type_prop, reshape_partial_rank_static_dynamic_but_zero_ok)
ASSERT_TRUE(r->get_output_partial_shape(0).is_static());
ASSERT_EQ(r->get_shape(), (Shape{3, 1, 0, 2}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1, 2});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{-1, 0}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{6, 1}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1, 2});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{3, 2}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_copy_input)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{3, 1}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{2, 2, 1, 3}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 1, 2});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{-1, 0}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 1, 1});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension::dynamic(), 1, 3}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_copy_input_dynamic)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 1});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1}));
}
TEST(type_prop, reshape_partial_rank_dynamic_special_zero)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{3, 1, 0, 2}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{3, 1, Dimension::dynamic(), 2}));
}
TEST(type_prop, reshape_partial_rank_dynamic_special_neg)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{3, -1, 0, 2}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0),
(PartialShape{3, Dimension::dynamic(), Dimension::dynamic(), 2}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_dynamic_with_interval)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension(1, 3), 3});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3), 1, 3}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_double_dynamic_with_interval)
{
auto param = make_shared<op::Parameter>(element::f32,
PartialShape{2, Dimension(1, 3), Dimension::dynamic()});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0),
(PartialShape{2, Dimension(1, 3), 1, Dimension::dynamic()}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_dynamic_with_interval)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension(1, 3)});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3)}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic_with_interval)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension(1, 3)});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{-1, 0}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3)}));
}
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic_with_interval_1)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension(1, 3), 2});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{-1, 0}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(1, 3), 2}));
}
TEST(type_prop, reshape_pass_interval_dimension_through_minus_one)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension(1, 3), 2});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {3}, std::vector<int64_t>{0, -1, 2}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 3), 2}));
}
TEST(type_prop, reshape_multiply_interval_by_defined_dim_for_minus_one)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension(1, 3), 2});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(2, 6)}));
}
TEST(type_prop, reshape_multiply_interval_by_interval_for_minus_one)
{
auto param =
make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension(1, 3), Dimension(1, 6)});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 18)}));
}
TEST(type_prop, reshape_multiply_interval_by_interval_divide_by_defined_dim_for_minus_one)
{
auto param = make_shared<op::Parameter>(element::f32,
PartialShape{1, Dimension(1, 3), 3, Dimension(1, 6)});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {3}, std::vector<int64_t>{0, -1, 3}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 18), 3}));
}
TEST(type_prop, reshape_multiply_interval_by_interval_divide_by_interval_for_minus_one)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{1, -1, Dimension(1, 6)});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic()}));
}
TEST(type_prop,
reshape_multiply_interval_by_interval_divide_by_interval_for_minus_one_zero_included_in_input)
{
auto param = make_shared<op::Parameter>(element::f32, PartialShape{1, -1, Dimension(0, 6)});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic()}));
}
TEST(type_prop, reshape_multiply_intervals_by_interval)
{
auto param = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension(1, 2), Dimension(1, 3), Dimension(1, 4)});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{-1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(1, 24)}));
}
TEST(type_prop, reshape_multiply_intervals_by_interval_zero_included)
{
auto param = make_shared<op::Parameter>(
element::f32, PartialShape{Dimension(0, 2), Dimension(0, 3), Dimension(0, 4)});
auto r = make_shared<op::v1::Reshape>(
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{-1}), true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(0, 24)}));
}