From 511db4724f19de08985896eed17797cd3b0eae1e Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Sun, 10 Jan 2021 21:37:24 +0300 Subject: [PATCH] [ nG ] opset1::Reshape output shape inference refactoring (#3542) * [ nG ] opset1::Reshape output shape inference refactoring * SET_INPUT_IS_RELEVANT_TO_SHAPE * small refactoring * brought legacy check back * style * Integizer * stylee * Apply suggestions from code review * merge master --- ngraph/core/src/op/reshape.cpp | 375 ++++++++++++++---------------- ngraph/test/type_prop/reshape.cpp | 245 +++++++++++++++++-- 2 files changed, 399 insertions(+), 221 deletions(-) diff --git a/ngraph/core/src/op/reshape.cpp b/ngraph/core/src/op/reshape.cpp index e9cd2254f85..ce16c62c9f2 100644 --- a/ngraph/core/src/op/reshape.cpp +++ b/ngraph/core/src/op/reshape.cpp @@ -54,6 +54,150 @@ namespace reshapeop output_shape.push_back(shape_pattern_ptr[i]); } } + + void calculate_output_shape(const ngraph::op::v1::Reshape* reshape_node, + vector& reshape_pattern, + const int64_t& minus_one_idx, + const PartialShape& input_pshape, + vector& output_shape) + { + if (reshape_pattern == std::vector{0} && !reshape_node->get_special_zero()) + { // legacy check introduced by PR #1206 + reshape_pattern = std::vector{}; + output_shape = {}; + return; + } + Dimension output_product(1); + for (size_t i = 0; i < reshape_pattern.size(); ++i) + { + if (i == minus_one_idx) // resolving everything except -1 + continue; + + auto pattern_dim = reshape_pattern[i]; + if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 && + reshape_node->get_special_zero()) + { + if (input_pshape.rank().is_dynamic()) + { + output_shape[i] = Dimension::dynamic(); + output_product *= Dimension::dynamic(); + } + else + { + NODE_VALIDATION_CHECK(reshape_node, + i < input_pshape.rank().get_length(), + "'0' dimension is out of range"); + output_shape[i] = input_pshape[i]; + // we do not include dimension to output product here and won't include in input + // product later because we will divide output_product by input_product. This + // dimension contributes to both products equally, but in case this dimension + // is dynamic and others are not we could fully define output dimension that + // is masked by -1 + } + } + else + { + output_shape[i] = pattern_dim; + output_product *= pattern_dim; + } + } + Dimension input_product(1); + if (input_pshape.rank().is_static()) + for (size_t i = 0; i < input_pshape.rank().get_length(); ++i) + { + if (i < reshape_pattern.size() && reshape_pattern[i] == 0) + continue; + input_product *= input_pshape[i]; + } + else + input_product = Dimension::dynamic(); + + if (minus_one_idx != -1) // resolving -1 masked dimension + { + if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0) + { + // TODO: Decide if this is desired behavior here. (NumPy seems + // to fail.) + NODE_VALIDATION_CHECK(reshape_node, + input_product.get_min_length() == 0 && + input_product.get_max_length() == 0, + "Cannot infer '-1' dimension with zero-size output " + "dimension unless at least one input dimension is " + "also zero-size"); + output_shape[minus_one_idx] = Dimension(0); + } + else + { + if (input_product.is_static() && output_product.is_static()) + { + NODE_VALIDATION_CHECK( + reshape_node, + input_product.get_length() % output_product.get_length() == 0, + "Non-'-1' output dimensions do not evenly divide the input dimensions"); + } + if (output_product.get_min_length() == 0 || output_product == Dimension() || + input_product == Dimension()) + { + output_shape[minus_one_idx] = Dimension::dynamic(); + } + else + { + Dimension::value_type lower; + if (input_product.get_min_length() == 0) + lower = 0; + else if (input_product.get_min_length() == -1 || + output_product.get_max_length() == 0 || + output_product.get_max_length() == -1) + lower = -1; // dynamic + else + lower = static_cast( + ceil(static_cast(input_product.get_min_length()) / + output_product.get_max_length())); + + Dimension::value_type upper; + if (input_product.get_max_length() == 0) + upper = 0; + else if (input_product.get_max_length() == -1 || + output_product.get_min_length() == 0 || + output_product.get_min_length() == -1) + upper = -1; // dynamic + else + upper = static_cast( + floor(static_cast(input_product.get_max_length()) / + output_product.get_min_length())); + + if (lower == -1) + output_shape[minus_one_idx] = Dimension::dynamic(); + else if (upper == -1) + output_shape[minus_one_idx] = Dimension(lower, upper); + else if (lower > upper) // empty intersection + output_shape[minus_one_idx] = Dimension::dynamic(); + else + output_shape[minus_one_idx] = Dimension(lower, upper); + } + } + } + PartialShape output_pshape(output_shape); + if (input_pshape.is_static() && output_pshape.is_static()) + { + size_t zero_dims = + std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) { + return dim.get_max_length() == 0 && dim.get_min_length() == 0; + }); + + bool backward_compatible_check = + (zero_dims && reshape_node->get_special_zero()) || minus_one_idx != -1; + bool in_out_elements_equal = shape_size(reshape_node->get_input_shape(0)) == + shape_size(output_pshape.to_shape()); + + NODE_VALIDATION_CHECK(reshape_node, + backward_compatible_check || in_out_elements_equal, + "Requested output shape ", + output_shape, + " is incompatible with input shape ", + reshape_node->get_input_shape(0)); + } + } } NGRAPH_RTTI_DEFINITION(op::v1::Reshape, "Reshape", 1); @@ -90,136 +234,35 @@ void op::v1::Reshape::validate_and_infer_types() "."); Rank output_rank = shape_pattern_shape.rank().is_dynamic() ? Rank::dynamic() : shape_pattern_shape[0]; - + set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank)); set_input_is_relevant_to_shape(1); - if (auto const_shape = as_type_ptr(input_value(1).get_node_shared_ptr())) + std::vector reshape_pattern; + int64_t minus_one_idx = -1; + + if (const auto constant = as_type_ptr(get_input_node_shared_ptr(1))) { - std::vector out_shape_val = const_shape->cast_vector(); - NODE_VALIDATION_CHECK(this, - std::none_of(out_shape_val.begin(), - out_shape_val.end(), - [](int64_t v) { return v < -1; }), - "Dim size cannot be less than -1 "); - - int zero_dims = std::count_if( - out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == 0; }); - int negative_dims = std::count_if( - out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == -1; }); - NODE_VALIDATION_CHECK(this, - negative_dims <= 1, - "More than one dimension has size of -1 (", - negative_dims, - ")"); - - if (!(zero_dims && m_special_zero) && !negative_dims) + const auto pattern_vector = constant->cast_vector(); + for (size_t i = 0; i < pattern_vector.size(); ++i) { - auto output_shape = const_shape->get_shape_val(); - if (output_shape == Shape{0}) - { - output_shape = Shape{}; - } - if (get_input_partial_shape(0).is_static()) - { - NODE_VALIDATION_CHECK(this, - shape_size(get_input_shape(0)) == shape_size(output_shape), - "Requested output shape ", - output_shape, - " is incompatible with input shape ", - get_input_shape(0)); - } - set_output_type(0, get_input_element_type(0), output_shape); - } - else - { - std::vector partial_shape(output_rank.get_length()); - // Replace zeros with Dynamic dimensions as needed - for (size_t i = 0; i < out_shape_val.size(); ++i) - { - const auto& v = out_shape_val[i]; - if (v < 0) - { - partial_shape[i] = Dimension(); - } - else if (v == 0 && m_special_zero) - { - partial_shape[i] = ((input_pshape.rank().is_static() && - input_pshape.rank().get_length() == out_shape_val.size()) - ? input_pshape[i] - : Dimension()); - } - else - { - partial_shape[i] = Dimension(v); - } - } + NODE_VALIDATION_CHECK(this, pattern_vector[i] >= -1, "Dim size cannot be less than -1"); - if (input_pshape.is_static()) - { - size_t output_elements = 1; - int negative_dim = -1; - - auto input_shape = input_pshape.to_shape(); - size_t input_elements = shape_size(input_shape); - for (size_t i = 0; i < output_rank.get_length(); i++) - { - if (out_shape_val[i] == 0 && m_special_zero) - { - // Copy input_shape[i] for zero values - NODE_VALIDATION_CHECK( - this, i < input_shape.size(), "'0' dimension is out of range"); - partial_shape[i] = Dimension(input_shape[i]); - output_elements *= input_shape[i]; - } - else if (out_shape_val[i] == -1) - { - negative_dim = i; - } - else - { - output_elements *= out_shape_val[i]; - } - } - - if (negative_dim != -1) - { - // Infer size such that number of output elements matches - // input elements - if (output_elements == 0) - { - // TODO(amprocte): Decide if this is desired behavior here. (NumPy seems - // to fail.) - NODE_VALIDATION_CHECK(this, - input_elements == 0, - "Cannot infer '-1' dimension with zero-size output " - "dimension unless at least one input dimension is " - "also zero-size"); - partial_shape[negative_dim] = Dimension(0); - } - else - { - NODE_VALIDATION_CHECK(this, - input_elements % output_elements == 0, - "Non-'-1' output dimensions do not evenly divide " - "the input dimensions"); - partial_shape[negative_dim] = Dimension(input_elements / output_elements); - } - } + if (pattern_vector[i] == -1) + { // ctor of Dimension(-1) would turn input Dimension(0, max_int) + NODE_VALIDATION_CHECK( + this, minus_one_idx == -1, "More than one dimension has size of -1"); + minus_one_idx = static_cast(i); } - - if (out_shape_val == std::vector{0, -1} && - input_pshape.rank().is_static() && input_pshape.rank().get_length() == 2) - { - partial_shape[0] = input_pshape[0]; - partial_shape[1] = input_pshape[1]; - } - - set_output_type(0, get_input_element_type(0), PartialShape(partial_shape)); + reshape_pattern.emplace_back(pattern_vector[i]); } } - else + + if (!reshape_pattern.empty()) { - set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank)); + std::vector output_shape(output_rank.get_length()); + reshapeop::calculate_output_shape( + this, reshape_pattern, minus_one_idx, input_pshape, output_shape); + set_output_type(0, get_input_element_type(0), output_shape); } } @@ -259,86 +302,26 @@ bool op::v1::Reshape::evaluate_reshape(const HostTensorVector& outputs, default: throw ngraph_error("shape_pattern element type is not integral data type"); } - NODE_VALIDATION_CHECK( - this, - std::none_of(out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v < -1; }), - "Dim size cannot be less than -1 "); - - int zero_dims = - std::count_if(out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == 0; }); - int negative_dims = std::count_if( - out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == -1; }); - NODE_VALIDATION_CHECK( - this, negative_dims <= 1, "More than one dimension has size of -1 (", negative_dims, ")"); - - Shape output_shape; - std::copy(out_shape_val.begin(), out_shape_val.end(), std::back_inserter(output_shape)); - if (!(zero_dims && m_special_zero) && !negative_dims) + std::vector reshape_pattern; + int64_t minus_one_idx = -1; + for (size_t i = 0; i < out_shape_val.size(); ++i) { - if (get_input_partial_shape(0).is_static()) - { - NODE_VALIDATION_CHECK(this, - shape_size(inputs[0]->get_shape()) == shape_size(output_shape), - "Requested output shape ", - output_shape, - " is incompatible with input shape ", - get_input_shape(0)); + NODE_VALIDATION_CHECK(this, out_shape_val[i] >= -1, "Dim size cannot be less than -1"); + if (out_shape_val[i] == -1) + { // ctor of Dimension(-1) would turn input Dimension(0, max_int) + NODE_VALIDATION_CHECK( + this, minus_one_idx == -1, "More than one dimension has size of -1"); + minus_one_idx = static_cast(i); } - outputs[0]->set_shape(output_shape); + reshape_pattern.emplace_back(out_shape_val[i]); } - else - { - size_t output_elements = 1; - int negative_dim = -1; - auto input_shape = inputs[0]->get_shape(); - size_t input_elements = shape_size(input_shape); + std::vector output_shape(out_shape_val.size()); + reshapeop::calculate_output_shape( + this, reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape); + NGRAPH_CHECK(PartialShape(output_shape).is_static()); + outputs[0]->set_shape(PartialShape(output_shape).to_shape()); - // compute the output shape - for (size_t i = 0; i < output_rank; i++) - { - if (out_shape_val[i] == 0 && m_special_zero) - { - // Copy input_shape[i] for zero values - NODE_VALIDATION_CHECK( - this, i < input_shape.size(), "'0' dimension is out of range"); - output_shape[i] = input_shape[i]; - output_elements *= input_shape[i]; - } - else if (out_shape_val[i] == -1) - { - negative_dim = i; - } - else - { - output_elements *= out_shape_val[i]; - } - } - - if (negative_dim != -1) - { - // Infer size such that number of output elements matches - // input elements - if (output_elements == 0) - { - NODE_VALIDATION_CHECK(this, - input_elements == 0, - "Cannot infer '-1' dimension with zero-size output " - "dimension unless at least one input dimension is " - "also zero-size"); - output_shape[negative_dim] = 0; - } - else - { - NODE_VALIDATION_CHECK( - this, - input_elements % output_elements == 0, - "Non-'-1' output dimensions do not evenly divide the input dimensions"); - output_shape[negative_dim] = input_elements / output_elements; - } - } - outputs[0]->set_shape(output_shape); - } const AxisVector order = get_default_order(inputs[0]->get_shape()); return reshapeop::evaluate_reshape(inputs[0], outputs[0], order); } diff --git a/ngraph/test/type_prop/reshape.cpp b/ngraph/test/type_prop/reshape.cpp index 0d2f73b60bf..4a3b8270526 100644 --- a/ngraph/test/type_prop/reshape.cpp +++ b/ngraph/test/type_prop/reshape.cpp @@ -21,7 +21,7 @@ using namespace std; using namespace ngraph; -TEST(type_prop, reshape_deduce_s2v) +TEST(type_prop, reshape_deduce_s2t) { auto param = make_shared(element::f32, Shape{}); auto r = make_shared( @@ -39,7 +39,7 @@ TEST(type_prop, reshape_deduce_s2m) ASSERT_EQ(r->get_shape(), (Shape{1, 1})); } -TEST(type_prop, reshape_deduce_s2t) +TEST(type_prop, reshape_deduce_s2m3) { auto param = make_shared(element::f32, Shape{}); auto r = make_shared( @@ -48,7 +48,7 @@ TEST(type_prop, reshape_deduce_s2t) ASSERT_EQ(r->get_shape(), (Shape{1, 1, 1})); } -TEST(type_prop, reshape_deduce_m2v_01) +TEST(type_prop, reshape_deduce_2d_to_1d) { auto param = make_shared(element::f32, Shape{3, 4}); auto r = make_shared( @@ -57,25 +57,7 @@ TEST(type_prop, reshape_deduce_m2v_01) ASSERT_EQ(r->get_shape(), (Shape{12})); } -TEST(type_prop, reshape_deduce_m2v_10) -{ - auto param = make_shared(element::f32, Shape{3, 4}); - auto r = make_shared( - param, op::Constant::create(element::u64, {1}, Shape{12}), false); - ASSERT_EQ(r->get_element_type(), element::f32); - ASSERT_EQ(r->get_shape(), (Shape{12})); -} - -TEST(type_prop, reshape_deduce_t2v_012) -{ - auto param = make_shared(element::f32, Shape{3, 4, 5}); - auto r = make_shared( - param, op::Constant::create(element::u64, {1}, Shape{60}), false); - ASSERT_EQ(r->get_element_type(), element::f32); - ASSERT_EQ(r->get_shape(), (Shape{60})); -} - -TEST(type_prop, reshape_deduce_t2v_120) +TEST(type_prop, reshape_deduce_3d_to_1d) { auto param = make_shared(element::f32, Shape{3, 4, 5}); auto r = make_shared( @@ -105,9 +87,7 @@ TEST(type_prop, reshape_deduce_wrong_output_shape) } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING( - error.what(), - std::string("Check 'shape_size(get_input_shape(0)) == shape_size(output_shape)'")); + EXPECT_HAS_SUBSTRING(error.what(), std::string("is incompatible with input shape")); } catch (...) { @@ -158,3 +138,218 @@ TEST(type_prop, reshape_partial_rank_static_dynamic_but_zero_ok) ASSERT_TRUE(r->get_output_partial_shape(0).is_static()); ASSERT_EQ(r->get_shape(), (Shape{3, 1, 0, 2})); } + +TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero) +{ + auto param = make_shared(element::f32, Shape{3, 1, 2}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{-1, 0}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_shape(), (Shape{6, 1})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg) +{ + auto param = make_shared(element::f32, Shape{3, 1, 2}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_shape(), (Shape{3, 2})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_copy_input) +{ + auto param = make_shared(element::f32, Shape{3, 1}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_shape(), (Shape{3, 1})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg) +{ + auto param = make_shared(element::f32, Shape{2, 2, 3}); + auto r = make_shared( + param, op::Constant::create(element::i64, {4}, std::vector{0, 0, 1, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_shape(), (Shape{2, 2, 1, 3})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic) +{ + auto param = make_shared(element::f32, PartialShape{Dimension::dynamic(), 1, 2}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{-1, 0}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_dynamic) +{ + auto param = make_shared(element::f32, PartialShape{Dimension::dynamic(), 1, 1}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_dynamic) +{ + auto param = make_shared(element::f32, PartialShape{2, Dimension::dynamic(), 3}); + auto r = make_shared( + param, op::Constant::create(element::i64, {4}, std::vector{0, 0, 1, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension::dynamic(), 1, 3})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_copy_input_dynamic) +{ + auto param = make_shared(element::f32, PartialShape{Dimension::dynamic(), 1}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1})); +} + +TEST(type_prop, reshape_partial_rank_dynamic_special_zero) +{ + auto param = make_shared(element::f32, PartialShape::dynamic()); + auto r = make_shared( + param, op::Constant::create(element::i64, {4}, std::vector{3, 1, 0, 2}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{3, 1, Dimension::dynamic(), 2})); +} + +TEST(type_prop, reshape_partial_rank_dynamic_special_neg) +{ + auto param = make_shared(element::f32, PartialShape::dynamic()); + auto r = make_shared( + param, op::Constant::create(element::i64, {4}, std::vector{3, -1, 0, 2}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), + (PartialShape{3, Dimension::dynamic(), Dimension::dynamic(), 2})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_dynamic_with_interval) +{ + auto param = make_shared(element::f32, PartialShape{2, Dimension(1, 3), 3}); + auto r = make_shared( + param, op::Constant::create(element::i64, {4}, std::vector{0, 0, 1, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3), 1, 3})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_double_dynamic_with_interval) +{ + auto param = make_shared(element::f32, + PartialShape{2, Dimension(1, 3), Dimension::dynamic()}); + auto r = make_shared( + param, op::Constant::create(element::i64, {4}, std::vector{0, 0, 1, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), + (PartialShape{2, Dimension(1, 3), 1, Dimension::dynamic()})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_dynamic_with_interval) +{ + auto param = make_shared(element::f32, PartialShape{2, Dimension(1, 3)}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3)})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic_with_interval) +{ + auto param = make_shared(element::f32, PartialShape{2, Dimension(1, 3)}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{-1, 0}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3)})); +} + +TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic_with_interval_1) +{ + auto param = make_shared(element::f32, PartialShape{Dimension(1, 3), 2}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{-1, 0}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(1, 3), 2})); +} + +TEST(type_prop, reshape_pass_interval_dimension_through_minus_one) +{ + auto param = make_shared(element::f32, PartialShape{1, Dimension(1, 3), 2}); + auto r = make_shared( + param, op::Constant::create(element::i64, {3}, std::vector{0, -1, 2}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 3), 2})); +} + +TEST(type_prop, reshape_multiply_interval_by_defined_dim_for_minus_one) +{ + auto param = make_shared(element::f32, PartialShape{1, Dimension(1, 3), 2}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(2, 6)})); +} + +TEST(type_prop, reshape_multiply_interval_by_interval_for_minus_one) +{ + auto param = + make_shared(element::f32, PartialShape{1, Dimension(1, 3), Dimension(1, 6)}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 18)})); +} + +TEST(type_prop, reshape_multiply_interval_by_interval_divide_by_defined_dim_for_minus_one) +{ + auto param = make_shared(element::f32, + PartialShape{1, Dimension(1, 3), 3, Dimension(1, 6)}); + auto r = make_shared( + param, op::Constant::create(element::i64, {3}, std::vector{0, -1, 3}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 18), 3})); +} + +TEST(type_prop, reshape_multiply_interval_by_interval_divide_by_interval_for_minus_one) +{ + auto param = make_shared(element::f32, PartialShape{1, -1, Dimension(1, 6)}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic()})); +} + +TEST(type_prop, + reshape_multiply_interval_by_interval_divide_by_interval_for_minus_one_zero_included_in_input) +{ + auto param = make_shared(element::f32, PartialShape{1, -1, Dimension(0, 6)}); + auto r = make_shared( + param, op::Constant::create(element::i64, {2}, std::vector{0, -1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic()})); +} + +TEST(type_prop, reshape_multiply_intervals_by_interval) +{ + auto param = make_shared( + element::f32, PartialShape{Dimension(1, 2), Dimension(1, 3), Dimension(1, 4)}); + auto r = make_shared( + param, op::Constant::create(element::i64, {1}, std::vector{-1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(1, 24)})); +} + +TEST(type_prop, reshape_multiply_intervals_by_interval_zero_included) +{ + auto param = make_shared( + element::f32, PartialShape{Dimension(0, 2), Dimension(0, 3), Dimension(0, 4)}); + auto r = make_shared( + param, op::Constant::create(element::i64, {1}, std::vector{-1}), true); + ASSERT_EQ(r->get_element_type(), element::f32); + ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(0, 24)})); +}