Reshape should support reshape to zero shapes (#5828)
* Reshape should support reshape to zero shapes * Fixed comments * Fixed backward compatible check * Fixed myriad tests * Removed header * Fixed myriad tests * Disabled Myriad tests * Fix tests * Fixed evaluate * Fixed comments * FIxed tests * Fixed tests * Fixed code style * Fixed Myriad tests * Added more tests
This commit is contained in:
parent
ac1803c3ad
commit
d56cf51c81
@ -82,7 +82,7 @@ INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticTopKPropagationConcatBased,
|
|||||||
class DynamicToStaticTopKPropagationConcatReshape : public DynamicToStaticTopKPropagationConcatBased {
|
class DynamicToStaticTopKPropagationConcatReshape : public DynamicToStaticTopKPropagationConcatBased {
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<ngraph::Node> buildSubgraph(std::shared_ptr<ngraph::Node> node) const override {
|
std::shared_ptr<ngraph::Node> buildSubgraph(std::shared_ptr<ngraph::Node> node) const override {
|
||||||
return std::make_shared<ngraph::opset5::Reshape>(node, ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}), false);
|
return std::make_shared<ngraph::opset5::Reshape>(node, ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}), false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -64,6 +64,12 @@ namespace ngraph
|
|||||||
bool m_special_zero;
|
bool m_special_zero;
|
||||||
bool evaluate_reshape(const HostTensorVector& outputs,
|
bool evaluate_reshape(const HostTensorVector& outputs,
|
||||||
const HostTensorVector& inputs) const;
|
const HostTensorVector& inputs) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void calculate_output_shape(std::vector<Dimension>& reshape_pattern,
|
||||||
|
const int64_t& minus_one_idx,
|
||||||
|
const PartialShape& input_pshape,
|
||||||
|
std::vector<Dimension>& output_shape) const;
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -35,158 +35,12 @@ namespace reshapeop
|
|||||||
{
|
{
|
||||||
using T = typename element_type_traits<ET>::value_type;
|
using T = typename element_type_traits<ET>::value_type;
|
||||||
T* shape_pattern_ptr = shape_pattern->get_data_ptr<ET>();
|
T* shape_pattern_ptr = shape_pattern->get_data_ptr<ET>();
|
||||||
size_t output_rank = shape_pattern->get_shape()[0];
|
size_t output_rank = shape_pattern->get_shape().empty() ? 0 : shape_pattern->get_shape()[0];
|
||||||
for (size_t i = 0; i < output_rank; i++)
|
for (size_t i = 0; i < output_rank; i++)
|
||||||
{
|
{
|
||||||
output_shape.push_back(shape_pattern_ptr[i]);
|
output_shape.push_back(shape_pattern_ptr[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void calculate_output_shape(const ngraph::op::v1::Reshape* reshape_node,
|
|
||||||
vector<Dimension>& reshape_pattern,
|
|
||||||
const int64_t& minus_one_idx,
|
|
||||||
const PartialShape& input_pshape,
|
|
||||||
vector<Dimension>& output_shape)
|
|
||||||
{
|
|
||||||
if (reshape_pattern == std::vector<Dimension>{0} && !reshape_node->get_special_zero())
|
|
||||||
{ // legacy check introduced by PR #1206
|
|
||||||
reshape_pattern = std::vector<Dimension>{};
|
|
||||||
output_shape = {};
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Dimension output_product(1);
|
|
||||||
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i)
|
|
||||||
{
|
|
||||||
if (i == minus_one_idx) // resolving everything except -1
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto pattern_dim = reshape_pattern[i];
|
|
||||||
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
|
|
||||||
reshape_node->get_special_zero())
|
|
||||||
{
|
|
||||||
if (input_pshape.rank().is_dynamic())
|
|
||||||
{
|
|
||||||
output_shape[i] = Dimension::dynamic();
|
|
||||||
output_product *= Dimension::dynamic();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
NODE_VALIDATION_CHECK(reshape_node,
|
|
||||||
i < input_pshape.rank().get_length(),
|
|
||||||
"'0' dimension is out of range");
|
|
||||||
output_shape[i] = input_pshape[i];
|
|
||||||
// we do not include dimension to output product here and won't include in input
|
|
||||||
// product later because we will divide output_product by input_product. This
|
|
||||||
// dimension contributes to both products equally, but in case this dimension
|
|
||||||
// is dynamic and others are not we could fully define output dimension that
|
|
||||||
// is masked by -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
output_shape[i] = pattern_dim;
|
|
||||||
output_product *= pattern_dim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Dimension input_product(1);
|
|
||||||
if (input_pshape.rank().is_static())
|
|
||||||
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i)
|
|
||||||
{
|
|
||||||
if (i < static_cast<int64_t>(reshape_pattern.size()) &&
|
|
||||||
reshape_pattern[i].get_min_length() == 0 &&
|
|
||||||
reshape_pattern[i].get_max_length() == 0)
|
|
||||||
continue;
|
|
||||||
input_product *= input_pshape[i];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
input_product = Dimension::dynamic();
|
|
||||||
|
|
||||||
if (minus_one_idx != -1) // resolving -1 masked dimension
|
|
||||||
{
|
|
||||||
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
|
|
||||||
{
|
|
||||||
// TODO: Decide if this is desired behavior here. (NumPy seems
|
|
||||||
// to fail.)
|
|
||||||
NODE_VALIDATION_CHECK(reshape_node,
|
|
||||||
input_product.get_min_length() == 0 &&
|
|
||||||
input_product.get_max_length() == 0,
|
|
||||||
"Cannot infer '-1' dimension with zero-size output "
|
|
||||||
"dimension unless at least one input dimension is "
|
|
||||||
"also zero-size");
|
|
||||||
output_shape[minus_one_idx] = Dimension(0);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (input_product.is_static() && output_product.is_static())
|
|
||||||
{
|
|
||||||
NODE_VALIDATION_CHECK(
|
|
||||||
reshape_node,
|
|
||||||
input_product.get_length() % output_product.get_length() == 0,
|
|
||||||
"Non-'-1' output dimensions do not evenly divide the input dimensions");
|
|
||||||
}
|
|
||||||
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
|
|
||||||
input_product == Dimension())
|
|
||||||
{
|
|
||||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
Dimension::value_type lower;
|
|
||||||
if (input_product.get_min_length() == 0)
|
|
||||||
lower = 0;
|
|
||||||
else if (input_product.get_min_length() == -1 ||
|
|
||||||
output_product.get_max_length() == 0 ||
|
|
||||||
output_product.get_max_length() == -1)
|
|
||||||
lower = -1; // dynamic
|
|
||||||
else
|
|
||||||
lower = static_cast<Dimension::value_type>(
|
|
||||||
ceil(static_cast<double>(input_product.get_min_length()) /
|
|
||||||
output_product.get_max_length()));
|
|
||||||
|
|
||||||
Dimension::value_type upper;
|
|
||||||
if (input_product.get_max_length() == 0)
|
|
||||||
upper = 0;
|
|
||||||
else if (input_product.get_max_length() == -1 ||
|
|
||||||
output_product.get_min_length() == 0 ||
|
|
||||||
output_product.get_min_length() == -1)
|
|
||||||
upper = -1; // dynamic
|
|
||||||
else
|
|
||||||
upper = static_cast<Dimension::value_type>(
|
|
||||||
floor(static_cast<double>(input_product.get_max_length()) /
|
|
||||||
output_product.get_min_length()));
|
|
||||||
|
|
||||||
if (lower == -1)
|
|
||||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
|
||||||
else if (upper == -1)
|
|
||||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
|
||||||
else if (lower > upper) // empty intersection
|
|
||||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
|
||||||
else
|
|
||||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
PartialShape output_pshape(output_shape);
|
|
||||||
if (input_pshape.is_static() && output_pshape.is_static())
|
|
||||||
{
|
|
||||||
size_t zero_dims =
|
|
||||||
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
|
|
||||||
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
|
|
||||||
});
|
|
||||||
|
|
||||||
bool backward_compatible_check =
|
|
||||||
(zero_dims && reshape_node->get_special_zero()) || minus_one_idx != -1;
|
|
||||||
bool in_out_elements_equal = shape_size(reshape_node->get_input_shape(0)) ==
|
|
||||||
shape_size(output_pshape.to_shape());
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(reshape_node,
|
|
||||||
backward_compatible_check || in_out_elements_equal,
|
|
||||||
"Requested output shape ",
|
|
||||||
output_shape,
|
|
||||||
" is incompatible with input shape ",
|
|
||||||
reshape_node->get_input_shape(0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace reshapeop
|
} // namespace reshapeop
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(op::v1::Reshape, "Reshape", 1);
|
NGRAPH_RTTI_DEFINITION(op::v1::Reshape, "Reshape", 1);
|
||||||
@ -204,7 +58,6 @@ bool op::v1::Reshape::visit_attributes(AttributeVisitor& visitor)
|
|||||||
visitor.on_attribute("special_zero", m_special_zero);
|
visitor.on_attribute("special_zero", m_special_zero);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void op::v1::Reshape::validate_and_infer_types()
|
void op::v1::Reshape::validate_and_infer_types()
|
||||||
{
|
{
|
||||||
NGRAPH_OP_SCOPE(v1_Reshape_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v1_Reshape_validate_and_infer_types);
|
||||||
@ -217,16 +70,21 @@ void op::v1::Reshape::validate_and_infer_types()
|
|||||||
const PartialShape& input_pshape = get_input_partial_shape(0);
|
const PartialShape& input_pshape = get_input_partial_shape(0);
|
||||||
const PartialShape& shape_pattern_shape = get_input_partial_shape(1);
|
const PartialShape& shape_pattern_shape = get_input_partial_shape(1);
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
shape_pattern_shape.rank().compatible(1),
|
shape_pattern_shape.rank().compatible(1) ||
|
||||||
"Pattern shape must have rank 1, got ",
|
(shape_pattern_shape.rank().is_static() &&
|
||||||
|
shape_pattern_shape.rank().get_length() == 0),
|
||||||
|
"Pattern shape must have rank 1 or be empty, got ",
|
||||||
shape_pattern_shape.rank(),
|
shape_pattern_shape.rank(),
|
||||||
".");
|
".");
|
||||||
Rank output_rank =
|
Rank output_rank =
|
||||||
shape_pattern_shape.rank().is_dynamic() ? Rank::dynamic() : shape_pattern_shape[0];
|
shape_pattern_shape.rank().is_dynamic()
|
||||||
|
? Rank::dynamic()
|
||||||
|
: shape_pattern_shape.rank().get_length() == 0 ? 0 : shape_pattern_shape[0];
|
||||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
|
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
|
||||||
set_input_is_relevant_to_shape(1);
|
set_input_is_relevant_to_shape(1);
|
||||||
|
|
||||||
std::vector<Dimension> reshape_pattern;
|
std::vector<Dimension> reshape_pattern;
|
||||||
|
bool shape_can_be_calculated = false;
|
||||||
int64_t minus_one_idx = -1;
|
int64_t minus_one_idx = -1;
|
||||||
|
|
||||||
HostTensorPtr lb, ub;
|
HostTensorPtr lb, ub;
|
||||||
@ -235,6 +93,7 @@ void op::v1::Reshape::validate_and_infer_types()
|
|||||||
{
|
{
|
||||||
const auto lower_bound = std::make_shared<op::Constant>(lb)->cast_vector<int64_t>();
|
const auto lower_bound = std::make_shared<op::Constant>(lb)->cast_vector<int64_t>();
|
||||||
const auto upper_bound = std::make_shared<op::Constant>(ub)->cast_vector<int64_t>();
|
const auto upper_bound = std::make_shared<op::Constant>(ub)->cast_vector<int64_t>();
|
||||||
|
shape_can_be_calculated = true;
|
||||||
NGRAPH_CHECK(lower_bound.size() == upper_bound.size());
|
NGRAPH_CHECK(lower_bound.size() == upper_bound.size());
|
||||||
for (size_t i = 0; i < lower_bound.size(); ++i)
|
for (size_t i = 0; i < lower_bound.size(); ++i)
|
||||||
{
|
{
|
||||||
@ -250,13 +109,22 @@ void op::v1::Reshape::validate_and_infer_types()
|
|||||||
}
|
}
|
||||||
reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]);
|
reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]);
|
||||||
}
|
}
|
||||||
|
// For scalar case reshape_patter should be empty but scalar reshape pattern should be empty
|
||||||
|
// or equal to 1
|
||||||
|
if (output_rank.is_static() && output_rank.get_length() == 0 && !lower_bound.empty())
|
||||||
|
{
|
||||||
|
reshape_pattern.clear();
|
||||||
|
NGRAPH_CHECK(lower_bound.size() == 1);
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
lower_bound[0] == 1 && upper_bound[0] == 1,
|
||||||
|
"The value of scalar shape pattern should be equal to 1!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!reshape_pattern.empty())
|
if (shape_can_be_calculated)
|
||||||
{
|
{
|
||||||
std::vector<Dimension> output_shape(output_rank.get_length());
|
std::vector<Dimension> output_shape(output_rank.get_length());
|
||||||
reshapeop::calculate_output_shape(
|
calculate_output_shape(reshape_pattern, minus_one_idx, input_pshape, output_shape);
|
||||||
this, reshape_pattern, minus_one_idx, input_pshape, output_shape);
|
|
||||||
set_output_type(0, get_input_element_type(0), output_shape);
|
set_output_type(0, get_input_element_type(0), output_shape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -311,8 +179,8 @@ bool op::v1::Reshape::evaluate_reshape(const HostTensorVector& outputs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Dimension> output_shape(out_shape_val.size());
|
std::vector<Dimension> output_shape(out_shape_val.size());
|
||||||
reshapeop::calculate_output_shape(
|
calculate_output_shape(
|
||||||
this, reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape);
|
reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape);
|
||||||
NGRAPH_CHECK(PartialShape(output_shape).is_static());
|
NGRAPH_CHECK(PartialShape(output_shape).is_static());
|
||||||
outputs[0]->set_shape(PartialShape(output_shape).to_shape());
|
outputs[0]->set_shape(PartialShape(output_shape).to_shape());
|
||||||
|
|
||||||
@ -390,3 +258,140 @@ bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVec
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
|
||||||
|
const int64_t& minus_one_idx,
|
||||||
|
const PartialShape& input_pshape,
|
||||||
|
vector<Dimension>& output_shape) const
|
||||||
|
{
|
||||||
|
Dimension output_product(1);
|
||||||
|
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i)
|
||||||
|
{
|
||||||
|
if (i == minus_one_idx) // resolving everything except -1
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto pattern_dim = reshape_pattern[i];
|
||||||
|
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
|
||||||
|
get_special_zero())
|
||||||
|
{
|
||||||
|
if (input_pshape.rank().is_dynamic())
|
||||||
|
{
|
||||||
|
output_shape[i] = Dimension::dynamic();
|
||||||
|
output_product *= Dimension::dynamic();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
this, i < input_pshape.rank().get_length(), "'0' dimension is out of range");
|
||||||
|
output_shape[i] = input_pshape[i];
|
||||||
|
// we do not include dimension to output product here and won't include in input
|
||||||
|
// product later because we will divide output_product by input_product. This
|
||||||
|
// dimension contributes to both products equally, but in case this dimension
|
||||||
|
// is dynamic and others are not we could fully define output dimension that
|
||||||
|
// is masked by -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
output_shape[i] = pattern_dim;
|
||||||
|
output_product *= pattern_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Dimension input_product(1);
|
||||||
|
if (input_pshape.rank().is_static())
|
||||||
|
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i)
|
||||||
|
{
|
||||||
|
if (i < static_cast<int64_t>(reshape_pattern.size()) &&
|
||||||
|
reshape_pattern[i].get_min_length() == 0 &&
|
||||||
|
reshape_pattern[i].get_max_length() == 0)
|
||||||
|
continue;
|
||||||
|
input_product *= input_pshape[i];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
input_product = Dimension::dynamic();
|
||||||
|
|
||||||
|
if (minus_one_idx != -1) // resolving -1 masked dimension
|
||||||
|
{
|
||||||
|
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
|
||||||
|
{
|
||||||
|
// TODO: Decide if this is desired behavior here. (NumPy seems
|
||||||
|
// to fail.)
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
input_product.get_min_length() == 0 &&
|
||||||
|
input_product.get_max_length() == 0,
|
||||||
|
"Cannot infer '-1' dimension with zero-size output "
|
||||||
|
"dimension unless at least one input dimension is "
|
||||||
|
"also zero-size");
|
||||||
|
output_shape[minus_one_idx] = Dimension(0);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (input_product.is_static() && output_product.is_static())
|
||||||
|
{
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
this,
|
||||||
|
input_product.get_length() % output_product.get_length() == 0,
|
||||||
|
"Non-'-1' output dimensions do not evenly divide the input dimensions");
|
||||||
|
}
|
||||||
|
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
|
||||||
|
input_product == Dimension())
|
||||||
|
{
|
||||||
|
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Dimension::value_type lower;
|
||||||
|
if (input_product.get_min_length() == 0)
|
||||||
|
lower = 0;
|
||||||
|
else if (input_product.get_min_length() == -1 ||
|
||||||
|
output_product.get_max_length() == 0 ||
|
||||||
|
output_product.get_max_length() == -1)
|
||||||
|
lower = -1; // dynamic
|
||||||
|
else
|
||||||
|
lower = static_cast<Dimension::value_type>(
|
||||||
|
ceil(static_cast<double>(input_product.get_min_length()) /
|
||||||
|
output_product.get_max_length()));
|
||||||
|
|
||||||
|
Dimension::value_type upper;
|
||||||
|
if (input_product.get_max_length() == 0)
|
||||||
|
upper = 0;
|
||||||
|
else if (input_product.get_max_length() == -1 ||
|
||||||
|
output_product.get_min_length() == 0 ||
|
||||||
|
output_product.get_min_length() == -1)
|
||||||
|
upper = -1; // dynamic
|
||||||
|
else
|
||||||
|
upper = static_cast<Dimension::value_type>(
|
||||||
|
floor(static_cast<double>(input_product.get_max_length()) /
|
||||||
|
output_product.get_min_length()));
|
||||||
|
|
||||||
|
if (lower == -1)
|
||||||
|
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||||
|
else if (upper == -1)
|
||||||
|
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||||
|
else if (lower > upper) // empty intersection
|
||||||
|
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||||
|
else
|
||||||
|
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PartialShape output_pshape(output_shape);
|
||||||
|
if (input_pshape.is_static() && output_pshape.is_static())
|
||||||
|
{
|
||||||
|
size_t zero_dims =
|
||||||
|
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
|
||||||
|
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
|
||||||
|
});
|
||||||
|
|
||||||
|
bool backward_compatible_check = (zero_dims && get_special_zero()) || minus_one_idx != -1;
|
||||||
|
bool in_out_elements_equal =
|
||||||
|
shape_size(get_input_shape(0)) == shape_size(output_pshape.to_shape());
|
||||||
|
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
backward_compatible_check || in_out_elements_equal,
|
||||||
|
"Requested output shape ",
|
||||||
|
output_shape,
|
||||||
|
" is incompatible with input shape ",
|
||||||
|
get_input_shape(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -331,6 +331,37 @@ NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_3D_to_scalar)
|
|||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_1d_to_same_shape)
|
||||||
|
{
|
||||||
|
const Shape input_shape{1};
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, input_shape);
|
||||||
|
auto r = make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
|
||||||
|
auto function = make_shared<Function>(r, ParameterVector{param});
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
vector<float> input_values(shape_size(input_shape), 1.f);
|
||||||
|
test_case.add_input<float>(input_shape, input_values);
|
||||||
|
test_case.add_expected_output<float>(Shape{}, vector<float>{1.f});
|
||||||
|
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_to_same_shape)
|
||||||
|
{
|
||||||
|
const Shape input_shape{};
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, input_shape);
|
||||||
|
auto r = make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
|
||||||
|
auto function = make_shared<Function>(r, ParameterVector{param});
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
vector<float> input_values(shape_size(input_shape), 1.f);
|
||||||
|
test_case.add_input<float>(input_shape, input_values);
|
||||||
|
test_case.add_expected_output<float>(Shape{}, vector<float>{1.f});
|
||||||
|
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
#if NGRAPH_INTERPRETER_ENABLE
|
#if NGRAPH_INTERPRETER_ENABLE
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, reshape_shufflenet_5d)
|
NGRAPH_TEST(${BACKEND_NAME}, reshape_shufflenet_5d)
|
||||||
|
@ -579,3 +579,66 @@ TEST(type_prop, reshape_multiply_intervals_by_interval_zero_included)
|
|||||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(0, 24)}));
|
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(0, 24)}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, reshape_to_zero_shape)
|
||||||
|
{
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, Shape{0, 1});
|
||||||
|
auto r = make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false);
|
||||||
|
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(r->get_output_shape(0), (Shape{0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, reshape_to_zero_shape_dynamic)
|
||||||
|
{
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto r = make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false);
|
||||||
|
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(r->get_output_shape(0), (Shape{0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, reshape_to_zero_shape_incorrect)
|
||||||
|
{
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 1});
|
||||||
|
ASSERT_THROW(
|
||||||
|
make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false),
|
||||||
|
std::exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, reshape_to_zero)
|
||||||
|
{
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 1});
|
||||||
|
auto r = make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), true);
|
||||||
|
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(r->get_output_shape(0), (Shape{2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, reshape_to_scalar)
|
||||||
|
{
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, Shape{});
|
||||||
|
auto r = make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
|
||||||
|
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(r->get_output_shape(0), (Shape{}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, reshape_to_scalar_2)
|
||||||
|
{
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, Shape{});
|
||||||
|
auto r = make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
|
||||||
|
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(r->get_output_shape(0), (Shape{}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, reshape_to_scalar_3)
|
||||||
|
{
|
||||||
|
auto param = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||||
|
ASSERT_THROW(
|
||||||
|
make_shared<op::v1::Reshape>(
|
||||||
|
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{100}), false),
|
||||||
|
std::exception);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user