[ nG ] opset1::Reshape output shape inference refactoring (#3542)
* [ nG ] opset1::Reshape output shape inference refactoring * SET_INPUT_IS_RELEVANT_TO_SHAPE * small refactoring * brought legacy check back * style * Integizer * stylee * Apply suggestions from code review * merge master
This commit is contained in:
parent
7f04723d25
commit
511db4724f
@ -54,6 +54,150 @@ namespace reshapeop
|
||||
output_shape.push_back(shape_pattern_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void calculate_output_shape(const ngraph::op::v1::Reshape* reshape_node,
|
||||
vector<Dimension>& reshape_pattern,
|
||||
const int64_t& minus_one_idx,
|
||||
const PartialShape& input_pshape,
|
||||
vector<Dimension>& output_shape)
|
||||
{
|
||||
if (reshape_pattern == std::vector<Dimension>{0} && !reshape_node->get_special_zero())
|
||||
{ // legacy check introduced by PR #1206
|
||||
reshape_pattern = std::vector<Dimension>{};
|
||||
output_shape = {};
|
||||
return;
|
||||
}
|
||||
Dimension output_product(1);
|
||||
for (size_t i = 0; i < reshape_pattern.size(); ++i)
|
||||
{
|
||||
if (i == minus_one_idx) // resolving everything except -1
|
||||
continue;
|
||||
|
||||
auto pattern_dim = reshape_pattern[i];
|
||||
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
|
||||
reshape_node->get_special_zero())
|
||||
{
|
||||
if (input_pshape.rank().is_dynamic())
|
||||
{
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
output_product *= Dimension::dynamic();
|
||||
}
|
||||
else
|
||||
{
|
||||
NODE_VALIDATION_CHECK(reshape_node,
|
||||
i < input_pshape.rank().get_length(),
|
||||
"'0' dimension is out of range");
|
||||
output_shape[i] = input_pshape[i];
|
||||
// we do not include dimension to output product here and won't include in input
|
||||
// product later because we will divide output_product by input_product. This
|
||||
// dimension contributes to both products equally, but in case this dimension
|
||||
// is dynamic and others are not we could fully define output dimension that
|
||||
// is masked by -1
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
output_shape[i] = pattern_dim;
|
||||
output_product *= pattern_dim;
|
||||
}
|
||||
}
|
||||
Dimension input_product(1);
|
||||
if (input_pshape.rank().is_static())
|
||||
for (size_t i = 0; i < input_pshape.rank().get_length(); ++i)
|
||||
{
|
||||
if (i < reshape_pattern.size() && reshape_pattern[i] == 0)
|
||||
continue;
|
||||
input_product *= input_pshape[i];
|
||||
}
|
||||
else
|
||||
input_product = Dimension::dynamic();
|
||||
|
||||
if (minus_one_idx != -1) // resolving -1 masked dimension
|
||||
{
|
||||
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
|
||||
{
|
||||
// TODO: Decide if this is desired behavior here. (NumPy seems
|
||||
// to fail.)
|
||||
NODE_VALIDATION_CHECK(reshape_node,
|
||||
input_product.get_min_length() == 0 &&
|
||||
input_product.get_max_length() == 0,
|
||||
"Cannot infer '-1' dimension with zero-size output "
|
||||
"dimension unless at least one input dimension is "
|
||||
"also zero-size");
|
||||
output_shape[minus_one_idx] = Dimension(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (input_product.is_static() && output_product.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
reshape_node,
|
||||
input_product.get_length() % output_product.get_length() == 0,
|
||||
"Non-'-1' output dimensions do not evenly divide the input dimensions");
|
||||
}
|
||||
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
|
||||
input_product == Dimension())
|
||||
{
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
}
|
||||
else
|
||||
{
|
||||
Dimension::value_type lower;
|
||||
if (input_product.get_min_length() == 0)
|
||||
lower = 0;
|
||||
else if (input_product.get_min_length() == -1 ||
|
||||
output_product.get_max_length() == 0 ||
|
||||
output_product.get_max_length() == -1)
|
||||
lower = -1; // dynamic
|
||||
else
|
||||
lower = static_cast<Dimension::value_type>(
|
||||
ceil(static_cast<double>(input_product.get_min_length()) /
|
||||
output_product.get_max_length()));
|
||||
|
||||
Dimension::value_type upper;
|
||||
if (input_product.get_max_length() == 0)
|
||||
upper = 0;
|
||||
else if (input_product.get_max_length() == -1 ||
|
||||
output_product.get_min_length() == 0 ||
|
||||
output_product.get_min_length() == -1)
|
||||
upper = -1; // dynamic
|
||||
else
|
||||
upper = static_cast<Dimension::value_type>(
|
||||
floor(static_cast<double>(input_product.get_max_length()) /
|
||||
output_product.get_min_length()));
|
||||
|
||||
if (lower == -1)
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
else if (upper == -1)
|
||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||
else if (lower > upper) // empty intersection
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
else
|
||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||
}
|
||||
}
|
||||
}
|
||||
PartialShape output_pshape(output_shape);
|
||||
if (input_pshape.is_static() && output_pshape.is_static())
|
||||
{
|
||||
size_t zero_dims =
|
||||
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
|
||||
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
|
||||
});
|
||||
|
||||
bool backward_compatible_check =
|
||||
(zero_dims && reshape_node->get_special_zero()) || minus_one_idx != -1;
|
||||
bool in_out_elements_equal = shape_size(reshape_node->get_input_shape(0)) ==
|
||||
shape_size(output_pshape.to_shape());
|
||||
|
||||
NODE_VALIDATION_CHECK(reshape_node,
|
||||
backward_compatible_check || in_out_elements_equal,
|
||||
"Requested output shape ",
|
||||
output_shape,
|
||||
" is incompatible with input shape ",
|
||||
reshape_node->get_input_shape(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Reshape, "Reshape", 1);
|
||||
@ -90,136 +234,35 @@ void op::v1::Reshape::validate_and_infer_types()
|
||||
".");
|
||||
Rank output_rank =
|
||||
shape_pattern_shape.rank().is_dynamic() ? Rank::dynamic() : shape_pattern_shape[0];
|
||||
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
|
||||
set_input_is_relevant_to_shape(1);
|
||||
|
||||
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
|
||||
std::vector<Dimension> reshape_pattern;
|
||||
int64_t minus_one_idx = -1;
|
||||
|
||||
if (const auto constant = as_type_ptr<op::Constant>(get_input_node_shared_ptr(1)))
|
||||
{
|
||||
std::vector<int64_t> out_shape_val = const_shape->cast_vector<int64_t>();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
std::none_of(out_shape_val.begin(),
|
||||
out_shape_val.end(),
|
||||
[](int64_t v) { return v < -1; }),
|
||||
"Dim size cannot be less than -1 ");
|
||||
|
||||
int zero_dims = std::count_if(
|
||||
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == 0; });
|
||||
int negative_dims = std::count_if(
|
||||
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == -1; });
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
negative_dims <= 1,
|
||||
"More than one dimension has size of -1 (",
|
||||
negative_dims,
|
||||
")");
|
||||
|
||||
if (!(zero_dims && m_special_zero) && !negative_dims)
|
||||
const auto pattern_vector = constant->cast_vector<int64_t>();
|
||||
for (size_t i = 0; i < pattern_vector.size(); ++i)
|
||||
{
|
||||
auto output_shape = const_shape->get_shape_val();
|
||||
if (output_shape == Shape{0})
|
||||
{
|
||||
output_shape = Shape{};
|
||||
}
|
||||
if (get_input_partial_shape(0).is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
shape_size(get_input_shape(0)) == shape_size(output_shape),
|
||||
"Requested output shape ",
|
||||
output_shape,
|
||||
" is incompatible with input shape ",
|
||||
get_input_shape(0));
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<Dimension> partial_shape(output_rank.get_length());
|
||||
// Replace zeros with Dynamic dimensions as needed
|
||||
for (size_t i = 0; i < out_shape_val.size(); ++i)
|
||||
{
|
||||
const auto& v = out_shape_val[i];
|
||||
if (v < 0)
|
||||
{
|
||||
partial_shape[i] = Dimension();
|
||||
}
|
||||
else if (v == 0 && m_special_zero)
|
||||
{
|
||||
partial_shape[i] = ((input_pshape.rank().is_static() &&
|
||||
input_pshape.rank().get_length() == out_shape_val.size())
|
||||
? input_pshape[i]
|
||||
: Dimension());
|
||||
}
|
||||
else
|
||||
{
|
||||
partial_shape[i] = Dimension(v);
|
||||
}
|
||||
}
|
||||
NODE_VALIDATION_CHECK(this, pattern_vector[i] >= -1, "Dim size cannot be less than -1");
|
||||
|
||||
if (input_pshape.is_static())
|
||||
{
|
||||
size_t output_elements = 1;
|
||||
int negative_dim = -1;
|
||||
|
||||
auto input_shape = input_pshape.to_shape();
|
||||
size_t input_elements = shape_size(input_shape);
|
||||
for (size_t i = 0; i < output_rank.get_length(); i++)
|
||||
{
|
||||
if (out_shape_val[i] == 0 && m_special_zero)
|
||||
{
|
||||
// Copy input_shape[i] for zero values
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, i < input_shape.size(), "'0' dimension is out of range");
|
||||
partial_shape[i] = Dimension(input_shape[i]);
|
||||
output_elements *= input_shape[i];
|
||||
}
|
||||
else if (out_shape_val[i] == -1)
|
||||
{
|
||||
negative_dim = i;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_elements *= out_shape_val[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (negative_dim != -1)
|
||||
{
|
||||
// Infer size such that number of output elements matches
|
||||
// input elements
|
||||
if (output_elements == 0)
|
||||
{
|
||||
// TODO(amprocte): Decide if this is desired behavior here. (NumPy seems
|
||||
// to fail.)
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_elements == 0,
|
||||
"Cannot infer '-1' dimension with zero-size output "
|
||||
"dimension unless at least one input dimension is "
|
||||
"also zero-size");
|
||||
partial_shape[negative_dim] = Dimension(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_elements % output_elements == 0,
|
||||
"Non-'-1' output dimensions do not evenly divide "
|
||||
"the input dimensions");
|
||||
partial_shape[negative_dim] = Dimension(input_elements / output_elements);
|
||||
}
|
||||
}
|
||||
if (pattern_vector[i] == -1)
|
||||
{ // ctor of Dimension(-1) would turn input Dimension(0, max_int)
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, minus_one_idx == -1, "More than one dimension has size of -1");
|
||||
minus_one_idx = static_cast<int64_t>(i);
|
||||
}
|
||||
|
||||
if (out_shape_val == std::vector<std::int64_t>{0, -1} &&
|
||||
input_pshape.rank().is_static() && input_pshape.rank().get_length() == 2)
|
||||
{
|
||||
partial_shape[0] = input_pshape[0];
|
||||
partial_shape[1] = input_pshape[1];
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), PartialShape(partial_shape));
|
||||
reshape_pattern.emplace_back(pattern_vector[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
if (!reshape_pattern.empty())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
|
||||
std::vector<Dimension> output_shape(output_rank.get_length());
|
||||
reshapeop::calculate_output_shape(
|
||||
this, reshape_pattern, minus_one_idx, input_pshape, output_shape);
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
}
|
||||
}
|
||||
|
||||
@ -259,86 +302,26 @@ bool op::v1::Reshape::evaluate_reshape(const HostTensorVector& outputs,
|
||||
default: throw ngraph_error("shape_pattern element type is not integral data type");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
std::none_of(out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v < -1; }),
|
||||
"Dim size cannot be less than -1 ");
|
||||
|
||||
int zero_dims =
|
||||
std::count_if(out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == 0; });
|
||||
int negative_dims = std::count_if(
|
||||
out_shape_val.begin(), out_shape_val.end(), [](int64_t v) { return v == -1; });
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, negative_dims <= 1, "More than one dimension has size of -1 (", negative_dims, ")");
|
||||
|
||||
Shape output_shape;
|
||||
std::copy(out_shape_val.begin(), out_shape_val.end(), std::back_inserter(output_shape));
|
||||
if (!(zero_dims && m_special_zero) && !negative_dims)
|
||||
std::vector<Dimension> reshape_pattern;
|
||||
int64_t minus_one_idx = -1;
|
||||
for (size_t i = 0; i < out_shape_val.size(); ++i)
|
||||
{
|
||||
if (get_input_partial_shape(0).is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
shape_size(inputs[0]->get_shape()) == shape_size(output_shape),
|
||||
"Requested output shape ",
|
||||
output_shape,
|
||||
" is incompatible with input shape ",
|
||||
get_input_shape(0));
|
||||
NODE_VALIDATION_CHECK(this, out_shape_val[i] >= -1, "Dim size cannot be less than -1");
|
||||
if (out_shape_val[i] == -1)
|
||||
{ // ctor of Dimension(-1) would turn input Dimension(0, max_int)
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, minus_one_idx == -1, "More than one dimension has size of -1");
|
||||
minus_one_idx = static_cast<int64_t>(i);
|
||||
}
|
||||
outputs[0]->set_shape(output_shape);
|
||||
reshape_pattern.emplace_back(out_shape_val[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t output_elements = 1;
|
||||
int negative_dim = -1;
|
||||
|
||||
auto input_shape = inputs[0]->get_shape();
|
||||
size_t input_elements = shape_size(input_shape);
|
||||
std::vector<Dimension> output_shape(out_shape_val.size());
|
||||
reshapeop::calculate_output_shape(
|
||||
this, reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape);
|
||||
NGRAPH_CHECK(PartialShape(output_shape).is_static());
|
||||
outputs[0]->set_shape(PartialShape(output_shape).to_shape());
|
||||
|
||||
// compute the output shape
|
||||
for (size_t i = 0; i < output_rank; i++)
|
||||
{
|
||||
if (out_shape_val[i] == 0 && m_special_zero)
|
||||
{
|
||||
// Copy input_shape[i] for zero values
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, i < input_shape.size(), "'0' dimension is out of range");
|
||||
output_shape[i] = input_shape[i];
|
||||
output_elements *= input_shape[i];
|
||||
}
|
||||
else if (out_shape_val[i] == -1)
|
||||
{
|
||||
negative_dim = i;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_elements *= out_shape_val[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (negative_dim != -1)
|
||||
{
|
||||
// Infer size such that number of output elements matches
|
||||
// input elements
|
||||
if (output_elements == 0)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_elements == 0,
|
||||
"Cannot infer '-1' dimension with zero-size output "
|
||||
"dimension unless at least one input dimension is "
|
||||
"also zero-size");
|
||||
output_shape[negative_dim] = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
input_elements % output_elements == 0,
|
||||
"Non-'-1' output dimensions do not evenly divide the input dimensions");
|
||||
output_shape[negative_dim] = input_elements / output_elements;
|
||||
}
|
||||
}
|
||||
outputs[0]->set_shape(output_shape);
|
||||
}
|
||||
const AxisVector order = get_default_order(inputs[0]->get_shape());
|
||||
return reshapeop::evaluate_reshape(inputs[0], outputs[0], order);
|
||||
}
|
||||
|
@ -21,7 +21,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, reshape_deduce_s2v)
|
||||
TEST(type_prop, reshape_deduce_s2t)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
@ -39,7 +39,7 @@ TEST(type_prop, reshape_deduce_s2m)
|
||||
ASSERT_EQ(r->get_shape(), (Shape{1, 1}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_s2t)
|
||||
TEST(type_prop, reshape_deduce_s2m3)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
@ -48,7 +48,7 @@ TEST(type_prop, reshape_deduce_s2t)
|
||||
ASSERT_EQ(r->get_shape(), (Shape{1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_m2v_01)
|
||||
TEST(type_prop, reshape_deduce_2d_to_1d)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
@ -57,25 +57,7 @@ TEST(type_prop, reshape_deduce_m2v_01)
|
||||
ASSERT_EQ(r->get_shape(), (Shape{12}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_m2v_10)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::u64, {1}, Shape{12}), false);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_shape(), (Shape{12}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_t2v_012)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::u64, {1}, Shape{60}), false);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_shape(), (Shape{60}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_t2v_120)
|
||||
TEST(type_prop, reshape_deduce_3d_to_1d)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
@ -105,9 +87,7 @@ TEST(type_prop, reshape_deduce_wrong_output_shape)
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Check 'shape_size(get_input_shape(0)) == shape_size(output_shape)'"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("is incompatible with input shape"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
@ -158,3 +138,218 @@ TEST(type_prop, reshape_partial_rank_static_dynamic_but_zero_ok)
|
||||
ASSERT_TRUE(r->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(r->get_shape(), (Shape{3, 1, 0, 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1, 2});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{-1, 0}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_shape(), (Shape{6, 1}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1, 2});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_shape(), (Shape{3, 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_copy_input)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_shape(), (Shape{3, 1}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_shape(), (Shape{2, 2, 1, 3}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 1, 2});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{-1, 0}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_dynamic)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 1, 1});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_dynamic)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension::dynamic(), 1, 3}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_copy_input_dynamic)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 1});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension::dynamic(), 1}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_partial_rank_dynamic_special_zero)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{3, 1, 0, 2}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{3, 1, Dimension::dynamic(), 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_partial_rank_dynamic_special_neg)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{3, -1, 0, 2}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0),
|
||||
(PartialShape{3, Dimension::dynamic(), Dimension::dynamic(), 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_dynamic_with_interval)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension(1, 3), 3});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3), 1, 3}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_zero_one_neg_double_dynamic_with_interval)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{2, Dimension(1, 3), Dimension::dynamic()});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {4}, std::vector<int64_t>{0, 0, 1, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0),
|
||||
(PartialShape{2, Dimension(1, 3), 1, Dimension::dynamic()}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_zero_neg_dynamic_with_interval)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension(1, 3)});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3)}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic_with_interval)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension(1, 3)});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{-1, 0}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{2, Dimension(1, 3)}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_deduce_special_zero_shape_neg_zero_dynamic_with_interval_1)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension(1, 3), 2});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{-1, 0}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(1, 3), 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_pass_interval_dimension_through_minus_one)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension(1, 3), 2});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {3}, std::vector<int64_t>{0, -1, 2}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 3), 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_multiply_interval_by_defined_dim_for_minus_one)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension(1, 3), 2});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(2, 6)}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_multiply_interval_by_interval_for_minus_one)
|
||||
{
|
||||
auto param =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension(1, 3), Dimension(1, 6)});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 18)}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_multiply_interval_by_interval_divide_by_defined_dim_for_minus_one)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{1, Dimension(1, 3), 3, Dimension(1, 6)});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {3}, std::vector<int64_t>{0, -1, 3}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension(1, 18), 3}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_multiply_interval_by_interval_divide_by_interval_for_minus_one)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{1, -1, Dimension(1, 6)});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic()}));
|
||||
}
|
||||
|
||||
TEST(type_prop,
|
||||
reshape_multiply_interval_by_interval_divide_by_interval_for_minus_one_zero_included_in_input)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape{1, -1, Dimension(0, 6)});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {2}, std::vector<int64_t>{0, -1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic()}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_multiply_intervals_by_interval)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(
|
||||
element::f32, PartialShape{Dimension(1, 2), Dimension(1, 3), Dimension(1, 4)});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{-1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(1, 24)}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_multiply_intervals_by_interval_zero_included)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(
|
||||
element::f32, PartialShape{Dimension(0, 2), Dimension(0, 3), Dimension(0, 4)});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{-1}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(0, 24)}));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user